47 lines
1.4 KiB
Python
47 lines
1.4 KiB
Python
"""TokenBudget middleware — tracks and limits token usage across the pipeline."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
|
|
from app.graph.state import ReportState
|
|
from .base import Middleware
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TokenBudgetMiddleware(Middleware):
|
|
"""Manages token budget to prevent context overflow.
|
|
|
|
Before: sets token budget based on model limits.
|
|
After: logs total estimated token usage.
|
|
"""
|
|
|
|
name = "token_budget"
|
|
|
|
# Rough char-to-token ratio for Chinese text
|
|
CHARS_PER_TOKEN = 1.5
|
|
|
|
def _estimate_tokens(self, text: str) -> int:
|
|
return int(len(text) / self.CHARS_PER_TOKEN)
|
|
|
|
async def before(self, state: ReportState) -> ReportState:
|
|
# Default budget is already set in ReportState (120k)
|
|
# Could adjust based on model selection
|
|
logger.info(f"[token_budget] budget = {state.token_budget} tokens")
|
|
return state
|
|
|
|
async def after(self, state: ReportState) -> ReportState:
|
|
# Estimate total tokens used in the draft
|
|
total_chars = 0
|
|
for chapter in state.draft.get("chapters", []):
|
|
total_chars += len(chapter.get("content", ""))
|
|
total_chars += len(state.draft.get("executive_summary", ""))
|
|
|
|
estimated = self._estimate_tokens(str(total_chars))
|
|
logger.info(
|
|
f"[token_budget] estimated output tokens: {estimated}, "
|
|
f"budget: {state.token_budget}"
|
|
)
|
|
return state
|