From c7533eada2b2b981fc76f84aa5475e9a5656c86e Mon Sep 17 00:00:00 2001 From: kang Date: Sat, 25 Apr 2026 19:25:22 +0800 Subject: [PATCH] init repo --- .gitignore | 24 + .project.json | 17 + AGENTS.md | 21 + BRIEF.md | 123 ++++ CLAUDE.md | 21 + RULES.md | 37 ++ app/__init__.py | 0 app/agents/__init__.py | 15 + app/agents/base.py | 166 ++++++ app/agents/data_agent.py | 78 +++ app/agents/formatter.py | 669 ++++++++++++++++++++++ app/agents/researcher.py | 103 ++++ app/agents/reviewer.py | 79 +++ app/agents/writer.py | 86 +++ app/api/__init__.py | 0 app/api/routes.py | 110 ++++ app/config.py | 57 ++ app/data/__init__.py | 0 app/data/factory.py | 52 ++ app/data/router.py | 82 +++ app/data/sources/__init__.py | 0 app/data/sources/akshare_source.py | 94 +++ app/data/sources/base.py | 34 ++ app/data/sources/gpt_researcher_source.py | 61 ++ app/data/sources/worldbank_source.py | 104 ++++ app/graph/__init__.py | 0 app/graph/builder.py | 108 ++++ app/graph/nodes.py | 393 +++++++++++++ app/graph/state.py | 104 ++++ app/main.py | 48 ++ app/memory/__init__.py | 0 app/memory/store.py | 114 ++++ app/middleware/__init__.py | 0 app/middleware/base.py | 22 + app/middleware/chain.py | 35 ++ app/middleware/client_context.py | 56 ++ app/middleware/compliance.py | 61 ++ app/middleware/memory.py | 64 +++ app/middleware/token_budget.py | 46 ++ app/output/__init__.py | 0 app/pipeline/__init__.py | 0 app/pipeline/orchestrator.py | 43 ++ app/pipeline/task.py | 22 + app/templates/__init__.py | 0 data-sources-research.md | 456 +++++++++++++++ memory/global.json | 22 + project.json | 11 + pyproject.toml | 30 + test_pipeline.py | 64 +++ tests/__init__.py | 0 50 files changed, 3732 insertions(+) create mode 100644 .gitignore create mode 100644 .project.json create mode 100644 AGENTS.md create mode 100644 BRIEF.md create mode 100644 CLAUDE.md create mode 100644 RULES.md create mode 100644 app/__init__.py create mode 100644 app/agents/__init__.py create mode 100644 app/agents/base.py create mode 100644 app/agents/data_agent.py create mode 100644 app/agents/formatter.py create mode 100644 app/agents/researcher.py create mode 100644 app/agents/reviewer.py create mode 100644 app/agents/writer.py create mode 100644 app/api/__init__.py create mode 100644 app/api/routes.py create mode 100644 app/config.py create mode 100644 app/data/__init__.py create mode 100644 app/data/factory.py create mode 100644 app/data/router.py create mode 100644 app/data/sources/__init__.py create mode 100644 app/data/sources/akshare_source.py create mode 100644 app/data/sources/base.py create mode 100644 app/data/sources/gpt_researcher_source.py create mode 100644 app/data/sources/worldbank_source.py create mode 100644 app/graph/__init__.py create mode 100644 app/graph/builder.py create mode 100644 app/graph/nodes.py create mode 100644 app/graph/state.py create mode 100644 app/main.py create mode 100644 app/memory/__init__.py create mode 100644 app/memory/store.py create mode 100644 app/middleware/__init__.py create mode 100644 app/middleware/base.py create mode 100644 app/middleware/chain.py create mode 100644 app/middleware/client_context.py create mode 100644 app/middleware/compliance.py create mode 100644 app/middleware/memory.py create mode 100644 app/middleware/token_budget.py create mode 100644 app/output/__init__.py create mode 100644 app/pipeline/__init__.py create mode 100644 app/pipeline/orchestrator.py create mode 100644 app/pipeline/task.py create mode 100644 app/templates/__init__.py create mode 100644 data-sources-research.md create mode 100644 memory/global.json create mode 100644 project.json create mode 100644 pyproject.toml create mode 100644 test_pipeline.py create mode 100644 tests/__init__.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1bdd6ff --- /dev/null +++ b/.gitignore @@ -0,0 +1,24 @@ +# OS +.DS_Store + +# Env +.env +.env.* + +# Python +__pycache__/ +.pytest_cache/ +.mypy_cache/ +.venv/ +venv/ + +# Node +node_modules/ +.next/ +dist/ +build/ +.nuxt/ +.output/ + +# Misc +*.log diff --git a/.project.json b/.project.json new file mode 100644 index 0000000..fc3fd43 --- /dev/null +++ b/.project.json @@ -0,0 +1,17 @@ +{ + "created" : "2026-03-27", + "description" : "AI 自动生成咨询报告系统(规划中)", + "kind" : "app", + "name" : "咨询报告生成", + "stack" : [ + "Python", + "FastAPI", + "LiteLLM", + "Pydantic" + ], + "status" : "paused", + "worklog" : { + "auto" : true, + "path" : "\/Users\/kangwan\/Projects\/business\/20260327-咨询报告生成\/Users\/kangwan\/Projects\/business\/20260327-咨询报告生成\/Users\/kangwan\/Projects\/business\/20260327-咨询报告生成\/Users\/kangwan\/Projects\/business\/20260327-咨询报告生成\/Users\/kangwan\/Projects\/business\/20260327-咨询报告生成\/Users\/kangwan\/Projects\/business\/20260327-咨询报告生成\/Users\/kangwan\/Projects\/business\/20260327-咨询报告生成\/memory\/worklog.json" + } +} diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..9068422 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,21 @@ +# 咨询报告生成 Agent Rules + +## Must Read First + +- `.project.json` 是机器真源:公网链接、快捷登录、凭证引用都以它为准 +- `RULES.md` 是人工规则和部署事实:启动命令、平台、域名、注意事项都写这里 +- 不允许编造不存在的域名、账号、密码;未知就保持空白并明确标记待补充 + +## Deployment Metadata Contract + +- 任何任务只要新增、删除或修改公网地址,必须在同一次任务里更新 `.project.json` +- `urls[]` 推荐显式写 `type`:`app`、`backend`、`docs`、`admin`、`repo` +- 项目专属的网页登录信息,如果允许放进仓库,就写 `.project.json.quick_login` +- 不能直接入库的敏感登录,不要伪造 `quick_login`,改为写 `.project.json.credentials` 引用 +- 数据库密码、API Key、服务器 root 密码,不属于 `quick_login` + +## Completion Gate + +- 部署完成后,不允许在 `.project.json` 缺少最新公网链接的状态下结束任务 +- 部署完成后,必须同步更新 `RULES.md` 的部署事实 +- 如果只更新了代码但没回写部署元数据,这个任务不算完成 diff --git a/BRIEF.md b/BRIEF.md new file mode 100644 index 0000000..b0a79f6 --- /dev/null +++ b/BRIEF.md @@ -0,0 +1,123 @@ +# 咨询报告 AI 生成系统 — 项目简报 + +> 新窗口打开时,让 Claude 先读这个文件 + +## 项目定位 + +为咨询公司构建 **安全、可控** 的行业报告自动生成系统。 +- 输入:客户需求 + 行业数据 + 报告模板 +- 输出:Word/PPT/Excel/PDF 格式的专业咨询报告 +- 安全铁律:**客户数据绝不经过第三方** + +## 架构灵感(来自 Open SWE) + +借鉴 Open SWE 的多 Agent 流水线,但从"写代码"改为"写报告": + +``` +用户输入(报告需求 + 数据) + │ + ├── Researcher Agent → 分析需求、检索资料、梳理框架 + │ + ├── Writer Agent → 按模板撰写报告正文 + │ + ├── Data Agent → 处理数据、生成图表、制作附录 + │ + ├── Reviewer Agent → 检查质量、一致性、合规性 + │ + └── Formatter Agent → 排版输出 docx/pptx/xlsx/pdf +``` + +## 可用 Skills(已有) + +| Skill | 路径 | 能力 | +|-------|------|------| +| docx | `~/Projects/code/20260119-skills合集/anthropics_skills/skills/docx/` | Word 文档生成/编辑/批注 | +| pptx | `~/Projects/code/20260119-skills合集/anthropics_skills/skills/pptx/` | PPT 演示文稿生成 | +| xlsx | `~/Projects/code/20260119-skills合集/anthropics_skills/skills/xlsx/` | Excel 数据分析/图表/财务模型 | +| pdf | `~/Projects/code/20260119-skills合集/anthropics_skills/skills/pdf/` | PDF 生成/合并 | + +## Open SWE 安全审查结论(避坑清单) + +从 Open SWE 代码审查中得出的教训,本项目必须避免: + +| Open SWE 的问题 | 本项目的对策 | +|-----------------|-------------| +| 所有内容发送给第三方 LLM (Poe/OpenRouter) | 必须用有 DPA 的 API 或自建模型 | +| LangSmith 遥测发送完整执行数据 | 不用 LangSmith,自建 tracing 或不 trace | +| fetch_url 无 SSRF 防护 | Agent 工具严格限制外网访问 | +| 本地沙箱 inherit_env 暴露密钥 | 隔离执行环境,最小化环境变量 | +| 无多租户隔离 | 每个客户项目独立隔离 | + +## 需要用户提供的材料 + +### 必须提供(不然搭不了) +1. **报告模板** — 你们现在用的 Word/PPT 模板文件(至少 2-3 个不同类型) +2. **样例报告** — 之前交付过的成品报告(脱敏后的)2-3 份 +3. **LLM 选择** — 用哪个模型?选项: + - Poe API(现有,但数据经过 Poe → 有泄露风险) + - 本地部署 LLM(192.168.2.221 Linux 服务器,需要 GPU) + - 有 DPA 的 API(如 Azure OpenAI / AWS Bedrock) + +### 最好提供(能做更好) +4. **行业数据源** — 你们通常从哪里获取行业数据?(数据库/网站/内部资料库) +5. **报告类型清单** — 你们做哪几种报告?(市场分析/竞品分析/尽职调查/投资备忘录...) +6. **品牌规范** — 公司 logo、配色、字体规范(用于排版) +7. **审批流程** — 报告生成后谁审核?需要什么审批机制? + +### 可选 +8. **客户列表结构** — 多租户需要支持多少客户?权限怎么分? +9. **部署偏好** — 部署在 VPS (76.13.31.179) 还是本地服务器 (192.168.2.221)? + +## 技术方案初步思路 + +``` +┌─────────────────────────────────────────────┐ +│ Web UI(后续做) │ +│ 用户输入需求 / 上传数据 / 下载报告 │ +└──────────────────┬──────────────────────────┘ + │ +┌──────────────────▼──────────────────────────┐ +│ 调度引擎 (Python/FastAPI) │ +│ 接收任务 → 编排 Agent → 输出报告 │ +│ - 多租户隔离 │ +│ - 任务队列 │ +│ - 模板管理 │ +└──────────────────┬──────────────────────────┘ + │ + ┌─────────────┼─────────────┐ + ▼ ▼ ▼ +┌─────────┐ ┌──────────┐ ┌──────────┐ +│Researcher│ │ Writer │ │ Data │ +│ Agent │ │ Agent │ │ Agent │ +│ 检索分析 │ │ 撰写正文 │ │ 图表数据 │ +└────┬─────┘ └────┬─────┘ └────┬─────┘ + └─────────────┼─────────────┘ + ▼ + ┌─────────────────┐ + │ Reviewer Agent │ + │ 质量审查 │ + └────────┬────────┘ + ▼ + ┌─────────────────┐ + │ Formatter Agent │ + │ docx/pptx/pdf │ + │ (用现有 Skills) │ + └─────────────────┘ +``` + +## 相关项目参考 + +- Open SWE 源码:`~/Projects/research/20260327-open-swe/source/` +- Open SWE 安全审查:在本项目创建对话中完成 +- Skills 合集:`~/Projects/code/20260119-skills合集/` +- GPT Researcher MCP:已配置在 `~/.claude.json`(可用于资料检索) + +## 状态 + +- [x] 需求明确:咨询公司内部效率工具 + 客户增值服务 +- [x] 安全需求明确:客户数据不得外泄 +- [x] 可用 Skills 盘点完成 +- [ ] **等待用户提供:报告模板 + 样例报告 + LLM 选择** +- [ ] 架构设计 +- [ ] 开发 +- [ ] 部署 diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..9068422 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,21 @@ +# 咨询报告生成 Agent Rules + +## Must Read First + +- `.project.json` 是机器真源:公网链接、快捷登录、凭证引用都以它为准 +- `RULES.md` 是人工规则和部署事实:启动命令、平台、域名、注意事项都写这里 +- 不允许编造不存在的域名、账号、密码;未知就保持空白并明确标记待补充 + +## Deployment Metadata Contract + +- 任何任务只要新增、删除或修改公网地址,必须在同一次任务里更新 `.project.json` +- `urls[]` 推荐显式写 `type`:`app`、`backend`、`docs`、`admin`、`repo` +- 项目专属的网页登录信息,如果允许放进仓库,就写 `.project.json.quick_login` +- 不能直接入库的敏感登录,不要伪造 `quick_login`,改为写 `.project.json.credentials` 引用 +- 数据库密码、API Key、服务器 root 密码,不属于 `quick_login` + +## Completion Gate + +- 部署完成后,不允许在 `.project.json` 缺少最新公网链接的状态下结束任务 +- 部署完成后,必须同步更新 `RULES.md` 的部署事实 +- 如果只更新了代码但没回写部署元数据,这个任务不算完成 diff --git a/RULES.md b/RULES.md new file mode 100644 index 0000000..ee3fb74 --- /dev/null +++ b/RULES.md @@ -0,0 +1,37 @@ +# 咨询报告生成 + +## 启动 +- `待补充` + +## 部署事实 +- 平台:待定 +- 发布状态:未部署 +- 主站 / 前端:待定 +- API / 后端:待定 +- 文档 / 解析:待定 +- 管理后台:待定 +- 代码仓:待定 + +## 快捷登录 +- 登录地址:待补充 +- 用户名:待补充 +- 密码:待补充 +- 说明:这里只写项目专属网页登录;数据库密码、API Key、服务器 root 密码不要写这里 + +## 元数据回写清单 +- 新增或变更公网地址后,必须同步更新 `.project.json.urls` +- 如果有网页后台登录: + - 可直接入库:写 `.project.json.quick_login` + - 不应入库:写 `.project.json.credentials` 引用 +- 部署完成后,`RULES.md` 和 `.project.json` 必须同一次任务一起更新 + +## 环境变量 +- 待补充 + +## 规则 +- 不允许编造不存在的部署域名、账号、密码 +- 没有公网地址时,`.project.json.urls` 保持空数组 +- 任何部署或域名变化,都要先改元数据,再视为任务完成 + +## 注意事项 +- 待补充 diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/agents/__init__.py b/app/agents/__init__.py new file mode 100644 index 0000000..1a15946 --- /dev/null +++ b/app/agents/__init__.py @@ -0,0 +1,15 @@ +from .base import BaseAgent +from .researcher import ResearcherAgent +from .writer import WriterAgent +from .data_agent import DataAgent +from .reviewer import ReviewerAgent +from .formatter import FormatterAgent + +__all__ = [ + "BaseAgent", + "ResearcherAgent", + "WriterAgent", + "DataAgent", + "ReviewerAgent", + "FormatterAgent", +] diff --git a/app/agents/base.py b/app/agents/base.py new file mode 100644 index 0000000..44bd134 --- /dev/null +++ b/app/agents/base.py @@ -0,0 +1,166 @@ +"""Base agent with LLM calling via litellm.""" + +from __future__ import annotations + +import json +import logging +from typing import Any + +import litellm + +from app.config import settings + +logger = logging.getLogger(__name__) + +# Disable litellm telemetry +litellm.telemetry = False + + +class BaseAgent: + """Base class for all pipeline agents.""" + + name: str = "base" + description: str = "" + system_prompt: str = "" + model: str = "" # empty = use default from config + + def __init__(self, model: str | None = None): + if model: + self.model = model + + def get_model(self) -> str: + return self.model or settings.llm_model + + async def call_llm( + self, + prompt: str, + *, + system: str | None = None, + temperature: float = 0.3, + max_tokens: int = 4096, + response_format: dict | None = None, + ) -> str: + """Call LLM via litellm. Returns the text response.""" + messages = [] + sys_prompt = system or self.system_prompt + if sys_prompt: + messages.append({"role": "system", "content": sys_prompt}) + messages.append({"role": "user", "content": prompt}) + + kwargs: dict[str, Any] = { + "model": self.get_model(), + "messages": messages, + "temperature": temperature, + "max_tokens": max_tokens, + } + if settings.llm_api_key: + kwargs["api_key"] = settings.llm_api_key + if settings.llm_api_base: + kwargs["api_base"] = settings.llm_api_base + if response_format: + kwargs["response_format"] = response_format + + logger.info(f"[{self.name}] calling {self.get_model()}") + response = await litellm.acompletion(**kwargs) + content = response.choices[0].message.content + logger.info(f"[{self.name}] got {len(content)} chars") + return content + + async def call_llm_json(self, prompt: str, **kwargs) -> dict: + """Call LLM and parse response as JSON.""" + raw = await self.call_llm( + prompt, + response_format={"type": "json_object"}, + **kwargs, + ) + # Strip markdown code fences if present + text = raw.strip() + if text.startswith("```"): + first_nl = text.find("\n") + if first_nl != -1: + text = text[first_nl + 1:] + if text.endswith("```"): + text = text[: text.rfind("```")] + text = text.strip() + + # Sanitize control characters inside JSON string values + # (models sometimes emit literal newlines/tabs inside strings) + import re + def _clean_json_string(s: str) -> str: + # Replace unescaped control chars within JSON strings + # This is a best-effort fix for common model outputs + result = [] + in_string = False + escape = False + for ch in s: + if escape: + result.append(ch) + escape = False + continue + if ch == '\\': + result.append(ch) + escape = True + continue + if ch == '"': + in_string = not in_string + result.append(ch) + continue + if in_string and ord(ch) < 32: + # Replace control chars with escaped versions + if ch == '\n': + result.append('\\n') + elif ch == '\r': + result.append('\\r') + elif ch == '\t': + result.append('\\t') + else: + result.append(f'\\u{ord(ch):04x}') + continue + result.append(ch) + return ''.join(result) + + # Try parsing with multiple strategies + for attempt, candidate in enumerate([text, _clean_json_string(text)]): + try: + return json.loads(candidate) + except json.JSONDecodeError: + continue + + # Last resort: try to extract the largest valid JSON object + # (model may have appended commentary after the JSON) + brace_depth = 0 + start = text.find('{') + if start == -1: + raise json.JSONDecodeError("No JSON object found", text, 0) + + cleaned = _clean_json_string(text) + for i, ch in enumerate(cleaned[start:], start): + if ch == '{': + brace_depth += 1 + elif ch == '}': + brace_depth -= 1 + if brace_depth == 0: + try: + return json.loads(cleaned[start:i + 1]) + except json.JSONDecodeError: + continue + + # If all else fails, use json_repair library or raise + try: + import json_repair + return json_repair.loads(text) + except (ImportError, Exception): + raise json.JSONDecodeError( + f"Failed to parse JSON after multiple attempts", text, 0 + ) + + async def run(self, context: dict[str, Any]) -> dict[str, Any]: + """Execute this agent's task. Override in subclasses. + + Args: + context: Shared pipeline context (accumulated by previous agents). + + Returns: + Dict of new keys to merge into context. + """ + raise NotImplementedError diff --git a/app/agents/data_agent.py b/app/agents/data_agent.py new file mode 100644 index 0000000..d721e54 --- /dev/null +++ b/app/agents/data_agent.py @@ -0,0 +1,78 @@ +"""Data Agent — processes data, generates chart specs and table data.""" + +from __future__ import annotations + +import json +from typing import Any + +from .base import BaseAgent +from app.config import settings + + +class DataAgent(BaseAgent): + name = "data" + description = "处理数据、生成图表规格和表格数据" + system_prompt = """\ +你是一位数据分析专家。你的任务是根据报告草稿中标注的图表和表格需求, +生成具体的数据和图表规格。 + +输出要求(JSON 格式): +{ + "charts": [ + { + "id": "chart_1", + "title": "图表标题", + "type": "bar|line|pie|area|scatter", + "description": "图表说明", + "data": { + "labels": ["标签1", "标签2"], + "datasets": [ + {"label": "数据集名", "data": [100, 200]} + ] + } + } + ], + "tables": [ + { + "id": "table_1", + "title": "表格标题", + "headers": ["列1", "列2", "列3"], + "rows": [["数据1", "数据2", "数据3"]] + } + ] +}""" + + def __init__(self): + super().__init__(model=settings.model_for_domain("fast")) + + async def run(self, context: dict[str, Any]) -> dict[str, Any]: + draft = context["draft"] + extra_data = context.get("extra_data", "") + + # Collect chart/table needs from draft + chart_needs = [] + table_needs = [] + for ch in draft.get("chapters", []): + chart_needs.extend(ch.get("charts", [])) + table_needs.extend(ch.get("tables", [])) + + if not chart_needs and not table_needs: + return {"data_assets": {"charts": [], "tables": []}} + + prompt = f"""\ +## 报告标题 +{draft.get("title", "")} + +## 需要生成的图表 +{json.dumps(chart_needs, ensure_ascii=False)} + +## 需要生成的表格 +{json.dumps(table_needs, ensure_ascii=False)} + +## 补充数据源 +{extra_data if extra_data else "(无额外数据,请根据行业常识生成合理的示例数据)"} + +请为以上需求生成具体的图表规格和表格数据。输出 JSON。""" + + result = await self.call_llm_json(prompt) + return {"data_assets": result} diff --git a/app/agents/formatter.py b/app/agents/formatter.py new file mode 100644 index 0000000..bab3efc --- /dev/null +++ b/app/agents/formatter.py @@ -0,0 +1,669 @@ +"""Formatter Agent — renders final report using Skills toolkit. + +Skills integration: + - docx: python-docx (baseline) + docx-js via Node.js (rich mode) + OOXML template editing + - pptx: html2pptx.js via Node.js (visual slides) + python-pptx fallback + - xlsx: openpyxl + recalc.py (formula recalculation via LibreOffice) + - pdf: reportlab with CJK support + fpdf2 fallback +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import shutil +import subprocess +import tempfile +from pathlib import Path +from typing import Any + +from .base import BaseAgent + +logger = logging.getLogger(__name__) + +# Skills root +SKILLS_ROOT = Path.home() / "Projects/code/20260119-skills合集/anthropics_skills/skills" +DOCX_SKILLS = SKILLS_ROOT / "docx" +PPTX_SKILLS = SKILLS_ROOT / "pptx" +XLSX_SKILLS = SKILLS_ROOT / "xlsx" +PDF_SKILLS = SKILLS_ROOT / "pdf" + + +def _skills_available() -> dict[str, bool]: + """Check which skill toolkits are available.""" + return { + "docx_js": (DOCX_SKILLS / "docx-js.md").exists(), + "html2pptx": (PPTX_SKILLS / "scripts" / "html2pptx.js").exists(), + "recalc": (XLSX_SKILLS / "recalc.py").exists(), + "ooxml_docx": (DOCX_SKILLS / "ooxml" / "scripts" / "unpack.py").exists(), + "ooxml_pptx": (PPTX_SKILLS / "ooxml" / "scripts" / "unpack.py").exists(), + "pdf_scripts": (PDF_SKILLS / "scripts").is_dir(), + } + + +class FormatterAgent(BaseAgent): + name = "formatter" + description = "将报告渲染为 docx/pptx/xlsx/pdf,融合 Skills 能力" + + def __init__(self): + super().__init__() + self.skills = _skills_available() + available = [k for k, v in self.skills.items() if v] + logger.info(f"[formatter] available skills: {available}") + + async def run(self, context: dict[str, Any]) -> dict[str, Any]: + draft = context["draft"] + data_assets = context.get("data_assets", {}) + output_dir = Path(context.get("output_dir", "output")) + formats = context.get("output_formats", ["docx"]) + template_path = context.get("template_path") # optional: user-provided template + + output_dir.mkdir(parents=True, exist_ok=True) + title = draft.get("title", "报告") + generated_files = [] + + for fmt in formats: + try: + match fmt: + case "docx": + path = await self._render_docx(draft, data_assets, output_dir, title, template_path) + case "pptx": + path = await self._render_pptx(draft, data_assets, output_dir, title) + case "xlsx": + path = await self._render_xlsx(data_assets, output_dir, title) + case "pdf": + path = await self._render_pdf(draft, data_assets, output_dir, title) + case _: + logger.warning(f"Unsupported format: {fmt}") + continue + generated_files.append(str(path)) + logger.info(f"[formatter] generated {path}") + except Exception as e: + logger.exception(f"[formatter] failed to render {fmt}") + + return {"generated_files": generated_files} + + # ----------------------------------------------------------------------- + # DOCX — python-docx baseline + OOXML template editing + # ----------------------------------------------------------------------- + + async def _render_docx( + self, draft: dict, data_assets: dict, output_dir: Path, title: str, + template_path: str | None = None, + ) -> Path: + if template_path and self.skills["ooxml_docx"]: + return await self._render_docx_from_template( + draft, data_assets, output_dir, title, Path(template_path) + ) + return await self._render_docx_baseline(draft, data_assets, output_dir, title) + + async def _render_docx_baseline( + self, draft: dict, data_assets: dict, output_dir: Path, title: str + ) -> Path: + from docx import Document + from docx.shared import Pt, RGBColor + from docx.enum.text import WD_ALIGN_PARAGRAPH + + doc = Document() + + # -- Styles -- + style = doc.styles["Normal"] + style.font.name = "微软雅黑" + style.font.size = Pt(11) + + # Title + t = doc.add_heading(title, level=0) + t.alignment = WD_ALIGN_PARAGRAPH.CENTER + + # Executive summary + if summary := draft.get("executive_summary"): + doc.add_heading("执行摘要", level=1) + # Add summary with highlight styling + p = doc.add_paragraph() + run = p.add_run(summary) + run.font.size = Pt(11) + run.font.color.rgb = RGBColor(0x33, 0x33, 0x33) + + # Chapters + for chapter in draft.get("chapters", []): + doc.add_heading(chapter["title"], level=1) + content = chapter.get("content", "") + self._docx_render_markdown(doc, content) + + # Tables from data assets + for table_spec in data_assets.get("tables", []): + doc.add_heading(table_spec.get("title", "数据表"), level=2) + self._docx_add_table(doc, table_spec) + + # Page break + chart descriptions as placeholders + for chart_spec in data_assets.get("charts", []): + doc.add_heading(chart_spec.get("title", "图表"), level=2) + desc = chart_spec.get("description", "") + chart_type = chart_spec.get("type", "") + doc.add_paragraph(f"[{chart_type.upper()} 图表] {desc}") + # Render chart data as a table too + chart_data = chart_spec.get("data", {}) + if labels := chart_data.get("labels"): + for ds in chart_data.get("datasets", []): + self._docx_add_table(doc, { + "headers": ["项目", ds.get("label", "数据")], + "rows": [[str(l), str(v)] for l, v in zip(labels, ds.get("data", []))], + }) + + path = output_dir / f"{title}.docx" + doc.save(str(path)) + return path + + async def _render_docx_from_template( + self, draft: dict, data_assets: dict, output_dir: Path, title: str, + template_path: Path, + ) -> Path: + """Edit an existing DOCX template using OOXML unpack/edit/pack workflow.""" + unpack_script = DOCX_SKILLS / "ooxml" / "scripts" / "unpack.py" + pack_script = DOCX_SKILLS / "ooxml" / "scripts" / "pack.py" + + with tempfile.TemporaryDirectory() as tmpdir: + work_dir = Path(tmpdir) / "unpacked" + + # Unpack template + proc = await asyncio.create_subprocess_exec( + "python3", str(unpack_script), str(template_path), str(work_dir), + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + ) + await proc.wait() + + if proc.returncode != 0: + logger.warning("[formatter] OOXML unpack failed, falling back to baseline") + return await self._render_docx_baseline(draft, data_assets, output_dir, title) + + # TODO: edit XML content in work_dir based on draft + # For now, just pack back as-is (template passthrough) + output_path = output_dir / f"{title}.docx" + proc = await asyncio.create_subprocess_exec( + "python3", str(pack_script), str(work_dir), str(output_path), + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + ) + await proc.wait() + return output_path + + def _docx_render_markdown(self, doc, content: str): + """Convert markdown-ish content to docx paragraphs.""" + from docx.shared import Pt + + for block in content.split("\n\n"): + block = block.strip() + if not block: + continue + if block.startswith("#### "): + doc.add_heading(block[5:], level=4) + elif block.startswith("### "): + doc.add_heading(block[4:], level=3) + elif block.startswith("## "): + doc.add_heading(block[3:], level=2) + elif block.startswith("- ") or block.startswith("* "): + # Bullet list + for line in block.split("\n"): + line = line.lstrip("- *").strip() + if line: + doc.add_paragraph(line, style="List Bullet") + elif block.startswith("1. ") or block.startswith("1)"): + # Numbered list + for line in block.split("\n"): + text = line.lstrip("0123456789.)) ").strip() + if text: + doc.add_paragraph(text, style="List Number") + else: + p = doc.add_paragraph(block) + for run in p.runs: + run.font.size = Pt(11) + + def _docx_add_table(self, doc, table_spec: dict): + """Add a formatted table to the document.""" + from docx.shared import Pt, RGBColor + from docx.oxml.ns import qn + + headers = table_spec.get("headers", []) + rows = table_spec.get("rows", []) + if not headers: + return + + tbl = doc.add_table(rows=1 + len(rows), cols=len(headers)) + tbl.style = "Light Grid Accent 1" + + # Header row + for i, h in enumerate(headers): + cell = tbl.rows[0].cells[i] + cell.text = str(h) + for p in cell.paragraphs: + for run in p.runs: + run.font.bold = True + run.font.size = Pt(10) + + # Data rows + for r_idx, row in enumerate(rows): + for c_idx, cell_val in enumerate(row): + tbl.rows[r_idx + 1].cells[c_idx].text = str(cell_val) + + # ----------------------------------------------------------------------- + # PPTX — html2pptx.js (rich) or python-pptx (fallback) + # ----------------------------------------------------------------------- + + async def _render_pptx( + self, draft: dict, data_assets: dict, output_dir: Path, title: str + ) -> Path: + if self.skills["html2pptx"]: + try: + return await self._render_pptx_html2pptx(draft, data_assets, output_dir, title) + except Exception as e: + logger.warning(f"[formatter] html2pptx failed ({e}), falling back to python-pptx") + return await self._render_pptx_baseline(draft, data_assets, output_dir, title) + + async def _render_pptx_html2pptx( + self, draft: dict, data_assets: dict, output_dir: Path, title: str + ) -> Path: + """Generate PPTX using html2pptx.js skill for visual slides.""" + with tempfile.TemporaryDirectory() as tmpdir: + work = Path(tmpdir) + + # Generate HTML slides + slides_html = [] + # Title slide + slides_html.append(f""" +

{title}

+

{draft.get('executive_summary', '')[:100]}

+""") + + # Chapter slides + for ch in draft.get("chapters", []): + content_lines = ch.get("content", "")[:400].split("\n") + bullets = "".join(f"
  • {l.strip()}
  • " for l in content_lines if l.strip()) + slides_html.append(f""" +

    {ch['title']}

    + +""") + + # Write HTML files + for i, html in enumerate(slides_html): + (work / f"slide_{i}.html").write_text(html, encoding="utf-8") + + # Write conversion script + script = work / "convert.js" + html2pptx_path = PPTX_SKILLS / "scripts" / "html2pptx.js" + slide_files = [f"slide_{i}.html" for i in range(len(slides_html))] + + script.write_text(f"""\ +const pptxgen = require('pptxgenjs'); +const {{ html2pptx }} = require('{html2pptx_path}'); +const path = require('path'); + +async function main() {{ + const pptx = new pptxgen(); + pptx.layout = 'LAYOUT_16x9'; + const files = {json.dumps(slide_files)}; + for (const f of files) {{ + await html2pptx(path.join('{work}', f), pptx); + }} + await pptx.writeFile({{ fileName: '{output_dir / f"{title}.pptx"}' }}); +}} +main().catch(e => {{ console.error(e); process.exit(1); }}); +""", encoding="utf-8") + + proc = await asyncio.create_subprocess_exec( + "node", str(script), + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + cwd=str(work), + ) + stdout, stderr = await proc.communicate() + if proc.returncode != 0: + raise RuntimeError(f"html2pptx failed: {stderr.decode()}") + + return output_dir / f"{title}.pptx" + + async def _render_pptx_baseline( + self, draft: dict, data_assets: dict, output_dir: Path, title: str + ) -> Path: + from pptx import Presentation + from pptx.util import Inches, Pt + from pptx.dml.color import RGBColor + + prs = Presentation() + + # Title slide + slide = prs.slides.add_slide(prs.slide_layouts[0]) + slide.shapes.title.text = title + if len(slide.placeholders) > 1: + slide.placeholders[1].text = draft.get("executive_summary", "")[:200] + + # Chapter slides + for chapter in draft.get("chapters", []): + slide = prs.slides.add_slide(prs.slide_layouts[1]) + slide.shapes.title.text = chapter["title"] + body = slide.placeholders[1] + tf = body.text_frame + tf.clear() + + content = chapter.get("content", "") + lines = [l.strip() for l in content.split("\n") if l.strip()] + for line in lines[:12]: # max 12 bullets per slide + p = tf.add_paragraph() + # Strip markdown markers + clean = line.lstrip("#-*0123456789.) ").strip() + p.text = clean + p.font.size = Pt(14) + p.space_after = Pt(4) + + # Data table slides + for table_spec in data_assets.get("tables", []): + slide = prs.slides.add_slide(prs.slide_layouts[5]) # blank layout + slide.shapes.title.text = table_spec.get("title", "数据表") + + headers = table_spec.get("headers", []) + rows = table_spec.get("rows", []) + if headers and rows: + n_rows = min(len(rows) + 1, 10) # limit rows per slide + n_cols = len(headers) + tbl = slide.shapes.add_table( + n_rows, n_cols, + Inches(0.5), Inches(1.5), Inches(9), Inches(4.5) + ).table + + for i, h in enumerate(headers): + tbl.cell(0, i).text = str(h) + for r_idx, row in enumerate(rows[:n_rows - 1]): + for c_idx, val in enumerate(row[:n_cols]): + tbl.cell(r_idx + 1, c_idx).text = str(val) + + path = output_dir / f"{title}.pptx" + prs.save(str(path)) + return path + + # ----------------------------------------------------------------------- + # XLSX — openpyxl + recalc.py (formula recalculation) + # ----------------------------------------------------------------------- + + async def _render_xlsx( + self, data_assets: dict, output_dir: Path, title: str + ) -> Path: + from openpyxl import Workbook + from openpyxl.styles import Font, PatternFill, Alignment, Border, Side + from openpyxl.utils import get_column_letter + + wb = Workbook() + ws = wb.active + ws.title = "数据总览" + + # Professional styling + header_font = Font(bold=True, size=11, color="FFFFFF") + header_fill = PatternFill(start_color="1A1A2E", end_color="1A1A2E", fill_type="solid") + title_font = Font(bold=True, size=14, color="1A1A2E") + thin_border = Border( + left=Side(style="thin", color="CCCCCC"), + right=Side(style="thin", color="CCCCCC"), + top=Side(style="thin", color="CCCCCC"), + bottom=Side(style="thin", color="CCCCCC"), + ) + + current_row = 1 + has_formulas = False + + for table_spec in data_assets.get("tables", []): + # Table title + ws.cell(row=current_row, column=1, value=table_spec.get("title", "")).font = title_font + current_row += 1 + + headers = table_spec.get("headers", []) + rows = table_spec.get("rows", []) + + if headers: + # Header row with styling + for col_idx, h in enumerate(headers, 1): + cell = ws.cell(row=current_row, column=col_idx, value=h) + cell.font = header_font + cell.fill = header_fill + cell.alignment = Alignment(horizontal="center") + cell.border = thin_border + current_row += 1 + + # Data rows + data_start = current_row + for row_data in rows: + for col_idx, val in enumerate(row_data, 1): + cell = ws.cell(row=current_row, column=col_idx, value=val) + cell.border = thin_border + # Try to convert numeric strings + if isinstance(val, str): + try: + cell.value = float(val.replace(",", "")) + except (ValueError, AttributeError): + pass + current_row += 1 + + # Auto-sum row for numeric columns + data_end = current_row - 1 + if data_end > data_start: + for col_idx in range(1, len(headers) + 1): + col_letter = get_column_letter(col_idx) + test_cell = ws.cell(row=data_start, column=col_idx) + if isinstance(test_cell.value, (int, float)): + cell = ws.cell( + row=current_row, column=col_idx, + value=f"=SUM({col_letter}{data_start}:{col_letter}{data_end})" + ) + cell.font = Font(bold=True) + cell.border = thin_border + has_formulas = True + elif col_idx == 1: + cell = ws.cell(row=current_row, column=1, value="合计") + cell.font = Font(bold=True) + cell.border = thin_border + current_row += 1 + + # Auto-fit column widths + for col_idx in range(1, len(headers) + 1): + max_len = max( + len(str(ws.cell(row=r, column=col_idx).value or "")) + for r in range(current_row - len(rows) - 2, current_row) + ) + ws.column_dimensions[get_column_letter(col_idx)].width = min(max_len + 4, 30) + + current_row += 2 # gap between tables + + # Chart data sheets + for chart_spec in data_assets.get("charts", []): + chart_ws = wb.create_sheet(title=chart_spec.get("title", "图表")[:31]) + chart_ws.cell(row=1, column=1, value=chart_spec.get("title", "")).font = title_font + chart_data = chart_spec.get("data", {}) + labels = chart_data.get("labels", []) + datasets = chart_data.get("datasets", []) + + # Headers: [项目, 数据集1, 数据集2, ...] + chart_ws.cell(row=2, column=1, value="项目").font = Font(bold=True) + for ds_idx, ds in enumerate(datasets, 2): + chart_ws.cell(row=2, column=ds_idx, value=ds.get("label", "")).font = Font(bold=True) + + for r_idx, label in enumerate(labels, 3): + chart_ws.cell(row=r_idx, column=1, value=label) + for ds_idx, ds in enumerate(datasets, 2): + data = ds.get("data", []) + if r_idx - 3 < len(data): + chart_ws.cell(row=r_idx, column=ds_idx, value=data[r_idx - 3]) + + path = output_dir / f"{title}.xlsx" + wb.save(str(path)) + + # Run recalc.py if we have formulas and the skill is available + if has_formulas and self.skills["recalc"]: + await self._xlsx_recalc(path) + + return path + + async def _xlsx_recalc(self, path: Path): + """Recalculate formulas using Skills recalc.py (requires LibreOffice).""" + recalc_script = XLSX_SKILLS / "recalc.py" + logger.info(f"[formatter] running recalc.py on {path}") + try: + proc = await asyncio.create_subprocess_exec( + "python3", str(recalc_script), str(path), "30", + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + ) + stdout, stderr = await proc.communicate() + if proc.returncode == 0: + result = json.loads(stdout.decode()) + logger.info(f"[formatter] recalc result: {result.get('status')}") + else: + logger.warning(f"[formatter] recalc.py failed: {stderr.decode()[:200]}") + except Exception as e: + logger.warning(f"[formatter] recalc.py error: {e}") + + # ----------------------------------------------------------------------- + # PDF — reportlab with CJK support + fpdf2 fallback + # ----------------------------------------------------------------------- + + async def _render_pdf( + self, draft: dict, data_assets: dict, output_dir: Path, title: str + ) -> Path: + try: + return await self._render_pdf_reportlab(draft, data_assets, output_dir, title) + except Exception as e: + logger.warning(f"[formatter] reportlab failed ({e}), falling back to fpdf2") + return await self._render_pdf_fpdf(draft, data_assets, output_dir, title) + + async def _render_pdf_reportlab( + self, draft: dict, data_assets: dict, output_dir: Path, title: str + ) -> Path: + """Generate PDF with reportlab — better CJK support and table rendering.""" + from reportlab.lib.pagesizes import A4 + from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle + from reportlab.lib.units import mm + from reportlab.lib import colors + from reportlab.platypus import ( + SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle, PageBreak, + ) + from reportlab.pdfbase import pdfmetrics + from reportlab.pdfbase.ttfonts import TTFont + + # Try to register a CJK font + cjk_font = "Helvetica" + for font_path in [ + "/System/Library/Fonts/STHeiti Medium.ttc", + "/System/Library/Fonts/PingFang.ttc", + "/usr/share/fonts/truetype/noto/NotoSansCJK-Regular.ttc", + ]: + if Path(font_path).exists(): + try: + pdfmetrics.registerFont(TTFont("CJK", font_path, subfontIndex=0)) + cjk_font = "CJK" + break + except Exception: + continue + + path = output_dir / f"{title}.pdf" + doc = SimpleDocTemplate(str(path), pagesize=A4, + topMargin=25*mm, bottomMargin=25*mm) + + styles = getSampleStyleSheet() + styles.add(ParagraphStyle( + name="CJKTitle", fontName=cjk_font, fontSize=22, + spaceAfter=12, alignment=1, + )) + styles.add(ParagraphStyle( + name="CJKHeading", fontName=cjk_font, fontSize=16, + spaceAfter=8, spaceBefore=16, textColor=colors.HexColor("#1a1a2e"), + )) + styles.add(ParagraphStyle( + name="CJKBody", fontName=cjk_font, fontSize=11, + spaceAfter=6, leading=16, + )) + + elements = [] + + # Title + elements.append(Paragraph(title, styles["CJKTitle"])) + elements.append(Spacer(1, 12)) + + # Executive summary + if summary := draft.get("executive_summary"): + elements.append(Paragraph("执行摘要", styles["CJKHeading"])) + elements.append(Paragraph(summary, styles["CJKBody"])) + elements.append(Spacer(1, 12)) + + # Chapters + for chapter in draft.get("chapters", []): + elements.append(PageBreak()) + elements.append(Paragraph(chapter["title"], styles["CJKHeading"])) + content = chapter.get("content", "") + for para in content.split("\n\n"): + para = para.strip() + if para: + elements.append(Paragraph(para, styles["CJKBody"])) + + # Tables + for table_spec in data_assets.get("tables", []): + elements.append(Spacer(1, 12)) + elements.append(Paragraph(table_spec.get("title", ""), styles["CJKHeading"])) + headers = table_spec.get("headers", []) + rows = table_spec.get("rows", []) + if headers: + table_data = [headers] + rows + t = Table(table_data) + t.setStyle(TableStyle([ + ("BACKGROUND", (0, 0), (-1, 0), colors.HexColor("#1a1a2e")), + ("TEXTCOLOR", (0, 0), (-1, 0), colors.white), + ("FONTNAME", (0, 0), (-1, -1), cjk_font), + ("FONTSIZE", (0, 0), (-1, 0), 10), + ("FONTSIZE", (0, 1), (-1, -1), 9), + ("GRID", (0, 0), (-1, -1), 0.5, colors.grey), + ("ALIGN", (0, 0), (-1, -1), "CENTER"), + ("ROWBACKGROUNDS", (0, 1), (-1, -1), [colors.white, colors.HexColor("#f5f5f5")]), + ])) + elements.append(t) + + doc.build(elements) + return path + + async def _render_pdf_fpdf( + self, draft: dict, data_assets: dict, output_dir: Path, title: str + ) -> Path: + """Fallback PDF generation with fpdf2.""" + from fpdf import FPDF + + pdf = FPDF() + pdf.set_auto_page_break(auto=True, margin=15) + + # Try CJK font + for font_path in [ + "/System/Library/Fonts/STHeiti Medium.ttc", + "/System/Library/Fonts/PingFang.ttc", + ]: + if Path(font_path).exists(): + try: + pdf.add_font("CJK", "", font_path, uni=True) + pdf.set_font("CJK", "", 11) + break + except Exception: + pdf.set_font("Helvetica", "", 11) + else: + pdf.set_font("Helvetica", "", 11) + + pdf.add_page() + pdf.set_font_size(24) + pdf.cell(0, 20, title, new_x="LMARGIN", new_y="NEXT", align="C") + + pdf.set_font_size(11) + if summary := draft.get("executive_summary"): + pdf.set_font_size(16) + pdf.cell(0, 12, "执行摘要", new_x="LMARGIN", new_y="NEXT") + pdf.set_font_size(11) + pdf.multi_cell(0, 6, summary) + + for chapter in draft.get("chapters", []): + pdf.add_page() + pdf.set_font_size(16) + pdf.cell(0, 12, chapter["title"], new_x="LMARGIN", new_y="NEXT") + pdf.set_font_size(11) + pdf.multi_cell(0, 6, chapter.get("content", "")) + + path = output_dir / f"{title}.pdf" + pdf.output(str(path)) + return path diff --git a/app/agents/researcher.py b/app/agents/researcher.py new file mode 100644 index 0000000..4d3aef3 --- /dev/null +++ b/app/agents/researcher.py @@ -0,0 +1,103 @@ +"""Researcher Agent — domain-aware, bilingual research.""" + +from __future__ import annotations + +from typing import Any + +from .base import BaseAgent +from app.config import settings + +SYSTEM_EN = """\ +You are a senior industry analyst at a top-tier consulting firm. +Your task is to produce a thorough research brief based on the given instructions. + +Requirements: +1. Be specific — cite concrete data points, market sizes, growth rates, company names +2. Be structured — organize findings with clear headings and logical flow +3. Be analytical — don't just list facts, provide insights and implications +4. Flag data gaps — explicitly note where data is uncertain or unavailable + +Output (JSON): +{ + "title": "Research brief title", + "executive_summary": "2-3 sentence summary of key findings", + "sections": [ + { + "heading": "Section heading", + "content": "Detailed findings (Markdown)", + "data_points": ["key data points extracted"], + "sources_quality": "high|medium|low — how confident are you in the data" + } + ], + "data_gaps": ["areas where data is insufficient or uncertain"], + "key_insights": ["top 3-5 non-obvious insights"] +}""" + +SYSTEM_ZH = """\ +你是一位顶级咨询公司的资深行业分析师。 +你的任务是根据给定的指令,输出一份深度研究简报。 + +要求: +1. 具体——引用具体的数据点、市场规模、增长率、企业名称 +2. 结构化——用清晰的标题和逻辑流组织发现 +3. 有分析深度——不要只罗列事实,要提供洞察和含义 +4. 标注数据缺口——明确指出数据不确定或不可获取的地方 + +输出(JSON): +{ + "title": "研究简报标题", + "executive_summary": "核心发现的2-3句总结", + "sections": [ + { + "heading": "章节标题", + "content": "详细发现(Markdown格式)", + "data_points": ["提取的关键数据点"], + "sources_quality": "high|medium|low — 对数据的置信度" + } + ], + "data_gaps": ["数据不充分或不确定的领域"], + "key_insights": ["3-5条非显而易见的洞察"] +}""" + + +class ResearcherAgent(BaseAgent): + name = "researcher" + description = "域感知研究 — 根据领域选择最优模型和语言" + + def __init__(self, model: str | None = None, language: str = "en"): + super().__init__(model=model) + self.language = language + self.system_prompt = SYSTEM_ZH if language == "zh" else SYSTEM_EN + + async def run(self, context: dict[str, Any]) -> dict[str, Any]: + requirement = context["requirement"] + report_type = context.get("report_type", "") + extra_data = context.get("extra_data", "") + + if self.language == "zh": + prompt = f"""\ +## 研究指令 +{requirement} + +## 研究方向 +{report_type} + +## 补充数据 +{extra_data if extra_data else "(无)"} + +请输出研究简报 JSON。""" + else: + prompt = f"""\ +## Research instructions +{requirement} + +## Research focus +{report_type} + +## Additional data +{extra_data if extra_data else "(none)"} + +Output the research brief as JSON.""" + + result = await self.call_llm_json(prompt, max_tokens=6144) + return {"research": result} diff --git a/app/agents/reviewer.py b/app/agents/reviewer.py new file mode 100644 index 0000000..b646bb3 --- /dev/null +++ b/app/agents/reviewer.py @@ -0,0 +1,79 @@ +"""Reviewer Agent — bilingual quality check with strongest reasoning model.""" + +from __future__ import annotations + +import json +from typing import Any + +from .base import BaseAgent +from app.config import settings + + +class ReviewerAgent(BaseAgent): + name = "reviewer" + description = "双语报告质量审查 — 使用最强推理模型" + + system_prompt = """\ +You are a senior consulting partner reviewing a report before client delivery. +The report has both Chinese and English versions (or will be translated). + +Review dimensions: +1. **Accuracy** — Are data points, percentages, and claims supported by the research? + Cross-check global claims against English research, Chinese claims against Chinese research. +2. **Logical consistency** — Does the narrative flow? Are there contradictions between chapters? +3. **Depth of analysis** — Is it consultancy-grade or just surface-level? Would a C-suite exec find it valuable? +4. **Bilingual quality** — If translated version exists, check for translation artifacts, + mistranslated terminology, or cultural mismatches. +5. **Data gaps honesty** — Are uncertainties acknowledged or are claims fabricated? +6. **Completeness** — Are any critical aspects of the requirement left unaddressed? + +Scoring guide: +- 90+: Publication-ready +- 80-89: Minor issues, can pass with notes +- 70-79: Needs revision (verdict: revise) +- <70: Significant problems (verdict: reject) + +Output (JSON): +{ + "overall_score": 85, + "verdict": "pass|revise|reject", + "issues": [ + { + "severity": "high|medium|low", + "chapter": "affected chapter", + "dimension": "accuracy|consistency|depth|bilingual|gaps|completeness", + "description": "issue description", + "suggestion": "specific fix suggestion" + } + ], + "strengths": ["what the report does well"], + "summary": "Overall assessment (2-3 sentences)" +}""" + + def __init__(self): + super().__init__(model=settings.model_for_domain("reasoning")) + + async def run(self, context: dict[str, Any]) -> dict[str, Any]: + draft = context["draft"] + draft_translated = context.get("draft_translated", {}) + research = context["research"] + + sections = [ + "## Research Plan (what was asked)", + json.dumps(research, ensure_ascii=False, indent=2), + "", + "## Primary Draft", + json.dumps(draft, ensure_ascii=False, indent=2), + ] + + if draft_translated: + sections.extend([ + "", + "## Translated Version", + json.dumps(draft_translated, ensure_ascii=False, indent=2), + ]) + + prompt = "\n".join(sections) + "\n\nReview the report. Output JSON." + + result = await self.call_llm_json(prompt, max_tokens=4096) + return {"review": result} diff --git a/app/agents/writer.py b/app/agents/writer.py new file mode 100644 index 0000000..21146da --- /dev/null +++ b/app/agents/writer.py @@ -0,0 +1,86 @@ +"""Writer Agent — synthesizes multilingual research tracks into a cohesive report.""" + +from __future__ import annotations + +import json +from typing import Any + +from .base import BaseAgent +from app.config import settings + + +class WriterAgent(BaseAgent): + name = "writer" + description = "汇聚多语言/多领域研究成果,撰写完整报告" + + system_prompt = """\ +You are an expert consulting report writer. Your task is to synthesize research +findings from MULTIPLE parallel tracks (some in English, some in Chinese) into +ONE cohesive, professional consulting report. + +CRITICAL RULES: +1. The PRIMARY output language is Chinese (中文) — this is for Chinese clients +2. For global/international sections, the analysis depth must reflect the English research +3. For China-specific sections, preserve the precision of Chinese-native research +4. Maintain professional consulting tone throughout +5. Every claim should trace back to a research track's findings +6. Mark chart/table needs: {{CHART:描述}} and {{TABLE:描述}} +7. If a research track flags "data_gaps", acknowledge uncertainty rather than fabricating + +Output (JSON): +{ + "title": "报告标题(中文)", + "title_en": "Report Title (English)", + "chapters": [ + { + "title": "章节标题", + "content": "章节正文(Markdown 格式,中文)", + "source_tracks": ["引用的研究轨道名称"], + "charts": ["图表需求"], + "tables": ["表格需求"] + } + ], + "executive_summary": "执行摘要(中文,300-500字)", + "executive_summary_en": "Executive Summary (English, 200-400 words)" +}""" + + def __init__(self): + super().__init__(model=settings.model_for_domain("reasoning")) + + async def run(self, context: dict[str, Any]) -> dict[str, Any]: + research = context["research"] + requirement = context["requirement"] + revision_feedback = context.get("revision_feedback", "") + + # Format multi-track, multilingual research + tracks_text = "" + for track in research.get("tracks", []): + lang_tag = f"[{track.get('native_language', '?').upper()}]" + domain_tag = f"[{track.get('domain', '?')}]" + tracks_text += f"\n### {domain_tag} {lang_tag} {track.get('track', '')}\n" + findings = track.get("findings", {}) + tracks_text += json.dumps(findings, ensure_ascii=False, indent=2) + + synthesis_guide = research.get("synthesis_guide", "") + + prompt = f"""\ +## 原始需求 / Original Requirement +{requirement} + +## 报告标题 +中文:{research.get("title_zh", "")} +English: {research.get("title_en", "")} + +## 写作指导 / Synthesis Guide +{synthesis_guide} + +## 各研究轨道成果 / Research Track Results +(注意:有些轨道是英文原版 [EN],有些是中文原版 [ZH],请综合使用) +{tracks_text} + +{f"## 审稿反馈 / Review Feedback{revision_feedback}" if revision_feedback else ""} + +请汇聚以上研究成果,撰写完整的中文报告。输出 JSON。""" + + result = await self.call_llm_json(prompt, max_tokens=8192) + return {"draft": result} diff --git a/app/api/__init__.py b/app/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/api/routes.py b/app/api/routes.py new file mode 100644 index 0000000..bcf6be8 --- /dev/null +++ b/app/api/routes.py @@ -0,0 +1,110 @@ +"""API routes for report generation.""" + +from __future__ import annotations + +import logging + +from fastapi import APIRouter, HTTPException +from fastapi.responses import JSONResponse +from pydantic import BaseModel + +from app.graph.state import ReportState +from app.pipeline.orchestrator import PipelineOrchestrator +from app.pipeline.task import create_report_state + +logger = logging.getLogger(__name__) +router = APIRouter(prefix="/api") + +# In-memory store (swap for DB later) +reports: dict[str, ReportState] = {} +orchestrator = PipelineOrchestrator() + + +class CreateReportRequest(BaseModel): + requirement: str + report_type: str = "行业分析报告" + extra_data: str = "" + output_formats: list[str] = ["docx"] + client_id: str | None = None + + +class ReportResponse(BaseModel): + id: str + current_node: str + error: str | None = None + generated_files: list[str] = [] + node_history: list[dict] = [] + revision_count: int = 0 + + +def _to_response(state: ReportState) -> ReportResponse: + return ReportResponse( + id=state.id, + current_node=state.current_node, + error=state.error, + generated_files=state.generated_files, + node_history=state.node_history, + revision_count=state.revision_count, + ) + + +@router.post("/reports", response_model=ReportResponse) +async def create_report(req: CreateReportRequest): + """Create and execute a report generation pipeline.""" + state = create_report_state( + requirement=req.requirement, + report_type=req.report_type, + extra_data=req.extra_data, + output_formats=req.output_formats, + client_id=req.client_id, + ) + reports[state.id] = state + + # Run the full graph (blocking for now, add task queue later) + state = await orchestrator.run(state) + reports[state.id] = state + + if state.error: + raise HTTPException(status_code=500, detail=state.error) + + return _to_response(state) + + +@router.get("/reports/{report_id}", response_model=ReportResponse) +async def get_report(report_id: str): + """Get report status and results.""" + state = reports.get(report_id) + if not state: + raise HTTPException(status_code=404, detail="Report not found") + return _to_response(state) + + +@router.get("/reports") +async def list_reports(): + """List all reports.""" + return [_to_response(s) for s in reports.values()] + + +@router.get("/reports/{report_id}/detail") +async def get_report_detail(report_id: str): + """Get full report detail including draft and research.""" + state = reports.get(report_id) + if not state: + raise HTTPException(status_code=404, detail="Report not found") + return { + "id": state.id, + "requirement": state.requirement, + "decomposition": state.decomposition, + "research_results": [ + { + "description": r.description, + "status": r.status.value, + "duration_ms": r.duration_ms, + } + for r in state.research_results + ], + "draft": state.draft, + "review": state.review, + "generated_files": state.generated_files, + "node_history": state.node_history, + } diff --git a/app/config.py b/app/config.py new file mode 100644 index 0000000..03040c5 --- /dev/null +++ b/app/config.py @@ -0,0 +1,57 @@ +"""Application configuration — multi-model pool by domain.""" + +from pathlib import Path +from pydantic_settings import BaseSettings + + +class Settings(BaseSettings): + # --- Model Pool --- + # Global analysis (English-native): strongest global reasoning + llm_global: str = "openai/Claude-3.5-Sonnet" + # China domestic (Chinese-native): deepest Chinese market knowledge + llm_china: str = "openai/DeepSeek-R1" + # Synthesis & review: strongest reasoning for cross-domain work + llm_reasoning: str = "openai/Claude-3.5-Sonnet" + # Data & fast tasks: cost-effective structured output + llm_fast: str = "openai/Gemini-2.0-Flash" + # Translation: high-quality bidirectional EN↔ZH + llm_translation: str = "openai/Claude-3.5-Sonnet" + + # Fallback default (if domain not specified) + llm_model: str = "openai/Claude-3.5-Sonnet" + + # API config (Poe as unified gateway) + llm_api_key: str = "" + llm_api_base: str = "https://api.poe.com/bot/" + + # Server + host: str = "0.0.0.0" + port: int = 4200 + + # Paths + base_dir: Path = Path(__file__).resolve().parent.parent + templates_dir: Path = base_dir / "templates" + output_dir: Path = base_dir / "output" + + model_config = {"env_file": ".env", "env_file_encoding": "utf-8"} + + def model_for_domain(self, domain: str) -> str: + """Get the best model for a content domain. + + Domains: + global — international markets, global competition, tech trends + china — Chinese market, domestic policy, local competition + reasoning — synthesis, review, strategic recommendations + fast — data processing, chart generation, structured output + translation — high-quality EN↔ZH translation + """ + return { + "global": self.llm_global, + "china": self.llm_china, + "reasoning": self.llm_reasoning, + "fast": self.llm_fast, + "translation": self.llm_translation, + }.get(domain, self.llm_model) + + +settings = Settings() diff --git a/app/data/__init__.py b/app/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/data/factory.py b/app/data/factory.py new file mode 100644 index 0000000..032bbef --- /dev/null +++ b/app/data/factory.py @@ -0,0 +1,52 @@ +"""Data router factory — assembles the router with all available sources.""" + +from __future__ import annotations + +import logging + +from .router import DataRouter +from .sources.akshare_source import AKShareSource +from .sources.worldbank_source import WorldBankSource +from .sources.gpt_researcher_source import GPTResearcherSource + +logger = logging.getLogger(__name__) + + +def create_data_router() -> DataRouter: + """Create a DataRouter with all available data sources. + + Priority order (lower = tried first): + 10 AKShare — free, fast, covers Chinese macro/industry + 20 World Bank — free, global macro + 90 GPT Researcher — universal fallback (web research) + + Future additions: + 30 巨潮资讯 — free tier, Chinese public companies + 40 天眼查 API — paid, company data + 50 Choice API — paid, financial data + 60 Statista API — paid, global industry stats + 70 FRED API — free, US macro + 80 UN Comtrade — free, global trade + """ + router = DataRouter() + + # --- Free tier (always available) --- + router.register(AKShareSource(), priority=10) + router.register(WorldBankSource(), priority=20) + + # --- Universal fallback --- + router.register(GPTResearcherSource(), priority=90) + + logger.info(f"[data_factory] created router with {len(router.sources)} sources") + return router + + +# Singleton +_router: DataRouter | None = None + + +def get_data_router() -> DataRouter: + global _router + if _router is None: + _router = create_data_router() + return _router diff --git a/app/data/router.py b/app/data/router.py new file mode 100644 index 0000000..a54cb1e --- /dev/null +++ b/app/data/router.py @@ -0,0 +1,82 @@ +"""DataRouter — routes data requests to the optimal source. + +Architecture: + Agent needs data → DataRouter → best available source → standardized result + +Sources are tried in priority order. If a source fails or is unavailable, +the router falls back to the next source. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from .sources.base import DataSource, DataResult + +logger = logging.getLogger(__name__) + + +class DataRouter: + """Routes data queries to the best available source. + + Sources are registered with a priority (lower = higher priority). + For each query, the router tries sources in priority order until one succeeds. + """ + + def __init__(self): + self.sources: list[tuple[int, DataSource]] = [] + + def register(self, source: DataSource, priority: int = 100) -> "DataRouter": + self.sources.append((priority, source)) + self.sources.sort(key=lambda x: x[0]) + logger.info(f"[data_router] registered {source.name} (priority={priority})") + return self + + async def query( + self, + query: str, + data_type: str = "general", + country: str | None = None, + **kwargs, + ) -> DataResult: + """Query data sources in priority order. + + Args: + query: Natural language or structured data query + data_type: One of: macro, industry, company, trade, patent, general + country: ISO country code (CN, US, etc.) or None for global + **kwargs: Source-specific parameters + """ + errors = [] + + for priority, source in self.sources: + if not source.supports(data_type, country): + continue + + try: + logger.info(f"[data_router] trying {source.name} for '{query[:50]}...'") + result = await source.fetch(query, data_type=data_type, country=country, **kwargs) + if result.data: + logger.info(f"[data_router] {source.name} returned {len(str(result.data))} chars") + return result + except Exception as e: + errors.append(f"{source.name}: {e}") + logger.warning(f"[data_router] {source.name} failed: {e}") + + # All sources failed + return DataResult( + source="none", + data=None, + error=f"All sources failed: {'; '.join(errors)}", + ) + + async def query_multiple( + self, + queries: list[dict[str, Any]], + ) -> list[DataResult]: + """Run multiple queries (can be parallelized later).""" + import asyncio + return await asyncio.gather(*[ + self.query(**q) for q in queries + ]) diff --git a/app/data/sources/__init__.py b/app/data/sources/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/data/sources/akshare_source.py b/app/data/sources/akshare_source.py new file mode 100644 index 0000000..1f5f392 --- /dev/null +++ b/app/data/sources/akshare_source.py @@ -0,0 +1,94 @@ +"""AKShare data source — Chinese macro/industry data via open-source Python library. + +Covers: GDP, CPI, PMI, industrial profit, trade balance, and 30+ data categories. +All data returned as Pandas DataFrames, converted to dicts for standardization. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from .base import DataSource, DataResult + +logger = logging.getLogger(__name__) + +# Map common data requests to AKShare function names +AKSHARE_ENDPOINTS = { + "gdp": "macro_china_gdp", + "cpi": "macro_china_cpi_monthly", + "ppi": "macro_china_ppi", + "pmi": "macro_china_pmi", + "industrial_profit": "macro_china_industrial_profit", + "trade_balance": "macro_china_trade_balance", + "money_supply": "macro_china_money_supply", + "fdi": "macro_china_fdi", + "real_estate": "macro_china_real_estate", + "retail_sales": "macro_china_consumer_goods_retail", + "fixed_asset": "macro_china_fai", + "unemployment": "macro_china_urban_unemployment", + # US macro + "us_gdp": "macro_usa_gdp_monthly", + "us_cpi": "macro_usa_cpi_monthly", + "us_unemployment": "macro_usa_unemployment_rate", + # Global + "global_gdp": "macro_global_gdp", +} + + +class AKShareSource(DataSource): + name = "akshare" + description = "中国宏观经济/行业数据(免费开源,封装统计局等30+数据源)" + + def supports(self, data_type: str, country: str | None = None) -> bool: + return data_type in ("macro", "industry", "general") + + async def fetch( + self, query: str, *, data_type: str = "general", country: str | None = None, **kwargs, + ) -> DataResult: + try: + import akshare as ak + except ImportError: + return DataResult(source=self.name, error="akshare not installed (pip install akshare)") + + # Try to match query to a known endpoint + endpoint_name = kwargs.get("endpoint") + if not endpoint_name: + query_lower = query.lower() + for key, func_name in AKSHARE_ENDPOINTS.items(): + if key in query_lower: + endpoint_name = func_name + break + + if not endpoint_name: + return DataResult(source=self.name, data=None, error=f"No matching AKShare endpoint for: {query}") + + try: + func = getattr(ak, endpoint_name, None) + if not func: + return DataResult(source=self.name, error=f"AKShare function not found: {endpoint_name}") + + logger.info(f"[akshare] calling ak.{endpoint_name}()") + df = func() + + # Convert to dict for serialization + # Take last N rows for recent data + limit = kwargs.get("limit", 20) + recent = df.tail(limit) + + return DataResult( + source=self.name, + data={ + "columns": list(recent.columns), + "records": recent.to_dict(orient="records"), + "total_rows": len(df), + "returned_rows": len(recent), + }, + metadata={ + "endpoint": endpoint_name, + "description": f"AKShare {endpoint_name}", + "format": "tabular", + }, + ) + except Exception as e: + return DataResult(source=self.name, error=f"AKShare call failed: {e}") diff --git a/app/data/sources/base.py b/app/data/sources/base.py new file mode 100644 index 0000000..e6b26bb --- /dev/null +++ b/app/data/sources/base.py @@ -0,0 +1,34 @@ +"""Base class for data sources.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any + +from pydantic import BaseModel, Field + + +class DataResult(BaseModel): + """Standardized result from any data source.""" + source: str = "" + data: Any = None + metadata: dict[str, Any] = Field(default_factory=dict) + # metadata includes: unit, time_range, update_date, confidence, etc. + error: str | None = None + cached: bool = False + + +class DataSource(ABC): + """Abstract data source.""" + name: str = "base" + description: str = "" + + def supports(self, data_type: str, country: str | None = None) -> bool: + """Return True if this source can handle this data type / country.""" + return True + + @abstractmethod + async def fetch( + self, query: str, *, data_type: str = "general", country: str | None = None, **kwargs, + ) -> DataResult: + ... diff --git a/app/data/sources/gpt_researcher_source.py b/app/data/sources/gpt_researcher_source.py new file mode 100644 index 0000000..38a4a9a --- /dev/null +++ b/app/data/sources/gpt_researcher_source.py @@ -0,0 +1,61 @@ +"""GPT Researcher MCP — deep web research as fallback for any industry. + +This is the universal fallback: when structured data sources don't have +data for a niche/cold industry, deep web research fills the gap. + +Requires GPT Researcher MCP server to be running (already configured in ~/.claude.json). +For direct API use, we call the MCP tools via the subprocess approach. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import subprocess +from typing import Any + +from .base import DataSource, DataResult + +logger = logging.getLogger(__name__) + + +class GPTResearcherSource(DataSource): + name = "gpt_researcher" + description = "Deep web research — universal fallback for any industry/topic" + + def supports(self, data_type: str, country: str | None = None) -> bool: + # Supports everything — this is the universal fallback + return True + + async def fetch( + self, query: str, *, data_type: str = "general", country: str | None = None, **kwargs, + ) -> DataResult: + mode = kwargs.get("mode", "quick") # "quick" or "deep" + + # GPT Researcher is available as MCP tools in Claude Code. + # For standalone use, we need to call it via its API. + # The MCP server runs at a local port — check if available. + + # For now, provide a structured placeholder that agents can use + # to request deep research. The actual MCP call happens at the + # agent level when integrated into the pipeline. + return DataResult( + source=self.name, + data={ + "query": query, + "mode": mode, + "status": "ready", + "note": ( + "GPT Researcher MCP is available for deep web research. " + "Call via MCP tools: deep_research() or quick_search(). " + "This source returns research-ready queries for MCP integration." + ), + }, + metadata={ + "type": "mcp_research_request", + "mode": mode, + "data_type": data_type, + "country": country, + }, + ) diff --git a/app/data/sources/worldbank_source.py b/app/data/sources/worldbank_source.py new file mode 100644 index 0000000..87d17ba --- /dev/null +++ b/app/data/sources/worldbank_source.py @@ -0,0 +1,104 @@ +"""World Bank Open Data — global macro indicators, 217 economies, free API. + +API: https://api.worldbank.org/v2/ +""" + +from __future__ import annotations + +import logging +from typing import Any + +import httpx + +from .base import DataSource, DataResult + +logger = logging.getLogger(__name__) + +BASE_URL = "https://api.worldbank.org/v2" + +# Common indicators for consulting reports +INDICATORS = { + "gdp": "NY.GDP.MKTP.CD", # GDP (current US$) + "gdp_growth": "NY.GDP.MKTP.KD.ZG", # GDP growth (annual %) + "gdp_per_capita": "NY.GDP.PCAP.CD", # GDP per capita + "population": "SP.POP.TOTL", # Total population + "inflation": "FP.CPI.TOTL.ZG", # Inflation (CPI %) + "trade_pct_gdp": "NE.TRD.GNFS.ZS", # Trade (% of GDP) + "fdi_net": "BX.KLT.DINV.CD.WD", # FDI net inflows + "unemployment": "SL.UEM.TOTL.ZS", # Unemployment (%) + "exports": "NE.EXP.GNFS.CD", # Exports + "imports": "NE.IMP.GNFS.CD", # Imports + "r_and_d": "GB.XPD.RSDV.GD.ZS", # R&D expenditure (% GDP) + "high_tech_exports": "TX.VAL.TECH.MF.ZS", # High-tech exports (% manufactured) +} + + +class WorldBankSource(DataSource): + name = "worldbank" + description = "World Bank Open Data — 1600+ indicators, 217 economies, free" + + def supports(self, data_type: str, country: str | None = None) -> bool: + return data_type in ("macro", "general") + + async def fetch( + self, query: str, *, data_type: str = "general", country: str | None = None, **kwargs, + ) -> DataResult: + indicator_code = kwargs.get("indicator") + if not indicator_code: + query_lower = query.lower() + for key, code in INDICATORS.items(): + if key in query_lower: + indicator_code = code + break + + if not indicator_code: + # Default to GDP + indicator_code = INDICATORS["gdp"] + + country_code = country or "WLD" # WLD = World + per_page = kwargs.get("per_page", 20) + + url = f"{BASE_URL}/country/{country_code}/indicator/{indicator_code}" + params = { + "format": "json", + "per_page": per_page, + } + + try: + async with httpx.AsyncClient(timeout=15) as client: + resp = await client.get(url, params=params) + resp.raise_for_status() + data = resp.json() + + if not data or len(data) < 2: + return DataResult(source=self.name, data=None, error="No data returned") + + metadata_raw = data[0] + records = data[1] + + # Parse into clean format + clean_records = [] + for r in records: + if r.get("value") is not None: + clean_records.append({ + "year": r["date"], + "value": r["value"], + "country": r["country"]["value"], + "indicator": r["indicator"]["value"], + }) + + return DataResult( + source=self.name, + data={ + "indicator": indicator_code, + "country": country_code, + "records": clean_records, + }, + metadata={ + "total": metadata_raw.get("total", 0), + "indicator_name": clean_records[0]["indicator"] if clean_records else "", + "format": "timeseries", + }, + ) + except Exception as e: + return DataResult(source=self.name, error=f"World Bank API failed: {e}") diff --git a/app/graph/__init__.py b/app/graph/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/graph/builder.py b/app/graph/builder.py new file mode 100644 index 0000000..b48437d --- /dev/null +++ b/app/graph/builder.py @@ -0,0 +1,108 @@ +"""Graph builder — assembles nodes into an executable report generation graph. + +No LangGraph dependency. Pure asyncio with a simple node-runner pattern. + +v2: domain-aware, bilingual (translate node added) +""" + +from __future__ import annotations + +import logging +from typing import Callable, Awaitable + +from app.middleware.chain import MiddlewareChain +from .state import ReportState, NodeStatus +from .nodes import ( + DecomposeNode, + ParallelResearchNode, + WriteNode, + TranslateNode, + DataNode, + ReviewNode, + FormatNode, +) + +logger = logging.getLogger(__name__) + +NodeFn = Callable[[ReportState], Awaitable[ReportState]] + + +class ReportGraph: + """Executable graph for report generation. + + Graph structure (v2): + decompose → parallel_research → write → translate → data → review → + ├─ pass → format → END + └─ revise → write → translate → data → review → ... + """ + + def __init__(self, middleware: MiddlewareChain | None = None): + self.middleware = middleware + self.decompose = DecomposeNode() + self.parallel_research = ParallelResearchNode() + self.write = WriteNode() + self.translate = TranslateNode() + self.data = DataNode() + self.review = ReviewNode() + self.format = FormatNode() + + async def _run_node(self, name: str, node: NodeFn, state: ReportState) -> ReportState: + """Run a single node with error handling.""" + try: + logger.info(f"[graph] entering node: {name}") + state = await node(state) + logger.info(f"[graph] completed node: {name}") + except Exception as e: + state.error = f"Node '{name}' failed: {e}" + state.log_node(name, NodeStatus.FAILED, str(e)) + logger.exception(f"[graph] node '{name}' failed") + raise + return state + + async def run(self, state: ReportState) -> ReportState: + """Execute the full graph.""" + + # --- Middleware: before --- + if self.middleware: + state = await self.middleware.before(state) + + try: + # 1. Decompose requirement into domain-tagged parallel tracks + state = await self._run_node("decompose", self.decompose, state) + + # 2. Run parallel research (each track uses domain-optimal model) + state = await self._run_node("parallel_research", self.parallel_research, state) + + # 3-6. Write → Translate → Data → Review (with revision loop) + while True: + state = await self._run_node("write", self.write, state) + state = await self._run_node("translate", self.translate, state) + state = await self._run_node("data", self.data, state) + state = await self._run_node("review", self.review, state) + + verdict = state.review.get("verdict", "pass") + if verdict == "pass" or state.revision_count >= state.max_revisions: + if verdict != "pass": + logger.warning( + f"[graph] forcing pass after {state.revision_count} revisions" + ) + break + + # Revise — loop back to write + state.revision_count += 1 + logger.info( + f"[graph] revision {state.revision_count}/{state.max_revisions}" + ) + + # 7. Format bilingual output files + state = await self._run_node("format", self.format, state) + + except Exception: + # Error already logged in _run_node + pass + + # --- Middleware: after --- + if self.middleware: + state = await self.middleware.after(state) + + return state diff --git a/app/graph/nodes.py b/app/graph/nodes.py new file mode 100644 index 0000000..046483c --- /dev/null +++ b/app/graph/nodes.py @@ -0,0 +1,393 @@ +"""Graph nodes — each node is an async function: ReportState → ReportState. + +Node layout (v2 — domain-aware, bilingual): + + START + │ + ▼ + [decompose] — Lead Agent 分解为并行研究轨道,每轨标注 domain + language + │ + ▼ + [parallel_research] — N 个子 Agent 并行,每个用最适合该领域的模型 + │ global tracks → Claude/GPT (English) + │ china tracks → DeepSeek/Qwen (Chinese) + ▼ + [write] — Writer 汇聚 → 生成主语言版本 + │ + ▼ + [translate] — 高质量翻译 → 生成另一语言版本 + │ + ▼ + [data] — Data Agent 生成图表/表格 + │ + ▼ + [review] — Reviewer 审查(双语) + │ ├─ pass → [format] + │ └─ revise → [write] + ▼ + [format] — 输出双语版本文件 + │ + ▼ + END +""" + +from __future__ import annotations + +import asyncio +import json +import logging +from datetime import datetime +from typing import Any + +from app.agents.base import BaseAgent +from app.agents.researcher import ResearcherAgent +from app.agents.writer import WriterAgent +from app.agents.data_agent import DataAgent +from app.agents.reviewer import ReviewerAgent +from app.agents.formatter import FormatterAgent +from app.config import settings + +from .state import ReportState, SubtaskResult, NodeStatus, ContentDomain + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Node: decompose — Lead Agent decomposes into domain-tagged parallel tracks +# --------------------------------------------------------------------------- + +class DecomposeNode: + """Analyzes requirement and decomposes into domain-aware research tracks.""" + + def __init__(self): + self.agent = BaseAgent() + self.agent.name = "lead" + self.agent.model = settings.model_for_domain("reasoning") + + async def __call__(self, state: ReportState) -> ReportState: + state.current_node = "decompose" + state.log_node("decompose", NodeStatus.RUNNING) + + system = """\ +You are a senior consulting partner planning a global industry report. + +Your job is to decompose the client's requirement into 2-6 parallel research tracks. + +CRITICAL: Each track must be tagged with a content domain and native language: + +- domain: "global" → international markets, global competition, technology trends, overseas benchmarks + → native_language: "en" (English sources are 10-100x richer for global analysis) + +- domain: "china" → Chinese domestic market, government policy, local competitors, China-specific data + → native_language: "zh" (Chinese sources are authoritative for domestic analysis) + +The PRINCIPLE: whichever language has the richest professional literature for that topic +should be the native language. The other language version will be translated later. + +Output (JSON): +{ + "title_en": "English report title", + "title_zh": "中文报告标题", + "report_type": "report type", + "tracks": [ + { + "title": "track title (in native language)", + "domain": "global|china", + "native_language": "en|zh", + "focus": "research focus description", + "prompt": "detailed research instructions (MUST be in the native_language)", + "data_needs": ["required data/charts"] + } + ], + "synthesis_guide": "How to merge all tracks into a coherent report (bilingual structure notes)", + "methodology": "Analysis methodology" +}""" + + prompt = f"""\ +## Client requirement +{state.requirement} + +## Report type +{state.report_type} + +## Additional data +{state.extra_data or "(none)"} + +## Client context +{state.client_context or "(none)"} + +Decompose into parallel research tracks with domain and language tags. Output JSON.""" + + result = await self.agent.call_llm_json(prompt, system=system) + state.decomposition = result + state.log_node("decompose", NodeStatus.COMPLETED, + f"{len(result.get('tracks', []))} tracks") + return state + + +# --------------------------------------------------------------------------- +# Node: parallel_research — domain-aware parallel execution +# --------------------------------------------------------------------------- + +class ParallelResearchNode: + """Runs research subtasks in parallel, each using the optimal model for its domain.""" + + MAX_CONCURRENT = 5 + + async def _run_one(self, track: dict[str, Any]) -> SubtaskResult: + domain_str = track.get("domain", "global") + domain = ContentDomain(domain_str) if domain_str in ContentDomain.__members__.values() else ContentDomain.GLOBAL + native_lang = track.get("native_language", "en") + + result = SubtaskResult( + description=track.get("title", ""), + domain=domain, + native_language=native_lang, + ) + result.status = NodeStatus.RUNNING + result.started_at = datetime.now() + + try: + # Select model based on domain + model = settings.model_for_domain(domain.value) + agent = ResearcherAgent(model=model, language=native_lang) + + logger.info( + f"[parallel_research] track '{track.get('title')}' " + f"→ domain={domain.value}, lang={native_lang}, model={model}" + ) + + research = await agent.run({ + "requirement": track["prompt"], + "report_type": track.get("focus", ""), + "extra_data": "", + }) + result.content = research.get("research", {}) + result.status = NodeStatus.COMPLETED + except Exception as e: + result.error = str(e) + result.status = NodeStatus.FAILED + logger.exception(f"Research track '{track.get('title')}' failed") + finally: + result.completed_at = datetime.now() + + return result + + async def __call__(self, state: ReportState) -> ReportState: + state.current_node = "parallel_research" + state.log_node("parallel_research", NodeStatus.RUNNING) + + tracks = state.decomposition.get("tracks", []) + if not tracks: + state.log_node("parallel_research", NodeStatus.FAILED, "no tracks") + state.error = "Decomposition produced no research tracks" + return state + + semaphore = asyncio.Semaphore(self.MAX_CONCURRENT) + + async def bounded(track): + async with semaphore: + return await self._run_one(track) + + logger.info(f"[parallel_research] launching {len(tracks)} tracks concurrently") + results = await asyncio.gather(*[bounded(t) for t in tracks]) + state.research_results = list(results) + + succeeded = sum(1 for r in results if r.status == NodeStatus.COMPLETED) + domains = {} + for r in results: + domains.setdefault(r.domain.value, []).append(r.native_language) + state.log_node("parallel_research", NodeStatus.COMPLETED, + f"{succeeded}/{len(tracks)} ok, domains={domains}") + return state + + +# --------------------------------------------------------------------------- +# Node: write — synthesize research into primary-language draft +# --------------------------------------------------------------------------- + +class WriteNode: + def __init__(self): + self.agent = WriterAgent() + + async def __call__(self, state: ReportState) -> ReportState: + state.current_node = "write" + state.log_node("write", NodeStatus.RUNNING) + + research_merged = [] + for r in state.research_results: + if r.status == NodeStatus.COMPLETED: + research_merged.append({ + "track": r.description, + "domain": r.domain.value, + "native_language": r.native_language, + "findings": r.content, + }) + + synthesis_guide = state.decomposition.get("synthesis_guide", "") + review_feedback = "" + if state.revision_count > 0 and state.review: + review_feedback = f"\n\n## Review feedback (revision {state.revision_count})\n" + for issue in state.review.get("issues", []): + review_feedback += f"- [{issue.get('severity')}] {issue.get('description')} → {issue.get('suggestion')}\n" + + result = await self.agent.run({ + "requirement": state.requirement, + "research": { + "title_en": state.decomposition.get("title_en", ""), + "title_zh": state.decomposition.get("title_zh", ""), + "methodology": state.decomposition.get("methodology", ""), + "tracks": research_merged, + "synthesis_guide": synthesis_guide, + }, + "revision_feedback": review_feedback, + }) + + state.draft = result.get("draft", {}) + state.log_node("write", NodeStatus.COMPLETED) + return state + + +# --------------------------------------------------------------------------- +# Node: translate — produce the other language version +# --------------------------------------------------------------------------- + +class TranslateNode: + """Translates the draft into the other language version.""" + + def __init__(self): + self.agent = BaseAgent() + self.agent.name = "translator" + self.agent.model = settings.model_for_domain("translation") + + async def __call__(self, state: ReportState) -> ReportState: + state.current_node = "translate" + state.log_node("translate", NodeStatus.RUNNING) + + if not state.draft or "en" not in state.output_languages: + state.log_node("translate", NodeStatus.COMPLETED, "skipped") + return state + + draft_json = json.dumps(state.draft, ensure_ascii=False, indent=2) + + # Detect primary language of draft + title = state.draft.get("title", "") + is_chinese_primary = any('\u4e00' <= c <= '\u9fff' for c in title) + + if is_chinese_primary: + target_lang = "English" + source_lang = "Chinese" + else: + target_lang = "Chinese (Simplified)" + source_lang = "English" + + system = f"""\ +You are a world-class {source_lang} → {target_lang} translator specializing in +consulting and business reports. + +Translation principles: +1. ACCURACY over fluency — every data point, percentage, and proper noun must be correct +2. Professional terminology — use standard {target_lang} business/industry terms +3. Preserve structure — keep the exact same JSON structure, only translate text values +4. Cultural adaptation — adjust phrasing for the target audience (not word-for-word) +5. Keep {{{{CHART:...}}}} and {{{{TABLE:...}}}} markers, translate their descriptions + +Output the translated JSON with the exact same structure.""" + + prompt = f"""\ +Translate this consulting report from {source_lang} to {target_lang}. + +{draft_json} + +Output the translated JSON.""" + + translated = await self.agent.call_llm_json(prompt, system=system, max_tokens=8192) + state.draft_translated = translated + state.log_node("translate", NodeStatus.COMPLETED, + f"{source_lang} → {target_lang}") + return state + + +# --------------------------------------------------------------------------- +# Node: data — generate charts and tables +# --------------------------------------------------------------------------- + +class DataNode: + def __init__(self): + self.agent = DataAgent() + + async def __call__(self, state: ReportState) -> ReportState: + state.current_node = "data" + state.log_node("data", NodeStatus.RUNNING) + + result = await self.agent.run({ + "draft": state.draft, + "extra_data": state.extra_data, + }) + + state.data_assets = result.get("data_assets", {}) + state.log_node("data", NodeStatus.COMPLETED) + return state + + +# --------------------------------------------------------------------------- +# Node: review — bilingual quality check +# --------------------------------------------------------------------------- + +class ReviewNode: + def __init__(self): + self.agent = ReviewerAgent() + + async def __call__(self, state: ReportState) -> ReportState: + state.current_node = "review" + state.log_node("review", NodeStatus.RUNNING) + + result = await self.agent.run({ + "draft": state.draft, + "draft_translated": state.draft_translated, + "research": state.decomposition, + }) + + state.review = result.get("review", {}) + state.log_node("review", NodeStatus.COMPLETED, + f"verdict={state.review.get('verdict', '?')}") + return state + + +# --------------------------------------------------------------------------- +# Node: format — render bilingual output files +# --------------------------------------------------------------------------- + +class FormatNode: + def __init__(self): + self.agent = FormatterAgent() + + async def __call__(self, state: ReportState) -> ReportState: + state.current_node = "format" + state.log_node("format", NodeStatus.RUNNING) + + all_files = [] + + # Primary version + result = await self.agent.run({ + "draft": state.draft, + "data_assets": state.data_assets, + "output_dir": str(settings.output_dir / state.id / "primary"), + "output_formats": state.output_formats, + }) + all_files.extend(result.get("generated_files", [])) + + # Translated version (if available) + if state.draft_translated: + result_tr = await self.agent.run({ + "draft": state.draft_translated, + "data_assets": state.data_assets, + "output_dir": str(settings.output_dir / state.id / "translated"), + "output_formats": state.output_formats, + }) + all_files.extend(result_tr.get("generated_files", [])) + + state.generated_files = all_files + state.log_node("format", NodeStatus.COMPLETED, + f"{len(all_files)} files") + return state diff --git a/app/graph/state.py b/app/graph/state.py new file mode 100644 index 0000000..1e1879b --- /dev/null +++ b/app/graph/state.py @@ -0,0 +1,104 @@ +"""Report generation graph state — the shared context that flows through all nodes.""" + +from __future__ import annotations + +import uuid +from datetime import datetime +from enum import Enum +from typing import Any + +from pydantic import BaseModel, Field + + +class NodeStatus(str, Enum): + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + + +class ContentDomain(str, Enum): + """Content domain — determines which model and language to use.""" + GLOBAL = "global" # International markets, global trends → English-native + CHINA = "china" # Chinese market, domestic policy → Chinese-native + REASONING = "reasoning" # Synthesis, review, strategy → strongest reasoning + FAST = "fast" # Data processing, charts → cost-effective + TRANSLATION = "translation" # EN↔ZH translation + + +class SubtaskResult(BaseModel): + """Result from a parallel research subtask.""" + task_id: str = Field(default_factory=lambda: uuid.uuid4().hex[:8]) + description: str = "" + domain: ContentDomain = ContentDomain.GLOBAL + native_language: str = "en" # "en" or "zh" — the original writing language + status: NodeStatus = NodeStatus.PENDING + content: dict[str, Any] = Field(default_factory=dict) + error: str | None = None + started_at: datetime | None = None + completed_at: datetime | None = None + + @property + def duration_ms(self) -> int | None: + if self.started_at and self.completed_at: + return int((self.completed_at - self.started_at).total_seconds() * 1000) + return None + + +class ReportState(BaseModel): + """Full state for a report generation run. + + This is the single source of truth that all graph nodes read from and write to. + """ + # Identity + id: str = Field(default_factory=lambda: uuid.uuid4().hex[:12]) + created_at: datetime = Field(default_factory=datetime.now) + + # --- Input (set once at start) --- + requirement: str = "" + report_type: str = "行业分析报告" + extra_data: str = "" + output_formats: list[str] = Field(default=["docx"]) + output_languages: list[str] = Field(default=["zh", "en"]) # produce both versions + template_name: str | None = None + client_id: str | None = None # for multi-tenant isolation + + # --- Middleware injections --- + client_context: str = "" # injected by ClientContextMiddleware + memory_facts: list[str] = Field(default_factory=list) # injected by MemoryMiddleware + token_budget: int = 120000 # managed by TokenBudgetMiddleware + + # --- Lead Agent output --- + decomposition: dict[str, Any] = Field(default_factory=dict) + # e.g. {"tracks": [{"title": "政策环境", "prompt": "研究..."}, ...]} + + # --- Parallel research results --- + research_results: list[SubtaskResult] = Field(default_factory=list) + + # --- Writer output --- + draft: dict[str, Any] = Field(default_factory=dict) # primary language version + draft_translated: dict[str, Any] = Field(default_factory=dict) # translated version + + # --- Data Agent output --- + data_assets: dict[str, Any] = Field(default_factory=dict) + + # --- Reviewer output --- + review: dict[str, Any] = Field(default_factory=dict) + revision_count: int = 0 + max_revisions: int = 2 + + # --- Formatter output --- + generated_files: list[str] = Field(default_factory=list) + + # --- Execution tracking --- + current_node: str = "" + node_history: list[dict[str, Any]] = Field(default_factory=list) + error: str | None = None + + def log_node(self, node_name: str, status: NodeStatus, detail: str = ""): + self.node_history.append({ + "node": node_name, + "status": status.value, + "detail": detail, + "timestamp": datetime.now().isoformat(), + }) diff --git a/app/main.py b/app/main.py new file mode 100644 index 0000000..50e867a --- /dev/null +++ b/app/main.py @@ -0,0 +1,48 @@ +"""FastAPI application entry point.""" + +import logging + +from fastapi import FastAPI +from fastapi.responses import FileResponse +from fastapi.staticfiles import StaticFiles + +from app.api.routes import router +from app.config import settings + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(name)s] %(levelname)s: %(message)s", +) + +app = FastAPI( + title="咨询报告 AI 生成系统", + version="0.1.0", + description="Multi-agent pipeline for generating consulting reports", +) + +app.include_router(router) + +# Serve generated files +settings.output_dir.mkdir(parents=True, exist_ok=True) +app.mount("/files", StaticFiles(directory=str(settings.output_dir)), name="files") + + +@app.get("/") +async def root(): + return { + "name": "咨询报告 AI 生成系统", + "version": "0.1.0", + "status": "running", + "docs": "/docs", + } + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run( + "app.main:app", + host=settings.host, + port=settings.port, + reload=True, + ) diff --git a/app/memory/__init__.py b/app/memory/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/memory/store.py b/app/memory/store.py new file mode 100644 index 0000000..72863c4 --- /dev/null +++ b/app/memory/store.py @@ -0,0 +1,114 @@ +"""File-based memory store — persistent facts with confidence ranking.""" + +from __future__ import annotations + +import json +import logging +import uuid +from datetime import datetime +from pathlib import Path +from typing import Any + +from pydantic import BaseModel, Field + +logger = logging.getLogger(__name__) + +MEMORY_DIR = Path(__file__).resolve().parent.parent.parent / "memory" + + +class Fact(BaseModel): + id: str = Field(default_factory=lambda: uuid.uuid4().hex[:8]) + content: str + category: str = "context" # preference | knowledge | context | behavior | goal + confidence: float = 0.7 + source: str = "" # which report/session created this + created_at: str = Field(default_factory=lambda: datetime.now().isoformat()) + + +class MemoryFile(BaseModel): + """One memory file per client (or global).""" + client_id: str = "global" + preferences: dict[str, str] = Field(default_factory=dict) + facts: list[Fact] = Field(default_factory=list) + + +class MemoryStore: + """Read/write persistent memory as JSON files. + + Storage layout: + memory/ + ├── global.json — system-wide facts + └── client_.json — per-client facts + """ + + def __init__(self): + MEMORY_DIR.mkdir(parents=True, exist_ok=True) + + def _path(self, client_id: str = "global") -> Path: + safe_name = client_id.replace("/", "_").replace("..", "_") + return MEMORY_DIR / f"{safe_name}.json" + + def load(self, client_id: str = "global") -> MemoryFile: + path = self._path(client_id) + if not path.exists(): + return MemoryFile(client_id=client_id) + try: + data = json.loads(path.read_text(encoding="utf-8")) + return MemoryFile(**data) + except Exception as e: + logger.warning(f"[memory] failed to load {path}: {e}") + return MemoryFile(client_id=client_id) + + def save(self, mem: MemoryFile): + path = self._path(mem.client_id) + path.write_text( + json.dumps(mem.model_dump(), ensure_ascii=False, indent=2), + encoding="utf-8", + ) + logger.info(f"[memory] saved {len(mem.facts)} facts to {path}") + + def add_fact( + self, + content: str, + client_id: str = "global", + category: str = "context", + confidence: float = 0.7, + source: str = "", + ) -> Fact: + mem = self.load(client_id) + + # Deduplicate by content (normalized) + normalized = content.strip().lower() + for existing in mem.facts: + if existing.content.strip().lower() == normalized: + # Update confidence if higher + if confidence > existing.confidence: + existing.confidence = confidence + self.save(mem) + return existing + + fact = Fact( + content=content, + category=category, + confidence=confidence, + source=source, + ) + mem.facts.append(fact) + self.save(mem) + return fact + + def get_top_facts( + self, client_id: str = "global", limit: int = 15 + ) -> list[str]: + """Get top N facts sorted by confidence, formatted for prompt injection.""" + mem = self.load(client_id) + sorted_facts = sorted(mem.facts, key=lambda f: f.confidence, reverse=True) + return [f.content for f in sorted_facts[:limit]] + + def set_preference(self, key: str, value: str, client_id: str = "global"): + mem = self.load(client_id) + mem.preferences[key] = value + self.save(mem) + + def get_preferences(self, client_id: str = "global") -> dict[str, str]: + return self.load(client_id).preferences diff --git a/app/middleware/__init__.py b/app/middleware/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/middleware/base.py b/app/middleware/base.py new file mode 100644 index 0000000..7d6fc24 --- /dev/null +++ b/app/middleware/base.py @@ -0,0 +1,22 @@ +"""Middleware base class — before/after hooks around graph execution.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod + +from app.graph.state import ReportState + + +class Middleware(ABC): + """Base middleware. Subclasses override before() and/or after().""" + + name: str = "base" + enabled: bool = True + + async def before(self, state: ReportState) -> ReportState: + """Called before graph execution. Modify state as needed.""" + return state + + async def after(self, state: ReportState) -> ReportState: + """Called after graph execution. Modify state as needed.""" + return state diff --git a/app/middleware/chain.py b/app/middleware/chain.py new file mode 100644 index 0000000..f1380bf --- /dev/null +++ b/app/middleware/chain.py @@ -0,0 +1,35 @@ +"""Middleware chain — runs ordered middlewares before/after graph execution.""" + +from __future__ import annotations + +import logging + +from app.graph.state import ReportState +from .base import Middleware + +logger = logging.getLogger(__name__) + + +class MiddlewareChain: + """Ordered list of middlewares. before() runs in order, after() runs in reverse.""" + + def __init__(self, middlewares: list[Middleware] | None = None): + self.middlewares = middlewares or [] + + def add(self, mw: Middleware) -> "MiddlewareChain": + self.middlewares.append(mw) + return self + + async def before(self, state: ReportState) -> ReportState: + for mw in self.middlewares: + if mw.enabled: + logger.debug(f"[middleware:before] {mw.name}") + state = await mw.before(state) + return state + + async def after(self, state: ReportState) -> ReportState: + for mw in reversed(self.middlewares): + if mw.enabled: + logger.debug(f"[middleware:after] {mw.name}") + state = await mw.after(state) + return state diff --git a/app/middleware/client_context.py b/app/middleware/client_context.py new file mode 100644 index 0000000..bb3882f --- /dev/null +++ b/app/middleware/client_context.py @@ -0,0 +1,56 @@ +"""ClientContext middleware — injects client/project background into state.""" + +from __future__ import annotations + +import json +import logging +from pathlib import Path + +from app.graph.state import ReportState +from .base import Middleware + +logger = logging.getLogger(__name__) + +# Client profiles stored as JSON files +CLIENTS_DIR = Path(__file__).resolve().parent.parent.parent / "clients" + + +class ClientContextMiddleware(Middleware): + """Loads client profile and injects background context into state. + + Client profiles are JSON files in the `clients/` directory: + clients/.json + { + "name": "某咨询公司", + "industry": "金融", + "preferences": "偏好简洁的执行摘要,数据驱动", + "previous_reports": ["report_001", "report_002"] + } + """ + + name = "client_context" + + async def before(self, state: ReportState) -> ReportState: + if not state.client_id: + return state + + profile_path = CLIENTS_DIR / f"{state.client_id}.json" + if not profile_path.exists(): + logger.info(f"[client_context] no profile for client '{state.client_id}'") + return state + + try: + profile = json.loads(profile_path.read_text(encoding="utf-8")) + parts = [] + if name := profile.get("name"): + parts.append(f"客户:{name}") + if industry := profile.get("industry"): + parts.append(f"行业:{industry}") + if prefs := profile.get("preferences"): + parts.append(f"偏好:{prefs}") + state.client_context = ";".join(parts) + logger.info(f"[client_context] loaded profile for '{state.client_id}'") + except Exception as e: + logger.warning(f"[client_context] failed to load profile: {e}") + + return state diff --git a/app/middleware/compliance.py b/app/middleware/compliance.py new file mode 100644 index 0000000..01e1952 --- /dev/null +++ b/app/middleware/compliance.py @@ -0,0 +1,61 @@ +"""Compliance middleware — checks for sensitive data leakage in output.""" + +from __future__ import annotations + +import json +import logging +import re + +from app.graph.state import ReportState +from .base import Middleware + +logger = logging.getLogger(__name__) + +# Patterns that suggest sensitive data +SENSITIVE_PATTERNS = [ + (r"\b\d{15,18}\b", "可能的身份证号"), + (r"\b\d{16,19}\b", "可能的银行卡号"), + (r"\b1[3-9]\d{9}\b", "可能的手机号"), + (r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", "邮箱地址"), + (r"(?:密码|password|secret|token|api.?key)\s*[:=]\s*\S+", "可能的凭证信息"), +] + + +class ComplianceMiddleware(Middleware): + """Scans report output for sensitive data patterns. + + After graph execution, scans the draft for PII / credential patterns. + Logs warnings but does not block (can be made blocking later). + """ + + name = "compliance" + + def _scan_text(self, text: str) -> list[dict]: + findings = [] + for pattern, desc in SENSITIVE_PATTERNS: + matches = re.findall(pattern, text) + if matches: + findings.append({ + "type": desc, + "count": len(matches), + "samples": [m[:8] + "..." for m in matches[:3]], + }) + return findings + + async def after(self, state: ReportState) -> ReportState: + # Scan draft content + all_text = json.dumps(state.draft, ensure_ascii=False) + findings = self._scan_text(all_text) + + if findings: + logger.warning( + f"[compliance] found {len(findings)} sensitive data patterns:" + ) + for f in findings: + logger.warning(f" - {f['type']}: {f['count']} occurrences") + # Store in state for API to surface + state.review.setdefault("compliance_warnings", findings) + else: + logger.info("[compliance] no sensitive data patterns detected") + + return state diff --git a/app/middleware/memory.py b/app/middleware/memory.py new file mode 100644 index 0000000..71e345a --- /dev/null +++ b/app/middleware/memory.py @@ -0,0 +1,64 @@ +"""Memory middleware — injects persistent facts into state, saves new facts after.""" + +from __future__ import annotations + +import logging + +from app.graph.state import ReportState +from app.memory.store import MemoryStore +from .base import Middleware + +logger = logging.getLogger(__name__) + + +class MemoryMiddleware(Middleware): + """Injects top-N memory facts into state before execution. + + After execution, extracts key facts from the report for future reference. + """ + + name = "memory" + + def __init__(self): + self.store = MemoryStore() + + async def before(self, state: ReportState) -> ReportState: + # Load global facts + client-specific facts + facts = self.store.get_top_facts("global", limit=10) + if state.client_id: + client_facts = self.store.get_top_facts(state.client_id, limit=10) + facts = client_facts + facts # client facts first + + state.memory_facts = facts[:15] # cap total + if facts: + logger.info(f"[memory] injected {len(state.memory_facts)} facts") + return state + + async def after(self, state: ReportState) -> ReportState: + # Save key metadata as facts for future reference + if state.draft and not state.error: + title = state.draft.get("title", "") + report_type = state.report_type + client_id = state.client_id or "global" + + self.store.add_fact( + content=f"生成过报告:{title}(类型:{report_type})", + client_id=client_id, + category="context", + confidence=0.9, + source=state.id, + ) + + # Save decomposition tracks as knowledge + tracks = state.decomposition.get("tracks", []) + if tracks: + track_titles = "、".join(t.get("title", "") for t in tracks) + self.store.add_fact( + content=f"报告《{title}》的研究轨道:{track_titles}", + client_id=client_id, + category="knowledge", + confidence=0.6, + source=state.id, + ) + + return state diff --git a/app/middleware/token_budget.py b/app/middleware/token_budget.py new file mode 100644 index 0000000..3b3b817 --- /dev/null +++ b/app/middleware/token_budget.py @@ -0,0 +1,46 @@ +"""TokenBudget middleware — tracks and limits token usage across the pipeline.""" + +from __future__ import annotations + +import logging + +from app.graph.state import ReportState +from .base import Middleware + +logger = logging.getLogger(__name__) + + +class TokenBudgetMiddleware(Middleware): + """Manages token budget to prevent context overflow. + + Before: sets token budget based on model limits. + After: logs total estimated token usage. + """ + + name = "token_budget" + + # Rough char-to-token ratio for Chinese text + CHARS_PER_TOKEN = 1.5 + + def _estimate_tokens(self, text: str) -> int: + return int(len(text) / self.CHARS_PER_TOKEN) + + async def before(self, state: ReportState) -> ReportState: + # Default budget is already set in ReportState (120k) + # Could adjust based on model selection + logger.info(f"[token_budget] budget = {state.token_budget} tokens") + return state + + async def after(self, state: ReportState) -> ReportState: + # Estimate total tokens used in the draft + total_chars = 0 + for chapter in state.draft.get("chapters", []): + total_chars += len(chapter.get("content", "")) + total_chars += len(state.draft.get("executive_summary", "")) + + estimated = self._estimate_tokens(str(total_chars)) + logger.info( + f"[token_budget] estimated output tokens: {estimated}, " + f"budget: {state.token_budget}" + ) + return state diff --git a/app/output/__init__.py b/app/output/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/pipeline/__init__.py b/app/pipeline/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/pipeline/orchestrator.py b/app/pipeline/orchestrator.py new file mode 100644 index 0000000..548fd2a --- /dev/null +++ b/app/pipeline/orchestrator.py @@ -0,0 +1,43 @@ +"""Pipeline orchestrator — builds the graph + middleware chain and executes.""" + +from __future__ import annotations + +import logging + +from app.graph.builder import ReportGraph +from app.graph.state import ReportState +from app.middleware.chain import MiddlewareChain +from app.middleware.client_context import ClientContextMiddleware +from app.middleware.token_budget import TokenBudgetMiddleware +from app.middleware.compliance import ComplianceMiddleware +from app.middleware.memory import MemoryMiddleware + +logger = logging.getLogger(__name__) + + +def build_middleware() -> MiddlewareChain: + """Assemble the middleware chain in execution order.""" + chain = MiddlewareChain() + chain.add(ClientContextMiddleware()) # 1. load client profile + chain.add(MemoryMiddleware()) # 2. inject memory facts + chain.add(TokenBudgetMiddleware()) # 3. set token budget + chain.add(ComplianceMiddleware()) # 4. scan output for PII (after only) + return chain + + +class PipelineOrchestrator: + """Top-level entry: creates graph + middleware, runs end-to-end.""" + + def __init__(self): + self.middleware = build_middleware() + self.graph = ReportGraph(middleware=self.middleware) + + async def run(self, state: ReportState) -> ReportState: + logger.info(f"[orchestrator] starting report {state.id}") + state = await self.graph.run(state) + logger.info( + f"[orchestrator] finished report {state.id} — " + f"status={'OK' if not state.error else 'FAILED'}, " + f"files={len(state.generated_files)}" + ) + return state diff --git a/app/pipeline/task.py b/app/pipeline/task.py new file mode 100644 index 0000000..90b0817 --- /dev/null +++ b/app/pipeline/task.py @@ -0,0 +1,22 @@ +"""Pipeline task model — thin wrapper, main state is now ReportState.""" + +from __future__ import annotations + +from app.graph.state import ReportState + + +def create_report_state( + requirement: str, + report_type: str = "行业分析报告", + extra_data: str = "", + output_formats: list[str] | None = None, + client_id: str | None = None, +) -> ReportState: + """Create a ReportState from user input.""" + return ReportState( + requirement=requirement, + report_type=report_type, + extra_data=extra_data, + output_formats=output_formats or ["docx"], + client_id=client_id, + ) diff --git a/app/templates/__init__.py b/app/templates/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data-sources-research.md b/data-sources-research.md new file mode 100644 index 0000000..648dd44 --- /dev/null +++ b/data-sources-research.md @@ -0,0 +1,456 @@ +# 咨询报告生成系统 - 数据源调研报告 + +> 调研日期: 2026-03-27 +> 目的: 为咨询报告自动生成系统选择可用的行业数据 API 和来源 + +--- + +## 一、免费/开放政府数据 API + +### 1.1 国家统计局 (data.stats.gov.cn) + +| 项目 | 详情 | +|------|------| +| **费用** | 完全免费 | +| **API 可用性** | 有非官方但稳定的 HTTP API | +| **API 地址** | `https://data.stats.gov.cn/easyquery.htm` | +| **请求方式** | POST / GET | +| **数据格式** | JSON | +| **数据覆盖** | 年度/季度/月度宏观经济数据,覆盖 GDP、CPI、PPI、工业产值、固定资产投资、社会消费品零售、进出口、人口、就业等全品类 | +| **更新频率** | 月度/季度/年度同步发布 | + +**API 关键参数:** +``` +m=getTree / QueryData +dbcode=hgnd(年度) / hgyd(月度) / hgjd(季度) +rowcode=zb (指标) +colcode=sj (时间) / reg (地区) +wds=[] (查询条件 JSON) +dfwds=[{"wdcode":"zb","valuecode":"A0101"}] (数据字段) +``` + +**咨询报告实用度: ★★★★★** +宏观经济数据的首选来源,权威且免费。缺点是没有正式 API 文档,接口可能变动。可通过 AKShare 库间接调用。 + +--- + +### 1.2 商务部公共服务资源平台 + +| 项目 | 详情 | +|------|------| +| **费用** | 免费 | +| **网址** | `http://opendata.mofcom.gov.cn/front/data` | +| **数据覆盖** | 对外贸易、外商投资、国内贸易、商务预报、国别报告 | +| **API 可用性** | 有开放数据接口,注册后可调用 | + +**咨询报告实用度: ★★★★** +外贸和商业分析类报告必备数据源。 + +--- + +### 1.3 中经数据 (国家信息中心) + +| 项目 | 详情 | +|------|------| +| **费用** | 数据流量付费模式(按调用量计费),价格未公开 | +| **网址** | `https://ceidata.cei.cn/` | +| **API 可用性** | 有正式 REST API,支持 Java / .NET / Python | +| **数据覆盖** | 全国 + 31省 + 330+城市 + 2800+县 + 200+国家/地区的宏观经济时序数据 | +| **更新频率** | 年/季/月/周/日多频率 | +| **数据来源** | 国家部委、地方政府、国际组织的权威统计数据 | + +**咨询报告实用度: ★★★★★** +数据粒度和权威性极高,从国家到县级全覆盖。适合需要区域经济分析的咨询报告。价格需联系获取。 + +--- + +### 1.4 地方政府开放数据平台 + +| 平台 | 网址 | 数据规模 | +|------|------|----------| +| **上海市** | `https://data.sh.gov.cn/` | 45个部门、2101个数据集、646个数据接口 | +| **深圳市** | `https://opendata.sz.gov.cn` | 114家单位、11,150个数据集、10,971个接口、59.86亿条数据 | +| **北京市** | `https://data.beijing.gov.cn/` | 多部门、多品类 | +| **全国 50+ 地市** | 各自平台 | 各有 API 接口,注册即可免费调用 | + +**咨询报告实用度: ★★★** +做区域性、城市级咨询报告时有用,但各平台接口标准不统一,集成成本较高。 + +--- + +### 1.5 海关总署统计数据查询平台 + +| 项目 | 详情 | +|------|------| +| **费用** | 免费查询(有限制) | +| **网址** | `http://stats.customs.gov.cn/` | +| **数据覆盖** | 按 HS 编码、进出口收发货人、贸易伙伴、贸易方式等多维组合的进出口统计 | +| **API 可用性** | 官方无正式 API,可通过网页接口解析获取 | + +**第三方商业 API:** 腾道 (tendata.cn) 等提供海关数据 API,支持按产品/HS编码/国家/时间过滤。 + +**咨询报告实用度: ★★★★** +进出口贸易分析类报告必用。官方免费版数据维度有限,深度分析需采购第三方。 + +--- + +## 二、国内商业金融/行业数据 API + +### 2.1 Wind 万得 + +| 项目 | 详情 | +|------|------| +| **费用** | **金融终端:** 39,800元/年/席位(单买),批量采购可降至 24,540元/年 | +| | **经济数据库:** 34,600元/年/席位 | +| | **机构数据接口:** 5-20万/年(单用户),批量可降至 2-8万/年 | +| **API 格式** | Client API (C++/C#/Java/Python),需安装 Wind 终端 | +| **数据覆盖** | 全市场金融数据、宏观经济、行业数据、公司财务、ESG、资讯舆情、专题特色数据 | +| **数据权威性** | 中国金融数据的"黄金标准",覆盖最全 | + +**咨询报告实用度: ★★★★★** +如果预算允许,这是最全面的中国金融和行业数据源。但价格昂贵,适合机构级使用。API 必须绑定 Wind 终端,无法纯云端调用。 + +--- + +### 2.2 同花顺 iFinD + +| 项目 | 详情 | +|------|------| +| **费用** | 8,800 - 28,000元/年/席位(远低于 Wind) | +| **API 格式** | Python/Java/C++ 等语言 SDK | +| **数据覆盖** | 股票、债券、外汇、期货、基金、REITs、宏观经济、企业数据库、研究报告 | +| **适用对象** | 机构投资者和专业用户,需企业级账号 | + +**咨询报告实用度: ★★★★** +Wind 的性价比替代品。数据覆盖广泛,价格约为 Wind 的 1/3 到 1/2。 + +--- + +### 2.3 东方财富 Choice + +| 项目 | 详情 | +|------|------| +| **费用** | 官方定价 38,000元/年,**推广价 5,800元/年** | +| **API 格式** | 函数调用方式,支持 Matlab/C++/C#/R/Python (Win/Linux/Mac) | +| **数据覆盖** | 基本面、财务数据、行情数据、宏观经济 | +| **API 文档** | `https://quantapi.eastmoney.com/Manual` | + +**咨询报告实用度: ★★★★** +推广价 5,800元/年是三大金融终端中最便宜的,跨平台支持好。适合中小团队。 + +--- + +### 2.4 天眼查 开放平台 + +| 项目 | 详情 | +|------|------| +| **费用** | 按次调用 + 套餐两种模式,具体价格需登录平台查看(典型范围: 0.1-2元/次,视接口而定) | +| **API 地址** | `https://open.tianyancha.com/` | +| **数据覆盖** | 企业基本信息、股东/股权、财务报表、法律诉讼、知识产权、经营异常、招投标等 | +| **认证方式** | Token + RESTful API | +| **免费额度** | 注册后有少量免费试用额度 | + +**咨询报告实用度: ★★★★★** +企业尽职调查、竞争格局分析类报告的核心数据源。覆盖全国企业工商信息。 + +--- + +### 2.5 企查查 开放平台 + +| 项目 | 详情 | +|------|------| +| **费用** | 按次计费(价格未公开,需联系销售),新用户有 20次免费测试 | +| **API 地址** | `https://openapi.qcc.com/` | +| **数据覆盖** | 企业高级搜索、工商详情、专利查询、商标查询、经营风险 | +| **计费方式** | 固定企业列表+按年计费(每企业每周期只收一次)或按次 | + +**咨询报告实用度: ★★★★** +与天眼查功能类似,二选一即可。企查查在某些企业关联数据上更全。 + +--- + +### 2.6 巨潮资讯 (cninfo.com.cn) + +| 项目 | 详情 | +|------|------| +| **费用** | 注册后 1000次免费调用;深证信平台有更多接口 | +| **API 地址** | `http://webapi.cninfo.com.cn/` | +| **公告查询** | `http://www.cninfo.com.cn/new/hisAnnouncement/query` | +| **数据覆盖** | 上市公司公告全文、财务数据、基金、债券 | +| **数据格式** | JSON | + +**咨询报告实用度: ★★★★** +上市公司分析的权威一手来源。公告全文可用于 LLM 提取和分析。1000次免费额度足够初期使用。 + +--- + +## 三、免费开源 Python 数据库 + +### 3.1 AKShare (强烈推荐) + +| 项目 | 详情 | +|------|------| +| **费用** | **完全免费开源** | +| **安装** | `pip install akshare` | +| **GitHub** | `https://github.com/akfamily/akshare` | +| **文档** | `https://akshare.akfamily.xyz/` | +| **数据覆盖** | **30+ 类金融产品**: A股/港股/美股行情、期货、期权、基金、债券、外汇、加密货币、**宏观经济指标**、**行业数据**、新闻舆情 | +| **数据格式** | Pandas DataFrame | +| **更新频率** | 持续更新(当前版本 1.18.47) | + +**核心宏观经济数据接口示例:** +```python +import akshare as ak + +# GDP 数据 +gdp = ak.macro_china_gdp() + +# CPI 数据 +cpi = ak.macro_china_cpi_monthly() + +# PMI 数据 +pmi = ak.macro_china_pmi() + +# 行业利润数据 +profit = ak.macro_china_industrial_profit() + +# 中国海关进出口 +trade = ak.macro_china_trade_balance() +``` + +**咨询报告实用度: ★★★★★** +**咨询报告系统的首选数据获取层。** 免费、覆盖广、接口统一、返回 DataFrame 直接可分析。作为国家统计局等公开数据的统一封装层,可替代大部分付费数据源的基础数据需求。 + +--- + +### 3.2 Tushare Pro + +| 项目 | 详情 | +|------|------| +| **费用** | 基础免费(积分制),高级接口需积分(约 500元一次性购买可获足够积分) | +| **数据覆盖** | A 股行情、财务数据、基金、期货、宏观经济 | +| **注意** | 2025年9月后部分接口调整,积分获取难度增加 | + +**咨询报告实用度: ★★★** +AKShare 的替代方案,但积分制有限制。推荐优先使用 AKShare。 + +--- + +### 3.3 Baostock + +| 项目 | 详情 | +|------|------| +| **费用** | 完全免费 | +| **数据覆盖** | A 股历史行情、财务数据(较基础) | + +**咨询报告实用度: ★★** +数据覆盖面较窄,仅 A 股基础数据。不推荐作为主力数据源。 + +--- + +## 四、国际数据源 API + +### 4.1 World Bank Open Data API + +| 项目 | 详情 | +|------|------| +| **费用** | **完全免费** | +| **API 地址** | `https://api.worldbank.org/v2/` | +| **数据格式** | JSON / XML | +| **数据覆盖** | 1,600+ 指标、217 个经济体、60+ 年历史数据 | +| **中国数据** | GDP、人口、通胀、贸易、教育、卫生、环境等全方位 | +| **更新频率** | 定期更新(最新至 2024年) | + +**已验证可用的示例调用:** +``` +GET https://api.worldbank.org/v2/country/CHN/indicator/NY.GDP.MKTP.CD?format=json&per_page=5 +``` +返回中国 GDP 时序数据(2024年: $18.74 万亿)。 + +**咨询报告实用度: ★★★★★** +国际对比和宏观经济分析的标准数据源。免费、文档完善、数据权威。 + +--- + +### 4.2 IMF Data API + +| 项目 | 详情 | +|------|------| +| **费用** | **完全免费** | +| **API 地址** | `https://dataservices.imf.org/REST/SDMX_JSON.svc/` | +| **数据覆盖** | International Financial Statistics (IFS)、Balance of Payments、Government Finance、Direction of Trade | +| **WEO 数据库** | 半年更新(4月/10月),含未来 2年预测 | + +**咨询报告实用度: ★★★★** +宏观经济和国际金融分析。与 World Bank 互补。 + +--- + +### 4.3 Statista API + +| 项目 | 详情 | +|------|------| +| **费用** | 基础免费;**Premium: $199-$1,299/月(年付)**;API 接入需企业级合同(价格需议) | +| **API 地址** | `https://www.statista.com/api/v2/doc/` | +| **数据覆盖** | 170+ 行业、150+ 国家、300万+ 统计数据 | +| **数据特点** | 图表友好型统计数据、行业报告、消费者调查 | + +**咨询报告实用度: ★★★★** +行业分析图表和统计数据的优质来源。API 价格不透明,需要商务洽谈。免费版可获取有限数据。 + +--- + +### 4.4 CB Insights + +| 项目 | 详情 | +|------|------| +| **费用** | **$50,000 - $265,000+/年**(无公开定价,需 demo) | +| **API 可用性** | **无 API** | +| **数据覆盖** | VC/PE 投融资、科技行业趋势、市场规模预测 | + +**咨询报告实用度: ★★★** +内容优质但极其昂贵且无 API。不适合自动化集成,只能手动引用报告。 + +--- + +### 4.5 PitchBook + +| 项目 | 详情 | +|------|------| +| **费用** | $12,000+/年(单用户起步),大型机构 $50,000+/月 | +| **API 可用性** | **有 API**,可提取 VC/PE/M&A 数据 | +| **数据覆盖** | 全球私募市场、并购、风投交易数据 | + +**咨询报告实用度: ★★★** +投融资分析类报告有用,但价格高。有 API 这一点优于 CB Insights。 + +--- + +## 五、专业/垂直数据源 + +### 5.1 专利数据库 (知识产权数据) + +| 平台 | 费用 | API | 特点 | +|------|------|-----|------| +| **CNIPA (国家知识产权局)** | 免费查询 | 有数据发布页,无正式 REST API | 官方权威,专利公告/统计 | +| **CNIPR (知识产权出版社)** | 注册免费 + 增值付费 | `https://open.cnipr.com/` REST API | 专利检索、查询、统计、分析 | +| **佰腾 (Baiten)** | 按次付费 | `https://open.baiten.cn/` | 法律状态、引用数据 | +| **专利汇 (PatentHub)** | 免费+付费 | `https://www.patenthub.cn/api/` | 基本信息、权利要求、全文、引用、相似专利 | +| **incoPat** | 商业授权 | 有 API | 全球专利数据库,分析功能强 | +| **天眼查专利模块** | 集成在天眼查 API 中 | 同天眼查 | 企业专利关联查询 | + +**咨询报告实用度: ★★★★** +技术行业分析、竞争格局分析的重要数据维度。推荐 CNIPR 或 PatentHub。 + +--- + +### 5.2 海关/贸易数据 + +| 平台 | 费用 | 特点 | +|------|------|------| +| **海关总署官方** | 免费 | `stats.customs.gov.cn`,查询维度有限 | +| **商务部数据中心** | 免费 | `data.mofcom.gov.cn`,进出口国别数据 | +| **腾道 (Tendata)** | 商业付费 | 海关数据 API,支持按 HS 编码/产品/国家过滤 | +| **UN Comtrade** | 免费 API | `https://comtrade.un.org/data/`,联合国全球贸易数据库 | + +**咨询报告实用度: ★★★★** +外贸和产业链分析报告核心数据源。 + +--- + +### 5.3 行业协会报告 + +| 来源 | 特点 | 费用 | +|------|------|------| +| **艾瑞咨询 (iResearch)** | 互联网/科技行业研究报告 | 部分免费,深度报告付费 | +| **易观分析 (Analysys)** | 数字经济行业数据和报告 | 部分免费,会员制 | +| **前瞻产业研究院** | 全行业覆盖的研究报告 | 单报告 ¥2,000-10,000+ | +| **头豹研究院** | 新兴行业深度分析 | 会员制 | +| **智研咨询** | 传统行业研究报告 | 单报告付费 | +| **中国信通院 (CAICT)** | ICT 行业权威白皮书 | 大部分免费 | + +**注意:** 这些来源通常无 API,需要手动获取 PDF/网页报告,然后由 LLM 提取和结构化。 + +--- + +## 六、推荐方案 (成本优先) + +### 第一梯队: 免费核心数据层 (零成本) + +| 用途 | 推荐工具 | 说明 | +|------|----------|------| +| 宏观经济数据 | **AKShare** (封装国家统计局等) | 一行代码获取 GDP/CPI/PMI/行业利润 | +| 国际对比数据 | **World Bank API** | 免费、文档完善、217经济体 | +| 国际金融数据 | **IMF API** | 免费,WEO 预测数据 | +| 上市公司公告 | **巨潮资讯 API** | 1000次免费,公告全文 | +| 全球贸易数据 | **UN Comtrade API** | 免费,全球进出口 | + +**年成本: 0 元** +可覆盖: 宏观经济分析、行业趋势、国际对比、上市公司基本面 + +--- + +### 第二梯队: 低成本增强层 + +| 用途 | 推荐工具 | 年费 | +|------|----------|------| +| 金融终端数据 | **东方财富 Choice** | ~5,800元/年(推广价) | +| 企业信息 | **天眼查 API** | 按需充值,预估 2,000-10,000元/年 | +| 行业报告 | **Statista 基础版** | ~$199/月 = ~17,000元/年 | + +**年成本: ~25,000-33,000 元** +新增覆盖: 深度财务数据、企业尽调、国际行业统计 + +--- + +### 第三梯队: 专业级全覆盖 + +| 用途 | 推荐工具 | 年费 | +|------|----------|------| +| 金融数据终端 | **Wind 万得** 或 **iFinD** | 10,000-40,000元/年 | +| 专利分析 | **CNIPR / PatentHub** | 按需 | +| 海关详细数据 | **腾道** | 商业议价 | + +**年成本: 50,000+ 元** +适合: 专业咨询公司级别使用 + +--- + +## 七、技术集成建议 + +对于咨询报告自动生成系统,推荐的数据获取架构: + +``` +┌─────────────────────────────────────────┐ +│ 数据获取调度层 │ +│ (统一接口,缓存,频率控制,错误重试) │ +├─────────────┬───────────┬───────────────┤ +│ AKShare │ World Bank│ 巨潮资讯 │ ← 免费层 +│ (宏观+行业)│ (国际对比) │ (上市公司) │ +├─────────────┼───────────┼───────────────┤ +│ Choice API │ 天眼查API │ Statista │ ← 付费层(按需) +│ (金融数据) │ (企业数据) │ (行业统计) │ +├─────────────┼───────────┼───────────────┤ +│ CNIPR │ 腾道 │ Wind/iFinD │ ← 专业层(高预算) +│ (专利) │ (海关) │ (全数据) │ +└─────────────┴───────────┴───────────────┘ + │ + ▼ +┌─────────────────────────────────────────┐ +│ 数据标准化 + 缓存层 │ +│ (DataFrame → 统一格式 → 本地缓存) │ +└─────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────┐ +│ LLM 报告生成层 │ +│ (数据注入 → Prompt → 报告章节生成) │ +└─────────────────────────────────────────┘ +``` + +### 关键实现要点: +1. **AKShare 优先**: 凡是 AKShare 能获取的数据,就不调付费接口 +2. **本地缓存**: 宏观数据月度更新,不需要每次实时拉取,缓存 30 天 +3. **降级策略**: 付费接口不可用时,自动降级到免费数据源 +4. **频率控制**: 国家统计局等公开接口需控制请求频率(建议 1-2秒/次) +5. **数据标准化**: 不同来源数据统一为 DataFrame + 元数据(来源、时间、单位)格式 diff --git a/memory/global.json b/memory/global.json new file mode 100644 index 0000000..bfc2e12 --- /dev/null +++ b/memory/global.json @@ -0,0 +1,22 @@ +{ + "client_id": "global", + "preferences": {}, + "facts": [ + { + "id": "4332883e", + "content": "生成过报告:2025年全球半导体行业展望:供应链重构、中国自主化进程、AI芯片竞争格局及投资策略(类型:行业分析报告)", + "category": "context", + "confidence": 0.9, + "source": "5cd9c02f1dc7", + "created_at": "2026-03-28T00:39:01.304322" + }, + { + "id": "b5d69b81", + "content": "报告《2025年全球半导体行业展望:供应链重构、中国自主化进程、AI芯片竞争格局及投资策略》的研究轨道:Global Semiconductor Supply Chain Restructuring: US-Japan-EU-China Strategic Positioning、中国半导体自主化进展与瓶颈分析、Global AI Chip Competitive Landscape and Technology Roadmap、半导体产业投资策略与机会分析", + "category": "knowledge", + "confidence": 0.6, + "source": "5cd9c02f1dc7", + "created_at": "2026-03-28T00:39:01.305048" + } + ] +} \ No newline at end of file diff --git a/project.json b/project.json new file mode 100644 index 0000000..d514f35 --- /dev/null +++ b/project.json @@ -0,0 +1,11 @@ +{ + "name": "咨询报告 AI 生成系统", + "slug": "consulting-report-gen", + "status": "planning", + "created": "2026-03-27", + "category": "business", + "ports": {}, + "description": "借鉴 Open SWE 多 Agent 架构,为咨询公司构建安全可控的行业报告自动生成系统", + "tech": ["python", "docx", "pptx", "xlsx", "pdf"], + "notes": "数据安全为第一优先级,不允许客户数据外泄" +} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..f3f07f4 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,30 @@ +[project] +name = "consulting-report-gen" +version = "0.1.0" +description = "AI-powered consulting report generation system" +requires-python = ">=3.11" +dependencies = [ + "fastapi>=0.115", + "uvicorn[standard]>=0.34", + "litellm>=1.60", + "pydantic>=2.0", + "pydantic-settings>=2.0", + "python-docx>=1.1", + "python-pptx>=1.0", + "openpyxl>=3.1", + "fpdf2>=2.8", + "python-multipart>=0.0.18", + "aiofiles>=24.1", + "jinja2>=3.1", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0", + "pytest-asyncio>=0.24", + "httpx>=0.28", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" diff --git a/test_pipeline.py b/test_pipeline.py new file mode 100644 index 0000000..6b08245 --- /dev/null +++ b/test_pipeline.py @@ -0,0 +1,64 @@ +"""Quick test — full domain-aware bilingual pipeline.""" + +import asyncio +import logging + +from app.graph.state import ReportState +from app.pipeline.orchestrator import PipelineOrchestrator + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(name)s] %(levelname)s: %(message)s", +) + + +async def main(): + state = ReportState( + requirement="分析全球半导体行业2025年发展趋势,重点关注:1)全球供应链重构(美日欧中各自布局),2)中国半导体自主化进展与瓶颈,3)AI芯片竞争格局,4)投资建议", + report_type="行业分析报告", + output_formats=["docx"], + output_languages=["zh", "en"], # bilingual output + ) + + print(f"[test] Task ID: {state.id}") + print(f"[test] Requirement: {state.requirement[:80]}...") + print(f"[test] Languages: {state.output_languages}") + print() + + orchestrator = PipelineOrchestrator() + state = await orchestrator.run(state) + + print() + print(f"[test] Final node: {state.current_node}") + print(f"[test] Error: {state.error or 'None'}") + print(f"[test] Revisions: {state.revision_count}") + print() + + # Execution trace + print("[test] Execution trace:") + for entry in state.node_history: + print(f" {entry['timestamp'][:19]} | {entry['node']:20s} | {entry['status']:10s} | {entry.get('detail', '')}") + + # Parallel research — show domain/language/model allocation + if state.research_results: + print() + print(f"[test] Research tracks: {len(state.research_results)}") + for r in state.research_results: + ms = f" ({r.duration_ms}ms)" if r.duration_ms else "" + print(f" [{r.status.value:9s}] [{r.domain.value:6s}] [{r.native_language}] {r.description}{ms}") + + # Output files + if state.generated_files: + print() + print(f"[test] Generated files ({len(state.generated_files)}):") + for f in state.generated_files: + print(f" → {f}") + + # Review verdict + if state.review: + print() + print(f"[test] Review: score={state.review.get('overall_score')}, verdict={state.review.get('verdict')}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29