"""Report generation graph state — the shared context that flows through all nodes.""" from __future__ import annotations import uuid from datetime import datetime from enum import Enum from typing import Any from pydantic import BaseModel, Field class NodeStatus(str, Enum): PENDING = "pending" RUNNING = "running" COMPLETED = "completed" FAILED = "failed" class ContentDomain(str, Enum): """Content domain — determines which model and language to use.""" GLOBAL = "global" # International markets, global trends → English-native CHINA = "china" # Chinese market, domestic policy → Chinese-native REASONING = "reasoning" # Synthesis, review, strategy → strongest reasoning FAST = "fast" # Data processing, charts → cost-effective TRANSLATION = "translation" # EN↔ZH translation class SubtaskResult(BaseModel): """Result from a parallel research subtask.""" task_id: str = Field(default_factory=lambda: uuid.uuid4().hex[:8]) description: str = "" domain: ContentDomain = ContentDomain.GLOBAL native_language: str = "en" # "en" or "zh" — the original writing language status: NodeStatus = NodeStatus.PENDING content: dict[str, Any] = Field(default_factory=dict) error: str | None = None started_at: datetime | None = None completed_at: datetime | None = None @property def duration_ms(self) -> int | None: if self.started_at and self.completed_at: return int((self.completed_at - self.started_at).total_seconds() * 1000) return None class ReportState(BaseModel): """Full state for a report generation run. This is the single source of truth that all graph nodes read from and write to. """ # Identity id: str = Field(default_factory=lambda: uuid.uuid4().hex[:12]) created_at: datetime = Field(default_factory=datetime.now) # --- Input (set once at start) --- requirement: str = "" report_type: str = "行业分析报告" extra_data: str = "" output_formats: list[str] = Field(default=["docx"]) output_languages: list[str] = Field(default=["zh", "en"]) # produce both versions template_name: str | None = None client_id: str | None = None # for multi-tenant isolation # --- Middleware injections --- client_context: str = "" # injected by ClientContextMiddleware memory_facts: list[str] = Field(default_factory=list) # injected by MemoryMiddleware token_budget: int = 120000 # managed by TokenBudgetMiddleware # --- Lead Agent output --- decomposition: dict[str, Any] = Field(default_factory=dict) # e.g. {"tracks": [{"title": "政策环境", "prompt": "研究..."}, ...]} # --- Parallel research results --- research_results: list[SubtaskResult] = Field(default_factory=list) # --- Writer output --- draft: dict[str, Any] = Field(default_factory=dict) # primary language version draft_translated: dict[str, Any] = Field(default_factory=dict) # translated version # --- Data Agent output --- data_assets: dict[str, Any] = Field(default_factory=dict) # --- Reviewer output --- review: dict[str, Any] = Field(default_factory=dict) revision_count: int = 0 max_revisions: int = 2 # --- Formatter output --- generated_files: list[str] = Field(default_factory=list) # --- Execution tracking --- current_node: str = "" node_history: list[dict[str, Any]] = Field(default_factory=list) error: str | None = None def log_node(self, node_name: str, status: NodeStatus, detail: str = ""): self.node_history.append({ "node": node_name, "status": status.value, "detail": detail, "timestamp": datetime.now().isoformat(), })