init repo
This commit is contained in:
104
app/graph/state.py
Normal file
104
app/graph/state.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Report generation graph state — the shared context that flows through all nodes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class NodeStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class ContentDomain(str, Enum):
|
||||
"""Content domain — determines which model and language to use."""
|
||||
GLOBAL = "global" # International markets, global trends → English-native
|
||||
CHINA = "china" # Chinese market, domestic policy → Chinese-native
|
||||
REASONING = "reasoning" # Synthesis, review, strategy → strongest reasoning
|
||||
FAST = "fast" # Data processing, charts → cost-effective
|
||||
TRANSLATION = "translation" # EN↔ZH translation
|
||||
|
||||
|
||||
class SubtaskResult(BaseModel):
|
||||
"""Result from a parallel research subtask."""
|
||||
task_id: str = Field(default_factory=lambda: uuid.uuid4().hex[:8])
|
||||
description: str = ""
|
||||
domain: ContentDomain = ContentDomain.GLOBAL
|
||||
native_language: str = "en" # "en" or "zh" — the original writing language
|
||||
status: NodeStatus = NodeStatus.PENDING
|
||||
content: dict[str, Any] = Field(default_factory=dict)
|
||||
error: str | None = None
|
||||
started_at: datetime | None = None
|
||||
completed_at: datetime | None = None
|
||||
|
||||
@property
|
||||
def duration_ms(self) -> int | None:
|
||||
if self.started_at and self.completed_at:
|
||||
return int((self.completed_at - self.started_at).total_seconds() * 1000)
|
||||
return None
|
||||
|
||||
|
||||
class ReportState(BaseModel):
|
||||
"""Full state for a report generation run.
|
||||
|
||||
This is the single source of truth that all graph nodes read from and write to.
|
||||
"""
|
||||
# Identity
|
||||
id: str = Field(default_factory=lambda: uuid.uuid4().hex[:12])
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
|
||||
# --- Input (set once at start) ---
|
||||
requirement: str = ""
|
||||
report_type: str = "行业分析报告"
|
||||
extra_data: str = ""
|
||||
output_formats: list[str] = Field(default=["docx"])
|
||||
output_languages: list[str] = Field(default=["zh", "en"]) # produce both versions
|
||||
template_name: str | None = None
|
||||
client_id: str | None = None # for multi-tenant isolation
|
||||
|
||||
# --- Middleware injections ---
|
||||
client_context: str = "" # injected by ClientContextMiddleware
|
||||
memory_facts: list[str] = Field(default_factory=list) # injected by MemoryMiddleware
|
||||
token_budget: int = 120000 # managed by TokenBudgetMiddleware
|
||||
|
||||
# --- Lead Agent output ---
|
||||
decomposition: dict[str, Any] = Field(default_factory=dict)
|
||||
# e.g. {"tracks": [{"title": "政策环境", "prompt": "研究..."}, ...]}
|
||||
|
||||
# --- Parallel research results ---
|
||||
research_results: list[SubtaskResult] = Field(default_factory=list)
|
||||
|
||||
# --- Writer output ---
|
||||
draft: dict[str, Any] = Field(default_factory=dict) # primary language version
|
||||
draft_translated: dict[str, Any] = Field(default_factory=dict) # translated version
|
||||
|
||||
# --- Data Agent output ---
|
||||
data_assets: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# --- Reviewer output ---
|
||||
review: dict[str, Any] = Field(default_factory=dict)
|
||||
revision_count: int = 0
|
||||
max_revisions: int = 2
|
||||
|
||||
# --- Formatter output ---
|
||||
generated_files: list[str] = Field(default_factory=list)
|
||||
|
||||
# --- Execution tracking ---
|
||||
current_node: str = ""
|
||||
node_history: list[dict[str, Any]] = Field(default_factory=list)
|
||||
error: str | None = None
|
||||
|
||||
def log_node(self, node_name: str, status: NodeStatus, detail: str = ""):
|
||||
self.node_history.append({
|
||||
"node": node_name,
|
||||
"status": status.value,
|
||||
"detail": detail,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
})
|
||||
Reference in New Issue
Block a user