feat: stream subject packs by generation batch

This commit is contained in:
2026-05-19 21:31:47 +08:00
parent 47299396dc
commit 00df9d01fe
6 changed files with 531 additions and 81 deletions

View File

@@ -289,6 +289,7 @@ AssetSize = Literal["source", "1024", "1536", "2048"]
AssetQuality = Literal["hd"]
SubjectKind = Literal["object", "living"]
SubjectView = str
SubjectAssetStatus = Literal["queued", "in_progress", "completed", "failed"]
SceneMode = Literal["remove_subject", "similar", "style"]
SceneStyle = Literal["source", "premium_product", "clean_studio", "warm_lifestyle", "cinematic"]
SceneAssetRole = Literal["scene", "first_frame", "last_frame"]
@@ -462,6 +463,13 @@ class SubjectAsset(BaseModel):
size: AssetSize = "source"
source_frame_indices: list[int] = Field(default_factory=list)
ai_completed: bool = True
status: SubjectAssetStatus = "completed"
progress: int = 100
error: str = ""
pack_id: str = ""
pack_label: str = ""
pack_mode: str = ""
pack_created_at: float = 0.0
created_at: float = 0.0
@@ -1371,6 +1379,26 @@ async def lifespan(_: FastAPI):
audio_script=audio_script,
message="服务重启 · 上次音频处理已中断,可重新处理",
)
subject_generation_interrupted = False
recovered_frames = []
for f in job.frames:
for e in f.elements or []:
recovered_assets = []
for asset in e.subject_assets or []:
if asset.status in {"queued", "in_progress"}:
recovered_assets.append(asset.model_copy(update={
"status": "failed",
"progress": 100,
"error": "服务重启 · 上次主体生成已中断,可重新生成",
"ai_completed": False,
}))
subject_generation_interrupted = True
else:
recovered_assets.append(asset)
e.subject_assets = recovered_assets
recovered_frames.append(f)
if subject_generation_interrupted:
update(job, frames=recovered_frames, message="服务重启 · 上次主体生成已中断,可重新生成")
JOBS[p.name] = job
except Exception:
pass
@@ -4793,6 +4821,11 @@ class GenerateSubjectAssetsReq(BaseModel):
subject_profile: SubjectProfilePreference | None = None
prompt: str = ""
replace_views: bool = False
source_subject_brief: str = ""
pack_id: str = ""
pack_label: str = ""
pack_mode: str = ""
pack_created_at: float = 0.0
def _subject_profile_prompt_clause(profile: SubjectProfilePreference | None) -> str:
@@ -5252,8 +5285,195 @@ def cutout_element(job_id: str, idx: int, element_id: str) -> Job:
return job
def _subject_source_indices(req: GenerateSubjectAssetsReq, idx: int) -> list[int]:
source_indices = [int(x) for x in (req.source_frame_indices or [idx]) if isinstance(x, int) or str(x).isdigit()]
if idx not in source_indices:
source_indices = [idx] + source_indices
return list(dict.fromkeys(source_indices))[:12]
def _normalize_subject_pack_id(value: str, idx: int, element_id: str) -> str:
cleaned = "".join(ch for ch in (value or "").strip() if ch.isalnum() or ch in {"_", "-"})
return cleaned[:96] or f"subject_pack_{idx:03d}_{element_id}_{uuid.uuid4().hex[:8]}"
def _update_subject_asset_status(
job_id: str,
idx: int,
element_id: str,
asset_id: str,
*,
status: SubjectAssetStatus,
progress: int,
error: str = "",
message: str = "",
) -> None:
job = JOBS.get(job_id)
if not job:
return
new_frames = []
for f in job.frames:
if f.index == idx:
for e in f.elements:
if e.id == element_id:
updated_assets = []
for asset in e.subject_assets or []:
if asset.id == asset_id:
updated_assets.append(asset.model_copy(update={
"status": status,
"progress": max(0, min(100, int(progress))),
"error": error,
"ai_completed": status == "completed",
}))
else:
updated_assets.append(asset)
e.subject_assets = updated_assets
new_frames.append(f)
update(job, frames=new_frames, message=message or job.message, error=error if status == "failed" else job.error)
def _subject_assets_background_worker(
job_id: str,
idx: int,
element_id: str,
req: GenerateSubjectAssetsReq,
queued: list[tuple[SubjectView, str, str]],
) -> None:
if req.reconstruction_mode == "similar" and not req.source_subject_brief.strip():
try:
req.source_subject_brief = _describe_source_subject(job_id, _subject_source_indices(req, idx))
except Exception as e:
print(f"[subject assets] source brief failed job={job_id} error={e}", flush=True)
for position, (view, view_label, placeholder_id) in enumerate(queued, start=1):
_update_subject_asset_status(
job_id,
idx,
element_id,
placeholder_id,
status="in_progress",
progress=10,
message=f"主体资产生成中 · {view_label} · {position}/{len(queued)}",
)
one_req = req.model_copy(deep=True)
one_req.views = [view]
one_req.replace_views = True
try:
_generate_subject_assets_sync(job_id, idx, element_id, one_req)
except HTTPException as e:
detail = str(e.detail)
_update_subject_asset_status(
job_id,
idx,
element_id,
placeholder_id,
status="failed",
progress=100,
error=detail,
message=f"主体资产生成失败 · {view_label}",
)
except Exception as e:
detail = str(e)
_update_subject_asset_status(
job_id,
idx,
element_id,
placeholder_id,
status="failed",
progress=100,
error=detail,
message=f"主体资产生成失败 · {view_label}",
)
@app.post("/jobs/{job_id}/frames/{idx}/elements/{element_id}/subject-assets", response_model=Job)
def generate_subject_assets(job_id: str, idx: int, element_id: str, req: GenerateSubjectAssetsReq) -> Job:
"""提交主体多视角生成任务,立即返回占位卡;后台逐张生成并逐张写回。"""
job = JOBS.get(job_id)
if not job:
raise HTTPException(404, "job not found")
frame = _find_frame(job, idx)
el = next((e for e in frame.elements if e.id == element_id), None)
if not el:
raise HTTPException(404, "element not found")
views = _subject_view_labels(req.subject_kind, req.views)
source_indices = _subject_source_indices(req, idx)
target_views = {view for view, _label in views}
now = time.time()
explicit_pack_id = bool((req.pack_id or "").strip())
pack_id = _normalize_subject_pack_id(req.pack_id, idx, element_id)
pack_label = (req.pack_label or "").strip()[:120] or f"{el.name_zh} · 主体套图"
pack_mode = (req.pack_mode or "").strip()[:40] or req.subject_style
pack_created_at = req.pack_created_at or now
placeholders: list[SubjectAsset] = []
queued: list[tuple[SubjectView, str, str]] = []
for view, view_label in views:
asset_id = f"subject_{idx:03d}_{element_id}_{view}_{uuid.uuid4().hex[:8]}"
placeholders.append(SubjectAsset(
id=asset_id,
view=view,
label=f"{el.name_zh} · {view_label}",
url="",
width=0,
height=0,
background=req.background,
quality=req.quality,
size=req.size,
source_frame_indices=source_indices,
ai_completed=False,
status="queued",
progress=0,
error="",
pack_id=pack_id,
pack_label=pack_label,
pack_mode=pack_mode,
pack_created_at=pack_created_at,
created_at=now,
))
queued.append((view, view_label, asset_id))
new_frames = []
for f in job.frames:
if f.index == idx:
for e in f.elements:
if e.id == element_id:
e.subject_kind = req.subject_kind
e.cutout_background = req.background
current_assets = e.subject_assets or []
if req.replace_views:
for old_asset in current_assets:
should_replace = old_asset.view in target_views and (
old_asset.pack_id == pack_id if explicit_pack_id else True
)
if should_replace and old_asset.url:
_delete_subject_asset_file(job_id, old_asset.id)
current_assets = [
asset for asset in current_assets
if not (
asset.view in target_views and (
asset.pack_id == pack_id if explicit_pack_id else True
)
)
]
e.subject_assets = current_assets + placeholders
new_frames.append(f)
update(job, frames=new_frames, message=f"主体资产已提交 · {el.name_zh} · {len(placeholders)} 张逐张生成中", error="")
worker_req = req.model_copy(deep=True)
worker_req.views = [view for view, _label in views]
worker_req.pack_id = pack_id
worker_req.pack_label = pack_label
worker_req.pack_mode = pack_mode
worker_req.pack_created_at = pack_created_at
threading.Thread(
target=_subject_assets_background_worker,
args=(job_id, idx, element_id, worker_req, queued),
daemon=True,
).start()
return job
def _generate_subject_assets_sync(job_id: str, idx: int, element_id: str, req: GenerateSubjectAssetsReq) -> Job:
"""为一个主体生成多视角资产包。
如果传入 source_frame_indices 或内置 character_id则把多张参考图作为独立 image[] 证据提交。"""
import time as _time
@@ -5265,10 +5485,7 @@ def generate_subject_assets(job_id: str, idx: int, element_id: str, req: Generat
if not el:
raise HTTPException(404, "element not found")
source_indices = [int(x) for x in (req.source_frame_indices or [idx]) if isinstance(x, int) or str(x).isdigit()]
if idx not in source_indices:
source_indices = [idx] + source_indices
source_indices = list(dict.fromkeys(source_indices))[:12]
source_indices = _subject_source_indices(req, idx)
similar_mode = req.reconstruction_mode == "similar"
character_reference_paths: list[Path] = []
@@ -5311,7 +5528,11 @@ def generate_subject_assets(job_id: str, idx: int, element_id: str, req: Generat
tmp_focus: Path | None = None
model_src: Path | list[Path] | None = None
frame_reference_paths = [p for p in (_source_frame_path(job_id, i) for i in source_indices) if p.exists()]
source_subject_brief = _describe_source_subject(job_id, source_indices) if similar_mode else ""
source_subject_brief = (
_ensure_english(req.source_subject_brief.strip())
if similar_mode and req.source_subject_brief.strip()
else (_describe_source_subject(job_id, source_indices) if similar_mode else "")
)
source_subject_clause = (
f"Source video role brief from selected keyframes: {source_subject_brief}. "
"Use this brief to preserve role category, creator-ad energy, camera readability, and broad styling, while creating a new non-identical subject. "
@@ -5484,6 +5705,13 @@ def generate_subject_assets(job_id: str, idx: int, element_id: str, req: Generat
quality=req.quality,
size=req.size,
source_frame_indices=source_indices,
status="completed",
progress=100,
error="",
pack_id=req.pack_id,
pack_label=req.pack_label,
pack_mode=req.pack_mode,
pack_created_at=req.pack_created_at or _time.time(),
created_at=_time.time(),
))
finally:
@@ -5509,10 +5737,21 @@ def generate_subject_assets(job_id: str, idx: int, element_id: str, req: Generat
current_assets = e.subject_assets or []
if req.replace_views:
replaced_views = {asset.view for asset in generated}
replace_pack_id = (req.pack_id or "").strip()
for old_asset in current_assets:
if old_asset.view in replaced_views:
should_replace = old_asset.view in replaced_views and (
old_asset.pack_id == replace_pack_id if replace_pack_id else True
)
if should_replace:
_delete_subject_asset_file(job_id, old_asset.id)
current_assets = [asset for asset in current_assets if asset.view not in replaced_views]
current_assets = [
asset for asset in current_assets
if not (
asset.view in replaced_views and (
asset.pack_id == replace_pack_id if replace_pack_id else True
)
)
]
final_assets = current_assets + generated
e.subject_assets = final_assets
if req.subject_kind == "living":