"""TokenBudget middleware — tracks and limits token usage across the pipeline.""" from __future__ import annotations import logging from app.graph.state import ReportState from .base import Middleware logger = logging.getLogger(__name__) class TokenBudgetMiddleware(Middleware): """Manages token budget to prevent context overflow. Before: sets token budget based on model limits. After: logs total estimated token usage. """ name = "token_budget" # Rough char-to-token ratio for Chinese text CHARS_PER_TOKEN = 1.5 def _estimate_tokens(self, text: str) -> int: return int(len(text) / self.CHARS_PER_TOKEN) async def before(self, state: ReportState) -> ReportState: # Default budget is already set in ReportState (120k) # Could adjust based on model selection logger.info(f"[token_budget] budget = {state.token_budget} tokens") return state async def after(self, state: ReportState) -> ReportState: # Estimate total tokens used in the draft total_chars = 0 for chapter in state.draft.get("chapters", []): total_chars += len(chapter.get("content", "")) total_chars += len(state.draft.get("executive_summary", "")) estimated = self._estimate_tokens(str(total_chars)) logger.info( f"[token_budget] estimated output tokens: {estimated}, " f"budget: {state.token_budget}" ) return state