"""Memory middleware — injects persistent facts into state, saves new facts after.""" from __future__ import annotations import logging from app.graph.state import ReportState from app.memory.store import MemoryStore from .base import Middleware logger = logging.getLogger(__name__) class MemoryMiddleware(Middleware): """Injects top-N memory facts into state before execution. After execution, extracts key facts from the report for future reference. """ name = "memory" def __init__(self): self.store = MemoryStore() async def before(self, state: ReportState) -> ReportState: # Load global facts + client-specific facts facts = self.store.get_top_facts("global", limit=10) if state.client_id: client_facts = self.store.get_top_facts(state.client_id, limit=10) facts = client_facts + facts # client facts first state.memory_facts = facts[:15] # cap total if facts: logger.info(f"[memory] injected {len(state.memory_facts)} facts") return state async def after(self, state: ReportState) -> ReportState: # Save key metadata as facts for future reference if state.draft and not state.error: title = state.draft.get("title", "") report_type = state.report_type client_id = state.client_id or "global" self.store.add_fact( content=f"生成过报告:{title}(类型:{report_type})", client_id=client_id, category="context", confidence=0.9, source=state.id, ) # Save decomposition tracks as knowledge tracks = state.decomposition.get("tracks", []) if tracks: track_titles = "、".join(t.get("title", "") for t in tracks) self.store.add_fact( content=f"报告《{title}》的研究轨道:{track_titles}", client_id=client_id, category="knowledge", confidence=0.6, source=state.id, ) return state