65 lines
2.1 KiB
Python
65 lines
2.1 KiB
Python
"""Memory middleware — injects persistent facts into state, saves new facts after."""
|
||
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
|
||
from app.graph.state import ReportState
|
||
from app.memory.store import MemoryStore
|
||
from .base import Middleware
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class MemoryMiddleware(Middleware):
|
||
"""Injects top-N memory facts into state before execution.
|
||
|
||
After execution, extracts key facts from the report for future reference.
|
||
"""
|
||
|
||
name = "memory"
|
||
|
||
def __init__(self):
|
||
self.store = MemoryStore()
|
||
|
||
async def before(self, state: ReportState) -> ReportState:
|
||
# Load global facts + client-specific facts
|
||
facts = self.store.get_top_facts("global", limit=10)
|
||
if state.client_id:
|
||
client_facts = self.store.get_top_facts(state.client_id, limit=10)
|
||
facts = client_facts + facts # client facts first
|
||
|
||
state.memory_facts = facts[:15] # cap total
|
||
if facts:
|
||
logger.info(f"[memory] injected {len(state.memory_facts)} facts")
|
||
return state
|
||
|
||
async def after(self, state: ReportState) -> ReportState:
|
||
# Save key metadata as facts for future reference
|
||
if state.draft and not state.error:
|
||
title = state.draft.get("title", "")
|
||
report_type = state.report_type
|
||
client_id = state.client_id or "global"
|
||
|
||
self.store.add_fact(
|
||
content=f"生成过报告:{title}(类型:{report_type})",
|
||
client_id=client_id,
|
||
category="context",
|
||
confidence=0.9,
|
||
source=state.id,
|
||
)
|
||
|
||
# Save decomposition tracks as knowledge
|
||
tracks = state.decomposition.get("tracks", [])
|
||
if tracks:
|
||
track_titles = "、".join(t.get("title", "") for t in tracks)
|
||
self.store.add_fact(
|
||
content=f"报告《{title}》的研究轨道:{track_titles}",
|
||
client_id=client_id,
|
||
category="knowledge",
|
||
confidence=0.6,
|
||
source=state.id,
|
||
)
|
||
|
||
return state
|