109 lines
3.7 KiB
Python
109 lines
3.7 KiB
Python
"""Graph builder — assembles nodes into an executable report generation graph.
|
|
|
|
No LangGraph dependency. Pure asyncio with a simple node-runner pattern.
|
|
|
|
v2: domain-aware, bilingual (translate node added)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from typing import Callable, Awaitable
|
|
|
|
from app.middleware.chain import MiddlewareChain
|
|
from .state import ReportState, NodeStatus
|
|
from .nodes import (
|
|
DecomposeNode,
|
|
ParallelResearchNode,
|
|
WriteNode,
|
|
TranslateNode,
|
|
DataNode,
|
|
ReviewNode,
|
|
FormatNode,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
NodeFn = Callable[[ReportState], Awaitable[ReportState]]
|
|
|
|
|
|
class ReportGraph:
|
|
"""Executable graph for report generation.
|
|
|
|
Graph structure (v2):
|
|
decompose → parallel_research → write → translate → data → review →
|
|
├─ pass → format → END
|
|
└─ revise → write → translate → data → review → ...
|
|
"""
|
|
|
|
def __init__(self, middleware: MiddlewareChain | None = None):
|
|
self.middleware = middleware
|
|
self.decompose = DecomposeNode()
|
|
self.parallel_research = ParallelResearchNode()
|
|
self.write = WriteNode()
|
|
self.translate = TranslateNode()
|
|
self.data = DataNode()
|
|
self.review = ReviewNode()
|
|
self.format = FormatNode()
|
|
|
|
async def _run_node(self, name: str, node: NodeFn, state: ReportState) -> ReportState:
|
|
"""Run a single node with error handling."""
|
|
try:
|
|
logger.info(f"[graph] entering node: {name}")
|
|
state = await node(state)
|
|
logger.info(f"[graph] completed node: {name}")
|
|
except Exception as e:
|
|
state.error = f"Node '{name}' failed: {e}"
|
|
state.log_node(name, NodeStatus.FAILED, str(e))
|
|
logger.exception(f"[graph] node '{name}' failed")
|
|
raise
|
|
return state
|
|
|
|
async def run(self, state: ReportState) -> ReportState:
|
|
"""Execute the full graph."""
|
|
|
|
# --- Middleware: before ---
|
|
if self.middleware:
|
|
state = await self.middleware.before(state)
|
|
|
|
try:
|
|
# 1. Decompose requirement into domain-tagged parallel tracks
|
|
state = await self._run_node("decompose", self.decompose, state)
|
|
|
|
# 2. Run parallel research (each track uses domain-optimal model)
|
|
state = await self._run_node("parallel_research", self.parallel_research, state)
|
|
|
|
# 3-6. Write → Translate → Data → Review (with revision loop)
|
|
while True:
|
|
state = await self._run_node("write", self.write, state)
|
|
state = await self._run_node("translate", self.translate, state)
|
|
state = await self._run_node("data", self.data, state)
|
|
state = await self._run_node("review", self.review, state)
|
|
|
|
verdict = state.review.get("verdict", "pass")
|
|
if verdict == "pass" or state.revision_count >= state.max_revisions:
|
|
if verdict != "pass":
|
|
logger.warning(
|
|
f"[graph] forcing pass after {state.revision_count} revisions"
|
|
)
|
|
break
|
|
|
|
# Revise — loop back to write
|
|
state.revision_count += 1
|
|
logger.info(
|
|
f"[graph] revision {state.revision_count}/{state.max_revisions}"
|
|
)
|
|
|
|
# 7. Format bilingual output files
|
|
state = await self._run_node("format", self.format, state)
|
|
|
|
except Exception:
|
|
# Error already logged in _run_node
|
|
pass
|
|
|
|
# --- Middleware: after ---
|
|
if self.middleware:
|
|
state = await self.middleware.after(state)
|
|
|
|
return state
|