105 lines
3.7 KiB
Python
105 lines
3.7 KiB
Python
"""Report generation graph state — the shared context that flows through all nodes."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import uuid
|
|
from datetime import datetime
|
|
from enum import Enum
|
|
from typing import Any
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
|
|
class NodeStatus(str, Enum):
|
|
PENDING = "pending"
|
|
RUNNING = "running"
|
|
COMPLETED = "completed"
|
|
FAILED = "failed"
|
|
|
|
|
|
class ContentDomain(str, Enum):
|
|
"""Content domain — determines which model and language to use."""
|
|
GLOBAL = "global" # International markets, global trends → English-native
|
|
CHINA = "china" # Chinese market, domestic policy → Chinese-native
|
|
REASONING = "reasoning" # Synthesis, review, strategy → strongest reasoning
|
|
FAST = "fast" # Data processing, charts → cost-effective
|
|
TRANSLATION = "translation" # EN↔ZH translation
|
|
|
|
|
|
class SubtaskResult(BaseModel):
|
|
"""Result from a parallel research subtask."""
|
|
task_id: str = Field(default_factory=lambda: uuid.uuid4().hex[:8])
|
|
description: str = ""
|
|
domain: ContentDomain = ContentDomain.GLOBAL
|
|
native_language: str = "en" # "en" or "zh" — the original writing language
|
|
status: NodeStatus = NodeStatus.PENDING
|
|
content: dict[str, Any] = Field(default_factory=dict)
|
|
error: str | None = None
|
|
started_at: datetime | None = None
|
|
completed_at: datetime | None = None
|
|
|
|
@property
|
|
def duration_ms(self) -> int | None:
|
|
if self.started_at and self.completed_at:
|
|
return int((self.completed_at - self.started_at).total_seconds() * 1000)
|
|
return None
|
|
|
|
|
|
class ReportState(BaseModel):
|
|
"""Full state for a report generation run.
|
|
|
|
This is the single source of truth that all graph nodes read from and write to.
|
|
"""
|
|
# Identity
|
|
id: str = Field(default_factory=lambda: uuid.uuid4().hex[:12])
|
|
created_at: datetime = Field(default_factory=datetime.now)
|
|
|
|
# --- Input (set once at start) ---
|
|
requirement: str = ""
|
|
report_type: str = "行业分析报告"
|
|
extra_data: str = ""
|
|
output_formats: list[str] = Field(default=["docx"])
|
|
output_languages: list[str] = Field(default=["zh", "en"]) # produce both versions
|
|
template_name: str | None = None
|
|
client_id: str | None = None # for multi-tenant isolation
|
|
|
|
# --- Middleware injections ---
|
|
client_context: str = "" # injected by ClientContextMiddleware
|
|
memory_facts: list[str] = Field(default_factory=list) # injected by MemoryMiddleware
|
|
token_budget: int = 120000 # managed by TokenBudgetMiddleware
|
|
|
|
# --- Lead Agent output ---
|
|
decomposition: dict[str, Any] = Field(default_factory=dict)
|
|
# e.g. {"tracks": [{"title": "政策环境", "prompt": "研究..."}, ...]}
|
|
|
|
# --- Parallel research results ---
|
|
research_results: list[SubtaskResult] = Field(default_factory=list)
|
|
|
|
# --- Writer output ---
|
|
draft: dict[str, Any] = Field(default_factory=dict) # primary language version
|
|
draft_translated: dict[str, Any] = Field(default_factory=dict) # translated version
|
|
|
|
# --- Data Agent output ---
|
|
data_assets: dict[str, Any] = Field(default_factory=dict)
|
|
|
|
# --- Reviewer output ---
|
|
review: dict[str, Any] = Field(default_factory=dict)
|
|
revision_count: int = 0
|
|
max_revisions: int = 2
|
|
|
|
# --- Formatter output ---
|
|
generated_files: list[str] = Field(default_factory=list)
|
|
|
|
# --- Execution tracking ---
|
|
current_node: str = ""
|
|
node_history: list[dict[str, Any]] = Field(default_factory=list)
|
|
error: str | None = None
|
|
|
|
def log_node(self, node_name: str, status: NodeStatus, detail: str = ""):
|
|
self.node_history.append({
|
|
"node": node_name,
|
|
"status": status.value,
|
|
"detail": detail,
|
|
"timestamp": datetime.now().isoformat(),
|
|
})
|