init repo
This commit is contained in:
0
app/middleware/__init__.py
Normal file
0
app/middleware/__init__.py
Normal file
22
app/middleware/base.py
Normal file
22
app/middleware/base.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""Middleware base class — before/after hooks around graph execution."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from app.graph.state import ReportState
|
||||
|
||||
|
||||
class Middleware(ABC):
|
||||
"""Base middleware. Subclasses override before() and/or after()."""
|
||||
|
||||
name: str = "base"
|
||||
enabled: bool = True
|
||||
|
||||
async def before(self, state: ReportState) -> ReportState:
|
||||
"""Called before graph execution. Modify state as needed."""
|
||||
return state
|
||||
|
||||
async def after(self, state: ReportState) -> ReportState:
|
||||
"""Called after graph execution. Modify state as needed."""
|
||||
return state
|
||||
35
app/middleware/chain.py
Normal file
35
app/middleware/chain.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Middleware chain — runs ordered middlewares before/after graph execution."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from app.graph.state import ReportState
|
||||
from .base import Middleware
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MiddlewareChain:
|
||||
"""Ordered list of middlewares. before() runs in order, after() runs in reverse."""
|
||||
|
||||
def __init__(self, middlewares: list[Middleware] | None = None):
|
||||
self.middlewares = middlewares or []
|
||||
|
||||
def add(self, mw: Middleware) -> "MiddlewareChain":
|
||||
self.middlewares.append(mw)
|
||||
return self
|
||||
|
||||
async def before(self, state: ReportState) -> ReportState:
|
||||
for mw in self.middlewares:
|
||||
if mw.enabled:
|
||||
logger.debug(f"[middleware:before] {mw.name}")
|
||||
state = await mw.before(state)
|
||||
return state
|
||||
|
||||
async def after(self, state: ReportState) -> ReportState:
|
||||
for mw in reversed(self.middlewares):
|
||||
if mw.enabled:
|
||||
logger.debug(f"[middleware:after] {mw.name}")
|
||||
state = await mw.after(state)
|
||||
return state
|
||||
56
app/middleware/client_context.py
Normal file
56
app/middleware/client_context.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""ClientContext middleware — injects client/project background into state."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from app.graph.state import ReportState
|
||||
from .base import Middleware
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Client profiles stored as JSON files
|
||||
CLIENTS_DIR = Path(__file__).resolve().parent.parent.parent / "clients"
|
||||
|
||||
|
||||
class ClientContextMiddleware(Middleware):
|
||||
"""Loads client profile and injects background context into state.
|
||||
|
||||
Client profiles are JSON files in the `clients/` directory:
|
||||
clients/<client_id>.json
|
||||
{
|
||||
"name": "某咨询公司",
|
||||
"industry": "金融",
|
||||
"preferences": "偏好简洁的执行摘要,数据驱动",
|
||||
"previous_reports": ["report_001", "report_002"]
|
||||
}
|
||||
"""
|
||||
|
||||
name = "client_context"
|
||||
|
||||
async def before(self, state: ReportState) -> ReportState:
|
||||
if not state.client_id:
|
||||
return state
|
||||
|
||||
profile_path = CLIENTS_DIR / f"{state.client_id}.json"
|
||||
if not profile_path.exists():
|
||||
logger.info(f"[client_context] no profile for client '{state.client_id}'")
|
||||
return state
|
||||
|
||||
try:
|
||||
profile = json.loads(profile_path.read_text(encoding="utf-8"))
|
||||
parts = []
|
||||
if name := profile.get("name"):
|
||||
parts.append(f"客户:{name}")
|
||||
if industry := profile.get("industry"):
|
||||
parts.append(f"行业:{industry}")
|
||||
if prefs := profile.get("preferences"):
|
||||
parts.append(f"偏好:{prefs}")
|
||||
state.client_context = ";".join(parts)
|
||||
logger.info(f"[client_context] loaded profile for '{state.client_id}'")
|
||||
except Exception as e:
|
||||
logger.warning(f"[client_context] failed to load profile: {e}")
|
||||
|
||||
return state
|
||||
61
app/middleware/compliance.py
Normal file
61
app/middleware/compliance.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""Compliance middleware — checks for sensitive data leakage in output."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
|
||||
from app.graph.state import ReportState
|
||||
from .base import Middleware
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Patterns that suggest sensitive data
|
||||
SENSITIVE_PATTERNS = [
|
||||
(r"\b\d{15,18}\b", "可能的身份证号"),
|
||||
(r"\b\d{16,19}\b", "可能的银行卡号"),
|
||||
(r"\b1[3-9]\d{9}\b", "可能的手机号"),
|
||||
(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", "邮箱地址"),
|
||||
(r"(?:密码|password|secret|token|api.?key)\s*[:=]\s*\S+", "可能的凭证信息"),
|
||||
]
|
||||
|
||||
|
||||
class ComplianceMiddleware(Middleware):
|
||||
"""Scans report output for sensitive data patterns.
|
||||
|
||||
After graph execution, scans the draft for PII / credential patterns.
|
||||
Logs warnings but does not block (can be made blocking later).
|
||||
"""
|
||||
|
||||
name = "compliance"
|
||||
|
||||
def _scan_text(self, text: str) -> list[dict]:
|
||||
findings = []
|
||||
for pattern, desc in SENSITIVE_PATTERNS:
|
||||
matches = re.findall(pattern, text)
|
||||
if matches:
|
||||
findings.append({
|
||||
"type": desc,
|
||||
"count": len(matches),
|
||||
"samples": [m[:8] + "..." for m in matches[:3]],
|
||||
})
|
||||
return findings
|
||||
|
||||
async def after(self, state: ReportState) -> ReportState:
|
||||
# Scan draft content
|
||||
all_text = json.dumps(state.draft, ensure_ascii=False)
|
||||
findings = self._scan_text(all_text)
|
||||
|
||||
if findings:
|
||||
logger.warning(
|
||||
f"[compliance] found {len(findings)} sensitive data patterns:"
|
||||
)
|
||||
for f in findings:
|
||||
logger.warning(f" - {f['type']}: {f['count']} occurrences")
|
||||
# Store in state for API to surface
|
||||
state.review.setdefault("compliance_warnings", findings)
|
||||
else:
|
||||
logger.info("[compliance] no sensitive data patterns detected")
|
||||
|
||||
return state
|
||||
64
app/middleware/memory.py
Normal file
64
app/middleware/memory.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Memory middleware — injects persistent facts into state, saves new facts after."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from app.graph.state import ReportState
|
||||
from app.memory.store import MemoryStore
|
||||
from .base import Middleware
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryMiddleware(Middleware):
|
||||
"""Injects top-N memory facts into state before execution.
|
||||
|
||||
After execution, extracts key facts from the report for future reference.
|
||||
"""
|
||||
|
||||
name = "memory"
|
||||
|
||||
def __init__(self):
|
||||
self.store = MemoryStore()
|
||||
|
||||
async def before(self, state: ReportState) -> ReportState:
|
||||
# Load global facts + client-specific facts
|
||||
facts = self.store.get_top_facts("global", limit=10)
|
||||
if state.client_id:
|
||||
client_facts = self.store.get_top_facts(state.client_id, limit=10)
|
||||
facts = client_facts + facts # client facts first
|
||||
|
||||
state.memory_facts = facts[:15] # cap total
|
||||
if facts:
|
||||
logger.info(f"[memory] injected {len(state.memory_facts)} facts")
|
||||
return state
|
||||
|
||||
async def after(self, state: ReportState) -> ReportState:
|
||||
# Save key metadata as facts for future reference
|
||||
if state.draft and not state.error:
|
||||
title = state.draft.get("title", "")
|
||||
report_type = state.report_type
|
||||
client_id = state.client_id or "global"
|
||||
|
||||
self.store.add_fact(
|
||||
content=f"生成过报告:{title}(类型:{report_type})",
|
||||
client_id=client_id,
|
||||
category="context",
|
||||
confidence=0.9,
|
||||
source=state.id,
|
||||
)
|
||||
|
||||
# Save decomposition tracks as knowledge
|
||||
tracks = state.decomposition.get("tracks", [])
|
||||
if tracks:
|
||||
track_titles = "、".join(t.get("title", "") for t in tracks)
|
||||
self.store.add_fact(
|
||||
content=f"报告《{title}》的研究轨道:{track_titles}",
|
||||
client_id=client_id,
|
||||
category="knowledge",
|
||||
confidence=0.6,
|
||||
source=state.id,
|
||||
)
|
||||
|
||||
return state
|
||||
46
app/middleware/token_budget.py
Normal file
46
app/middleware/token_budget.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""TokenBudget middleware — tracks and limits token usage across the pipeline."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from app.graph.state import ReportState
|
||||
from .base import Middleware
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TokenBudgetMiddleware(Middleware):
|
||||
"""Manages token budget to prevent context overflow.
|
||||
|
||||
Before: sets token budget based on model limits.
|
||||
After: logs total estimated token usage.
|
||||
"""
|
||||
|
||||
name = "token_budget"
|
||||
|
||||
# Rough char-to-token ratio for Chinese text
|
||||
CHARS_PER_TOKEN = 1.5
|
||||
|
||||
def _estimate_tokens(self, text: str) -> int:
|
||||
return int(len(text) / self.CHARS_PER_TOKEN)
|
||||
|
||||
async def before(self, state: ReportState) -> ReportState:
|
||||
# Default budget is already set in ReportState (120k)
|
||||
# Could adjust based on model selection
|
||||
logger.info(f"[token_budget] budget = {state.token_budget} tokens")
|
||||
return state
|
||||
|
||||
async def after(self, state: ReportState) -> ReportState:
|
||||
# Estimate total tokens used in the draft
|
||||
total_chars = 0
|
||||
for chapter in state.draft.get("chapters", []):
|
||||
total_chars += len(chapter.get("content", ""))
|
||||
total_chars += len(state.draft.get("executive_summary", ""))
|
||||
|
||||
estimated = self._estimate_tokens(str(total_chars))
|
||||
logger.info(
|
||||
f"[token_budget] estimated output tokens: {estimated}, "
|
||||
f"budget: {state.token_budget}"
|
||||
)
|
||||
return state
|
||||
Reference in New Issue
Block a user