Files
2026-04-25 19:25:22 +08:00

65 lines
2.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Memory middleware — injects persistent facts into state, saves new facts after."""
from __future__ import annotations
import logging
from app.graph.state import ReportState
from app.memory.store import MemoryStore
from .base import Middleware
logger = logging.getLogger(__name__)
class MemoryMiddleware(Middleware):
"""Injects top-N memory facts into state before execution.
After execution, extracts key facts from the report for future reference.
"""
name = "memory"
def __init__(self):
self.store = MemoryStore()
async def before(self, state: ReportState) -> ReportState:
# Load global facts + client-specific facts
facts = self.store.get_top_facts("global", limit=10)
if state.client_id:
client_facts = self.store.get_top_facts(state.client_id, limit=10)
facts = client_facts + facts # client facts first
state.memory_facts = facts[:15] # cap total
if facts:
logger.info(f"[memory] injected {len(state.memory_facts)} facts")
return state
async def after(self, state: ReportState) -> ReportState:
# Save key metadata as facts for future reference
if state.draft and not state.error:
title = state.draft.get("title", "")
report_type = state.report_type
client_id = state.client_id or "global"
self.store.add_fact(
content=f"生成过报告:{title}(类型:{report_type}",
client_id=client_id,
category="context",
confidence=0.9,
source=state.id,
)
# Save decomposition tracks as knowledge
tracks = state.decomposition.get("tracks", [])
if tracks:
track_titles = "".join(t.get("title", "") for t in tracks)
self.store.add_fact(
content=f"报告《{title}》的研究轨道:{track_titles}",
client_id=client_id,
category="knowledge",
confidence=0.6,
source=state.id,
)
return state