394 lines
14 KiB
Python
394 lines
14 KiB
Python
"""Graph nodes — each node is an async function: ReportState → ReportState.
|
|
|
|
Node layout (v2 — domain-aware, bilingual):
|
|
|
|
START
|
|
│
|
|
▼
|
|
[decompose] — Lead Agent 分解为并行研究轨道,每轨标注 domain + language
|
|
│
|
|
▼
|
|
[parallel_research] — N 个子 Agent 并行,每个用最适合该领域的模型
|
|
│ global tracks → Claude/GPT (English)
|
|
│ china tracks → DeepSeek/Qwen (Chinese)
|
|
▼
|
|
[write] — Writer 汇聚 → 生成主语言版本
|
|
│
|
|
▼
|
|
[translate] — 高质量翻译 → 生成另一语言版本
|
|
│
|
|
▼
|
|
[data] — Data Agent 生成图表/表格
|
|
│
|
|
▼
|
|
[review] — Reviewer 审查(双语)
|
|
│ ├─ pass → [format]
|
|
│ └─ revise → [write]
|
|
▼
|
|
[format] — 输出双语版本文件
|
|
│
|
|
▼
|
|
END
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from datetime import datetime
|
|
from typing import Any
|
|
|
|
from app.agents.base import BaseAgent
|
|
from app.agents.researcher import ResearcherAgent
|
|
from app.agents.writer import WriterAgent
|
|
from app.agents.data_agent import DataAgent
|
|
from app.agents.reviewer import ReviewerAgent
|
|
from app.agents.formatter import FormatterAgent
|
|
from app.config import settings
|
|
|
|
from .state import ReportState, SubtaskResult, NodeStatus, ContentDomain
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Node: decompose — Lead Agent decomposes into domain-tagged parallel tracks
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class DecomposeNode:
|
|
"""Analyzes requirement and decomposes into domain-aware research tracks."""
|
|
|
|
def __init__(self):
|
|
self.agent = BaseAgent()
|
|
self.agent.name = "lead"
|
|
self.agent.model = settings.model_for_domain("reasoning")
|
|
|
|
async def __call__(self, state: ReportState) -> ReportState:
|
|
state.current_node = "decompose"
|
|
state.log_node("decompose", NodeStatus.RUNNING)
|
|
|
|
system = """\
|
|
You are a senior consulting partner planning a global industry report.
|
|
|
|
Your job is to decompose the client's requirement into 2-6 parallel research tracks.
|
|
|
|
CRITICAL: Each track must be tagged with a content domain and native language:
|
|
|
|
- domain: "global" → international markets, global competition, technology trends, overseas benchmarks
|
|
→ native_language: "en" (English sources are 10-100x richer for global analysis)
|
|
|
|
- domain: "china" → Chinese domestic market, government policy, local competitors, China-specific data
|
|
→ native_language: "zh" (Chinese sources are authoritative for domestic analysis)
|
|
|
|
The PRINCIPLE: whichever language has the richest professional literature for that topic
|
|
should be the native language. The other language version will be translated later.
|
|
|
|
Output (JSON):
|
|
{
|
|
"title_en": "English report title",
|
|
"title_zh": "中文报告标题",
|
|
"report_type": "report type",
|
|
"tracks": [
|
|
{
|
|
"title": "track title (in native language)",
|
|
"domain": "global|china",
|
|
"native_language": "en|zh",
|
|
"focus": "research focus description",
|
|
"prompt": "detailed research instructions (MUST be in the native_language)",
|
|
"data_needs": ["required data/charts"]
|
|
}
|
|
],
|
|
"synthesis_guide": "How to merge all tracks into a coherent report (bilingual structure notes)",
|
|
"methodology": "Analysis methodology"
|
|
}"""
|
|
|
|
prompt = f"""\
|
|
## Client requirement
|
|
{state.requirement}
|
|
|
|
## Report type
|
|
{state.report_type}
|
|
|
|
## Additional data
|
|
{state.extra_data or "(none)"}
|
|
|
|
## Client context
|
|
{state.client_context or "(none)"}
|
|
|
|
Decompose into parallel research tracks with domain and language tags. Output JSON."""
|
|
|
|
result = await self.agent.call_llm_json(prompt, system=system)
|
|
state.decomposition = result
|
|
state.log_node("decompose", NodeStatus.COMPLETED,
|
|
f"{len(result.get('tracks', []))} tracks")
|
|
return state
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Node: parallel_research — domain-aware parallel execution
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class ParallelResearchNode:
|
|
"""Runs research subtasks in parallel, each using the optimal model for its domain."""
|
|
|
|
MAX_CONCURRENT = 5
|
|
|
|
async def _run_one(self, track: dict[str, Any]) -> SubtaskResult:
|
|
domain_str = track.get("domain", "global")
|
|
domain = ContentDomain(domain_str) if domain_str in ContentDomain.__members__.values() else ContentDomain.GLOBAL
|
|
native_lang = track.get("native_language", "en")
|
|
|
|
result = SubtaskResult(
|
|
description=track.get("title", ""),
|
|
domain=domain,
|
|
native_language=native_lang,
|
|
)
|
|
result.status = NodeStatus.RUNNING
|
|
result.started_at = datetime.now()
|
|
|
|
try:
|
|
# Select model based on domain
|
|
model = settings.model_for_domain(domain.value)
|
|
agent = ResearcherAgent(model=model, language=native_lang)
|
|
|
|
logger.info(
|
|
f"[parallel_research] track '{track.get('title')}' "
|
|
f"→ domain={domain.value}, lang={native_lang}, model={model}"
|
|
)
|
|
|
|
research = await agent.run({
|
|
"requirement": track["prompt"],
|
|
"report_type": track.get("focus", ""),
|
|
"extra_data": "",
|
|
})
|
|
result.content = research.get("research", {})
|
|
result.status = NodeStatus.COMPLETED
|
|
except Exception as e:
|
|
result.error = str(e)
|
|
result.status = NodeStatus.FAILED
|
|
logger.exception(f"Research track '{track.get('title')}' failed")
|
|
finally:
|
|
result.completed_at = datetime.now()
|
|
|
|
return result
|
|
|
|
async def __call__(self, state: ReportState) -> ReportState:
|
|
state.current_node = "parallel_research"
|
|
state.log_node("parallel_research", NodeStatus.RUNNING)
|
|
|
|
tracks = state.decomposition.get("tracks", [])
|
|
if not tracks:
|
|
state.log_node("parallel_research", NodeStatus.FAILED, "no tracks")
|
|
state.error = "Decomposition produced no research tracks"
|
|
return state
|
|
|
|
semaphore = asyncio.Semaphore(self.MAX_CONCURRENT)
|
|
|
|
async def bounded(track):
|
|
async with semaphore:
|
|
return await self._run_one(track)
|
|
|
|
logger.info(f"[parallel_research] launching {len(tracks)} tracks concurrently")
|
|
results = await asyncio.gather(*[bounded(t) for t in tracks])
|
|
state.research_results = list(results)
|
|
|
|
succeeded = sum(1 for r in results if r.status == NodeStatus.COMPLETED)
|
|
domains = {}
|
|
for r in results:
|
|
domains.setdefault(r.domain.value, []).append(r.native_language)
|
|
state.log_node("parallel_research", NodeStatus.COMPLETED,
|
|
f"{succeeded}/{len(tracks)} ok, domains={domains}")
|
|
return state
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Node: write — synthesize research into primary-language draft
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class WriteNode:
|
|
def __init__(self):
|
|
self.agent = WriterAgent()
|
|
|
|
async def __call__(self, state: ReportState) -> ReportState:
|
|
state.current_node = "write"
|
|
state.log_node("write", NodeStatus.RUNNING)
|
|
|
|
research_merged = []
|
|
for r in state.research_results:
|
|
if r.status == NodeStatus.COMPLETED:
|
|
research_merged.append({
|
|
"track": r.description,
|
|
"domain": r.domain.value,
|
|
"native_language": r.native_language,
|
|
"findings": r.content,
|
|
})
|
|
|
|
synthesis_guide = state.decomposition.get("synthesis_guide", "")
|
|
review_feedback = ""
|
|
if state.revision_count > 0 and state.review:
|
|
review_feedback = f"\n\n## Review feedback (revision {state.revision_count})\n"
|
|
for issue in state.review.get("issues", []):
|
|
review_feedback += f"- [{issue.get('severity')}] {issue.get('description')} → {issue.get('suggestion')}\n"
|
|
|
|
result = await self.agent.run({
|
|
"requirement": state.requirement,
|
|
"research": {
|
|
"title_en": state.decomposition.get("title_en", ""),
|
|
"title_zh": state.decomposition.get("title_zh", ""),
|
|
"methodology": state.decomposition.get("methodology", ""),
|
|
"tracks": research_merged,
|
|
"synthesis_guide": synthesis_guide,
|
|
},
|
|
"revision_feedback": review_feedback,
|
|
})
|
|
|
|
state.draft = result.get("draft", {})
|
|
state.log_node("write", NodeStatus.COMPLETED)
|
|
return state
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Node: translate — produce the other language version
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TranslateNode:
|
|
"""Translates the draft into the other language version."""
|
|
|
|
def __init__(self):
|
|
self.agent = BaseAgent()
|
|
self.agent.name = "translator"
|
|
self.agent.model = settings.model_for_domain("translation")
|
|
|
|
async def __call__(self, state: ReportState) -> ReportState:
|
|
state.current_node = "translate"
|
|
state.log_node("translate", NodeStatus.RUNNING)
|
|
|
|
if not state.draft or "en" not in state.output_languages:
|
|
state.log_node("translate", NodeStatus.COMPLETED, "skipped")
|
|
return state
|
|
|
|
draft_json = json.dumps(state.draft, ensure_ascii=False, indent=2)
|
|
|
|
# Detect primary language of draft
|
|
title = state.draft.get("title", "")
|
|
is_chinese_primary = any('\u4e00' <= c <= '\u9fff' for c in title)
|
|
|
|
if is_chinese_primary:
|
|
target_lang = "English"
|
|
source_lang = "Chinese"
|
|
else:
|
|
target_lang = "Chinese (Simplified)"
|
|
source_lang = "English"
|
|
|
|
system = f"""\
|
|
You are a world-class {source_lang} → {target_lang} translator specializing in
|
|
consulting and business reports.
|
|
|
|
Translation principles:
|
|
1. ACCURACY over fluency — every data point, percentage, and proper noun must be correct
|
|
2. Professional terminology — use standard {target_lang} business/industry terms
|
|
3. Preserve structure — keep the exact same JSON structure, only translate text values
|
|
4. Cultural adaptation — adjust phrasing for the target audience (not word-for-word)
|
|
5. Keep {{{{CHART:...}}}} and {{{{TABLE:...}}}} markers, translate their descriptions
|
|
|
|
Output the translated JSON with the exact same structure."""
|
|
|
|
prompt = f"""\
|
|
Translate this consulting report from {source_lang} to {target_lang}.
|
|
|
|
{draft_json}
|
|
|
|
Output the translated JSON."""
|
|
|
|
translated = await self.agent.call_llm_json(prompt, system=system, max_tokens=8192)
|
|
state.draft_translated = translated
|
|
state.log_node("translate", NodeStatus.COMPLETED,
|
|
f"{source_lang} → {target_lang}")
|
|
return state
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Node: data — generate charts and tables
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class DataNode:
|
|
def __init__(self):
|
|
self.agent = DataAgent()
|
|
|
|
async def __call__(self, state: ReportState) -> ReportState:
|
|
state.current_node = "data"
|
|
state.log_node("data", NodeStatus.RUNNING)
|
|
|
|
result = await self.agent.run({
|
|
"draft": state.draft,
|
|
"extra_data": state.extra_data,
|
|
})
|
|
|
|
state.data_assets = result.get("data_assets", {})
|
|
state.log_node("data", NodeStatus.COMPLETED)
|
|
return state
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Node: review — bilingual quality check
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class ReviewNode:
|
|
def __init__(self):
|
|
self.agent = ReviewerAgent()
|
|
|
|
async def __call__(self, state: ReportState) -> ReportState:
|
|
state.current_node = "review"
|
|
state.log_node("review", NodeStatus.RUNNING)
|
|
|
|
result = await self.agent.run({
|
|
"draft": state.draft,
|
|
"draft_translated": state.draft_translated,
|
|
"research": state.decomposition,
|
|
})
|
|
|
|
state.review = result.get("review", {})
|
|
state.log_node("review", NodeStatus.COMPLETED,
|
|
f"verdict={state.review.get('verdict', '?')}")
|
|
return state
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Node: format — render bilingual output files
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class FormatNode:
|
|
def __init__(self):
|
|
self.agent = FormatterAgent()
|
|
|
|
async def __call__(self, state: ReportState) -> ReportState:
|
|
state.current_node = "format"
|
|
state.log_node("format", NodeStatus.RUNNING)
|
|
|
|
all_files = []
|
|
|
|
# Primary version
|
|
result = await self.agent.run({
|
|
"draft": state.draft,
|
|
"data_assets": state.data_assets,
|
|
"output_dir": str(settings.output_dir / state.id / "primary"),
|
|
"output_formats": state.output_formats,
|
|
})
|
|
all_files.extend(result.get("generated_files", []))
|
|
|
|
# Translated version (if available)
|
|
if state.draft_translated:
|
|
result_tr = await self.agent.run({
|
|
"draft": state.draft_translated,
|
|
"data_assets": state.data_assets,
|
|
"output_dir": str(settings.output_dir / state.id / "translated"),
|
|
"output_formats": state.output_formats,
|
|
})
|
|
all_files.extend(result_tr.get("generated_files", []))
|
|
|
|
state.generated_files = all_files
|
|
state.log_node("format", NodeStatus.COMPLETED,
|
|
f"{len(all_files)} files")
|
|
return state
|