init repo

This commit is contained in:
2026-04-25 19:25:22 +08:00
commit c7533eada2
50 changed files with 3732 additions and 0 deletions

0
app/__init__.py Normal file
View File

15
app/agents/__init__.py Normal file
View File

@@ -0,0 +1,15 @@
from .base import BaseAgent
from .researcher import ResearcherAgent
from .writer import WriterAgent
from .data_agent import DataAgent
from .reviewer import ReviewerAgent
from .formatter import FormatterAgent
__all__ = [
"BaseAgent",
"ResearcherAgent",
"WriterAgent",
"DataAgent",
"ReviewerAgent",
"FormatterAgent",
]

166
app/agents/base.py Normal file
View File

@@ -0,0 +1,166 @@
"""Base agent with LLM calling via litellm."""
from __future__ import annotations
import json
import logging
from typing import Any
import litellm
from app.config import settings
logger = logging.getLogger(__name__)
# Disable litellm telemetry
litellm.telemetry = False
class BaseAgent:
"""Base class for all pipeline agents."""
name: str = "base"
description: str = ""
system_prompt: str = ""
model: str = "" # empty = use default from config
def __init__(self, model: str | None = None):
if model:
self.model = model
def get_model(self) -> str:
return self.model or settings.llm_model
async def call_llm(
self,
prompt: str,
*,
system: str | None = None,
temperature: float = 0.3,
max_tokens: int = 4096,
response_format: dict | None = None,
) -> str:
"""Call LLM via litellm. Returns the text response."""
messages = []
sys_prompt = system or self.system_prompt
if sys_prompt:
messages.append({"role": "system", "content": sys_prompt})
messages.append({"role": "user", "content": prompt})
kwargs: dict[str, Any] = {
"model": self.get_model(),
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
}
if settings.llm_api_key:
kwargs["api_key"] = settings.llm_api_key
if settings.llm_api_base:
kwargs["api_base"] = settings.llm_api_base
if response_format:
kwargs["response_format"] = response_format
logger.info(f"[{self.name}] calling {self.get_model()}")
response = await litellm.acompletion(**kwargs)
content = response.choices[0].message.content
logger.info(f"[{self.name}] got {len(content)} chars")
return content
async def call_llm_json(self, prompt: str, **kwargs) -> dict:
"""Call LLM and parse response as JSON."""
raw = await self.call_llm(
prompt,
response_format={"type": "json_object"},
**kwargs,
)
# Strip markdown code fences if present
text = raw.strip()
if text.startswith("```"):
first_nl = text.find("\n")
if first_nl != -1:
text = text[first_nl + 1:]
if text.endswith("```"):
text = text[: text.rfind("```")]
text = text.strip()
# Sanitize control characters inside JSON string values
# (models sometimes emit literal newlines/tabs inside strings)
import re
def _clean_json_string(s: str) -> str:
# Replace unescaped control chars within JSON strings
# This is a best-effort fix for common model outputs
result = []
in_string = False
escape = False
for ch in s:
if escape:
result.append(ch)
escape = False
continue
if ch == '\\':
result.append(ch)
escape = True
continue
if ch == '"':
in_string = not in_string
result.append(ch)
continue
if in_string and ord(ch) < 32:
# Replace control chars with escaped versions
if ch == '\n':
result.append('\\n')
elif ch == '\r':
result.append('\\r')
elif ch == '\t':
result.append('\\t')
else:
result.append(f'\\u{ord(ch):04x}')
continue
result.append(ch)
return ''.join(result)
# Try parsing with multiple strategies
for attempt, candidate in enumerate([text, _clean_json_string(text)]):
try:
return json.loads(candidate)
except json.JSONDecodeError:
continue
# Last resort: try to extract the largest valid JSON object
# (model may have appended commentary after the JSON)
brace_depth = 0
start = text.find('{')
if start == -1:
raise json.JSONDecodeError("No JSON object found", text, 0)
cleaned = _clean_json_string(text)
for i, ch in enumerate(cleaned[start:], start):
if ch == '{':
brace_depth += 1
elif ch == '}':
brace_depth -= 1
if brace_depth == 0:
try:
return json.loads(cleaned[start:i + 1])
except json.JSONDecodeError:
continue
# If all else fails, use json_repair library or raise
try:
import json_repair
return json_repair.loads(text)
except (ImportError, Exception):
raise json.JSONDecodeError(
f"Failed to parse JSON after multiple attempts", text, 0
)
async def run(self, context: dict[str, Any]) -> dict[str, Any]:
"""Execute this agent's task. Override in subclasses.
Args:
context: Shared pipeline context (accumulated by previous agents).
Returns:
Dict of new keys to merge into context.
"""
raise NotImplementedError

78
app/agents/data_agent.py Normal file
View File

@@ -0,0 +1,78 @@
"""Data Agent — processes data, generates chart specs and table data."""
from __future__ import annotations
import json
from typing import Any
from .base import BaseAgent
from app.config import settings
class DataAgent(BaseAgent):
name = "data"
description = "处理数据、生成图表规格和表格数据"
system_prompt = """\
你是一位数据分析专家。你的任务是根据报告草稿中标注的图表和表格需求,
生成具体的数据和图表规格。
输出要求JSON 格式):
{
"charts": [
{
"id": "chart_1",
"title": "图表标题",
"type": "bar|line|pie|area|scatter",
"description": "图表说明",
"data": {
"labels": ["标签1", "标签2"],
"datasets": [
{"label": "数据集名", "data": [100, 200]}
]
}
}
],
"tables": [
{
"id": "table_1",
"title": "表格标题",
"headers": ["列1", "列2", "列3"],
"rows": [["数据1", "数据2", "数据3"]]
}
]
}"""
def __init__(self):
super().__init__(model=settings.model_for_domain("fast"))
async def run(self, context: dict[str, Any]) -> dict[str, Any]:
draft = context["draft"]
extra_data = context.get("extra_data", "")
# Collect chart/table needs from draft
chart_needs = []
table_needs = []
for ch in draft.get("chapters", []):
chart_needs.extend(ch.get("charts", []))
table_needs.extend(ch.get("tables", []))
if not chart_needs and not table_needs:
return {"data_assets": {"charts": [], "tables": []}}
prompt = f"""\
## 报告标题
{draft.get("title", "")}
## 需要生成的图表
{json.dumps(chart_needs, ensure_ascii=False)}
## 需要生成的表格
{json.dumps(table_needs, ensure_ascii=False)}
## 补充数据源
{extra_data if extra_data else "(无额外数据,请根据行业常识生成合理的示例数据)"}
请为以上需求生成具体的图表规格和表格数据。输出 JSON。"""
result = await self.call_llm_json(prompt)
return {"data_assets": result}

669
app/agents/formatter.py Normal file
View File

@@ -0,0 +1,669 @@
"""Formatter Agent — renders final report using Skills toolkit.
Skills integration:
- docx: python-docx (baseline) + docx-js via Node.js (rich mode) + OOXML template editing
- pptx: html2pptx.js via Node.js (visual slides) + python-pptx fallback
- xlsx: openpyxl + recalc.py (formula recalculation via LibreOffice)
- pdf: reportlab with CJK support + fpdf2 fallback
"""
from __future__ import annotations
import asyncio
import json
import logging
import shutil
import subprocess
import tempfile
from pathlib import Path
from typing import Any
from .base import BaseAgent
logger = logging.getLogger(__name__)
# Skills root
SKILLS_ROOT = Path.home() / "Projects/code/20260119-skills合集/anthropics_skills/skills"
DOCX_SKILLS = SKILLS_ROOT / "docx"
PPTX_SKILLS = SKILLS_ROOT / "pptx"
XLSX_SKILLS = SKILLS_ROOT / "xlsx"
PDF_SKILLS = SKILLS_ROOT / "pdf"
def _skills_available() -> dict[str, bool]:
"""Check which skill toolkits are available."""
return {
"docx_js": (DOCX_SKILLS / "docx-js.md").exists(),
"html2pptx": (PPTX_SKILLS / "scripts" / "html2pptx.js").exists(),
"recalc": (XLSX_SKILLS / "recalc.py").exists(),
"ooxml_docx": (DOCX_SKILLS / "ooxml" / "scripts" / "unpack.py").exists(),
"ooxml_pptx": (PPTX_SKILLS / "ooxml" / "scripts" / "unpack.py").exists(),
"pdf_scripts": (PDF_SKILLS / "scripts").is_dir(),
}
class FormatterAgent(BaseAgent):
name = "formatter"
description = "将报告渲染为 docx/pptx/xlsx/pdf融合 Skills 能力"
def __init__(self):
super().__init__()
self.skills = _skills_available()
available = [k for k, v in self.skills.items() if v]
logger.info(f"[formatter] available skills: {available}")
async def run(self, context: dict[str, Any]) -> dict[str, Any]:
draft = context["draft"]
data_assets = context.get("data_assets", {})
output_dir = Path(context.get("output_dir", "output"))
formats = context.get("output_formats", ["docx"])
template_path = context.get("template_path") # optional: user-provided template
output_dir.mkdir(parents=True, exist_ok=True)
title = draft.get("title", "报告")
generated_files = []
for fmt in formats:
try:
match fmt:
case "docx":
path = await self._render_docx(draft, data_assets, output_dir, title, template_path)
case "pptx":
path = await self._render_pptx(draft, data_assets, output_dir, title)
case "xlsx":
path = await self._render_xlsx(data_assets, output_dir, title)
case "pdf":
path = await self._render_pdf(draft, data_assets, output_dir, title)
case _:
logger.warning(f"Unsupported format: {fmt}")
continue
generated_files.append(str(path))
logger.info(f"[formatter] generated {path}")
except Exception as e:
logger.exception(f"[formatter] failed to render {fmt}")
return {"generated_files": generated_files}
# -----------------------------------------------------------------------
# DOCX — python-docx baseline + OOXML template editing
# -----------------------------------------------------------------------
async def _render_docx(
self, draft: dict, data_assets: dict, output_dir: Path, title: str,
template_path: str | None = None,
) -> Path:
if template_path and self.skills["ooxml_docx"]:
return await self._render_docx_from_template(
draft, data_assets, output_dir, title, Path(template_path)
)
return await self._render_docx_baseline(draft, data_assets, output_dir, title)
async def _render_docx_baseline(
self, draft: dict, data_assets: dict, output_dir: Path, title: str
) -> Path:
from docx import Document
from docx.shared import Pt, RGBColor
from docx.enum.text import WD_ALIGN_PARAGRAPH
doc = Document()
# -- Styles --
style = doc.styles["Normal"]
style.font.name = "微软雅黑"
style.font.size = Pt(11)
# Title
t = doc.add_heading(title, level=0)
t.alignment = WD_ALIGN_PARAGRAPH.CENTER
# Executive summary
if summary := draft.get("executive_summary"):
doc.add_heading("执行摘要", level=1)
# Add summary with highlight styling
p = doc.add_paragraph()
run = p.add_run(summary)
run.font.size = Pt(11)
run.font.color.rgb = RGBColor(0x33, 0x33, 0x33)
# Chapters
for chapter in draft.get("chapters", []):
doc.add_heading(chapter["title"], level=1)
content = chapter.get("content", "")
self._docx_render_markdown(doc, content)
# Tables from data assets
for table_spec in data_assets.get("tables", []):
doc.add_heading(table_spec.get("title", "数据表"), level=2)
self._docx_add_table(doc, table_spec)
# Page break + chart descriptions as placeholders
for chart_spec in data_assets.get("charts", []):
doc.add_heading(chart_spec.get("title", "图表"), level=2)
desc = chart_spec.get("description", "")
chart_type = chart_spec.get("type", "")
doc.add_paragraph(f"[{chart_type.upper()} 图表] {desc}")
# Render chart data as a table too
chart_data = chart_spec.get("data", {})
if labels := chart_data.get("labels"):
for ds in chart_data.get("datasets", []):
self._docx_add_table(doc, {
"headers": ["项目", ds.get("label", "数据")],
"rows": [[str(l), str(v)] for l, v in zip(labels, ds.get("data", []))],
})
path = output_dir / f"{title}.docx"
doc.save(str(path))
return path
async def _render_docx_from_template(
self, draft: dict, data_assets: dict, output_dir: Path, title: str,
template_path: Path,
) -> Path:
"""Edit an existing DOCX template using OOXML unpack/edit/pack workflow."""
unpack_script = DOCX_SKILLS / "ooxml" / "scripts" / "unpack.py"
pack_script = DOCX_SKILLS / "ooxml" / "scripts" / "pack.py"
with tempfile.TemporaryDirectory() as tmpdir:
work_dir = Path(tmpdir) / "unpacked"
# Unpack template
proc = await asyncio.create_subprocess_exec(
"python3", str(unpack_script), str(template_path), str(work_dir),
stdout=subprocess.PIPE, stderr=subprocess.PIPE,
)
await proc.wait()
if proc.returncode != 0:
logger.warning("[formatter] OOXML unpack failed, falling back to baseline")
return await self._render_docx_baseline(draft, data_assets, output_dir, title)
# TODO: edit XML content in work_dir based on draft
# For now, just pack back as-is (template passthrough)
output_path = output_dir / f"{title}.docx"
proc = await asyncio.create_subprocess_exec(
"python3", str(pack_script), str(work_dir), str(output_path),
stdout=subprocess.PIPE, stderr=subprocess.PIPE,
)
await proc.wait()
return output_path
def _docx_render_markdown(self, doc, content: str):
"""Convert markdown-ish content to docx paragraphs."""
from docx.shared import Pt
for block in content.split("\n\n"):
block = block.strip()
if not block:
continue
if block.startswith("#### "):
doc.add_heading(block[5:], level=4)
elif block.startswith("### "):
doc.add_heading(block[4:], level=3)
elif block.startswith("## "):
doc.add_heading(block[3:], level=2)
elif block.startswith("- ") or block.startswith("* "):
# Bullet list
for line in block.split("\n"):
line = line.lstrip("- *").strip()
if line:
doc.add_paragraph(line, style="List Bullet")
elif block.startswith("1. ") or block.startswith("1"):
# Numbered list
for line in block.split("\n"):
text = line.lstrip("0123456789.) ").strip()
if text:
doc.add_paragraph(text, style="List Number")
else:
p = doc.add_paragraph(block)
for run in p.runs:
run.font.size = Pt(11)
def _docx_add_table(self, doc, table_spec: dict):
"""Add a formatted table to the document."""
from docx.shared import Pt, RGBColor
from docx.oxml.ns import qn
headers = table_spec.get("headers", [])
rows = table_spec.get("rows", [])
if not headers:
return
tbl = doc.add_table(rows=1 + len(rows), cols=len(headers))
tbl.style = "Light Grid Accent 1"
# Header row
for i, h in enumerate(headers):
cell = tbl.rows[0].cells[i]
cell.text = str(h)
for p in cell.paragraphs:
for run in p.runs:
run.font.bold = True
run.font.size = Pt(10)
# Data rows
for r_idx, row in enumerate(rows):
for c_idx, cell_val in enumerate(row):
tbl.rows[r_idx + 1].cells[c_idx].text = str(cell_val)
# -----------------------------------------------------------------------
# PPTX — html2pptx.js (rich) or python-pptx (fallback)
# -----------------------------------------------------------------------
async def _render_pptx(
self, draft: dict, data_assets: dict, output_dir: Path, title: str
) -> Path:
if self.skills["html2pptx"]:
try:
return await self._render_pptx_html2pptx(draft, data_assets, output_dir, title)
except Exception as e:
logger.warning(f"[formatter] html2pptx failed ({e}), falling back to python-pptx")
return await self._render_pptx_baseline(draft, data_assets, output_dir, title)
async def _render_pptx_html2pptx(
self, draft: dict, data_assets: dict, output_dir: Path, title: str
) -> Path:
"""Generate PPTX using html2pptx.js skill for visual slides."""
with tempfile.TemporaryDirectory() as tmpdir:
work = Path(tmpdir)
# Generate HTML slides
slides_html = []
# Title slide
slides_html.append(f"""<html><body style="width:720pt;height:405pt;display:flex;align-items:center;justify-content:center;flex-direction:column;background:linear-gradient(135deg,#1a1a2e,#16213e);color:white;font-family:sans-serif;">
<h1 style="font-size:36pt;margin:0;">{title}</h1>
<p style="font-size:18pt;color:#aaa;margin-top:20pt;">{draft.get('executive_summary', '')[:100]}</p>
</body></html>""")
# Chapter slides
for ch in draft.get("chapters", []):
content_lines = ch.get("content", "")[:400].split("\n")
bullets = "".join(f"<li>{l.strip()}</li>" for l in content_lines if l.strip())
slides_html.append(f"""<html><body style="width:720pt;height:405pt;padding:40pt;font-family:sans-serif;background:#ffffff;">
<h2 style="font-size:28pt;color:#1a1a2e;border-bottom:2pt solid #e94560;padding-bottom:10pt;">{ch['title']}</h2>
<ul style="font-size:14pt;color:#333;line-height:1.8;">{bullets}</ul>
</body></html>""")
# Write HTML files
for i, html in enumerate(slides_html):
(work / f"slide_{i}.html").write_text(html, encoding="utf-8")
# Write conversion script
script = work / "convert.js"
html2pptx_path = PPTX_SKILLS / "scripts" / "html2pptx.js"
slide_files = [f"slide_{i}.html" for i in range(len(slides_html))]
script.write_text(f"""\
const pptxgen = require('pptxgenjs');
const {{ html2pptx }} = require('{html2pptx_path}');
const path = require('path');
async function main() {{
const pptx = new pptxgen();
pptx.layout = 'LAYOUT_16x9';
const files = {json.dumps(slide_files)};
for (const f of files) {{
await html2pptx(path.join('{work}', f), pptx);
}}
await pptx.writeFile({{ fileName: '{output_dir / f"{title}.pptx"}' }});
}}
main().catch(e => {{ console.error(e); process.exit(1); }});
""", encoding="utf-8")
proc = await asyncio.create_subprocess_exec(
"node", str(script),
stdout=subprocess.PIPE, stderr=subprocess.PIPE,
cwd=str(work),
)
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
raise RuntimeError(f"html2pptx failed: {stderr.decode()}")
return output_dir / f"{title}.pptx"
async def _render_pptx_baseline(
self, draft: dict, data_assets: dict, output_dir: Path, title: str
) -> Path:
from pptx import Presentation
from pptx.util import Inches, Pt
from pptx.dml.color import RGBColor
prs = Presentation()
# Title slide
slide = prs.slides.add_slide(prs.slide_layouts[0])
slide.shapes.title.text = title
if len(slide.placeholders) > 1:
slide.placeholders[1].text = draft.get("executive_summary", "")[:200]
# Chapter slides
for chapter in draft.get("chapters", []):
slide = prs.slides.add_slide(prs.slide_layouts[1])
slide.shapes.title.text = chapter["title"]
body = slide.placeholders[1]
tf = body.text_frame
tf.clear()
content = chapter.get("content", "")
lines = [l.strip() for l in content.split("\n") if l.strip()]
for line in lines[:12]: # max 12 bullets per slide
p = tf.add_paragraph()
# Strip markdown markers
clean = line.lstrip("#-*0123456789. ").strip()
p.text = clean
p.font.size = Pt(14)
p.space_after = Pt(4)
# Data table slides
for table_spec in data_assets.get("tables", []):
slide = prs.slides.add_slide(prs.slide_layouts[5]) # blank layout
slide.shapes.title.text = table_spec.get("title", "数据表")
headers = table_spec.get("headers", [])
rows = table_spec.get("rows", [])
if headers and rows:
n_rows = min(len(rows) + 1, 10) # limit rows per slide
n_cols = len(headers)
tbl = slide.shapes.add_table(
n_rows, n_cols,
Inches(0.5), Inches(1.5), Inches(9), Inches(4.5)
).table
for i, h in enumerate(headers):
tbl.cell(0, i).text = str(h)
for r_idx, row in enumerate(rows[:n_rows - 1]):
for c_idx, val in enumerate(row[:n_cols]):
tbl.cell(r_idx + 1, c_idx).text = str(val)
path = output_dir / f"{title}.pptx"
prs.save(str(path))
return path
# -----------------------------------------------------------------------
# XLSX — openpyxl + recalc.py (formula recalculation)
# -----------------------------------------------------------------------
async def _render_xlsx(
self, data_assets: dict, output_dir: Path, title: str
) -> Path:
from openpyxl import Workbook
from openpyxl.styles import Font, PatternFill, Alignment, Border, Side
from openpyxl.utils import get_column_letter
wb = Workbook()
ws = wb.active
ws.title = "数据总览"
# Professional styling
header_font = Font(bold=True, size=11, color="FFFFFF")
header_fill = PatternFill(start_color="1A1A2E", end_color="1A1A2E", fill_type="solid")
title_font = Font(bold=True, size=14, color="1A1A2E")
thin_border = Border(
left=Side(style="thin", color="CCCCCC"),
right=Side(style="thin", color="CCCCCC"),
top=Side(style="thin", color="CCCCCC"),
bottom=Side(style="thin", color="CCCCCC"),
)
current_row = 1
has_formulas = False
for table_spec in data_assets.get("tables", []):
# Table title
ws.cell(row=current_row, column=1, value=table_spec.get("title", "")).font = title_font
current_row += 1
headers = table_spec.get("headers", [])
rows = table_spec.get("rows", [])
if headers:
# Header row with styling
for col_idx, h in enumerate(headers, 1):
cell = ws.cell(row=current_row, column=col_idx, value=h)
cell.font = header_font
cell.fill = header_fill
cell.alignment = Alignment(horizontal="center")
cell.border = thin_border
current_row += 1
# Data rows
data_start = current_row
for row_data in rows:
for col_idx, val in enumerate(row_data, 1):
cell = ws.cell(row=current_row, column=col_idx, value=val)
cell.border = thin_border
# Try to convert numeric strings
if isinstance(val, str):
try:
cell.value = float(val.replace(",", ""))
except (ValueError, AttributeError):
pass
current_row += 1
# Auto-sum row for numeric columns
data_end = current_row - 1
if data_end > data_start:
for col_idx in range(1, len(headers) + 1):
col_letter = get_column_letter(col_idx)
test_cell = ws.cell(row=data_start, column=col_idx)
if isinstance(test_cell.value, (int, float)):
cell = ws.cell(
row=current_row, column=col_idx,
value=f"=SUM({col_letter}{data_start}:{col_letter}{data_end})"
)
cell.font = Font(bold=True)
cell.border = thin_border
has_formulas = True
elif col_idx == 1:
cell = ws.cell(row=current_row, column=1, value="合计")
cell.font = Font(bold=True)
cell.border = thin_border
current_row += 1
# Auto-fit column widths
for col_idx in range(1, len(headers) + 1):
max_len = max(
len(str(ws.cell(row=r, column=col_idx).value or ""))
for r in range(current_row - len(rows) - 2, current_row)
)
ws.column_dimensions[get_column_letter(col_idx)].width = min(max_len + 4, 30)
current_row += 2 # gap between tables
# Chart data sheets
for chart_spec in data_assets.get("charts", []):
chart_ws = wb.create_sheet(title=chart_spec.get("title", "图表")[:31])
chart_ws.cell(row=1, column=1, value=chart_spec.get("title", "")).font = title_font
chart_data = chart_spec.get("data", {})
labels = chart_data.get("labels", [])
datasets = chart_data.get("datasets", [])
# Headers: [项目, 数据集1, 数据集2, ...]
chart_ws.cell(row=2, column=1, value="项目").font = Font(bold=True)
for ds_idx, ds in enumerate(datasets, 2):
chart_ws.cell(row=2, column=ds_idx, value=ds.get("label", "")).font = Font(bold=True)
for r_idx, label in enumerate(labels, 3):
chart_ws.cell(row=r_idx, column=1, value=label)
for ds_idx, ds in enumerate(datasets, 2):
data = ds.get("data", [])
if r_idx - 3 < len(data):
chart_ws.cell(row=r_idx, column=ds_idx, value=data[r_idx - 3])
path = output_dir / f"{title}.xlsx"
wb.save(str(path))
# Run recalc.py if we have formulas and the skill is available
if has_formulas and self.skills["recalc"]:
await self._xlsx_recalc(path)
return path
async def _xlsx_recalc(self, path: Path):
"""Recalculate formulas using Skills recalc.py (requires LibreOffice)."""
recalc_script = XLSX_SKILLS / "recalc.py"
logger.info(f"[formatter] running recalc.py on {path}")
try:
proc = await asyncio.create_subprocess_exec(
"python3", str(recalc_script), str(path), "30",
stdout=subprocess.PIPE, stderr=subprocess.PIPE,
)
stdout, stderr = await proc.communicate()
if proc.returncode == 0:
result = json.loads(stdout.decode())
logger.info(f"[formatter] recalc result: {result.get('status')}")
else:
logger.warning(f"[formatter] recalc.py failed: {stderr.decode()[:200]}")
except Exception as e:
logger.warning(f"[formatter] recalc.py error: {e}")
# -----------------------------------------------------------------------
# PDF — reportlab with CJK support + fpdf2 fallback
# -----------------------------------------------------------------------
async def _render_pdf(
self, draft: dict, data_assets: dict, output_dir: Path, title: str
) -> Path:
try:
return await self._render_pdf_reportlab(draft, data_assets, output_dir, title)
except Exception as e:
logger.warning(f"[formatter] reportlab failed ({e}), falling back to fpdf2")
return await self._render_pdf_fpdf(draft, data_assets, output_dir, title)
async def _render_pdf_reportlab(
self, draft: dict, data_assets: dict, output_dir: Path, title: str
) -> Path:
"""Generate PDF with reportlab — better CJK support and table rendering."""
from reportlab.lib.pagesizes import A4
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
from reportlab.lib.units import mm
from reportlab.lib import colors
from reportlab.platypus import (
SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle, PageBreak,
)
from reportlab.pdfbase import pdfmetrics
from reportlab.pdfbase.ttfonts import TTFont
# Try to register a CJK font
cjk_font = "Helvetica"
for font_path in [
"/System/Library/Fonts/STHeiti Medium.ttc",
"/System/Library/Fonts/PingFang.ttc",
"/usr/share/fonts/truetype/noto/NotoSansCJK-Regular.ttc",
]:
if Path(font_path).exists():
try:
pdfmetrics.registerFont(TTFont("CJK", font_path, subfontIndex=0))
cjk_font = "CJK"
break
except Exception:
continue
path = output_dir / f"{title}.pdf"
doc = SimpleDocTemplate(str(path), pagesize=A4,
topMargin=25*mm, bottomMargin=25*mm)
styles = getSampleStyleSheet()
styles.add(ParagraphStyle(
name="CJKTitle", fontName=cjk_font, fontSize=22,
spaceAfter=12, alignment=1,
))
styles.add(ParagraphStyle(
name="CJKHeading", fontName=cjk_font, fontSize=16,
spaceAfter=8, spaceBefore=16, textColor=colors.HexColor("#1a1a2e"),
))
styles.add(ParagraphStyle(
name="CJKBody", fontName=cjk_font, fontSize=11,
spaceAfter=6, leading=16,
))
elements = []
# Title
elements.append(Paragraph(title, styles["CJKTitle"]))
elements.append(Spacer(1, 12))
# Executive summary
if summary := draft.get("executive_summary"):
elements.append(Paragraph("执行摘要", styles["CJKHeading"]))
elements.append(Paragraph(summary, styles["CJKBody"]))
elements.append(Spacer(1, 12))
# Chapters
for chapter in draft.get("chapters", []):
elements.append(PageBreak())
elements.append(Paragraph(chapter["title"], styles["CJKHeading"]))
content = chapter.get("content", "")
for para in content.split("\n\n"):
para = para.strip()
if para:
elements.append(Paragraph(para, styles["CJKBody"]))
# Tables
for table_spec in data_assets.get("tables", []):
elements.append(Spacer(1, 12))
elements.append(Paragraph(table_spec.get("title", ""), styles["CJKHeading"]))
headers = table_spec.get("headers", [])
rows = table_spec.get("rows", [])
if headers:
table_data = [headers] + rows
t = Table(table_data)
t.setStyle(TableStyle([
("BACKGROUND", (0, 0), (-1, 0), colors.HexColor("#1a1a2e")),
("TEXTCOLOR", (0, 0), (-1, 0), colors.white),
("FONTNAME", (0, 0), (-1, -1), cjk_font),
("FONTSIZE", (0, 0), (-1, 0), 10),
("FONTSIZE", (0, 1), (-1, -1), 9),
("GRID", (0, 0), (-1, -1), 0.5, colors.grey),
("ALIGN", (0, 0), (-1, -1), "CENTER"),
("ROWBACKGROUNDS", (0, 1), (-1, -1), [colors.white, colors.HexColor("#f5f5f5")]),
]))
elements.append(t)
doc.build(elements)
return path
async def _render_pdf_fpdf(
self, draft: dict, data_assets: dict, output_dir: Path, title: str
) -> Path:
"""Fallback PDF generation with fpdf2."""
from fpdf import FPDF
pdf = FPDF()
pdf.set_auto_page_break(auto=True, margin=15)
# Try CJK font
for font_path in [
"/System/Library/Fonts/STHeiti Medium.ttc",
"/System/Library/Fonts/PingFang.ttc",
]:
if Path(font_path).exists():
try:
pdf.add_font("CJK", "", font_path, uni=True)
pdf.set_font("CJK", "", 11)
break
except Exception:
pdf.set_font("Helvetica", "", 11)
else:
pdf.set_font("Helvetica", "", 11)
pdf.add_page()
pdf.set_font_size(24)
pdf.cell(0, 20, title, new_x="LMARGIN", new_y="NEXT", align="C")
pdf.set_font_size(11)
if summary := draft.get("executive_summary"):
pdf.set_font_size(16)
pdf.cell(0, 12, "执行摘要", new_x="LMARGIN", new_y="NEXT")
pdf.set_font_size(11)
pdf.multi_cell(0, 6, summary)
for chapter in draft.get("chapters", []):
pdf.add_page()
pdf.set_font_size(16)
pdf.cell(0, 12, chapter["title"], new_x="LMARGIN", new_y="NEXT")
pdf.set_font_size(11)
pdf.multi_cell(0, 6, chapter.get("content", ""))
path = output_dir / f"{title}.pdf"
pdf.output(str(path))
return path

103
app/agents/researcher.py Normal file
View File

@@ -0,0 +1,103 @@
"""Researcher Agent — domain-aware, bilingual research."""
from __future__ import annotations
from typing import Any
from .base import BaseAgent
from app.config import settings
SYSTEM_EN = """\
You are a senior industry analyst at a top-tier consulting firm.
Your task is to produce a thorough research brief based on the given instructions.
Requirements:
1. Be specific — cite concrete data points, market sizes, growth rates, company names
2. Be structured — organize findings with clear headings and logical flow
3. Be analytical — don't just list facts, provide insights and implications
4. Flag data gaps — explicitly note where data is uncertain or unavailable
Output (JSON):
{
"title": "Research brief title",
"executive_summary": "2-3 sentence summary of key findings",
"sections": [
{
"heading": "Section heading",
"content": "Detailed findings (Markdown)",
"data_points": ["key data points extracted"],
"sources_quality": "high|medium|low — how confident are you in the data"
}
],
"data_gaps": ["areas where data is insufficient or uncertain"],
"key_insights": ["top 3-5 non-obvious insights"]
}"""
SYSTEM_ZH = """\
你是一位顶级咨询公司的资深行业分析师。
你的任务是根据给定的指令,输出一份深度研究简报。
要求:
1. 具体——引用具体的数据点、市场规模、增长率、企业名称
2. 结构化——用清晰的标题和逻辑流组织发现
3. 有分析深度——不要只罗列事实,要提供洞察和含义
4. 标注数据缺口——明确指出数据不确定或不可获取的地方
输出JSON
{
"title": "研究简报标题",
"executive_summary": "核心发现的2-3句总结",
"sections": [
{
"heading": "章节标题",
"content": "详细发现Markdown格式",
"data_points": ["提取的关键数据点"],
"sources_quality": "high|medium|low — 对数据的置信度"
}
],
"data_gaps": ["数据不充分或不确定的领域"],
"key_insights": ["3-5条非显而易见的洞察"]
}"""
class ResearcherAgent(BaseAgent):
name = "researcher"
description = "域感知研究 — 根据领域选择最优模型和语言"
def __init__(self, model: str | None = None, language: str = "en"):
super().__init__(model=model)
self.language = language
self.system_prompt = SYSTEM_ZH if language == "zh" else SYSTEM_EN
async def run(self, context: dict[str, Any]) -> dict[str, Any]:
requirement = context["requirement"]
report_type = context.get("report_type", "")
extra_data = context.get("extra_data", "")
if self.language == "zh":
prompt = f"""\
## 研究指令
{requirement}
## 研究方向
{report_type}
## 补充数据
{extra_data if extra_data else "(无)"}
请输出研究简报 JSON。"""
else:
prompt = f"""\
## Research instructions
{requirement}
## Research focus
{report_type}
## Additional data
{extra_data if extra_data else "(none)"}
Output the research brief as JSON."""
result = await self.call_llm_json(prompt, max_tokens=6144)
return {"research": result}

79
app/agents/reviewer.py Normal file
View File

@@ -0,0 +1,79 @@
"""Reviewer Agent — bilingual quality check with strongest reasoning model."""
from __future__ import annotations
import json
from typing import Any
from .base import BaseAgent
from app.config import settings
class ReviewerAgent(BaseAgent):
name = "reviewer"
description = "双语报告质量审查 — 使用最强推理模型"
system_prompt = """\
You are a senior consulting partner reviewing a report before client delivery.
The report has both Chinese and English versions (or will be translated).
Review dimensions:
1. **Accuracy** — Are data points, percentages, and claims supported by the research?
Cross-check global claims against English research, Chinese claims against Chinese research.
2. **Logical consistency** — Does the narrative flow? Are there contradictions between chapters?
3. **Depth of analysis** — Is it consultancy-grade or just surface-level? Would a C-suite exec find it valuable?
4. **Bilingual quality** — If translated version exists, check for translation artifacts,
mistranslated terminology, or cultural mismatches.
5. **Data gaps honesty** — Are uncertainties acknowledged or are claims fabricated?
6. **Completeness** — Are any critical aspects of the requirement left unaddressed?
Scoring guide:
- 90+: Publication-ready
- 80-89: Minor issues, can pass with notes
- 70-79: Needs revision (verdict: revise)
- <70: Significant problems (verdict: reject)
Output (JSON):
{
"overall_score": 85,
"verdict": "pass|revise|reject",
"issues": [
{
"severity": "high|medium|low",
"chapter": "affected chapter",
"dimension": "accuracy|consistency|depth|bilingual|gaps|completeness",
"description": "issue description",
"suggestion": "specific fix suggestion"
}
],
"strengths": ["what the report does well"],
"summary": "Overall assessment (2-3 sentences)"
}"""
def __init__(self):
super().__init__(model=settings.model_for_domain("reasoning"))
async def run(self, context: dict[str, Any]) -> dict[str, Any]:
draft = context["draft"]
draft_translated = context.get("draft_translated", {})
research = context["research"]
sections = [
"## Research Plan (what was asked)",
json.dumps(research, ensure_ascii=False, indent=2),
"",
"## Primary Draft",
json.dumps(draft, ensure_ascii=False, indent=2),
]
if draft_translated:
sections.extend([
"",
"## Translated Version",
json.dumps(draft_translated, ensure_ascii=False, indent=2),
])
prompt = "\n".join(sections) + "\n\nReview the report. Output JSON."
result = await self.call_llm_json(prompt, max_tokens=4096)
return {"review": result}

86
app/agents/writer.py Normal file
View File

@@ -0,0 +1,86 @@
"""Writer Agent — synthesizes multilingual research tracks into a cohesive report."""
from __future__ import annotations
import json
from typing import Any
from .base import BaseAgent
from app.config import settings
class WriterAgent(BaseAgent):
name = "writer"
description = "汇聚多语言/多领域研究成果,撰写完整报告"
system_prompt = """\
You are an expert consulting report writer. Your task is to synthesize research
findings from MULTIPLE parallel tracks (some in English, some in Chinese) into
ONE cohesive, professional consulting report.
CRITICAL RULES:
1. The PRIMARY output language is Chinese (中文) — this is for Chinese clients
2. For global/international sections, the analysis depth must reflect the English research
3. For China-specific sections, preserve the precision of Chinese-native research
4. Maintain professional consulting tone throughout
5. Every claim should trace back to a research track's findings
6. Mark chart/table needs: {{CHART:描述}} and {{TABLE:描述}}
7. If a research track flags "data_gaps", acknowledge uncertainty rather than fabricating
Output (JSON):
{
"title": "报告标题(中文)",
"title_en": "Report Title (English)",
"chapters": [
{
"title": "章节标题",
"content": "章节正文Markdown 格式,中文)",
"source_tracks": ["引用的研究轨道名称"],
"charts": ["图表需求"],
"tables": ["表格需求"]
}
],
"executive_summary": "执行摘要中文300-500字",
"executive_summary_en": "Executive Summary (English, 200-400 words)"
}"""
def __init__(self):
super().__init__(model=settings.model_for_domain("reasoning"))
async def run(self, context: dict[str, Any]) -> dict[str, Any]:
research = context["research"]
requirement = context["requirement"]
revision_feedback = context.get("revision_feedback", "")
# Format multi-track, multilingual research
tracks_text = ""
for track in research.get("tracks", []):
lang_tag = f"[{track.get('native_language', '?').upper()}]"
domain_tag = f"[{track.get('domain', '?')}]"
tracks_text += f"\n### {domain_tag} {lang_tag} {track.get('track', '')}\n"
findings = track.get("findings", {})
tracks_text += json.dumps(findings, ensure_ascii=False, indent=2)
synthesis_guide = research.get("synthesis_guide", "")
prompt = f"""\
## 原始需求 / Original Requirement
{requirement}
## 报告标题
中文:{research.get("title_zh", "")}
English: {research.get("title_en", "")}
## 写作指导 / Synthesis Guide
{synthesis_guide}
## 各研究轨道成果 / Research Track Results
(注意:有些轨道是英文原版 [EN],有些是中文原版 [ZH],请综合使用)
{tracks_text}
{f"## 审稿反馈 / Review Feedback{revision_feedback}" if revision_feedback else ""}
请汇聚以上研究成果,撰写完整的中文报告。输出 JSON。"""
result = await self.call_llm_json(prompt, max_tokens=8192)
return {"draft": result}

0
app/api/__init__.py Normal file
View File

110
app/api/routes.py Normal file
View File

@@ -0,0 +1,110 @@
"""API routes for report generation."""
from __future__ import annotations
import logging
from fastapi import APIRouter, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from app.graph.state import ReportState
from app.pipeline.orchestrator import PipelineOrchestrator
from app.pipeline.task import create_report_state
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api")
# In-memory store (swap for DB later)
reports: dict[str, ReportState] = {}
orchestrator = PipelineOrchestrator()
class CreateReportRequest(BaseModel):
requirement: str
report_type: str = "行业分析报告"
extra_data: str = ""
output_formats: list[str] = ["docx"]
client_id: str | None = None
class ReportResponse(BaseModel):
id: str
current_node: str
error: str | None = None
generated_files: list[str] = []
node_history: list[dict] = []
revision_count: int = 0
def _to_response(state: ReportState) -> ReportResponse:
return ReportResponse(
id=state.id,
current_node=state.current_node,
error=state.error,
generated_files=state.generated_files,
node_history=state.node_history,
revision_count=state.revision_count,
)
@router.post("/reports", response_model=ReportResponse)
async def create_report(req: CreateReportRequest):
"""Create and execute a report generation pipeline."""
state = create_report_state(
requirement=req.requirement,
report_type=req.report_type,
extra_data=req.extra_data,
output_formats=req.output_formats,
client_id=req.client_id,
)
reports[state.id] = state
# Run the full graph (blocking for now, add task queue later)
state = await orchestrator.run(state)
reports[state.id] = state
if state.error:
raise HTTPException(status_code=500, detail=state.error)
return _to_response(state)
@router.get("/reports/{report_id}", response_model=ReportResponse)
async def get_report(report_id: str):
"""Get report status and results."""
state = reports.get(report_id)
if not state:
raise HTTPException(status_code=404, detail="Report not found")
return _to_response(state)
@router.get("/reports")
async def list_reports():
"""List all reports."""
return [_to_response(s) for s in reports.values()]
@router.get("/reports/{report_id}/detail")
async def get_report_detail(report_id: str):
"""Get full report detail including draft and research."""
state = reports.get(report_id)
if not state:
raise HTTPException(status_code=404, detail="Report not found")
return {
"id": state.id,
"requirement": state.requirement,
"decomposition": state.decomposition,
"research_results": [
{
"description": r.description,
"status": r.status.value,
"duration_ms": r.duration_ms,
}
for r in state.research_results
],
"draft": state.draft,
"review": state.review,
"generated_files": state.generated_files,
"node_history": state.node_history,
}

57
app/config.py Normal file
View File

@@ -0,0 +1,57 @@
"""Application configuration — multi-model pool by domain."""
from pathlib import Path
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
# --- Model Pool ---
# Global analysis (English-native): strongest global reasoning
llm_global: str = "openai/Claude-3.5-Sonnet"
# China domestic (Chinese-native): deepest Chinese market knowledge
llm_china: str = "openai/DeepSeek-R1"
# Synthesis & review: strongest reasoning for cross-domain work
llm_reasoning: str = "openai/Claude-3.5-Sonnet"
# Data & fast tasks: cost-effective structured output
llm_fast: str = "openai/Gemini-2.0-Flash"
# Translation: high-quality bidirectional EN↔ZH
llm_translation: str = "openai/Claude-3.5-Sonnet"
# Fallback default (if domain not specified)
llm_model: str = "openai/Claude-3.5-Sonnet"
# API config (Poe as unified gateway)
llm_api_key: str = ""
llm_api_base: str = "https://api.poe.com/bot/"
# Server
host: str = "0.0.0.0"
port: int = 4200
# Paths
base_dir: Path = Path(__file__).resolve().parent.parent
templates_dir: Path = base_dir / "templates"
output_dir: Path = base_dir / "output"
model_config = {"env_file": ".env", "env_file_encoding": "utf-8"}
def model_for_domain(self, domain: str) -> str:
"""Get the best model for a content domain.
Domains:
global — international markets, global competition, tech trends
china — Chinese market, domestic policy, local competition
reasoning — synthesis, review, strategic recommendations
fast — data processing, chart generation, structured output
translation — high-quality EN↔ZH translation
"""
return {
"global": self.llm_global,
"china": self.llm_china,
"reasoning": self.llm_reasoning,
"fast": self.llm_fast,
"translation": self.llm_translation,
}.get(domain, self.llm_model)
settings = Settings()

0
app/data/__init__.py Normal file
View File

52
app/data/factory.py Normal file
View File

@@ -0,0 +1,52 @@
"""Data router factory — assembles the router with all available sources."""
from __future__ import annotations
import logging
from .router import DataRouter
from .sources.akshare_source import AKShareSource
from .sources.worldbank_source import WorldBankSource
from .sources.gpt_researcher_source import GPTResearcherSource
logger = logging.getLogger(__name__)
def create_data_router() -> DataRouter:
"""Create a DataRouter with all available data sources.
Priority order (lower = tried first):
10 AKShare — free, fast, covers Chinese macro/industry
20 World Bank — free, global macro
90 GPT Researcher — universal fallback (web research)
Future additions:
30 巨潮资讯 — free tier, Chinese public companies
40 天眼查 API — paid, company data
50 Choice API — paid, financial data
60 Statista API — paid, global industry stats
70 FRED API — free, US macro
80 UN Comtrade — free, global trade
"""
router = DataRouter()
# --- Free tier (always available) ---
router.register(AKShareSource(), priority=10)
router.register(WorldBankSource(), priority=20)
# --- Universal fallback ---
router.register(GPTResearcherSource(), priority=90)
logger.info(f"[data_factory] created router with {len(router.sources)} sources")
return router
# Singleton
_router: DataRouter | None = None
def get_data_router() -> DataRouter:
global _router
if _router is None:
_router = create_data_router()
return _router

82
app/data/router.py Normal file
View File

@@ -0,0 +1,82 @@
"""DataRouter — routes data requests to the optimal source.
Architecture:
Agent needs data → DataRouter → best available source → standardized result
Sources are tried in priority order. If a source fails or is unavailable,
the router falls back to the next source.
"""
from __future__ import annotations
import logging
from typing import Any
from .sources.base import DataSource, DataResult
logger = logging.getLogger(__name__)
class DataRouter:
"""Routes data queries to the best available source.
Sources are registered with a priority (lower = higher priority).
For each query, the router tries sources in priority order until one succeeds.
"""
def __init__(self):
self.sources: list[tuple[int, DataSource]] = []
def register(self, source: DataSource, priority: int = 100) -> "DataRouter":
self.sources.append((priority, source))
self.sources.sort(key=lambda x: x[0])
logger.info(f"[data_router] registered {source.name} (priority={priority})")
return self
async def query(
self,
query: str,
data_type: str = "general",
country: str | None = None,
**kwargs,
) -> DataResult:
"""Query data sources in priority order.
Args:
query: Natural language or structured data query
data_type: One of: macro, industry, company, trade, patent, general
country: ISO country code (CN, US, etc.) or None for global
**kwargs: Source-specific parameters
"""
errors = []
for priority, source in self.sources:
if not source.supports(data_type, country):
continue
try:
logger.info(f"[data_router] trying {source.name} for '{query[:50]}...'")
result = await source.fetch(query, data_type=data_type, country=country, **kwargs)
if result.data:
logger.info(f"[data_router] {source.name} returned {len(str(result.data))} chars")
return result
except Exception as e:
errors.append(f"{source.name}: {e}")
logger.warning(f"[data_router] {source.name} failed: {e}")
# All sources failed
return DataResult(
source="none",
data=None,
error=f"All sources failed: {'; '.join(errors)}",
)
async def query_multiple(
self,
queries: list[dict[str, Any]],
) -> list[DataResult]:
"""Run multiple queries (can be parallelized later)."""
import asyncio
return await asyncio.gather(*[
self.query(**q) for q in queries
])

View File

View File

@@ -0,0 +1,94 @@
"""AKShare data source — Chinese macro/industry data via open-source Python library.
Covers: GDP, CPI, PMI, industrial profit, trade balance, and 30+ data categories.
All data returned as Pandas DataFrames, converted to dicts for standardization.
"""
from __future__ import annotations
import logging
from typing import Any
from .base import DataSource, DataResult
logger = logging.getLogger(__name__)
# Map common data requests to AKShare function names
AKSHARE_ENDPOINTS = {
"gdp": "macro_china_gdp",
"cpi": "macro_china_cpi_monthly",
"ppi": "macro_china_ppi",
"pmi": "macro_china_pmi",
"industrial_profit": "macro_china_industrial_profit",
"trade_balance": "macro_china_trade_balance",
"money_supply": "macro_china_money_supply",
"fdi": "macro_china_fdi",
"real_estate": "macro_china_real_estate",
"retail_sales": "macro_china_consumer_goods_retail",
"fixed_asset": "macro_china_fai",
"unemployment": "macro_china_urban_unemployment",
# US macro
"us_gdp": "macro_usa_gdp_monthly",
"us_cpi": "macro_usa_cpi_monthly",
"us_unemployment": "macro_usa_unemployment_rate",
# Global
"global_gdp": "macro_global_gdp",
}
class AKShareSource(DataSource):
name = "akshare"
description = "中国宏观经济/行业数据免费开源封装统计局等30+数据源)"
def supports(self, data_type: str, country: str | None = None) -> bool:
return data_type in ("macro", "industry", "general")
async def fetch(
self, query: str, *, data_type: str = "general", country: str | None = None, **kwargs,
) -> DataResult:
try:
import akshare as ak
except ImportError:
return DataResult(source=self.name, error="akshare not installed (pip install akshare)")
# Try to match query to a known endpoint
endpoint_name = kwargs.get("endpoint")
if not endpoint_name:
query_lower = query.lower()
for key, func_name in AKSHARE_ENDPOINTS.items():
if key in query_lower:
endpoint_name = func_name
break
if not endpoint_name:
return DataResult(source=self.name, data=None, error=f"No matching AKShare endpoint for: {query}")
try:
func = getattr(ak, endpoint_name, None)
if not func:
return DataResult(source=self.name, error=f"AKShare function not found: {endpoint_name}")
logger.info(f"[akshare] calling ak.{endpoint_name}()")
df = func()
# Convert to dict for serialization
# Take last N rows for recent data
limit = kwargs.get("limit", 20)
recent = df.tail(limit)
return DataResult(
source=self.name,
data={
"columns": list(recent.columns),
"records": recent.to_dict(orient="records"),
"total_rows": len(df),
"returned_rows": len(recent),
},
metadata={
"endpoint": endpoint_name,
"description": f"AKShare {endpoint_name}",
"format": "tabular",
},
)
except Exception as e:
return DataResult(source=self.name, error=f"AKShare call failed: {e}")

34
app/data/sources/base.py Normal file
View File

@@ -0,0 +1,34 @@
"""Base class for data sources."""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any
from pydantic import BaseModel, Field
class DataResult(BaseModel):
"""Standardized result from any data source."""
source: str = ""
data: Any = None
metadata: dict[str, Any] = Field(default_factory=dict)
# metadata includes: unit, time_range, update_date, confidence, etc.
error: str | None = None
cached: bool = False
class DataSource(ABC):
"""Abstract data source."""
name: str = "base"
description: str = ""
def supports(self, data_type: str, country: str | None = None) -> bool:
"""Return True if this source can handle this data type / country."""
return True
@abstractmethod
async def fetch(
self, query: str, *, data_type: str = "general", country: str | None = None, **kwargs,
) -> DataResult:
...

View File

@@ -0,0 +1,61 @@
"""GPT Researcher MCP — deep web research as fallback for any industry.
This is the universal fallback: when structured data sources don't have
data for a niche/cold industry, deep web research fills the gap.
Requires GPT Researcher MCP server to be running (already configured in ~/.claude.json).
For direct API use, we call the MCP tools via the subprocess approach.
"""
from __future__ import annotations
import asyncio
import json
import logging
import subprocess
from typing import Any
from .base import DataSource, DataResult
logger = logging.getLogger(__name__)
class GPTResearcherSource(DataSource):
name = "gpt_researcher"
description = "Deep web research — universal fallback for any industry/topic"
def supports(self, data_type: str, country: str | None = None) -> bool:
# Supports everything — this is the universal fallback
return True
async def fetch(
self, query: str, *, data_type: str = "general", country: str | None = None, **kwargs,
) -> DataResult:
mode = kwargs.get("mode", "quick") # "quick" or "deep"
# GPT Researcher is available as MCP tools in Claude Code.
# For standalone use, we need to call it via its API.
# The MCP server runs at a local port — check if available.
# For now, provide a structured placeholder that agents can use
# to request deep research. The actual MCP call happens at the
# agent level when integrated into the pipeline.
return DataResult(
source=self.name,
data={
"query": query,
"mode": mode,
"status": "ready",
"note": (
"GPT Researcher MCP is available for deep web research. "
"Call via MCP tools: deep_research() or quick_search(). "
"This source returns research-ready queries for MCP integration."
),
},
metadata={
"type": "mcp_research_request",
"mode": mode,
"data_type": data_type,
"country": country,
},
)

View File

@@ -0,0 +1,104 @@
"""World Bank Open Data — global macro indicators, 217 economies, free API.
API: https://api.worldbank.org/v2/
"""
from __future__ import annotations
import logging
from typing import Any
import httpx
from .base import DataSource, DataResult
logger = logging.getLogger(__name__)
BASE_URL = "https://api.worldbank.org/v2"
# Common indicators for consulting reports
INDICATORS = {
"gdp": "NY.GDP.MKTP.CD", # GDP (current US$)
"gdp_growth": "NY.GDP.MKTP.KD.ZG", # GDP growth (annual %)
"gdp_per_capita": "NY.GDP.PCAP.CD", # GDP per capita
"population": "SP.POP.TOTL", # Total population
"inflation": "FP.CPI.TOTL.ZG", # Inflation (CPI %)
"trade_pct_gdp": "NE.TRD.GNFS.ZS", # Trade (% of GDP)
"fdi_net": "BX.KLT.DINV.CD.WD", # FDI net inflows
"unemployment": "SL.UEM.TOTL.ZS", # Unemployment (%)
"exports": "NE.EXP.GNFS.CD", # Exports
"imports": "NE.IMP.GNFS.CD", # Imports
"r_and_d": "GB.XPD.RSDV.GD.ZS", # R&D expenditure (% GDP)
"high_tech_exports": "TX.VAL.TECH.MF.ZS", # High-tech exports (% manufactured)
}
class WorldBankSource(DataSource):
name = "worldbank"
description = "World Bank Open Data — 1600+ indicators, 217 economies, free"
def supports(self, data_type: str, country: str | None = None) -> bool:
return data_type in ("macro", "general")
async def fetch(
self, query: str, *, data_type: str = "general", country: str | None = None, **kwargs,
) -> DataResult:
indicator_code = kwargs.get("indicator")
if not indicator_code:
query_lower = query.lower()
for key, code in INDICATORS.items():
if key in query_lower:
indicator_code = code
break
if not indicator_code:
# Default to GDP
indicator_code = INDICATORS["gdp"]
country_code = country or "WLD" # WLD = World
per_page = kwargs.get("per_page", 20)
url = f"{BASE_URL}/country/{country_code}/indicator/{indicator_code}"
params = {
"format": "json",
"per_page": per_page,
}
try:
async with httpx.AsyncClient(timeout=15) as client:
resp = await client.get(url, params=params)
resp.raise_for_status()
data = resp.json()
if not data or len(data) < 2:
return DataResult(source=self.name, data=None, error="No data returned")
metadata_raw = data[0]
records = data[1]
# Parse into clean format
clean_records = []
for r in records:
if r.get("value") is not None:
clean_records.append({
"year": r["date"],
"value": r["value"],
"country": r["country"]["value"],
"indicator": r["indicator"]["value"],
})
return DataResult(
source=self.name,
data={
"indicator": indicator_code,
"country": country_code,
"records": clean_records,
},
metadata={
"total": metadata_raw.get("total", 0),
"indicator_name": clean_records[0]["indicator"] if clean_records else "",
"format": "timeseries",
},
)
except Exception as e:
return DataResult(source=self.name, error=f"World Bank API failed: {e}")

0
app/graph/__init__.py Normal file
View File

108
app/graph/builder.py Normal file
View File

@@ -0,0 +1,108 @@
"""Graph builder — assembles nodes into an executable report generation graph.
No LangGraph dependency. Pure asyncio with a simple node-runner pattern.
v2: domain-aware, bilingual (translate node added)
"""
from __future__ import annotations
import logging
from typing import Callable, Awaitable
from app.middleware.chain import MiddlewareChain
from .state import ReportState, NodeStatus
from .nodes import (
DecomposeNode,
ParallelResearchNode,
WriteNode,
TranslateNode,
DataNode,
ReviewNode,
FormatNode,
)
logger = logging.getLogger(__name__)
NodeFn = Callable[[ReportState], Awaitable[ReportState]]
class ReportGraph:
"""Executable graph for report generation.
Graph structure (v2):
decompose → parallel_research → write → translate → data → review →
├─ pass → format → END
└─ revise → write → translate → data → review → ...
"""
def __init__(self, middleware: MiddlewareChain | None = None):
self.middleware = middleware
self.decompose = DecomposeNode()
self.parallel_research = ParallelResearchNode()
self.write = WriteNode()
self.translate = TranslateNode()
self.data = DataNode()
self.review = ReviewNode()
self.format = FormatNode()
async def _run_node(self, name: str, node: NodeFn, state: ReportState) -> ReportState:
"""Run a single node with error handling."""
try:
logger.info(f"[graph] entering node: {name}")
state = await node(state)
logger.info(f"[graph] completed node: {name}")
except Exception as e:
state.error = f"Node '{name}' failed: {e}"
state.log_node(name, NodeStatus.FAILED, str(e))
logger.exception(f"[graph] node '{name}' failed")
raise
return state
async def run(self, state: ReportState) -> ReportState:
"""Execute the full graph."""
# --- Middleware: before ---
if self.middleware:
state = await self.middleware.before(state)
try:
# 1. Decompose requirement into domain-tagged parallel tracks
state = await self._run_node("decompose", self.decompose, state)
# 2. Run parallel research (each track uses domain-optimal model)
state = await self._run_node("parallel_research", self.parallel_research, state)
# 3-6. Write → Translate → Data → Review (with revision loop)
while True:
state = await self._run_node("write", self.write, state)
state = await self._run_node("translate", self.translate, state)
state = await self._run_node("data", self.data, state)
state = await self._run_node("review", self.review, state)
verdict = state.review.get("verdict", "pass")
if verdict == "pass" or state.revision_count >= state.max_revisions:
if verdict != "pass":
logger.warning(
f"[graph] forcing pass after {state.revision_count} revisions"
)
break
# Revise — loop back to write
state.revision_count += 1
logger.info(
f"[graph] revision {state.revision_count}/{state.max_revisions}"
)
# 7. Format bilingual output files
state = await self._run_node("format", self.format, state)
except Exception:
# Error already logged in _run_node
pass
# --- Middleware: after ---
if self.middleware:
state = await self.middleware.after(state)
return state

393
app/graph/nodes.py Normal file
View File

@@ -0,0 +1,393 @@
"""Graph nodes — each node is an async function: ReportState → ReportState.
Node layout (v2 — domain-aware, bilingual):
START
[decompose] — Lead Agent 分解为并行研究轨道,每轨标注 domain + language
[parallel_research] — N 个子 Agent 并行,每个用最适合该领域的模型
│ global tracks → Claude/GPT (English)
│ china tracks → DeepSeek/Qwen (Chinese)
[write] — Writer 汇聚 → 生成主语言版本
[translate] — 高质量翻译 → 生成另一语言版本
[data] — Data Agent 生成图表/表格
[review] — Reviewer 审查(双语)
│ ├─ pass → [format]
│ └─ revise → [write]
[format] — 输出双语版本文件
END
"""
from __future__ import annotations
import asyncio
import json
import logging
from datetime import datetime
from typing import Any
from app.agents.base import BaseAgent
from app.agents.researcher import ResearcherAgent
from app.agents.writer import WriterAgent
from app.agents.data_agent import DataAgent
from app.agents.reviewer import ReviewerAgent
from app.agents.formatter import FormatterAgent
from app.config import settings
from .state import ReportState, SubtaskResult, NodeStatus, ContentDomain
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Node: decompose — Lead Agent decomposes into domain-tagged parallel tracks
# ---------------------------------------------------------------------------
class DecomposeNode:
"""Analyzes requirement and decomposes into domain-aware research tracks."""
def __init__(self):
self.agent = BaseAgent()
self.agent.name = "lead"
self.agent.model = settings.model_for_domain("reasoning")
async def __call__(self, state: ReportState) -> ReportState:
state.current_node = "decompose"
state.log_node("decompose", NodeStatus.RUNNING)
system = """\
You are a senior consulting partner planning a global industry report.
Your job is to decompose the client's requirement into 2-6 parallel research tracks.
CRITICAL: Each track must be tagged with a content domain and native language:
- domain: "global" → international markets, global competition, technology trends, overseas benchmarks
→ native_language: "en" (English sources are 10-100x richer for global analysis)
- domain: "china" → Chinese domestic market, government policy, local competitors, China-specific data
→ native_language: "zh" (Chinese sources are authoritative for domestic analysis)
The PRINCIPLE: whichever language has the richest professional literature for that topic
should be the native language. The other language version will be translated later.
Output (JSON):
{
"title_en": "English report title",
"title_zh": "中文报告标题",
"report_type": "report type",
"tracks": [
{
"title": "track title (in native language)",
"domain": "global|china",
"native_language": "en|zh",
"focus": "research focus description",
"prompt": "detailed research instructions (MUST be in the native_language)",
"data_needs": ["required data/charts"]
}
],
"synthesis_guide": "How to merge all tracks into a coherent report (bilingual structure notes)",
"methodology": "Analysis methodology"
}"""
prompt = f"""\
## Client requirement
{state.requirement}
## Report type
{state.report_type}
## Additional data
{state.extra_data or "(none)"}
## Client context
{state.client_context or "(none)"}
Decompose into parallel research tracks with domain and language tags. Output JSON."""
result = await self.agent.call_llm_json(prompt, system=system)
state.decomposition = result
state.log_node("decompose", NodeStatus.COMPLETED,
f"{len(result.get('tracks', []))} tracks")
return state
# ---------------------------------------------------------------------------
# Node: parallel_research — domain-aware parallel execution
# ---------------------------------------------------------------------------
class ParallelResearchNode:
"""Runs research subtasks in parallel, each using the optimal model for its domain."""
MAX_CONCURRENT = 5
async def _run_one(self, track: dict[str, Any]) -> SubtaskResult:
domain_str = track.get("domain", "global")
domain = ContentDomain(domain_str) if domain_str in ContentDomain.__members__.values() else ContentDomain.GLOBAL
native_lang = track.get("native_language", "en")
result = SubtaskResult(
description=track.get("title", ""),
domain=domain,
native_language=native_lang,
)
result.status = NodeStatus.RUNNING
result.started_at = datetime.now()
try:
# Select model based on domain
model = settings.model_for_domain(domain.value)
agent = ResearcherAgent(model=model, language=native_lang)
logger.info(
f"[parallel_research] track '{track.get('title')}' "
f"→ domain={domain.value}, lang={native_lang}, model={model}"
)
research = await agent.run({
"requirement": track["prompt"],
"report_type": track.get("focus", ""),
"extra_data": "",
})
result.content = research.get("research", {})
result.status = NodeStatus.COMPLETED
except Exception as e:
result.error = str(e)
result.status = NodeStatus.FAILED
logger.exception(f"Research track '{track.get('title')}' failed")
finally:
result.completed_at = datetime.now()
return result
async def __call__(self, state: ReportState) -> ReportState:
state.current_node = "parallel_research"
state.log_node("parallel_research", NodeStatus.RUNNING)
tracks = state.decomposition.get("tracks", [])
if not tracks:
state.log_node("parallel_research", NodeStatus.FAILED, "no tracks")
state.error = "Decomposition produced no research tracks"
return state
semaphore = asyncio.Semaphore(self.MAX_CONCURRENT)
async def bounded(track):
async with semaphore:
return await self._run_one(track)
logger.info(f"[parallel_research] launching {len(tracks)} tracks concurrently")
results = await asyncio.gather(*[bounded(t) for t in tracks])
state.research_results = list(results)
succeeded = sum(1 for r in results if r.status == NodeStatus.COMPLETED)
domains = {}
for r in results:
domains.setdefault(r.domain.value, []).append(r.native_language)
state.log_node("parallel_research", NodeStatus.COMPLETED,
f"{succeeded}/{len(tracks)} ok, domains={domains}")
return state
# ---------------------------------------------------------------------------
# Node: write — synthesize research into primary-language draft
# ---------------------------------------------------------------------------
class WriteNode:
def __init__(self):
self.agent = WriterAgent()
async def __call__(self, state: ReportState) -> ReportState:
state.current_node = "write"
state.log_node("write", NodeStatus.RUNNING)
research_merged = []
for r in state.research_results:
if r.status == NodeStatus.COMPLETED:
research_merged.append({
"track": r.description,
"domain": r.domain.value,
"native_language": r.native_language,
"findings": r.content,
})
synthesis_guide = state.decomposition.get("synthesis_guide", "")
review_feedback = ""
if state.revision_count > 0 and state.review:
review_feedback = f"\n\n## Review feedback (revision {state.revision_count})\n"
for issue in state.review.get("issues", []):
review_feedback += f"- [{issue.get('severity')}] {issue.get('description')}{issue.get('suggestion')}\n"
result = await self.agent.run({
"requirement": state.requirement,
"research": {
"title_en": state.decomposition.get("title_en", ""),
"title_zh": state.decomposition.get("title_zh", ""),
"methodology": state.decomposition.get("methodology", ""),
"tracks": research_merged,
"synthesis_guide": synthesis_guide,
},
"revision_feedback": review_feedback,
})
state.draft = result.get("draft", {})
state.log_node("write", NodeStatus.COMPLETED)
return state
# ---------------------------------------------------------------------------
# Node: translate — produce the other language version
# ---------------------------------------------------------------------------
class TranslateNode:
"""Translates the draft into the other language version."""
def __init__(self):
self.agent = BaseAgent()
self.agent.name = "translator"
self.agent.model = settings.model_for_domain("translation")
async def __call__(self, state: ReportState) -> ReportState:
state.current_node = "translate"
state.log_node("translate", NodeStatus.RUNNING)
if not state.draft or "en" not in state.output_languages:
state.log_node("translate", NodeStatus.COMPLETED, "skipped")
return state
draft_json = json.dumps(state.draft, ensure_ascii=False, indent=2)
# Detect primary language of draft
title = state.draft.get("title", "")
is_chinese_primary = any('\u4e00' <= c <= '\u9fff' for c in title)
if is_chinese_primary:
target_lang = "English"
source_lang = "Chinese"
else:
target_lang = "Chinese (Simplified)"
source_lang = "English"
system = f"""\
You are a world-class {source_lang}{target_lang} translator specializing in
consulting and business reports.
Translation principles:
1. ACCURACY over fluency — every data point, percentage, and proper noun must be correct
2. Professional terminology — use standard {target_lang} business/industry terms
3. Preserve structure — keep the exact same JSON structure, only translate text values
4. Cultural adaptation — adjust phrasing for the target audience (not word-for-word)
5. Keep {{{{CHART:...}}}} and {{{{TABLE:...}}}} markers, translate their descriptions
Output the translated JSON with the exact same structure."""
prompt = f"""\
Translate this consulting report from {source_lang} to {target_lang}.
{draft_json}
Output the translated JSON."""
translated = await self.agent.call_llm_json(prompt, system=system, max_tokens=8192)
state.draft_translated = translated
state.log_node("translate", NodeStatus.COMPLETED,
f"{source_lang}{target_lang}")
return state
# ---------------------------------------------------------------------------
# Node: data — generate charts and tables
# ---------------------------------------------------------------------------
class DataNode:
def __init__(self):
self.agent = DataAgent()
async def __call__(self, state: ReportState) -> ReportState:
state.current_node = "data"
state.log_node("data", NodeStatus.RUNNING)
result = await self.agent.run({
"draft": state.draft,
"extra_data": state.extra_data,
})
state.data_assets = result.get("data_assets", {})
state.log_node("data", NodeStatus.COMPLETED)
return state
# ---------------------------------------------------------------------------
# Node: review — bilingual quality check
# ---------------------------------------------------------------------------
class ReviewNode:
def __init__(self):
self.agent = ReviewerAgent()
async def __call__(self, state: ReportState) -> ReportState:
state.current_node = "review"
state.log_node("review", NodeStatus.RUNNING)
result = await self.agent.run({
"draft": state.draft,
"draft_translated": state.draft_translated,
"research": state.decomposition,
})
state.review = result.get("review", {})
state.log_node("review", NodeStatus.COMPLETED,
f"verdict={state.review.get('verdict', '?')}")
return state
# ---------------------------------------------------------------------------
# Node: format — render bilingual output files
# ---------------------------------------------------------------------------
class FormatNode:
def __init__(self):
self.agent = FormatterAgent()
async def __call__(self, state: ReportState) -> ReportState:
state.current_node = "format"
state.log_node("format", NodeStatus.RUNNING)
all_files = []
# Primary version
result = await self.agent.run({
"draft": state.draft,
"data_assets": state.data_assets,
"output_dir": str(settings.output_dir / state.id / "primary"),
"output_formats": state.output_formats,
})
all_files.extend(result.get("generated_files", []))
# Translated version (if available)
if state.draft_translated:
result_tr = await self.agent.run({
"draft": state.draft_translated,
"data_assets": state.data_assets,
"output_dir": str(settings.output_dir / state.id / "translated"),
"output_formats": state.output_formats,
})
all_files.extend(result_tr.get("generated_files", []))
state.generated_files = all_files
state.log_node("format", NodeStatus.COMPLETED,
f"{len(all_files)} files")
return state

104
app/graph/state.py Normal file
View File

@@ -0,0 +1,104 @@
"""Report generation graph state — the shared context that flows through all nodes."""
from __future__ import annotations
import uuid
from datetime import datetime
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field
class NodeStatus(str, Enum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
class ContentDomain(str, Enum):
"""Content domain — determines which model and language to use."""
GLOBAL = "global" # International markets, global trends → English-native
CHINA = "china" # Chinese market, domestic policy → Chinese-native
REASONING = "reasoning" # Synthesis, review, strategy → strongest reasoning
FAST = "fast" # Data processing, charts → cost-effective
TRANSLATION = "translation" # EN↔ZH translation
class SubtaskResult(BaseModel):
"""Result from a parallel research subtask."""
task_id: str = Field(default_factory=lambda: uuid.uuid4().hex[:8])
description: str = ""
domain: ContentDomain = ContentDomain.GLOBAL
native_language: str = "en" # "en" or "zh" — the original writing language
status: NodeStatus = NodeStatus.PENDING
content: dict[str, Any] = Field(default_factory=dict)
error: str | None = None
started_at: datetime | None = None
completed_at: datetime | None = None
@property
def duration_ms(self) -> int | None:
if self.started_at and self.completed_at:
return int((self.completed_at - self.started_at).total_seconds() * 1000)
return None
class ReportState(BaseModel):
"""Full state for a report generation run.
This is the single source of truth that all graph nodes read from and write to.
"""
# Identity
id: str = Field(default_factory=lambda: uuid.uuid4().hex[:12])
created_at: datetime = Field(default_factory=datetime.now)
# --- Input (set once at start) ---
requirement: str = ""
report_type: str = "行业分析报告"
extra_data: str = ""
output_formats: list[str] = Field(default=["docx"])
output_languages: list[str] = Field(default=["zh", "en"]) # produce both versions
template_name: str | None = None
client_id: str | None = None # for multi-tenant isolation
# --- Middleware injections ---
client_context: str = "" # injected by ClientContextMiddleware
memory_facts: list[str] = Field(default_factory=list) # injected by MemoryMiddleware
token_budget: int = 120000 # managed by TokenBudgetMiddleware
# --- Lead Agent output ---
decomposition: dict[str, Any] = Field(default_factory=dict)
# e.g. {"tracks": [{"title": "政策环境", "prompt": "研究..."}, ...]}
# --- Parallel research results ---
research_results: list[SubtaskResult] = Field(default_factory=list)
# --- Writer output ---
draft: dict[str, Any] = Field(default_factory=dict) # primary language version
draft_translated: dict[str, Any] = Field(default_factory=dict) # translated version
# --- Data Agent output ---
data_assets: dict[str, Any] = Field(default_factory=dict)
# --- Reviewer output ---
review: dict[str, Any] = Field(default_factory=dict)
revision_count: int = 0
max_revisions: int = 2
# --- Formatter output ---
generated_files: list[str] = Field(default_factory=list)
# --- Execution tracking ---
current_node: str = ""
node_history: list[dict[str, Any]] = Field(default_factory=list)
error: str | None = None
def log_node(self, node_name: str, status: NodeStatus, detail: str = ""):
self.node_history.append({
"node": node_name,
"status": status.value,
"detail": detail,
"timestamp": datetime.now().isoformat(),
})

48
app/main.py Normal file
View File

@@ -0,0 +1,48 @@
"""FastAPI application entry point."""
import logging
from fastapi import FastAPI
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from app.api.routes import router
from app.config import settings
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(name)s] %(levelname)s: %(message)s",
)
app = FastAPI(
title="咨询报告 AI 生成系统",
version="0.1.0",
description="Multi-agent pipeline for generating consulting reports",
)
app.include_router(router)
# Serve generated files
settings.output_dir.mkdir(parents=True, exist_ok=True)
app.mount("/files", StaticFiles(directory=str(settings.output_dir)), name="files")
@app.get("/")
async def root():
return {
"name": "咨询报告 AI 生成系统",
"version": "0.1.0",
"status": "running",
"docs": "/docs",
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"app.main:app",
host=settings.host,
port=settings.port,
reload=True,
)

0
app/memory/__init__.py Normal file
View File

114
app/memory/store.py Normal file
View File

@@ -0,0 +1,114 @@
"""File-based memory store — persistent facts with confidence ranking."""
from __future__ import annotations
import json
import logging
import uuid
from datetime import datetime
from pathlib import Path
from typing import Any
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
MEMORY_DIR = Path(__file__).resolve().parent.parent.parent / "memory"
class Fact(BaseModel):
id: str = Field(default_factory=lambda: uuid.uuid4().hex[:8])
content: str
category: str = "context" # preference | knowledge | context | behavior | goal
confidence: float = 0.7
source: str = "" # which report/session created this
created_at: str = Field(default_factory=lambda: datetime.now().isoformat())
class MemoryFile(BaseModel):
"""One memory file per client (or global)."""
client_id: str = "global"
preferences: dict[str, str] = Field(default_factory=dict)
facts: list[Fact] = Field(default_factory=list)
class MemoryStore:
"""Read/write persistent memory as JSON files.
Storage layout:
memory/
├── global.json — system-wide facts
└── client_<id>.json — per-client facts
"""
def __init__(self):
MEMORY_DIR.mkdir(parents=True, exist_ok=True)
def _path(self, client_id: str = "global") -> Path:
safe_name = client_id.replace("/", "_").replace("..", "_")
return MEMORY_DIR / f"{safe_name}.json"
def load(self, client_id: str = "global") -> MemoryFile:
path = self._path(client_id)
if not path.exists():
return MemoryFile(client_id=client_id)
try:
data = json.loads(path.read_text(encoding="utf-8"))
return MemoryFile(**data)
except Exception as e:
logger.warning(f"[memory] failed to load {path}: {e}")
return MemoryFile(client_id=client_id)
def save(self, mem: MemoryFile):
path = self._path(mem.client_id)
path.write_text(
json.dumps(mem.model_dump(), ensure_ascii=False, indent=2),
encoding="utf-8",
)
logger.info(f"[memory] saved {len(mem.facts)} facts to {path}")
def add_fact(
self,
content: str,
client_id: str = "global",
category: str = "context",
confidence: float = 0.7,
source: str = "",
) -> Fact:
mem = self.load(client_id)
# Deduplicate by content (normalized)
normalized = content.strip().lower()
for existing in mem.facts:
if existing.content.strip().lower() == normalized:
# Update confidence if higher
if confidence > existing.confidence:
existing.confidence = confidence
self.save(mem)
return existing
fact = Fact(
content=content,
category=category,
confidence=confidence,
source=source,
)
mem.facts.append(fact)
self.save(mem)
return fact
def get_top_facts(
self, client_id: str = "global", limit: int = 15
) -> list[str]:
"""Get top N facts sorted by confidence, formatted for prompt injection."""
mem = self.load(client_id)
sorted_facts = sorted(mem.facts, key=lambda f: f.confidence, reverse=True)
return [f.content for f in sorted_facts[:limit]]
def set_preference(self, key: str, value: str, client_id: str = "global"):
mem = self.load(client_id)
mem.preferences[key] = value
self.save(mem)
def get_preferences(self, client_id: str = "global") -> dict[str, str]:
return self.load(client_id).preferences

View File

22
app/middleware/base.py Normal file
View File

@@ -0,0 +1,22 @@
"""Middleware base class — before/after hooks around graph execution."""
from __future__ import annotations
from abc import ABC, abstractmethod
from app.graph.state import ReportState
class Middleware(ABC):
"""Base middleware. Subclasses override before() and/or after()."""
name: str = "base"
enabled: bool = True
async def before(self, state: ReportState) -> ReportState:
"""Called before graph execution. Modify state as needed."""
return state
async def after(self, state: ReportState) -> ReportState:
"""Called after graph execution. Modify state as needed."""
return state

35
app/middleware/chain.py Normal file
View File

@@ -0,0 +1,35 @@
"""Middleware chain — runs ordered middlewares before/after graph execution."""
from __future__ import annotations
import logging
from app.graph.state import ReportState
from .base import Middleware
logger = logging.getLogger(__name__)
class MiddlewareChain:
"""Ordered list of middlewares. before() runs in order, after() runs in reverse."""
def __init__(self, middlewares: list[Middleware] | None = None):
self.middlewares = middlewares or []
def add(self, mw: Middleware) -> "MiddlewareChain":
self.middlewares.append(mw)
return self
async def before(self, state: ReportState) -> ReportState:
for mw in self.middlewares:
if mw.enabled:
logger.debug(f"[middleware:before] {mw.name}")
state = await mw.before(state)
return state
async def after(self, state: ReportState) -> ReportState:
for mw in reversed(self.middlewares):
if mw.enabled:
logger.debug(f"[middleware:after] {mw.name}")
state = await mw.after(state)
return state

View File

@@ -0,0 +1,56 @@
"""ClientContext middleware — injects client/project background into state."""
from __future__ import annotations
import json
import logging
from pathlib import Path
from app.graph.state import ReportState
from .base import Middleware
logger = logging.getLogger(__name__)
# Client profiles stored as JSON files
CLIENTS_DIR = Path(__file__).resolve().parent.parent.parent / "clients"
class ClientContextMiddleware(Middleware):
"""Loads client profile and injects background context into state.
Client profiles are JSON files in the `clients/` directory:
clients/<client_id>.json
{
"name": "某咨询公司",
"industry": "金融",
"preferences": "偏好简洁的执行摘要,数据驱动",
"previous_reports": ["report_001", "report_002"]
}
"""
name = "client_context"
async def before(self, state: ReportState) -> ReportState:
if not state.client_id:
return state
profile_path = CLIENTS_DIR / f"{state.client_id}.json"
if not profile_path.exists():
logger.info(f"[client_context] no profile for client '{state.client_id}'")
return state
try:
profile = json.loads(profile_path.read_text(encoding="utf-8"))
parts = []
if name := profile.get("name"):
parts.append(f"客户:{name}")
if industry := profile.get("industry"):
parts.append(f"行业:{industry}")
if prefs := profile.get("preferences"):
parts.append(f"偏好:{prefs}")
state.client_context = "".join(parts)
logger.info(f"[client_context] loaded profile for '{state.client_id}'")
except Exception as e:
logger.warning(f"[client_context] failed to load profile: {e}")
return state

View File

@@ -0,0 +1,61 @@
"""Compliance middleware — checks for sensitive data leakage in output."""
from __future__ import annotations
import json
import logging
import re
from app.graph.state import ReportState
from .base import Middleware
logger = logging.getLogger(__name__)
# Patterns that suggest sensitive data
SENSITIVE_PATTERNS = [
(r"\b\d{15,18}\b", "可能的身份证号"),
(r"\b\d{16,19}\b", "可能的银行卡号"),
(r"\b1[3-9]\d{9}\b", "可能的手机号"),
(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", "邮箱地址"),
(r"(?:密码|password|secret|token|api.?key)\s*[:=]\s*\S+", "可能的凭证信息"),
]
class ComplianceMiddleware(Middleware):
"""Scans report output for sensitive data patterns.
After graph execution, scans the draft for PII / credential patterns.
Logs warnings but does not block (can be made blocking later).
"""
name = "compliance"
def _scan_text(self, text: str) -> list[dict]:
findings = []
for pattern, desc in SENSITIVE_PATTERNS:
matches = re.findall(pattern, text)
if matches:
findings.append({
"type": desc,
"count": len(matches),
"samples": [m[:8] + "..." for m in matches[:3]],
})
return findings
async def after(self, state: ReportState) -> ReportState:
# Scan draft content
all_text = json.dumps(state.draft, ensure_ascii=False)
findings = self._scan_text(all_text)
if findings:
logger.warning(
f"[compliance] found {len(findings)} sensitive data patterns:"
)
for f in findings:
logger.warning(f" - {f['type']}: {f['count']} occurrences")
# Store in state for API to surface
state.review.setdefault("compliance_warnings", findings)
else:
logger.info("[compliance] no sensitive data patterns detected")
return state

64
app/middleware/memory.py Normal file
View File

@@ -0,0 +1,64 @@
"""Memory middleware — injects persistent facts into state, saves new facts after."""
from __future__ import annotations
import logging
from app.graph.state import ReportState
from app.memory.store import MemoryStore
from .base import Middleware
logger = logging.getLogger(__name__)
class MemoryMiddleware(Middleware):
"""Injects top-N memory facts into state before execution.
After execution, extracts key facts from the report for future reference.
"""
name = "memory"
def __init__(self):
self.store = MemoryStore()
async def before(self, state: ReportState) -> ReportState:
# Load global facts + client-specific facts
facts = self.store.get_top_facts("global", limit=10)
if state.client_id:
client_facts = self.store.get_top_facts(state.client_id, limit=10)
facts = client_facts + facts # client facts first
state.memory_facts = facts[:15] # cap total
if facts:
logger.info(f"[memory] injected {len(state.memory_facts)} facts")
return state
async def after(self, state: ReportState) -> ReportState:
# Save key metadata as facts for future reference
if state.draft and not state.error:
title = state.draft.get("title", "")
report_type = state.report_type
client_id = state.client_id or "global"
self.store.add_fact(
content=f"生成过报告:{title}(类型:{report_type}",
client_id=client_id,
category="context",
confidence=0.9,
source=state.id,
)
# Save decomposition tracks as knowledge
tracks = state.decomposition.get("tracks", [])
if tracks:
track_titles = "".join(t.get("title", "") for t in tracks)
self.store.add_fact(
content=f"报告《{title}》的研究轨道:{track_titles}",
client_id=client_id,
category="knowledge",
confidence=0.6,
source=state.id,
)
return state

View File

@@ -0,0 +1,46 @@
"""TokenBudget middleware — tracks and limits token usage across the pipeline."""
from __future__ import annotations
import logging
from app.graph.state import ReportState
from .base import Middleware
logger = logging.getLogger(__name__)
class TokenBudgetMiddleware(Middleware):
"""Manages token budget to prevent context overflow.
Before: sets token budget based on model limits.
After: logs total estimated token usage.
"""
name = "token_budget"
# Rough char-to-token ratio for Chinese text
CHARS_PER_TOKEN = 1.5
def _estimate_tokens(self, text: str) -> int:
return int(len(text) / self.CHARS_PER_TOKEN)
async def before(self, state: ReportState) -> ReportState:
# Default budget is already set in ReportState (120k)
# Could adjust based on model selection
logger.info(f"[token_budget] budget = {state.token_budget} tokens")
return state
async def after(self, state: ReportState) -> ReportState:
# Estimate total tokens used in the draft
total_chars = 0
for chapter in state.draft.get("chapters", []):
total_chars += len(chapter.get("content", ""))
total_chars += len(state.draft.get("executive_summary", ""))
estimated = self._estimate_tokens(str(total_chars))
logger.info(
f"[token_budget] estimated output tokens: {estimated}, "
f"budget: {state.token_budget}"
)
return state

0
app/output/__init__.py Normal file
View File

0
app/pipeline/__init__.py Normal file
View File

View File

@@ -0,0 +1,43 @@
"""Pipeline orchestrator — builds the graph + middleware chain and executes."""
from __future__ import annotations
import logging
from app.graph.builder import ReportGraph
from app.graph.state import ReportState
from app.middleware.chain import MiddlewareChain
from app.middleware.client_context import ClientContextMiddleware
from app.middleware.token_budget import TokenBudgetMiddleware
from app.middleware.compliance import ComplianceMiddleware
from app.middleware.memory import MemoryMiddleware
logger = logging.getLogger(__name__)
def build_middleware() -> MiddlewareChain:
"""Assemble the middleware chain in execution order."""
chain = MiddlewareChain()
chain.add(ClientContextMiddleware()) # 1. load client profile
chain.add(MemoryMiddleware()) # 2. inject memory facts
chain.add(TokenBudgetMiddleware()) # 3. set token budget
chain.add(ComplianceMiddleware()) # 4. scan output for PII (after only)
return chain
class PipelineOrchestrator:
"""Top-level entry: creates graph + middleware, runs end-to-end."""
def __init__(self):
self.middleware = build_middleware()
self.graph = ReportGraph(middleware=self.middleware)
async def run(self, state: ReportState) -> ReportState:
logger.info(f"[orchestrator] starting report {state.id}")
state = await self.graph.run(state)
logger.info(
f"[orchestrator] finished report {state.id}"
f"status={'OK' if not state.error else 'FAILED'}, "
f"files={len(state.generated_files)}"
)
return state

22
app/pipeline/task.py Normal file
View File

@@ -0,0 +1,22 @@
"""Pipeline task model — thin wrapper, main state is now ReportState."""
from __future__ import annotations
from app.graph.state import ReportState
def create_report_state(
requirement: str,
report_type: str = "行业分析报告",
extra_data: str = "",
output_formats: list[str] | None = None,
client_id: str | None = None,
) -> ReportState:
"""Create a ReportState from user input."""
return ReportState(
requirement=requirement,
report_type=report_type,
extra_data=extra_data,
output_formats=output_formats or ["docx"],
client_id=client_id,
)

View File