111 lines
3.1 KiB
Python
111 lines
3.1 KiB
Python
"""API routes for report generation."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
|
|
from fastapi import APIRouter, HTTPException
|
|
from fastapi.responses import JSONResponse
|
|
from pydantic import BaseModel
|
|
|
|
from app.graph.state import ReportState
|
|
from app.pipeline.orchestrator import PipelineOrchestrator
|
|
from app.pipeline.task import create_report_state
|
|
|
|
logger = logging.getLogger(__name__)
|
|
router = APIRouter(prefix="/api")
|
|
|
|
# In-memory store (swap for DB later)
|
|
reports: dict[str, ReportState] = {}
|
|
orchestrator = PipelineOrchestrator()
|
|
|
|
|
|
class CreateReportRequest(BaseModel):
|
|
requirement: str
|
|
report_type: str = "行业分析报告"
|
|
extra_data: str = ""
|
|
output_formats: list[str] = ["docx"]
|
|
client_id: str | None = None
|
|
|
|
|
|
class ReportResponse(BaseModel):
|
|
id: str
|
|
current_node: str
|
|
error: str | None = None
|
|
generated_files: list[str] = []
|
|
node_history: list[dict] = []
|
|
revision_count: int = 0
|
|
|
|
|
|
def _to_response(state: ReportState) -> ReportResponse:
|
|
return ReportResponse(
|
|
id=state.id,
|
|
current_node=state.current_node,
|
|
error=state.error,
|
|
generated_files=state.generated_files,
|
|
node_history=state.node_history,
|
|
revision_count=state.revision_count,
|
|
)
|
|
|
|
|
|
@router.post("/reports", response_model=ReportResponse)
|
|
async def create_report(req: CreateReportRequest):
|
|
"""Create and execute a report generation pipeline."""
|
|
state = create_report_state(
|
|
requirement=req.requirement,
|
|
report_type=req.report_type,
|
|
extra_data=req.extra_data,
|
|
output_formats=req.output_formats,
|
|
client_id=req.client_id,
|
|
)
|
|
reports[state.id] = state
|
|
|
|
# Run the full graph (blocking for now, add task queue later)
|
|
state = await orchestrator.run(state)
|
|
reports[state.id] = state
|
|
|
|
if state.error:
|
|
raise HTTPException(status_code=500, detail=state.error)
|
|
|
|
return _to_response(state)
|
|
|
|
|
|
@router.get("/reports/{report_id}", response_model=ReportResponse)
|
|
async def get_report(report_id: str):
|
|
"""Get report status and results."""
|
|
state = reports.get(report_id)
|
|
if not state:
|
|
raise HTTPException(status_code=404, detail="Report not found")
|
|
return _to_response(state)
|
|
|
|
|
|
@router.get("/reports")
|
|
async def list_reports():
|
|
"""List all reports."""
|
|
return [_to_response(s) for s in reports.values()]
|
|
|
|
|
|
@router.get("/reports/{report_id}/detail")
|
|
async def get_report_detail(report_id: str):
|
|
"""Get full report detail including draft and research."""
|
|
state = reports.get(report_id)
|
|
if not state:
|
|
raise HTTPException(status_code=404, detail="Report not found")
|
|
return {
|
|
"id": state.id,
|
|
"requirement": state.requirement,
|
|
"decomposition": state.decomposition,
|
|
"research_results": [
|
|
{
|
|
"description": r.description,
|
|
"status": r.status.value,
|
|
"duration_ms": r.duration_ms,
|
|
}
|
|
for r in state.research_results
|
|
],
|
|
"draft": state.draft,
|
|
"review": state.review,
|
|
"generated_files": state.generated_files,
|
|
"node_history": state.node_history,
|
|
}
|