init repo

This commit is contained in:
2026-04-25 19:25:22 +08:00
commit c7533eada2
50 changed files with 3732 additions and 0 deletions

View File

22
app/middleware/base.py Normal file
View File

@@ -0,0 +1,22 @@
"""Middleware base class — before/after hooks around graph execution."""
from __future__ import annotations
from abc import ABC, abstractmethod
from app.graph.state import ReportState
class Middleware(ABC):
"""Base middleware. Subclasses override before() and/or after()."""
name: str = "base"
enabled: bool = True
async def before(self, state: ReportState) -> ReportState:
"""Called before graph execution. Modify state as needed."""
return state
async def after(self, state: ReportState) -> ReportState:
"""Called after graph execution. Modify state as needed."""
return state

35
app/middleware/chain.py Normal file
View File

@@ -0,0 +1,35 @@
"""Middleware chain — runs ordered middlewares before/after graph execution."""
from __future__ import annotations
import logging
from app.graph.state import ReportState
from .base import Middleware
logger = logging.getLogger(__name__)
class MiddlewareChain:
"""Ordered list of middlewares. before() runs in order, after() runs in reverse."""
def __init__(self, middlewares: list[Middleware] | None = None):
self.middlewares = middlewares or []
def add(self, mw: Middleware) -> "MiddlewareChain":
self.middlewares.append(mw)
return self
async def before(self, state: ReportState) -> ReportState:
for mw in self.middlewares:
if mw.enabled:
logger.debug(f"[middleware:before] {mw.name}")
state = await mw.before(state)
return state
async def after(self, state: ReportState) -> ReportState:
for mw in reversed(self.middlewares):
if mw.enabled:
logger.debug(f"[middleware:after] {mw.name}")
state = await mw.after(state)
return state

View File

@@ -0,0 +1,56 @@
"""ClientContext middleware — injects client/project background into state."""
from __future__ import annotations
import json
import logging
from pathlib import Path
from app.graph.state import ReportState
from .base import Middleware
logger = logging.getLogger(__name__)
# Client profiles stored as JSON files
CLIENTS_DIR = Path(__file__).resolve().parent.parent.parent / "clients"
class ClientContextMiddleware(Middleware):
"""Loads client profile and injects background context into state.
Client profiles are JSON files in the `clients/` directory:
clients/<client_id>.json
{
"name": "某咨询公司",
"industry": "金融",
"preferences": "偏好简洁的执行摘要,数据驱动",
"previous_reports": ["report_001", "report_002"]
}
"""
name = "client_context"
async def before(self, state: ReportState) -> ReportState:
if not state.client_id:
return state
profile_path = CLIENTS_DIR / f"{state.client_id}.json"
if not profile_path.exists():
logger.info(f"[client_context] no profile for client '{state.client_id}'")
return state
try:
profile = json.loads(profile_path.read_text(encoding="utf-8"))
parts = []
if name := profile.get("name"):
parts.append(f"客户:{name}")
if industry := profile.get("industry"):
parts.append(f"行业:{industry}")
if prefs := profile.get("preferences"):
parts.append(f"偏好:{prefs}")
state.client_context = "".join(parts)
logger.info(f"[client_context] loaded profile for '{state.client_id}'")
except Exception as e:
logger.warning(f"[client_context] failed to load profile: {e}")
return state

View File

@@ -0,0 +1,61 @@
"""Compliance middleware — checks for sensitive data leakage in output."""
from __future__ import annotations
import json
import logging
import re
from app.graph.state import ReportState
from .base import Middleware
logger = logging.getLogger(__name__)
# Patterns that suggest sensitive data
SENSITIVE_PATTERNS = [
(r"\b\d{15,18}\b", "可能的身份证号"),
(r"\b\d{16,19}\b", "可能的银行卡号"),
(r"\b1[3-9]\d{9}\b", "可能的手机号"),
(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", "邮箱地址"),
(r"(?:密码|password|secret|token|api.?key)\s*[:=]\s*\S+", "可能的凭证信息"),
]
class ComplianceMiddleware(Middleware):
"""Scans report output for sensitive data patterns.
After graph execution, scans the draft for PII / credential patterns.
Logs warnings but does not block (can be made blocking later).
"""
name = "compliance"
def _scan_text(self, text: str) -> list[dict]:
findings = []
for pattern, desc in SENSITIVE_PATTERNS:
matches = re.findall(pattern, text)
if matches:
findings.append({
"type": desc,
"count": len(matches),
"samples": [m[:8] + "..." for m in matches[:3]],
})
return findings
async def after(self, state: ReportState) -> ReportState:
# Scan draft content
all_text = json.dumps(state.draft, ensure_ascii=False)
findings = self._scan_text(all_text)
if findings:
logger.warning(
f"[compliance] found {len(findings)} sensitive data patterns:"
)
for f in findings:
logger.warning(f" - {f['type']}: {f['count']} occurrences")
# Store in state for API to surface
state.review.setdefault("compliance_warnings", findings)
else:
logger.info("[compliance] no sensitive data patterns detected")
return state

64
app/middleware/memory.py Normal file
View File

@@ -0,0 +1,64 @@
"""Memory middleware — injects persistent facts into state, saves new facts after."""
from __future__ import annotations
import logging
from app.graph.state import ReportState
from app.memory.store import MemoryStore
from .base import Middleware
logger = logging.getLogger(__name__)
class MemoryMiddleware(Middleware):
"""Injects top-N memory facts into state before execution.
After execution, extracts key facts from the report for future reference.
"""
name = "memory"
def __init__(self):
self.store = MemoryStore()
async def before(self, state: ReportState) -> ReportState:
# Load global facts + client-specific facts
facts = self.store.get_top_facts("global", limit=10)
if state.client_id:
client_facts = self.store.get_top_facts(state.client_id, limit=10)
facts = client_facts + facts # client facts first
state.memory_facts = facts[:15] # cap total
if facts:
logger.info(f"[memory] injected {len(state.memory_facts)} facts")
return state
async def after(self, state: ReportState) -> ReportState:
# Save key metadata as facts for future reference
if state.draft and not state.error:
title = state.draft.get("title", "")
report_type = state.report_type
client_id = state.client_id or "global"
self.store.add_fact(
content=f"生成过报告:{title}(类型:{report_type}",
client_id=client_id,
category="context",
confidence=0.9,
source=state.id,
)
# Save decomposition tracks as knowledge
tracks = state.decomposition.get("tracks", [])
if tracks:
track_titles = "".join(t.get("title", "") for t in tracks)
self.store.add_fact(
content=f"报告《{title}》的研究轨道:{track_titles}",
client_id=client_id,
category="knowledge",
confidence=0.6,
source=state.id,
)
return state

View File

@@ -0,0 +1,46 @@
"""TokenBudget middleware — tracks and limits token usage across the pipeline."""
from __future__ import annotations
import logging
from app.graph.state import ReportState
from .base import Middleware
logger = logging.getLogger(__name__)
class TokenBudgetMiddleware(Middleware):
"""Manages token budget to prevent context overflow.
Before: sets token budget based on model limits.
After: logs total estimated token usage.
"""
name = "token_budget"
# Rough char-to-token ratio for Chinese text
CHARS_PER_TOKEN = 1.5
def _estimate_tokens(self, text: str) -> int:
return int(len(text) / self.CHARS_PER_TOKEN)
async def before(self, state: ReportState) -> ReportState:
# Default budget is already set in ReportState (120k)
# Could adjust based on model selection
logger.info(f"[token_budget] budget = {state.token_budget} tokens")
return state
async def after(self, state: ReportState) -> ReportState:
# Estimate total tokens used in the draft
total_chars = 0
for chapter in state.draft.get("chapters", []):
total_chars += len(chapter.get("content", ""))
total_chars += len(state.draft.get("executive_summary", ""))
estimated = self._estimate_tokens(str(total_chars))
logger.info(
f"[token_budget] estimated output tokens: {estimated}, "
f"budget: {state.token_budget}"
)
return state