feat: make subject conversion dialog-driven
This commit is contained in:
92
api/main.py
92
api/main.py
@@ -107,6 +107,7 @@ IMAGE_MODEL = GPT_IMAGE_MODEL
|
||||
PRODUCT_VIEW_MODEL = GPT_IMAGE_MODEL
|
||||
SUBJECT_ASSET_IMAGE_MODEL = GPT_IMAGE_MODEL
|
||||
SubjectModelBundle = Literal["gpt", "gemini"]
|
||||
SubjectAgentMode = Literal["realistic", "cartoon", "elements", "custom"]
|
||||
SUBJECT_AGENT_GPT_MODEL = gpt_model_env("SUBJECT_AGENT_GPT_MODEL", VISION_MODEL)
|
||||
SUBJECT_AGENT_GEMINI_MODEL = os.getenv("SUBJECT_AGENT_GEMINI_MODEL", "gemini-2.5-flash").strip() or "gemini-2.5-flash"
|
||||
SUBJECT_ASSET_IMAGE_MODELS = [GPT_IMAGE_MODEL] + (
|
||||
@@ -766,7 +767,7 @@ class SubjectAgentState(BaseModel):
|
||||
source_frame_indices: list[int] = Field(default_factory=list)
|
||||
analysis: SubjectAgentAnalysis | None = None
|
||||
messages: list[SubjectAgentMessage] = Field(default_factory=list)
|
||||
selected_mode: Literal["realistic", "cartoon", "elements", "custom"] = "custom"
|
||||
selected_mode: SubjectAgentMode = "custom"
|
||||
selected_traits: list[str] = Field(default_factory=list)
|
||||
requirements_zh: str = ""
|
||||
generation_prompt_en: str = ""
|
||||
@@ -4103,27 +4104,85 @@ def _subject_agent_analysis(job_id: str, source_indices: list[int], bundle: Subj
|
||||
)
|
||||
|
||||
|
||||
def _subject_agent_message_update(state: SubjectAgentState, user_message: str) -> tuple[str, str, str, int, list[str]]:
|
||||
_SUBJECT_AGENT_MODES: set[str] = {"realistic", "cartoon", "elements", "custom"}
|
||||
|
||||
|
||||
def _subject_agent_quantity_from_text(text: str, fallback: int) -> int:
|
||||
quantity = max(1, min(10, int(fallback or 6)))
|
||||
text = text or ""
|
||||
if re.fullmatch(r"\s*\d{1,2}\s*", text):
|
||||
return max(1, min(10, int(text.strip())))
|
||||
digit_match = re.search(r"(\d{1,2})\s*(?:张|个|视图|张图|图|views?)", text, flags=re.I)
|
||||
if digit_match:
|
||||
return max(1, min(10, int(digit_match.group(1))))
|
||||
cn_numbers = {
|
||||
"一": 1,
|
||||
"二": 2,
|
||||
"两": 2,
|
||||
"三": 3,
|
||||
"四": 4,
|
||||
"五": 5,
|
||||
"六": 6,
|
||||
"七": 7,
|
||||
"八": 8,
|
||||
"九": 9,
|
||||
"十": 10,
|
||||
}
|
||||
cn_match = re.search(r"([一二两三四五六七八九十])\s*(?:张|个|视图|张图|图)", text)
|
||||
if cn_match:
|
||||
return max(1, min(10, cn_numbers.get(cn_match.group(1), quantity)))
|
||||
return quantity
|
||||
|
||||
|
||||
def _subject_agent_mode_from_text(text: str, fallback: SubjectAgentMode = "custom") -> SubjectAgentMode:
|
||||
compact = re.sub(r"\s+", "", text or "").lower()
|
||||
if re.search(r"卡通|动画|插画|公仔|潮玩|二次元|cartoon|anime|illustration|toy|stylized", compact):
|
||||
return "cartoon"
|
||||
if re.search(r"创意复刻|创意模式|元素|参考创新|不像|换人|全新主体|全新人物|不同人|newperson|newactor|concept|element", compact):
|
||||
return "elements"
|
||||
if re.search(r"形象锁定|复刻这个人|复刻形象|同一主体|同一个人|保持这个人|保持原主体|完全复刻|source locked|same subject|sameperson", compact):
|
||||
return "realistic"
|
||||
if re.search(r"自主描述|只按文字|不依赖|不用参考|按描述|fromdescription|custom", compact):
|
||||
return "custom"
|
||||
return fallback
|
||||
|
||||
|
||||
def _subject_agent_mode_from_value(value: object, fallback: SubjectAgentMode) -> SubjectAgentMode:
|
||||
text = str(value or "").strip()
|
||||
return text if text in _SUBJECT_AGENT_MODES else fallback
|
||||
|
||||
|
||||
def _subject_agent_message_update(state: SubjectAgentState, user_message: str) -> tuple[str, str, str, int, list[str], SubjectAgentMode]:
|
||||
current_req = state.requirements_zh.strip()
|
||||
selected_traits = state.selected_traits[:20]
|
||||
quantity = max(1, min(10, int(state.quantity or 6)))
|
||||
qty_match = re.search(r"(\d{1,2})\s*张", user_message)
|
||||
if qty_match:
|
||||
quantity = max(1, min(10, int(qty_match.group(1))))
|
||||
quantity = _subject_agent_quantity_from_text(user_message, int(state.quantity or 6))
|
||||
selected_mode = _subject_agent_mode_from_text(user_message, state.selected_mode)
|
||||
fallback_req = ";".join(part for part in [current_req, user_message.strip()] if part).strip(";")
|
||||
mode_label = {
|
||||
"realistic": "source-locked same visible subject reconstruction",
|
||||
"cartoon": "cartoon or stylized reconstruction",
|
||||
"elements": "creative element reconstruction with a different new subject",
|
||||
"custom": "custom description driven subject generation",
|
||||
}.get(selected_mode, "custom description driven subject generation")
|
||||
fallback_prompt = _ensure_english(
|
||||
"Subject image generation requirements: "
|
||||
+ (fallback_req or "create a consistent SKG ad subject pack")
|
||||
+ f". Direction mode: {mode_label}."
|
||||
+ f" Generate exactly {quantity} separate views."
|
||||
+ ". Keep one identity and one outfit bible across all generated views. "
|
||||
+ (f"Selected traits: {', '.join(selected_traits)}." if selected_traits else "")
|
||||
)
|
||||
if not LLM_API_KEY:
|
||||
return "已记录这条生图要求。继续补充要保留/删除的元素,确认后我会按当前要求生成。", fallback_req, fallback_prompt, quantity, selected_traits
|
||||
return "已记录这条生图要求。继续补充要保留/删除的元素,确认后我会按当前要求生成。", fallback_req, fallback_prompt, quantity, selected_traits, selected_mode
|
||||
system = (
|
||||
"You are an SKG subject image-generation requirements agent. Your scope is only image generation for a subject view pack. "
|
||||
"Do not answer unrelated video, audio, download, coding, copywriting, or general chat requests; redirect to subject image requirements. "
|
||||
"Normalize the user's fuzzy Chinese request into precise generation constraints. "
|
||||
"Return strict JSON with keys: assistant_message_zh, updated_requirements_zh, generation_prompt_en, quantity, selected_traits. "
|
||||
"Infer selected_mode from the conversation. Allowed selected_mode values are realistic, cartoon, elements, custom. "
|
||||
"Use realistic when the user wants to lock or replicate the visible reference subject; cartoon for stylized/cartoon/toy/illustration; "
|
||||
"elements when the user wants the creative logic but a different new subject; custom when the user wants free text generation without relying on references. "
|
||||
"Infer quantity from Chinese or English requests such as 4张, 六视图, generate 8 views. "
|
||||
"Return strict JSON with keys: assistant_message_zh, updated_requirements_zh, generation_prompt_en, quantity, selected_traits, selected_mode. "
|
||||
"generation_prompt_en must be English and must enforce: one consistent identity, one consistent outfit bible, neck/shoulder readability, no text/watermarks/UI, and legal-safe reconstruction."
|
||||
)
|
||||
user_payload = {
|
||||
@@ -4153,12 +4212,13 @@ def _subject_agent_message_update(state: SubjectAgentState, user_message: str) -
|
||||
assistant = str(data.get("assistant_message_zh") or "已记录这条生图要求。").strip()[:1200]
|
||||
updated_req = str(data.get("updated_requirements_zh") or fallback_req).strip()[:2200]
|
||||
prompt_en = _ensure_english(str(data.get("generation_prompt_en") or fallback_prompt).strip())[:2600]
|
||||
out_quantity = max(1, min(10, int(data.get("quantity") or quantity)))
|
||||
out_quantity = _subject_agent_quantity_from_text(str(data.get("quantity") or ""), quantity)
|
||||
out_traits = _list_of_strings(data.get("selected_traits"), 24) or selected_traits
|
||||
return assistant, updated_req, prompt_en, out_quantity, out_traits
|
||||
out_mode = _subject_agent_mode_from_value(data.get("selected_mode"), selected_mode)
|
||||
return assistant, updated_req, prompt_en, out_quantity, out_traits, out_mode
|
||||
except Exception as e:
|
||||
print(f"[subject agent message failed] bundle={state.model_bundle} error={e}", flush=True)
|
||||
return "已先按本地规则记录这条要求;模型回复失败时仍可直接生成。", fallback_req, fallback_prompt, quantity, selected_traits
|
||||
return "已先按本地规则记录这条要求;模型回复失败时仍可直接生成。", fallback_req, fallback_prompt, quantity, selected_traits, selected_mode
|
||||
|
||||
|
||||
# ---------- API 路由 ----------
|
||||
@@ -4179,7 +4239,7 @@ class SubjectAgentMessageReq(BaseModel):
|
||||
|
||||
model_bundle: SubjectModelBundle = "gpt"
|
||||
source_frame_indices: list[int] = Field(default_factory=list)
|
||||
selected_mode: Literal["realistic", "cartoon", "elements", "custom"] = "custom"
|
||||
selected_mode: SubjectAgentMode = "custom"
|
||||
selected_traits: list[str] = Field(default_factory=list)
|
||||
requirements_zh: str = ""
|
||||
message: str = ""
|
||||
@@ -4666,7 +4726,7 @@ def analyze_subject_agent(job_id: str, req: SubjectAgentAnalyzeReq) -> Job:
|
||||
state = job.subject_agent.model_copy(deep=True)
|
||||
assistant_text = (
|
||||
f"我已用 {req.model_bundle.upper()} 套件分析这些参考帧。"
|
||||
"你可以选择形象锁定、创意复刻、元素混合或自主描述,也可以继续告诉我要改数量、风格、服装、人物大小。"
|
||||
"接下来直接告诉我要复刻形象、卡通化、参考创意换新人,还是只按文字生成;数量、风格、服装和人物大小也都写在对话里。"
|
||||
)
|
||||
messages = (state.messages + [SubjectAgentMessage(role="assistant", content=assistant_text, created_at=time.time())])[-30:]
|
||||
state = state.model_copy(update={
|
||||
@@ -4689,10 +4749,11 @@ def message_subject_agent(job_id: str, req: SubjectAgentMessageReq) -> Job:
|
||||
raise HTTPException(404, "job not found")
|
||||
state = job.subject_agent.model_copy(deep=True)
|
||||
source_indices = [idx for idx in req.source_frame_indices if any(frame.index == idx for frame in job.frames)][:8]
|
||||
fallback_mode = req.selected_mode or state.selected_mode
|
||||
state = state.model_copy(update={
|
||||
"model_bundle": req.model_bundle,
|
||||
"source_frame_indices": source_indices or state.source_frame_indices,
|
||||
"selected_mode": req.selected_mode,
|
||||
"selected_mode": fallback_mode,
|
||||
"selected_traits": [str(item).strip()[:80] for item in req.selected_traits if str(item).strip()][:24],
|
||||
"requirements_zh": req.requirements_zh.strip()[:2200] or state.requirements_zh,
|
||||
"quantity": max(1, min(10, int(req.quantity or state.quantity or 6))),
|
||||
@@ -4700,7 +4761,7 @@ def message_subject_agent(job_id: str, req: SubjectAgentMessageReq) -> Job:
|
||||
user_message = req.message.strip()
|
||||
if not user_message:
|
||||
user_message = state.requirements_zh or "按当前设置准备主体套图生成要求"
|
||||
assistant_text, requirements_zh, prompt_en, quantity, selected_traits = _subject_agent_message_update(state, user_message)
|
||||
assistant_text, requirements_zh, prompt_en, quantity, selected_traits, selected_mode = _subject_agent_message_update(state, user_message)
|
||||
messages = (
|
||||
state.messages
|
||||
+ [SubjectAgentMessage(role="user", content=user_message, created_at=time.time())]
|
||||
@@ -4709,6 +4770,7 @@ def message_subject_agent(job_id: str, req: SubjectAgentMessageReq) -> Job:
|
||||
state = state.model_copy(update={
|
||||
"requirements_zh": requirements_zh,
|
||||
"generation_prompt_en": prompt_en,
|
||||
"selected_mode": selected_mode,
|
||||
"quantity": quantity,
|
||||
"selected_traits": selected_traits,
|
||||
"messages": messages,
|
||||
|
||||
Reference in New Issue
Block a user