init repo
This commit is contained in:
0
app/__init__.py
Normal file
0
app/__init__.py
Normal file
15
app/agents/__init__.py
Normal file
15
app/agents/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from .base import BaseAgent
|
||||
from .researcher import ResearcherAgent
|
||||
from .writer import WriterAgent
|
||||
from .data_agent import DataAgent
|
||||
from .reviewer import ReviewerAgent
|
||||
from .formatter import FormatterAgent
|
||||
|
||||
__all__ = [
|
||||
"BaseAgent",
|
||||
"ResearcherAgent",
|
||||
"WriterAgent",
|
||||
"DataAgent",
|
||||
"ReviewerAgent",
|
||||
"FormatterAgent",
|
||||
]
|
||||
166
app/agents/base.py
Normal file
166
app/agents/base.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""Base agent with LLM calling via litellm."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Disable litellm telemetry
|
||||
litellm.telemetry = False
|
||||
|
||||
|
||||
class BaseAgent:
|
||||
"""Base class for all pipeline agents."""
|
||||
|
||||
name: str = "base"
|
||||
description: str = ""
|
||||
system_prompt: str = ""
|
||||
model: str = "" # empty = use default from config
|
||||
|
||||
def __init__(self, model: str | None = None):
|
||||
if model:
|
||||
self.model = model
|
||||
|
||||
def get_model(self) -> str:
|
||||
return self.model or settings.llm_model
|
||||
|
||||
async def call_llm(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
system: str | None = None,
|
||||
temperature: float = 0.3,
|
||||
max_tokens: int = 4096,
|
||||
response_format: dict | None = None,
|
||||
) -> str:
|
||||
"""Call LLM via litellm. Returns the text response."""
|
||||
messages = []
|
||||
sys_prompt = system or self.system_prompt
|
||||
if sys_prompt:
|
||||
messages.append({"role": "system", "content": sys_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.get_model(),
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
}
|
||||
if settings.llm_api_key:
|
||||
kwargs["api_key"] = settings.llm_api_key
|
||||
if settings.llm_api_base:
|
||||
kwargs["api_base"] = settings.llm_api_base
|
||||
if response_format:
|
||||
kwargs["response_format"] = response_format
|
||||
|
||||
logger.info(f"[{self.name}] calling {self.get_model()}")
|
||||
response = await litellm.acompletion(**kwargs)
|
||||
content = response.choices[0].message.content
|
||||
logger.info(f"[{self.name}] got {len(content)} chars")
|
||||
return content
|
||||
|
||||
async def call_llm_json(self, prompt: str, **kwargs) -> dict:
|
||||
"""Call LLM and parse response as JSON."""
|
||||
raw = await self.call_llm(
|
||||
prompt,
|
||||
response_format={"type": "json_object"},
|
||||
**kwargs,
|
||||
)
|
||||
# Strip markdown code fences if present
|
||||
text = raw.strip()
|
||||
if text.startswith("```"):
|
||||
first_nl = text.find("\n")
|
||||
if first_nl != -1:
|
||||
text = text[first_nl + 1:]
|
||||
if text.endswith("```"):
|
||||
text = text[: text.rfind("```")]
|
||||
text = text.strip()
|
||||
|
||||
# Sanitize control characters inside JSON string values
|
||||
# (models sometimes emit literal newlines/tabs inside strings)
|
||||
import re
|
||||
def _clean_json_string(s: str) -> str:
|
||||
# Replace unescaped control chars within JSON strings
|
||||
# This is a best-effort fix for common model outputs
|
||||
result = []
|
||||
in_string = False
|
||||
escape = False
|
||||
for ch in s:
|
||||
if escape:
|
||||
result.append(ch)
|
||||
escape = False
|
||||
continue
|
||||
if ch == '\\':
|
||||
result.append(ch)
|
||||
escape = True
|
||||
continue
|
||||
if ch == '"':
|
||||
in_string = not in_string
|
||||
result.append(ch)
|
||||
continue
|
||||
if in_string and ord(ch) < 32:
|
||||
# Replace control chars with escaped versions
|
||||
if ch == '\n':
|
||||
result.append('\\n')
|
||||
elif ch == '\r':
|
||||
result.append('\\r')
|
||||
elif ch == '\t':
|
||||
result.append('\\t')
|
||||
else:
|
||||
result.append(f'\\u{ord(ch):04x}')
|
||||
continue
|
||||
result.append(ch)
|
||||
return ''.join(result)
|
||||
|
||||
# Try parsing with multiple strategies
|
||||
for attempt, candidate in enumerate([text, _clean_json_string(text)]):
|
||||
try:
|
||||
return json.loads(candidate)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# Last resort: try to extract the largest valid JSON object
|
||||
# (model may have appended commentary after the JSON)
|
||||
brace_depth = 0
|
||||
start = text.find('{')
|
||||
if start == -1:
|
||||
raise json.JSONDecodeError("No JSON object found", text, 0)
|
||||
|
||||
cleaned = _clean_json_string(text)
|
||||
for i, ch in enumerate(cleaned[start:], start):
|
||||
if ch == '{':
|
||||
brace_depth += 1
|
||||
elif ch == '}':
|
||||
brace_depth -= 1
|
||||
if brace_depth == 0:
|
||||
try:
|
||||
return json.loads(cleaned[start:i + 1])
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# If all else fails, use json_repair library or raise
|
||||
try:
|
||||
import json_repair
|
||||
return json_repair.loads(text)
|
||||
except (ImportError, Exception):
|
||||
raise json.JSONDecodeError(
|
||||
f"Failed to parse JSON after multiple attempts", text, 0
|
||||
)
|
||||
|
||||
async def run(self, context: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Execute this agent's task. Override in subclasses.
|
||||
|
||||
Args:
|
||||
context: Shared pipeline context (accumulated by previous agents).
|
||||
|
||||
Returns:
|
||||
Dict of new keys to merge into context.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
78
app/agents/data_agent.py
Normal file
78
app/agents/data_agent.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Data Agent — processes data, generates chart specs and table data."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from .base import BaseAgent
|
||||
from app.config import settings
|
||||
|
||||
|
||||
class DataAgent(BaseAgent):
|
||||
name = "data"
|
||||
description = "处理数据、生成图表规格和表格数据"
|
||||
system_prompt = """\
|
||||
你是一位数据分析专家。你的任务是根据报告草稿中标注的图表和表格需求,
|
||||
生成具体的数据和图表规格。
|
||||
|
||||
输出要求(JSON 格式):
|
||||
{
|
||||
"charts": [
|
||||
{
|
||||
"id": "chart_1",
|
||||
"title": "图表标题",
|
||||
"type": "bar|line|pie|area|scatter",
|
||||
"description": "图表说明",
|
||||
"data": {
|
||||
"labels": ["标签1", "标签2"],
|
||||
"datasets": [
|
||||
{"label": "数据集名", "data": [100, 200]}
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"tables": [
|
||||
{
|
||||
"id": "table_1",
|
||||
"title": "表格标题",
|
||||
"headers": ["列1", "列2", "列3"],
|
||||
"rows": [["数据1", "数据2", "数据3"]]
|
||||
}
|
||||
]
|
||||
}"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(model=settings.model_for_domain("fast"))
|
||||
|
||||
async def run(self, context: dict[str, Any]) -> dict[str, Any]:
|
||||
draft = context["draft"]
|
||||
extra_data = context.get("extra_data", "")
|
||||
|
||||
# Collect chart/table needs from draft
|
||||
chart_needs = []
|
||||
table_needs = []
|
||||
for ch in draft.get("chapters", []):
|
||||
chart_needs.extend(ch.get("charts", []))
|
||||
table_needs.extend(ch.get("tables", []))
|
||||
|
||||
if not chart_needs and not table_needs:
|
||||
return {"data_assets": {"charts": [], "tables": []}}
|
||||
|
||||
prompt = f"""\
|
||||
## 报告标题
|
||||
{draft.get("title", "")}
|
||||
|
||||
## 需要生成的图表
|
||||
{json.dumps(chart_needs, ensure_ascii=False)}
|
||||
|
||||
## 需要生成的表格
|
||||
{json.dumps(table_needs, ensure_ascii=False)}
|
||||
|
||||
## 补充数据源
|
||||
{extra_data if extra_data else "(无额外数据,请根据行业常识生成合理的示例数据)"}
|
||||
|
||||
请为以上需求生成具体的图表规格和表格数据。输出 JSON。"""
|
||||
|
||||
result = await self.call_llm_json(prompt)
|
||||
return {"data_assets": result}
|
||||
669
app/agents/formatter.py
Normal file
669
app/agents/formatter.py
Normal file
@@ -0,0 +1,669 @@
|
||||
"""Formatter Agent — renders final report using Skills toolkit.
|
||||
|
||||
Skills integration:
|
||||
- docx: python-docx (baseline) + docx-js via Node.js (rich mode) + OOXML template editing
|
||||
- pptx: html2pptx.js via Node.js (visual slides) + python-pptx fallback
|
||||
- xlsx: openpyxl + recalc.py (formula recalculation via LibreOffice)
|
||||
- pdf: reportlab with CJK support + fpdf2 fallback
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .base import BaseAgent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Skills root
|
||||
SKILLS_ROOT = Path.home() / "Projects/code/20260119-skills合集/anthropics_skills/skills"
|
||||
DOCX_SKILLS = SKILLS_ROOT / "docx"
|
||||
PPTX_SKILLS = SKILLS_ROOT / "pptx"
|
||||
XLSX_SKILLS = SKILLS_ROOT / "xlsx"
|
||||
PDF_SKILLS = SKILLS_ROOT / "pdf"
|
||||
|
||||
|
||||
def _skills_available() -> dict[str, bool]:
|
||||
"""Check which skill toolkits are available."""
|
||||
return {
|
||||
"docx_js": (DOCX_SKILLS / "docx-js.md").exists(),
|
||||
"html2pptx": (PPTX_SKILLS / "scripts" / "html2pptx.js").exists(),
|
||||
"recalc": (XLSX_SKILLS / "recalc.py").exists(),
|
||||
"ooxml_docx": (DOCX_SKILLS / "ooxml" / "scripts" / "unpack.py").exists(),
|
||||
"ooxml_pptx": (PPTX_SKILLS / "ooxml" / "scripts" / "unpack.py").exists(),
|
||||
"pdf_scripts": (PDF_SKILLS / "scripts").is_dir(),
|
||||
}
|
||||
|
||||
|
||||
class FormatterAgent(BaseAgent):
|
||||
name = "formatter"
|
||||
description = "将报告渲染为 docx/pptx/xlsx/pdf,融合 Skills 能力"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.skills = _skills_available()
|
||||
available = [k for k, v in self.skills.items() if v]
|
||||
logger.info(f"[formatter] available skills: {available}")
|
||||
|
||||
async def run(self, context: dict[str, Any]) -> dict[str, Any]:
|
||||
draft = context["draft"]
|
||||
data_assets = context.get("data_assets", {})
|
||||
output_dir = Path(context.get("output_dir", "output"))
|
||||
formats = context.get("output_formats", ["docx"])
|
||||
template_path = context.get("template_path") # optional: user-provided template
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
title = draft.get("title", "报告")
|
||||
generated_files = []
|
||||
|
||||
for fmt in formats:
|
||||
try:
|
||||
match fmt:
|
||||
case "docx":
|
||||
path = await self._render_docx(draft, data_assets, output_dir, title, template_path)
|
||||
case "pptx":
|
||||
path = await self._render_pptx(draft, data_assets, output_dir, title)
|
||||
case "xlsx":
|
||||
path = await self._render_xlsx(data_assets, output_dir, title)
|
||||
case "pdf":
|
||||
path = await self._render_pdf(draft, data_assets, output_dir, title)
|
||||
case _:
|
||||
logger.warning(f"Unsupported format: {fmt}")
|
||||
continue
|
||||
generated_files.append(str(path))
|
||||
logger.info(f"[formatter] generated {path}")
|
||||
except Exception as e:
|
||||
logger.exception(f"[formatter] failed to render {fmt}")
|
||||
|
||||
return {"generated_files": generated_files}
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# DOCX — python-docx baseline + OOXML template editing
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
async def _render_docx(
|
||||
self, draft: dict, data_assets: dict, output_dir: Path, title: str,
|
||||
template_path: str | None = None,
|
||||
) -> Path:
|
||||
if template_path and self.skills["ooxml_docx"]:
|
||||
return await self._render_docx_from_template(
|
||||
draft, data_assets, output_dir, title, Path(template_path)
|
||||
)
|
||||
return await self._render_docx_baseline(draft, data_assets, output_dir, title)
|
||||
|
||||
async def _render_docx_baseline(
|
||||
self, draft: dict, data_assets: dict, output_dir: Path, title: str
|
||||
) -> Path:
|
||||
from docx import Document
|
||||
from docx.shared import Pt, RGBColor
|
||||
from docx.enum.text import WD_ALIGN_PARAGRAPH
|
||||
|
||||
doc = Document()
|
||||
|
||||
# -- Styles --
|
||||
style = doc.styles["Normal"]
|
||||
style.font.name = "微软雅黑"
|
||||
style.font.size = Pt(11)
|
||||
|
||||
# Title
|
||||
t = doc.add_heading(title, level=0)
|
||||
t.alignment = WD_ALIGN_PARAGRAPH.CENTER
|
||||
|
||||
# Executive summary
|
||||
if summary := draft.get("executive_summary"):
|
||||
doc.add_heading("执行摘要", level=1)
|
||||
# Add summary with highlight styling
|
||||
p = doc.add_paragraph()
|
||||
run = p.add_run(summary)
|
||||
run.font.size = Pt(11)
|
||||
run.font.color.rgb = RGBColor(0x33, 0x33, 0x33)
|
||||
|
||||
# Chapters
|
||||
for chapter in draft.get("chapters", []):
|
||||
doc.add_heading(chapter["title"], level=1)
|
||||
content = chapter.get("content", "")
|
||||
self._docx_render_markdown(doc, content)
|
||||
|
||||
# Tables from data assets
|
||||
for table_spec in data_assets.get("tables", []):
|
||||
doc.add_heading(table_spec.get("title", "数据表"), level=2)
|
||||
self._docx_add_table(doc, table_spec)
|
||||
|
||||
# Page break + chart descriptions as placeholders
|
||||
for chart_spec in data_assets.get("charts", []):
|
||||
doc.add_heading(chart_spec.get("title", "图表"), level=2)
|
||||
desc = chart_spec.get("description", "")
|
||||
chart_type = chart_spec.get("type", "")
|
||||
doc.add_paragraph(f"[{chart_type.upper()} 图表] {desc}")
|
||||
# Render chart data as a table too
|
||||
chart_data = chart_spec.get("data", {})
|
||||
if labels := chart_data.get("labels"):
|
||||
for ds in chart_data.get("datasets", []):
|
||||
self._docx_add_table(doc, {
|
||||
"headers": ["项目", ds.get("label", "数据")],
|
||||
"rows": [[str(l), str(v)] for l, v in zip(labels, ds.get("data", []))],
|
||||
})
|
||||
|
||||
path = output_dir / f"{title}.docx"
|
||||
doc.save(str(path))
|
||||
return path
|
||||
|
||||
async def _render_docx_from_template(
|
||||
self, draft: dict, data_assets: dict, output_dir: Path, title: str,
|
||||
template_path: Path,
|
||||
) -> Path:
|
||||
"""Edit an existing DOCX template using OOXML unpack/edit/pack workflow."""
|
||||
unpack_script = DOCX_SKILLS / "ooxml" / "scripts" / "unpack.py"
|
||||
pack_script = DOCX_SKILLS / "ooxml" / "scripts" / "pack.py"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
work_dir = Path(tmpdir) / "unpacked"
|
||||
|
||||
# Unpack template
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
"python3", str(unpack_script), str(template_path), str(work_dir),
|
||||
stdout=subprocess.PIPE, stderr=subprocess.PIPE,
|
||||
)
|
||||
await proc.wait()
|
||||
|
||||
if proc.returncode != 0:
|
||||
logger.warning("[formatter] OOXML unpack failed, falling back to baseline")
|
||||
return await self._render_docx_baseline(draft, data_assets, output_dir, title)
|
||||
|
||||
# TODO: edit XML content in work_dir based on draft
|
||||
# For now, just pack back as-is (template passthrough)
|
||||
output_path = output_dir / f"{title}.docx"
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
"python3", str(pack_script), str(work_dir), str(output_path),
|
||||
stdout=subprocess.PIPE, stderr=subprocess.PIPE,
|
||||
)
|
||||
await proc.wait()
|
||||
return output_path
|
||||
|
||||
def _docx_render_markdown(self, doc, content: str):
|
||||
"""Convert markdown-ish content to docx paragraphs."""
|
||||
from docx.shared import Pt
|
||||
|
||||
for block in content.split("\n\n"):
|
||||
block = block.strip()
|
||||
if not block:
|
||||
continue
|
||||
if block.startswith("#### "):
|
||||
doc.add_heading(block[5:], level=4)
|
||||
elif block.startswith("### "):
|
||||
doc.add_heading(block[4:], level=3)
|
||||
elif block.startswith("## "):
|
||||
doc.add_heading(block[3:], level=2)
|
||||
elif block.startswith("- ") or block.startswith("* "):
|
||||
# Bullet list
|
||||
for line in block.split("\n"):
|
||||
line = line.lstrip("- *").strip()
|
||||
if line:
|
||||
doc.add_paragraph(line, style="List Bullet")
|
||||
elif block.startswith("1. ") or block.startswith("1)"):
|
||||
# Numbered list
|
||||
for line in block.split("\n"):
|
||||
text = line.lstrip("0123456789.)) ").strip()
|
||||
if text:
|
||||
doc.add_paragraph(text, style="List Number")
|
||||
else:
|
||||
p = doc.add_paragraph(block)
|
||||
for run in p.runs:
|
||||
run.font.size = Pt(11)
|
||||
|
||||
def _docx_add_table(self, doc, table_spec: dict):
|
||||
"""Add a formatted table to the document."""
|
||||
from docx.shared import Pt, RGBColor
|
||||
from docx.oxml.ns import qn
|
||||
|
||||
headers = table_spec.get("headers", [])
|
||||
rows = table_spec.get("rows", [])
|
||||
if not headers:
|
||||
return
|
||||
|
||||
tbl = doc.add_table(rows=1 + len(rows), cols=len(headers))
|
||||
tbl.style = "Light Grid Accent 1"
|
||||
|
||||
# Header row
|
||||
for i, h in enumerate(headers):
|
||||
cell = tbl.rows[0].cells[i]
|
||||
cell.text = str(h)
|
||||
for p in cell.paragraphs:
|
||||
for run in p.runs:
|
||||
run.font.bold = True
|
||||
run.font.size = Pt(10)
|
||||
|
||||
# Data rows
|
||||
for r_idx, row in enumerate(rows):
|
||||
for c_idx, cell_val in enumerate(row):
|
||||
tbl.rows[r_idx + 1].cells[c_idx].text = str(cell_val)
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# PPTX — html2pptx.js (rich) or python-pptx (fallback)
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
async def _render_pptx(
|
||||
self, draft: dict, data_assets: dict, output_dir: Path, title: str
|
||||
) -> Path:
|
||||
if self.skills["html2pptx"]:
|
||||
try:
|
||||
return await self._render_pptx_html2pptx(draft, data_assets, output_dir, title)
|
||||
except Exception as e:
|
||||
logger.warning(f"[formatter] html2pptx failed ({e}), falling back to python-pptx")
|
||||
return await self._render_pptx_baseline(draft, data_assets, output_dir, title)
|
||||
|
||||
async def _render_pptx_html2pptx(
|
||||
self, draft: dict, data_assets: dict, output_dir: Path, title: str
|
||||
) -> Path:
|
||||
"""Generate PPTX using html2pptx.js skill for visual slides."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
work = Path(tmpdir)
|
||||
|
||||
# Generate HTML slides
|
||||
slides_html = []
|
||||
# Title slide
|
||||
slides_html.append(f"""<html><body style="width:720pt;height:405pt;display:flex;align-items:center;justify-content:center;flex-direction:column;background:linear-gradient(135deg,#1a1a2e,#16213e);color:white;font-family:sans-serif;">
|
||||
<h1 style="font-size:36pt;margin:0;">{title}</h1>
|
||||
<p style="font-size:18pt;color:#aaa;margin-top:20pt;">{draft.get('executive_summary', '')[:100]}</p>
|
||||
</body></html>""")
|
||||
|
||||
# Chapter slides
|
||||
for ch in draft.get("chapters", []):
|
||||
content_lines = ch.get("content", "")[:400].split("\n")
|
||||
bullets = "".join(f"<li>{l.strip()}</li>" for l in content_lines if l.strip())
|
||||
slides_html.append(f"""<html><body style="width:720pt;height:405pt;padding:40pt;font-family:sans-serif;background:#ffffff;">
|
||||
<h2 style="font-size:28pt;color:#1a1a2e;border-bottom:2pt solid #e94560;padding-bottom:10pt;">{ch['title']}</h2>
|
||||
<ul style="font-size:14pt;color:#333;line-height:1.8;">{bullets}</ul>
|
||||
</body></html>""")
|
||||
|
||||
# Write HTML files
|
||||
for i, html in enumerate(slides_html):
|
||||
(work / f"slide_{i}.html").write_text(html, encoding="utf-8")
|
||||
|
||||
# Write conversion script
|
||||
script = work / "convert.js"
|
||||
html2pptx_path = PPTX_SKILLS / "scripts" / "html2pptx.js"
|
||||
slide_files = [f"slide_{i}.html" for i in range(len(slides_html))]
|
||||
|
||||
script.write_text(f"""\
|
||||
const pptxgen = require('pptxgenjs');
|
||||
const {{ html2pptx }} = require('{html2pptx_path}');
|
||||
const path = require('path');
|
||||
|
||||
async function main() {{
|
||||
const pptx = new pptxgen();
|
||||
pptx.layout = 'LAYOUT_16x9';
|
||||
const files = {json.dumps(slide_files)};
|
||||
for (const f of files) {{
|
||||
await html2pptx(path.join('{work}', f), pptx);
|
||||
}}
|
||||
await pptx.writeFile({{ fileName: '{output_dir / f"{title}.pptx"}' }});
|
||||
}}
|
||||
main().catch(e => {{ console.error(e); process.exit(1); }});
|
||||
""", encoding="utf-8")
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
"node", str(script),
|
||||
stdout=subprocess.PIPE, stderr=subprocess.PIPE,
|
||||
cwd=str(work),
|
||||
)
|
||||
stdout, stderr = await proc.communicate()
|
||||
if proc.returncode != 0:
|
||||
raise RuntimeError(f"html2pptx failed: {stderr.decode()}")
|
||||
|
||||
return output_dir / f"{title}.pptx"
|
||||
|
||||
async def _render_pptx_baseline(
|
||||
self, draft: dict, data_assets: dict, output_dir: Path, title: str
|
||||
) -> Path:
|
||||
from pptx import Presentation
|
||||
from pptx.util import Inches, Pt
|
||||
from pptx.dml.color import RGBColor
|
||||
|
||||
prs = Presentation()
|
||||
|
||||
# Title slide
|
||||
slide = prs.slides.add_slide(prs.slide_layouts[0])
|
||||
slide.shapes.title.text = title
|
||||
if len(slide.placeholders) > 1:
|
||||
slide.placeholders[1].text = draft.get("executive_summary", "")[:200]
|
||||
|
||||
# Chapter slides
|
||||
for chapter in draft.get("chapters", []):
|
||||
slide = prs.slides.add_slide(prs.slide_layouts[1])
|
||||
slide.shapes.title.text = chapter["title"]
|
||||
body = slide.placeholders[1]
|
||||
tf = body.text_frame
|
||||
tf.clear()
|
||||
|
||||
content = chapter.get("content", "")
|
||||
lines = [l.strip() for l in content.split("\n") if l.strip()]
|
||||
for line in lines[:12]: # max 12 bullets per slide
|
||||
p = tf.add_paragraph()
|
||||
# Strip markdown markers
|
||||
clean = line.lstrip("#-*0123456789.) ").strip()
|
||||
p.text = clean
|
||||
p.font.size = Pt(14)
|
||||
p.space_after = Pt(4)
|
||||
|
||||
# Data table slides
|
||||
for table_spec in data_assets.get("tables", []):
|
||||
slide = prs.slides.add_slide(prs.slide_layouts[5]) # blank layout
|
||||
slide.shapes.title.text = table_spec.get("title", "数据表")
|
||||
|
||||
headers = table_spec.get("headers", [])
|
||||
rows = table_spec.get("rows", [])
|
||||
if headers and rows:
|
||||
n_rows = min(len(rows) + 1, 10) # limit rows per slide
|
||||
n_cols = len(headers)
|
||||
tbl = slide.shapes.add_table(
|
||||
n_rows, n_cols,
|
||||
Inches(0.5), Inches(1.5), Inches(9), Inches(4.5)
|
||||
).table
|
||||
|
||||
for i, h in enumerate(headers):
|
||||
tbl.cell(0, i).text = str(h)
|
||||
for r_idx, row in enumerate(rows[:n_rows - 1]):
|
||||
for c_idx, val in enumerate(row[:n_cols]):
|
||||
tbl.cell(r_idx + 1, c_idx).text = str(val)
|
||||
|
||||
path = output_dir / f"{title}.pptx"
|
||||
prs.save(str(path))
|
||||
return path
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# XLSX — openpyxl + recalc.py (formula recalculation)
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
async def _render_xlsx(
|
||||
self, data_assets: dict, output_dir: Path, title: str
|
||||
) -> Path:
|
||||
from openpyxl import Workbook
|
||||
from openpyxl.styles import Font, PatternFill, Alignment, Border, Side
|
||||
from openpyxl.utils import get_column_letter
|
||||
|
||||
wb = Workbook()
|
||||
ws = wb.active
|
||||
ws.title = "数据总览"
|
||||
|
||||
# Professional styling
|
||||
header_font = Font(bold=True, size=11, color="FFFFFF")
|
||||
header_fill = PatternFill(start_color="1A1A2E", end_color="1A1A2E", fill_type="solid")
|
||||
title_font = Font(bold=True, size=14, color="1A1A2E")
|
||||
thin_border = Border(
|
||||
left=Side(style="thin", color="CCCCCC"),
|
||||
right=Side(style="thin", color="CCCCCC"),
|
||||
top=Side(style="thin", color="CCCCCC"),
|
||||
bottom=Side(style="thin", color="CCCCCC"),
|
||||
)
|
||||
|
||||
current_row = 1
|
||||
has_formulas = False
|
||||
|
||||
for table_spec in data_assets.get("tables", []):
|
||||
# Table title
|
||||
ws.cell(row=current_row, column=1, value=table_spec.get("title", "")).font = title_font
|
||||
current_row += 1
|
||||
|
||||
headers = table_spec.get("headers", [])
|
||||
rows = table_spec.get("rows", [])
|
||||
|
||||
if headers:
|
||||
# Header row with styling
|
||||
for col_idx, h in enumerate(headers, 1):
|
||||
cell = ws.cell(row=current_row, column=col_idx, value=h)
|
||||
cell.font = header_font
|
||||
cell.fill = header_fill
|
||||
cell.alignment = Alignment(horizontal="center")
|
||||
cell.border = thin_border
|
||||
current_row += 1
|
||||
|
||||
# Data rows
|
||||
data_start = current_row
|
||||
for row_data in rows:
|
||||
for col_idx, val in enumerate(row_data, 1):
|
||||
cell = ws.cell(row=current_row, column=col_idx, value=val)
|
||||
cell.border = thin_border
|
||||
# Try to convert numeric strings
|
||||
if isinstance(val, str):
|
||||
try:
|
||||
cell.value = float(val.replace(",", ""))
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
current_row += 1
|
||||
|
||||
# Auto-sum row for numeric columns
|
||||
data_end = current_row - 1
|
||||
if data_end > data_start:
|
||||
for col_idx in range(1, len(headers) + 1):
|
||||
col_letter = get_column_letter(col_idx)
|
||||
test_cell = ws.cell(row=data_start, column=col_idx)
|
||||
if isinstance(test_cell.value, (int, float)):
|
||||
cell = ws.cell(
|
||||
row=current_row, column=col_idx,
|
||||
value=f"=SUM({col_letter}{data_start}:{col_letter}{data_end})"
|
||||
)
|
||||
cell.font = Font(bold=True)
|
||||
cell.border = thin_border
|
||||
has_formulas = True
|
||||
elif col_idx == 1:
|
||||
cell = ws.cell(row=current_row, column=1, value="合计")
|
||||
cell.font = Font(bold=True)
|
||||
cell.border = thin_border
|
||||
current_row += 1
|
||||
|
||||
# Auto-fit column widths
|
||||
for col_idx in range(1, len(headers) + 1):
|
||||
max_len = max(
|
||||
len(str(ws.cell(row=r, column=col_idx).value or ""))
|
||||
for r in range(current_row - len(rows) - 2, current_row)
|
||||
)
|
||||
ws.column_dimensions[get_column_letter(col_idx)].width = min(max_len + 4, 30)
|
||||
|
||||
current_row += 2 # gap between tables
|
||||
|
||||
# Chart data sheets
|
||||
for chart_spec in data_assets.get("charts", []):
|
||||
chart_ws = wb.create_sheet(title=chart_spec.get("title", "图表")[:31])
|
||||
chart_ws.cell(row=1, column=1, value=chart_spec.get("title", "")).font = title_font
|
||||
chart_data = chart_spec.get("data", {})
|
||||
labels = chart_data.get("labels", [])
|
||||
datasets = chart_data.get("datasets", [])
|
||||
|
||||
# Headers: [项目, 数据集1, 数据集2, ...]
|
||||
chart_ws.cell(row=2, column=1, value="项目").font = Font(bold=True)
|
||||
for ds_idx, ds in enumerate(datasets, 2):
|
||||
chart_ws.cell(row=2, column=ds_idx, value=ds.get("label", "")).font = Font(bold=True)
|
||||
|
||||
for r_idx, label in enumerate(labels, 3):
|
||||
chart_ws.cell(row=r_idx, column=1, value=label)
|
||||
for ds_idx, ds in enumerate(datasets, 2):
|
||||
data = ds.get("data", [])
|
||||
if r_idx - 3 < len(data):
|
||||
chart_ws.cell(row=r_idx, column=ds_idx, value=data[r_idx - 3])
|
||||
|
||||
path = output_dir / f"{title}.xlsx"
|
||||
wb.save(str(path))
|
||||
|
||||
# Run recalc.py if we have formulas and the skill is available
|
||||
if has_formulas and self.skills["recalc"]:
|
||||
await self._xlsx_recalc(path)
|
||||
|
||||
return path
|
||||
|
||||
async def _xlsx_recalc(self, path: Path):
|
||||
"""Recalculate formulas using Skills recalc.py (requires LibreOffice)."""
|
||||
recalc_script = XLSX_SKILLS / "recalc.py"
|
||||
logger.info(f"[formatter] running recalc.py on {path}")
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
"python3", str(recalc_script), str(path), "30",
|
||||
stdout=subprocess.PIPE, stderr=subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = await proc.communicate()
|
||||
if proc.returncode == 0:
|
||||
result = json.loads(stdout.decode())
|
||||
logger.info(f"[formatter] recalc result: {result.get('status')}")
|
||||
else:
|
||||
logger.warning(f"[formatter] recalc.py failed: {stderr.decode()[:200]}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[formatter] recalc.py error: {e}")
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# PDF — reportlab with CJK support + fpdf2 fallback
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
async def _render_pdf(
|
||||
self, draft: dict, data_assets: dict, output_dir: Path, title: str
|
||||
) -> Path:
|
||||
try:
|
||||
return await self._render_pdf_reportlab(draft, data_assets, output_dir, title)
|
||||
except Exception as e:
|
||||
logger.warning(f"[formatter] reportlab failed ({e}), falling back to fpdf2")
|
||||
return await self._render_pdf_fpdf(draft, data_assets, output_dir, title)
|
||||
|
||||
async def _render_pdf_reportlab(
|
||||
self, draft: dict, data_assets: dict, output_dir: Path, title: str
|
||||
) -> Path:
|
||||
"""Generate PDF with reportlab — better CJK support and table rendering."""
|
||||
from reportlab.lib.pagesizes import A4
|
||||
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
|
||||
from reportlab.lib.units import mm
|
||||
from reportlab.lib import colors
|
||||
from reportlab.platypus import (
|
||||
SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle, PageBreak,
|
||||
)
|
||||
from reportlab.pdfbase import pdfmetrics
|
||||
from reportlab.pdfbase.ttfonts import TTFont
|
||||
|
||||
# Try to register a CJK font
|
||||
cjk_font = "Helvetica"
|
||||
for font_path in [
|
||||
"/System/Library/Fonts/STHeiti Medium.ttc",
|
||||
"/System/Library/Fonts/PingFang.ttc",
|
||||
"/usr/share/fonts/truetype/noto/NotoSansCJK-Regular.ttc",
|
||||
]:
|
||||
if Path(font_path).exists():
|
||||
try:
|
||||
pdfmetrics.registerFont(TTFont("CJK", font_path, subfontIndex=0))
|
||||
cjk_font = "CJK"
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
path = output_dir / f"{title}.pdf"
|
||||
doc = SimpleDocTemplate(str(path), pagesize=A4,
|
||||
topMargin=25*mm, bottomMargin=25*mm)
|
||||
|
||||
styles = getSampleStyleSheet()
|
||||
styles.add(ParagraphStyle(
|
||||
name="CJKTitle", fontName=cjk_font, fontSize=22,
|
||||
spaceAfter=12, alignment=1,
|
||||
))
|
||||
styles.add(ParagraphStyle(
|
||||
name="CJKHeading", fontName=cjk_font, fontSize=16,
|
||||
spaceAfter=8, spaceBefore=16, textColor=colors.HexColor("#1a1a2e"),
|
||||
))
|
||||
styles.add(ParagraphStyle(
|
||||
name="CJKBody", fontName=cjk_font, fontSize=11,
|
||||
spaceAfter=6, leading=16,
|
||||
))
|
||||
|
||||
elements = []
|
||||
|
||||
# Title
|
||||
elements.append(Paragraph(title, styles["CJKTitle"]))
|
||||
elements.append(Spacer(1, 12))
|
||||
|
||||
# Executive summary
|
||||
if summary := draft.get("executive_summary"):
|
||||
elements.append(Paragraph("执行摘要", styles["CJKHeading"]))
|
||||
elements.append(Paragraph(summary, styles["CJKBody"]))
|
||||
elements.append(Spacer(1, 12))
|
||||
|
||||
# Chapters
|
||||
for chapter in draft.get("chapters", []):
|
||||
elements.append(PageBreak())
|
||||
elements.append(Paragraph(chapter["title"], styles["CJKHeading"]))
|
||||
content = chapter.get("content", "")
|
||||
for para in content.split("\n\n"):
|
||||
para = para.strip()
|
||||
if para:
|
||||
elements.append(Paragraph(para, styles["CJKBody"]))
|
||||
|
||||
# Tables
|
||||
for table_spec in data_assets.get("tables", []):
|
||||
elements.append(Spacer(1, 12))
|
||||
elements.append(Paragraph(table_spec.get("title", ""), styles["CJKHeading"]))
|
||||
headers = table_spec.get("headers", [])
|
||||
rows = table_spec.get("rows", [])
|
||||
if headers:
|
||||
table_data = [headers] + rows
|
||||
t = Table(table_data)
|
||||
t.setStyle(TableStyle([
|
||||
("BACKGROUND", (0, 0), (-1, 0), colors.HexColor("#1a1a2e")),
|
||||
("TEXTCOLOR", (0, 0), (-1, 0), colors.white),
|
||||
("FONTNAME", (0, 0), (-1, -1), cjk_font),
|
||||
("FONTSIZE", (0, 0), (-1, 0), 10),
|
||||
("FONTSIZE", (0, 1), (-1, -1), 9),
|
||||
("GRID", (0, 0), (-1, -1), 0.5, colors.grey),
|
||||
("ALIGN", (0, 0), (-1, -1), "CENTER"),
|
||||
("ROWBACKGROUNDS", (0, 1), (-1, -1), [colors.white, colors.HexColor("#f5f5f5")]),
|
||||
]))
|
||||
elements.append(t)
|
||||
|
||||
doc.build(elements)
|
||||
return path
|
||||
|
||||
async def _render_pdf_fpdf(
|
||||
self, draft: dict, data_assets: dict, output_dir: Path, title: str
|
||||
) -> Path:
|
||||
"""Fallback PDF generation with fpdf2."""
|
||||
from fpdf import FPDF
|
||||
|
||||
pdf = FPDF()
|
||||
pdf.set_auto_page_break(auto=True, margin=15)
|
||||
|
||||
# Try CJK font
|
||||
for font_path in [
|
||||
"/System/Library/Fonts/STHeiti Medium.ttc",
|
||||
"/System/Library/Fonts/PingFang.ttc",
|
||||
]:
|
||||
if Path(font_path).exists():
|
||||
try:
|
||||
pdf.add_font("CJK", "", font_path, uni=True)
|
||||
pdf.set_font("CJK", "", 11)
|
||||
break
|
||||
except Exception:
|
||||
pdf.set_font("Helvetica", "", 11)
|
||||
else:
|
||||
pdf.set_font("Helvetica", "", 11)
|
||||
|
||||
pdf.add_page()
|
||||
pdf.set_font_size(24)
|
||||
pdf.cell(0, 20, title, new_x="LMARGIN", new_y="NEXT", align="C")
|
||||
|
||||
pdf.set_font_size(11)
|
||||
if summary := draft.get("executive_summary"):
|
||||
pdf.set_font_size(16)
|
||||
pdf.cell(0, 12, "执行摘要", new_x="LMARGIN", new_y="NEXT")
|
||||
pdf.set_font_size(11)
|
||||
pdf.multi_cell(0, 6, summary)
|
||||
|
||||
for chapter in draft.get("chapters", []):
|
||||
pdf.add_page()
|
||||
pdf.set_font_size(16)
|
||||
pdf.cell(0, 12, chapter["title"], new_x="LMARGIN", new_y="NEXT")
|
||||
pdf.set_font_size(11)
|
||||
pdf.multi_cell(0, 6, chapter.get("content", ""))
|
||||
|
||||
path = output_dir / f"{title}.pdf"
|
||||
pdf.output(str(path))
|
||||
return path
|
||||
103
app/agents/researcher.py
Normal file
103
app/agents/researcher.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""Researcher Agent — domain-aware, bilingual research."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .base import BaseAgent
|
||||
from app.config import settings
|
||||
|
||||
SYSTEM_EN = """\
|
||||
You are a senior industry analyst at a top-tier consulting firm.
|
||||
Your task is to produce a thorough research brief based on the given instructions.
|
||||
|
||||
Requirements:
|
||||
1. Be specific — cite concrete data points, market sizes, growth rates, company names
|
||||
2. Be structured — organize findings with clear headings and logical flow
|
||||
3. Be analytical — don't just list facts, provide insights and implications
|
||||
4. Flag data gaps — explicitly note where data is uncertain or unavailable
|
||||
|
||||
Output (JSON):
|
||||
{
|
||||
"title": "Research brief title",
|
||||
"executive_summary": "2-3 sentence summary of key findings",
|
||||
"sections": [
|
||||
{
|
||||
"heading": "Section heading",
|
||||
"content": "Detailed findings (Markdown)",
|
||||
"data_points": ["key data points extracted"],
|
||||
"sources_quality": "high|medium|low — how confident are you in the data"
|
||||
}
|
||||
],
|
||||
"data_gaps": ["areas where data is insufficient or uncertain"],
|
||||
"key_insights": ["top 3-5 non-obvious insights"]
|
||||
}"""
|
||||
|
||||
SYSTEM_ZH = """\
|
||||
你是一位顶级咨询公司的资深行业分析师。
|
||||
你的任务是根据给定的指令,输出一份深度研究简报。
|
||||
|
||||
要求:
|
||||
1. 具体——引用具体的数据点、市场规模、增长率、企业名称
|
||||
2. 结构化——用清晰的标题和逻辑流组织发现
|
||||
3. 有分析深度——不要只罗列事实,要提供洞察和含义
|
||||
4. 标注数据缺口——明确指出数据不确定或不可获取的地方
|
||||
|
||||
输出(JSON):
|
||||
{
|
||||
"title": "研究简报标题",
|
||||
"executive_summary": "核心发现的2-3句总结",
|
||||
"sections": [
|
||||
{
|
||||
"heading": "章节标题",
|
||||
"content": "详细发现(Markdown格式)",
|
||||
"data_points": ["提取的关键数据点"],
|
||||
"sources_quality": "high|medium|low — 对数据的置信度"
|
||||
}
|
||||
],
|
||||
"data_gaps": ["数据不充分或不确定的领域"],
|
||||
"key_insights": ["3-5条非显而易见的洞察"]
|
||||
}"""
|
||||
|
||||
|
||||
class ResearcherAgent(BaseAgent):
|
||||
name = "researcher"
|
||||
description = "域感知研究 — 根据领域选择最优模型和语言"
|
||||
|
||||
def __init__(self, model: str | None = None, language: str = "en"):
|
||||
super().__init__(model=model)
|
||||
self.language = language
|
||||
self.system_prompt = SYSTEM_ZH if language == "zh" else SYSTEM_EN
|
||||
|
||||
async def run(self, context: dict[str, Any]) -> dict[str, Any]:
|
||||
requirement = context["requirement"]
|
||||
report_type = context.get("report_type", "")
|
||||
extra_data = context.get("extra_data", "")
|
||||
|
||||
if self.language == "zh":
|
||||
prompt = f"""\
|
||||
## 研究指令
|
||||
{requirement}
|
||||
|
||||
## 研究方向
|
||||
{report_type}
|
||||
|
||||
## 补充数据
|
||||
{extra_data if extra_data else "(无)"}
|
||||
|
||||
请输出研究简报 JSON。"""
|
||||
else:
|
||||
prompt = f"""\
|
||||
## Research instructions
|
||||
{requirement}
|
||||
|
||||
## Research focus
|
||||
{report_type}
|
||||
|
||||
## Additional data
|
||||
{extra_data if extra_data else "(none)"}
|
||||
|
||||
Output the research brief as JSON."""
|
||||
|
||||
result = await self.call_llm_json(prompt, max_tokens=6144)
|
||||
return {"research": result}
|
||||
79
app/agents/reviewer.py
Normal file
79
app/agents/reviewer.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Reviewer Agent — bilingual quality check with strongest reasoning model."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from .base import BaseAgent
|
||||
from app.config import settings
|
||||
|
||||
|
||||
class ReviewerAgent(BaseAgent):
|
||||
name = "reviewer"
|
||||
description = "双语报告质量审查 — 使用最强推理模型"
|
||||
|
||||
system_prompt = """\
|
||||
You are a senior consulting partner reviewing a report before client delivery.
|
||||
The report has both Chinese and English versions (or will be translated).
|
||||
|
||||
Review dimensions:
|
||||
1. **Accuracy** — Are data points, percentages, and claims supported by the research?
|
||||
Cross-check global claims against English research, Chinese claims against Chinese research.
|
||||
2. **Logical consistency** — Does the narrative flow? Are there contradictions between chapters?
|
||||
3. **Depth of analysis** — Is it consultancy-grade or just surface-level? Would a C-suite exec find it valuable?
|
||||
4. **Bilingual quality** — If translated version exists, check for translation artifacts,
|
||||
mistranslated terminology, or cultural mismatches.
|
||||
5. **Data gaps honesty** — Are uncertainties acknowledged or are claims fabricated?
|
||||
6. **Completeness** — Are any critical aspects of the requirement left unaddressed?
|
||||
|
||||
Scoring guide:
|
||||
- 90+: Publication-ready
|
||||
- 80-89: Minor issues, can pass with notes
|
||||
- 70-79: Needs revision (verdict: revise)
|
||||
- <70: Significant problems (verdict: reject)
|
||||
|
||||
Output (JSON):
|
||||
{
|
||||
"overall_score": 85,
|
||||
"verdict": "pass|revise|reject",
|
||||
"issues": [
|
||||
{
|
||||
"severity": "high|medium|low",
|
||||
"chapter": "affected chapter",
|
||||
"dimension": "accuracy|consistency|depth|bilingual|gaps|completeness",
|
||||
"description": "issue description",
|
||||
"suggestion": "specific fix suggestion"
|
||||
}
|
||||
],
|
||||
"strengths": ["what the report does well"],
|
||||
"summary": "Overall assessment (2-3 sentences)"
|
||||
}"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(model=settings.model_for_domain("reasoning"))
|
||||
|
||||
async def run(self, context: dict[str, Any]) -> dict[str, Any]:
|
||||
draft = context["draft"]
|
||||
draft_translated = context.get("draft_translated", {})
|
||||
research = context["research"]
|
||||
|
||||
sections = [
|
||||
"## Research Plan (what was asked)",
|
||||
json.dumps(research, ensure_ascii=False, indent=2),
|
||||
"",
|
||||
"## Primary Draft",
|
||||
json.dumps(draft, ensure_ascii=False, indent=2),
|
||||
]
|
||||
|
||||
if draft_translated:
|
||||
sections.extend([
|
||||
"",
|
||||
"## Translated Version",
|
||||
json.dumps(draft_translated, ensure_ascii=False, indent=2),
|
||||
])
|
||||
|
||||
prompt = "\n".join(sections) + "\n\nReview the report. Output JSON."
|
||||
|
||||
result = await self.call_llm_json(prompt, max_tokens=4096)
|
||||
return {"review": result}
|
||||
86
app/agents/writer.py
Normal file
86
app/agents/writer.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""Writer Agent — synthesizes multilingual research tracks into a cohesive report."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from .base import BaseAgent
|
||||
from app.config import settings
|
||||
|
||||
|
||||
class WriterAgent(BaseAgent):
|
||||
name = "writer"
|
||||
description = "汇聚多语言/多领域研究成果,撰写完整报告"
|
||||
|
||||
system_prompt = """\
|
||||
You are an expert consulting report writer. Your task is to synthesize research
|
||||
findings from MULTIPLE parallel tracks (some in English, some in Chinese) into
|
||||
ONE cohesive, professional consulting report.
|
||||
|
||||
CRITICAL RULES:
|
||||
1. The PRIMARY output language is Chinese (中文) — this is for Chinese clients
|
||||
2. For global/international sections, the analysis depth must reflect the English research
|
||||
3. For China-specific sections, preserve the precision of Chinese-native research
|
||||
4. Maintain professional consulting tone throughout
|
||||
5. Every claim should trace back to a research track's findings
|
||||
6. Mark chart/table needs: {{CHART:描述}} and {{TABLE:描述}}
|
||||
7. If a research track flags "data_gaps", acknowledge uncertainty rather than fabricating
|
||||
|
||||
Output (JSON):
|
||||
{
|
||||
"title": "报告标题(中文)",
|
||||
"title_en": "Report Title (English)",
|
||||
"chapters": [
|
||||
{
|
||||
"title": "章节标题",
|
||||
"content": "章节正文(Markdown 格式,中文)",
|
||||
"source_tracks": ["引用的研究轨道名称"],
|
||||
"charts": ["图表需求"],
|
||||
"tables": ["表格需求"]
|
||||
}
|
||||
],
|
||||
"executive_summary": "执行摘要(中文,300-500字)",
|
||||
"executive_summary_en": "Executive Summary (English, 200-400 words)"
|
||||
}"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(model=settings.model_for_domain("reasoning"))
|
||||
|
||||
async def run(self, context: dict[str, Any]) -> dict[str, Any]:
|
||||
research = context["research"]
|
||||
requirement = context["requirement"]
|
||||
revision_feedback = context.get("revision_feedback", "")
|
||||
|
||||
# Format multi-track, multilingual research
|
||||
tracks_text = ""
|
||||
for track in research.get("tracks", []):
|
||||
lang_tag = f"[{track.get('native_language', '?').upper()}]"
|
||||
domain_tag = f"[{track.get('domain', '?')}]"
|
||||
tracks_text += f"\n### {domain_tag} {lang_tag} {track.get('track', '')}\n"
|
||||
findings = track.get("findings", {})
|
||||
tracks_text += json.dumps(findings, ensure_ascii=False, indent=2)
|
||||
|
||||
synthesis_guide = research.get("synthesis_guide", "")
|
||||
|
||||
prompt = f"""\
|
||||
## 原始需求 / Original Requirement
|
||||
{requirement}
|
||||
|
||||
## 报告标题
|
||||
中文:{research.get("title_zh", "")}
|
||||
English: {research.get("title_en", "")}
|
||||
|
||||
## 写作指导 / Synthesis Guide
|
||||
{synthesis_guide}
|
||||
|
||||
## 各研究轨道成果 / Research Track Results
|
||||
(注意:有些轨道是英文原版 [EN],有些是中文原版 [ZH],请综合使用)
|
||||
{tracks_text}
|
||||
|
||||
{f"## 审稿反馈 / Review Feedback{revision_feedback}" if revision_feedback else ""}
|
||||
|
||||
请汇聚以上研究成果,撰写完整的中文报告。输出 JSON。"""
|
||||
|
||||
result = await self.call_llm_json(prompt, max_tokens=8192)
|
||||
return {"draft": result}
|
||||
0
app/api/__init__.py
Normal file
0
app/api/__init__.py
Normal file
110
app/api/routes.py
Normal file
110
app/api/routes.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""API routes for report generation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.graph.state import ReportState
|
||||
from app.pipeline.orchestrator import PipelineOrchestrator
|
||||
from app.pipeline.task import create_report_state
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api")
|
||||
|
||||
# In-memory store (swap for DB later)
|
||||
reports: dict[str, ReportState] = {}
|
||||
orchestrator = PipelineOrchestrator()
|
||||
|
||||
|
||||
class CreateReportRequest(BaseModel):
|
||||
requirement: str
|
||||
report_type: str = "行业分析报告"
|
||||
extra_data: str = ""
|
||||
output_formats: list[str] = ["docx"]
|
||||
client_id: str | None = None
|
||||
|
||||
|
||||
class ReportResponse(BaseModel):
|
||||
id: str
|
||||
current_node: str
|
||||
error: str | None = None
|
||||
generated_files: list[str] = []
|
||||
node_history: list[dict] = []
|
||||
revision_count: int = 0
|
||||
|
||||
|
||||
def _to_response(state: ReportState) -> ReportResponse:
|
||||
return ReportResponse(
|
||||
id=state.id,
|
||||
current_node=state.current_node,
|
||||
error=state.error,
|
||||
generated_files=state.generated_files,
|
||||
node_history=state.node_history,
|
||||
revision_count=state.revision_count,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/reports", response_model=ReportResponse)
|
||||
async def create_report(req: CreateReportRequest):
|
||||
"""Create and execute a report generation pipeline."""
|
||||
state = create_report_state(
|
||||
requirement=req.requirement,
|
||||
report_type=req.report_type,
|
||||
extra_data=req.extra_data,
|
||||
output_formats=req.output_formats,
|
||||
client_id=req.client_id,
|
||||
)
|
||||
reports[state.id] = state
|
||||
|
||||
# Run the full graph (blocking for now, add task queue later)
|
||||
state = await orchestrator.run(state)
|
||||
reports[state.id] = state
|
||||
|
||||
if state.error:
|
||||
raise HTTPException(status_code=500, detail=state.error)
|
||||
|
||||
return _to_response(state)
|
||||
|
||||
|
||||
@router.get("/reports/{report_id}", response_model=ReportResponse)
|
||||
async def get_report(report_id: str):
|
||||
"""Get report status and results."""
|
||||
state = reports.get(report_id)
|
||||
if not state:
|
||||
raise HTTPException(status_code=404, detail="Report not found")
|
||||
return _to_response(state)
|
||||
|
||||
|
||||
@router.get("/reports")
|
||||
async def list_reports():
|
||||
"""List all reports."""
|
||||
return [_to_response(s) for s in reports.values()]
|
||||
|
||||
|
||||
@router.get("/reports/{report_id}/detail")
|
||||
async def get_report_detail(report_id: str):
|
||||
"""Get full report detail including draft and research."""
|
||||
state = reports.get(report_id)
|
||||
if not state:
|
||||
raise HTTPException(status_code=404, detail="Report not found")
|
||||
return {
|
||||
"id": state.id,
|
||||
"requirement": state.requirement,
|
||||
"decomposition": state.decomposition,
|
||||
"research_results": [
|
||||
{
|
||||
"description": r.description,
|
||||
"status": r.status.value,
|
||||
"duration_ms": r.duration_ms,
|
||||
}
|
||||
for r in state.research_results
|
||||
],
|
||||
"draft": state.draft,
|
||||
"review": state.review,
|
||||
"generated_files": state.generated_files,
|
||||
"node_history": state.node_history,
|
||||
}
|
||||
57
app/config.py
Normal file
57
app/config.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Application configuration — multi-model pool by domain."""
|
||||
|
||||
from pathlib import Path
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
# --- Model Pool ---
|
||||
# Global analysis (English-native): strongest global reasoning
|
||||
llm_global: str = "openai/Claude-3.5-Sonnet"
|
||||
# China domestic (Chinese-native): deepest Chinese market knowledge
|
||||
llm_china: str = "openai/DeepSeek-R1"
|
||||
# Synthesis & review: strongest reasoning for cross-domain work
|
||||
llm_reasoning: str = "openai/Claude-3.5-Sonnet"
|
||||
# Data & fast tasks: cost-effective structured output
|
||||
llm_fast: str = "openai/Gemini-2.0-Flash"
|
||||
# Translation: high-quality bidirectional EN↔ZH
|
||||
llm_translation: str = "openai/Claude-3.5-Sonnet"
|
||||
|
||||
# Fallback default (if domain not specified)
|
||||
llm_model: str = "openai/Claude-3.5-Sonnet"
|
||||
|
||||
# API config (Poe as unified gateway)
|
||||
llm_api_key: str = ""
|
||||
llm_api_base: str = "https://api.poe.com/bot/"
|
||||
|
||||
# Server
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 4200
|
||||
|
||||
# Paths
|
||||
base_dir: Path = Path(__file__).resolve().parent.parent
|
||||
templates_dir: Path = base_dir / "templates"
|
||||
output_dir: Path = base_dir / "output"
|
||||
|
||||
model_config = {"env_file": ".env", "env_file_encoding": "utf-8"}
|
||||
|
||||
def model_for_domain(self, domain: str) -> str:
|
||||
"""Get the best model for a content domain.
|
||||
|
||||
Domains:
|
||||
global — international markets, global competition, tech trends
|
||||
china — Chinese market, domestic policy, local competition
|
||||
reasoning — synthesis, review, strategic recommendations
|
||||
fast — data processing, chart generation, structured output
|
||||
translation — high-quality EN↔ZH translation
|
||||
"""
|
||||
return {
|
||||
"global": self.llm_global,
|
||||
"china": self.llm_china,
|
||||
"reasoning": self.llm_reasoning,
|
||||
"fast": self.llm_fast,
|
||||
"translation": self.llm_translation,
|
||||
}.get(domain, self.llm_model)
|
||||
|
||||
|
||||
settings = Settings()
|
||||
0
app/data/__init__.py
Normal file
0
app/data/__init__.py
Normal file
52
app/data/factory.py
Normal file
52
app/data/factory.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Data router factory — assembles the router with all available sources."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from .router import DataRouter
|
||||
from .sources.akshare_source import AKShareSource
|
||||
from .sources.worldbank_source import WorldBankSource
|
||||
from .sources.gpt_researcher_source import GPTResearcherSource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_data_router() -> DataRouter:
|
||||
"""Create a DataRouter with all available data sources.
|
||||
|
||||
Priority order (lower = tried first):
|
||||
10 AKShare — free, fast, covers Chinese macro/industry
|
||||
20 World Bank — free, global macro
|
||||
90 GPT Researcher — universal fallback (web research)
|
||||
|
||||
Future additions:
|
||||
30 巨潮资讯 — free tier, Chinese public companies
|
||||
40 天眼查 API — paid, company data
|
||||
50 Choice API — paid, financial data
|
||||
60 Statista API — paid, global industry stats
|
||||
70 FRED API — free, US macro
|
||||
80 UN Comtrade — free, global trade
|
||||
"""
|
||||
router = DataRouter()
|
||||
|
||||
# --- Free tier (always available) ---
|
||||
router.register(AKShareSource(), priority=10)
|
||||
router.register(WorldBankSource(), priority=20)
|
||||
|
||||
# --- Universal fallback ---
|
||||
router.register(GPTResearcherSource(), priority=90)
|
||||
|
||||
logger.info(f"[data_factory] created router with {len(router.sources)} sources")
|
||||
return router
|
||||
|
||||
|
||||
# Singleton
|
||||
_router: DataRouter | None = None
|
||||
|
||||
|
||||
def get_data_router() -> DataRouter:
|
||||
global _router
|
||||
if _router is None:
|
||||
_router = create_data_router()
|
||||
return _router
|
||||
82
app/data/router.py
Normal file
82
app/data/router.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""DataRouter — routes data requests to the optimal source.
|
||||
|
||||
Architecture:
|
||||
Agent needs data → DataRouter → best available source → standardized result
|
||||
|
||||
Sources are tried in priority order. If a source fails or is unavailable,
|
||||
the router falls back to the next source.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from .sources.base import DataSource, DataResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DataRouter:
|
||||
"""Routes data queries to the best available source.
|
||||
|
||||
Sources are registered with a priority (lower = higher priority).
|
||||
For each query, the router tries sources in priority order until one succeeds.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.sources: list[tuple[int, DataSource]] = []
|
||||
|
||||
def register(self, source: DataSource, priority: int = 100) -> "DataRouter":
|
||||
self.sources.append((priority, source))
|
||||
self.sources.sort(key=lambda x: x[0])
|
||||
logger.info(f"[data_router] registered {source.name} (priority={priority})")
|
||||
return self
|
||||
|
||||
async def query(
|
||||
self,
|
||||
query: str,
|
||||
data_type: str = "general",
|
||||
country: str | None = None,
|
||||
**kwargs,
|
||||
) -> DataResult:
|
||||
"""Query data sources in priority order.
|
||||
|
||||
Args:
|
||||
query: Natural language or structured data query
|
||||
data_type: One of: macro, industry, company, trade, patent, general
|
||||
country: ISO country code (CN, US, etc.) or None for global
|
||||
**kwargs: Source-specific parameters
|
||||
"""
|
||||
errors = []
|
||||
|
||||
for priority, source in self.sources:
|
||||
if not source.supports(data_type, country):
|
||||
continue
|
||||
|
||||
try:
|
||||
logger.info(f"[data_router] trying {source.name} for '{query[:50]}...'")
|
||||
result = await source.fetch(query, data_type=data_type, country=country, **kwargs)
|
||||
if result.data:
|
||||
logger.info(f"[data_router] {source.name} returned {len(str(result.data))} chars")
|
||||
return result
|
||||
except Exception as e:
|
||||
errors.append(f"{source.name}: {e}")
|
||||
logger.warning(f"[data_router] {source.name} failed: {e}")
|
||||
|
||||
# All sources failed
|
||||
return DataResult(
|
||||
source="none",
|
||||
data=None,
|
||||
error=f"All sources failed: {'; '.join(errors)}",
|
||||
)
|
||||
|
||||
async def query_multiple(
|
||||
self,
|
||||
queries: list[dict[str, Any]],
|
||||
) -> list[DataResult]:
|
||||
"""Run multiple queries (can be parallelized later)."""
|
||||
import asyncio
|
||||
return await asyncio.gather(*[
|
||||
self.query(**q) for q in queries
|
||||
])
|
||||
0
app/data/sources/__init__.py
Normal file
0
app/data/sources/__init__.py
Normal file
94
app/data/sources/akshare_source.py
Normal file
94
app/data/sources/akshare_source.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""AKShare data source — Chinese macro/industry data via open-source Python library.
|
||||
|
||||
Covers: GDP, CPI, PMI, industrial profit, trade balance, and 30+ data categories.
|
||||
All data returned as Pandas DataFrames, converted to dicts for standardization.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from .base import DataSource, DataResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Map common data requests to AKShare function names
|
||||
AKSHARE_ENDPOINTS = {
|
||||
"gdp": "macro_china_gdp",
|
||||
"cpi": "macro_china_cpi_monthly",
|
||||
"ppi": "macro_china_ppi",
|
||||
"pmi": "macro_china_pmi",
|
||||
"industrial_profit": "macro_china_industrial_profit",
|
||||
"trade_balance": "macro_china_trade_balance",
|
||||
"money_supply": "macro_china_money_supply",
|
||||
"fdi": "macro_china_fdi",
|
||||
"real_estate": "macro_china_real_estate",
|
||||
"retail_sales": "macro_china_consumer_goods_retail",
|
||||
"fixed_asset": "macro_china_fai",
|
||||
"unemployment": "macro_china_urban_unemployment",
|
||||
# US macro
|
||||
"us_gdp": "macro_usa_gdp_monthly",
|
||||
"us_cpi": "macro_usa_cpi_monthly",
|
||||
"us_unemployment": "macro_usa_unemployment_rate",
|
||||
# Global
|
||||
"global_gdp": "macro_global_gdp",
|
||||
}
|
||||
|
||||
|
||||
class AKShareSource(DataSource):
|
||||
name = "akshare"
|
||||
description = "中国宏观经济/行业数据(免费开源,封装统计局等30+数据源)"
|
||||
|
||||
def supports(self, data_type: str, country: str | None = None) -> bool:
|
||||
return data_type in ("macro", "industry", "general")
|
||||
|
||||
async def fetch(
|
||||
self, query: str, *, data_type: str = "general", country: str | None = None, **kwargs,
|
||||
) -> DataResult:
|
||||
try:
|
||||
import akshare as ak
|
||||
except ImportError:
|
||||
return DataResult(source=self.name, error="akshare not installed (pip install akshare)")
|
||||
|
||||
# Try to match query to a known endpoint
|
||||
endpoint_name = kwargs.get("endpoint")
|
||||
if not endpoint_name:
|
||||
query_lower = query.lower()
|
||||
for key, func_name in AKSHARE_ENDPOINTS.items():
|
||||
if key in query_lower:
|
||||
endpoint_name = func_name
|
||||
break
|
||||
|
||||
if not endpoint_name:
|
||||
return DataResult(source=self.name, data=None, error=f"No matching AKShare endpoint for: {query}")
|
||||
|
||||
try:
|
||||
func = getattr(ak, endpoint_name, None)
|
||||
if not func:
|
||||
return DataResult(source=self.name, error=f"AKShare function not found: {endpoint_name}")
|
||||
|
||||
logger.info(f"[akshare] calling ak.{endpoint_name}()")
|
||||
df = func()
|
||||
|
||||
# Convert to dict for serialization
|
||||
# Take last N rows for recent data
|
||||
limit = kwargs.get("limit", 20)
|
||||
recent = df.tail(limit)
|
||||
|
||||
return DataResult(
|
||||
source=self.name,
|
||||
data={
|
||||
"columns": list(recent.columns),
|
||||
"records": recent.to_dict(orient="records"),
|
||||
"total_rows": len(df),
|
||||
"returned_rows": len(recent),
|
||||
},
|
||||
metadata={
|
||||
"endpoint": endpoint_name,
|
||||
"description": f"AKShare {endpoint_name}",
|
||||
"format": "tabular",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
return DataResult(source=self.name, error=f"AKShare call failed: {e}")
|
||||
34
app/data/sources/base.py
Normal file
34
app/data/sources/base.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Base class for data sources."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class DataResult(BaseModel):
|
||||
"""Standardized result from any data source."""
|
||||
source: str = ""
|
||||
data: Any = None
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
# metadata includes: unit, time_range, update_date, confidence, etc.
|
||||
error: str | None = None
|
||||
cached: bool = False
|
||||
|
||||
|
||||
class DataSource(ABC):
|
||||
"""Abstract data source."""
|
||||
name: str = "base"
|
||||
description: str = ""
|
||||
|
||||
def supports(self, data_type: str, country: str | None = None) -> bool:
|
||||
"""Return True if this source can handle this data type / country."""
|
||||
return True
|
||||
|
||||
@abstractmethod
|
||||
async def fetch(
|
||||
self, query: str, *, data_type: str = "general", country: str | None = None, **kwargs,
|
||||
) -> DataResult:
|
||||
...
|
||||
61
app/data/sources/gpt_researcher_source.py
Normal file
61
app/data/sources/gpt_researcher_source.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""GPT Researcher MCP — deep web research as fallback for any industry.
|
||||
|
||||
This is the universal fallback: when structured data sources don't have
|
||||
data for a niche/cold industry, deep web research fills the gap.
|
||||
|
||||
Requires GPT Researcher MCP server to be running (already configured in ~/.claude.json).
|
||||
For direct API use, we call the MCP tools via the subprocess approach.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import subprocess
|
||||
from typing import Any
|
||||
|
||||
from .base import DataSource, DataResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GPTResearcherSource(DataSource):
|
||||
name = "gpt_researcher"
|
||||
description = "Deep web research — universal fallback for any industry/topic"
|
||||
|
||||
def supports(self, data_type: str, country: str | None = None) -> bool:
|
||||
# Supports everything — this is the universal fallback
|
||||
return True
|
||||
|
||||
async def fetch(
|
||||
self, query: str, *, data_type: str = "general", country: str | None = None, **kwargs,
|
||||
) -> DataResult:
|
||||
mode = kwargs.get("mode", "quick") # "quick" or "deep"
|
||||
|
||||
# GPT Researcher is available as MCP tools in Claude Code.
|
||||
# For standalone use, we need to call it via its API.
|
||||
# The MCP server runs at a local port — check if available.
|
||||
|
||||
# For now, provide a structured placeholder that agents can use
|
||||
# to request deep research. The actual MCP call happens at the
|
||||
# agent level when integrated into the pipeline.
|
||||
return DataResult(
|
||||
source=self.name,
|
||||
data={
|
||||
"query": query,
|
||||
"mode": mode,
|
||||
"status": "ready",
|
||||
"note": (
|
||||
"GPT Researcher MCP is available for deep web research. "
|
||||
"Call via MCP tools: deep_research() or quick_search(). "
|
||||
"This source returns research-ready queries for MCP integration."
|
||||
),
|
||||
},
|
||||
metadata={
|
||||
"type": "mcp_research_request",
|
||||
"mode": mode,
|
||||
"data_type": data_type,
|
||||
"country": country,
|
||||
},
|
||||
)
|
||||
104
app/data/sources/worldbank_source.py
Normal file
104
app/data/sources/worldbank_source.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""World Bank Open Data — global macro indicators, 217 economies, free API.
|
||||
|
||||
API: https://api.worldbank.org/v2/
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from .base import DataSource, DataResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BASE_URL = "https://api.worldbank.org/v2"
|
||||
|
||||
# Common indicators for consulting reports
|
||||
INDICATORS = {
|
||||
"gdp": "NY.GDP.MKTP.CD", # GDP (current US$)
|
||||
"gdp_growth": "NY.GDP.MKTP.KD.ZG", # GDP growth (annual %)
|
||||
"gdp_per_capita": "NY.GDP.PCAP.CD", # GDP per capita
|
||||
"population": "SP.POP.TOTL", # Total population
|
||||
"inflation": "FP.CPI.TOTL.ZG", # Inflation (CPI %)
|
||||
"trade_pct_gdp": "NE.TRD.GNFS.ZS", # Trade (% of GDP)
|
||||
"fdi_net": "BX.KLT.DINV.CD.WD", # FDI net inflows
|
||||
"unemployment": "SL.UEM.TOTL.ZS", # Unemployment (%)
|
||||
"exports": "NE.EXP.GNFS.CD", # Exports
|
||||
"imports": "NE.IMP.GNFS.CD", # Imports
|
||||
"r_and_d": "GB.XPD.RSDV.GD.ZS", # R&D expenditure (% GDP)
|
||||
"high_tech_exports": "TX.VAL.TECH.MF.ZS", # High-tech exports (% manufactured)
|
||||
}
|
||||
|
||||
|
||||
class WorldBankSource(DataSource):
|
||||
name = "worldbank"
|
||||
description = "World Bank Open Data — 1600+ indicators, 217 economies, free"
|
||||
|
||||
def supports(self, data_type: str, country: str | None = None) -> bool:
|
||||
return data_type in ("macro", "general")
|
||||
|
||||
async def fetch(
|
||||
self, query: str, *, data_type: str = "general", country: str | None = None, **kwargs,
|
||||
) -> DataResult:
|
||||
indicator_code = kwargs.get("indicator")
|
||||
if not indicator_code:
|
||||
query_lower = query.lower()
|
||||
for key, code in INDICATORS.items():
|
||||
if key in query_lower:
|
||||
indicator_code = code
|
||||
break
|
||||
|
||||
if not indicator_code:
|
||||
# Default to GDP
|
||||
indicator_code = INDICATORS["gdp"]
|
||||
|
||||
country_code = country or "WLD" # WLD = World
|
||||
per_page = kwargs.get("per_page", 20)
|
||||
|
||||
url = f"{BASE_URL}/country/{country_code}/indicator/{indicator_code}"
|
||||
params = {
|
||||
"format": "json",
|
||||
"per_page": per_page,
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=15) as client:
|
||||
resp = await client.get(url, params=params)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
if not data or len(data) < 2:
|
||||
return DataResult(source=self.name, data=None, error="No data returned")
|
||||
|
||||
metadata_raw = data[0]
|
||||
records = data[1]
|
||||
|
||||
# Parse into clean format
|
||||
clean_records = []
|
||||
for r in records:
|
||||
if r.get("value") is not None:
|
||||
clean_records.append({
|
||||
"year": r["date"],
|
||||
"value": r["value"],
|
||||
"country": r["country"]["value"],
|
||||
"indicator": r["indicator"]["value"],
|
||||
})
|
||||
|
||||
return DataResult(
|
||||
source=self.name,
|
||||
data={
|
||||
"indicator": indicator_code,
|
||||
"country": country_code,
|
||||
"records": clean_records,
|
||||
},
|
||||
metadata={
|
||||
"total": metadata_raw.get("total", 0),
|
||||
"indicator_name": clean_records[0]["indicator"] if clean_records else "",
|
||||
"format": "timeseries",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
return DataResult(source=self.name, error=f"World Bank API failed: {e}")
|
||||
0
app/graph/__init__.py
Normal file
0
app/graph/__init__.py
Normal file
108
app/graph/builder.py
Normal file
108
app/graph/builder.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""Graph builder — assembles nodes into an executable report generation graph.
|
||||
|
||||
No LangGraph dependency. Pure asyncio with a simple node-runner pattern.
|
||||
|
||||
v2: domain-aware, bilingual (translate node added)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Callable, Awaitable
|
||||
|
||||
from app.middleware.chain import MiddlewareChain
|
||||
from .state import ReportState, NodeStatus
|
||||
from .nodes import (
|
||||
DecomposeNode,
|
||||
ParallelResearchNode,
|
||||
WriteNode,
|
||||
TranslateNode,
|
||||
DataNode,
|
||||
ReviewNode,
|
||||
FormatNode,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NodeFn = Callable[[ReportState], Awaitable[ReportState]]
|
||||
|
||||
|
||||
class ReportGraph:
|
||||
"""Executable graph for report generation.
|
||||
|
||||
Graph structure (v2):
|
||||
decompose → parallel_research → write → translate → data → review →
|
||||
├─ pass → format → END
|
||||
└─ revise → write → translate → data → review → ...
|
||||
"""
|
||||
|
||||
def __init__(self, middleware: MiddlewareChain | None = None):
|
||||
self.middleware = middleware
|
||||
self.decompose = DecomposeNode()
|
||||
self.parallel_research = ParallelResearchNode()
|
||||
self.write = WriteNode()
|
||||
self.translate = TranslateNode()
|
||||
self.data = DataNode()
|
||||
self.review = ReviewNode()
|
||||
self.format = FormatNode()
|
||||
|
||||
async def _run_node(self, name: str, node: NodeFn, state: ReportState) -> ReportState:
|
||||
"""Run a single node with error handling."""
|
||||
try:
|
||||
logger.info(f"[graph] entering node: {name}")
|
||||
state = await node(state)
|
||||
logger.info(f"[graph] completed node: {name}")
|
||||
except Exception as e:
|
||||
state.error = f"Node '{name}' failed: {e}"
|
||||
state.log_node(name, NodeStatus.FAILED, str(e))
|
||||
logger.exception(f"[graph] node '{name}' failed")
|
||||
raise
|
||||
return state
|
||||
|
||||
async def run(self, state: ReportState) -> ReportState:
|
||||
"""Execute the full graph."""
|
||||
|
||||
# --- Middleware: before ---
|
||||
if self.middleware:
|
||||
state = await self.middleware.before(state)
|
||||
|
||||
try:
|
||||
# 1. Decompose requirement into domain-tagged parallel tracks
|
||||
state = await self._run_node("decompose", self.decompose, state)
|
||||
|
||||
# 2. Run parallel research (each track uses domain-optimal model)
|
||||
state = await self._run_node("parallel_research", self.parallel_research, state)
|
||||
|
||||
# 3-6. Write → Translate → Data → Review (with revision loop)
|
||||
while True:
|
||||
state = await self._run_node("write", self.write, state)
|
||||
state = await self._run_node("translate", self.translate, state)
|
||||
state = await self._run_node("data", self.data, state)
|
||||
state = await self._run_node("review", self.review, state)
|
||||
|
||||
verdict = state.review.get("verdict", "pass")
|
||||
if verdict == "pass" or state.revision_count >= state.max_revisions:
|
||||
if verdict != "pass":
|
||||
logger.warning(
|
||||
f"[graph] forcing pass after {state.revision_count} revisions"
|
||||
)
|
||||
break
|
||||
|
||||
# Revise — loop back to write
|
||||
state.revision_count += 1
|
||||
logger.info(
|
||||
f"[graph] revision {state.revision_count}/{state.max_revisions}"
|
||||
)
|
||||
|
||||
# 7. Format bilingual output files
|
||||
state = await self._run_node("format", self.format, state)
|
||||
|
||||
except Exception:
|
||||
# Error already logged in _run_node
|
||||
pass
|
||||
|
||||
# --- Middleware: after ---
|
||||
if self.middleware:
|
||||
state = await self.middleware.after(state)
|
||||
|
||||
return state
|
||||
393
app/graph/nodes.py
Normal file
393
app/graph/nodes.py
Normal file
@@ -0,0 +1,393 @@
|
||||
"""Graph nodes — each node is an async function: ReportState → ReportState.
|
||||
|
||||
Node layout (v2 — domain-aware, bilingual):
|
||||
|
||||
START
|
||||
│
|
||||
▼
|
||||
[decompose] — Lead Agent 分解为并行研究轨道,每轨标注 domain + language
|
||||
│
|
||||
▼
|
||||
[parallel_research] — N 个子 Agent 并行,每个用最适合该领域的模型
|
||||
│ global tracks → Claude/GPT (English)
|
||||
│ china tracks → DeepSeek/Qwen (Chinese)
|
||||
▼
|
||||
[write] — Writer 汇聚 → 生成主语言版本
|
||||
│
|
||||
▼
|
||||
[translate] — 高质量翻译 → 生成另一语言版本
|
||||
│
|
||||
▼
|
||||
[data] — Data Agent 生成图表/表格
|
||||
│
|
||||
▼
|
||||
[review] — Reviewer 审查(双语)
|
||||
│ ├─ pass → [format]
|
||||
│ └─ revise → [write]
|
||||
▼
|
||||
[format] — 输出双语版本文件
|
||||
│
|
||||
▼
|
||||
END
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from app.agents.base import BaseAgent
|
||||
from app.agents.researcher import ResearcherAgent
|
||||
from app.agents.writer import WriterAgent
|
||||
from app.agents.data_agent import DataAgent
|
||||
from app.agents.reviewer import ReviewerAgent
|
||||
from app.agents.formatter import FormatterAgent
|
||||
from app.config import settings
|
||||
|
||||
from .state import ReportState, SubtaskResult, NodeStatus, ContentDomain
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Node: decompose — Lead Agent decomposes into domain-tagged parallel tracks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class DecomposeNode:
|
||||
"""Analyzes requirement and decomposes into domain-aware research tracks."""
|
||||
|
||||
def __init__(self):
|
||||
self.agent = BaseAgent()
|
||||
self.agent.name = "lead"
|
||||
self.agent.model = settings.model_for_domain("reasoning")
|
||||
|
||||
async def __call__(self, state: ReportState) -> ReportState:
|
||||
state.current_node = "decompose"
|
||||
state.log_node("decompose", NodeStatus.RUNNING)
|
||||
|
||||
system = """\
|
||||
You are a senior consulting partner planning a global industry report.
|
||||
|
||||
Your job is to decompose the client's requirement into 2-6 parallel research tracks.
|
||||
|
||||
CRITICAL: Each track must be tagged with a content domain and native language:
|
||||
|
||||
- domain: "global" → international markets, global competition, technology trends, overseas benchmarks
|
||||
→ native_language: "en" (English sources are 10-100x richer for global analysis)
|
||||
|
||||
- domain: "china" → Chinese domestic market, government policy, local competitors, China-specific data
|
||||
→ native_language: "zh" (Chinese sources are authoritative for domestic analysis)
|
||||
|
||||
The PRINCIPLE: whichever language has the richest professional literature for that topic
|
||||
should be the native language. The other language version will be translated later.
|
||||
|
||||
Output (JSON):
|
||||
{
|
||||
"title_en": "English report title",
|
||||
"title_zh": "中文报告标题",
|
||||
"report_type": "report type",
|
||||
"tracks": [
|
||||
{
|
||||
"title": "track title (in native language)",
|
||||
"domain": "global|china",
|
||||
"native_language": "en|zh",
|
||||
"focus": "research focus description",
|
||||
"prompt": "detailed research instructions (MUST be in the native_language)",
|
||||
"data_needs": ["required data/charts"]
|
||||
}
|
||||
],
|
||||
"synthesis_guide": "How to merge all tracks into a coherent report (bilingual structure notes)",
|
||||
"methodology": "Analysis methodology"
|
||||
}"""
|
||||
|
||||
prompt = f"""\
|
||||
## Client requirement
|
||||
{state.requirement}
|
||||
|
||||
## Report type
|
||||
{state.report_type}
|
||||
|
||||
## Additional data
|
||||
{state.extra_data or "(none)"}
|
||||
|
||||
## Client context
|
||||
{state.client_context or "(none)"}
|
||||
|
||||
Decompose into parallel research tracks with domain and language tags. Output JSON."""
|
||||
|
||||
result = await self.agent.call_llm_json(prompt, system=system)
|
||||
state.decomposition = result
|
||||
state.log_node("decompose", NodeStatus.COMPLETED,
|
||||
f"{len(result.get('tracks', []))} tracks")
|
||||
return state
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Node: parallel_research — domain-aware parallel execution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class ParallelResearchNode:
|
||||
"""Runs research subtasks in parallel, each using the optimal model for its domain."""
|
||||
|
||||
MAX_CONCURRENT = 5
|
||||
|
||||
async def _run_one(self, track: dict[str, Any]) -> SubtaskResult:
|
||||
domain_str = track.get("domain", "global")
|
||||
domain = ContentDomain(domain_str) if domain_str in ContentDomain.__members__.values() else ContentDomain.GLOBAL
|
||||
native_lang = track.get("native_language", "en")
|
||||
|
||||
result = SubtaskResult(
|
||||
description=track.get("title", ""),
|
||||
domain=domain,
|
||||
native_language=native_lang,
|
||||
)
|
||||
result.status = NodeStatus.RUNNING
|
||||
result.started_at = datetime.now()
|
||||
|
||||
try:
|
||||
# Select model based on domain
|
||||
model = settings.model_for_domain(domain.value)
|
||||
agent = ResearcherAgent(model=model, language=native_lang)
|
||||
|
||||
logger.info(
|
||||
f"[parallel_research] track '{track.get('title')}' "
|
||||
f"→ domain={domain.value}, lang={native_lang}, model={model}"
|
||||
)
|
||||
|
||||
research = await agent.run({
|
||||
"requirement": track["prompt"],
|
||||
"report_type": track.get("focus", ""),
|
||||
"extra_data": "",
|
||||
})
|
||||
result.content = research.get("research", {})
|
||||
result.status = NodeStatus.COMPLETED
|
||||
except Exception as e:
|
||||
result.error = str(e)
|
||||
result.status = NodeStatus.FAILED
|
||||
logger.exception(f"Research track '{track.get('title')}' failed")
|
||||
finally:
|
||||
result.completed_at = datetime.now()
|
||||
|
||||
return result
|
||||
|
||||
async def __call__(self, state: ReportState) -> ReportState:
|
||||
state.current_node = "parallel_research"
|
||||
state.log_node("parallel_research", NodeStatus.RUNNING)
|
||||
|
||||
tracks = state.decomposition.get("tracks", [])
|
||||
if not tracks:
|
||||
state.log_node("parallel_research", NodeStatus.FAILED, "no tracks")
|
||||
state.error = "Decomposition produced no research tracks"
|
||||
return state
|
||||
|
||||
semaphore = asyncio.Semaphore(self.MAX_CONCURRENT)
|
||||
|
||||
async def bounded(track):
|
||||
async with semaphore:
|
||||
return await self._run_one(track)
|
||||
|
||||
logger.info(f"[parallel_research] launching {len(tracks)} tracks concurrently")
|
||||
results = await asyncio.gather(*[bounded(t) for t in tracks])
|
||||
state.research_results = list(results)
|
||||
|
||||
succeeded = sum(1 for r in results if r.status == NodeStatus.COMPLETED)
|
||||
domains = {}
|
||||
for r in results:
|
||||
domains.setdefault(r.domain.value, []).append(r.native_language)
|
||||
state.log_node("parallel_research", NodeStatus.COMPLETED,
|
||||
f"{succeeded}/{len(tracks)} ok, domains={domains}")
|
||||
return state
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Node: write — synthesize research into primary-language draft
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class WriteNode:
|
||||
def __init__(self):
|
||||
self.agent = WriterAgent()
|
||||
|
||||
async def __call__(self, state: ReportState) -> ReportState:
|
||||
state.current_node = "write"
|
||||
state.log_node("write", NodeStatus.RUNNING)
|
||||
|
||||
research_merged = []
|
||||
for r in state.research_results:
|
||||
if r.status == NodeStatus.COMPLETED:
|
||||
research_merged.append({
|
||||
"track": r.description,
|
||||
"domain": r.domain.value,
|
||||
"native_language": r.native_language,
|
||||
"findings": r.content,
|
||||
})
|
||||
|
||||
synthesis_guide = state.decomposition.get("synthesis_guide", "")
|
||||
review_feedback = ""
|
||||
if state.revision_count > 0 and state.review:
|
||||
review_feedback = f"\n\n## Review feedback (revision {state.revision_count})\n"
|
||||
for issue in state.review.get("issues", []):
|
||||
review_feedback += f"- [{issue.get('severity')}] {issue.get('description')} → {issue.get('suggestion')}\n"
|
||||
|
||||
result = await self.agent.run({
|
||||
"requirement": state.requirement,
|
||||
"research": {
|
||||
"title_en": state.decomposition.get("title_en", ""),
|
||||
"title_zh": state.decomposition.get("title_zh", ""),
|
||||
"methodology": state.decomposition.get("methodology", ""),
|
||||
"tracks": research_merged,
|
||||
"synthesis_guide": synthesis_guide,
|
||||
},
|
||||
"revision_feedback": review_feedback,
|
||||
})
|
||||
|
||||
state.draft = result.get("draft", {})
|
||||
state.log_node("write", NodeStatus.COMPLETED)
|
||||
return state
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Node: translate — produce the other language version
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TranslateNode:
|
||||
"""Translates the draft into the other language version."""
|
||||
|
||||
def __init__(self):
|
||||
self.agent = BaseAgent()
|
||||
self.agent.name = "translator"
|
||||
self.agent.model = settings.model_for_domain("translation")
|
||||
|
||||
async def __call__(self, state: ReportState) -> ReportState:
|
||||
state.current_node = "translate"
|
||||
state.log_node("translate", NodeStatus.RUNNING)
|
||||
|
||||
if not state.draft or "en" not in state.output_languages:
|
||||
state.log_node("translate", NodeStatus.COMPLETED, "skipped")
|
||||
return state
|
||||
|
||||
draft_json = json.dumps(state.draft, ensure_ascii=False, indent=2)
|
||||
|
||||
# Detect primary language of draft
|
||||
title = state.draft.get("title", "")
|
||||
is_chinese_primary = any('\u4e00' <= c <= '\u9fff' for c in title)
|
||||
|
||||
if is_chinese_primary:
|
||||
target_lang = "English"
|
||||
source_lang = "Chinese"
|
||||
else:
|
||||
target_lang = "Chinese (Simplified)"
|
||||
source_lang = "English"
|
||||
|
||||
system = f"""\
|
||||
You are a world-class {source_lang} → {target_lang} translator specializing in
|
||||
consulting and business reports.
|
||||
|
||||
Translation principles:
|
||||
1. ACCURACY over fluency — every data point, percentage, and proper noun must be correct
|
||||
2. Professional terminology — use standard {target_lang} business/industry terms
|
||||
3. Preserve structure — keep the exact same JSON structure, only translate text values
|
||||
4. Cultural adaptation — adjust phrasing for the target audience (not word-for-word)
|
||||
5. Keep {{{{CHART:...}}}} and {{{{TABLE:...}}}} markers, translate their descriptions
|
||||
|
||||
Output the translated JSON with the exact same structure."""
|
||||
|
||||
prompt = f"""\
|
||||
Translate this consulting report from {source_lang} to {target_lang}.
|
||||
|
||||
{draft_json}
|
||||
|
||||
Output the translated JSON."""
|
||||
|
||||
translated = await self.agent.call_llm_json(prompt, system=system, max_tokens=8192)
|
||||
state.draft_translated = translated
|
||||
state.log_node("translate", NodeStatus.COMPLETED,
|
||||
f"{source_lang} → {target_lang}")
|
||||
return state
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Node: data — generate charts and tables
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class DataNode:
|
||||
def __init__(self):
|
||||
self.agent = DataAgent()
|
||||
|
||||
async def __call__(self, state: ReportState) -> ReportState:
|
||||
state.current_node = "data"
|
||||
state.log_node("data", NodeStatus.RUNNING)
|
||||
|
||||
result = await self.agent.run({
|
||||
"draft": state.draft,
|
||||
"extra_data": state.extra_data,
|
||||
})
|
||||
|
||||
state.data_assets = result.get("data_assets", {})
|
||||
state.log_node("data", NodeStatus.COMPLETED)
|
||||
return state
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Node: review — bilingual quality check
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class ReviewNode:
|
||||
def __init__(self):
|
||||
self.agent = ReviewerAgent()
|
||||
|
||||
async def __call__(self, state: ReportState) -> ReportState:
|
||||
state.current_node = "review"
|
||||
state.log_node("review", NodeStatus.RUNNING)
|
||||
|
||||
result = await self.agent.run({
|
||||
"draft": state.draft,
|
||||
"draft_translated": state.draft_translated,
|
||||
"research": state.decomposition,
|
||||
})
|
||||
|
||||
state.review = result.get("review", {})
|
||||
state.log_node("review", NodeStatus.COMPLETED,
|
||||
f"verdict={state.review.get('verdict', '?')}")
|
||||
return state
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Node: format — render bilingual output files
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class FormatNode:
|
||||
def __init__(self):
|
||||
self.agent = FormatterAgent()
|
||||
|
||||
async def __call__(self, state: ReportState) -> ReportState:
|
||||
state.current_node = "format"
|
||||
state.log_node("format", NodeStatus.RUNNING)
|
||||
|
||||
all_files = []
|
||||
|
||||
# Primary version
|
||||
result = await self.agent.run({
|
||||
"draft": state.draft,
|
||||
"data_assets": state.data_assets,
|
||||
"output_dir": str(settings.output_dir / state.id / "primary"),
|
||||
"output_formats": state.output_formats,
|
||||
})
|
||||
all_files.extend(result.get("generated_files", []))
|
||||
|
||||
# Translated version (if available)
|
||||
if state.draft_translated:
|
||||
result_tr = await self.agent.run({
|
||||
"draft": state.draft_translated,
|
||||
"data_assets": state.data_assets,
|
||||
"output_dir": str(settings.output_dir / state.id / "translated"),
|
||||
"output_formats": state.output_formats,
|
||||
})
|
||||
all_files.extend(result_tr.get("generated_files", []))
|
||||
|
||||
state.generated_files = all_files
|
||||
state.log_node("format", NodeStatus.COMPLETED,
|
||||
f"{len(all_files)} files")
|
||||
return state
|
||||
104
app/graph/state.py
Normal file
104
app/graph/state.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Report generation graph state — the shared context that flows through all nodes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class NodeStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class ContentDomain(str, Enum):
|
||||
"""Content domain — determines which model and language to use."""
|
||||
GLOBAL = "global" # International markets, global trends → English-native
|
||||
CHINA = "china" # Chinese market, domestic policy → Chinese-native
|
||||
REASONING = "reasoning" # Synthesis, review, strategy → strongest reasoning
|
||||
FAST = "fast" # Data processing, charts → cost-effective
|
||||
TRANSLATION = "translation" # EN↔ZH translation
|
||||
|
||||
|
||||
class SubtaskResult(BaseModel):
|
||||
"""Result from a parallel research subtask."""
|
||||
task_id: str = Field(default_factory=lambda: uuid.uuid4().hex[:8])
|
||||
description: str = ""
|
||||
domain: ContentDomain = ContentDomain.GLOBAL
|
||||
native_language: str = "en" # "en" or "zh" — the original writing language
|
||||
status: NodeStatus = NodeStatus.PENDING
|
||||
content: dict[str, Any] = Field(default_factory=dict)
|
||||
error: str | None = None
|
||||
started_at: datetime | None = None
|
||||
completed_at: datetime | None = None
|
||||
|
||||
@property
|
||||
def duration_ms(self) -> int | None:
|
||||
if self.started_at and self.completed_at:
|
||||
return int((self.completed_at - self.started_at).total_seconds() * 1000)
|
||||
return None
|
||||
|
||||
|
||||
class ReportState(BaseModel):
|
||||
"""Full state for a report generation run.
|
||||
|
||||
This is the single source of truth that all graph nodes read from and write to.
|
||||
"""
|
||||
# Identity
|
||||
id: str = Field(default_factory=lambda: uuid.uuid4().hex[:12])
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
|
||||
# --- Input (set once at start) ---
|
||||
requirement: str = ""
|
||||
report_type: str = "行业分析报告"
|
||||
extra_data: str = ""
|
||||
output_formats: list[str] = Field(default=["docx"])
|
||||
output_languages: list[str] = Field(default=["zh", "en"]) # produce both versions
|
||||
template_name: str | None = None
|
||||
client_id: str | None = None # for multi-tenant isolation
|
||||
|
||||
# --- Middleware injections ---
|
||||
client_context: str = "" # injected by ClientContextMiddleware
|
||||
memory_facts: list[str] = Field(default_factory=list) # injected by MemoryMiddleware
|
||||
token_budget: int = 120000 # managed by TokenBudgetMiddleware
|
||||
|
||||
# --- Lead Agent output ---
|
||||
decomposition: dict[str, Any] = Field(default_factory=dict)
|
||||
# e.g. {"tracks": [{"title": "政策环境", "prompt": "研究..."}, ...]}
|
||||
|
||||
# --- Parallel research results ---
|
||||
research_results: list[SubtaskResult] = Field(default_factory=list)
|
||||
|
||||
# --- Writer output ---
|
||||
draft: dict[str, Any] = Field(default_factory=dict) # primary language version
|
||||
draft_translated: dict[str, Any] = Field(default_factory=dict) # translated version
|
||||
|
||||
# --- Data Agent output ---
|
||||
data_assets: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# --- Reviewer output ---
|
||||
review: dict[str, Any] = Field(default_factory=dict)
|
||||
revision_count: int = 0
|
||||
max_revisions: int = 2
|
||||
|
||||
# --- Formatter output ---
|
||||
generated_files: list[str] = Field(default_factory=list)
|
||||
|
||||
# --- Execution tracking ---
|
||||
current_node: str = ""
|
||||
node_history: list[dict[str, Any]] = Field(default_factory=list)
|
||||
error: str | None = None
|
||||
|
||||
def log_node(self, node_name: str, status: NodeStatus, detail: str = ""):
|
||||
self.node_history.append({
|
||||
"node": node_name,
|
||||
"status": status.value,
|
||||
"detail": detail,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
})
|
||||
48
app/main.py
Normal file
48
app/main.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""FastAPI application entry point."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from app.api.routes import router
|
||||
from app.config import settings
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(name)s] %(levelname)s: %(message)s",
|
||||
)
|
||||
|
||||
app = FastAPI(
|
||||
title="咨询报告 AI 生成系统",
|
||||
version="0.1.0",
|
||||
description="Multi-agent pipeline for generating consulting reports",
|
||||
)
|
||||
|
||||
app.include_router(router)
|
||||
|
||||
# Serve generated files
|
||||
settings.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
app.mount("/files", StaticFiles(directory=str(settings.output_dir)), name="files")
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {
|
||||
"name": "咨询报告 AI 生成系统",
|
||||
"version": "0.1.0",
|
||||
"status": "running",
|
||||
"docs": "/docs",
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(
|
||||
"app.main:app",
|
||||
host=settings.host,
|
||||
port=settings.port,
|
||||
reload=True,
|
||||
)
|
||||
0
app/memory/__init__.py
Normal file
0
app/memory/__init__.py
Normal file
114
app/memory/store.py
Normal file
114
app/memory/store.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""File-based memory store — persistent facts with confidence ranking."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MEMORY_DIR = Path(__file__).resolve().parent.parent.parent / "memory"
|
||||
|
||||
|
||||
class Fact(BaseModel):
|
||||
id: str = Field(default_factory=lambda: uuid.uuid4().hex[:8])
|
||||
content: str
|
||||
category: str = "context" # preference | knowledge | context | behavior | goal
|
||||
confidence: float = 0.7
|
||||
source: str = "" # which report/session created this
|
||||
created_at: str = Field(default_factory=lambda: datetime.now().isoformat())
|
||||
|
||||
|
||||
class MemoryFile(BaseModel):
|
||||
"""One memory file per client (or global)."""
|
||||
client_id: str = "global"
|
||||
preferences: dict[str, str] = Field(default_factory=dict)
|
||||
facts: list[Fact] = Field(default_factory=list)
|
||||
|
||||
|
||||
class MemoryStore:
|
||||
"""Read/write persistent memory as JSON files.
|
||||
|
||||
Storage layout:
|
||||
memory/
|
||||
├── global.json — system-wide facts
|
||||
└── client_<id>.json — per-client facts
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
MEMORY_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _path(self, client_id: str = "global") -> Path:
|
||||
safe_name = client_id.replace("/", "_").replace("..", "_")
|
||||
return MEMORY_DIR / f"{safe_name}.json"
|
||||
|
||||
def load(self, client_id: str = "global") -> MemoryFile:
|
||||
path = self._path(client_id)
|
||||
if not path.exists():
|
||||
return MemoryFile(client_id=client_id)
|
||||
try:
|
||||
data = json.loads(path.read_text(encoding="utf-8"))
|
||||
return MemoryFile(**data)
|
||||
except Exception as e:
|
||||
logger.warning(f"[memory] failed to load {path}: {e}")
|
||||
return MemoryFile(client_id=client_id)
|
||||
|
||||
def save(self, mem: MemoryFile):
|
||||
path = self._path(mem.client_id)
|
||||
path.write_text(
|
||||
json.dumps(mem.model_dump(), ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
logger.info(f"[memory] saved {len(mem.facts)} facts to {path}")
|
||||
|
||||
def add_fact(
|
||||
self,
|
||||
content: str,
|
||||
client_id: str = "global",
|
||||
category: str = "context",
|
||||
confidence: float = 0.7,
|
||||
source: str = "",
|
||||
) -> Fact:
|
||||
mem = self.load(client_id)
|
||||
|
||||
# Deduplicate by content (normalized)
|
||||
normalized = content.strip().lower()
|
||||
for existing in mem.facts:
|
||||
if existing.content.strip().lower() == normalized:
|
||||
# Update confidence if higher
|
||||
if confidence > existing.confidence:
|
||||
existing.confidence = confidence
|
||||
self.save(mem)
|
||||
return existing
|
||||
|
||||
fact = Fact(
|
||||
content=content,
|
||||
category=category,
|
||||
confidence=confidence,
|
||||
source=source,
|
||||
)
|
||||
mem.facts.append(fact)
|
||||
self.save(mem)
|
||||
return fact
|
||||
|
||||
def get_top_facts(
|
||||
self, client_id: str = "global", limit: int = 15
|
||||
) -> list[str]:
|
||||
"""Get top N facts sorted by confidence, formatted for prompt injection."""
|
||||
mem = self.load(client_id)
|
||||
sorted_facts = sorted(mem.facts, key=lambda f: f.confidence, reverse=True)
|
||||
return [f.content for f in sorted_facts[:limit]]
|
||||
|
||||
def set_preference(self, key: str, value: str, client_id: str = "global"):
|
||||
mem = self.load(client_id)
|
||||
mem.preferences[key] = value
|
||||
self.save(mem)
|
||||
|
||||
def get_preferences(self, client_id: str = "global") -> dict[str, str]:
|
||||
return self.load(client_id).preferences
|
||||
0
app/middleware/__init__.py
Normal file
0
app/middleware/__init__.py
Normal file
22
app/middleware/base.py
Normal file
22
app/middleware/base.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""Middleware base class — before/after hooks around graph execution."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from app.graph.state import ReportState
|
||||
|
||||
|
||||
class Middleware(ABC):
|
||||
"""Base middleware. Subclasses override before() and/or after()."""
|
||||
|
||||
name: str = "base"
|
||||
enabled: bool = True
|
||||
|
||||
async def before(self, state: ReportState) -> ReportState:
|
||||
"""Called before graph execution. Modify state as needed."""
|
||||
return state
|
||||
|
||||
async def after(self, state: ReportState) -> ReportState:
|
||||
"""Called after graph execution. Modify state as needed."""
|
||||
return state
|
||||
35
app/middleware/chain.py
Normal file
35
app/middleware/chain.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Middleware chain — runs ordered middlewares before/after graph execution."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from app.graph.state import ReportState
|
||||
from .base import Middleware
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MiddlewareChain:
|
||||
"""Ordered list of middlewares. before() runs in order, after() runs in reverse."""
|
||||
|
||||
def __init__(self, middlewares: list[Middleware] | None = None):
|
||||
self.middlewares = middlewares or []
|
||||
|
||||
def add(self, mw: Middleware) -> "MiddlewareChain":
|
||||
self.middlewares.append(mw)
|
||||
return self
|
||||
|
||||
async def before(self, state: ReportState) -> ReportState:
|
||||
for mw in self.middlewares:
|
||||
if mw.enabled:
|
||||
logger.debug(f"[middleware:before] {mw.name}")
|
||||
state = await mw.before(state)
|
||||
return state
|
||||
|
||||
async def after(self, state: ReportState) -> ReportState:
|
||||
for mw in reversed(self.middlewares):
|
||||
if mw.enabled:
|
||||
logger.debug(f"[middleware:after] {mw.name}")
|
||||
state = await mw.after(state)
|
||||
return state
|
||||
56
app/middleware/client_context.py
Normal file
56
app/middleware/client_context.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""ClientContext middleware — injects client/project background into state."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from app.graph.state import ReportState
|
||||
from .base import Middleware
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Client profiles stored as JSON files
|
||||
CLIENTS_DIR = Path(__file__).resolve().parent.parent.parent / "clients"
|
||||
|
||||
|
||||
class ClientContextMiddleware(Middleware):
|
||||
"""Loads client profile and injects background context into state.
|
||||
|
||||
Client profiles are JSON files in the `clients/` directory:
|
||||
clients/<client_id>.json
|
||||
{
|
||||
"name": "某咨询公司",
|
||||
"industry": "金融",
|
||||
"preferences": "偏好简洁的执行摘要,数据驱动",
|
||||
"previous_reports": ["report_001", "report_002"]
|
||||
}
|
||||
"""
|
||||
|
||||
name = "client_context"
|
||||
|
||||
async def before(self, state: ReportState) -> ReportState:
|
||||
if not state.client_id:
|
||||
return state
|
||||
|
||||
profile_path = CLIENTS_DIR / f"{state.client_id}.json"
|
||||
if not profile_path.exists():
|
||||
logger.info(f"[client_context] no profile for client '{state.client_id}'")
|
||||
return state
|
||||
|
||||
try:
|
||||
profile = json.loads(profile_path.read_text(encoding="utf-8"))
|
||||
parts = []
|
||||
if name := profile.get("name"):
|
||||
parts.append(f"客户:{name}")
|
||||
if industry := profile.get("industry"):
|
||||
parts.append(f"行业:{industry}")
|
||||
if prefs := profile.get("preferences"):
|
||||
parts.append(f"偏好:{prefs}")
|
||||
state.client_context = ";".join(parts)
|
||||
logger.info(f"[client_context] loaded profile for '{state.client_id}'")
|
||||
except Exception as e:
|
||||
logger.warning(f"[client_context] failed to load profile: {e}")
|
||||
|
||||
return state
|
||||
61
app/middleware/compliance.py
Normal file
61
app/middleware/compliance.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""Compliance middleware — checks for sensitive data leakage in output."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
|
||||
from app.graph.state import ReportState
|
||||
from .base import Middleware
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Patterns that suggest sensitive data
|
||||
SENSITIVE_PATTERNS = [
|
||||
(r"\b\d{15,18}\b", "可能的身份证号"),
|
||||
(r"\b\d{16,19}\b", "可能的银行卡号"),
|
||||
(r"\b1[3-9]\d{9}\b", "可能的手机号"),
|
||||
(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", "邮箱地址"),
|
||||
(r"(?:密码|password|secret|token|api.?key)\s*[:=]\s*\S+", "可能的凭证信息"),
|
||||
]
|
||||
|
||||
|
||||
class ComplianceMiddleware(Middleware):
|
||||
"""Scans report output for sensitive data patterns.
|
||||
|
||||
After graph execution, scans the draft for PII / credential patterns.
|
||||
Logs warnings but does not block (can be made blocking later).
|
||||
"""
|
||||
|
||||
name = "compliance"
|
||||
|
||||
def _scan_text(self, text: str) -> list[dict]:
|
||||
findings = []
|
||||
for pattern, desc in SENSITIVE_PATTERNS:
|
||||
matches = re.findall(pattern, text)
|
||||
if matches:
|
||||
findings.append({
|
||||
"type": desc,
|
||||
"count": len(matches),
|
||||
"samples": [m[:8] + "..." for m in matches[:3]],
|
||||
})
|
||||
return findings
|
||||
|
||||
async def after(self, state: ReportState) -> ReportState:
|
||||
# Scan draft content
|
||||
all_text = json.dumps(state.draft, ensure_ascii=False)
|
||||
findings = self._scan_text(all_text)
|
||||
|
||||
if findings:
|
||||
logger.warning(
|
||||
f"[compliance] found {len(findings)} sensitive data patterns:"
|
||||
)
|
||||
for f in findings:
|
||||
logger.warning(f" - {f['type']}: {f['count']} occurrences")
|
||||
# Store in state for API to surface
|
||||
state.review.setdefault("compliance_warnings", findings)
|
||||
else:
|
||||
logger.info("[compliance] no sensitive data patterns detected")
|
||||
|
||||
return state
|
||||
64
app/middleware/memory.py
Normal file
64
app/middleware/memory.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Memory middleware — injects persistent facts into state, saves new facts after."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from app.graph.state import ReportState
|
||||
from app.memory.store import MemoryStore
|
||||
from .base import Middleware
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryMiddleware(Middleware):
|
||||
"""Injects top-N memory facts into state before execution.
|
||||
|
||||
After execution, extracts key facts from the report for future reference.
|
||||
"""
|
||||
|
||||
name = "memory"
|
||||
|
||||
def __init__(self):
|
||||
self.store = MemoryStore()
|
||||
|
||||
async def before(self, state: ReportState) -> ReportState:
|
||||
# Load global facts + client-specific facts
|
||||
facts = self.store.get_top_facts("global", limit=10)
|
||||
if state.client_id:
|
||||
client_facts = self.store.get_top_facts(state.client_id, limit=10)
|
||||
facts = client_facts + facts # client facts first
|
||||
|
||||
state.memory_facts = facts[:15] # cap total
|
||||
if facts:
|
||||
logger.info(f"[memory] injected {len(state.memory_facts)} facts")
|
||||
return state
|
||||
|
||||
async def after(self, state: ReportState) -> ReportState:
|
||||
# Save key metadata as facts for future reference
|
||||
if state.draft and not state.error:
|
||||
title = state.draft.get("title", "")
|
||||
report_type = state.report_type
|
||||
client_id = state.client_id or "global"
|
||||
|
||||
self.store.add_fact(
|
||||
content=f"生成过报告:{title}(类型:{report_type})",
|
||||
client_id=client_id,
|
||||
category="context",
|
||||
confidence=0.9,
|
||||
source=state.id,
|
||||
)
|
||||
|
||||
# Save decomposition tracks as knowledge
|
||||
tracks = state.decomposition.get("tracks", [])
|
||||
if tracks:
|
||||
track_titles = "、".join(t.get("title", "") for t in tracks)
|
||||
self.store.add_fact(
|
||||
content=f"报告《{title}》的研究轨道:{track_titles}",
|
||||
client_id=client_id,
|
||||
category="knowledge",
|
||||
confidence=0.6,
|
||||
source=state.id,
|
||||
)
|
||||
|
||||
return state
|
||||
46
app/middleware/token_budget.py
Normal file
46
app/middleware/token_budget.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""TokenBudget middleware — tracks and limits token usage across the pipeline."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from app.graph.state import ReportState
|
||||
from .base import Middleware
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TokenBudgetMiddleware(Middleware):
|
||||
"""Manages token budget to prevent context overflow.
|
||||
|
||||
Before: sets token budget based on model limits.
|
||||
After: logs total estimated token usage.
|
||||
"""
|
||||
|
||||
name = "token_budget"
|
||||
|
||||
# Rough char-to-token ratio for Chinese text
|
||||
CHARS_PER_TOKEN = 1.5
|
||||
|
||||
def _estimate_tokens(self, text: str) -> int:
|
||||
return int(len(text) / self.CHARS_PER_TOKEN)
|
||||
|
||||
async def before(self, state: ReportState) -> ReportState:
|
||||
# Default budget is already set in ReportState (120k)
|
||||
# Could adjust based on model selection
|
||||
logger.info(f"[token_budget] budget = {state.token_budget} tokens")
|
||||
return state
|
||||
|
||||
async def after(self, state: ReportState) -> ReportState:
|
||||
# Estimate total tokens used in the draft
|
||||
total_chars = 0
|
||||
for chapter in state.draft.get("chapters", []):
|
||||
total_chars += len(chapter.get("content", ""))
|
||||
total_chars += len(state.draft.get("executive_summary", ""))
|
||||
|
||||
estimated = self._estimate_tokens(str(total_chars))
|
||||
logger.info(
|
||||
f"[token_budget] estimated output tokens: {estimated}, "
|
||||
f"budget: {state.token_budget}"
|
||||
)
|
||||
return state
|
||||
0
app/output/__init__.py
Normal file
0
app/output/__init__.py
Normal file
0
app/pipeline/__init__.py
Normal file
0
app/pipeline/__init__.py
Normal file
43
app/pipeline/orchestrator.py
Normal file
43
app/pipeline/orchestrator.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""Pipeline orchestrator — builds the graph + middleware chain and executes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from app.graph.builder import ReportGraph
|
||||
from app.graph.state import ReportState
|
||||
from app.middleware.chain import MiddlewareChain
|
||||
from app.middleware.client_context import ClientContextMiddleware
|
||||
from app.middleware.token_budget import TokenBudgetMiddleware
|
||||
from app.middleware.compliance import ComplianceMiddleware
|
||||
from app.middleware.memory import MemoryMiddleware
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_middleware() -> MiddlewareChain:
|
||||
"""Assemble the middleware chain in execution order."""
|
||||
chain = MiddlewareChain()
|
||||
chain.add(ClientContextMiddleware()) # 1. load client profile
|
||||
chain.add(MemoryMiddleware()) # 2. inject memory facts
|
||||
chain.add(TokenBudgetMiddleware()) # 3. set token budget
|
||||
chain.add(ComplianceMiddleware()) # 4. scan output for PII (after only)
|
||||
return chain
|
||||
|
||||
|
||||
class PipelineOrchestrator:
|
||||
"""Top-level entry: creates graph + middleware, runs end-to-end."""
|
||||
|
||||
def __init__(self):
|
||||
self.middleware = build_middleware()
|
||||
self.graph = ReportGraph(middleware=self.middleware)
|
||||
|
||||
async def run(self, state: ReportState) -> ReportState:
|
||||
logger.info(f"[orchestrator] starting report {state.id}")
|
||||
state = await self.graph.run(state)
|
||||
logger.info(
|
||||
f"[orchestrator] finished report {state.id} — "
|
||||
f"status={'OK' if not state.error else 'FAILED'}, "
|
||||
f"files={len(state.generated_files)}"
|
||||
)
|
||||
return state
|
||||
22
app/pipeline/task.py
Normal file
22
app/pipeline/task.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""Pipeline task model — thin wrapper, main state is now ReportState."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.graph.state import ReportState
|
||||
|
||||
|
||||
def create_report_state(
|
||||
requirement: str,
|
||||
report_type: str = "行业分析报告",
|
||||
extra_data: str = "",
|
||||
output_formats: list[str] | None = None,
|
||||
client_id: str | None = None,
|
||||
) -> ReportState:
|
||||
"""Create a ReportState from user input."""
|
||||
return ReportState(
|
||||
requirement=requirement,
|
||||
report_type=report_type,
|
||||
extra_data=extra_data,
|
||||
output_formats=output_formats or ["docx"],
|
||||
client_id=client_id,
|
||||
)
|
||||
0
app/templates/__init__.py
Normal file
0
app/templates/__init__.py
Normal file
Reference in New Issue
Block a user