feat: add subject image agent workflow

This commit is contained in:
2026-05-20 12:51:02 +08:00
parent 3d198b024b
commit 35fc088375
5 changed files with 873 additions and 370 deletions

View File

@@ -106,6 +106,9 @@ IMAGE_FALLBACK_ENABLED = os.getenv("IMAGE_FALLBACK_ENABLED", "true").strip().low
IMAGE_MODEL = GPT_IMAGE_MODEL
PRODUCT_VIEW_MODEL = GPT_IMAGE_MODEL
SUBJECT_ASSET_IMAGE_MODEL = GPT_IMAGE_MODEL
SubjectModelBundle = Literal["gpt", "gemini"]
SUBJECT_AGENT_GPT_MODEL = gpt_model_env("SUBJECT_AGENT_GPT_MODEL", VISION_MODEL)
SUBJECT_AGENT_GEMINI_MODEL = os.getenv("SUBJECT_AGENT_GEMINI_MODEL", "gemini-2.5-flash").strip() or "gemini-2.5-flash"
SUBJECT_ASSET_IMAGE_MODELS = [GPT_IMAGE_MODEL] + (
[IMAGE_FALLBACK_MODEL] if IMAGE_FALLBACK_ENABLED and IMAGE_FALLBACK_MODEL and IMAGE_FALLBACK_MODEL != GPT_IMAGE_MODEL else []
)
@@ -734,6 +737,39 @@ class AudioScript(BaseModel):
created_at: float = 0.0
class SubjectAgentAnalysis(BaseModel):
model_bundle: SubjectModelBundle = "gpt"
model: str = ""
source_frame_indices: list[int] = Field(default_factory=list)
summary_zh: str = ""
summary_en: str = ""
generation_brief_en: str = ""
trait_chips: list[str] = Field(default_factory=list)
mode_options: list[str] = Field(default_factory=list)
questions: list[str] = Field(default_factory=list)
warnings: list[str] = Field(default_factory=list)
created_at: float = 0.0
class SubjectAgentMessage(BaseModel):
role: Literal["user", "assistant"] = "assistant"
content: str = ""
created_at: float = 0.0
class SubjectAgentState(BaseModel):
model_bundle: SubjectModelBundle = "gpt"
source_frame_indices: list[int] = Field(default_factory=list)
analysis: SubjectAgentAnalysis | None = None
messages: list[SubjectAgentMessage] = Field(default_factory=list)
selected_mode: Literal["realistic", "cartoon", "elements", "custom"] = "custom"
selected_traits: list[str] = Field(default_factory=list)
requirements_zh: str = ""
generation_prompt_en: str = ""
quantity: int = 6
updated_at: float = 0.0
class Job(BaseModel):
id: str
url: str
@@ -751,6 +787,7 @@ class Job(BaseModel):
storyboard_images: list[StoryboardImage] = Field(default_factory=list)
generated_videos: list[GeneratedVideo] = Field(default_factory=list)
product_refs: list[dict] = Field(default_factory=list)
subject_agent: SubjectAgentState = Field(default_factory=SubjectAgentState)
error: str = ""
@@ -3892,7 +3929,7 @@ def _image_path_to_data_url(path: Path) -> str:
return f"data:{media_type};base64,{base64.b64encode(path.read_bytes()).decode('ascii')}"
def _vision_brief_from_images(image_paths: list[Path], prompt: str, max_images: int = 8) -> str:
def _vision_brief_from_images(image_paths: list[Path], prompt: str, max_images: int = 8, model: str | None = None) -> str:
paths = [path for path in image_paths if path.exists()][:max_images]
if not paths:
return ""
@@ -3903,7 +3940,7 @@ def _vision_brief_from_images(image_paths: list[Path], prompt: str, max_images:
content.append({"type": "image_url", "image_url": {"url": _image_path_to_data_url(path)}})
try:
resp = llm().chat.completions.create(
model=VISION_MODEL,
model=model or VISION_MODEL,
messages=[{"role": "user", "content": content}],
response_format={"type": "json_object"},
temperature=0.1,
@@ -3977,12 +4014,170 @@ def _describe_subject_consensus_from_images(name: str, subject_style: str, image
return _vision_brief_from_images(image_paths, prompt, max_images=10)
def _subject_agent_model(bundle: SubjectModelBundle) -> str:
return SUBJECT_AGENT_GEMINI_MODEL if bundle == "gemini" else SUBJECT_AGENT_GPT_MODEL
def _subject_agent_image_model(bundle: SubjectModelBundle) -> str:
return IMAGE_FALLBACK_MODEL if bundle == "gemini" and IMAGE_FALLBACK_MODEL else GPT_IMAGE_MODEL
def _list_of_strings(value, limit: int = 18) -> list[str]:
if isinstance(value, list):
return [str(item).strip()[:80] for item in value if str(item).strip()][:limit]
if isinstance(value, str) and value.strip():
return [part.strip()[:80] for part in re.split(r"[,;\n]", value) if part.strip()][:limit]
return []
def _subject_agent_json_from_images(job_id: str, source_indices: list[int], bundle: SubjectModelBundle) -> dict:
paths = [_source_frame_path(job_id, idx) for idx in source_indices]
paths = [path for path in paths if path.exists()][:8]
if not paths or not LLM_API_KEY:
return {}
prompt = (
"You are the image-generation requirements agent for an SKG ad-subject reconstruction workspace. "
"Only analyze the attached reference images for future subject pack generation. Do not discuss video, audio, copywriting, download, or unrelated tasks. "
"The user may later choose whether to preserve the visible subject, preserve only the creative concept with a new person, mix selected elements, or create from a new description. "
"Output strict JSON only with these keys: summary_zh, summary_en, generation_brief_en, trait_chips, mode_options, questions, warnings. "
"summary_zh: 2-4 concise Chinese sentences describing visible subject, concept, outfit/material, camera usefulness. "
"summary_en and generation_brief_en: English only. generation_brief_en is a direct image-generation brief that preserves useful traits while avoiding copyrighted/identifying replication unless user explicitly selects source-locked mode. "
"trait_chips: 8-18 short Chinese selectable traits. Include identity category, anatomy/material, clothing, color, style, framing, and useful negative constraints. "
"mode_options: short Chinese labels for likely choices. questions: 2-4 Chinese questions to clarify generation. warnings: Chinese notes about identity/copyright/consistency risk."
)
content: list[dict] = [{"type": "text", "text": prompt}]
for path in paths:
content.append({"type": "image_url", "image_url": {"url": _image_path_to_data_url(path)}})
try:
resp = llm().chat.completions.create(
model=_subject_agent_model(bundle),
messages=[{"role": "user", "content": content}],
response_format={"type": "json_object"},
temperature=0.15,
max_tokens=1600,
)
raw = (resp.choices[0].message.content or "").strip()
if not raw:
raw = (getattr(resp.choices[0].message, "reasoning_content", "") or "").strip()
match = re.search(r"\{[\s\S]*\}", raw)
raw = match.group(0) if match else raw
data = json.loads(raw)
return data if isinstance(data, dict) else {}
except Exception as e:
print(f"[subject agent analyze failed] bundle={bundle} error={e}", flush=True)
return {}
def _subject_agent_analysis(job_id: str, source_indices: list[int], bundle: SubjectModelBundle) -> SubjectAgentAnalysis:
clean_indices = list(dict.fromkeys(int(idx) for idx in source_indices if isinstance(idx, int) or str(idx).isdigit()))[:8]
model = _subject_agent_model(bundle)
data = _subject_agent_json_from_images(job_id, clean_indices, bundle)
brief_en = _ensure_english(str(data.get("generation_brief_en") or data.get("summary_en") or "").strip()) if data else ""
if not data:
data = {
"summary_zh": "已接收参考帧,但模型没有返回可用结构化分析。你仍可以在下方描述要保留或改变的主体元素。",
"summary_en": "Reference frames were received, but no structured analysis was returned.",
"generation_brief_en": "Use the selected reference frames as visual evidence for a new consistent SKG ad subject pack. Keep neck and shoulder readability clear.",
"trait_chips": ["同一主体", "服装统一", "肩颈清晰", "白底", "六视图"],
"mode_options": ["形象锁定", "创意复刻", "元素混合", "自主描述"],
"questions": ["你要保留原主体外形,还是只保留创意模式?", "是否需要改变人物年龄、性别、服装或风格?"],
"warnings": ["模型分析失败时请用文字补充关键要求。"],
}
brief_en = str(data["generation_brief_en"])
return SubjectAgentAnalysis(
model_bundle=bundle,
model=model,
source_frame_indices=clean_indices,
summary_zh=str(data.get("summary_zh") or "").strip()[:1800],
summary_en=str(data.get("summary_en") or "").strip()[:1800],
generation_brief_en=brief_en[:2200],
trait_chips=_list_of_strings(data.get("trait_chips"), 24),
mode_options=_list_of_strings(data.get("mode_options"), 8),
questions=_list_of_strings(data.get("questions"), 8),
warnings=_list_of_strings(data.get("warnings"), 8),
created_at=time.time(),
)
def _subject_agent_message_update(state: SubjectAgentState, user_message: str) -> tuple[str, str, str, int, list[str]]:
current_req = state.requirements_zh.strip()
selected_traits = state.selected_traits[:20]
quantity = max(1, min(10, int(state.quantity or 6)))
qty_match = re.search(r"(\d{1,2})\s*张", user_message)
if qty_match:
quantity = max(1, min(10, int(qty_match.group(1))))
fallback_req = "".join(part for part in [current_req, user_message.strip()] if part).strip("")
fallback_prompt = _ensure_english(
"Subject image generation requirements: "
+ (fallback_req or "create a consistent SKG ad subject pack")
+ ". Keep one identity and one outfit bible across all generated views. "
+ (f"Selected traits: {', '.join(selected_traits)}." if selected_traits else "")
)
if not LLM_API_KEY:
return "已记录这条生图要求。继续补充要保留/删除的元素,确认后我会按当前要求生成。", fallback_req, fallback_prompt, quantity, selected_traits
system = (
"You are an SKG subject image-generation requirements agent. Your scope is only image generation for a subject view pack. "
"Do not answer unrelated video, audio, download, coding, copywriting, or general chat requests; redirect to subject image requirements. "
"Normalize the user's fuzzy Chinese request into precise generation constraints. "
"Return strict JSON with keys: assistant_message_zh, updated_requirements_zh, generation_prompt_en, quantity, selected_traits. "
"generation_prompt_en must be English and must enforce: one consistent identity, one consistent outfit bible, neck/shoulder readability, no text/watermarks/UI, and legal-safe reconstruction."
)
user_payload = {
"analysis": state.analysis.model_dump() if state.analysis else None,
"current_requirements_zh": current_req,
"current_generation_prompt_en": state.generation_prompt_en,
"current_quantity": quantity,
"selected_mode": state.selected_mode,
"selected_traits": selected_traits,
"recent_messages": [m.model_dump() for m in state.messages[-8:]],
"user_message": user_message,
}
try:
resp = llm().chat.completions.create(
model=_subject_agent_model(state.model_bundle),
messages=[
{"role": "system", "content": system},
{"role": "user", "content": json.dumps(user_payload, ensure_ascii=False)},
],
response_format={"type": "json_object"},
temperature=0.2,
max_tokens=1200,
)
raw = (resp.choices[0].message.content or "").strip()
match = re.search(r"\{[\s\S]*\}", raw)
data = json.loads(match.group(0) if match else raw)
assistant = str(data.get("assistant_message_zh") or "已记录这条生图要求。").strip()[:1200]
updated_req = str(data.get("updated_requirements_zh") or fallback_req).strip()[:2200]
prompt_en = _ensure_english(str(data.get("generation_prompt_en") or fallback_prompt).strip())[:2600]
out_quantity = max(1, min(10, int(data.get("quantity") or quantity)))
out_traits = _list_of_strings(data.get("selected_traits"), 24) or selected_traits
return assistant, updated_req, prompt_en, out_quantity, out_traits
except Exception as e:
print(f"[subject agent message failed] bundle={state.model_bundle} error={e}", flush=True)
return "已先按本地规则记录这条要求;模型回复失败时仍可直接生成。", fallback_req, fallback_prompt, quantity, selected_traits
# ---------- API 路由 ----------
class CreateJobReq(BaseModel):
url: str
class SubjectAgentAnalyzeReq(BaseModel):
model_bundle: SubjectModelBundle = "gpt"
source_frame_indices: list[int] = Field(default_factory=list)
class SubjectAgentMessageReq(BaseModel):
model_bundle: SubjectModelBundle = "gpt"
source_frame_indices: list[int] = Field(default_factory=list)
selected_mode: Literal["realistic", "cartoon", "elements", "custom"] = "custom"
selected_traits: list[str] = Field(default_factory=list)
requirements_zh: str = ""
message: str = ""
quantity: int = 6
class TranslateReq(BaseModel):
text: str
target: Literal["en", "zh"] = "en"
@@ -4451,6 +4646,70 @@ def get_job(job_id: str) -> Job:
return job_with_artifacts(job)
@app.post("/jobs/{job_id}/subject-agent/analyze", response_model=Job)
def analyze_subject_agent(job_id: str, req: SubjectAgentAnalyzeReq) -> Job:
job = JOBS.get(job_id)
if not job:
raise HTTPException(404, "job not found")
source_indices = [idx for idx in req.source_frame_indices if any(frame.index == idx for frame in job.frames)][:8]
if not source_indices:
raise HTTPException(400, "source_frame_indices required")
analysis = _subject_agent_analysis(job_id, source_indices, req.model_bundle)
state = job.subject_agent.model_copy(deep=True)
assistant_text = (
f"我已用 {req.model_bundle.upper()} 套件分析这些参考帧。"
"你可以选择形象锁定、创意复刻、元素混合或自主描述,也可以继续告诉我要改数量、风格、服装、人物大小。"
)
messages = (state.messages + [SubjectAgentMessage(role="assistant", content=assistant_text, created_at=time.time())])[-30:]
state = state.model_copy(update={
"model_bundle": req.model_bundle,
"source_frame_indices": source_indices,
"analysis": analysis,
"messages": messages,
"generation_prompt_en": analysis.generation_brief_en,
"selected_traits": analysis.trait_chips[:6],
"updated_at": time.time(),
})
update(job, subject_agent=state, message="转换层分析完成")
return job_with_artifacts(job)
@app.post("/jobs/{job_id}/subject-agent/message", response_model=Job)
def message_subject_agent(job_id: str, req: SubjectAgentMessageReq) -> Job:
job = JOBS.get(job_id)
if not job:
raise HTTPException(404, "job not found")
state = job.subject_agent.model_copy(deep=True)
source_indices = [idx for idx in req.source_frame_indices if any(frame.index == idx for frame in job.frames)][:8]
state = state.model_copy(update={
"model_bundle": req.model_bundle,
"source_frame_indices": source_indices or state.source_frame_indices,
"selected_mode": req.selected_mode,
"selected_traits": [str(item).strip()[:80] for item in req.selected_traits if str(item).strip()][:24],
"requirements_zh": req.requirements_zh.strip()[:2200] or state.requirements_zh,
"quantity": max(1, min(10, int(req.quantity or state.quantity or 6))),
})
user_message = req.message.strip()
if not user_message:
user_message = state.requirements_zh or "按当前设置准备主体套图生成要求"
assistant_text, requirements_zh, prompt_en, quantity, selected_traits = _subject_agent_message_update(state, user_message)
messages = (
state.messages
+ [SubjectAgentMessage(role="user", content=user_message, created_at=time.time())]
+ [SubjectAgentMessage(role="assistant", content=assistant_text, created_at=time.time())]
)[-30:]
state = state.model_copy(update={
"requirements_zh": requirements_zh,
"generation_prompt_en": prompt_en,
"quantity": quantity,
"selected_traits": selected_traits,
"messages": messages,
"updated_at": time.time(),
})
update(job, subject_agent=state, message="转换层生图要求已更新")
return job_with_artifacts(job)
@app.delete("/jobs/{job_id}")
def delete_job(job_id: str) -> dict[str, bool | str]:
d = (JOBS_DIR / job_id).resolve()