6123 lines
262 KiB
Python
6123 lines
262 KiB
Python
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import base64
|
||
import hashlib
|
||
import hmac
|
||
import json
|
||
import os
|
||
import random
|
||
import re
|
||
import secrets
|
||
import shutil
|
||
import subprocess
|
||
import threading
|
||
import time
|
||
import uuid
|
||
from contextlib import asynccontextmanager
|
||
from pathlib import Path
|
||
from typing import Literal
|
||
|
||
import httpx
|
||
from dotenv import load_dotenv
|
||
from fastapi import BackgroundTasks, FastAPI, File, HTTPException, Request, Response, UploadFile
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.responses import FileResponse
|
||
from pydantic import BaseModel, Field
|
||
|
||
load_dotenv()
|
||
|
||
JOBS_DIR = Path(os.getenv("JOBS_DIR", "./jobs")).resolve()
|
||
JOBS_DIR.mkdir(parents=True, exist_ok=True)
|
||
CORS_ORIGINS = [o.strip() for o in os.getenv("CORS_ORIGINS", "http://localhost:4290,http://127.0.0.1:4290").split(",") if o.strip()]
|
||
PRODUCT_LIBRARY_DIR = Path(
|
||
os.getenv("PRODUCT_LIBRARY_DIR", Path(__file__).resolve().parent / "product_library" / "skg-products")
|
||
).resolve()
|
||
PRODUCT_LIBRARY_MANIFEST = PRODUCT_LIBRARY_DIR / "manifest.json"
|
||
CHARACTER_LIBRARY_DIR = Path(
|
||
os.getenv("CHARACTER_LIBRARY_DIR", Path(__file__).resolve().parent / "character_library" / "skg-characters")
|
||
).resolve()
|
||
CHARACTER_LIBRARY_MANIFEST = CHARACTER_LIBRARY_DIR / "manifest.json"
|
||
SUBJECT_TEMPLATE_DIR = Path(os.getenv("SUBJECT_TEMPLATE_DIR", JOBS_DIR / "_subject_templates")).resolve()
|
||
SUBJECT_TEMPLATE_IMAGE_DIR = SUBJECT_TEMPLATE_DIR / "images"
|
||
SUBJECT_TEMPLATE_MANIFEST = SUBJECT_TEMPLATE_DIR / "manifest.json"
|
||
SUBJECT_TEMPLATE_IMAGE_DIR.mkdir(parents=True, exist_ok=True)
|
||
|
||
LLM_BASE_URL = os.getenv("LLM_BASE_URL", "").strip()
|
||
LLM_API_KEY = os.getenv("LLM_API_KEY", "").strip()
|
||
ASR_MODEL = os.getenv("ASR_MODEL", "whisper-1")
|
||
ASR_FALLBACK_MODEL = os.getenv("ASR_FALLBACK_MODEL", "gemini-2.5-flash").strip() or "gemini-2.5-flash"
|
||
ASR_TIMEOUT_SECONDS = max(15, int(os.getenv("ASR_TIMEOUT_SECONDS", "45")))
|
||
LOCAL_ASR_BIN = os.getenv("LOCAL_ASR_BIN", "").strip()
|
||
LOCAL_ASR_MODEL = os.getenv("LOCAL_ASR_MODEL", "mlx-community/whisper-tiny").strip() or "mlx-community/whisper-tiny"
|
||
LOCAL_ASR_TIMEOUT_SECONDS = max(30, int(os.getenv("LOCAL_ASR_TIMEOUT_SECONDS", "180")))
|
||
TRANSLATE_MODEL = os.getenv("TRANSLATE_MODEL", "gemini-2.5-flash")
|
||
DEFAULT_GPT_TEXT_MODEL = os.getenv("GPT_TEXT_MODEL", "gpt-4o").strip() or "gpt-4o"
|
||
|
||
|
||
def gpt_model_env(name: str, default: str | None = None) -> str:
|
||
value = os.getenv(name, default or DEFAULT_GPT_TEXT_MODEL).strip()
|
||
if not value or value.lower().startswith("gemini-"):
|
||
return default or DEFAULT_GPT_TEXT_MODEL
|
||
return value
|
||
|
||
|
||
REWRITE_MODEL = gpt_model_env("REWRITE_MODEL")
|
||
VISION_MODEL = gpt_model_env("VISION_MODEL")
|
||
IMAGE_BASE_URL = os.getenv("IMAGE_BASE_URL", LLM_BASE_URL).strip()
|
||
IMAGE_API_KEY = os.getenv("IMAGE_API_KEY", LLM_API_KEY).strip()
|
||
AI_HTTP_PROXY = (
|
||
os.getenv("AI_HTTP_PROXY")
|
||
or os.getenv("IMAGE_HTTP_PROXY")
|
||
or os.getenv("HTTPS_PROXY")
|
||
or os.getenv("https_proxy")
|
||
or os.getenv("HTTP_PROXY")
|
||
or os.getenv("http_proxy")
|
||
or ""
|
||
).strip()
|
||
# Product decision: every image-generation/editing path is locked to gpt-image-2.
|
||
# Environment variables may still choose the gateway URL/key, but not the model.
|
||
GPT_IMAGE_MODEL = "gpt-image-2"
|
||
IMAGE_MODEL = GPT_IMAGE_MODEL
|
||
PRODUCT_VIEW_MODEL = GPT_IMAGE_MODEL
|
||
SUBJECT_ASSET_IMAGE_MODEL = GPT_IMAGE_MODEL
|
||
SUBJECT_ASSET_IMAGE_MODELS = [GPT_IMAGE_MODEL]
|
||
PRODUCT_ASSET_MAX_SIDE = max(1024, int(os.getenv("PRODUCT_ASSET_MAX_SIDE", "1600")))
|
||
PRODUCT_ASSET_MIN_LONG_SIDE = max(512, int(os.getenv("PRODUCT_ASSET_MIN_LONG_SIDE", "900")))
|
||
PRODUCT_ASSET_MIN_SHORT_SIDE = max(320, int(os.getenv("PRODUCT_ASSET_MIN_SHORT_SIDE", "600")))
|
||
PRODUCT_ASSET_JPEG_QUALITY = max(80, min(95, int(os.getenv("PRODUCT_ASSET_JPEG_QUALITY", "92"))))
|
||
VIDEO_MODEL = os.getenv("VIDEO_MODEL", "seedance").strip() or "seedance"
|
||
YTDLP_COOKIES_FILE = os.getenv("YTDLP_COOKIES_FILE", "").strip()
|
||
YTDLP_COOKIES_FROM_BROWSER = os.getenv("YTDLP_COOKIES_FROM_BROWSER", "").strip()
|
||
AUDIO_PRODUCT_BRIEF = os.getenv(
|
||
"AUDIO_PRODUCT_BRIEF",
|
||
"SKG 智能按摩产品,主打日常肩颈、腰背、眼部、膝盖或足部放松;广告表达要高级、干净、可信,不做医疗疗效承诺。",
|
||
).strip()
|
||
AUDIO_REWRITE_MODEL = gpt_model_env("AUDIO_REWRITE_MODEL", REWRITE_MODEL)
|
||
VOICE_PROVIDER = "azure_openai"
|
||
AZURE_OPENAI_BASE_URL = os.getenv("AZURE_OPENAI_BASE_URL", "https://ai.skg.com/azure").strip().rstrip("/")
|
||
AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY", LLM_API_KEY).strip()
|
||
AZURE_TTS_MODEL = os.getenv("AZURE_TTS_MODEL", "gpt-4o-mini-tts").strip() or "gpt-4o-mini-tts"
|
||
AZURE_TTS_VOICE_ID = os.getenv("AZURE_TTS_VOICE_ID", "alloy").strip() or "alloy"
|
||
DEFAULT_AZURE_TTS_VOICE_POOL = ["alloy", "verse", "shimmer"]
|
||
AZURE_TTS_VOICE_POOL = [
|
||
v.strip()
|
||
for v in os.getenv("AZURE_TTS_VOICE_POOL", ",".join(DEFAULT_AZURE_TTS_VOICE_POOL)).split(",")
|
||
if v.strip()
|
||
]
|
||
AZURE_TTS_PATH = os.getenv("AZURE_TTS_PATH", "/audio/speech").strip() or "/audio/speech"
|
||
AZURE_TTS_PATHS = [
|
||
p.strip()
|
||
for p in os.getenv("AZURE_TTS_PATHS", f"{AZURE_TTS_PATH},/audio/speech,/v1/audio/speech").split(",")
|
||
if p.strip()
|
||
]
|
||
|
||
POE_API_BASE_URL = os.getenv("POE_API_BASE_URL", "https://api.poe.com/v1").strip() or "https://api.poe.com/v1"
|
||
POE_API_KEY = os.getenv("POE_API_KEY", "").strip()
|
||
|
||
|
||
def env_video_model(name: str, default: str) -> str:
|
||
value = os.getenv(name, "").strip()
|
||
if not value:
|
||
return default
|
||
# Older local envs used business aliases as model IDs. Keep those aliases usable
|
||
# while mapping them to concrete Poe video model IDs by default.
|
||
if value.lower() in {"seedance", "kling", "veo", "veo3", "voe"}:
|
||
return default
|
||
return value
|
||
|
||
|
||
VIDEO_MODEL_ALIASES = {
|
||
"seedance": env_video_model("VIDEO_MODEL_SEEDANCE", "seedance-2-fast"),
|
||
"kling": env_video_model("VIDEO_MODEL_KLING", "kling-omni"),
|
||
"veo3": env_video_model("VIDEO_MODEL_VEO3", "veo-3.1-fast"),
|
||
"veo": env_video_model("VIDEO_MODEL_VEO3", "veo-3.1-fast"),
|
||
"voe": env_video_model("VIDEO_MODEL_VEO3", "veo-3.1-fast"),
|
||
}
|
||
VIDEO_API_BASE_URL = os.getenv("VIDEO_API_BASE_URL", "").strip()
|
||
VIDEO_API_KEY = os.getenv("VIDEO_API_KEY", "").strip()
|
||
WEB_AUTH_USERNAME = os.getenv("WEB_AUTH_USERNAME", "").strip()
|
||
WEB_AUTH_PASSWORD = os.getenv("WEB_AUTH_PASSWORD", "").strip()
|
||
WEB_AUTH_SESSION_SECRET = os.getenv("WEB_AUTH_SESSION_SECRET", "").strip()
|
||
WEB_AUTH_COOKIE_NAME = os.getenv("WEB_AUTH_COOKIE_NAME", "skg_marketing_session").strip() or "skg_marketing_session"
|
||
WEB_AUTH_COOKIE_SECURE = os.getenv("WEB_AUTH_COOKIE_SECURE", "true").strip().lower() not in {"0", "false", "no"}
|
||
WEB_AUTH_CONFIGURED = bool(WEB_AUTH_USERNAME and WEB_AUTH_PASSWORD and WEB_AUTH_SESSION_SECRET)
|
||
|
||
|
||
def default_video_gateway_paths(base_url: str) -> tuple[str, str, str]:
|
||
base = base_url.strip().rstrip("/").lower()
|
||
if "ai.skg.com/doubao" in base:
|
||
return (
|
||
"/api/v3/contents/generations/tasks",
|
||
"/api/v3/contents/generations/tasks/{id}",
|
||
"/api/v3/contents/generations/tasks/{id}/content",
|
||
)
|
||
if "ark.cn-beijing.volces.com" in base:
|
||
return (
|
||
"/contents/generations/tasks",
|
||
"/contents/generations/tasks/{id}",
|
||
"/contents/generations/tasks/{id}/content",
|
||
)
|
||
return ("/videos", "/videos/{id}", "/videos/{id}/content")
|
||
|
||
|
||
DEFAULT_VIDEO_CREATE_PATH, DEFAULT_VIDEO_STATUS_PATH, DEFAULT_VIDEO_CONTENT_PATH = default_video_gateway_paths(VIDEO_API_BASE_URL)
|
||
VIDEO_CREATE_PATH = os.getenv("VIDEO_CREATE_PATH", DEFAULT_VIDEO_CREATE_PATH).strip() or DEFAULT_VIDEO_CREATE_PATH
|
||
VIDEO_CREATE_PATHS = [
|
||
p.strip()
|
||
for p in os.getenv(
|
||
"VIDEO_CREATE_PATHS",
|
||
VIDEO_CREATE_PATH if VIDEO_CREATE_PATH != "/videos" else f"{VIDEO_CREATE_PATH},/videos/generations,/video/generations",
|
||
).split(",")
|
||
if p.strip()
|
||
]
|
||
VIDEO_STATUS_PATH = os.getenv("VIDEO_STATUS_PATH", DEFAULT_VIDEO_STATUS_PATH).strip() or DEFAULT_VIDEO_STATUS_PATH
|
||
VIDEO_CONTENT_PATH = os.getenv("VIDEO_CONTENT_PATH", DEFAULT_VIDEO_CONTENT_PATH).strip() or DEFAULT_VIDEO_CONTENT_PATH
|
||
VIDEO_DURATION_FIELD = os.getenv("VIDEO_DURATION_FIELD", "seconds").strip() or "seconds"
|
||
VIDEO_POLL_TIMEOUT_SECONDS = max(60, int(os.getenv("VIDEO_POLL_TIMEOUT_SECONDS", "900")))
|
||
FFMPEG_BIN = os.getenv("FFMPEG_BIN", "").strip()
|
||
FFPROBE_BIN = os.getenv("FFPROBE_BIN", "").strip()
|
||
LOCAL_FFMPEG_CANDIDATES = [
|
||
"/Applications/Downie 4.app/Contents/Resources/ffmpeg",
|
||
"/Applications/Permute 3.app/Contents/Resources/ffmpeg",
|
||
"/Applications/VideoFusion-macOS.app/Contents/Resources/ffmpeg",
|
||
]
|
||
_MEDIA_BIN_CACHE: dict[str, str] = {}
|
||
|
||
# OpenAI 客户端(OpenAI 兼容网关,含 SKG ezlink)
|
||
from openai import OpenAI
|
||
_llm_client: OpenAI | None = None
|
||
_image_client: OpenAI | None = None
|
||
|
||
def ai_http_client(timeout: float = 120) -> httpx.Client:
|
||
"""HTTP client for SKG AI gateway calls.
|
||
|
||
launchd does not reliably inherit interactive-shell proxy variables, so the
|
||
app also supports an explicit AI_HTTP_PROXY / IMAGE_HTTP_PROXY in api/.env.
|
||
"""
|
||
kwargs: dict = {"timeout": timeout}
|
||
if AI_HTTP_PROXY:
|
||
kwargs["proxy"] = AI_HTTP_PROXY
|
||
return httpx.Client(**kwargs)
|
||
|
||
|
||
def openai_http_client(timeout: float = 120) -> httpx.Client | None:
|
||
return ai_http_client(timeout=timeout) if AI_HTTP_PROXY else None
|
||
|
||
|
||
def llm() -> OpenAI:
|
||
global _llm_client
|
||
if _llm_client is None:
|
||
if not LLM_API_KEY:
|
||
raise RuntimeError("LLM_API_KEY 未配置")
|
||
kwargs = {"base_url": LLM_BASE_URL or None, "api_key": LLM_API_KEY}
|
||
http_client = openai_http_client()
|
||
if http_client:
|
||
kwargs["http_client"] = http_client
|
||
_llm_client = OpenAI(**kwargs)
|
||
return _llm_client
|
||
|
||
def image_llm() -> OpenAI:
|
||
global _image_client
|
||
if _image_client is None:
|
||
if not IMAGE_API_KEY:
|
||
raise RuntimeError("IMAGE_API_KEY 或 LLM_API_KEY 未配置")
|
||
kwargs = {"base_url": IMAGE_BASE_URL or None, "api_key": IMAGE_API_KEY}
|
||
http_client = openai_http_client()
|
||
if http_client:
|
||
kwargs["http_client"] = http_client
|
||
_image_client = OpenAI(**kwargs)
|
||
return _image_client
|
||
|
||
def product_view_llm() -> OpenAI:
|
||
return image_llm() if PRODUCT_VIEW_MODEL == GPT_IMAGE_MODEL else llm()
|
||
|
||
# Pipeline 状态:
|
||
# created → downloading → downloaded(前端“开始”会继续触发音频解析)
|
||
# → splitting → frames_extracted
|
||
# → transcribing → transcribed | failed
|
||
JobStatus = Literal[
|
||
"created", "downloading", "downloaded",
|
||
"splitting", "frames_extracted",
|
||
"transcribing", "transcribed", "failed",
|
||
]
|
||
|
||
KEYFRAME_COUNT = int(os.getenv("KEYFRAME_COUNT", "12"))
|
||
FrameExtractTarget = Literal["transparent_human", "balanced", "subject", "transition", "expression", "motion"]
|
||
FrameExtractMode = Literal["replace", "append"]
|
||
FrameExtractQuality = Literal["auto", "fast", "accurate", "ultra"]
|
||
AnalyzeTask = tuple[str, int, FrameExtractTarget, FrameExtractMode, FrameExtractQuality]
|
||
AssetBackground = Literal["white", "black"]
|
||
AssetSize = Literal["source", "1024", "1536", "2048"]
|
||
AssetQuality = Literal["hd"]
|
||
SubjectKind = Literal["object", "living"]
|
||
SubjectView = str
|
||
SceneMode = Literal["remove_subject", "similar", "style"]
|
||
SceneStyle = Literal["source", "premium_product", "clean_studio", "warm_lifestyle", "cinematic"]
|
||
SceneAssetRole = Literal["scene", "first_frame", "last_frame"]
|
||
FRAME_TARGET_LABELS: dict[FrameExtractTarget, str] = {
|
||
"transparent_human": "透明骨架人",
|
||
"balanced": "综合关键帧",
|
||
"subject": "清晰主体",
|
||
"transition": "转场变化",
|
||
"expression": "表情瞬间",
|
||
"motion": "动作峰值",
|
||
}
|
||
|
||
TRANSPARENT_HUMAN_POSITIVE_PROMPT = (
|
||
"Target subject: transparent human character, translucent human body, glass-like human body, clear acrylic skin, "
|
||
"transparent vinyl skin, visible clean white skeleton inside, skeleton visible inside transparent body, "
|
||
"white bones inside clear body, non-horror skeleton character, friendly transparent humanoid, 3D commercial character, "
|
||
"premium wellness character, transparent body with visible spine, transparent body with visible rib cage. "
|
||
"中文目标:透明人体、半透明人体、玻璃人体、亚克力人体、果冻质感人体、外层透明皮肤、身体内部可见骨架、"
|
||
"透明身体里的白色骨骼、干净白色骨架、非恐怖骷髅人、3D广告角色、透明骨架人、可见脊柱、可见肋骨、"
|
||
"可见颈椎、可见骨盆、可见四肢骨骼、透明皮肤包裹骨架。"
|
||
)
|
||
TRANSPARENT_HUMAN_NEGATIVE_PROMPT = (
|
||
"Avoid: normal human, ordinary skeleton, skeleton only without transparent body, horror skeleton, gore, blood, corpse, "
|
||
"zombie, organs, veins, autopsy, surgery, hospital, dark horror scene, blurry person, heavily occluded person, "
|
||
"person too small, product only, background only, no visible skeleton, no transparent body, transparent clothing only. "
|
||
"反向排除:普通真人、普通骷髅、只有骨架没有透明外壳、恐怖骷髅、血腥、腐烂、僵尸、尸体、器官、血管、"
|
||
"解剖、医院、手术、黑暗恐怖场景、模糊人物、遮挡严重、人物太远、只有产品没有人、只有背景没有人、"
|
||
"看不到骨架、看不到透明身体、透明衣服但不是透明身体。"
|
||
)
|
||
TRANSPARENT_HUMAN_QUALIFIED_STANDARD = (
|
||
"A qualified frame must satisfy all core conditions: 1) there is a humanoid character; "
|
||
"2) the outer body is transparent or translucent; 3) a clean white skeleton is clearly visible inside the body; "
|
||
"4) the transparent body and inner skeleton belong to the same character, not a background overlay; "
|
||
"5) the character should occupy at least about 35% of frame height and be easy to inspect; "
|
||
"6) no severe blur, occlusion, or deformation; 7) clean premium commercial wellness style, non-horror."
|
||
)
|
||
FRAME_QUALITY_LABELS: dict[FrameExtractQuality, str] = {
|
||
"auto": "自动",
|
||
"fast": "快速",
|
||
"accurate": "精细",
|
||
"ultra": "极准",
|
||
}
|
||
|
||
|
||
class GeneratedImage(BaseModel):
|
||
id: str # uuid hex 12
|
||
prompt: str
|
||
model: str
|
||
mode: str = "edit" # "edit"(带参考图) | "text"(纯文字)
|
||
url: str # /jobs/{job_id}/frames/{idx}/gen/{id}.jpg
|
||
selected: bool = False
|
||
created_at: float = 0.0
|
||
|
||
|
||
class GeneratedVideo(BaseModel):
|
||
id: str
|
||
provider_id: str = ""
|
||
frame_idx: int
|
||
prompt: str
|
||
model: str = ""
|
||
status: Literal["queued", "in_progress", "completed", "failed"] = "queued"
|
||
url: str = ""
|
||
poster_url: str = ""
|
||
duration: float = 4.0
|
||
progress: int = 0
|
||
error: str = ""
|
||
created_at: float = 0.0
|
||
|
||
|
||
class VideoSourceRef(BaseModel):
|
||
kind: Literal["image", "source_video"] = "image"
|
||
url: str = ""
|
||
|
||
|
||
class StoryboardScene(BaseModel):
|
||
"""分镜头编排:每个 selected 分镜对应一个 scene 描述
|
||
v2: 4 图槽 + 时长(复制粘贴模式)— 主体 / 场景 / 产品 / 动作 各一张图
|
||
v1 字段保留兼容(subject/product/scene/action/reference_ids)"""
|
||
duration: float = 0
|
||
first_image: dict | None = None
|
||
last_image: dict | None = None
|
||
product_images: list[dict] = Field(default_factory=list)
|
||
subject_images: list[dict] = Field(default_factory=list)
|
||
product_fusion_shots: list[dict] = Field(default_factory=list)
|
||
visual_mode: Literal["person_only", "person_product", "product_only", "environment"] = "person_product"
|
||
needs_product: bool = True
|
||
needs_subject: bool = True
|
||
first_frame_plan: str = ""
|
||
last_frame_plan: str = ""
|
||
product_placement: str = ""
|
||
# 4 图槽:dict 含 {kind, frame_idx, element_id?, cutout_id?, label}
|
||
subject_image: dict | None = None
|
||
scene_image: dict | None = None
|
||
product_image: dict | None = None
|
||
action_image: dict | None = None
|
||
# v1 兼容
|
||
subject: str = ""
|
||
product: str = ""
|
||
scene: str = ""
|
||
action: str = ""
|
||
reference_ids: list[str] = []
|
||
|
||
|
||
class StoryboardImage(BaseModel):
|
||
"""用户从各处"上推"到分镜头编排区的图片"""
|
||
ref_id: str # uuid hex 8
|
||
kind: Literal["keyframe", "cutout", "asset"] # asset = 场景 / 主体视角等组图素材
|
||
frame_idx: int
|
||
element_id: str | None = None # cutout 时
|
||
cutout_id: str | None = None # cutout 时(versioned id;老数据可能 == element_id)
|
||
label: str = "" # 显示用名字
|
||
created_at: float = 0.0
|
||
|
||
|
||
class QualityReport(BaseModel):
|
||
width: int = 0
|
||
height: int = 0
|
||
short_side: int = 0
|
||
sharpness: float = 0.0
|
||
risk: Literal["ok", "warn", "bad"] = "ok"
|
||
warnings: list[str] = Field(default_factory=list)
|
||
|
||
|
||
class TransparentHumanFrameScore(BaseModel):
|
||
transparent_body_score: int = 0
|
||
skeleton_visible_score: int = 0
|
||
human_prominence_score: int = 0
|
||
clarity_score: int = 0
|
||
commercial_style_score: int = 0
|
||
product_usefulness_score: int = 0
|
||
total_score: int = 0
|
||
qualified: bool = False
|
||
reject_reason: str = ""
|
||
notes: str = ""
|
||
|
||
|
||
class SceneAsset(BaseModel):
|
||
id: str
|
||
label: str = ""
|
||
url: str = ""
|
||
width: int = 0
|
||
height: int = 0
|
||
quality: AssetQuality = "hd"
|
||
size: AssetSize = "source"
|
||
scene_mode: SceneMode = "remove_subject"
|
||
scene_style: SceneStyle = "source"
|
||
asset_role: SceneAssetRole = "scene"
|
||
quality_report: QualityReport | None = None
|
||
created_at: float = 0.0
|
||
|
||
|
||
class SubjectAsset(BaseModel):
|
||
id: str
|
||
view: SubjectView
|
||
label: str = ""
|
||
url: str = ""
|
||
width: int = 0
|
||
height: int = 0
|
||
background: AssetBackground = "white"
|
||
quality: AssetQuality = "hd"
|
||
size: AssetSize = "source"
|
||
source_frame_indices: list[int] = Field(default_factory=list)
|
||
ai_completed: bool = True
|
||
created_at: float = 0.0
|
||
|
||
|
||
class ProductLibraryItem(BaseModel):
|
||
id: str
|
||
handle: str
|
||
title: str
|
||
product_type: str = ""
|
||
image_type: str = "gallery"
|
||
image_index: int = 0
|
||
filename: str
|
||
url: str
|
||
width: int = 0
|
||
height: int = 0
|
||
source_path: str = ""
|
||
white_score: float = 0.0
|
||
near_white_score: float = 0.0
|
||
has_people: bool = False
|
||
tags: list[str] = Field(default_factory=list)
|
||
|
||
|
||
class CharacterLibraryImage(BaseModel):
|
||
id: str
|
||
view: str
|
||
label: str
|
||
filename: str
|
||
width: int = 0
|
||
height: int = 0
|
||
source_path: str = ""
|
||
url: str = ""
|
||
|
||
|
||
class CharacterLibraryItem(BaseModel):
|
||
id: str
|
||
name: str
|
||
folder: str = ""
|
||
description: str = ""
|
||
prompt_brief: str = ""
|
||
primary_image: str = ""
|
||
images: list[CharacterLibraryImage] = Field(default_factory=list)
|
||
|
||
|
||
class SubjectTemplateImage(BaseModel):
|
||
id: str
|
||
view: str
|
||
label: str = ""
|
||
filename: str
|
||
url: str = ""
|
||
width: int = 0
|
||
height: int = 0
|
||
background: AssetBackground = "white"
|
||
quality: AssetQuality = "hd"
|
||
size: AssetSize = "source"
|
||
source_asset_id: str = ""
|
||
source_frame_indices: list[int] = Field(default_factory=list)
|
||
created_at: float = 0.0
|
||
|
||
|
||
class SubjectTemplateItem(BaseModel):
|
||
id: str
|
||
name: str
|
||
description: str = ""
|
||
note: str = ""
|
||
prompt_brief: str = ""
|
||
source: Literal["database"] = "database"
|
||
source_job_id: str = ""
|
||
source_frame_idx: int = -1
|
||
source_element_id: str = ""
|
||
subject_style: Literal["transparent_human", "source_actor"] = "transparent_human"
|
||
primary_image: str = ""
|
||
images: list[SubjectTemplateImage] = Field(default_factory=list)
|
||
created_at: float = 0.0
|
||
updated_at: float = 0.0
|
||
|
||
|
||
class ProductFusionRegion(BaseModel):
|
||
x: float = 0
|
||
y: float = 0
|
||
w: float = 0
|
||
h: float = 0
|
||
|
||
|
||
class ProductFusionShot(BaseModel):
|
||
id: str = ""
|
||
first_image: dict | None = None
|
||
last_image: dict | None = None
|
||
product_images: list[dict] = Field(default_factory=list)
|
||
product_image: dict | None = None
|
||
character_id: str = ""
|
||
character_name: str = ""
|
||
subject_image: dict | None = None
|
||
subject_images: list[dict] = Field(default_factory=list)
|
||
person_image: dict | None = None
|
||
product_region: ProductFusionRegion | None = None
|
||
scene_image: dict | None = None
|
||
action_text: str = ""
|
||
duration: float = 5
|
||
image_model: str = "gpt-image-2"
|
||
video_model: str = "seedance"
|
||
guide_image: dict | None = None
|
||
|
||
|
||
class KeyElement(BaseModel):
|
||
"""关键帧里识别 / 用户提取的元素 · 多次提取累积多张图,让用户挑选满意的"""
|
||
id: str # uuid hex 8
|
||
name_zh: str
|
||
name_en: str = ""
|
||
position: str = ""
|
||
source: Literal["auto", "manual", "region"] = "manual"
|
||
region: dict | None = None
|
||
# 多张提取图 id(每次 cutout 端点累积新 id)→ /jobs/.../elements/{element_id}/cutouts/{cutout_id}.jpg
|
||
cutouts: list[str] = []
|
||
# 旧字段兼容(v1 单图)· 渲染时 fallback 用,新提取不再写入
|
||
cutout_id: str | None = None
|
||
cutout_background: Literal["white", "black"] = "white"
|
||
subject_kind: SubjectKind = "object"
|
||
subject_assets: list[SubjectAsset] = Field(default_factory=list)
|
||
created_at: float = 0.0
|
||
|
||
|
||
class KeyFrame(BaseModel):
|
||
index: int
|
||
timestamp: float
|
||
url: str
|
||
description: dict | None = None # vision 模型识别结果 {scene, objects, style, suggested_prompt}
|
||
transparent_human_score: TransparentHumanFrameScore | None = None
|
||
cleaned_url: str | None = None # 清洗后干净版(待应用)→ /jobs/{id}/frames/{idx}/cleaned.jpg
|
||
cleaned_applied: bool = False # 是否已用清洗版替换原图(替换后 cleaned_url=null)
|
||
quality_report: QualityReport | None = None
|
||
scene_assets: list[SceneAsset] = Field(default_factory=list)
|
||
elements: list[KeyElement] = [] # 提取的元素清单(持久化)
|
||
storyboard: StoryboardScene | None = None # 分镜头编排字段
|
||
generated_images: list[GeneratedImage] = []
|
||
|
||
|
||
class TranscriptSegment(BaseModel):
|
||
index: int
|
||
start: float
|
||
end: float
|
||
en: str
|
||
zh: str = ""
|
||
|
||
|
||
class AudioScript(BaseModel):
|
||
status: Literal["idle", "rewriting", "completed", "failed"] = "idle"
|
||
source_text: str = ""
|
||
source_zh: str = ""
|
||
rewritten_text: str = ""
|
||
speaker_profile: str = ""
|
||
rhythm_profile: str = ""
|
||
background_audio_profile: str = ""
|
||
product_brief: str = ""
|
||
rewrite_model: str = ""
|
||
voice_provider: str = ""
|
||
voice_model: str = ""
|
||
voice_id: str = ""
|
||
voice_url: str = ""
|
||
error: str = ""
|
||
created_at: float = 0.0
|
||
|
||
|
||
class Job(BaseModel):
|
||
id: str
|
||
url: str
|
||
status: JobStatus = "created"
|
||
progress: int = 0
|
||
message: str = ""
|
||
video_url: str = ""
|
||
duration: float = 0.0
|
||
width: int = 0
|
||
height: int = 0
|
||
source_audio_url: str = ""
|
||
frames: list[KeyFrame] = Field(default_factory=list)
|
||
transcript: list[TranscriptSegment] = Field(default_factory=list)
|
||
audio_script: AudioScript = Field(default_factory=AudioScript)
|
||
storyboard_images: list[StoryboardImage] = Field(default_factory=list)
|
||
generated_videos: list[GeneratedVideo] = Field(default_factory=list)
|
||
product_refs: list[dict] = Field(default_factory=list)
|
||
error: str = ""
|
||
|
||
|
||
class AuthLoginPayload(BaseModel):
|
||
username: str
|
||
password: str
|
||
remember: bool = False
|
||
|
||
|
||
JOBS: dict[str, Job] = {}
|
||
ANALYZE_QUEUE: list[AnalyzeTask] = []
|
||
ANALYZE_WORKER_RUNNING = False
|
||
AUDIO_WORKERS_RUNNING: set[str] = set()
|
||
AUDIO_WORKERS_LOCK = threading.Lock()
|
||
|
||
|
||
def ensure_auth_configured() -> None:
|
||
if not WEB_AUTH_CONFIGURED:
|
||
raise HTTPException(503, "WEB_AUTH_USERNAME、WEB_AUTH_PASSWORD 或 WEB_AUTH_SESSION_SECRET 未配置")
|
||
|
||
|
||
def _auth_signature(body: str) -> str:
|
||
return hmac.new(WEB_AUTH_SESSION_SECRET.encode("utf-8"), body.encode("utf-8"), hashlib.sha256).hexdigest()
|
||
|
||
|
||
def _encode_auth_body(payload: dict) -> str:
|
||
raw = json.dumps(payload, ensure_ascii=False, separators=(",", ":")).encode("utf-8")
|
||
return base64.urlsafe_b64encode(raw).decode("ascii").rstrip("=")
|
||
|
||
|
||
def _decode_auth_body(body: str) -> dict:
|
||
padded = body + "=" * (-len(body) % 4)
|
||
raw = base64.urlsafe_b64decode(padded.encode("ascii"))
|
||
data = json.loads(raw.decode("utf-8"))
|
||
return data if isinstance(data, dict) else {}
|
||
|
||
|
||
def make_auth_token(username: str, ttl_seconds: int) -> str:
|
||
body = _encode_auth_body({
|
||
"u": username,
|
||
"exp": int(time.time()) + ttl_seconds,
|
||
"n": secrets.token_hex(8),
|
||
})
|
||
return f"{body}.{_auth_signature(body)}"
|
||
|
||
|
||
def verify_auth_token(token: str) -> str | None:
|
||
if not WEB_AUTH_CONFIGURED or "." not in token:
|
||
return None
|
||
body, supplied_sig = token.rsplit(".", 1)
|
||
if not hmac.compare_digest(_auth_signature(body), supplied_sig):
|
||
return None
|
||
try:
|
||
payload = _decode_auth_body(body)
|
||
username = str(payload.get("u") or "")
|
||
expires_at = int(payload.get("exp") or 0)
|
||
except Exception:
|
||
return None
|
||
if username != WEB_AUTH_USERNAME or expires_at < int(time.time()):
|
||
return None
|
||
return username
|
||
|
||
|
||
def auth_username_from_request(request: Request) -> str | None:
|
||
token = request.cookies.get(WEB_AUTH_COOKIE_NAME, "")
|
||
return verify_auth_token(token)
|
||
|
||
|
||
def job_dir(job_id: str) -> Path:
|
||
d = JOBS_DIR / job_id
|
||
d.mkdir(parents=True, exist_ok=True)
|
||
return d
|
||
|
||
|
||
def source_audio_url_for(job_id: str) -> str:
|
||
return f"/jobs/{job_id}/audio.wav" if (JOBS_DIR / job_id / "audio.wav").exists() else ""
|
||
|
||
|
||
def job_with_artifacts(job: Job) -> Job:
|
||
updates = {"source_audio_url": source_audio_url_for(job.id)}
|
||
if not job.video_url and (JOBS_DIR / job.id / "source.mp4").exists():
|
||
updates["video_url"] = f"/jobs/{job.id}/video.mp4"
|
||
return job.model_copy(update=updates)
|
||
|
||
|
||
def save_state(job: Job) -> None:
|
||
(job_dir(job.id) / "state.json").write_text(job.model_dump_json(indent=2))
|
||
|
||
|
||
def update(job: Job, **kw) -> None:
|
||
for k, v in kw.items():
|
||
setattr(job, k, v)
|
||
save_state(job)
|
||
|
||
|
||
def public_api_base() -> str:
|
||
return (LLM_BASE_URL or "https://api.openai.com/v1").rstrip("/")
|
||
|
||
|
||
def video_uses_poe() -> bool:
|
||
if VIDEO_API_BASE_URL:
|
||
return VIDEO_API_BASE_URL.rstrip("/") == POE_API_BASE_URL.rstrip("/")
|
||
return bool(POE_API_KEY)
|
||
|
||
|
||
def video_uses_ark() -> bool:
|
||
base = video_api_base()
|
||
return "ark.cn-beijing.volces.com" in base or "ai.skg.com/doubao" in base
|
||
|
||
|
||
def video_provider_name() -> str:
|
||
base = video_api_base()
|
||
if video_uses_poe():
|
||
return "poe"
|
||
if "ai.skg.com/doubao" in base:
|
||
return "doubao"
|
||
if "ark.cn-beijing.volces.com" in base:
|
||
return "ark"
|
||
return "custom"
|
||
|
||
|
||
def video_api_base() -> str:
|
||
if VIDEO_API_BASE_URL:
|
||
return VIDEO_API_BASE_URL.rstrip("/")
|
||
if POE_API_KEY:
|
||
return POE_API_BASE_URL.rstrip("/")
|
||
return (LLM_BASE_URL or "https://api.openai.com/v1").rstrip("/")
|
||
|
||
|
||
def video_api_key() -> str:
|
||
if VIDEO_API_KEY:
|
||
return VIDEO_API_KEY
|
||
if video_uses_poe():
|
||
return POE_API_KEY
|
||
return LLM_API_KEY
|
||
|
||
|
||
def video_path(template: str, **values: str) -> str:
|
||
path = template.format(**values)
|
||
return path if path.startswith("/") else f"/{path}"
|
||
|
||
|
||
def ensure_video_api_configured() -> None:
|
||
if not video_api_key():
|
||
raise HTTPException(503, "POE_API_KEY、VIDEO_API_KEY 或 LLM_API_KEY 未配置,无法调用生视频 API")
|
||
|
||
|
||
def storyboard_ref_path(job_id: str, ref: dict | None) -> Path | None:
|
||
if not ref:
|
||
return None
|
||
try:
|
||
kind = ref.get("kind")
|
||
frame_idx = int(ref.get("frame_idx"))
|
||
except Exception:
|
||
return None
|
||
if kind == "keyframe":
|
||
clean = job_dir(job_id) / "cleaned" / f"{frame_idx:03d}.jpg"
|
||
if clean.exists():
|
||
return clean
|
||
p = job_dir(job_id) / "frames" / f"{frame_idx:03d}.jpg"
|
||
return p if p.exists() else None
|
||
if kind == "cutout":
|
||
element_id = (ref.get("element_id") or "").strip()
|
||
cutout_id = (ref.get("cutout_id") or "").strip()
|
||
if not element_id:
|
||
return None
|
||
candidates = []
|
||
if cutout_id and cutout_id != element_id:
|
||
candidates.append(job_dir(job_id) / "elements" / f"{frame_idx:03d}_{element_id}_{cutout_id}.jpg")
|
||
candidates.append(job_dir(job_id) / "elements" / f"{frame_idx:03d}_{element_id}.jpg")
|
||
candidates.append(job_dir(job_id) / "elements" / f"{frame_idx:03d}_{element_id}.png")
|
||
for p in candidates:
|
||
if p.exists():
|
||
return p
|
||
if kind == "asset":
|
||
asset_id = (ref.get("element_id") or ref.get("cutout_id") or "").strip()
|
||
if not asset_id:
|
||
return None
|
||
p = job_dir(job_id) / "assets" / f"{asset_id}.jpg"
|
||
return p if p.exists() else None
|
||
return None
|
||
|
||
|
||
def load_product_library_items() -> list[ProductLibraryItem]:
|
||
if not PRODUCT_LIBRARY_MANIFEST.exists():
|
||
return []
|
||
try:
|
||
data = json.loads(PRODUCT_LIBRARY_MANIFEST.read_text(encoding="utf-8"))
|
||
return [ProductLibraryItem(**item) for item in data.get("items", [])]
|
||
except Exception as e:
|
||
raise HTTPException(500, f"product library manifest invalid: {e}")
|
||
|
||
|
||
def find_product_library_item(product_id: str) -> ProductLibraryItem:
|
||
product_id = product_id.strip()
|
||
for item in load_product_library_items():
|
||
if item.id == product_id:
|
||
return item
|
||
raise HTTPException(404, "product library item not found")
|
||
|
||
|
||
def product_library_file(item: ProductLibraryItem) -> Path:
|
||
p = (PRODUCT_LIBRARY_DIR / item.filename).resolve()
|
||
try:
|
||
p.relative_to(PRODUCT_LIBRARY_DIR)
|
||
except ValueError:
|
||
raise HTTPException(400, "invalid product library path")
|
||
if not p.exists():
|
||
raise HTTPException(404, "product library image missing")
|
||
return p
|
||
|
||
|
||
def load_character_library_items() -> list[CharacterLibraryItem]:
|
||
if not CHARACTER_LIBRARY_MANIFEST.exists():
|
||
return []
|
||
try:
|
||
data = json.loads(CHARACTER_LIBRARY_MANIFEST.read_text(encoding="utf-8"))
|
||
items: list[CharacterLibraryItem] = []
|
||
for raw in data.get("characters", []):
|
||
item = CharacterLibraryItem(**raw)
|
||
for image in item.images:
|
||
image.url = f"/character-library/skg/images/{image.filename}"
|
||
items.append(item)
|
||
return items
|
||
except Exception as e:
|
||
raise HTTPException(500, f"character library manifest invalid: {e}")
|
||
|
||
|
||
def find_character_library_item(character_id: str) -> CharacterLibraryItem:
|
||
character_id = character_id.strip()
|
||
for item in load_character_library_items():
|
||
if item.id == character_id:
|
||
return item
|
||
raise HTTPException(404, "character library item not found")
|
||
|
||
|
||
def character_library_file(filename: str) -> Path:
|
||
p = (CHARACTER_LIBRARY_DIR / filename).resolve()
|
||
try:
|
||
p.relative_to(CHARACTER_LIBRARY_DIR)
|
||
except ValueError:
|
||
raise HTTPException(400, "invalid character library path")
|
||
if not p.exists():
|
||
raise HTTPException(404, "character library image missing")
|
||
return p
|
||
|
||
|
||
def load_subject_template_items() -> list[SubjectTemplateItem]:
|
||
if not SUBJECT_TEMPLATE_MANIFEST.exists():
|
||
return []
|
||
try:
|
||
data = json.loads(SUBJECT_TEMPLATE_MANIFEST.read_text(encoding="utf-8"))
|
||
items: list[SubjectTemplateItem] = []
|
||
for raw in data.get("templates", []):
|
||
item = SubjectTemplateItem(**raw)
|
||
for image in item.images:
|
||
image.url = f"/subject-templates/images/{image.filename}"
|
||
items.append(item)
|
||
items.sort(key=lambda item: item.updated_at or item.created_at, reverse=True)
|
||
return items
|
||
except Exception as e:
|
||
raise HTTPException(500, f"subject template manifest invalid: {e}")
|
||
|
||
|
||
def save_subject_template_items(items: list[SubjectTemplateItem]) -> None:
|
||
SUBJECT_TEMPLATE_MANIFEST.parent.mkdir(parents=True, exist_ok=True)
|
||
SUBJECT_TEMPLATE_MANIFEST.write_text(
|
||
json.dumps({"templates": [item.model_dump() for item in items]}, ensure_ascii=False, indent=2),
|
||
encoding="utf-8",
|
||
)
|
||
|
||
|
||
def find_subject_template_item(template_id: str) -> SubjectTemplateItem:
|
||
template_id = template_id.strip()
|
||
for item in load_subject_template_items():
|
||
if item.id == template_id:
|
||
return item
|
||
raise HTTPException(404, "subject template not found")
|
||
|
||
|
||
def subject_template_image_file(filename: str) -> Path:
|
||
p = (SUBJECT_TEMPLATE_IMAGE_DIR / filename).resolve()
|
||
try:
|
||
p.relative_to(SUBJECT_TEMPLATE_IMAGE_DIR)
|
||
except ValueError:
|
||
raise HTTPException(400, "invalid subject template image path")
|
||
if not p.exists():
|
||
raise HTTPException(404, "subject template image missing")
|
||
return p
|
||
|
||
|
||
def storyboard_ref_url(job_id: str, ref: dict | None) -> str:
|
||
if not ref:
|
||
return ""
|
||
kind = ref.get("kind")
|
||
frame_idx = ref.get("frame_idx")
|
||
if kind == "keyframe" and frame_idx is not None:
|
||
return f"/jobs/{job_id}/frames/{int(frame_idx)}.jpg"
|
||
if kind == "cutout" and frame_idx is not None and ref.get("element_id"):
|
||
element_id = ref.get("element_id")
|
||
cutout_id = ref.get("cutout_id")
|
||
if cutout_id and cutout_id != element_id:
|
||
return f"/jobs/{job_id}/frames/{int(frame_idx)}/elements/{element_id}/cutouts/{cutout_id}.jpg"
|
||
return f"/jobs/{job_id}/frames/{int(frame_idx)}/elements/{element_id}/cutout.jpg"
|
||
if kind == "asset" and ref.get("element_id"):
|
||
return f"/jobs/{job_id}/assets/{ref.get('element_id')}.jpg"
|
||
return ""
|
||
|
||
|
||
def prepare_video_reference(src: Path, dst: Path, size: tuple[int, int] = (720, 1280)) -> None:
|
||
dst.parent.mkdir(parents=True, exist_ok=True)
|
||
img = Image.open(src).convert("RGB")
|
||
img.thumbnail(size, Image.Resampling.LANCZOS)
|
||
canvas = Image.new("RGB", size, (8, 8, 10))
|
||
x = (size[0] - img.width) // 2
|
||
y = (size[1] - img.height) // 2
|
||
canvas.paste(img, (x, y))
|
||
canvas.save(dst, "JPEG", quality=94)
|
||
|
||
|
||
def update_generated_video(job_id: str, video_id: str, **kw) -> None:
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
return
|
||
updated = []
|
||
for v in job.generated_videos:
|
||
if v.id == video_id:
|
||
data = v.model_dump()
|
||
data.update(kw)
|
||
updated.append(GeneratedVideo(**data))
|
||
else:
|
||
updated.append(v)
|
||
update(job, generated_videos=updated)
|
||
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(_: FastAPI):
|
||
# 启动时从磁盘恢复 jobs(简化版:只列目录)
|
||
for p in JOBS_DIR.iterdir():
|
||
if p.is_dir() and (p / "state.json").exists():
|
||
try:
|
||
job = Job.model_validate_json((p / "state.json").read_text())
|
||
source_exists = (p / "source.mp4").exists()
|
||
if job.status in {"created", "downloading"}:
|
||
if source_exists:
|
||
update(job, status="downloaded", progress=25, error="", message="服务重启 · 视频已恢复,可重新解析")
|
||
else:
|
||
update(job, status="failed", message="服务重启 · 下载任务已中断,请重新提交")
|
||
elif job.status == "splitting":
|
||
update(
|
||
job,
|
||
status="frames_extracted" if job.frames else "downloaded",
|
||
progress=70 if job.frames else 25,
|
||
error="",
|
||
message="服务重启 · 上次抽帧已中断,可重新抽帧",
|
||
)
|
||
elif job.status == "transcribing":
|
||
audio_script = job.audio_script
|
||
if audio_script.status == "rewriting":
|
||
audio_script = audio_script.model_copy(update={
|
||
"status": "failed",
|
||
"error": "服务重启 · 上次音频改写/配音已中断,可重新处理",
|
||
"created_at": audio_script.created_at or time.time(),
|
||
})
|
||
update(
|
||
job,
|
||
status="frames_extracted",
|
||
progress=70,
|
||
error="",
|
||
audio_script=audio_script,
|
||
message="服务重启 · 上次音频处理已中断,可重新处理",
|
||
)
|
||
JOBS[p.name] = job
|
||
except Exception:
|
||
pass
|
||
yield
|
||
|
||
|
||
app = FastAPI(title="SKG TK 二创 API", lifespan=lifespan)
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=CORS_ORIGINS,
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
|
||
@app.get("/auth/check")
|
||
def auth_check(request: Request) -> Response:
|
||
ensure_auth_configured()
|
||
if not auth_username_from_request(request):
|
||
raise HTTPException(401, "unauthorized")
|
||
return Response(status_code=204)
|
||
|
||
|
||
@app.post("/auth/login")
|
||
def auth_login(payload: AuthLoginPayload, response: Response) -> dict:
|
||
ensure_auth_configured()
|
||
username = payload.username.strip()
|
||
password = payload.password
|
||
valid_user = hmac.compare_digest(username, WEB_AUTH_USERNAME)
|
||
valid_password = hmac.compare_digest(password, WEB_AUTH_PASSWORD)
|
||
if not (valid_user and valid_password):
|
||
raise HTTPException(401, "用户名或密码不正确")
|
||
|
||
ttl_seconds = 60 * 60 * 24 * 30 if payload.remember else 60 * 60 * 12
|
||
response.set_cookie(
|
||
key=WEB_AUTH_COOKIE_NAME,
|
||
value=make_auth_token(WEB_AUTH_USERNAME, ttl_seconds),
|
||
max_age=ttl_seconds,
|
||
httponly=True,
|
||
secure=WEB_AUTH_COOKIE_SECURE,
|
||
samesite="lax",
|
||
path="/",
|
||
)
|
||
return {"ok": True, "username": WEB_AUTH_USERNAME}
|
||
|
||
|
||
@app.post("/auth/logout")
|
||
def auth_logout(response: Response) -> dict:
|
||
response.delete_cookie(
|
||
key=WEB_AUTH_COOKIE_NAME,
|
||
path="/",
|
||
secure=WEB_AUTH_COOKIE_SECURE,
|
||
samesite="lax",
|
||
)
|
||
return {"ok": True}
|
||
|
||
|
||
# ---------- Pipeline 实现 ----------
|
||
|
||
def _binary_works(path: str) -> bool:
|
||
if not path:
|
||
return False
|
||
if os.path.sep in path and not Path(path).exists():
|
||
return False
|
||
try:
|
||
res = subprocess.run([path, "-version"], capture_output=True, text=True, timeout=5)
|
||
return res.returncode == 0
|
||
except Exception:
|
||
return False
|
||
|
||
|
||
def media_binary(name: Literal["ffmpeg", "ffprobe"]) -> str:
|
||
cached = _MEDIA_BIN_CACHE.get(name)
|
||
if cached:
|
||
return cached
|
||
env_bin = FFMPEG_BIN if name == "ffmpeg" else FFPROBE_BIN
|
||
candidates: list[str] = []
|
||
if env_bin:
|
||
candidates.append(env_bin)
|
||
found = shutil.which(name)
|
||
if found:
|
||
candidates.append(found)
|
||
if name == "ffmpeg":
|
||
candidates.extend(LOCAL_FFMPEG_CANDIDATES)
|
||
for candidate in candidates:
|
||
if _binary_works(candidate):
|
||
_MEDIA_BIN_CACHE[name] = candidate
|
||
return candidate
|
||
raise RuntimeError(f"{name} 不可用,请配置 {name.upper()}_BIN 或修复本机 ffmpeg 安装")
|
||
|
||
|
||
def _normalize_media_cmd(cmd: list[str]) -> list[str]:
|
||
if not cmd:
|
||
return cmd
|
||
if cmd[0] == "ffmpeg":
|
||
return [media_binary("ffmpeg"), *cmd[1:]]
|
||
if cmd[0] == "ffprobe":
|
||
return [media_binary("ffprobe"), *cmd[1:]]
|
||
return cmd
|
||
|
||
|
||
def run(cmd: list[str], cwd: Path | None = None) -> str:
|
||
cmd = _normalize_media_cmd(cmd)
|
||
res = subprocess.run(cmd, cwd=cwd, capture_output=True, text=True)
|
||
if res.returncode != 0:
|
||
# ffmpeg 把 banner 写 stderr,挑最后几行(真错误一般在末尾)
|
||
tail = "\n".join(res.stderr.splitlines()[-12:]) or res.stderr[-800:]
|
||
raise RuntimeError(f"cmd failed: {' '.join(cmd[:3])}... · {tail}")
|
||
return res.stdout
|
||
|
||
|
||
def ytdlp_cookie_args() -> list[str]:
|
||
if YTDLP_COOKIES_FILE:
|
||
cookies = Path(YTDLP_COOKIES_FILE).expanduser()
|
||
if not cookies.exists():
|
||
raise RuntimeError("TikTok cookies 文件不可用,请检查 YTDLP_COOKIES_FILE 配置。")
|
||
return ["--cookies", str(cookies)]
|
||
if YTDLP_COOKIES_FROM_BROWSER:
|
||
return ["--cookies-from-browser", YTDLP_COOKIES_FROM_BROWSER]
|
||
return []
|
||
|
||
|
||
def normalize_download_error(error: Exception) -> str:
|
||
raw = str(error)
|
||
lower = raw.lower()
|
||
auth_required = (
|
||
"log in for access" in lower
|
||
or "login" in lower and "cookies" in lower
|
||
or "cookies-from-browser" in lower
|
||
or "sign in" in lower and "tiktok" in lower
|
||
)
|
||
if auth_required:
|
||
return (
|
||
"TikTok 下载需要登录态。请上传视频文件,或在后端配置 "
|
||
"YTDLP_COOKIES_FILE / YTDLP_COOKIES_FROM_BROWSER 后重试。"
|
||
f"原始错误:{raw}"
|
||
)
|
||
return raw
|
||
|
||
|
||
# ---- 启发式选帧工具 ----
|
||
import imagehash
|
||
import numpy as np
|
||
from PIL import Image, ImageChops, ImageEnhance, ImageFilter, ImageOps
|
||
|
||
|
||
def _sharpness_from_gray(g: np.ndarray) -> float:
|
||
"""Laplacian variance:值越大越清晰,模糊/转场帧值低。"""
|
||
lap = (-4 * g[1:-1, 1:-1]
|
||
+ g[:-2, 1:-1] + g[2:, 1:-1] + g[1:-1, :-2] + g[1:-1, 2:])
|
||
return float(lap.var())
|
||
|
||
|
||
def _frame_metrics(img_path: Path, idx: int, timestamp: float, metric_width: int = 160) -> dict | None:
|
||
"""低清候选帧的本地评分特征。只用于排序,最终仍从原视频抽原尺寸帧。"""
|
||
try:
|
||
with Image.open(img_path) as raw:
|
||
img = raw.convert("RGB")
|
||
h = imagehash.phash(img)
|
||
src_w, src_h = img.size
|
||
metric_height = max(1, round(metric_width * src_h / max(src_w, 1)))
|
||
small = img.resize((metric_width, metric_height))
|
||
except Exception:
|
||
return None
|
||
|
||
arr = np.asarray(small, dtype=np.float32)
|
||
# Rec. 601 luma,保留 0-255 范围,便于和清晰度 / 对比度阈值一起看。
|
||
gray = (0.299 * arr[:, :, 0] + 0.587 * arr[:, :, 1] + 0.114 * arr[:, :, 2]).astype(np.float32)
|
||
gh, gw = gray.shape
|
||
center = gray[gh // 4:max(gh // 4 + 1, gh * 3 // 4), gw // 4:max(gw // 4 + 1, gw * 3 // 4)]
|
||
rg = arr[:, :, 0] - arr[:, :, 1]
|
||
yb = 0.5 * (arr[:, :, 0] + arr[:, :, 1]) - arr[:, :, 2]
|
||
colorfulness = float(np.sqrt(rg.var() + yb.var()) + 0.3 * np.sqrt(rg.mean() ** 2 + yb.mean() ** 2))
|
||
return {
|
||
"path": img_path,
|
||
"idx": idx,
|
||
"timestamp": timestamp,
|
||
"hash": h,
|
||
"gray": gray,
|
||
"sharp": _sharpness_from_gray(gray),
|
||
"center_sharp": _sharpness_from_gray(center),
|
||
"brightness": float(gray.mean()),
|
||
"contrast": float(gray.std()),
|
||
"colorfulness": colorfulness,
|
||
"scene_score": 0.0,
|
||
"motion": 0.0,
|
||
}
|
||
|
||
|
||
def _physical_memory_gb() -> float:
|
||
try:
|
||
page_size = os.sysconf("SC_PAGE_SIZE")
|
||
pages = os.sysconf("SC_PHYS_PAGES")
|
||
return float(page_size * pages) / (1024 ** 3)
|
||
except Exception:
|
||
return 0.0
|
||
|
||
|
||
def _resolve_frame_quality(duration: float, quality: FrameExtractQuality) -> FrameExtractQuality:
|
||
if quality != "auto":
|
||
return quality
|
||
cores = os.cpu_count() or 4
|
||
memory_gb = _physical_memory_gb()
|
||
strong_machine = cores >= 10 and (memory_gb == 0.0 or memory_gb >= 32)
|
||
# 展示/演示时不能把本机资源打满:auto 最高只到 accurate。
|
||
# ultra 保留为手动选择项,不再由 auto 自动命中。
|
||
if strong_machine and duration <= 600:
|
||
return "accurate"
|
||
if cores >= 8 and duration <= 240:
|
||
return "accurate"
|
||
return "fast"
|
||
|
||
|
||
def _scan_profile(duration: float, quality: FrameExtractQuality) -> tuple[float, int, int, int]:
|
||
"""返回 scan_fps / scan_width / metric_width / estimated_count。"""
|
||
if quality == "ultra":
|
||
base_fps, scan_width, cap, metric_width = 12.0, 960, 1800, 320
|
||
elif quality == "accurate":
|
||
base_fps, scan_width, cap, metric_width = 8.0, 720, 900, 240
|
||
else:
|
||
base_fps, scan_width, cap, metric_width = 2.0, 360, 240, 160
|
||
|
||
estimated = max(1, min(int(duration * base_fps), cap))
|
||
scan_fps = max(0.02, min(base_fps, estimated / max(duration, 0.1)))
|
||
return scan_fps, scan_width, metric_width, estimated
|
||
|
||
|
||
def _image_quality_report(img_path: Path, region: dict | None = None) -> QualityReport:
|
||
warnings: list[str] = []
|
||
try:
|
||
with Image.open(img_path) as raw:
|
||
img = raw.convert("RGB")
|
||
width, height = img.size
|
||
metric_width = min(512, width)
|
||
metric_height = max(1, round(metric_width * height / max(width, 1)))
|
||
small = img.resize((metric_width, metric_height))
|
||
gray = np.asarray(ImageOps.grayscale(small), dtype=np.float32)
|
||
sharp = _sharpness_from_gray(gray)
|
||
except Exception:
|
||
return QualityReport(risk="bad", warnings=["无法读取图片质量信息"])
|
||
|
||
short_side = min(width, height)
|
||
if short_side < 720:
|
||
warnings.append(f"短边 {short_side}px 低于 720px,生视频可能偏糊")
|
||
if sharp < 30:
|
||
warnings.append("清晰度偏低,高清增强后仍可能有细节损失")
|
||
|
||
if region:
|
||
try:
|
||
rw = int(float(region.get("w", 0)) * width)
|
||
rh = int(float(region.get("h", 0)) * height)
|
||
if min(rw, rh) < 512:
|
||
warnings.append(f"主体框约 {rw}×{rh}px,主体素材偏小")
|
||
except Exception:
|
||
pass
|
||
|
||
risk: Literal["ok", "warn", "bad"] = "ok"
|
||
if any("低于" in w or "偏小" in w for w in warnings):
|
||
risk = "warn"
|
||
if short_side < 480 or sharp < 12:
|
||
risk = "bad"
|
||
return QualityReport(width=width, height=height, short_side=short_side, sharpness=round(sharp, 2), risk=risk, warnings=warnings)
|
||
|
||
|
||
def _asset_target_size(source_path: Path, size: AssetSize, square: bool = False) -> tuple[int, int]:
|
||
try:
|
||
with Image.open(source_path) as raw:
|
||
src_w, src_h = raw.size
|
||
except Exception:
|
||
src_w, src_h = 1024, 1024
|
||
if size == "source":
|
||
return max(1, src_w), max(1, src_h)
|
||
side = int(size)
|
||
if square:
|
||
return side, side
|
||
if src_w >= src_h:
|
||
return side, max(1, round(side * src_h / max(src_w, 1)))
|
||
return max(1, round(side * src_w / max(src_h, 1))), side
|
||
|
||
|
||
def _normalize_asset_image(
|
||
img_bytes: bytes,
|
||
out_path: Path,
|
||
source_path: Path,
|
||
size: AssetSize,
|
||
background: AssetBackground = "white",
|
||
square: bool = False,
|
||
fill_subject: bool = False,
|
||
) -> tuple[int, int]:
|
||
import io as _io
|
||
target_w, target_h = _asset_target_size(source_path, size, square=square)
|
||
bg = (255, 255, 255) if background == "white" else (0, 0, 0)
|
||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||
with Image.open(_io.BytesIO(img_bytes)) as raw:
|
||
img = raw.convert("RGB")
|
||
if fill_subject:
|
||
diff = ImageChops.difference(img, Image.new("RGB", img.size, bg))
|
||
mask = diff.convert("L").point(lambda px: 255 if px > 18 else 0)
|
||
bbox = mask.getbbox()
|
||
if bbox:
|
||
left, top, right, bottom = bbox
|
||
pad_x = round((right - left) * 0.06)
|
||
pad_y = round((bottom - top) * 0.06)
|
||
img = img.crop((
|
||
max(0, left - pad_x),
|
||
max(0, top - pad_y),
|
||
min(img.width, right + pad_x),
|
||
min(img.height, bottom + pad_y),
|
||
))
|
||
max_w = max(1, round(target_w * 0.92))
|
||
max_h = max(1, round(target_h * 0.94))
|
||
img.thumbnail((max_w, max_h), Image.Resampling.LANCZOS)
|
||
else:
|
||
img.thumbnail((target_w, target_h), Image.Resampling.LANCZOS)
|
||
canvas = Image.new("RGB", (target_w, target_h), bg)
|
||
canvas.paste(img, ((target_w - img.width) // 2, (target_h - img.height) // 2))
|
||
canvas.save(out_path, "JPEG", quality=95)
|
||
return target_w, target_h
|
||
|
||
|
||
def _asset_url(job_id: str, asset_id: str) -> str:
|
||
return f"/jobs/{job_id}/assets/{asset_id}.jpg"
|
||
|
||
|
||
def _delete_subject_asset_file(job_id: str, asset_id: str) -> None:
|
||
if not asset_id:
|
||
return
|
||
p = job_dir(job_id) / "assets" / f"{asset_id}.jpg"
|
||
if p.exists():
|
||
try:
|
||
p.unlink()
|
||
except OSError:
|
||
pass
|
||
|
||
|
||
def _find_frame(job: Job, idx: int) -> KeyFrame:
|
||
frame = next((f for f in job.frames if f.index == idx), None)
|
||
if not frame:
|
||
raise HTTPException(404, "frame not found")
|
||
return frame
|
||
|
||
|
||
def _source_frame_path(job_id: str, idx: int) -> Path:
|
||
cleaned_path = job_dir(job_id) / "cleaned" / f"{idx:03d}.jpg"
|
||
if cleaned_path.exists():
|
||
return cleaned_path
|
||
return job_dir(job_id) / "frames" / f"{idx:03d}.jpg"
|
||
|
||
|
||
def _focus_source_for_element(job_id: str, idx: int, el: KeyElement) -> tuple[Path, Path | None]:
|
||
import tempfile as _tempfile
|
||
src = _source_frame_path(job_id, idx)
|
||
tmp_focus: Path | None = None
|
||
model_src = src
|
||
if not el.region:
|
||
return model_src, tmp_focus
|
||
try:
|
||
im = Image.open(src).convert("RGB")
|
||
W, H = im.size
|
||
r = el.region
|
||
x = max(0.0, min(1.0, float(r.get("x", 0))))
|
||
y = max(0.0, min(1.0, float(r.get("y", 0))))
|
||
w = max(0.0, min(1.0 - x, float(r.get("w", 0))))
|
||
h = max(0.0, min(1.0 - y, float(r.get("h", 0))))
|
||
cx, cy = x + w / 2, y + h / 2
|
||
ew, eh = w * 1.6, h * 1.6
|
||
x0 = max(0.0, cx - ew / 2); y0 = max(0.0, cy - eh / 2)
|
||
x1 = min(1.0, cx + ew / 2); y1 = min(1.0, cy + eh / 2)
|
||
left, top, right, bottom = int(x0 * W), int(y0 * H), int(x1 * W), int(y1 * H)
|
||
if right - left > 8 and bottom - top > 8:
|
||
cropped = im.crop((left, top, right, bottom))
|
||
tmp = _tempfile.NamedTemporaryFile(suffix=".jpg", delete=False)
|
||
cropped.save(tmp.name, format="JPEG", quality=92)
|
||
tmp.close()
|
||
tmp_focus = Path(tmp.name)
|
||
model_src = tmp_focus
|
||
except Exception as e:
|
||
print(f"[focus source crop failed, fallback to full frame] {e}", flush=True)
|
||
return model_src, tmp_focus
|
||
|
||
|
||
def _make_reference_contact_sheet(job_id: str, frame_indices: list[int], out_path: Path, max_items: int = 6) -> Path | None:
|
||
paths: list[Path] = []
|
||
seen: set[int] = set()
|
||
max_items = max(2, min(12, int(max_items or 6)))
|
||
for idx in frame_indices:
|
||
if idx in seen:
|
||
continue
|
||
seen.add(idx)
|
||
p = _source_frame_path(job_id, idx)
|
||
if p.exists():
|
||
paths.append(p)
|
||
if len(paths) >= max_items:
|
||
break
|
||
if len(paths) <= 1:
|
||
return None
|
||
|
||
thumbs: list[Image.Image] = []
|
||
for p in paths:
|
||
try:
|
||
im = Image.open(p).convert("RGB")
|
||
im.thumbnail((420, 420), Image.Resampling.LANCZOS)
|
||
canvas = Image.new("RGB", (420, 420), (245, 245, 245))
|
||
canvas.paste(im, ((420 - im.width) // 2, (420 - im.height) // 2))
|
||
thumbs.append(canvas)
|
||
except Exception:
|
||
continue
|
||
if len(thumbs) <= 1:
|
||
return None
|
||
|
||
cols = 4 if len(thumbs) > 6 else (3 if len(thumbs) > 2 else 2)
|
||
rows = (len(thumbs) + cols - 1) // cols
|
||
sheet = Image.new("RGB", (cols * 420, rows * 420), (245, 245, 245))
|
||
for i, thumb in enumerate(thumbs):
|
||
sheet.paste(thumb, ((i % cols) * 420, (i // cols) * 420))
|
||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||
sheet.save(out_path, "JPEG", quality=92)
|
||
return out_path
|
||
|
||
|
||
def _make_paths_contact_sheet(paths: list[Path], out_path: Path, max_items: int = 10) -> Path | None:
|
||
usable: list[Path] = []
|
||
seen: set[str] = set()
|
||
max_items = max(2, min(12, int(max_items or 10)))
|
||
for p in paths:
|
||
key = str(p)
|
||
if key in seen or not p.exists():
|
||
continue
|
||
seen.add(key)
|
||
usable.append(p)
|
||
if len(usable) >= max_items:
|
||
break
|
||
if len(usable) <= 1:
|
||
return usable[0] if usable else None
|
||
|
||
thumbs: list[Image.Image] = []
|
||
for p in usable:
|
||
try:
|
||
im = Image.open(p).convert("RGB")
|
||
im.thumbnail((420, 420), Image.Resampling.LANCZOS)
|
||
canvas = Image.new("RGB", (420, 420), (245, 245, 245))
|
||
canvas.paste(im, ((420 - im.width) // 2, (420 - im.height) // 2))
|
||
thumbs.append(canvas)
|
||
except Exception:
|
||
continue
|
||
if len(thumbs) <= 1:
|
||
return usable[0] if usable else None
|
||
|
||
cols = 4 if len(thumbs) > 6 else (3 if len(thumbs) > 2 else 2)
|
||
rows = (len(thumbs) + cols - 1) // cols
|
||
sheet = Image.new("RGB", (cols * 420, rows * 420), (245, 245, 245))
|
||
for i, thumb in enumerate(thumbs):
|
||
sheet.paste(thumb, ((i % cols) * 420, (i // cols) * 420))
|
||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||
sheet.save(out_path, "JPEG", quality=92)
|
||
return out_path
|
||
|
||
|
||
SUBJECT_VIEW_LABELS: dict[str, str] = {
|
||
"front": "正面",
|
||
"back": "背面",
|
||
"left": "左侧",
|
||
"right": "右侧",
|
||
"three_quarter_left": "左前 45°",
|
||
"three_quarter_right": "右前 45°",
|
||
"side": "侧面",
|
||
"side_walk": "侧面走路",
|
||
"top": "顶部视角",
|
||
"bottom": "底部视角",
|
||
"expression_neutral": "中性表情",
|
||
"expression_smile": "微笑表情",
|
||
"expression_happy": "开心表情",
|
||
"expression_angry": "生气表情",
|
||
"expression_sad": "难过表情",
|
||
"expression_relaxed": "放松表情",
|
||
"expression_serious": "严肃表情",
|
||
"expression_surprised": "惊讶表情",
|
||
"action_walk": "走路动作",
|
||
"action_turn": "转身动作",
|
||
"action_sit": "坐姿动作",
|
||
"action_hold": "手持动作",
|
||
"action_use": "使用动作",
|
||
"bust_front": "肩颈半身正面近景",
|
||
"bust_left_45": "肩颈左前 45° 近景",
|
||
"bust_right_45": "肩颈右前 45° 近景",
|
||
"back_neck_detail": "后颈/肩背特写",
|
||
"bust": "半身近景",
|
||
"back_detail": "背部特写",
|
||
}
|
||
|
||
|
||
def _subject_view_labels(kind: SubjectKind, requested: list[str] | None = None) -> list[tuple[SubjectView, str]]:
|
||
if requested:
|
||
normalized: list[str] = []
|
||
for raw in requested:
|
||
key = "".join(ch for ch in str(raw).strip().lower() if ch.isalnum() or ch == "_")
|
||
if key and key not in normalized:
|
||
normalized.append(key)
|
||
return [(key, SUBJECT_VIEW_LABELS.get(key, key.replace("_", " "))) for key in normalized[:10]]
|
||
if kind == "living":
|
||
return [
|
||
("front", "正面站立"),
|
||
("three_quarter_left", "左前 45° 站立"),
|
||
("left", "左侧站立"),
|
||
("back", "背面站立"),
|
||
("right", "右侧站立"),
|
||
("three_quarter_right", "右前 45° 站立"),
|
||
("bust_front", "肩颈半身正面近景"),
|
||
("bust_left_45", "肩颈左前 45° 近景"),
|
||
("bust_right_45", "肩颈右前 45° 近景"),
|
||
("back_neck_detail", "后颈/肩背特写"),
|
||
]
|
||
return [
|
||
("front", "正面"),
|
||
("back", "背面"),
|
||
("left", "左侧"),
|
||
("right", "右侧"),
|
||
("top", "顶部"),
|
||
("bottom", "底部"),
|
||
]
|
||
|
||
|
||
def _attach_temporal_metrics(items: list[dict]) -> None:
|
||
"""相邻低清帧差异:转场 / 动作目标依赖它,不需要逐帧高分辨率扫描。"""
|
||
for i, it in enumerate(items):
|
||
prev_delta = 0.0
|
||
next_delta = 0.0
|
||
if i > 0:
|
||
prev_delta = float(np.mean(np.abs(it["gray"] - items[i - 1]["gray"])) / 255.0)
|
||
if i + 1 < len(items):
|
||
next_delta = float(np.mean(np.abs(items[i + 1]["gray"] - it["gray"])) / 255.0)
|
||
it["scene_score"] = max(prev_delta, next_delta)
|
||
it["motion"] = (prev_delta + next_delta) / 2.0
|
||
|
||
|
||
def _normalize_item_metrics(items: list[dict]) -> None:
|
||
for key in ("sharp", "center_sharp", "contrast", "colorfulness", "scene_score", "motion"):
|
||
vals = [float(it.get(key, 0.0)) for it in items if float(it.get(key, 0.0)) > 0]
|
||
cap = float(np.percentile(vals, 95)) if vals else 1.0
|
||
if cap <= 0:
|
||
cap = 1.0
|
||
for it in items:
|
||
it[f"{key}_n"] = min(float(it.get(key, 0.0)) / cap, 1.0)
|
||
|
||
|
||
def _target_score(item: dict, target: FrameExtractTarget) -> float:
|
||
sharp = float(item.get("sharp_n", 0.0))
|
||
center = float(item.get("center_sharp_n", 0.0))
|
||
contrast = float(item.get("contrast_n", 0.0))
|
||
color = float(item.get("colorfulness_n", 0.0))
|
||
scene = float(item.get("scene_score_n", 0.0))
|
||
motion = float(item.get("motion_n", 0.0))
|
||
|
||
if target == "transparent_human":
|
||
# 当前抽帧阶段走本地算力:优先清晰中心主体、高对比、适度色彩和时间覆盖。
|
||
# 透明骨架人的语义判断留给后续审核/识别,不在抽帧阶段逐帧调用 Vision。
|
||
score = center * 0.45 + sharp * 0.30 + contrast * 0.15 + color * 0.10
|
||
elif target == "subject":
|
||
score = center * 0.48 + sharp * 0.25 + contrast * 0.17 + color * 0.10
|
||
elif target == "transition":
|
||
score = scene * 0.55 + sharp * 0.28 + contrast * 0.12 + color * 0.05
|
||
elif target == "expression":
|
||
# 没有额外视觉模型时,表情/动物瞬间只能用中心细节 + 清晰 + 轻微动作变化做本地近似。
|
||
score = center * 0.40 + sharp * 0.24 + motion * 0.18 + contrast * 0.12 + color * 0.06
|
||
elif target == "motion":
|
||
score = motion * 0.45 + sharp * 0.30 + center * 0.15 + contrast * 0.10
|
||
else:
|
||
score = sharp * 0.45 + scene * 0.22 + center * 0.15 + contrast * 0.12 + color * 0.06
|
||
|
||
brightness = float(item.get("brightness", 0.0))
|
||
raw_contrast = float(item.get("contrast", 0.0))
|
||
if raw_contrast < 4 or brightness < 8 or brightness > 247:
|
||
return score * 0.15
|
||
if raw_contrast < 9:
|
||
return score * 0.65
|
||
return score
|
||
|
||
|
||
def _select_keyframes(candidates: list[dict], n: int, target: FrameExtractTarget, dup_threshold: int = 8) -> list[dict]:
|
||
"""
|
||
candidates: 按时间排序的低清候选帧评分项
|
||
n: 目标帧数
|
||
dup_threshold: pHash 汉明距离 < 此值视为相似(默认 8,64bit hash 大致 ~12.5% 像素差)
|
||
"""
|
||
if len(candidates) <= n:
|
||
return candidates
|
||
|
||
_attach_temporal_metrics(candidates)
|
||
_normalize_item_metrics(candidates)
|
||
for it in candidates:
|
||
it["score"] = _target_score(it, target)
|
||
|
||
# 去重:相似帧保留当前目标下分数更高的
|
||
deduped: list[dict] = []
|
||
for it in candidates:
|
||
dup = None
|
||
for kept in deduped:
|
||
if (it["hash"] - kept["hash"]) < dup_threshold:
|
||
dup = kept
|
||
break
|
||
if dup is None:
|
||
deduped.append(it)
|
||
elif it["score"] > dup["score"]:
|
||
deduped[deduped.index(dup)] = it
|
||
|
||
# 时序分桶:把候选时间轴等分 n 段,每段取当前目标下最优的
|
||
total = len(candidates)
|
||
buckets: list[list[dict]] = [[] for _ in range(n)]
|
||
for it in deduped:
|
||
b = min(int(it["idx"] * n / total), n - 1)
|
||
buckets[b].append(it)
|
||
|
||
selected: list[dict] = []
|
||
for b in buckets:
|
||
if b:
|
||
selected.append(max(b, key=lambda x: x["score"]))
|
||
|
||
# 空桶补足:从未选的 deduped 里按目标分数补
|
||
chosen_paths = {it["path"] for it in selected}
|
||
remaining = sorted([it for it in deduped if it["path"] not in chosen_paths],
|
||
key=lambda x: -x["score"])
|
||
while len(selected) < n and remaining:
|
||
selected.append(remaining.pop(0))
|
||
|
||
# 按时间排序输出
|
||
selected.sort(key=lambda x: x["idx"])
|
||
return selected
|
||
|
||
|
||
def _rank_keyframe_candidates(candidates: list[dict], target: FrameExtractTarget, limit: int, dup_threshold: int = 8) -> list[dict]:
|
||
if not candidates:
|
||
return []
|
||
_attach_temporal_metrics(candidates)
|
||
_normalize_item_metrics(candidates)
|
||
for it in candidates:
|
||
it["score"] = _target_score(it, target)
|
||
deduped: list[dict] = []
|
||
for it in sorted(candidates, key=lambda x: -float(x.get("score", 0.0))):
|
||
if any((it["hash"] - kept["hash"]) < dup_threshold for kept in deduped):
|
||
continue
|
||
deduped.append(it)
|
||
if len(deduped) >= limit:
|
||
break
|
||
return deduped
|
||
|
||
|
||
def _score_transparent_human_frame(img_path: Path) -> TransparentHumanFrameScore:
|
||
if not LLM_API_KEY:
|
||
return TransparentHumanFrameScore(
|
||
qualified=False,
|
||
reject_reason="LLM_API_KEY 未配置,无法进行透明骨架人语义验收",
|
||
)
|
||
img_b64 = base64.b64encode(img_path.read_bytes()).decode("ascii")
|
||
prompt = (
|
||
"You are a strict keyframe quality inspector for a SKG transparent-human video recreation workflow. "
|
||
+ TRANSPARENT_HUMAN_POSITIVE_PROMPT + " "
|
||
+ TRANSPARENT_HUMAN_NEGATIVE_PROMPT + " "
|
||
+ TRANSPARENT_HUMAN_QUALIFIED_STANDARD + "\n\n"
|
||
"Score this single frame using exactly these dimensions:\n"
|
||
"- transparent_body_score: 0-25, clear transparent/translucent outer human body shell.\n"
|
||
"- skeleton_visible_score: 0-25, clean white skeleton clearly visible inside the body.\n"
|
||
"- human_prominence_score: 0-15, character centered/large/easy to identify, ideally >=35% frame height.\n"
|
||
"- clarity_score: 0-15, no severe motion blur, occlusion, or deformation.\n"
|
||
"- commercial_style_score: 0-10, clean premium non-horror advertising/wellness style.\n"
|
||
"- product_usefulness_score: 0-10, useful for later SKG product video generation; neck/shoulder/waist/eye/foot/knee area visible when relevant.\n"
|
||
"Reject if any of these is true: normal human only; ordinary skeleton only; product/background only; transparent person too far; severe blur; more than half occluded; horror/corpse/autopsy/surgery/hospital; unable to judge.\n"
|
||
"Output strict JSON only with keys: transparent_body_score, skeleton_visible_score, human_prominence_score, clarity_score, commercial_style_score, product_usefulness_score, qualified, reject_reason, notes."
|
||
)
|
||
try:
|
||
resp = llm().chat.completions.create(
|
||
model=VISION_MODEL,
|
||
messages=[{"role": "user", "content": [
|
||
{"type": "text", "text": prompt},
|
||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img_b64}"}},
|
||
]}],
|
||
response_format={"type": "json_object"},
|
||
temperature=0.1,
|
||
max_tokens=1200,
|
||
)
|
||
raw = (resp.choices[0].message.content or "").strip()
|
||
if not raw:
|
||
raw = (getattr(resp.choices[0].message, "reasoning_content", "") or "").strip()
|
||
import re as _re
|
||
match = _re.search(r"\{[\s\S]*\}", raw)
|
||
raw = match.group(0) if match else raw
|
||
data = json.loads(raw)
|
||
except Exception as e:
|
||
return TransparentHumanFrameScore(qualified=False, reject_reason=f"AI 评分失败:{e}")
|
||
|
||
def score(name: str, cap: int) -> int:
|
||
try:
|
||
value = int(round(float(data.get(name, 0))))
|
||
except Exception:
|
||
value = 0
|
||
return max(0, min(cap, value))
|
||
|
||
item = TransparentHumanFrameScore(
|
||
transparent_body_score=score("transparent_body_score", 25),
|
||
skeleton_visible_score=score("skeleton_visible_score", 25),
|
||
human_prominence_score=score("human_prominence_score", 15),
|
||
clarity_score=score("clarity_score", 15),
|
||
commercial_style_score=score("commercial_style_score", 10),
|
||
product_usefulness_score=score("product_usefulness_score", 10),
|
||
reject_reason=str(data.get("reject_reason", "") or ""),
|
||
notes=str(data.get("notes", "") or ""),
|
||
)
|
||
item.total_score = (
|
||
item.transparent_body_score
|
||
+ item.skeleton_visible_score
|
||
+ item.human_prominence_score
|
||
+ item.clarity_score
|
||
+ item.commercial_style_score
|
||
+ item.product_usefulness_score
|
||
)
|
||
item.qualified = bool(data.get("qualified")) and (
|
||
item.transparent_body_score >= 18
|
||
and item.skeleton_visible_score >= 18
|
||
and item.human_prominence_score >= 8
|
||
and item.clarity_score >= 8
|
||
and item.commercial_style_score >= 6
|
||
and item.product_usefulness_score >= 4
|
||
and item.total_score >= 72
|
||
)
|
||
if not item.qualified and not item.reject_reason:
|
||
item.reject_reason = f"透明骨架人评分不足,总分 {item.total_score}/100"
|
||
return item
|
||
|
||
|
||
def _duration_from_text(text: str) -> float:
|
||
m = re.search(r"Duration:\s*(\d+):(\d+):(\d+(?:\.\d+)?)", text)
|
||
if not m:
|
||
return 0.0
|
||
hours, minutes, seconds = m.groups()
|
||
return int(hours) * 3600 + int(minutes) * 60 + float(seconds)
|
||
|
||
|
||
def _ffmpeg_probe_text(path: Path) -> str:
|
||
ffmpeg = media_binary("ffmpeg")
|
||
res = subprocess.run([ffmpeg, "-hide_banner", "-i", str(path)], capture_output=True, text=True)
|
||
text = "\n".join(part for part in [res.stdout, res.stderr] if part)
|
||
if "Input #0" not in text:
|
||
tail = "\n".join(text.splitlines()[-12:])
|
||
raise RuntimeError(f"ffmpeg 读取媒体失败:{tail}")
|
||
return text
|
||
|
||
|
||
def _ffmpeg_meta_fallback(path: Path) -> dict:
|
||
text = _ffmpeg_probe_text(path)
|
||
duration = _duration_from_text(text)
|
||
streams: list[dict] = []
|
||
for line in text.splitlines():
|
||
if " Video:" not in line:
|
||
continue
|
||
m = re.search(r"(?<![.\d])(\d{2,5})x(\d{2,5})(?![.\d])", line)
|
||
if m:
|
||
streams.append({"codec_type": "video", "width": int(m.group(1)), "height": int(m.group(2))})
|
||
return {"streams": streams, "format": {"duration": str(duration)}}
|
||
|
||
|
||
def ffprobe_meta(mp4: Path) -> dict:
|
||
try:
|
||
out = run([
|
||
"ffprobe", "-v", "error", "-print_format", "json", "-show_streams", "-show_format", str(mp4),
|
||
])
|
||
return json.loads(out)
|
||
except Exception:
|
||
return _ffmpeg_meta_fallback(mp4)
|
||
|
||
|
||
def media_duration(path: Path) -> float:
|
||
try:
|
||
out = run([
|
||
"ffprobe", "-v", "error", "-print_format", "json", "-show_format", str(path),
|
||
])
|
||
return float(json.loads(out).get("format", {}).get("duration") or 0)
|
||
except Exception:
|
||
try:
|
||
return _duration_from_text(_ffmpeg_probe_text(path))
|
||
except Exception:
|
||
return 0.0
|
||
|
||
|
||
def pipeline_download(job_id: str) -> None:
|
||
"""阶段 1:仅下载(或上传跳过),落 source.mp4;前端开始流程会在 downloaded 后触发音频解析。"""
|
||
job = JOBS[job_id]
|
||
d = job_dir(job_id)
|
||
stage = "download"
|
||
try:
|
||
mp4 = d / "source.mp4"
|
||
if mp4.exists():
|
||
update(job, status="downloading", message="本地上传 · 跳过下载", progress=15)
|
||
else:
|
||
update(job, status="downloading", message="yt-dlp 下载中…", progress=5)
|
||
cmd = [
|
||
"yt-dlp", "-f", "best[ext=mp4]/best",
|
||
"-o", str(mp4),
|
||
"--no-warnings", "--no-playlist",
|
||
"--retries", "3",
|
||
*ytdlp_cookie_args(),
|
||
job.url,
|
||
]
|
||
run(cmd)
|
||
if not mp4.exists():
|
||
raise RuntimeError("下载完成但找不到 source.mp4")
|
||
|
||
stage = "metadata"
|
||
meta = ffprobe_meta(mp4)
|
||
v_stream = next((s for s in meta["streams"] if s["codec_type"] == "video"), None)
|
||
duration = float(meta["format"]["duration"])
|
||
if duration <= 0:
|
||
raise RuntimeError("视频时长读取失败")
|
||
update(
|
||
job,
|
||
status="downloaded",
|
||
video_url=f"/jobs/{job_id}/video.mp4",
|
||
duration=duration,
|
||
width=int(v_stream["width"]) if v_stream else 0,
|
||
height=int(v_stream["height"]) if v_stream else 0,
|
||
progress=25,
|
||
error="",
|
||
message=f"视频就绪 · {duration:.1f}s · 等待音频解析",
|
||
)
|
||
except Exception as e:
|
||
message = "视频元数据解析失败" if stage == "metadata" else "下载失败"
|
||
update(job, status="failed", error=normalize_download_error(e), message=message)
|
||
|
||
|
||
def pipeline_analyze(
|
||
job_id: str,
|
||
frame_count: int = KEYFRAME_COUNT,
|
||
target: FrameExtractTarget = "transparent_human",
|
||
mode: FrameExtractMode = "replace",
|
||
quality: FrameExtractQuality = "auto",
|
||
) -> None:
|
||
"""阶段 2:拆音轨 + 抽关键帧。ASR/翻译是独立文案轨,不阻塞视觉素材流。"""
|
||
job = JOBS[job_id]
|
||
d = job_dir(job_id)
|
||
try:
|
||
mp4 = d / "source.mp4"
|
||
if not mp4.exists():
|
||
raise RuntimeError("source.mp4 不存在,先完成下载")
|
||
|
||
wav = d / "audio.wav"
|
||
audio_running = job_id in AUDIO_WORKERS_RUNNING or job.audio_script.status == "rewriting"
|
||
if wav.exists():
|
||
update(job, status="splitting", message="复用音轨 · 准备抽帧…", progress=35, source_audio_url=f"/jobs/{job_id}/audio.wav")
|
||
elif audio_running:
|
||
update(job, status="splitting", message="音频路并行处理中 · 准备抽帧…", progress=35)
|
||
else:
|
||
update(job, status="splitting", message="ffmpeg 拆分音轨…", progress=35)
|
||
run([
|
||
"ffmpeg", "-y", "-i", str(mp4),
|
||
"-vn", "-ac", "1", "-ar", "16000", "-c:a", "pcm_s16le",
|
||
str(wav),
|
||
])
|
||
update(job, source_audio_url=f"/jobs/{job_id}/audio.wav")
|
||
n = max(1, min(int(frame_count), 20))
|
||
target_label = FRAME_TARGET_LABELS.get(target, FRAME_TARGET_LABELS["balanced"])
|
||
duration = max(float(job.duration or 1.0), 0.1)
|
||
effective_quality = _resolve_frame_quality(duration, quality)
|
||
effective_quality_label = FRAME_QUALITY_LABELS.get(effective_quality, FRAME_QUALITY_LABELS["accurate"])
|
||
quality_label = f"自动·{effective_quality_label}" if quality == "auto" else effective_quality_label
|
||
scan_fps, scan_width, metric_width, estimated_scan_count = _scan_profile(duration, effective_quality)
|
||
|
||
update(job, message=f"本地{quality_label}扫描 · {target_label} · 约 {estimated_scan_count} 帧…", progress=45)
|
||
frames_dir = d / "frames"
|
||
replacing = mode == "replace"
|
||
existing_frames = list(job.frames) if not replacing else []
|
||
if replacing and frames_dir.exists():
|
||
shutil.rmtree(frames_dir)
|
||
frames_dir.mkdir(parents=True, exist_ok=True)
|
||
scan_dir = d / "frame_scan"
|
||
if scan_dir.exists():
|
||
shutil.rmtree(scan_dir)
|
||
scan_dir.mkdir(parents=True)
|
||
|
||
# 1) 低分辨率、低帧率扫描。扫描图只用于候选评分,最终不直接作为关键帧。
|
||
run([
|
||
"ffmpeg", "-y", "-i", str(mp4),
|
||
"-vf", f"fps={scan_fps:.4f},scale={scan_width}:-2",
|
||
"-q:v", "4",
|
||
str(scan_dir / "s_%05d.jpg"),
|
||
])
|
||
|
||
scan_paths = sorted(scan_dir.glob("s_*.jpg"))
|
||
if not scan_paths:
|
||
raise RuntimeError("低清扫描没有生成候选帧")
|
||
|
||
candidates: list[dict] = []
|
||
for i, p in enumerate(scan_paths):
|
||
t = min(i / scan_fps, max(duration - 0.05, 0.0))
|
||
item = _frame_metrics(p, i, t, metric_width)
|
||
if item:
|
||
candidates.append(item)
|
||
if not candidates:
|
||
raise RuntimeError("候选帧评分失败")
|
||
|
||
# 2) 目标化筛选:pHash 去重 + 清晰度 / 中心细节 / 转场变化 / 动作强度。
|
||
# 抽帧阶段只走本机算力,不逐帧调用 Vision;语义审核留到后续素材准备。
|
||
semantic_transparent = False
|
||
selection_count = n if replacing else min(len(candidates), max(n * 4, n + len(existing_frames) + 2))
|
||
update(job, message=f"{quality_label}本地筛选 · {target_label} · {n} / {len(candidates)} 张…", progress=60)
|
||
chosen = _select_keyframes(candidates, selection_count, target)
|
||
|
||
# 3) 只对最终选中的时间点,从原视频抽高质量关键帧。
|
||
renamed: list[KeyFrame] = []
|
||
chosen_sorted = chosen if semantic_transparent else sorted(chosen, key=lambda it: float(it["timestamp"]))
|
||
existing_timestamps = [float(f.timestamp) for f in existing_frames]
|
||
next_idx = max((int(f.index) for f in existing_frames), default=-1) + 1
|
||
rejected_by_ai = 0
|
||
for attempt, item in enumerate(chosen_sorted, start=1):
|
||
if len(renamed) >= n:
|
||
break
|
||
t = float(item["timestamp"])
|
||
if not replacing and any(abs(t - old) < 0.35 for old in existing_timestamps):
|
||
continue
|
||
idx = next_idx + len(renamed)
|
||
dst = frames_dir / f"{idx:03d}.jpg"
|
||
run([
|
||
"ffmpeg", "-y", "-ss", f"{t:.3f}", "-i", str(mp4),
|
||
"-frames:v", "1",
|
||
"-pix_fmt", "yuvj420p", "-q:v", "3",
|
||
str(dst),
|
||
])
|
||
transparent_score: TransparentHumanFrameScore | None = None
|
||
if semantic_transparent:
|
||
update(
|
||
job,
|
||
message=f"AI 验收透明骨架人 · 已通过 {len(renamed)}/{n} · 候选 {attempt}/{len(chosen_sorted)}…",
|
||
progress=min(68, 60 + int(attempt / max(1, len(chosen_sorted)) * 8)),
|
||
)
|
||
transparent_score = _score_transparent_human_frame(dst)
|
||
if not transparent_score.qualified:
|
||
rejected_by_ai += 1
|
||
try:
|
||
dst.unlink()
|
||
except OSError:
|
||
pass
|
||
reason = transparent_score.reject_reason or f"总分 {transparent_score.total_score}/100"
|
||
update(job, message=f"AI 退回候选帧 · {reason[:48]} · 自动换下一帧", progress=65)
|
||
continue
|
||
renamed.append(KeyFrame(
|
||
index=idx,
|
||
timestamp=round(t, 2),
|
||
url=f"/jobs/{job_id}/frames/{idx}.jpg",
|
||
transparent_human_score=transparent_score,
|
||
))
|
||
existing_timestamps.append(t)
|
||
|
||
if semantic_transparent and not renamed:
|
||
raise RuntimeError("AI 未找到合格透明骨架人帧:需要透明/半透明人体外壳 + 清楚白色骨架 + 非恐怖广告感")
|
||
|
||
# 4) 清理扫描目录
|
||
shutil.rmtree(scan_dir, ignore_errors=True)
|
||
|
||
merged_frames = sorted(existing_frames + renamed, key=lambda f: f.timestamp)
|
||
action_label = "追加" if not replacing else "抽取"
|
||
|
||
final_message = (
|
||
f"已按「{quality_label} · {target_label}」AI验收 {action_label} {len(renamed)} 张"
|
||
+ (f" · 退回 {rejected_by_ai} 张" if semantic_transparent else "")
|
||
+ f" · 共 {len(merged_frames)} 张"
|
||
) if semantic_transparent else (
|
||
f"已按「{quality_label} · {target_label}」{action_label} {len(renamed)} 张关键帧 · 共 {len(merged_frames)} 张"
|
||
)
|
||
update(
|
||
job,
|
||
status="transcribed" if job.transcript else "frames_extracted",
|
||
frames=merged_frames,
|
||
progress=70,
|
||
error="",
|
||
message=final_message,
|
||
)
|
||
|
||
except Exception as e:
|
||
update(job, status="failed", error=str(e), message="解析失败")
|
||
|
||
|
||
def analyze_queue_worker() -> None:
|
||
global ANALYZE_WORKER_RUNNING
|
||
ANALYZE_WORKER_RUNNING = True
|
||
try:
|
||
while ANALYZE_QUEUE:
|
||
job_id, frames, target, mode, quality = ANALYZE_QUEUE.pop(0)
|
||
if job_id not in JOBS:
|
||
continue
|
||
pipeline_analyze(job_id, frames, target, mode, quality)
|
||
if ANALYZE_QUEUE:
|
||
for pos, (queued_job_id, *_rest) in enumerate(ANALYZE_QUEUE, start=1):
|
||
queued_job = JOBS.get(queued_job_id)
|
||
if queued_job:
|
||
update(queued_job, status="splitting", progress=30, message=f"排队等待抽帧 · 前方 {pos - 1} 个任务")
|
||
finally:
|
||
ANALYZE_WORKER_RUNNING = False
|
||
|
||
|
||
# ---------- 音频转写 + 翻译 + SKG 改写 + Azure OpenAI 配音 ----------
|
||
|
||
class TranscriptionUnavailable(RuntimeError):
|
||
pass
|
||
|
||
|
||
def _parse_asr_segments(content: str, duration: float) -> list[dict]:
|
||
raw = (content or "").strip()
|
||
if raw.startswith("```"):
|
||
import re as _re
|
||
match = _re.search(r"(\[[\s\S]*\]|\{[\s\S]*\})", raw)
|
||
raw = match.group(0) if match else raw
|
||
try:
|
||
data = json.loads(raw)
|
||
except json.JSONDecodeError:
|
||
text = raw.strip()
|
||
return [{"start": 0.0, "end": duration, "text": text}] if text else []
|
||
if isinstance(data, dict):
|
||
if data.get("can_hear") is False:
|
||
raise TranscriptionUnavailable("fallback ASR could not hear the audio")
|
||
for key in ("segments", "data", "items", "result"):
|
||
if isinstance(data.get(key), list):
|
||
data = data[key]
|
||
break
|
||
else:
|
||
text = str(data.get("text") or data.get("transcript") or "").strip()
|
||
return [{"start": 0.0, "end": duration, "text": text}] if text else []
|
||
if not isinstance(data, list):
|
||
return []
|
||
segments: list[dict] = []
|
||
for i, item in enumerate(data):
|
||
if isinstance(item, str):
|
||
text = item.strip()
|
||
start = 0.0 if len(data) == 1 else duration * i / max(1, len(data))
|
||
end = duration if len(data) == 1 else duration * (i + 1) / max(1, len(data))
|
||
elif isinstance(item, dict):
|
||
text = str(item.get("text") or item.get("en") or item.get("transcript") or "").strip()
|
||
start = float(item.get("start") or item.get("start_time") or 0)
|
||
end = float(item.get("end") or item.get("end_time") or duration)
|
||
else:
|
||
continue
|
||
if text:
|
||
segments.append({"start": max(0.0, start), "end": max(start, end), "text": text})
|
||
return segments
|
||
|
||
|
||
def _clean_asr_segments(segments: list[dict], duration: float) -> list[dict]:
|
||
clean: list[dict] = []
|
||
cursor = 0.0
|
||
for item in segments:
|
||
text = str(item.get("text") or item.get("en") or item.get("transcript") or "").strip()
|
||
if not text:
|
||
continue
|
||
try:
|
||
start = float(item.get("start") if item.get("start") is not None else item.get("start_time") or 0)
|
||
end = float(item.get("end") if item.get("end") is not None else item.get("end_time") or 0)
|
||
except (TypeError, ValueError):
|
||
continue
|
||
if end <= 0 and duration > 0:
|
||
end = duration
|
||
start = max(0.0, min(start, duration if duration > 0 else start))
|
||
end = max(start + 0.05, min(end, duration if duration > 0 else end))
|
||
# Keep the timeline monotonic. Real ASR can overlap slightly, but the UI table should not jump back.
|
||
if start < cursor - 0.25:
|
||
start = cursor
|
||
end = max(end, start + 0.05)
|
||
cursor = max(cursor, end)
|
||
clean.append({"start": round(start, 2), "end": round(end, 2), "text": text})
|
||
return clean
|
||
|
||
|
||
def _segment_text_key(text: str) -> str:
|
||
return re.sub(r"[^a-z0-9]+", " ", text.lower()).strip()
|
||
|
||
|
||
def _validate_asr_segments(segments: list[dict], duration: float, source: str) -> list[dict]:
|
||
clean = _clean_asr_segments(segments, duration)
|
||
if not clean:
|
||
raise TranscriptionUnavailable(f"{source} did not return transcript segments")
|
||
keyed = [_segment_text_key(str(s.get("text") or "")) for s in clean if _segment_text_key(str(s.get("text") or ""))]
|
||
unique_ratio = len(set(keyed)) / max(1, len(keyed))
|
||
one_secondish = [
|
||
s for s in clean
|
||
if 0.75 <= (float(s["end"]) - float(s["start"])) <= 1.25
|
||
]
|
||
if len(clean) >= 12 and unique_ratio < 0.35:
|
||
raise TranscriptionUnavailable(f"{source} returned repetitive transcript segments")
|
||
if len(clean) >= 20 and len(one_secondish) / len(clean) > 0.75 and unique_ratio < 0.65:
|
||
raise TranscriptionUnavailable(f"{source} returned synthetic one-second timeline")
|
||
if duration > 0:
|
||
last_end = max(float(s["end"]) for s in clean)
|
||
words = sum(len(str(s.get("text") or "").split()) for s in clean)
|
||
if len(clean) > 1 and last_end > duration + 3:
|
||
raise TranscriptionUnavailable(f"{source} returned timestamps outside audio duration")
|
||
if duration > 10 and last_end < duration * 0.45 and words < 20:
|
||
raise TranscriptionUnavailable(f"{source} returned too little transcript coverage")
|
||
for item in clean:
|
||
item["_source"] = source
|
||
return clean
|
||
|
||
|
||
def _local_asr_binary() -> str:
|
||
candidates = [
|
||
LOCAL_ASR_BIN,
|
||
shutil.which("mlx_whisper") or "",
|
||
"/opt/homebrew/bin/mlx_whisper",
|
||
]
|
||
for candidate in candidates:
|
||
if candidate and Path(candidate).exists() and os.access(candidate, os.X_OK):
|
||
return candidate
|
||
raise TranscriptionUnavailable("本机未找到可用 mlx_whisper")
|
||
|
||
|
||
def _transcribe_mlx_sync(wav: Path) -> list[dict]:
|
||
wav = wav.resolve()
|
||
duration = media_duration(wav)
|
||
binary = _local_asr_binary()
|
||
output_name = "asr-local"
|
||
output_path = wav.parent / f"{output_name}.json"
|
||
if output_path.exists():
|
||
output_path.unlink()
|
||
env = os.environ.copy()
|
||
try:
|
||
ffmpeg_path = Path(media_binary("ffmpeg"))
|
||
env["PATH"] = f"{ffmpeg_path.parent}{os.pathsep}{env.get('PATH', '')}"
|
||
except Exception:
|
||
pass
|
||
cmd = [
|
||
binary,
|
||
str(wav),
|
||
"--model", LOCAL_ASR_MODEL,
|
||
"--output-dir", str(wav.parent),
|
||
"--output-name", output_name,
|
||
"--output-format", "json",
|
||
"--verbose", "False",
|
||
"--condition-on-previous-text", "False",
|
||
"--word-timestamps", "True",
|
||
]
|
||
try:
|
||
result = subprocess.run(
|
||
cmd,
|
||
cwd=str(wav.parent),
|
||
env=env,
|
||
capture_output=True,
|
||
text=True,
|
||
timeout=LOCAL_ASR_TIMEOUT_SECONDS,
|
||
)
|
||
except subprocess.TimeoutExpired as e:
|
||
raise TranscriptionUnavailable(f"本机 ASR 超时:{LOCAL_ASR_TIMEOUT_SECONDS}s") from e
|
||
if result.returncode != 0:
|
||
detail = (result.stderr or result.stdout or "").strip().splitlines()[-1:] or ["本机 ASR 执行失败"]
|
||
raise TranscriptionUnavailable(detail[0][:500])
|
||
if not output_path.exists():
|
||
raise TranscriptionUnavailable("本机 ASR 未生成 json 结果")
|
||
data = json.loads(output_path.read_text(encoding="utf-8"))
|
||
segments = data.get("segments") or []
|
||
return _validate_asr_segments(segments, duration, "mlx_whisper")
|
||
|
||
|
||
def _transcribe_gemini_sync(wav: Path) -> list[dict]:
|
||
duration = media_duration(wav)
|
||
audio_b64 = base64.b64encode(wav.read_bytes()).decode("ascii")
|
||
prompt = (
|
||
"Transcribe the attached audio. Return strict JSON only, no markdown. "
|
||
"If you cannot truly hear the audio, return {\"can_hear\": false}. Do not guess. "
|
||
"If you can hear it, return {\"can_hear\": true, \"segments\": "
|
||
"[{\"start\": 0.0, \"end\": 1.2, \"text\": \"English transcript\"}]}. "
|
||
"Use English for the transcript. Only include timestamps you can infer from the audio."
|
||
)
|
||
last_error: Exception | None = None
|
||
for attempt in range(3):
|
||
try:
|
||
resp = llm().chat.completions.create(
|
||
model=ASR_FALLBACK_MODEL,
|
||
messages=[{"role": "user", "content": [
|
||
{"type": "text", "text": prompt},
|
||
{"type": "input_audio", "input_audio": {"data": audio_b64, "format": "wav"}},
|
||
]}],
|
||
temperature=0,
|
||
timeout=ASR_TIMEOUT_SECONDS,
|
||
)
|
||
content = (resp.choices[0].message.content or "").strip()
|
||
return _validate_asr_segments(_parse_asr_segments(content, duration), duration, "gemini audio fallback")
|
||
except Exception as e:
|
||
last_error = e
|
||
if attempt < 2:
|
||
time.sleep(1.0)
|
||
raise last_error or RuntimeError("Gemini audio transcription failed")
|
||
|
||
|
||
def _transcribe_sync(wav: Path) -> list[dict]:
|
||
"""Remote ASR first, local mlx_whisper second. Gemini fallback is guarded against fake timelines."""
|
||
errors: list[str] = []
|
||
duration = media_duration(wav)
|
||
try:
|
||
with wav.open("rb") as f:
|
||
resp = llm().audio.transcriptions.create(
|
||
file=(wav.name, f, "audio/wav"),
|
||
model=ASR_MODEL,
|
||
response_format="verbose_json",
|
||
timestamp_granularities=["segment"],
|
||
timeout=ASR_TIMEOUT_SECONDS,
|
||
)
|
||
raw = resp.model_dump() if hasattr(resp, "model_dump") else resp
|
||
segments = raw.get("segments") or []
|
||
# 兜底:网关如果不返回 segments,把全文当一段
|
||
if not segments and raw.get("text"):
|
||
segments = [{"start": 0.0, "end": float(raw.get("duration", 0) or 0), "text": raw["text"]}]
|
||
return _validate_asr_segments(segments, duration, ASR_MODEL)
|
||
except Exception as e:
|
||
errors.append(f"{ASR_MODEL}: {e}")
|
||
try:
|
||
return _transcribe_mlx_sync(wav)
|
||
except Exception as e:
|
||
errors.append(f"mlx_whisper: {e}")
|
||
try:
|
||
return _transcribe_gemini_sync(wav)
|
||
except Exception as e:
|
||
errors.append(f"{ASR_FALLBACK_MODEL}: {e}")
|
||
raise TranscriptionUnavailable(";".join(errors))
|
||
|
||
|
||
def _translate_sync(segments: list[dict]) -> list[str]:
|
||
"""批量翻译为中文,按段返回"""
|
||
payload = [{"i": i, "en": s.get("text", "").strip()} for i, s in enumerate(segments)]
|
||
prompt = (
|
||
"你是字幕翻译。把下列英文字幕段翻译为简体中文,保持原意、口语化、自然流畅。"
|
||
"严格返回 JSON 数组,不要任何 markdown 或多余文字,schema: "
|
||
'[{"i": 0, "zh": "..."}, ...]\n\n输入:\n'
|
||
+ json.dumps(payload, ensure_ascii=False)
|
||
)
|
||
try:
|
||
resp = llm().chat.completions.create(
|
||
model=TRANSLATE_MODEL,
|
||
messages=[{"role": "user", "content": prompt}],
|
||
response_format={"type": "json_object"},
|
||
temperature=0.2,
|
||
)
|
||
content = resp.choices[0].message.content or "[]"
|
||
except Exception:
|
||
return ["" for _ in segments]
|
||
try:
|
||
data = json.loads(content)
|
||
if isinstance(data, dict):
|
||
for k in ("data", "items", "result", "translations"):
|
||
if k in data and isinstance(data[k], list):
|
||
data = data[k]
|
||
break
|
||
if not isinstance(data, list):
|
||
data = []
|
||
except json.JSONDecodeError:
|
||
data = []
|
||
zh_by_idx: dict[int, str] = {}
|
||
for it in data:
|
||
if isinstance(it, dict) and "i" in it:
|
||
zh_by_idx[int(it["i"])] = str(it.get("zh", ""))
|
||
return [zh_by_idx.get(i, "") for i in range(len(segments))]
|
||
|
||
|
||
def _transcript_join(segments: list[TranscriptSegment], field: Literal["en", "zh"]) -> str:
|
||
lines: list[str] = []
|
||
for s in segments:
|
||
text = (s.zh if field == "zh" else s.en).strip()
|
||
if text:
|
||
lines.append(f"[{s.start:.1f}-{s.end:.1f}s] {text}")
|
||
return "\n".join(lines)
|
||
|
||
|
||
def _voiceover_target_words(target_seconds: float) -> tuple[int, int]:
|
||
seconds = max(4.0, min(float(target_seconds or 0) or 12.0, 45.0))
|
||
center = int(round(seconds * 2.35))
|
||
return max(10, int(center * 0.86)), min(110, max(14, int(center * 1.12)))
|
||
|
||
|
||
def _segment_duration(segments: list[TranscriptSegment]) -> float:
|
||
if not segments:
|
||
return 0.0
|
||
start = min((s.start for s in segments), default=0.0)
|
||
end = max((s.end for s in segments), default=0.0)
|
||
return max(0.0, end - start)
|
||
|
||
|
||
def _fallback_audio_script(segments: list[TranscriptSegment], target_seconds: float = 12.0) -> str:
|
||
seconds = max(target_seconds, _segment_duration(segments), 4.0)
|
||
if seconds <= 7:
|
||
return "Meet SKG: warm massage, easy comfort, and a tiny reset for busy bodies."
|
||
if seconds <= 13:
|
||
return (
|
||
"Meet SKG, your shortcut to a calmer body break. A little warmth, a steady massage rhythm, "
|
||
"and suddenly your day feels less tight and more yours."
|
||
)
|
||
if seconds <= 22:
|
||
return (
|
||
"This is SKG: smart massage for the moments your body asks for a pause. Warmth, rhythm, "
|
||
"and a clean wearable feel turn neck, back, or everyday tension into a softer reset."
|
||
)
|
||
return (
|
||
"Say hello to SKG, the small reset button your day keeps asking for. From neck and shoulder breaks "
|
||
"to back, eye, knee, or foot comfort, SKG brings warm, rhythmic massage into everyday routines, "
|
||
"so winding down feels simple, smart, and a little more fun."
|
||
)
|
||
|
||
|
||
def _audio_delivery_profile(segments: list[TranscriptSegment], target_seconds: float, voice_id: str) -> tuple[str, str]:
|
||
duration = max(float(target_seconds or 0), _segment_duration(segments), 0.0)
|
||
words = sum(len([w for w in s.en.replace("\n", " ").split(" ") if w.strip()]) for s in segments)
|
||
sentence_count = len([s for s in segments if (s.en or s.zh).strip()])
|
||
wpm = int(round(words / max(duration, 1.0) * 60)) if words else 0
|
||
avg_sentence = duration / sentence_count if sentence_count else 0.0
|
||
speaker = (
|
||
f"按原素材的短视频单人旁白处理;当前近似音色为 {voice_id},用于保持商业口播的亲近感和节奏。"
|
||
if voice_id
|
||
else "按原素材的短视频单人旁白处理;等待选择 TTS 音色。"
|
||
)
|
||
rhythm = (
|
||
f"源音频约 {duration:.1f}s,{sentence_count} 个语义段,语速约 {wpm} wpm,平均每段 {avg_sentence:.1f}s;"
|
||
"新配音按相同时长、短句停顿和信息密度改写。"
|
||
if duration > 0 and sentence_count
|
||
else "源音频节奏信息不足;新配音按 8-12 秒信息流广告口播节奏生成。"
|
||
)
|
||
return speaker, rhythm
|
||
|
||
|
||
def _fallback_audio_profile(segments: list[TranscriptSegment], target_seconds: float = 0.0) -> tuple[str, str, str]:
|
||
duration = max(float(target_seconds or 0), _segment_duration(segments), 0.0)
|
||
words = sum(len([w for w in s.en.replace("\n", " ").split(" ") if w.strip()]) for s in segments)
|
||
sentence_count = len([s for s in segments if (s.en or s.zh).strip()])
|
||
wpm = int(round(words / max(duration, 1.0) * 60)) if words else 0
|
||
avg_sentence = duration / sentence_count if sentence_count else 0.0
|
||
speaker = "检测到短视频口播人声;当前仅能根据转写段落估算,未做声纹克隆。"
|
||
rhythm = (
|
||
f"音频约 {duration:.1f}s,{sentence_count} 个文案段,语速约 {wpm} wpm,平均每段 {avg_sentence:.1f}s。"
|
||
if duration > 0 and sentence_count
|
||
else "音频节奏信息不足;等待模型返回更完整的语速和停顿分析。"
|
||
)
|
||
background = "背景音待模型细分;当前已保留原音频文件,可继续用于音乐、人声和环境声判断。"
|
||
return speaker, rhythm, background
|
||
|
||
|
||
def _audio_profile_model_sync(wav: Path, segments: list[TranscriptSegment], target_seconds: float = 0.0) -> tuple[str, str, str]:
|
||
fallback = _fallback_audio_profile(segments, target_seconds)
|
||
if not LLM_API_KEY or not wav.exists():
|
||
return fallback
|
||
transcript = _transcript_join(segments, "en") or _transcript_join(segments, "zh") or "No reliable transcript."
|
||
try:
|
||
audio_b64 = base64.b64encode(wav.read_bytes()).decode("ascii")
|
||
except Exception:
|
||
return fallback
|
||
prompt = (
|
||
"Analyze this short-video audio for an ad recreation workflow. Return strict JSON only, no markdown.\n"
|
||
"Fields:\n"
|
||
"- speaker_profile: describe speaker count, likely gender/age range if audible, tone, energy, accent/language, confidence.\n"
|
||
"- rhythm_profile: describe pacing, pauses, speech density, segment rhythm, and timing pattern.\n"
|
||
"- background_audio_profile: describe music, background sound, ambience, SFX, loudness relationship to voice, and whether it should be recreated or replaced.\n"
|
||
"Do not invent an exact identity. If uncertain, state uncertainty.\n\n"
|
||
f"Known transcript/timestamps:\n{transcript[:5000]}"
|
||
)
|
||
last_error: Exception | None = None
|
||
for attempt in range(2):
|
||
try:
|
||
resp = llm().chat.completions.create(
|
||
model=ASR_FALLBACK_MODEL,
|
||
messages=[{"role": "user", "content": [
|
||
{"type": "text", "text": prompt},
|
||
{"type": "input_audio", "input_audio": {"data": audio_b64, "format": "wav"}},
|
||
]}],
|
||
response_format={"type": "json_object"},
|
||
temperature=0.1,
|
||
max_tokens=900,
|
||
timeout=ASR_TIMEOUT_SECONDS,
|
||
)
|
||
content = (resp.choices[0].message.content or "").strip()
|
||
data = json.loads(content)
|
||
speaker = str(data.get("speaker_profile") or "").strip()
|
||
rhythm = str(data.get("rhythm_profile") or "").strip()
|
||
background = str(data.get("background_audio_profile") or "").strip()
|
||
if speaker or rhythm or background:
|
||
return (
|
||
speaker or fallback[0],
|
||
rhythm or fallback[1],
|
||
background or fallback[2],
|
||
)
|
||
except Exception as e:
|
||
last_error = e
|
||
if attempt == 0:
|
||
time.sleep(1.0)
|
||
if last_error:
|
||
print(f"[audio profile fallback] {last_error}", flush=True)
|
||
return fallback
|
||
|
||
|
||
def _build_audio_intake_sync(job_id: str, wav: Path, segments: list[TranscriptSegment], target_seconds: float = 0.0) -> AudioScript:
|
||
source_text = _transcript_join(segments, "en")
|
||
source_zh = _transcript_join(segments, "zh")
|
||
duration = max(float(target_seconds or 0), _segment_duration(segments), 0.0)
|
||
speaker_profile, rhythm_profile, background_audio_profile = _audio_profile_model_sync(wav, segments, duration)
|
||
return AudioScript(
|
||
status="completed",
|
||
source_text=source_text,
|
||
source_zh=source_zh,
|
||
speaker_profile=speaker_profile,
|
||
rhythm_profile=rhythm_profile,
|
||
background_audio_profile=background_audio_profile,
|
||
product_brief=AUDIO_PRODUCT_BRIEF,
|
||
rewrite_model=ASR_FALLBACK_MODEL,
|
||
created_at=time.time(),
|
||
)
|
||
|
||
|
||
def _rewrite_audio_script_sync(segments: list[TranscriptSegment], target_seconds: float = 12.0) -> tuple[str, str]:
|
||
fallback = _fallback_audio_script(segments, target_seconds)
|
||
if not LLM_API_KEY:
|
||
return fallback, "LLM_API_KEY 未配置,使用本地 SKG 模板"
|
||
source_text = _transcript_join(segments, "en")
|
||
source_zh = _transcript_join(segments, "zh")
|
||
min_words, max_words = _voiceover_target_words(target_seconds)
|
||
prompt = (
|
||
"You are an English short-video voice-over writer for SKG wellness massagers. "
|
||
"Write a fresh product-introduction VO for SKG. Use the source transcript only as timing and pacing reference; "
|
||
"do not summarize it unless it helps the rhythm.\n"
|
||
"Rules:\n"
|
||
f"1. Target audio length is about {target_seconds:.1f} seconds. Output {min_words}-{max_words} English words.\n"
|
||
"2. Make it natural, warm, premium, and a little playful. It should sound like a real creator, not a stiff ad.\n"
|
||
"3. Do not claim medical treatment, cure, pain elimination, or clinical effects.\n"
|
||
"4. Do not copy the original brand, creator, price, platform language, or exact claims.\n"
|
||
"5. Introduce SKG products directly: smart massage, warmth, rhythm, daily neck/back/eye/knee/foot relaxation.\n"
|
||
"6. Keep it easy for TTS: short sentences, spoken phrasing, no hashtags, no stage directions, no quotation marks.\n"
|
||
"7. If the source transcript is thin, ignore it and write a general SKG product intro.\n"
|
||
'Return strict JSON only: {"rewritten_text":"..."}.\n\n'
|
||
f"SKG product context: {AUDIO_PRODUCT_BRIEF}\n\n"
|
||
f"English transcript:\n{source_text or 'None'}\n\n"
|
||
f"Chinese translation for reference:\n{source_zh or 'None'}"
|
||
)
|
||
try:
|
||
resp = llm().chat.completions.create(
|
||
model=AUDIO_REWRITE_MODEL,
|
||
messages=[
|
||
{"role": "system", "content": "Return valid JSON only. No explanation. No markdown."},
|
||
{"role": "user", "content": prompt},
|
||
],
|
||
response_format={"type": "json_object"},
|
||
temperature=0.72,
|
||
max_tokens=600,
|
||
)
|
||
raw = (resp.choices[0].message.content or "").strip()
|
||
if raw.startswith("```"):
|
||
import re as _re
|
||
match = _re.search(r"\{[\s\S]*\}", raw)
|
||
raw = match.group(0) if match else raw
|
||
data = json.loads(raw)
|
||
text = str(data.get("rewritten_text", "")).strip()
|
||
return (text or fallback), ""
|
||
except Exception as e:
|
||
return fallback, f"改写失败,使用本地模板:{e}"
|
||
|
||
|
||
def _choose_azure_voice_id() -> str:
|
||
if AZURE_TTS_VOICE_POOL:
|
||
return random.choice(AZURE_TTS_VOICE_POOL)
|
||
return AZURE_TTS_VOICE_ID
|
||
|
||
|
||
def _choose_tts_voice_id() -> str:
|
||
return _choose_azure_voice_id()
|
||
|
||
|
||
def _voice_speed_for(voice_id: str, target_seconds: float, text: str) -> float:
|
||
words = len([w for w in text.replace("\n", " ").split(" ") if w.strip()])
|
||
estimated_seconds = words / 2.35 if words else target_seconds
|
||
if target_seconds > 0 and estimated_seconds > target_seconds * 1.12:
|
||
return 1.06
|
||
if target_seconds > 0 and estimated_seconds < target_seconds * 0.82:
|
||
return 0.94
|
||
if voice_id == "English_MaturePartner":
|
||
return 0.96
|
||
if voice_id == "English_Upbeat_Woman":
|
||
return 1.02
|
||
return 0.99
|
||
|
||
|
||
def _azure_tts_url_for(path_value: str) -> str:
|
||
path = path_value if path_value.startswith("/") else f"/{path_value}"
|
||
if AZURE_OPENAI_BASE_URL.endswith(path):
|
||
return AZURE_OPENAI_BASE_URL
|
||
return f"{AZURE_OPENAI_BASE_URL}{path}"
|
||
|
||
|
||
def _azure_tts_urls() -> list[str]:
|
||
urls: list[str] = []
|
||
for path in AZURE_TTS_PATHS or [AZURE_TTS_PATH]:
|
||
url = _azure_tts_url_for(path)
|
||
if url not in urls:
|
||
urls.append(url)
|
||
return urls
|
||
|
||
|
||
def _azure_openai_tts_sync(job_id: str, text: str, voice_id: str, target_seconds: float = 12.0) -> str:
|
||
if not AZURE_OPENAI_API_KEY:
|
||
raise RuntimeError("AZURE_OPENAI_API_KEY 或 LLM_API_KEY 未配置,未生成配音")
|
||
if not text.strip():
|
||
raise RuntimeError("改写文案为空,未生成配音")
|
||
payload = {
|
||
"model": AZURE_TTS_MODEL,
|
||
"voice": voice_id,
|
||
"input": text.strip()[:9500],
|
||
"response_format": "mp3",
|
||
"speed": _voice_speed_for(voice_id, target_seconds, text),
|
||
}
|
||
headers = {
|
||
"Authorization": f"Bearer {AZURE_OPENAI_API_KEY}",
|
||
"api-key": AZURE_OPENAI_API_KEY,
|
||
"Content-Type": "application/json",
|
||
}
|
||
resp: httpx.Response | None = None
|
||
errors: list[str] = []
|
||
with ai_http_client(timeout=120) as client:
|
||
for url in _azure_tts_urls():
|
||
try:
|
||
current = client.post(url, headers=headers, json=payload)
|
||
except Exception as e:
|
||
errors.append(f"{url}: {type(e).__name__}: {e}")
|
||
continue
|
||
if current.status_code < 400:
|
||
resp = current
|
||
break
|
||
errors.append(f"{url}: HTTP {current.status_code}: {current.text[:180]}")
|
||
if current.status_code not in {404, 405}:
|
||
resp = current
|
||
break
|
||
if resp is None:
|
||
raise RuntimeError("Azure OpenAI TTS 不可用;已尝试 " + " | ".join(errors))
|
||
if resp.status_code >= 400:
|
||
detail = " | ".join(errors) or resp.text[:300]
|
||
raise RuntimeError(f"Azure OpenAI TTS HTTP {resp.status_code}: {detail[:600]}")
|
||
audio_bytes = resp.content
|
||
if not audio_bytes:
|
||
raise RuntimeError("Azure OpenAI TTS 未返回音频内容")
|
||
content_type = resp.headers.get("content-type", "")
|
||
if "application/json" in content_type.lower():
|
||
try:
|
||
data = resp.json()
|
||
except Exception:
|
||
data = {"error": resp.text[:300]}
|
||
raise RuntimeError(f"Azure OpenAI TTS 返回 JSON 而不是音频:{str(data)[:300]}")
|
||
out = job_dir(job_id) / "audio_script.mp3"
|
||
out.write_bytes(audio_bytes)
|
||
return f"/jobs/{job_id}/audio-script.mp3"
|
||
|
||
|
||
def _tts_sync(job_id: str, text: str, voice_id: str, target_seconds: float = 12.0) -> tuple[str, str, str]:
|
||
return _azure_openai_tts_sync(job_id, text, voice_id, target_seconds), "azure_openai", AZURE_TTS_MODEL
|
||
|
||
|
||
def _build_audio_script_sync(job_id: str, segments: list[TranscriptSegment], target_seconds: float = 12.0) -> AudioScript:
|
||
source_text = _transcript_join(segments, "en")
|
||
source_zh = _transcript_join(segments, "zh")
|
||
duration = max(float(target_seconds or 0), _segment_duration(segments), 4.0)
|
||
rewritten, rewrite_error = _rewrite_audio_script_sync(segments, duration)
|
||
selected_voice_id = _choose_tts_voice_id()
|
||
speaker_profile, rhythm_profile = _audio_delivery_profile(segments, duration, selected_voice_id)
|
||
voice_url = ""
|
||
voice_error = ""
|
||
voice_provider = "azure_openai"
|
||
voice_model = AZURE_TTS_MODEL
|
||
try:
|
||
voice_url, voice_provider, voice_model = _tts_sync(job_id, rewritten, selected_voice_id, duration)
|
||
except Exception as e:
|
||
voice_error = str(e)
|
||
# 改写失败时已有本地 SKG 模板兜底,不把它标成用户可见错误;配音失败才需要提示。
|
||
errors = voice_error
|
||
return AudioScript(
|
||
status="completed",
|
||
source_text=source_text,
|
||
source_zh=source_zh,
|
||
rewritten_text=rewritten,
|
||
speaker_profile=speaker_profile,
|
||
rhythm_profile=rhythm_profile,
|
||
product_brief=AUDIO_PRODUCT_BRIEF,
|
||
rewrite_model=AUDIO_REWRITE_MODEL,
|
||
voice_provider=voice_provider,
|
||
voice_model=voice_model,
|
||
voice_id=selected_voice_id,
|
||
voice_url=voice_url,
|
||
error=errors,
|
||
created_at=time.time(),
|
||
)
|
||
|
||
|
||
def pipeline_transcribe(job_id: str, manage_job_status: bool = True) -> None:
|
||
job = JOBS[job_id]
|
||
d = job_dir(job_id)
|
||
wav = d / "audio.wav"
|
||
def progress(message: str, value: int) -> None:
|
||
if manage_job_status:
|
||
update(job, status="transcribing", message=message, progress=value, error="")
|
||
|
||
try:
|
||
if not wav.exists():
|
||
mp4 = d / "source.mp4"
|
||
if not mp4.exists():
|
||
raise RuntimeError("source.mp4 不存在,视频导入完成后再提取音频")
|
||
progress("ffmpeg 提取音频轨…", max(45, min(job.progress, 70)))
|
||
run([
|
||
"ffmpeg", "-y", "-i", str(mp4),
|
||
"-vn", "-ac", "1", "-ar", "16000", "-c:a", "pcm_s16le",
|
||
str(wav),
|
||
])
|
||
if not wav.exists():
|
||
raise RuntimeError("音频提取完成但找不到 audio.wav")
|
||
update(job, source_audio_url=f"/jobs/{job_id}/audio.wav")
|
||
target_duration = max(media_duration(wav), float(job.duration or 0), 4.0)
|
||
|
||
if not LLM_API_KEY:
|
||
# 无 key 模式:mock 数据
|
||
progress("ASR (mock) …", 75)
|
||
time.sleep(1.0)
|
||
mock = [
|
||
TranscriptSegment(index=0, start=0.0, end=3.5,
|
||
en="Welcome back, today we're testing something new.",
|
||
zh="欢迎回来,今天我们要测试一些新东西。"),
|
||
TranscriptSegment(index=1, start=3.5, end=7.2,
|
||
en="This device looks really sleek and minimal.",
|
||
zh="这个设备看起来非常时尚和简约。"),
|
||
]
|
||
update_kwargs = {
|
||
"transcript": mock,
|
||
"audio_script": AudioScript(
|
||
status="rewriting",
|
||
source_text=_transcript_join(mock, "en"),
|
||
source_zh=_transcript_join(mock, "zh"),
|
||
speaker_profile="正在分析原音频讲话人和口播节奏…",
|
||
rhythm_profile="正在按原音频时长、语速和停顿分析口播节奏…",
|
||
background_audio_profile="正在分析背景音乐、环境声和音效…",
|
||
product_brief=AUDIO_PRODUCT_BRIEF,
|
||
rewrite_model=ASR_FALLBACK_MODEL,
|
||
),
|
||
}
|
||
if manage_job_status:
|
||
update_kwargs.update(message="ASR mock 完成,分析声音和背景音…", progress=92)
|
||
update(job, **update_kwargs)
|
||
audio_script = _build_audio_intake_sync(job_id, wav, mock, target_duration)
|
||
if manage_job_status:
|
||
update(job, transcript=mock, status="transcribed", progress=100,
|
||
audio_script=audio_script,
|
||
message="音频解析完成(MOCK · 未设 LLM_API_KEY)")
|
||
else:
|
||
update(job, transcript=mock, audio_script=audio_script)
|
||
return
|
||
|
||
# 1) whisper ASR
|
||
progress(f"{ASR_MODEL} 转录中…", 78)
|
||
segments = _transcribe_sync(wav)
|
||
if not segments:
|
||
raise TranscriptionUnavailable("ASR 未返回可用字幕段")
|
||
asr_source = str(segments[0].get("_source") or ASR_MODEL)
|
||
|
||
# 先把英文段落落到 job 上(让 UI 提前看到,翻译再补 zh)
|
||
en_only = [
|
||
TranscriptSegment(
|
||
index=i,
|
||
start=float(s.get("start", 0)),
|
||
end=float(s.get("end", 0)),
|
||
en=str(s.get("text", "")).strip(),
|
||
zh="",
|
||
)
|
||
for i, s in enumerate(segments)
|
||
]
|
||
if manage_job_status:
|
||
update(job, transcript=en_only, message=f"ASR 完成 · {len(en_only)} 段,开始翻译…", progress=88)
|
||
else:
|
||
update(job, transcript=en_only)
|
||
|
||
# 2) Gemini 翻译
|
||
zh_list = _translate_sync(segments)
|
||
full = [
|
||
TranscriptSegment(
|
||
index=seg.index, start=seg.start, end=seg.end, en=seg.en,
|
||
zh=zh_list[i] if i < len(zh_list) else "",
|
||
)
|
||
for i, seg in enumerate(en_only)
|
||
]
|
||
update_kwargs = {
|
||
"transcript": full,
|
||
"audio_script": AudioScript(
|
||
status="rewriting",
|
||
source_text=_transcript_join(full, "en"),
|
||
source_zh=_transcript_join(full, "zh"),
|
||
speaker_profile="正在分析原音频讲话人和口播节奏…",
|
||
rhythm_profile="正在按原音频时长、语速和停顿分析口播节奏…",
|
||
background_audio_profile="正在分析背景音乐、环境声和音效…",
|
||
product_brief=AUDIO_PRODUCT_BRIEF,
|
||
rewrite_model=ASR_FALLBACK_MODEL,
|
||
),
|
||
}
|
||
if manage_job_status:
|
||
update_kwargs.update(message="翻译完成,分析讲话人、节奏和背景音…", progress=94)
|
||
update(job, **update_kwargs)
|
||
audio_script = _build_audio_intake_sync(job_id, wav, full, target_duration)
|
||
if manage_job_status:
|
||
update(job, transcript=full, status="transcribed", progress=100,
|
||
audio_script=audio_script,
|
||
message=f"音频解析完成 · {len(full)} 段({asr_source} + {TRANSLATE_MODEL} + {ASR_FALLBACK_MODEL} 音频分析)")
|
||
else:
|
||
update(job, transcript=full, audio_script=audio_script)
|
||
|
||
except Exception as e:
|
||
if manage_job_status:
|
||
update(
|
||
job,
|
||
status="failed",
|
||
audio_script=AudioScript(status="failed", error=str(e), created_at=time.time()),
|
||
error=str(e),
|
||
message="转录失败",
|
||
)
|
||
else:
|
||
update(job, audio_script=AudioScript(status="failed", error=str(e), created_at=time.time()))
|
||
|
||
|
||
def _audio_processing_worker(job_id: str, manage_job_status: bool) -> None:
|
||
try:
|
||
pipeline_transcribe(job_id, manage_job_status=manage_job_status)
|
||
finally:
|
||
with AUDIO_WORKERS_LOCK:
|
||
AUDIO_WORKERS_RUNNING.discard(job_id)
|
||
|
||
|
||
def start_audio_processing(job_id: str, manage_job_status: bool = True) -> bool:
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
return False
|
||
if not manage_job_status:
|
||
has_audio_output = bool(job.transcript) or bool(job.audio_script.rewritten_text)
|
||
if has_audio_output or job.audio_script.status == "rewriting":
|
||
return False
|
||
with AUDIO_WORKERS_LOCK:
|
||
if job_id in AUDIO_WORKERS_RUNNING:
|
||
return False
|
||
AUDIO_WORKERS_RUNNING.add(job_id)
|
||
threading.Thread(
|
||
target=_audio_processing_worker,
|
||
args=(job_id, manage_job_status),
|
||
daemon=True,
|
||
name=f"audio-{job_id}",
|
||
).start()
|
||
return True
|
||
|
||
|
||
def _image_is_capacity_error(status_code: int, body: str) -> bool:
|
||
lower = body.lower()
|
||
return (
|
||
status_code == 429
|
||
or (
|
||
status_code in (500, 502, 503, 504)
|
||
and any(token in lower for token in ("saturated", "rate", "quota", "capacity", "overload", "timeout", "繁忙", "饱和", "过载"))
|
||
)
|
||
)
|
||
|
||
|
||
def _image_retry_delay(attempt: int, status_code: int = 0, body: str = "", retry_after: str | None = None) -> float:
|
||
if retry_after:
|
||
try:
|
||
return max(1.0, min(60.0, float(retry_after)))
|
||
except ValueError:
|
||
pass
|
||
if _image_is_capacity_error(status_code, body):
|
||
return [6.0, 14.0, 30.0, 45.0][min(attempt, 3)]
|
||
return [1.0, 2.0, 4.0, 8.0][min(attempt, 3)]
|
||
|
||
|
||
def _image_is_transport_error(message: str) -> bool:
|
||
lower = message.lower()
|
||
return any(
|
||
token in lower
|
||
for token in (
|
||
"connecterror",
|
||
"connecttimeout",
|
||
"readtimeout",
|
||
"timeout",
|
||
"nodename nor servname",
|
||
"name or service not known",
|
||
"temporary failure in name resolution",
|
||
"operation not permitted",
|
||
"connection refused",
|
||
"network is unreachable",
|
||
)
|
||
)
|
||
|
||
|
||
def _image_failure_message(kind: str, attempts: int, last_err: str, capacity_seen: bool) -> str:
|
||
if capacity_seen:
|
||
return (
|
||
f"{kind} failed after {attempts} attempts: gpt-image-2 上游负载饱和,"
|
||
f"已自动退避重试仍失败,请稍后点重试。最后错误:{last_err}"
|
||
)
|
||
if _image_is_transport_error(last_err):
|
||
return (
|
||
f"{kind} failed after {attempts} attempts: 图片网关网络/DNS 连接失败,"
|
||
"请确认本机网络或在 api/.env 配置 AI_HTTP_PROXY / IMAGE_HTTP_PROXY 后重启后端。"
|
||
f"最后错误:{last_err}"
|
||
)
|
||
return f"{kind} failed after {attempts} attempts: {last_err}"
|
||
|
||
|
||
def _image_error_status(error: Exception) -> int:
|
||
msg = str(error)
|
||
return 503 if (
|
||
"上游负载饱和" in msg
|
||
or "HTTP 429" in msg
|
||
or "saturated" in msg.lower()
|
||
or _image_is_transport_error(msg)
|
||
) else 500
|
||
|
||
|
||
def _image_endpoint(path: str) -> str:
|
||
base = (IMAGE_BASE_URL or "").strip().rstrip("/")
|
||
if not base:
|
||
raise RuntimeError("IMAGE_BASE_URL 或 LLM_BASE_URL 未配置")
|
||
return f"{base}/{path.lstrip('/')}"
|
||
|
||
|
||
def _prepare_image_edit_bytes(image_path: Path, max_side: int) -> bytes:
|
||
import io as _io
|
||
from PIL import Image as _PILImage
|
||
try:
|
||
im = _PILImage.open(image_path)
|
||
if max(im.size) > max_side:
|
||
im.thumbnail((max_side, max_side), _PILImage.LANCZOS)
|
||
buf = _io.BytesIO()
|
||
im.convert("RGB").save(buf, format="JPEG", quality=88)
|
||
return buf.getvalue()
|
||
except Exception:
|
||
return image_path.read_bytes()
|
||
|
||
|
||
def _image_edit_call(
|
||
image_path: Path | list[Path],
|
||
prompt: str,
|
||
model: str | None = None,
|
||
models: list[str] | None = None,
|
||
fallback_text: bool = False,
|
||
max_attempts: int = 3,
|
||
max_side: int = 1024,
|
||
) -> tuple[bytes, str]:
|
||
"""通用 image edit 调用 · 失败重试 + 可选 text fallback。
|
||
返回 (image_bytes, effective_mode) where effective_mode in {"edit","text"}。
|
||
失败 raise RuntimeError。
|
||
输入图自动 resize 到 max_side(默认 1024)边长后再用 multipart 上传;多参考图使用 image[]。
|
||
生图模型按产品规则强制使用 gpt-image-2;model/models 参数只保留兼容旧调用。"""
|
||
import base64 as b64lib
|
||
import time as _time
|
||
import httpx
|
||
if not IMAGE_API_KEY:
|
||
raise RuntimeError("IMAGE_API_KEY 或 LLM_API_KEY 未配置")
|
||
models_cycle = [GPT_IMAGE_MODEL]
|
||
model = GPT_IMAGE_MODEL
|
||
image_paths = image_path if isinstance(image_path, list) else [image_path]
|
||
image_paths = [path for path in image_paths if path and path.exists()][:10]
|
||
if not image_paths:
|
||
raise RuntimeError("image edit reference image missing")
|
||
img_bytes_list = [_prepare_image_edit_bytes(path, max_side) for path in image_paths]
|
||
plan: list[str] = ["edit"] * max_attempts
|
||
if fallback_text:
|
||
plan.append("text")
|
||
|
||
last_err = ""
|
||
resp_data: dict = {}
|
||
effective_mode = "edit"
|
||
capacity_seen = False
|
||
for attempt, current_mode in enumerate(plan):
|
||
current_model = models_cycle[min(attempt, len(models_cycle) - 1)]
|
||
status_code = 0
|
||
body = ""
|
||
retry_after: str | None = None
|
||
try:
|
||
if current_mode == "edit":
|
||
with ai_http_client(timeout=120) as client:
|
||
r = client.post(
|
||
_image_endpoint("/images/edits"),
|
||
headers={
|
||
"Authorization": f"Bearer {IMAGE_API_KEY}",
|
||
},
|
||
data={"model": current_model, "prompt": prompt, "n": "1"},
|
||
files=(
|
||
{"image": ("reference.jpg", img_bytes_list[0], "image/jpeg")}
|
||
if len(img_bytes_list) == 1
|
||
else [
|
||
("image[]", (f"reference_{idx + 1}.jpg", img_bytes, "image/jpeg"))
|
||
for idx, img_bytes in enumerate(img_bytes_list)
|
||
]
|
||
),
|
||
)
|
||
r.raise_for_status()
|
||
resp_data = r.json()
|
||
else:
|
||
resp = image_llm().images.generate(model=current_model, prompt=prompt, n=1)
|
||
resp_data = resp.model_dump() if hasattr(resp, "model_dump") else {"data": [{"b64_json": resp.data[0].b64_json}]}
|
||
if resp_data.get("data"):
|
||
effective_mode = current_mode
|
||
model = current_model # 记录实际成功的 model
|
||
break
|
||
err_obj = resp_data.get("error") or {}
|
||
last_err = f"empty data · {err_obj.get('code', '')} · {str(err_obj.get('message', ''))[:200]} · model={current_model}"
|
||
except httpx.HTTPStatusError as e:
|
||
body = e.response.text
|
||
status_code = e.response.status_code
|
||
retry_after = e.response.headers.get("retry-after")
|
||
capacity_seen = capacity_seen or _image_is_capacity_error(status_code, body)
|
||
fatal = status_code in (401, 403)
|
||
last_err = f"HTTP {status_code}: {body[:200]} · model={current_model}"
|
||
if fatal:
|
||
raise RuntimeError(f"image edit HTTP {status_code}: {body[:300]}")
|
||
except Exception as e:
|
||
last_err = f"{type(e).__name__}: {e} · model={current_model}"
|
||
|
||
if attempt < len(plan) - 1:
|
||
tag = f"retry {attempt + 1}/{len(plan)} → {GPT_IMAGE_MODEL}"
|
||
delay = _image_retry_delay(attempt, status_code, body, retry_after)
|
||
print(f"[image edit {tag}, sleep {delay:.0f}s] {last_err}", flush=True)
|
||
_time.sleep(delay)
|
||
|
||
data_arr = resp_data.get("data", [])
|
||
if not data_arr:
|
||
raise RuntimeError(_image_failure_message("image edit", len(plan), last_err, capacity_seen))
|
||
item = data_arr[0]
|
||
b64 = item.get("b64_json")
|
||
if not b64 and item.get("url"):
|
||
with ai_http_client(timeout=120) as client:
|
||
image_resp = client.get(item["url"])
|
||
image_resp.raise_for_status()
|
||
return image_resp.content, effective_mode
|
||
if not b64:
|
||
raise RuntimeError("image edit returned no b64_json")
|
||
return b64lib.b64decode(b64), effective_mode
|
||
|
||
|
||
def _image_text_call(
|
||
prompt: str,
|
||
model: str | None = None,
|
||
models: list[str] | None = None,
|
||
max_attempts: int = 3,
|
||
) -> tuple[bytes, str]:
|
||
"""Text-only image generation. 生图模型强制使用 gpt-image-2。"""
|
||
import base64 as b64lib
|
||
import time as _time
|
||
if not IMAGE_API_KEY:
|
||
raise RuntimeError("IMAGE_API_KEY 或 LLM_API_KEY 未配置")
|
||
models_cycle = [GPT_IMAGE_MODEL]
|
||
last_err = ""
|
||
resp_data: dict = {}
|
||
capacity_seen = False
|
||
for attempt in range(max_attempts):
|
||
current_model = models_cycle[min(attempt, len(models_cycle) - 1)]
|
||
status_code = 0
|
||
body = ""
|
||
try:
|
||
resp = image_llm().images.generate(model=current_model, prompt=prompt, n=1)
|
||
resp_data = resp.model_dump() if hasattr(resp, "model_dump") else {"data": [{"b64_json": resp.data[0].b64_json}]}
|
||
if resp_data.get("data"):
|
||
b64 = resp_data["data"][0].get("b64_json")
|
||
if b64:
|
||
return b64lib.b64decode(b64), "text"
|
||
err_obj = resp_data.get("error") or {}
|
||
last_err = f"empty data · {err_obj.get('code', '')} · {str(err_obj.get('message', ''))[:200]} · model={current_model}"
|
||
except Exception as e:
|
||
last_err = f"{type(e).__name__}: {e} · model={current_model}"
|
||
body = str(e)
|
||
status_code = 429 if "429" in body or "saturated" in body.lower() or "饱和" in body else 0
|
||
capacity_seen = capacity_seen or _image_is_capacity_error(status_code, body)
|
||
if attempt < max_attempts - 1:
|
||
delay = _image_retry_delay(attempt, status_code, body)
|
||
print(f"[image text retry {attempt + 1}/{max_attempts} → {GPT_IMAGE_MODEL}, sleep {delay:.0f}s] {last_err}", flush=True)
|
||
_time.sleep(delay)
|
||
raise RuntimeError(_image_failure_message("image text", max_attempts, last_err, capacity_seen))
|
||
|
||
|
||
def _image_path_to_data_url(path: Path) -> str:
|
||
media_type = "image/png" if path.suffix.lower() == ".png" else "image/jpeg"
|
||
return f"data:{media_type};base64,{base64.b64encode(path.read_bytes()).decode('ascii')}"
|
||
|
||
|
||
def _vision_brief_from_images(image_paths: list[Path], prompt: str, max_images: int = 8) -> str:
|
||
paths = [path for path in image_paths if path.exists()][:max_images]
|
||
if not paths:
|
||
return ""
|
||
if not LLM_API_KEY:
|
||
return ""
|
||
content: list[dict] = [{"type": "text", "text": prompt}]
|
||
for path in paths:
|
||
content.append({"type": "image_url", "image_url": {"url": _image_path_to_data_url(path)}})
|
||
try:
|
||
resp = llm().chat.completions.create(
|
||
model=VISION_MODEL,
|
||
messages=[{"role": "user", "content": content}],
|
||
response_format={"type": "json_object"},
|
||
temperature=0.1,
|
||
max_tokens=1400,
|
||
)
|
||
raw = (resp.choices[0].message.content or "").strip()
|
||
if not raw:
|
||
raw = (getattr(resp.choices[0].message, "reasoning_content", "") or "").strip()
|
||
match = re.search(r"\{[\s\S]*\}", raw)
|
||
raw = match.group(0) if match else raw
|
||
data = json.loads(raw)
|
||
except Exception as e:
|
||
print(f"[vision brief failed] {e}", flush=True)
|
||
return ""
|
||
|
||
if isinstance(data, dict):
|
||
if isinstance(data.get("brief"), str) and data["brief"].strip():
|
||
return data["brief"].strip()[:1800]
|
||
parts: list[str] = []
|
||
for key in (
|
||
"gender_presentation", "age_range", "body_proportion", "hair", "skin_tone",
|
||
"wardrobe_style", "pose_language", "camera_visibility", "commercial_mood",
|
||
"neck_shoulder_readiness", "style_constraints",
|
||
):
|
||
value = data.get(key)
|
||
if isinstance(value, str) and value.strip():
|
||
parts.append(f"{key.replace('_', ' ')}: {value.strip()}")
|
||
if parts:
|
||
return "; ".join(parts)[:1800]
|
||
return ""
|
||
|
||
|
||
def _describe_source_subject(job_id: str, source_indices: list[int]) -> str:
|
||
"""Turn source keyframes into a non-identifying visual brief for similar-subject text generation."""
|
||
paths = [_source_frame_path(job_id, idx) for idx in source_indices]
|
||
prompt = (
|
||
"You are preparing a non-identifying character brief for generating a NEW similar but non-identical ad subject. "
|
||
"Look at these source video keyframes as evidence of one role and style, not as a person to identify. "
|
||
"Do NOT identify the person, do NOT estimate exact age, do NOT describe biometric identity, and do NOT mention celebrity or real-person likeness. "
|
||
"Output strict JSON only. Use broad style traits suitable for text-to-image generation.\n"
|
||
"Required keys: gender_presentation, age_range, body_proportion, hair, skin_tone, wardrobe_style, "
|
||
"pose_language, camera_visibility, commercial_mood, neck_shoulder_readiness, style_constraints, brief.\n"
|
||
"The brief should be 80-140 words and should preserve category, role, energy, camera readability, and commercial atmosphere while explicitly allowing a new non-identical subject."
|
||
)
|
||
return _vision_brief_from_images(paths, prompt, max_images=8)
|
||
|
||
|
||
def _describe_subject_template_from_images(name: str, subject_style: str, image_paths: list[Path], note: str = "") -> str:
|
||
prompt = (
|
||
f"You are summarizing a saved SKG subject template named '{name}' for future text-to-image generation. "
|
||
f"Subject style: {subject_style}. User note: {note[:500]}. "
|
||
"Look at the subject views and describe the reusable creative direction without copying identity or pixels. "
|
||
"Do NOT identify a person and do NOT describe exact facial identity. "
|
||
"Output strict JSON only with keys: gender_presentation, age_range, body_proportion, material_or_skin, "
|
||
"wardrobe_or_surface_style, pose_language, camera_readability, neck_shoulder_readiness, commercial_mood, brief. "
|
||
"The brief should be 80-140 words and must be useful as a reference character brief for creating a new innovative variation."
|
||
)
|
||
return _vision_brief_from_images(image_paths, prompt, max_images=10)
|
||
|
||
|
||
# ---------- API 路由 ----------
|
||
|
||
class CreateJobReq(BaseModel):
|
||
url: str
|
||
|
||
|
||
class TranslateReq(BaseModel):
|
||
text: str
|
||
target: Literal["en", "zh"] = "en"
|
||
|
||
|
||
class ScriptRewriteSegmentReq(BaseModel):
|
||
index: int
|
||
start: float = 0.0
|
||
end: float = 0.0
|
||
role: str = ""
|
||
source: str = ""
|
||
current_text: str = ""
|
||
|
||
|
||
class RewriteStoryboardScriptReq(BaseModel):
|
||
mode: Literal["segment", "all"] = "segment"
|
||
author_intent: str = ""
|
||
segments: list[ScriptRewriteSegmentReq] = Field(default_factory=list)
|
||
|
||
|
||
@app.post("/translate")
|
||
def translate_text(req: TranslateReq) -> dict:
|
||
"""单条文本翻译(给生图自定义提取元素 zh→en 用)"""
|
||
import re as _re
|
||
text = req.text.strip()
|
||
if not text:
|
||
return {"text": ""}
|
||
if not LLM_API_KEY:
|
||
raise HTTPException(503, "LLM_API_KEY 未配置")
|
||
|
||
target_label = "English" if req.target == "en" else "Simplified Chinese"
|
||
prompt = (
|
||
f"Translate the following text into concise {target_label}, suitable as an element "
|
||
"label in an image-generation prompt. Output only the translation itself — no quotes, "
|
||
"no punctuation, no explanation, no markdown.\n\n"
|
||
f"Input: {text}"
|
||
)
|
||
try:
|
||
resp = llm().chat.completions.create(
|
||
model=TRANSLATE_MODEL,
|
||
messages=[{"role": "user", "content": prompt}],
|
||
temperature=0.2,
|
||
max_tokens=200,
|
||
)
|
||
out = (resp.choices[0].message.content or "").strip()
|
||
if not out:
|
||
rc = getattr(resp.choices[0].message, "reasoning_content", "") or ""
|
||
if rc:
|
||
out = rc.strip().splitlines()[-1].strip()
|
||
out = _re.sub(r'^[\'"「『]+|[\'"」』]+$', "", out).strip()
|
||
return {"text": out}
|
||
except Exception as e:
|
||
raise HTTPException(500, f"translate failed: {e}")
|
||
|
||
|
||
def _fallback_script_rewrite_item(segment: ScriptRewriteSegmentReq, author_intent: str = "") -> dict:
|
||
source = (segment.source or "").strip()
|
||
intent = (author_intent or "").strip()
|
||
role = segment.role or ""
|
||
templates = {
|
||
"开场钩子": "你有没有发现,低头久了以后,脖子和肩膀会先替你喊累。",
|
||
"痛点推进": "刷手机、坐电脑、赶通勤叠在一起,肩颈很容易一直绷着放不下来。",
|
||
"利益证明": "SKG 这种挂脖按摩仪,重点就是贴住肩颈位置,把热敷感和揉按感带到真正紧的地方。",
|
||
"方案过渡": "这一段可以直接拍拿起、戴上、贴合,让产品自然进入日常放松场景。",
|
||
"转化收口": "如果你也想把肩颈放松变成每天的小习惯,可以从这台 SKG 开始。",
|
||
"节奏承接": "顺着原片节奏,把这一句落到一个具体的肩颈使用场景里。",
|
||
}
|
||
rewritten = templates.get(role, templates["节奏承接"])
|
||
if source and role not in {"开场钩子", "转化收口"}:
|
||
rewritten = f"{rewritten} 原片这一句的节奏可以保留,但内容换成 SKG 的佩戴和放松体验。"
|
||
if intent:
|
||
rewritten = f"{rewritten} 语气按作者想法处理:{intent[:44]}。"
|
||
return {"index": segment.index, "text": rewritten[:220]}
|
||
|
||
|
||
def _parse_script_rewrite_items(raw: str, requested: list[ScriptRewriteSegmentReq], author_intent: str = "") -> list[dict]:
|
||
text = (raw or "").strip()
|
||
text = re.sub(r"^```(?:json)?\s*", "", text, flags=re.I).strip()
|
||
text = re.sub(r"\s*```$", "", text).strip()
|
||
match = re.search(r"\{[\s\S]*\}", text)
|
||
json_text = match.group(0) if match else text
|
||
try:
|
||
data = json.loads(json_text)
|
||
except Exception:
|
||
return [_fallback_script_rewrite_item(segment, author_intent) for segment in requested]
|
||
raw_items = data.get("items") if isinstance(data, dict) else data
|
||
if not isinstance(raw_items, list):
|
||
raw_items = []
|
||
by_index: dict[int, str] = {}
|
||
for item in raw_items:
|
||
if not isinstance(item, dict):
|
||
continue
|
||
try:
|
||
idx = int(item.get("index"))
|
||
except Exception:
|
||
continue
|
||
value = str(item.get("text") or item.get("rewritten_text") or "").strip()
|
||
if value:
|
||
by_index[idx] = re.sub(r"\s+", " ", value).strip()[:260]
|
||
return [
|
||
{"index": segment.index, "text": by_index.get(segment.index) or _fallback_script_rewrite_item(segment, author_intent)["text"]}
|
||
for segment in requested
|
||
]
|
||
|
||
|
||
def _rewrite_storyboard_script_sync(req: RewriteStoryboardScriptReq) -> list[dict]:
|
||
segments = [segment for segment in req.segments if (segment.source or segment.current_text).strip()]
|
||
if not segments:
|
||
return []
|
||
author_intent = (req.author_intent or "").strip()
|
||
if not LLM_API_KEY:
|
||
return [_fallback_script_rewrite_item(segment, author_intent) for segment in segments]
|
||
payload = [
|
||
{
|
||
"index": segment.index,
|
||
"time": f"{segment.start:.1f}-{segment.end:.1f}s",
|
||
"role": segment.role,
|
||
"source_reference": segment.source,
|
||
"current_voiceover": segment.current_text,
|
||
}
|
||
for segment in segments
|
||
]
|
||
prompt = (
|
||
"你是信息流广告脚本文案改写师。任务:基于原参考文案的节奏和信息结构,把每段改写成 SKG 挂脖肩颈按摩仪的新口播文案。\n"
|
||
"硬规则:\n"
|
||
"1. 输出中文短视频口播,不要英文,不要舞台说明,不要引号。\n"
|
||
"2. 不逐字翻译原文,不保留原品牌、价格、优惠码、平台话术;只参考节奏、钩子、痛点、转化结构。\n"
|
||
"3. 产品固定为套在脖子上的 U 形肩颈按摩仪,表达肩颈紧绷、久坐低头、热敷感、揉按感、佩戴放松和日常使用场景。\n"
|
||
"4. 避免医疗疗效、治疗、治愈、止痛等强功效承诺。\n"
|
||
"5. 每段尽量短,适配该段时间;保持自然创作者口吻。\n"
|
||
"6. mode=all 时,整片要前后连贯;mode=segment 时,只改给定段落但仍要贴合上下文风格。\n"
|
||
f"作者想法:{author_intent or '没有额外想法,按原片节奏改成自然卖点口播。'}\n"
|
||
f"改写模式:{req.mode}\n"
|
||
f"SKG 产品背景:{AUDIO_PRODUCT_BRIEF}\n\n"
|
||
"输入段落 JSON:\n"
|
||
+ json.dumps(payload, ensure_ascii=False)
|
||
+ '\n\n只输出严格 JSON:{"items":[{"index":0,"text":"改写后的中文口播"}]}'
|
||
)
|
||
models = []
|
||
for model in [AUDIO_REWRITE_MODEL, ASR_FALLBACK_MODEL, TRANSLATE_MODEL]:
|
||
if model and model not in models:
|
||
models.append(model)
|
||
for model in models:
|
||
try:
|
||
resp = llm().chat.completions.create(
|
||
model=model,
|
||
messages=[
|
||
{"role": "system", "content": "只返回合法 JSON,不要 markdown,不要解释。"},
|
||
{"role": "user", "content": prompt},
|
||
],
|
||
response_format={"type": "json_object"},
|
||
temperature=0.68 if req.mode == "all" else 0.62,
|
||
max_tokens=max(900, min(5000, 180 * len(segments) + 500)),
|
||
)
|
||
message = resp.choices[0].message
|
||
raw = (message.content or getattr(message, "reasoning_content", "") or "").strip()
|
||
items = _parse_script_rewrite_items(raw, segments, author_intent)
|
||
if any((item.get("text") or "").strip() for item in items):
|
||
return items
|
||
except Exception as e:
|
||
print(f"[script rewrite fallback] {model}: {e}", flush=True)
|
||
continue
|
||
return [_fallback_script_rewrite_item(segment, author_intent) for segment in segments]
|
||
|
||
|
||
@app.post("/jobs/{job_id}/script/rewrite")
|
||
def rewrite_storyboard_script(job_id: str, req: RewriteStoryboardScriptReq) -> dict:
|
||
if job_id not in JOBS:
|
||
raise HTTPException(404, "job not found")
|
||
return {"items": _rewrite_storyboard_script_sync(req)}
|
||
|
||
|
||
@app.get("/health")
|
||
def health() -> dict:
|
||
return {
|
||
"ok": True,
|
||
"llm_configured": bool(LLM_API_KEY),
|
||
"auth_configured": WEB_AUTH_CONFIGURED,
|
||
"base_url": LLM_BASE_URL or "openai-default",
|
||
"image_base_url": IMAGE_BASE_URL or LLM_BASE_URL or "openai-default",
|
||
"voice_base_url": AZURE_OPENAI_BASE_URL,
|
||
"models": {
|
||
"asr": ASR_MODEL,
|
||
"local_asr": LOCAL_ASR_MODEL,
|
||
"asr_fallback": ASR_FALLBACK_MODEL,
|
||
"translate": TRANSLATE_MODEL,
|
||
"rewrite": REWRITE_MODEL,
|
||
"audio_rewrite": AUDIO_REWRITE_MODEL,
|
||
"vision": VISION_MODEL,
|
||
"product_view": PRODUCT_VIEW_MODEL,
|
||
"image": IMAGE_MODEL,
|
||
"image_base_url": IMAGE_BASE_URL or LLM_BASE_URL or "openai-default",
|
||
"ai_proxy_configured": bool(AI_HTTP_PROXY),
|
||
"image_fallbacks": [GPT_IMAGE_MODEL],
|
||
"subject_image": SUBJECT_ASSET_IMAGE_MODEL,
|
||
"subject_image_fallbacks": SUBJECT_ASSET_IMAGE_MODELS,
|
||
"voice_provider": VOICE_PROVIDER,
|
||
"voice_base_url": AZURE_OPENAI_BASE_URL,
|
||
"voice_tts": AZURE_TTS_MODEL,
|
||
"voice_tts_paths": AZURE_TTS_PATHS,
|
||
"voice_id": AZURE_TTS_VOICE_ID,
|
||
"voice_pool": AZURE_TTS_VOICE_POOL,
|
||
"voice_configured": bool(AZURE_OPENAI_API_KEY),
|
||
"video": VIDEO_MODEL,
|
||
"video_aliases": VIDEO_MODEL_ALIASES,
|
||
"video_provider": video_provider_name(),
|
||
"video_base_url": video_api_base(),
|
||
"video_configured": bool(video_api_key()),
|
||
"video_create_paths": VIDEO_CREATE_PATHS,
|
||
},
|
||
}
|
||
|
||
|
||
class JobSummary(BaseModel):
|
||
id: str
|
||
url: str
|
||
status: JobStatus
|
||
progress: int = 0
|
||
message: str = ""
|
||
duration: float = 0.0
|
||
width: int = 0
|
||
height: int = 0
|
||
video_url: str = ""
|
||
frame_count: int = 0
|
||
video_count: int = 0
|
||
thumbnail: str = ""
|
||
error: str = ""
|
||
mtime: float = 0.0
|
||
|
||
|
||
@app.get("/jobs", response_model=list[JobSummary])
|
||
def list_jobs(limit: int | None = None) -> list[JobSummary]:
|
||
"""所有 job 的精简列表,按磁盘 state.json mtime 倒序(最新优先)。前端无 ?job= 时用它回填历史。"""
|
||
items: list[JobSummary] = []
|
||
for job_id, job in JOBS.items():
|
||
state_path = JOBS_DIR / job_id / "state.json"
|
||
mtime = state_path.stat().st_mtime if state_path.exists() else 0.0
|
||
thumb = f"/jobs/{job_id}/frames/{job.frames[0].index}.jpg" if job.frames else ""
|
||
items.append(JobSummary(
|
||
id=job.id,
|
||
url=job.url,
|
||
status=job.status,
|
||
progress=job.progress,
|
||
message=job.message,
|
||
duration=job.duration,
|
||
width=job.width,
|
||
height=job.height,
|
||
video_url=job.video_url,
|
||
frame_count=len(job.frames),
|
||
video_count=len(job.generated_videos),
|
||
thumbnail=thumb,
|
||
error=job.error,
|
||
mtime=mtime,
|
||
))
|
||
items.sort(key=lambda s: s.mtime, reverse=True)
|
||
if limit is not None and limit > 0:
|
||
items = items[:limit]
|
||
return items
|
||
|
||
|
||
@app.post("/jobs", response_model=Job)
|
||
async def create_job(req: CreateJobReq, bg: BackgroundTasks) -> Job:
|
||
if not req.url.strip():
|
||
raise HTTPException(400, "url required")
|
||
job_id = uuid.uuid4().hex[:12]
|
||
job = Job(id=job_id, url=req.url.strip())
|
||
JOBS[job_id] = job
|
||
save_state(job)
|
||
bg.add_task(pipeline_download, job_id)
|
||
return job
|
||
|
||
|
||
@app.post("/jobs/{job_id}/download/retry", response_model=Job)
|
||
async def retry_job_download(job_id: str, bg: BackgroundTasks) -> Job:
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
source_kind = getattr(job, "source_kind", "")
|
||
if source_kind == "upload" or job.url.startswith("upload://"):
|
||
raise HTTPException(409, "uploaded videos cannot be redownloaded; upload the file again")
|
||
if job.status in {"downloading", "splitting", "transcribing"}:
|
||
raise HTTPException(409, f"job is busy: {job.status}")
|
||
|
||
mp4 = job_dir(job_id) / "source.mp4"
|
||
if mp4.exists() and mp4.stat().st_size == 0:
|
||
mp4.unlink()
|
||
update(
|
||
job,
|
||
status="downloading",
|
||
progress=1,
|
||
error="",
|
||
message="重新提交下载…",
|
||
video_url="",
|
||
)
|
||
bg.add_task(pipeline_download, job_id)
|
||
return job
|
||
|
||
|
||
@app.post("/jobs/upload", response_model=Job)
|
||
async def create_job_from_upload(bg: BackgroundTasks, file: UploadFile = File(...)) -> Job:
|
||
if not file.filename:
|
||
raise HTTPException(400, "file required")
|
||
ext = Path(file.filename).suffix.lower()
|
||
if ext not in {".mp4", ".mov", ".webm", ".mkv", ".m4v"}:
|
||
raise HTTPException(400, f"unsupported video format: {ext}")
|
||
|
||
job_id = uuid.uuid4().hex[:12]
|
||
d = job_dir(job_id)
|
||
mp4 = d / "source.mp4"
|
||
with mp4.open("wb") as f:
|
||
while chunk := await file.read(1024 * 1024):
|
||
f.write(chunk)
|
||
if not mp4.exists() or mp4.stat().st_size == 0:
|
||
raise HTTPException(500, "upload failed")
|
||
|
||
job = Job(id=job_id, url=f"upload://{file.filename}")
|
||
JOBS[job_id] = job
|
||
save_state(job)
|
||
bg.add_task(pipeline_download, job_id)
|
||
return job
|
||
|
||
|
||
@app.post("/jobs/{job_id}/analyze", response_model=Job)
|
||
async def trigger_analyze(
|
||
job_id: str,
|
||
bg: BackgroundTasks,
|
||
frames: int = KEYFRAME_COUNT,
|
||
target: FrameExtractTarget = "transparent_human",
|
||
mode: FrameExtractMode = "replace",
|
||
quality: FrameExtractQuality = "auto",
|
||
) -> Job:
|
||
global ANALYZE_WORKER_RUNNING
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
if job.status not in {"downloaded", "frames_extracted", "transcribed", "transcribing", "failed"}:
|
||
raise HTTPException(409, f"status must be downloaded/transcribing/failed, got {job.status}")
|
||
ANALYZE_QUEUE.append((job_id, frames, target, mode, quality))
|
||
position = len(ANALYZE_QUEUE)
|
||
update(
|
||
job,
|
||
status="splitting",
|
||
progress=30,
|
||
error="",
|
||
message="排队等待抽帧" if ANALYZE_WORKER_RUNNING or position > 1 else "准备抽帧…",
|
||
)
|
||
if not ANALYZE_WORKER_RUNNING:
|
||
ANALYZE_WORKER_RUNNING = True
|
||
bg.add_task(analyze_queue_worker)
|
||
return job
|
||
|
||
|
||
@app.post("/jobs/{job_id}/frames", response_model=Job)
|
||
def add_manual_frame(job_id: str, t: float) -> Job:
|
||
"""从指定时间戳手动抽 1 帧追加到 job.frames"""
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
if not job.video_url:
|
||
raise HTTPException(400, "video not ready")
|
||
d = job_dir(job_id)
|
||
mp4 = d / "source.mp4"
|
||
if not mp4.exists():
|
||
raise HTTPException(400, "source.mp4 missing")
|
||
frames_dir = d / "frames"
|
||
frames_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 新 index:max(existing)+1(即使列表已按 ts 排序,文件名用 index 保持稳定)
|
||
next_idx = max((f.index for f in job.frames), default=-1) + 1
|
||
out = frames_dir / f"{next_idx:03d}.jpg"
|
||
try:
|
||
run([
|
||
"ffmpeg", "-y", "-ss", str(t), "-i", str(mp4),
|
||
"-frames:v", "1", "-pix_fmt", "yuvj420p", "-q:v", "3",
|
||
str(out),
|
||
])
|
||
except RuntimeError as e:
|
||
raise HTTPException(500, f"ffmpeg failed: {e}")
|
||
|
||
new_frame = KeyFrame(
|
||
index=next_idx,
|
||
timestamp=round(float(t), 2),
|
||
url=f"/jobs/{job_id}/frames/{next_idx}.jpg",
|
||
)
|
||
merged = sorted(list(job.frames) + [new_frame], key=lambda f: f.timestamp)
|
||
update(job, frames=merged, message=f"已手动加帧({t:.1f}s),共 {len(merged)} 张")
|
||
return job
|
||
|
||
|
||
@app.get("/jobs/{job_id}", response_model=Job)
|
||
def get_job(job_id: str) -> Job:
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
return job_with_artifacts(job)
|
||
|
||
|
||
@app.delete("/jobs/{job_id}")
|
||
def delete_job(job_id: str) -> dict[str, bool | str]:
|
||
d = (JOBS_DIR / job_id).resolve()
|
||
if JOBS_DIR not in d.parents:
|
||
raise HTTPException(400, "invalid job id")
|
||
job = JOBS.pop(job_id, None)
|
||
if not job and not d.exists():
|
||
raise HTTPException(404, "job not found")
|
||
if d.exists():
|
||
shutil.rmtree(d)
|
||
return {"ok": True, "id": job_id}
|
||
|
||
|
||
@app.post("/jobs/{job_id}/transcribe", response_model=Job)
|
||
async def trigger_transcribe(job_id: str, bg: BackgroundTasks) -> Job:
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
mp4 = job_dir(job_id) / "source.mp4"
|
||
if job.status in {"created", "downloading"} or not mp4.exists():
|
||
raise HTTPException(409, f"video not ready, got {job.status}")
|
||
if job.status == "transcribing" or job.audio_script.status == "rewriting" or job_id in AUDIO_WORKERS_RUNNING:
|
||
raise HTTPException(409, f"job is busy, got {job.status}")
|
||
manage_job_status = job.status != "splitting"
|
||
audio_payload = AudioScript(
|
||
status="rewriting",
|
||
speaker_profile="正在分析原音频讲话人和口播节奏…",
|
||
rhythm_profile="正在按原音频时长、语速和停顿分析口播节奏…",
|
||
background_audio_profile="正在分析背景音乐、环境声和音效…",
|
||
product_brief=AUDIO_PRODUCT_BRIEF,
|
||
rewrite_model=ASR_FALLBACK_MODEL,
|
||
)
|
||
if manage_job_status:
|
||
update(job, status="transcribing", progress=max(45, min(job.progress, 70)), error="", message="准备提取音频…", audio_script=audio_payload)
|
||
else:
|
||
update(job, error="", audio_script=audio_payload)
|
||
if not start_audio_processing(job_id, manage_job_status=manage_job_status):
|
||
update(job, message="音频已在处理中")
|
||
return job_with_artifacts(job)
|
||
|
||
|
||
@app.get("/jobs/{job_id}/video.mp4")
|
||
def get_video(job_id: str):
|
||
p = job_dir(job_id) / "source.mp4"
|
||
if not p.exists():
|
||
raise HTTPException(404, "video not found")
|
||
return FileResponse(p, media_type="video/mp4")
|
||
|
||
|
||
@app.get("/jobs/{job_id}/audio.wav")
|
||
def get_source_audio(job_id: str):
|
||
p = job_dir(job_id) / "audio.wav"
|
||
if not p.exists():
|
||
raise HTTPException(404, "audio not found")
|
||
return FileResponse(p, media_type="audio/wav")
|
||
|
||
|
||
@app.get("/jobs/{job_id}/audio-script.mp3")
|
||
def get_audio_script(job_id: str):
|
||
p = job_dir(job_id) / "audio_script.mp3"
|
||
if not p.exists():
|
||
raise HTTPException(404, "audio script not found")
|
||
return FileResponse(p, media_type="audio/mpeg")
|
||
|
||
|
||
@app.get("/jobs/{job_id}/frames/{idx}.jpg")
|
||
def get_frame(job_id: str, idx: int):
|
||
p = job_dir(job_id) / "frames" / f"{idx:03d}.jpg"
|
||
if not p.exists():
|
||
raise HTTPException(404, "frame not found")
|
||
return FileResponse(p, media_type="image/jpeg")
|
||
|
||
|
||
class GenerateReq(BaseModel):
|
||
prompt: str
|
||
extra_prompt: str = "" # ✓ 需要的元素(正向)
|
||
negative_prompt: str = "" # ✗ 不需要的元素(负向)
|
||
model: str = "" # 兼容旧前端字段;服务端强制使用 gpt-image-2
|
||
mode: str = "edit" # "edit" 带参考图,"text" 纯文字
|
||
from_selected: bool = False # True 时优先用 frame.selected 的生成图作 reference(迭代),否则原关键帧
|
||
|
||
|
||
@app.post("/jobs/{job_id}/frames/{idx}/generate", response_model=Job)
|
||
def generate_image(job_id: str, idx: int, req: GenerateReq) -> Job:
|
||
"""根据关键帧 + prompt 生成新图(image-to-image 或 text-to-image)"""
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
frame = next((f for f in job.frames if f.index == idx), None)
|
||
if not frame:
|
||
raise HTTPException(404, "frame not found")
|
||
frame_path = job_dir(job_id) / "frames" / f"{idx:03d}.jpg"
|
||
if not frame_path.exists():
|
||
raise HTTPException(404, "frame file missing")
|
||
|
||
# 决定 i2i 参考图:from_selected=True 且存在 selected 生成图 → 用它(迭代);否则原关键帧
|
||
reference_path = frame_path
|
||
reference_source = "keyframe"
|
||
if req.from_selected:
|
||
sel = next((g for g in frame.generated_images if g.selected), None)
|
||
if sel:
|
||
sel_path = job_dir(job_id) / "gen" / f"{idx:03d}_{sel.id}.jpg"
|
||
if sel_path.exists():
|
||
reference_path = sel_path
|
||
reference_source = f"gen:{sel.id[:6]}"
|
||
|
||
full_prompt = req.prompt.strip()
|
||
if req.extra_prompt.strip():
|
||
full_prompt = f"{full_prompt}. Include: {req.extra_prompt.strip()}"
|
||
if req.negative_prompt.strip():
|
||
full_prompt = f"{full_prompt}. Avoid: {req.negative_prompt.strip()}"
|
||
if not full_prompt:
|
||
raise HTTPException(400, "prompt required")
|
||
if not IMAGE_API_KEY:
|
||
raise HTTPException(503, "IMAGE_API_KEY 或 LLM_API_KEY 未配置")
|
||
|
||
model = GPT_IMAGE_MODEL
|
||
gen_id = uuid.uuid4().hex[:12]
|
||
|
||
import base64 as b64lib
|
||
import time as _time
|
||
import httpx
|
||
|
||
img_bytes_in: bytes | None = None
|
||
if req.mode == "edit":
|
||
img_bytes_in = reference_path.read_bytes()
|
||
|
||
# 尝试 i2i 最多 3 次,全失败时降级 text-only 再试 1 次
|
||
plan: list[str] = ([req.mode] * 3) if req.mode == "edit" else [req.mode]
|
||
if req.mode == "edit":
|
||
plan.append("text") # i2i 都失败时自动降级
|
||
resp_data: dict = {}
|
||
last_err = ""
|
||
effective_mode = req.mode
|
||
capacity_seen = False
|
||
for attempt, current_mode in enumerate(plan):
|
||
status_code = 0
|
||
body = ""
|
||
retry_after: str | None = None
|
||
try:
|
||
if current_mode == "edit":
|
||
if img_bytes_in is None:
|
||
raise RuntimeError("edit mode reference image missing")
|
||
with ai_http_client(timeout=120) as client:
|
||
r = client.post(
|
||
_image_endpoint("/images/edits"),
|
||
headers={
|
||
"Authorization": f"Bearer {IMAGE_API_KEY}",
|
||
},
|
||
data={"model": model, "prompt": full_prompt, "n": "1"},
|
||
files={"image": ("reference.jpg", img_bytes_in, "image/jpeg")},
|
||
)
|
||
r.raise_for_status()
|
||
resp_data = r.json()
|
||
else:
|
||
# text-only
|
||
resp = image_llm().images.generate(model=model, prompt=full_prompt, n=1)
|
||
resp_data = resp.model_dump() if hasattr(resp, "model_dump") else {"data": [{"b64_json": resp.data[0].b64_json}]}
|
||
|
||
if resp_data.get("data"):
|
||
effective_mode = current_mode
|
||
break
|
||
err_obj = resp_data.get("error") or {}
|
||
last_err = f"empty data · {err_obj.get('code', '')} · {str(err_obj.get('message', ''))[:200]}"
|
||
except httpx.HTTPStatusError as e:
|
||
body = e.response.text
|
||
status_code = e.response.status_code
|
||
retry_after = e.response.headers.get("retry-after")
|
||
capacity_seen = capacity_seen or _image_is_capacity_error(status_code, body)
|
||
transient = (
|
||
status_code == 429
|
||
or status_code >= 500
|
||
or "incomplete_generation" in body
|
||
or "rate_limit" in body
|
||
or "timeout" in body.lower()
|
||
or _image_is_capacity_error(status_code, body)
|
||
)
|
||
last_err = f"HTTP {status_code}: {body[:200]}"
|
||
if not transient:
|
||
raise HTTPException(500, f"image gen HTTP {status_code}: {body[:300]}")
|
||
except Exception as e:
|
||
last_err = f"{type(e).__name__}: {e}"
|
||
|
||
if attempt < len(plan) - 1:
|
||
next_mode = plan[attempt + 1]
|
||
tag = f"fallback → {next_mode}" if next_mode != current_mode else f"retry {attempt + 1}/{len(plan)}"
|
||
print(f"[image gen {tag}] {last_err}", flush=True)
|
||
_time.sleep(_image_retry_delay(attempt, status_code, body, retry_after))
|
||
|
||
data_arr = resp_data.get("data", [])
|
||
if not data_arr:
|
||
raise HTTPException(503 if capacity_seen else 500, _image_failure_message("image gen", len(plan), last_err, capacity_seen))
|
||
|
||
item = data_arr[0]
|
||
b64 = item.get("b64_json")
|
||
if b64:
|
||
out_bytes = b64lib.b64decode(b64)
|
||
elif item.get("url"):
|
||
with ai_http_client(timeout=120) as client:
|
||
image_resp = client.get(item["url"])
|
||
image_resp.raise_for_status()
|
||
out_bytes = image_resp.content
|
||
else:
|
||
raise HTTPException(500, "image gen returned no b64_json")
|
||
|
||
# 保存到本地 jobs/<id>/gen/<idx>_<gen_id>.jpg
|
||
gen_dir = job_dir(job_id) / "gen"
|
||
gen_dir.mkdir(parents=True, exist_ok=True)
|
||
out_path = gen_dir / f"{idx:03d}_{gen_id}.jpg"
|
||
out_path.write_bytes(out_bytes)
|
||
|
||
new_gen = GeneratedImage(
|
||
id=gen_id,
|
||
prompt=full_prompt,
|
||
model=model,
|
||
mode=effective_mode,
|
||
url=f"/jobs/{job_id}/frames/{idx}/gen/{gen_id}.jpg",
|
||
selected=False,
|
||
created_at=_time.time(),
|
||
)
|
||
|
||
# 写回 job.frames
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
f.generated_images = f.generated_images + [new_gen]
|
||
update(job, frames=job.frames, message=f"生图完成 · 分镜 {idx + 1}")
|
||
return job
|
||
|
||
|
||
@app.get("/jobs/{job_id}/frames/{idx}/gen/{gen_id}.jpg")
|
||
def get_generated_image(job_id: str, idx: int, gen_id: str):
|
||
p = job_dir(job_id) / "gen" / f"{idx:03d}_{gen_id}.jpg"
|
||
if not p.exists():
|
||
raise HTTPException(404, "generated image not found")
|
||
return FileResponse(p, media_type="image/jpeg")
|
||
|
||
|
||
class SelectGenReq(BaseModel):
|
||
selected: bool
|
||
|
||
|
||
@app.post("/jobs/{job_id}/frames/{idx}/gen/{gen_id}/select", response_model=Job)
|
||
def select_generated(job_id: str, idx: int, gen_id: str, req: SelectGenReq) -> Job:
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
for f in job.frames:
|
||
if f.index != idx:
|
||
continue
|
||
for g in f.generated_images:
|
||
# 单选:该帧只能选一张
|
||
if g.id == gen_id:
|
||
g.selected = req.selected
|
||
else:
|
||
g.selected = False
|
||
break
|
||
update(job, frames=job.frames)
|
||
return job
|
||
|
||
|
||
@app.post("/jobs/{job_id}/frames/{idx}/describe", response_model=Job)
|
||
def describe_frame(job_id: str, idx: int) -> Job:
|
||
"""调 vision 模型识别该关键帧,返回结构化描述。"""
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
frame = next((f for f in job.frames if f.index == idx), None)
|
||
if not frame:
|
||
raise HTTPException(404, "frame not found")
|
||
p = job_dir(job_id) / "frames" / f"{idx:03d}.jpg"
|
||
if not p.exists():
|
||
raise HTTPException(404, "frame file not found")
|
||
|
||
import base64 as b64lib
|
||
import re as _re
|
||
img_b64 = b64lib.b64encode(p.read_bytes()).decode("ascii")
|
||
|
||
prompt = (
|
||
"请识别这张图,输出严格 JSON(不要 markdown 不要解释,不要思考):\n"
|
||
'{\n'
|
||
' "scene": "一句话描述场景",\n'
|
||
' "objects": [{"name": "物体名(中文)", "position": "在画面哪里", "color": "颜色", "extract_prompt": "用于提取该元素的英文 prompt"}],\n'
|
||
' "style": "整体风格 / 打光 / 色调(一句话)",\n'
|
||
' "suggested_prompt": "适合用作下游生图的完整英文 prompt",\n'
|
||
' "transparent_human_assessment": {"transparent_body_score": 0, "skeleton_visible_score": 0, "human_prominence_score": 0, "clarity_score": 0, "commercial_style_score": 0, "product_usefulness_score": 0, "qualified": false, "reject_reason": "如果不合格说明原因"}\n'
|
||
'}\n'
|
||
"要求:objects 列出 3-8 个画面里**可独立提取**的主要元素,extract_prompt 用于后续 image edit 模型。"
|
||
"transparent_human_assessment 按透明骨架人标准评分:"
|
||
+ TRANSPARENT_HUMAN_POSITIVE_PROMPT + " "
|
||
+ TRANSPARENT_HUMAN_NEGATIVE_PROMPT + " "
|
||
+ TRANSPARENT_HUMAN_QUALIFIED_STANDARD
|
||
)
|
||
|
||
last_err = ""
|
||
data = None
|
||
for attempt in range(3):
|
||
try:
|
||
resp = llm().chat.completions.create(
|
||
model=VISION_MODEL,
|
||
messages=[{"role": "user", "content": [
|
||
{"type": "text", "text": prompt},
|
||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img_b64}"}},
|
||
]}],
|
||
response_format={"type": "json_object"},
|
||
temperature=0.3,
|
||
max_tokens=3000,
|
||
)
|
||
content = (resp.choices[0].message.content or "").strip()
|
||
if not content:
|
||
# thinking 模型可能 content 空;尝试取 reasoning_content 里挖 JSON
|
||
rc = getattr(resp.choices[0].message, "reasoning_content", "") or ""
|
||
m = _re.search(r"\{[\s\S]*\}", rc)
|
||
content = m.group(0) if m else ""
|
||
# 剥掉 ```json ... ``` 包装
|
||
content = _re.sub(r"^```(?:json)?\s*|\s*```$", "", content).strip()
|
||
if not content:
|
||
last_err = f"empty content (attempt {attempt + 1})"
|
||
continue
|
||
data = json.loads(content)
|
||
break
|
||
except json.JSONDecodeError as e:
|
||
last_err = f"json decode (attempt {attempt + 1}): {e} · raw[:200]={content[:200]}"
|
||
print(f"[vision retry] {last_err}", flush=True)
|
||
continue
|
||
except Exception as e:
|
||
last_err = f"vision call (attempt {attempt + 1}): {e}"
|
||
print(f"[vision retry] {last_err}", flush=True)
|
||
continue
|
||
|
||
if data is None:
|
||
raise HTTPException(500, last_err or "vision failed after 3 retries")
|
||
|
||
# 写回 job
|
||
new_frames = []
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
f.description = data
|
||
new_frames.append(f)
|
||
update(job, frames=new_frames, message=f"识别完成 · 分镜 {idx + 1}")
|
||
return job
|
||
|
||
|
||
# ---------- 清洗水印 / 元素提取(关键帧二阶段加工) ----------
|
||
|
||
class CleanupReq(BaseModel):
|
||
# 多个相对坐标矩形 0-1,限制清洗范围;空 / None = 全图清洗
|
||
regions: list[dict] | None = None # [{"x","y","w","h"}, ...]
|
||
|
||
|
||
def _region_to_phrase(r: dict) -> str:
|
||
"""把相对坐标矩形转成简短方位描述给 prompt 用(避免百分号 / 括号触发模型异常)"""
|
||
x = max(0.0, min(1.0, float(r.get("x", 0))))
|
||
y = max(0.0, min(1.0, float(r.get("y", 0))))
|
||
w = max(0.0, min(1.0 - x, float(r.get("w", 0))))
|
||
h = max(0.0, min(1.0 - y, float(r.get("h", 0))))
|
||
if w <= 0 or h <= 0:
|
||
return ""
|
||
cx, cy = x + w / 2, y + h / 2
|
||
hpos = "left" if cx < 0.4 else "right" if cx > 0.6 else "middle"
|
||
vpos = "top" if cy < 0.4 else "bottom" if cy > 0.6 else "middle"
|
||
if hpos == "middle" and vpos == "middle":
|
||
return "center"
|
||
if hpos == "middle":
|
||
return vpos
|
||
if vpos == "middle":
|
||
return hpos
|
||
return f"{vpos} {hpos}"
|
||
|
||
|
||
@app.post("/jobs/{job_id}/frames/{idx}/cleanup", response_model=Job)
|
||
def cleanup_frame(job_id: str, idx: int, req: CleanupReq | None = None) -> Job:
|
||
"""调 gpt-image-2 image edit 清洗关键帧:去水印 / @用户名 / 字幕 / 平台 logo。
|
||
输出干净版到 jobs/<id>/cleaned/<idx>.jpg,写回 frame.cleaned_url。
|
||
可选 region: 限定只清洗框内区域。"""
|
||
import time as _time
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
frame = next((f for f in job.frames if f.index == idx), None)
|
||
if not frame:
|
||
raise HTTPException(404, "frame not found")
|
||
frame_path = job_dir(job_id) / "frames" / f"{idx:03d}.jpg"
|
||
if not frame_path.exists():
|
||
raise HTTPException(404, "frame file missing")
|
||
|
||
region_phrases: list[str] = []
|
||
if req and req.regions:
|
||
for r in req.regions:
|
||
p = _region_to_phrase(r)
|
||
if p:
|
||
region_phrases.append(p)
|
||
region_phrases = list(dict.fromkeys(region_phrases))
|
||
|
||
# prompt 用"重画一张副本"语义而非"erase / remove only X" — 避免 Gemini 走 mask/inpainting
|
||
# function call 路径(实测该路径在 SKG 网关上 100% 触发 incomplete_generation)
|
||
if region_phrases:
|
||
if len(region_phrases) == 1:
|
||
zones = f"the {region_phrases[0]} area"
|
||
else:
|
||
zones = ", ".join(region_phrases) + " areas"
|
||
prompt = (
|
||
f"Recreate this image as a clean version: remove the text and graphics in {zones}, "
|
||
"keep the rest of the scene identical."
|
||
)
|
||
else:
|
||
prompt = (
|
||
"Recreate this image as a clean version without watermarks, captions, "
|
||
"hashtags, usernames, or platform logos. Keep the composition and style."
|
||
)
|
||
|
||
models = [GPT_IMAGE_MODEL]
|
||
try:
|
||
img_bytes, _mode = _image_edit_call(
|
||
frame_path, prompt, models=models, fallback_text=False, max_attempts=3,
|
||
)
|
||
except RuntimeError as e:
|
||
raise HTTPException(500, f"cleanup failed: {e}")
|
||
|
||
out_dir = job_dir(job_id) / "cleaned"
|
||
out_dir.mkdir(parents=True, exist_ok=True)
|
||
out_path = out_dir / f"{idx:03d}.jpg"
|
||
out_path.write_bytes(img_bytes)
|
||
|
||
new_frames = []
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
f.cleaned_url = f"/jobs/{job_id}/frames/{idx}/cleaned.jpg?t={int(_time.time())}"
|
||
f.cleaned_applied = False # 重新清洗:重置"已应用"状态
|
||
new_frames.append(f)
|
||
update(job, frames=new_frames, message=f"清洗完成 · 分镜 {idx + 1}")
|
||
return job
|
||
|
||
|
||
@app.get("/jobs/{job_id}/frames/{idx}/cleaned.jpg")
|
||
def get_cleaned_frame(job_id: str, idx: int):
|
||
p = job_dir(job_id) / "cleaned" / f"{idx:03d}.jpg"
|
||
if not p.exists():
|
||
raise HTTPException(404, "cleaned frame not found")
|
||
return FileResponse(p, media_type="image/jpeg")
|
||
|
||
|
||
@app.delete("/jobs/{job_id}/frames/{idx}/cleanup", response_model=Job)
|
||
def discard_cleaned(job_id: str, idx: int) -> Job:
|
||
"""丢弃待应用的清洗版(不影响已应用的)"""
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
frame = next((f for f in job.frames if f.index == idx), None)
|
||
if not frame:
|
||
raise HTTPException(404, "frame not found")
|
||
p = job_dir(job_id) / "cleaned" / f"{idx:03d}.jpg"
|
||
if p.exists():
|
||
try: p.unlink()
|
||
except OSError: pass
|
||
new_frames = []
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
f.cleaned_url = None
|
||
new_frames.append(f)
|
||
update(job, frames=new_frames, message=f"丢弃清洗版 · 分镜 {idx + 1}")
|
||
return job
|
||
|
||
|
||
@app.post("/jobs/{job_id}/frames/{idx}/cleanup/apply", response_model=Job)
|
||
def apply_cleaned(job_id: str, idx: int) -> Job:
|
||
"""用清洗版替换原关键帧:物理覆盖 frames/{idx}.jpg ← cleaned/{idx}.jpg。
|
||
原图作备份 → orig/{idx}.jpg(首次替换时备份,后续替换跳过)。
|
||
替换后 frame.cleaned_url 清空(不再有"待应用"清洗版)"""
|
||
import shutil as _shutil
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
frame = next((f for f in job.frames if f.index == idx), None)
|
||
if not frame:
|
||
raise HTTPException(404, "frame not found")
|
||
cleaned_path = job_dir(job_id) / "cleaned" / f"{idx:03d}.jpg"
|
||
if not cleaned_path.exists():
|
||
raise HTTPException(404, "no cleaned version to apply")
|
||
frame_path = job_dir(job_id) / "frames" / f"{idx:03d}.jpg"
|
||
|
||
# 首次替换:把原图备份到 orig/{idx}.jpg
|
||
orig_dir = job_dir(job_id) / "orig"
|
||
orig_dir.mkdir(parents=True, exist_ok=True)
|
||
orig_backup = orig_dir / f"{idx:03d}.jpg"
|
||
if not orig_backup.exists() and frame_path.exists():
|
||
_shutil.copy2(frame_path, orig_backup)
|
||
|
||
# 用 cleaned 覆盖 frames/
|
||
_shutil.copy2(cleaned_path, frame_path)
|
||
# 删 cleaned 文件(已经"应用",不再是单独的待选版本)
|
||
try:
|
||
cleaned_path.unlink()
|
||
except OSError:
|
||
pass
|
||
|
||
new_frames = []
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
f.cleaned_url = None
|
||
f.cleaned_applied = True
|
||
new_frames.append(f)
|
||
update(job, frames=new_frames, message=f"已替换分镜 {idx + 1} 为清洗版")
|
||
return job
|
||
|
||
|
||
class AddElementReq(BaseModel):
|
||
name_zh: str
|
||
name_en: str = ""
|
||
position: str = ""
|
||
source: Literal["auto", "manual", "region"] = "manual"
|
||
region: dict | None = None
|
||
|
||
|
||
class UpdateElementReq(BaseModel):
|
||
name_zh: str | None = None
|
||
name_en: str | None = None
|
||
position: str | None = None
|
||
|
||
|
||
class GenerateSceneAssetReq(BaseModel):
|
||
quality: AssetQuality = "hd"
|
||
size: AssetSize = "source"
|
||
scene_mode: SceneMode = "remove_subject"
|
||
scene_style: SceneStyle = "source"
|
||
asset_role: SceneAssetRole = "scene"
|
||
prompt: str = ""
|
||
source_frame_indices: list[int] | None = None
|
||
subject_images: list[dict] = Field(default_factory=list)
|
||
product_images: list[dict] = Field(default_factory=list)
|
||
|
||
|
||
class GenerateSubjectAssetsReq(BaseModel):
|
||
subject_kind: SubjectKind = "object"
|
||
background: AssetBackground = "white"
|
||
quality: AssetQuality = "hd"
|
||
size: AssetSize = "source"
|
||
source_frame_indices: list[int] | None = None
|
||
views: list[str] | None = None
|
||
character_id: str = ""
|
||
subject_template_id: str = ""
|
||
subject_style: Literal["transparent_human", "source_actor"] = "transparent_human"
|
||
reconstruction_mode: Literal["same", "similar"] = "same"
|
||
prompt: str = ""
|
||
replace_views: bool = False
|
||
|
||
|
||
class UpdateProductRefsReq(BaseModel):
|
||
items: list[dict] = Field(default_factory=list)
|
||
|
||
|
||
@app.put("/jobs/{job_id}/product-refs", response_model=Job)
|
||
def update_product_refs(job_id: str, req: UpdateProductRefsReq) -> Job:
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
items: list[dict] = []
|
||
for item in req.items[:300]:
|
||
if isinstance(item, dict) and isinstance(item.get("ref"), dict):
|
||
items.append(item)
|
||
update(job, product_refs=items)
|
||
return job
|
||
|
||
|
||
@app.post("/jobs/{job_id}/frames/{idx}/elements", response_model=Job)
|
||
def add_element(job_id: str, idx: int, req: AddElementReq) -> Job:
|
||
"""加一条元素 · 若 name_en 缺则自动 zh→en 翻译"""
|
||
import time as _time
|
||
import re as _re
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
frame = next((f for f in job.frames if f.index == idx), None)
|
||
if not frame:
|
||
raise HTTPException(404, "frame not found")
|
||
name_zh = req.name_zh.strip()
|
||
if not name_zh:
|
||
raise HTTPException(400, "name_zh required")
|
||
name_en = req.name_en.strip()
|
||
if not name_en and LLM_API_KEY:
|
||
try:
|
||
prompt = (
|
||
"Translate the following text into concise English, suitable as an element label "
|
||
"in an image-generation prompt. Output only the translation — no quotes, no punctuation, "
|
||
f"no explanation.\n\nInput: {name_zh}"
|
||
)
|
||
resp = llm().chat.completions.create(
|
||
model=TRANSLATE_MODEL,
|
||
messages=[{"role": "user", "content": prompt}],
|
||
temperature=0.2,
|
||
max_tokens=200,
|
||
)
|
||
out = (resp.choices[0].message.content or "").strip()
|
||
if not out:
|
||
rc = getattr(resp.choices[0].message, "reasoning_content", "") or ""
|
||
if rc:
|
||
out = rc.strip().splitlines()[-1].strip()
|
||
name_en = _re.sub(r'^[\'"「『]+|[\'"」』]+$', "", out).strip()
|
||
except Exception as e:
|
||
print(f"[add_element translate failed] {e}", flush=True)
|
||
name_en = ""
|
||
|
||
el = KeyElement(
|
||
id=uuid.uuid4().hex[:8],
|
||
name_zh=name_zh,
|
||
name_en=name_en,
|
||
position=req.position.strip(),
|
||
source=req.source,
|
||
region=req.region,
|
||
created_at=_time.time(),
|
||
)
|
||
new_frames = []
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
f.elements = f.elements + [el]
|
||
new_frames.append(f)
|
||
update(job, frames=new_frames, message=f"加入元素 · 分镜 {idx + 1} · {name_zh}")
|
||
return job
|
||
|
||
|
||
@app.patch("/jobs/{job_id}/frames/{idx}/elements/{element_id}", response_model=Job)
|
||
def update_element(job_id: str, idx: int, element_id: str, req: UpdateElementReq) -> Job:
|
||
"""更新元素标签 / 英文提示。提取不准时允许用户修正,不强制重建元素。"""
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
changed_name = ""
|
||
found = False
|
||
new_frames = []
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
for e in f.elements:
|
||
if e.id == element_id:
|
||
found = True
|
||
if req.name_zh is not None:
|
||
name_zh = req.name_zh.strip()
|
||
if not name_zh:
|
||
raise HTTPException(400, "name_zh required")
|
||
e.name_zh = name_zh
|
||
changed_name = name_zh
|
||
if req.name_en is not None:
|
||
e.name_en = req.name_en.strip()
|
||
if req.position is not None:
|
||
e.position = req.position.strip()
|
||
new_frames.append(f)
|
||
if not found:
|
||
raise HTTPException(404, "element not found")
|
||
update(job, frames=new_frames, message=f"更新元素 · 分镜 {idx + 1} · {changed_name or element_id}")
|
||
return job
|
||
|
||
|
||
@app.delete("/jobs/{job_id}/frames/{idx}/elements/{element_id}", response_model=Job)
|
||
def delete_element(job_id: str, idx: int, element_id: str) -> Job:
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
new_frames = []
|
||
removed = False
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
before = len(f.elements)
|
||
f.elements = [e for e in f.elements if e.id != element_id]
|
||
removed = len(f.elements) < before
|
||
# 若有提取图也删(含多版本)
|
||
if removed:
|
||
elements_dir = job_dir(job_id) / "elements"
|
||
if elements_dir.exists():
|
||
for pat in (f"{idx:03d}_{element_id}.jpg", f"{idx:03d}_{element_id}.png",
|
||
f"{idx:03d}_{element_id}_*.jpg"):
|
||
for p in elements_dir.glob(pat):
|
||
try: p.unlink()
|
||
except OSError: pass
|
||
new_frames.append(f)
|
||
if not removed:
|
||
raise HTTPException(404, "element not found")
|
||
update(job, frames=new_frames, message=f"删除元素 · 分镜 {idx + 1}")
|
||
return job
|
||
|
||
|
||
@app.post("/jobs/{job_id}/frames/{idx}/scene-asset", response_model=Job)
|
||
def generate_scene_asset(job_id: str, idx: int, req: GenerateSceneAssetReq) -> Job:
|
||
"""为关键帧生成一张资产图。
|
||
scene: 去主体背景板;first_frame/last_frame: 纯文字生成视频首尾帧,参考帧只用于理解统一人物形象。"""
|
||
import time as _time
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
frame = _find_frame(job, idx)
|
||
src = _source_frame_path(job_id, idx)
|
||
if not src.exists():
|
||
raise HTTPException(404, "source frame file missing")
|
||
|
||
source_indices = [int(x) for x in (req.source_frame_indices or [idx]) if isinstance(x, int) or str(x).isdigit()]
|
||
if not source_indices:
|
||
source_indices = [idx]
|
||
source_indices = list(dict.fromkeys(source_indices))[:8]
|
||
model_src = src
|
||
sheet_tmp: Path | None = None
|
||
asset_sheet_tmp: Path | None = None
|
||
if len(source_indices) > 1:
|
||
sheet_tmp = job_dir(job_id) / "tmp" / f"scene_refs_{idx:03d}_{uuid.uuid4().hex[:6]}.jpg"
|
||
sheet = _make_reference_contact_sheet(job_id, source_indices, sheet_tmp)
|
||
if sheet:
|
||
model_src = sheet
|
||
subject_ref_paths = [p for p in (storyboard_ref_path(job_id, r) for r in req.subject_images[:8]) if p and p.exists()]
|
||
product_ref_paths = [p for p in (storyboard_ref_path(job_id, r) for r in req.product_images[:6]) if p and p.exists()]
|
||
asset_ref_paths = [*subject_ref_paths, *product_ref_paths]
|
||
if req.asset_role != "scene" and asset_ref_paths:
|
||
asset_sheet_tmp = job_dir(job_id) / "tmp" / f"endpoint_refs_{idx:03d}_{uuid.uuid4().hex[:6]}.jpg"
|
||
asset_sheet = _make_paths_contact_sheet(asset_ref_paths, asset_sheet_tmp, max_items=10)
|
||
if asset_sheet:
|
||
model_src = asset_sheet
|
||
|
||
confirmed_subjects = [
|
||
(e.name_en or e.name_zh).strip()
|
||
for ref_frame in job.frames
|
||
for e in (ref_frame.elements or [])
|
||
if (e.subject_assets or [])
|
||
]
|
||
if not confirmed_subjects:
|
||
confirmed_subjects = [
|
||
(e.name_en or e.name_zh).strip()
|
||
for ref_frame in job.frames
|
||
for e in (ref_frame.elements or [])
|
||
if (e.name_en or e.name_zh).strip()
|
||
][:3]
|
||
confirmed_subjects = list(dict.fromkeys([x for x in confirmed_subjects if x]))[:3]
|
||
subject_clause = (
|
||
"Confirmed foreground subject(s) to remove: " + ", ".join(confirmed_subjects) + ". "
|
||
if confirmed_subjects
|
||
else "Remove the main foreground subject from the frame if present. "
|
||
)
|
||
identity_clause = (
|
||
f"Use the generated subject asset references as the primary character identity lock ({len(subject_ref_paths)} image(s)); preserve the subject type, material, proportions, style, age/gender presentation, pose vocabulary, and ad-friendly identity exactly as shown in those selected views. "
|
||
if subject_ref_paths
|
||
else (
|
||
"No generated subject reference was provided for this endpoint. Do not add a main character unless the user scene direction explicitly asks for one. "
|
||
)
|
||
)
|
||
mode_clause = {
|
||
"remove_subject": (
|
||
"Keep the original environment, camera angle, perspective, composition, lighting direction, color mood, and spatial layout. "
|
||
"The result should be an empty clean scene/background plate with the subject removed and the occluded background reconstructed."
|
||
),
|
||
"similar": (
|
||
"Create a similar but not identical scene/background plate: keep the same camera angle, rough spatial layout, lighting direction, and usage context, "
|
||
"but vary props, surface details, textures, and small environmental details so it is not a duplicate of the source."
|
||
),
|
||
"style": (
|
||
"Create a scene/background plate with the same camera angle and spatial layout, but reinterpret the environment in the selected visual style. "
|
||
"Keep it believable and useful for image-to-video generation."
|
||
),
|
||
}[req.scene_mode]
|
||
style_clause = {
|
||
"source": "Follow the original source style.",
|
||
"premium_product": "Use a premium product-advertising style: polished, high-end, clean commercial lighting, refined materials.",
|
||
"clean_studio": "Use a clean studio style: simple surfaces, controlled lighting, minimal distractions.",
|
||
"warm_lifestyle": "Use a warm lifestyle style: realistic lived-in details, soft natural light, approachable atmosphere.",
|
||
"cinematic": "Use a cinematic style: dramatic but natural lighting, richer depth, filmic contrast, not fantasy.",
|
||
}[req.scene_style]
|
||
user_prompt = req.prompt.strip()
|
||
user_prompt_clause = (
|
||
"User scene direction: " + user_prompt[:1200] + " "
|
||
if user_prompt
|
||
else ""
|
||
)
|
||
if req.asset_role != "scene" and asset_ref_paths:
|
||
reference_clause = (
|
||
f"Use the provided asset contact sheet as the primary visual reference: {len(subject_ref_paths)} generated subject image(s) and {len(product_ref_paths)} SKG product image(s). "
|
||
"Do not use the original keyframe as the first/last-frame truth; it is only a storage anchor for this row. "
|
||
)
|
||
else:
|
||
reference_clause = (
|
||
f"Use the selected reference frame contact sheet as visual evidence for location, composition, lighting, materials, and atmosphere. Reference frame indices: {', '.join(str(i + 1) for i in source_indices)}. "
|
||
if len(source_indices) > 1
|
||
else "Use the provided frame as the primary visual reference. "
|
||
)
|
||
product_asset_clause = (
|
||
"Use the provided SKG product references as the rigid product truth when the user prompt asks for product presence: a white U-shaped neck-and-shoulder wearable massage device worn around the neck/shoulders, not headphones, a collar pillow, skincare, food, or a medical prop. Keep product scale believable, preserve left/right asymmetry, side thickness, inner contact pads, buttons, white material, and real wearable placement. "
|
||
if product_ref_paths
|
||
else "Do not invent a random product. Only include an SKG product if the user prompt explicitly asks for it. "
|
||
)
|
||
subject_asset_clause = (
|
||
TRANSPARENT_HUMAN_POSITIVE_PROMPT + " "
|
||
+ TRANSPARENT_HUMAN_NEGATIVE_PROMPT + " "
|
||
+ "If the selected subject references are transparent humanoid assets, keep the same friendly transparent or translucent human character: glass/acrylic/vinyl-like transparent outer body, visible clean white skeleton inside, clean commercial wellness style, non-horror. "
|
||
+ "If the selected subject references are normal actor assets, keep them as a normal believable commercial actor and do not convert them into a transparent skeleton. "
|
||
+ "Use the selected subject views only to understand identity, proportions, material, pose vocabulary, camera language, and lighting; do not copy watermarks, subtitles, platform UI, logos, or accidental artifacts. "
|
||
if subject_ref_paths
|
||
else "No main character should be generated unless the user scene direction explicitly requires one; product-only and environment-only frames should stay product-only or scene-only. "
|
||
)
|
||
if req.asset_role == "scene":
|
||
prompt = (
|
||
"Create one clean high-definition scene/background reference image from this frame. "
|
||
+ subject_clause
|
||
+ "Do not include the removed subject, duplicate people, animals, products, text, watermark, platform UI, captions, usernames, hashtags, logos, or overlay graphics. "
|
||
+ reference_clause
|
||
+ user_prompt_clause
|
||
+ mode_clause + " "
|
||
+ style_clause + " "
|
||
+ "Enhance clarity and texture while avoiding over-smoothing, warped geometry, or changing important perspective details. "
|
||
+ "Do not create multiple views. Do not isolate objects."
|
||
)
|
||
else:
|
||
role_clause = (
|
||
"This is the FIRST frame for an image-to-video clip: create a clear beginning pose and composition. "
|
||
if req.asset_role == "first_frame"
|
||
else "This is the LAST frame for an image-to-video clip: create a clear ending pose that can naturally follow the first frame, not a duplicate. "
|
||
)
|
||
prompt = (
|
||
"Create one premium 9:16 high-definition video endpoint frame from text direction. "
|
||
+ role_clause
|
||
+ identity_clause
|
||
+ reference_clause
|
||
+ user_prompt_clause
|
||
+ style_clause + " "
|
||
+ product_asset_clause
|
||
+ subject_asset_clause
|
||
+ "Do not create a plain background plate. Do not include SKG product unless the user prompt explicitly asks for it. "
|
||
+ "The output should be ready as a first/last frame for Seedance video generation, with stable composition, believable perspective, clear subject, no text, no watermark, no gore, no medical surgery imagery."
|
||
)
|
||
models = [GPT_IMAGE_MODEL]
|
||
try:
|
||
if req.asset_role == "scene":
|
||
img_bytes, _mode = _image_edit_call(model_src, prompt, models=models, fallback_text=False, max_attempts=3, max_side=1280)
|
||
elif asset_ref_paths:
|
||
img_bytes, _mode = _image_edit_call(model_src, prompt, models=models, fallback_text=False, max_attempts=3, max_side=1600)
|
||
else:
|
||
img_bytes, _mode = _image_text_call(prompt, models=models, max_attempts=3)
|
||
except RuntimeError as e:
|
||
raise HTTPException(500, f"{req.asset_role} asset failed: {e}")
|
||
finally:
|
||
if sheet_tmp and sheet_tmp.exists():
|
||
try: sheet_tmp.unlink()
|
||
except OSError: pass
|
||
if asset_sheet_tmp and asset_sheet_tmp.exists():
|
||
try: asset_sheet_tmp.unlink()
|
||
except OSError: pass
|
||
|
||
asset_id = f"scene_{idx:03d}_{uuid.uuid4().hex[:8]}"
|
||
out_path = job_dir(job_id) / "assets" / f"{asset_id}.jpg"
|
||
width, height = _normalize_asset_image(img_bytes, out_path, src, req.size, "white", square=False)
|
||
report = _image_quality_report(out_path)
|
||
scene = SceneAsset(
|
||
id=asset_id,
|
||
label=(
|
||
f"分镜 {idx + 1} 场景图"
|
||
if req.asset_role == "scene"
|
||
else f"分镜 {idx + 1} {'首帧' if req.asset_role == 'first_frame' else '尾帧'}"
|
||
),
|
||
url=_asset_url(job_id, asset_id),
|
||
width=width,
|
||
height=height,
|
||
quality=req.quality,
|
||
size=req.size,
|
||
scene_mode=req.scene_mode,
|
||
scene_style=req.scene_style,
|
||
asset_role=req.asset_role,
|
||
quality_report=report,
|
||
created_at=_time.time(),
|
||
)
|
||
|
||
new_frames = []
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
f.quality_report = _image_quality_report(src)
|
||
f.scene_assets = (f.scene_assets or []) + [scene]
|
||
new_frames.append(f)
|
||
asset_label = "场景图" if req.asset_role == "scene" else ("首帧" if req.asset_role == "first_frame" else "尾帧")
|
||
update(job, frames=new_frames, message=f"{asset_label}生成完成 · 分镜 {idx + 1}")
|
||
return job
|
||
|
||
|
||
@app.post("/jobs/{job_id}/frames/{idx}/elements/{element_id}/cutout", response_model=Job)
|
||
def cutout_element(job_id: str, idx: int, element_id: str) -> Job:
|
||
"""AI 提取元素 · 每次累积一张新图:
|
||
调 gpt-image-2 生成**完整、清晰**的元素图(即使原图只露出部分也补全)。
|
||
region 元素:先把 region + 30% padding 区域裁出作为 focus,再发给模型聚焦补全。"""
|
||
from PIL import Image as _PILImage
|
||
import io as _io
|
||
import tempfile as _tempfile
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
frame = next((f for f in job.frames if f.index == idx), None)
|
||
if not frame:
|
||
raise HTTPException(404, "frame not found")
|
||
el = next((e for e in frame.elements if e.id == element_id), None)
|
||
if not el:
|
||
raise HTTPException(404, "element not found")
|
||
|
||
cleaned_path = job_dir(job_id) / "cleaned" / f"{idx:03d}.jpg"
|
||
src = cleaned_path if cleaned_path.exists() else job_dir(job_id) / "frames" / f"{idx:03d}.jpg"
|
||
if not src.exists():
|
||
raise HTTPException(404, "source frame file missing")
|
||
|
||
out_dir = job_dir(job_id) / "elements"
|
||
out_dir.mkdir(parents=True, exist_ok=True)
|
||
new_cutout_id = uuid.uuid4().hex[:8]
|
||
out_path = out_dir / f"{idx:03d}_{element_id}_{new_cutout_id}.jpg"
|
||
|
||
# region 元素:先 PIL 裁出 region + 30% padding 作为 focus 给模型(让它聚焦在该元素)
|
||
tmp_focus: Path | None = None
|
||
model_src = src
|
||
if el.region:
|
||
try:
|
||
im = _PILImage.open(src).convert("RGB")
|
||
W, H = im.size
|
||
r = el.region
|
||
x = max(0.0, min(1.0, float(r.get("x", 0))))
|
||
y = max(0.0, min(1.0, float(r.get("y", 0))))
|
||
w = max(0.0, min(1.0 - x, float(r.get("w", 0))))
|
||
h = max(0.0, min(1.0 - y, float(r.get("h", 0))))
|
||
cx, cy = x + w / 2, y + h / 2
|
||
# 扩大 30% 给上下文(避免裁到正好边界丢失补全 hint)
|
||
ew, eh = w * 1.6, h * 1.6
|
||
x0 = max(0.0, cx - ew / 2); y0 = max(0.0, cy - eh / 2)
|
||
x1 = min(1.0, cx + ew / 2); y1 = min(1.0, cy + eh / 2)
|
||
left, top, right, bottom = int(x0 * W), int(y0 * H), int(x1 * W), int(y1 * H)
|
||
if right - left > 8 and bottom - top > 8:
|
||
cropped = im.crop((left, top, right, bottom))
|
||
tmp = _tempfile.NamedTemporaryFile(suffix=".jpg", delete=False)
|
||
cropped.save(tmp.name, format="JPEG", quality=92)
|
||
tmp.close()
|
||
tmp_focus = Path(tmp.name)
|
||
model_src = tmp_focus
|
||
except Exception as e:
|
||
print(f"[cutout region crop failed, fallback to full frame] {e}", flush=True)
|
||
|
||
target = (el.name_en or el.name_zh).strip()
|
||
prompt = (
|
||
f"Identify the {target} in this image. "
|
||
f"Generate a complete, high-resolution, sharply detailed image of the entire {target} as a standalone asset. "
|
||
f"If the {target} is only partially visible in the source (cropped at edges, occluded by other objects, or out of frame), "
|
||
"intelligently reconstruct the missing parts based on visual context so the result shows the FULL element. "
|
||
"Place the complete element on a pure white background, isolated, with no other objects, no scene fragments, no shadows from the original scene. "
|
||
"Preserve the element's original color palette, style, lighting character, and proportions. "
|
||
"Output must be a clean, high-quality asset image suitable for downstream composition."
|
||
)
|
||
models = [GPT_IMAGE_MODEL]
|
||
img_bytes: bytes
|
||
try:
|
||
try:
|
||
img_bytes, _mode = _image_edit_call(
|
||
model_src, prompt, models=models, fallback_text=False, max_attempts=3,
|
||
)
|
||
except RuntimeError as e:
|
||
raise HTTPException(500, f"extract failed: {e}")
|
||
finally:
|
||
if tmp_focus and tmp_focus.exists():
|
||
try: tmp_focus.unlink()
|
||
except OSError: pass
|
||
out_path.write_bytes(img_bytes)
|
||
|
||
new_frames = []
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
for e in f.elements:
|
||
if e.id == element_id:
|
||
e.cutouts = (e.cutouts or []) + [new_cutout_id]
|
||
if not e.cutout_id:
|
||
e.cutout_id = new_cutout_id
|
||
new_frames.append(f)
|
||
update(job, frames=new_frames, message=f"提取完成 · {el.name_zh}")
|
||
return job
|
||
|
||
|
||
@app.post("/jobs/{job_id}/frames/{idx}/elements/{element_id}/subject-assets", response_model=Job)
|
||
def generate_subject_assets(job_id: str, idx: int, element_id: str, req: GenerateSubjectAssetsReq) -> Job:
|
||
"""为一个主体生成多视角资产包。
|
||
如果传入 source_frame_indices 或内置 character_id,则把多张参考图作为独立 image[] 证据提交。"""
|
||
import time as _time
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
frame = _find_frame(job, idx)
|
||
el = next((e for e in frame.elements if e.id == element_id), None)
|
||
if not el:
|
||
raise HTTPException(404, "element not found")
|
||
|
||
source_indices = [int(x) for x in (req.source_frame_indices or [idx]) if isinstance(x, int) or str(x).isdigit()]
|
||
if idx not in source_indices:
|
||
source_indices = [idx] + source_indices
|
||
source_indices = list(dict.fromkeys(source_indices))[:12]
|
||
|
||
similar_mode = req.reconstruction_mode == "similar"
|
||
character_reference_paths: list[Path] = []
|
||
template_brief_clause = ""
|
||
character_label = ""
|
||
subject_template_id = (req.subject_template_id or "").strip()
|
||
character_id = (req.character_id or "").strip()
|
||
if subject_template_id:
|
||
template = find_subject_template_item(subject_template_id)
|
||
character_label = template.name
|
||
template_paths = [subject_template_image_file(image.filename) for image in template.images[:10]]
|
||
character_reference_paths.extend(template_paths)
|
||
brief = template.prompt_brief.strip() or template.note.strip() or template.description.strip()
|
||
if similar_mode and not brief:
|
||
brief = _describe_subject_template_from_images(template.name, template.subject_style, template_paths, template.note)
|
||
template_brief_clause = (
|
||
f"Reference character brief from saved database template '{template.name}': {brief}. "
|
||
"Use this as a high-quality creative direction and identity bible only; do not copy a face, exact pose, pixels, file artifacts, labels, or accidental defects. "
|
||
"Create a new innovative variation that keeps the same broad subject type, transparent wellness character language, camera readability, shoulder/neck product compatibility, and commercial role. "
|
||
if brief else
|
||
f"Selected reusable subject template from database: {template.name}. Create a new innovative variation, not a duplicate. "
|
||
)
|
||
elif character_id:
|
||
character = find_character_library_item(character_id)
|
||
character_label = character.name
|
||
character_reference_paths.extend(character_library_file(image.filename) for image in character.images[:7])
|
||
brief = character.prompt_brief.strip() or character.description.strip()
|
||
template_brief_clause = (
|
||
f"Reference character brief from built-in creative character '{character.name}': {brief}. "
|
||
"Use this planned character brief as a high-quality creative direction and anatomy/style bible only; "
|
||
"do not copy the exact face, exact pose, exact silhouette, pixels, or make a duplicate. "
|
||
"Create a new innovative variation that keeps the same broad role, transparent wellness character language, camera readability, and shoulder/neck product compatibility. "
|
||
)
|
||
|
||
tmp_focus: Path | None = None
|
||
model_src: Path | list[Path] | None = None
|
||
frame_reference_paths = [p for p in (_source_frame_path(job_id, i) for i in source_indices) if p.exists()]
|
||
source_subject_brief = _describe_source_subject(job_id, source_indices) if similar_mode else ""
|
||
source_subject_clause = (
|
||
f"Source video role brief from selected keyframes: {source_subject_brief}. "
|
||
"Use this brief to preserve role category, creator-ad energy, camera readability, and broad styling, while creating a new non-identical subject. "
|
||
if source_subject_brief else
|
||
"Source video role brief unavailable; create a new non-identical ad subject guided by the user direction, template brief, and requested view. "
|
||
)
|
||
if not similar_mode:
|
||
model_src, tmp_focus = _focus_source_for_element(job_id, idx, el)
|
||
if character_reference_paths:
|
||
remaining = max(0, 10 - len(character_reference_paths))
|
||
model_src = character_reference_paths + frame_reference_paths[:remaining]
|
||
elif len(frame_reference_paths) > 1:
|
||
model_src = frame_reference_paths[:10]
|
||
|
||
try:
|
||
with Image.open(_source_frame_path(job_id, idx)) as src_im:
|
||
source_is_portrait = src_im.height > src_im.width
|
||
except Exception:
|
||
source_is_portrait = False
|
||
canvas_clause = (
|
||
"Canvas and aspect ratio: the reference video frame is vertical, so output a vertical portrait 9:16-style image, not a square canvas and not a horizontal layout. "
|
||
if source_is_portrait
|
||
else "Canvas and aspect ratio: keep a single clean reference-image canvas with the same broad orientation as the source evidence. "
|
||
)
|
||
|
||
target = (el.name_en or el.name_zh).strip()
|
||
bg_phrase = "pure white" if req.background == "white" else "pure black"
|
||
similar_actor = req.subject_kind == "living" and req.subject_style == "source_actor" and req.reconstruction_mode == "similar"
|
||
kind_phrase = "human actor or living character" if req.subject_kind == "living" else "object or product-like subject"
|
||
transparent_character_clause = (
|
||
TRANSPARENT_HUMAN_POSITIVE_PROMPT
|
||
+ " The generated living character must be a friendly transparent humanoid with transparent or translucent outer body and clean white skeleton visible inside the same body. "
|
||
+ TRANSPARENT_HUMAN_NEGATIVE_PROMPT
|
||
+ " Do not render a normal human, ordinary skeleton-only character, horror skeleton, medical anatomy, organs, veins, blood, corpse, zombie, hospital, surgery, or autopsy visual. "
|
||
if req.subject_kind == "living" and req.subject_style == "transparent_human"
|
||
else ""
|
||
)
|
||
actor_style_clause = (
|
||
"Generate a believable normal commercial video actor, not a transparent or skeleton character. "
|
||
"Use the text briefs to understand the source video's casting direction, age range, gender presentation, body proportion, wardrobe category, gesture vocabulary, framing, energy, lighting, and creator-ad style. "
|
||
"Do not recreate the exact person's face, biometric identity, unique likeness, tattoos, scars, logos, watermarks, captions, or platform UI. "
|
||
"The output must be a newly designed similar actor that could play the same role in a new ad, with consistent identity across all views. "
|
||
if similar_actor
|
||
else ""
|
||
)
|
||
identity_clause = (
|
||
"Create a similar but non-identical original subject: match the performance role, silhouette category, styling direction, camera-readability, and commercial mood, while changing exact identity and unique personal features. "
|
||
if req.reconstruction_mode == "similar"
|
||
else "Preserve identity, proportions, silhouette, material, colors, styling, and distinctive details across all generated views. "
|
||
)
|
||
prompt_extra = req.prompt.strip()
|
||
prompt_extra_clause = f"User direction: {prompt_extra[:1200]} " if prompt_extra else ""
|
||
identity_lock_clause = (
|
||
"Identity lock: these API calls generate one high-definition multi-view pack for ONE single subject, but each individual output file must show only its one requested view. "
|
||
"Before rendering, infer one consistent character bible from the supplied text brief and generation instructions: gender presentation, age range, body proportions, head shape, face direction cues, material, silhouette, wardrobe/material style, and commercial mood. "
|
||
"Keep that same character bible unchanged across every generated view in separate files. "
|
||
"If user direction requests a gender, age, or style change, apply that one change uniformly to all views; never mix male/female, young/old, or multiple style identities inside the same pack. "
|
||
"For transparent humanoids, keep the same transparent skin shell, skeleton proportions, visible spine/rib cage/pelvis/limb bones, and non-horror wellness character style in every view. "
|
||
)
|
||
neck_product_clause = (
|
||
"This subject pack is for SKG neck-and-shoulder wearable massage device videos. "
|
||
"Make the neck, collarbone, shoulder line, upper back, side neck, and shoulder slope clear and product-ready. "
|
||
"Avoid bulky collars, scarves, hair, hoods, props, or poses that hide the neck/shoulder placement area. "
|
||
"For back and close-up views, prioritize the cervical spine, shoulder blades, upper trapezius, and clean wearable-device contact area. "
|
||
)
|
||
models = [GPT_IMAGE_MODEL]
|
||
generated: list[SubjectAsset] = []
|
||
try:
|
||
for view, view_label in _subject_view_labels(req.subject_kind, req.views):
|
||
closeup_view = view in {"bust", "back_detail", "bust_front", "bust_left_45", "bust_right_45", "back_neck_detail"} or "detail" in view
|
||
if req.subject_kind == "living":
|
||
if closeup_view:
|
||
view_prompt = f"upper-body shoulder-and-neck close-up character reference, {view_label}"
|
||
elif view.startswith("expression_"):
|
||
emotion = view_label.replace("表情", "")
|
||
view_prompt = f"full-body upright standing character reference with a clear {emotion} facial expression"
|
||
elif view.startswith("action_") or view == "side_walk":
|
||
view_prompt = f"full-body upright standing character reference, {view_label}, consistent actor proportions"
|
||
else:
|
||
view_prompt = f"full-body upright standing character reference, {view_label}"
|
||
else:
|
||
view_prompt = f"complete object/product reference, {view_label} view"
|
||
view_name = view.replace("_", " ")
|
||
single_view_clause = (
|
||
f"Single-image output rule: this output file is ONLY for the {view_label} view ({view_name}). "
|
||
"Render exactly one subject, one time, in one pose and one camera angle. "
|
||
"Do not create a multi-view sheet, contact sheet, grid, storyboard, lineup, comparison layout, before/after layout, mirrored pair, duplicate subjects, thumbnails, labels, captions, arrows, view names, panel borders, or multiple versions in the same image. "
|
||
"Do not include any other views in this image. "
|
||
)
|
||
framing_clause = (
|
||
"For this close-up view, intentionally crop as an upper-body asset from head/neck to chest or upper back; the neck, shoulders, collarbone or upper spine area must be large, clear, and useful for placing a neck-and-shoulder massage device. "
|
||
"Do not force full-body framing for close-ups. "
|
||
if closeup_view and req.subject_kind == "living"
|
||
else "The subject must be complete, centered, full body or full object, head-to-feet visible when applicable, not cropped by the canvas. Make the subject large and readable: it should occupy about 85-95% of the image height with only small margins. "
|
||
)
|
||
reference_strategy_clause = (
|
||
"Text-only generation mode: no source image is attached to this image request. Use only the written source/video/template briefs below as creative constraints. "
|
||
"This is intentionally NOT image editing and NOT identity replication. "
|
||
+ source_subject_clause
|
||
+ template_brief_clause
|
||
if similar_mode else
|
||
"Use the reference image(s) only as visual evidence; do not crop, cut out, paste, trace, or extract pixels from the source. "
|
||
)
|
||
prompt = (
|
||
reference_strategy_clause
|
||
+
|
||
f"Generate one newly rendered {view_prompt} for {target}. "
|
||
f"The subject is a {kind_phrase}. Treat all source evidence as one role and one consistent subject bible, not multiple subjects. "
|
||
+ single_view_clause
|
||
+ identity_clause
|
||
+ identity_lock_clause
|
||
+ neck_product_clause
|
||
+ canvas_clause
|
||
+ prompt_extra_clause
|
||
+ actor_style_clause
|
||
+ framing_clause
|
||
+ f"Create a high-definition standalone asset on a solid {bg_phrase} background. "
|
||
"No extra objects, no props, no additional products, no background elements, no original scene fragments, no shadows from the original scene, no text, no watermark, no UI. "
|
||
"If the source is incomplete, partially visible, occluded, or low resolution, reconstruct the missing parts by redrawing a clean complete subject while staying consistent with the reference. "
|
||
"For living standard full-body views, keep a normal upright standing pose; do not create sitting, walking, medical, horror, or distorted anatomy unless explicitly requested by the view label. "
|
||
+ transparent_character_clause
|
||
)
|
||
try:
|
||
if similar_mode:
|
||
print(
|
||
f"[subject assets] reconstruction_mode=similar endpoint=/images/generations view={view} image_refs=0 model={GPT_IMAGE_MODEL}",
|
||
flush=True,
|
||
)
|
||
img_bytes, _mode = _image_text_call(prompt, models=models, max_attempts=3)
|
||
else:
|
||
if model_src is None:
|
||
raise RuntimeError("subject asset edit reference image missing")
|
||
img_bytes, _mode = _image_edit_call(model_src, prompt, models=models, fallback_text=False, max_attempts=3, max_side=1280)
|
||
except RuntimeError as e:
|
||
raise HTTPException(_image_error_status(e), f"subject asset {view} failed: {e}")
|
||
|
||
asset_id = f"subject_{idx:03d}_{element_id}_{view}_{uuid.uuid4().hex[:8]}"
|
||
out_path = job_dir(job_id) / "assets" / f"{asset_id}.jpg"
|
||
width, height = _normalize_asset_image(img_bytes, out_path, _source_frame_path(job_id, idx), req.size, req.background, square=False, fill_subject=True)
|
||
generated.append(SubjectAsset(
|
||
id=asset_id,
|
||
view=view,
|
||
label=f"{el.name_zh} · {view_label}" + (f" · {character_label}" if character_label else ""),
|
||
url=_asset_url(job_id, asset_id),
|
||
width=width,
|
||
height=height,
|
||
background=req.background,
|
||
quality=req.quality,
|
||
size=req.size,
|
||
source_frame_indices=source_indices,
|
||
created_at=_time.time(),
|
||
))
|
||
finally:
|
||
for p in (tmp_focus,):
|
||
if p and p.exists():
|
||
try: p.unlink()
|
||
except OSError: pass
|
||
|
||
src = _source_frame_path(job_id, idx)
|
||
new_frames = []
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
f.quality_report = _image_quality_report(src, el.region)
|
||
for e in f.elements:
|
||
if e.id == element_id:
|
||
e.subject_kind = req.subject_kind
|
||
e.cutout_background = req.background
|
||
current_assets = e.subject_assets or []
|
||
if req.replace_views:
|
||
replaced_views = {asset.view for asset in generated}
|
||
for old_asset in current_assets:
|
||
if old_asset.view in replaced_views:
|
||
_delete_subject_asset_file(job_id, old_asset.id)
|
||
current_assets = [asset for asset in current_assets if asset.view not in replaced_views]
|
||
e.subject_assets = current_assets + generated
|
||
new_frames.append(f)
|
||
update(job, frames=new_frames, message=f"主体资产包生成完成 · {el.name_zh} · {len(generated)} 张")
|
||
return job
|
||
|
||
|
||
@app.delete("/jobs/{job_id}/frames/{idx}/elements/{element_id}/subject-assets/{asset_id}", response_model=Job)
|
||
def delete_subject_asset(job_id: str, idx: int, element_id: str, asset_id: str) -> Job:
|
||
"""删除某张主体白底视图。"""
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
frame = _find_frame(job, idx)
|
||
el = next((e for e in frame.elements if e.id == element_id), None)
|
||
if not el:
|
||
raise HTTPException(404, "element not found")
|
||
assets = el.subject_assets or []
|
||
if not any(asset.id == asset_id for asset in assets):
|
||
raise HTTPException(404, "subject asset not found")
|
||
|
||
_delete_subject_asset_file(job_id, asset_id)
|
||
new_frames = []
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
for e in f.elements:
|
||
if e.id == element_id:
|
||
e.subject_assets = [asset for asset in (e.subject_assets or []) if asset.id != asset_id]
|
||
new_frames.append(f)
|
||
update(job, frames=new_frames, message=f"主体视图已删除 · {el.name_zh}")
|
||
return job
|
||
|
||
|
||
@app.delete("/jobs/{job_id}/frames/{idx}/elements/{element_id}/cutouts/{cutout_id}", response_model=Job)
|
||
def delete_cutout(job_id: str, idx: int, element_id: str, cutout_id: str) -> Job:
|
||
"""删除该元素的某张提取图"""
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
p = job_dir(job_id) / "elements" / f"{idx:03d}_{element_id}_{cutout_id}.jpg"
|
||
if p.exists():
|
||
try: p.unlink()
|
||
except OSError: pass
|
||
|
||
removed = False
|
||
new_frames = []
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
for e in f.elements:
|
||
if e.id == element_id:
|
||
if cutout_id in (e.cutouts or []):
|
||
e.cutouts = [c for c in e.cutouts if c != cutout_id]
|
||
removed = True
|
||
# cutout_id 兼容字段:若指向被删的就清空 / 移到 cutouts 第一个
|
||
if e.cutout_id == cutout_id:
|
||
e.cutout_id = e.cutouts[0] if e.cutouts else None
|
||
new_frames.append(f)
|
||
if not removed:
|
||
raise HTTPException(404, "cutout not found in element")
|
||
update(job, frames=new_frames, message=f"删除提取图")
|
||
return job
|
||
|
||
|
||
class UpdateStoryboardReq(BaseModel):
|
||
duration: float = 0
|
||
first_image: dict | None = None
|
||
last_image: dict | None = None
|
||
product_images: list[dict] = Field(default_factory=list)
|
||
subject_images: list[dict] = Field(default_factory=list)
|
||
product_fusion_shots: list[dict] = Field(default_factory=list)
|
||
visual_mode: Literal["person_only", "person_product", "product_only", "environment"] = "person_product"
|
||
needs_product: bool = True
|
||
needs_subject: bool = True
|
||
first_frame_plan: str = ""
|
||
last_frame_plan: str = ""
|
||
product_placement: str = ""
|
||
subject_image: dict | None = None
|
||
scene_image: dict | None = None
|
||
product_image: dict | None = None
|
||
action_image: dict | None = None
|
||
# v1 字段(前端可不传)
|
||
subject: str = ""
|
||
product: str = ""
|
||
scene: str = ""
|
||
action: str = ""
|
||
reference_ids: list[str] = []
|
||
|
||
|
||
class GenerateStoryboardVideoReq(BaseModel):
|
||
prompt: str
|
||
duration: float = 4
|
||
first_image: dict | None = None
|
||
last_image: dict | None = None
|
||
product_images: list[dict] = Field(default_factory=list)
|
||
subject_image: dict | None = None
|
||
subject_images: list[dict] = Field(default_factory=list)
|
||
scene_image: dict | None = None
|
||
product_image: dict | None = None
|
||
action_image: dict | None = None
|
||
source_ref: VideoSourceRef | None = None
|
||
model: str = ""
|
||
size: str = "720x1280"
|
||
|
||
|
||
class ProductFusionDescriptionReq(BaseModel):
|
||
shots: list[ProductFusionShot] = Field(default_factory=list)
|
||
|
||
|
||
def video_seconds(duration: float) -> str:
|
||
if video_uses_ark():
|
||
if duration <= 0:
|
||
return "5"
|
||
return str(max(4, min(15, round(duration))))
|
||
if duration <= 6:
|
||
return "4"
|
||
if duration <= 10:
|
||
return "8"
|
||
return "12"
|
||
|
||
|
||
def resolve_video_model(raw: str | None) -> str:
|
||
requested = (raw or VIDEO_MODEL or "seedance").strip()
|
||
lowered = requested.lower()
|
||
if lowered in {"sora", "sora-2", "sora_2"}:
|
||
raise HTTPException(400, "Sora 已停用,请选择 Seedance / Kling / Veo 3")
|
||
return VIDEO_MODEL_ALIASES.get(lowered, requested)
|
||
|
||
|
||
def normalize_video_status(status: str | None) -> Literal["queued", "in_progress", "completed", "failed"]:
|
||
s = (status or "queued").lower()
|
||
if s in {"completed", "complete", "succeeded", "success", "done"}:
|
||
return "completed"
|
||
if s in {"failed", "failure", "error", "cancelled", "canceled", "expired"}:
|
||
return "failed"
|
||
if s in {"running", "processing", "in_progress", "generating", "started"}:
|
||
return "in_progress"
|
||
return "queued"
|
||
|
||
|
||
def video_progress(data: dict, fallback: int) -> int:
|
||
raw = data.get("progress", data.get("percentage", data.get("percent", fallback)))
|
||
try:
|
||
value = int(float(raw))
|
||
except Exception:
|
||
value = fallback
|
||
return max(0, min(100, value))
|
||
|
||
|
||
def video_url_from_response(data: dict) -> str:
|
||
for key in ("url", "video_url", "output_url", "download_url"):
|
||
v = data.get(key)
|
||
if isinstance(v, str) and v:
|
||
return v
|
||
arr = data.get("data")
|
||
if isinstance(arr, list) and arr:
|
||
first = arr[0]
|
||
if isinstance(first, dict):
|
||
for key in ("url", "video_url", "output_url", "download_url"):
|
||
v = first.get(key)
|
||
if isinstance(v, str) and v:
|
||
return v
|
||
output = data.get("output")
|
||
if isinstance(output, dict):
|
||
for key in ("url", "video_url", "download_url"):
|
||
v = output.get(key)
|
||
if isinstance(v, str) and v:
|
||
return v
|
||
content = data.get("content")
|
||
if isinstance(content, dict):
|
||
for key in ("video_url", "url", "download_url", "file_url"):
|
||
v = content.get(key)
|
||
if isinstance(v, str) and v:
|
||
return v
|
||
return ""
|
||
|
||
|
||
def download_generated_video(client, base: str, headers: dict, provider_id: str, direct_url: str, out_mp4: Path) -> None:
|
||
if direct_url:
|
||
url = direct_url if direct_url.startswith("http") else f"{base}{direct_url if direct_url.startswith('/') else '/' + direct_url}"
|
||
r = client.get(url, headers=headers if url.startswith(base) else None)
|
||
else:
|
||
r = client.get(f"{base}{video_path(VIDEO_CONTENT_PATH, id=provider_id)}", headers=headers)
|
||
r.raise_for_status()
|
||
out_mp4.write_bytes(r.content)
|
||
|
||
|
||
def size_to_video_ratio(size: str) -> str:
|
||
try:
|
||
w, h = [int(x) for x in size.lower().replace(" ", "").split("x", 1)]
|
||
except Exception:
|
||
return "9:16"
|
||
if w <= 0 or h <= 0:
|
||
return "9:16"
|
||
ratio = w / h
|
||
known = {
|
||
"16:9": 16 / 9,
|
||
"9:16": 9 / 16,
|
||
"1:1": 1,
|
||
"4:3": 4 / 3,
|
||
"3:4": 3 / 4,
|
||
"21:9": 21 / 9,
|
||
}
|
||
return min(known, key=lambda key: abs(known[key] - ratio))
|
||
|
||
|
||
def ark_reference_data_url(ref_img: Path) -> str:
|
||
mime = "image/png" if ref_img.suffix.lower() == ".png" else "image/jpeg"
|
||
return f"data:{mime};base64,{base64.b64encode(ref_img.read_bytes()).decode('ascii')}"
|
||
|
||
|
||
def submit_video_create(
|
||
client,
|
||
url: str,
|
||
headers: dict,
|
||
ref_img: Path,
|
||
payload: dict,
|
||
source_ref: VideoSourceRef | None = None,
|
||
last_img: Path | None = None,
|
||
product_imgs: list[Path] | None = None,
|
||
primary_role: str = "first_frame",
|
||
):
|
||
if video_uses_ark():
|
||
content = [{"type": "text", "text": payload["prompt"]}]
|
||
if source_ref and source_ref.kind == "source_video" and source_ref.url:
|
||
content.append(
|
||
{
|
||
"type": "video_url",
|
||
"video_url": {"url": source_ref.url},
|
||
"role": "reference_video",
|
||
}
|
||
)
|
||
content.append(
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {"url": ark_reference_data_url(ref_img)},
|
||
"role": primary_role,
|
||
}
|
||
)
|
||
if last_img and last_img.exists():
|
||
content.append(
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {"url": ark_reference_data_url(last_img)},
|
||
"role": "last_frame",
|
||
}
|
||
)
|
||
for product_img in (product_imgs or [])[:6]:
|
||
if product_img.exists():
|
||
content.append(
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {"url": ark_reference_data_url(product_img)},
|
||
"role": "reference_image",
|
||
}
|
||
)
|
||
data = {
|
||
"model": payload["model"],
|
||
"content": content,
|
||
"ratio": size_to_video_ratio(str(payload.get("size", ""))),
|
||
"duration": int(float(str(payload.get(VIDEO_DURATION_FIELD, 5)))),
|
||
"watermark": False,
|
||
"resolution": "720p",
|
||
}
|
||
return client.post(url, headers={**headers, "Content-Type": "application/json"}, json=data)
|
||
|
||
if video_uses_poe():
|
||
data = dict(payload)
|
||
data[VIDEO_DURATION_FIELD] = int(float(str(data.get(VIDEO_DURATION_FIELD, 4))))
|
||
data["input_image"] = base64.b64encode(ref_img.read_bytes()).decode("ascii")
|
||
return client.post(url, headers=headers, json=data)
|
||
|
||
with ref_img.open("rb") as fh:
|
||
return client.post(
|
||
url,
|
||
headers=headers,
|
||
data=payload,
|
||
files={"input_reference": ("reference.jpg", fh, "image/jpeg")},
|
||
)
|
||
|
||
|
||
def render_storyboard_video(
|
||
job_id: str,
|
||
local_id: str,
|
||
provider_id: str,
|
||
ref_path: Path,
|
||
prompt: str,
|
||
model: str,
|
||
seconds: str,
|
||
size: str,
|
||
source_ref: VideoSourceRef | None = None,
|
||
last_ref_path: Path | None = None,
|
||
product_ref_paths: list[Path] | None = None,
|
||
primary_role: str = "first_frame",
|
||
) -> None:
|
||
import httpx
|
||
|
||
out_dir = job_dir(job_id) / "storyboard_videos" / local_id
|
||
ref_img = out_dir / "reference.jpg"
|
||
last_img = out_dir / "last_reference.jpg"
|
||
out_mp4 = out_dir / "video.mp4"
|
||
base = video_api_base()
|
||
headers = {"Authorization": f"Bearer {video_api_key()}"}
|
||
|
||
try:
|
||
prepare_video_reference(ref_path, ref_img)
|
||
prepared_last_img: Path | None = None
|
||
if last_ref_path and last_ref_path.exists():
|
||
prepare_video_reference(last_ref_path, last_img)
|
||
prepared_last_img = last_img
|
||
prepared_product_imgs: list[Path] = []
|
||
for i, product_ref_path in enumerate((product_ref_paths or [])[:6], start=1):
|
||
if product_ref_path.exists():
|
||
product_img = out_dir / f"product_reference_{i}.jpg"
|
||
prepare_video_reference(product_ref_path, product_img)
|
||
prepared_product_imgs.append(product_img)
|
||
update_generated_video(job_id, local_id, status="in_progress", progress=5)
|
||
with httpx.Client(timeout=120) as client:
|
||
payload = {"model": model, "prompt": prompt, "size": size}
|
||
payload[VIDEO_DURATION_FIELD] = seconds
|
||
create = None
|
||
create_errors: list[str] = []
|
||
for create_path in VIDEO_CREATE_PATHS:
|
||
resp = submit_video_create(client, f"{base}{video_path(create_path)}", headers, ref_img, payload, source_ref, prepared_last_img, prepared_product_imgs, primary_role)
|
||
if video_uses_ark() and source_ref and resp.status_code in {400, 422}:
|
||
create_errors.append(f"{video_path(create_path)} + reference_video -> HTTP {resp.status_code}: {resp.text[:160]}")
|
||
resp = submit_video_create(client, f"{base}{video_path(create_path)}", headers, ref_img, payload, None, prepared_last_img, prepared_product_imgs, primary_role)
|
||
if video_uses_ark() and prepared_last_img and resp.status_code in {400, 422}:
|
||
create_errors.append(f"{video_path(create_path)} + last_frame -> HTTP {resp.status_code}: {resp.text[:160]}")
|
||
resp = submit_video_create(client, f"{base}{video_path(create_path)}", headers, ref_img, payload, None, None, prepared_product_imgs, primary_role)
|
||
if video_uses_ark() and prepared_product_imgs and resp.status_code in {400, 422}:
|
||
create_errors.append(f"{video_path(create_path)} + product_reference -> HTTP {resp.status_code}: {resp.text[:160]}")
|
||
resp = submit_video_create(client, f"{base}{video_path(create_path)}", headers, ref_img, payload, None, prepared_last_img, None, primary_role)
|
||
if resp.status_code < 400:
|
||
create = resp
|
||
break
|
||
create_errors.append(f"{video_path(create_path)} -> HTTP {resp.status_code}: {resp.text[:160]}")
|
||
if resp.status_code not in {400, 404, 405}:
|
||
resp.raise_for_status()
|
||
if create is None:
|
||
raise RuntimeError("视频模型已选择,但当前网关视频生成入口不可用;已尝试 " + " | ".join(create_errors))
|
||
data = create.json()
|
||
video_api_id = data.get("id") or provider_id or local_id
|
||
status = normalize_video_status(data.get("status"))
|
||
progress = video_progress(data, 5)
|
||
direct_url = video_url_from_response(data)
|
||
update_generated_video(job_id, local_id, provider_id=video_api_id, status=status, progress=progress)
|
||
|
||
deadline = time.time() + VIDEO_POLL_TIMEOUT_SECONDS
|
||
while status in {"queued", "in_progress"} and time.time() < deadline:
|
||
time.sleep(8)
|
||
poll = client.get(f"{base}{video_path(VIDEO_STATUS_PATH, id=video_api_id)}", headers=headers)
|
||
poll.raise_for_status()
|
||
pdata = poll.json()
|
||
status = normalize_video_status(pdata.get("status"))
|
||
progress = video_progress(pdata, progress)
|
||
direct_url = video_url_from_response(pdata) or direct_url
|
||
update_generated_video(job_id, local_id, status=status, progress=progress)
|
||
|
||
if status != "completed":
|
||
update_generated_video(job_id, local_id, status="failed", error=f"video status: {status}", progress=progress)
|
||
return
|
||
|
||
download_generated_video(client, base, headers, video_api_id, direct_url, out_mp4)
|
||
update_generated_video(
|
||
job_id,
|
||
local_id,
|
||
status="completed",
|
||
progress=100,
|
||
url=f"/jobs/{job_id}/storyboard-videos/{local_id}.mp4",
|
||
error="",
|
||
)
|
||
except Exception as e:
|
||
update_generated_video(job_id, local_id, status="failed", error=str(e)[:500])
|
||
|
||
|
||
@app.post("/jobs/{job_id}/frames/{idx}/storyboard/video", response_model=Job)
|
||
def generate_storyboard_video(job_id: str, idx: int, req: GenerateStoryboardVideoReq, bg: BackgroundTasks) -> Job:
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
frame = next((f for f in job.frames if f.index == idx), None)
|
||
if not frame:
|
||
raise HTTPException(404, "frame not found")
|
||
ensure_video_api_configured()
|
||
prompt = req.prompt.strip()
|
||
if not prompt:
|
||
raise HTTPException(400, "prompt required")
|
||
|
||
ref = req.first_image or req.subject_image or req.product_image or req.scene_image or req.action_image
|
||
primary_role = "first_frame" if req.first_image else "reference_image"
|
||
ref_path = storyboard_ref_path(job_id, ref) or (job_dir(job_id) / "frames" / f"{idx:03d}.jpg")
|
||
if not ref_path.exists():
|
||
raise HTTPException(404, "reference image missing")
|
||
poster = storyboard_ref_url(job_id, ref) or f"/jobs/{job_id}/frames/{idx}.jpg"
|
||
last_ref_path = storyboard_ref_path(job_id, req.last_image)
|
||
raw_product_refs = req.product_images[:6] if req.product_images else ([req.product_image] if req.product_image else [])
|
||
product_ref_paths = [p for p in (storyboard_ref_path(job_id, r) for r in raw_product_refs) if p]
|
||
subject_ref_paths = [p for p in (storyboard_ref_path(job_id, r) for r in req.subject_images[:8]) if p]
|
||
reference_ref_paths = []
|
||
seen_ref_paths: set[str] = {str(ref_path)}
|
||
# Product fusion is sensitive to object drift. Send product references before
|
||
# extra character references so the rigid SKG device keeps its real shape.
|
||
for p in [*product_ref_paths, *subject_ref_paths]:
|
||
key = str(p)
|
||
if key not in seen_ref_paths:
|
||
reference_ref_paths.append(p)
|
||
seen_ref_paths.add(key)
|
||
|
||
local_id = uuid.uuid4().hex[:12]
|
||
model = resolve_video_model(req.model)
|
||
seconds = video_seconds(float(req.duration or 4))
|
||
item = GeneratedVideo(
|
||
id=local_id,
|
||
provider_id="",
|
||
frame_idx=idx,
|
||
prompt=prompt,
|
||
model=model,
|
||
status="queued",
|
||
url="",
|
||
poster_url=poster,
|
||
duration=float(seconds),
|
||
progress=0,
|
||
created_at=time.time(),
|
||
)
|
||
update(job, generated_videos=[item] + job.generated_videos, message=f"视频生成已提交 · 分镜 {idx + 1}")
|
||
source_ref = req.source_ref
|
||
if source_ref and source_ref.kind == "source_video" and not source_ref.url:
|
||
source_ref = None
|
||
bg.add_task(render_storyboard_video, job_id, local_id, "", ref_path, prompt, model, seconds, req.size, source_ref, last_ref_path, reference_ref_paths, primary_role)
|
||
return job
|
||
|
||
|
||
@app.get("/jobs/{job_id}/storyboard-videos/{video_id}.mp4")
|
||
def get_storyboard_video(job_id: str, video_id: str):
|
||
p = job_dir(job_id) / "storyboard_videos" / video_id / "video.mp4"
|
||
if not p.exists():
|
||
raise HTTPException(404, "storyboard video not found")
|
||
return FileResponse(p, media_type="video/mp4")
|
||
|
||
|
||
class CopyProductLibraryAssetReq(BaseModel):
|
||
product_id: str
|
||
|
||
|
||
class CopyCharacterLibraryAssetReq(BaseModel):
|
||
character_id: str
|
||
|
||
|
||
class GenerateProductAngleAssetReq(BaseModel):
|
||
source_ref: dict
|
||
source_refs: list[dict] = Field(default_factory=list)
|
||
source_notes: list[str] = Field(default_factory=list)
|
||
target_view: str
|
||
note: str = ""
|
||
|
||
|
||
class AnalyzeProductViewsReq(BaseModel):
|
||
refs: list[dict] = Field(default_factory=list)
|
||
|
||
|
||
class SaveSubjectTemplateReq(BaseModel):
|
||
name: str
|
||
note: str = ""
|
||
frame_idx: int
|
||
element_id: str
|
||
asset_ids: list[str] = Field(default_factory=list)
|
||
subject_style: Literal["transparent_human", "source_actor"] = "transparent_human"
|
||
|
||
|
||
@app.get("/product-library/skg", response_model=list[ProductLibraryItem])
|
||
def list_skg_product_library() -> list[ProductLibraryItem]:
|
||
"""内置 SKG 白底产品图库。来源是本地筛选后的产品图 manifest。"""
|
||
return load_product_library_items()
|
||
|
||
|
||
@app.get("/product-library/skg/images/{filename}")
|
||
def get_skg_product_library_image(filename: str):
|
||
items = load_product_library_items()
|
||
item = next((x for x in items if Path(x.filename).name == filename), None)
|
||
if not item:
|
||
raise HTTPException(404, "product library image not found")
|
||
return FileResponse(product_library_file(item), media_type="image/jpeg")
|
||
|
||
|
||
@app.get("/character-library/skg", response_model=list[CharacterLibraryItem])
|
||
def list_skg_character_library() -> list[CharacterLibraryItem]:
|
||
"""内置透明骨架人角色库。来源是桌面生成的 5 个角色参考组。"""
|
||
return load_character_library_items()
|
||
|
||
|
||
@app.get("/character-library/skg/images/{filename:path}")
|
||
def get_skg_character_library_image(filename: str):
|
||
p = character_library_file(filename)
|
||
media_type = "image/png" if p.suffix.lower() == ".png" else "image/jpeg"
|
||
return FileResponse(p, media_type=media_type)
|
||
|
||
|
||
@app.get("/subject-templates", response_model=list[SubjectTemplateItem])
|
||
def list_subject_templates() -> list[SubjectTemplateItem]:
|
||
"""数据库化主体模板库。保存后的相似主体可被后续任务复用为创意参考。"""
|
||
return load_subject_template_items()
|
||
|
||
|
||
@app.get("/subject-templates/images/{filename:path}")
|
||
def get_subject_template_image(filename: str):
|
||
p = subject_template_image_file(filename)
|
||
return FileResponse(p, media_type="image/jpeg")
|
||
|
||
|
||
@app.post("/jobs/{job_id}/subject-templates", response_model=SubjectTemplateItem)
|
||
def save_subject_template(job_id: str, req: SaveSubjectTemplateReq) -> SubjectTemplateItem:
|
||
"""把当前 job 里已确认的相似主体视图复制到主体模板库。"""
|
||
import time as _time
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
name = req.name.strip()
|
||
if not name:
|
||
raise HTTPException(400, "template name required")
|
||
frame = _find_frame(job, req.frame_idx)
|
||
element = next((e for e in frame.elements if e.id == req.element_id), None)
|
||
if not element:
|
||
raise HTTPException(404, "element not found")
|
||
|
||
requested_ids = [x.strip() for x in req.asset_ids if x.strip()]
|
||
selected_assets = [asset for asset in (element.subject_assets or []) if not requested_ids or asset.id in requested_ids]
|
||
if requested_ids:
|
||
selected_assets.sort(key=lambda asset: requested_ids.index(asset.id) if asset.id in requested_ids else 999)
|
||
else:
|
||
selected_assets.sort(key=lambda asset: asset.created_at, reverse=True)
|
||
if not selected_assets:
|
||
raise HTTPException(400, "no subject assets to save")
|
||
|
||
template_id = f"subject-template-{uuid.uuid4().hex[:10]}"
|
||
template_dir = SUBJECT_TEMPLATE_IMAGE_DIR / template_id
|
||
template_dir.mkdir(parents=True, exist_ok=True)
|
||
now = _time.time()
|
||
images: list[SubjectTemplateImage] = []
|
||
saved_image_paths: list[Path] = []
|
||
for asset in selected_assets:
|
||
src = job_dir(job_id) / "assets" / f"{asset.id}.jpg"
|
||
if not src.exists():
|
||
continue
|
||
image_id = f"{asset.view}_{uuid.uuid4().hex[:8]}"
|
||
filename = f"{template_id}/{image_id}.jpg"
|
||
dst = SUBJECT_TEMPLATE_IMAGE_DIR / filename
|
||
shutil.copy2(src, dst)
|
||
saved_image_paths.append(dst)
|
||
images.append(SubjectTemplateImage(
|
||
id=image_id,
|
||
view=asset.view,
|
||
label=asset.label or asset.view,
|
||
filename=filename,
|
||
url=f"/subject-templates/images/{filename}",
|
||
width=asset.width,
|
||
height=asset.height,
|
||
background=asset.background,
|
||
quality=asset.quality,
|
||
size=asset.size,
|
||
source_asset_id=asset.id,
|
||
source_frame_indices=asset.source_frame_indices,
|
||
created_at=asset.created_at or now,
|
||
))
|
||
if not images:
|
||
raise HTTPException(404, "subject asset files missing")
|
||
|
||
primary = next((image.id for image in images if image.view == "front"), images[0].id)
|
||
prompt_brief = _describe_subject_template_from_images(
|
||
name,
|
||
req.subject_style,
|
||
saved_image_paths,
|
||
req.note.strip(),
|
||
) or req.note.strip()
|
||
item = SubjectTemplateItem(
|
||
id=template_id,
|
||
name=name,
|
||
description=req.note.strip(),
|
||
note=req.note.strip(),
|
||
prompt_brief=prompt_brief,
|
||
source_job_id=job_id,
|
||
source_frame_idx=frame.index,
|
||
source_element_id=element.id,
|
||
subject_style=req.subject_style,
|
||
primary_image=primary,
|
||
images=images,
|
||
created_at=now,
|
||
updated_at=now,
|
||
)
|
||
items = [item] + [existing for existing in load_subject_template_items() if existing.id != item.id]
|
||
save_subject_template_items(items)
|
||
return item
|
||
|
||
|
||
def normalize_product_asset_image(src: Path, out: Path) -> dict:
|
||
original_bytes = src.stat().st_size if src.exists() else 0
|
||
actions: list[str] = []
|
||
warnings: list[str] = []
|
||
with Image.open(src) as opened:
|
||
img = ImageOps.exif_transpose(opened)
|
||
original_width, original_height = img.size
|
||
if img.mode in {"RGBA", "LA"} or ("transparency" in img.info):
|
||
rgba = img.convert("RGBA")
|
||
base = Image.new("RGB", img.size, (255, 255, 255))
|
||
base.paste(rgba, mask=rgba.getchannel("A"))
|
||
img = base
|
||
actions.append("透明背景已铺白")
|
||
elif img.mode != "RGB":
|
||
img = img.convert("RGB")
|
||
actions.append("已转 RGB/JPEG")
|
||
|
||
max_side = max(img.size)
|
||
if max_side > PRODUCT_ASSET_MAX_SIDE:
|
||
ratio = PRODUCT_ASSET_MAX_SIDE / max_side
|
||
next_size = (max(1, round(img.width * ratio)), max(1, round(img.height * ratio)))
|
||
img = img.resize(next_size, Image.Resampling.LANCZOS)
|
||
actions.append(f"最长边压缩到 {PRODUCT_ASSET_MAX_SIDE}px")
|
||
if max(original_width, original_height) >= 2400:
|
||
warnings.append("原图过大已自动压缩;超高清不会提升识别稳定性")
|
||
elif max_side < PRODUCT_ASSET_MIN_LONG_SIDE:
|
||
ratio = PRODUCT_ASSET_MIN_LONG_SIDE / max_side
|
||
next_size = (max(1, round(img.width * ratio)), max(1, round(img.height * ratio)))
|
||
img = img.resize(next_size, Image.Resampling.LANCZOS)
|
||
actions.append(f"低分辨率图已放大到最长边 {PRODUCT_ASSET_MIN_LONG_SIDE}px")
|
||
warnings.append("原始分辨率偏低,已放大为工作图,但真实细节不会增加")
|
||
|
||
if min(img.size) < PRODUCT_ASSET_MIN_SHORT_SIDE:
|
||
warnings.append(f"短边低于 {PRODUCT_ASSET_MIN_SHORT_SIDE}px,细节/比例识别可能不稳")
|
||
if original_bytes >= 5 * 1024 * 1024:
|
||
warnings.append("原文件较大,已生成轻量 AI 工作副本")
|
||
|
||
out.parent.mkdir(parents=True, exist_ok=True)
|
||
img.save(out, "JPEG", quality=PRODUCT_ASSET_JPEG_QUALITY, optimize=True, progressive=True, subsampling=0)
|
||
work_width, work_height = img.size
|
||
|
||
return {
|
||
"standard": f"AI工作副本:最长边≤{PRODUCT_ASSET_MAX_SIDE}px,建议长边≥{PRODUCT_ASSET_MIN_LONG_SIDE}px,短边≥{PRODUCT_ASSET_MIN_SHORT_SIDE}px,JPEG q{PRODUCT_ASSET_JPEG_QUALITY}",
|
||
"original_width": original_width,
|
||
"original_height": original_height,
|
||
"width": work_width,
|
||
"height": work_height,
|
||
"original_bytes": original_bytes,
|
||
"work_bytes": out.stat().st_size if out.exists() else 0,
|
||
"max_side": PRODUCT_ASSET_MAX_SIDE,
|
||
"min_long_side": PRODUCT_ASSET_MIN_LONG_SIDE,
|
||
"min_short_side": PRODUCT_ASSET_MIN_SHORT_SIDE,
|
||
"quality": PRODUCT_ASSET_JPEG_QUALITY,
|
||
"actions": actions,
|
||
"warnings": warnings,
|
||
"normalized": bool(actions or warnings),
|
||
}
|
||
|
||
|
||
@app.post("/jobs/{job_id}/assets")
|
||
async def upload_storyboard_asset(job_id: str, file: UploadFile = File(...)) -> dict:
|
||
if job_id not in JOBS:
|
||
raise HTTPException(404, "job not found")
|
||
asset_id = uuid.uuid4().hex[:12]
|
||
out_dir = job_dir(job_id) / "assets"
|
||
out_dir.mkdir(parents=True, exist_ok=True)
|
||
tmp = out_dir / f"{asset_id}.upload"
|
||
out = out_dir / f"{asset_id}.jpg"
|
||
try:
|
||
tmp.write_bytes(await file.read())
|
||
asset_meta = normalize_product_asset_image(tmp, out)
|
||
except Exception as e:
|
||
raise HTTPException(400, f"product image upload failed: {e}")
|
||
finally:
|
||
try:
|
||
tmp.unlink()
|
||
except Exception:
|
||
pass
|
||
return {
|
||
"kind": "asset",
|
||
"frame_idx": -1,
|
||
"element_id": asset_id,
|
||
"cutout_id": asset_id,
|
||
"label": file.filename or "SKG 产品图",
|
||
"asset_meta": asset_meta,
|
||
}
|
||
|
||
|
||
PRODUCT_VIEW_VALUES = ["front", "left_45", "right_45", "side_thickness", "inner_contacts", "back_bottom"]
|
||
PRODUCT_VIEW_BATCH_SIZE = max(1, min(12, int(os.getenv("PRODUCT_VIEW_BATCH_SIZE", "8"))))
|
||
|
||
PRODUCT_VIEW_LABELS = {
|
||
"front": "正面/外侧主外观",
|
||
"left_45": "佩戴者左 45",
|
||
"right_45": "佩戴者右 45",
|
||
"side_thickness": "侧面厚度",
|
||
"inner_contacts": "贴颈内侧/触点",
|
||
"back_bottom": "背面/底部",
|
||
}
|
||
|
||
PRODUCT_BACKGROUND_VALUES = ["white", "black", "simple", "complex", "unknown"]
|
||
PRODUCT_USE_TAG_VALUES = [
|
||
"hero_packshot",
|
||
"wearing_scale",
|
||
"inner_contact",
|
||
"side_thickness",
|
||
"asymmetry",
|
||
"button_detail",
|
||
"back_bottom",
|
||
"material_texture",
|
||
]
|
||
|
||
|
||
def default_product_use_tags(view: str) -> list[str]:
|
||
defaults = {
|
||
"front": ["hero_packshot", "asymmetry"],
|
||
"left_45": ["hero_packshot", "asymmetry", "button_detail"],
|
||
"right_45": ["hero_packshot", "asymmetry", "button_detail"],
|
||
"side_thickness": ["side_thickness", "wearing_scale"],
|
||
"inner_contacts": ["inner_contact", "wearing_scale"],
|
||
"back_bottom": ["back_bottom", "material_texture"],
|
||
}
|
||
return defaults.get(view, ["hero_packshot"])
|
||
|
||
|
||
def normalize_product_use_tags(tags: object, view: str) -> list[str]:
|
||
if isinstance(tags, str):
|
||
raw_tags = re.split(r"[,,/、\s]+", tags)
|
||
elif isinstance(tags, list):
|
||
raw_tags = [str(x) for x in tags]
|
||
else:
|
||
raw_tags = []
|
||
result = []
|
||
for tag in raw_tags + default_product_use_tags(view):
|
||
tag = str(tag).strip()
|
||
if tag in PRODUCT_USE_TAG_VALUES and tag not in result:
|
||
result.append(tag)
|
||
return result[:4]
|
||
|
||
|
||
def fallback_product_view(index: int) -> dict:
|
||
view = PRODUCT_VIEW_VALUES[min(index, len(PRODUCT_VIEW_VALUES) - 1)]
|
||
return {
|
||
"view": view,
|
||
"background": "unknown",
|
||
"use_tags": default_product_use_tags(view),
|
||
"orientation": default_product_orientation(view),
|
||
"landmarks": default_product_landmarks(view),
|
||
"note": f"{PRODUCT_VIEW_LABELS.get(view, view)}参考;模型识别不可用时按上传顺序自动标注,请重点复核佩戴者左/右、上/下和贴颈内侧。",
|
||
"risk": "模型识别不可用,按上传顺序兜底",
|
||
"confidence": 0.25,
|
||
}
|
||
|
||
|
||
PRODUCT_ORIENTATION_KEYS = [
|
||
"product_left",
|
||
"product_right",
|
||
"top",
|
||
"bottom",
|
||
"inner_side",
|
||
"outer_side",
|
||
"opening_direction",
|
||
]
|
||
|
||
|
||
def default_product_orientation(view: str) -> dict:
|
||
base = {
|
||
"product_left": "佩戴者左侧;需人工复核图中位置",
|
||
"product_right": "佩戴者右侧;需人工复核图中位置",
|
||
"top": "靠近下巴/脸/颈部上沿",
|
||
"bottom": "靠近锁骨/肩部下沿",
|
||
"inner_side": "贴近脖子皮肤的一侧,通常可见按摩触点",
|
||
"outer_side": "外壳展示面,通常可见按键/Logo/材质",
|
||
"opening_direction": "U 形开口方向需结合图片复核",
|
||
}
|
||
if view == "inner_contacts":
|
||
base["inner_side"] = "本图重点:贴颈内侧/按摩触点"
|
||
elif view == "side_thickness":
|
||
base["outer_side"] = "本图重点:侧厚、边缘和机身厚度"
|
||
elif view in {"left_45", "right_45"}:
|
||
base["opening_direction"] = "注意不要把图片左右直接当成产品佩戴者左右"
|
||
return base
|
||
|
||
|
||
def default_product_landmarks(view: str) -> list[str]:
|
||
defaults = {
|
||
"front": ["U形开口", "外壳主轮廓", "左右臂"],
|
||
"left_45": ["佩戴者左侧臂", "侧边弧度", "按键/结构差异"],
|
||
"right_45": ["佩戴者右侧臂", "侧边弧度", "按键/结构差异"],
|
||
"side_thickness": ["机身厚度", "侧边轮廓", "佩戴比例"],
|
||
"inner_contacts": ["贴颈内侧", "按摩触点", "皮肤接触面"],
|
||
"back_bottom": ["背面/底部", "接口/底面", "材质细节"],
|
||
}
|
||
return defaults.get(view, ["U形挂脖轮廓"])
|
||
|
||
|
||
def normalize_product_orientation(value: object, view: str) -> dict:
|
||
base = default_product_orientation(view)
|
||
if isinstance(value, dict):
|
||
for key in PRODUCT_ORIENTATION_KEYS:
|
||
raw = value.get(key)
|
||
if raw is None:
|
||
continue
|
||
text = re.sub(r"\s+", " ", str(raw)).strip().strip('"\' ,,。')
|
||
if text:
|
||
base[key] = text[:80]
|
||
return base
|
||
|
||
|
||
def normalize_product_landmarks(value: object, view: str) -> list[str]:
|
||
if isinstance(value, str):
|
||
raw_items = re.split(r"[,,/、\n]+", value)
|
||
elif isinstance(value, list):
|
||
raw_items = [str(item) for item in value]
|
||
else:
|
||
raw_items = []
|
||
result = []
|
||
for item in raw_items + default_product_landmarks(view):
|
||
text = re.sub(r"\s+", " ", str(item)).strip().strip('"\' ,,。')
|
||
if text and text not in result:
|
||
result.append(text[:24])
|
||
return result[:8]
|
||
|
||
|
||
def normalize_product_view_data(data: dict, index: int) -> dict:
|
||
view = str(data.get("view") or "").strip().strip('"\' ,。')
|
||
if view not in PRODUCT_VIEW_VALUES:
|
||
return fallback_product_view(index)
|
||
background = str(data.get("background") or "unknown").strip().strip('"\' ,。')
|
||
if background not in PRODUCT_BACKGROUND_VALUES:
|
||
background = "unknown"
|
||
use_tags = normalize_product_use_tags(data.get("use_tags"), view)
|
||
orientation = normalize_product_orientation(data.get("orientation"), view)
|
||
landmarks = normalize_product_landmarks(data.get("landmarks"), view)
|
||
note = str(data.get("note") or "").strip().strip('"\' ,,。')
|
||
note = re.sub(r"\s+", " ", note)[:320] or f"{PRODUCT_VIEW_LABELS.get(view, view)}参考"
|
||
risk = str(data.get("risk") or "").strip().strip('"\' ,,。')
|
||
risk = re.sub(r"\s+", " ", risk)[:160]
|
||
try:
|
||
confidence = max(0.0, min(1.0, float(data.get("confidence", 0.5))))
|
||
except Exception:
|
||
confidence = 0.5
|
||
if confidence <= 0 and not risk and landmarks:
|
||
confidence = 0.65
|
||
return {
|
||
"view": view,
|
||
"background": background,
|
||
"use_tags": use_tags,
|
||
"orientation": orientation,
|
||
"landmarks": landmarks,
|
||
"note": note,
|
||
"risk": risk,
|
||
"confidence": confidence,
|
||
}
|
||
|
||
|
||
def parse_product_view_response(raw: str, index: int) -> dict:
|
||
text = (raw or "").strip()
|
||
text = re.sub(r"^```(?:json)?\s*", "", text, flags=re.I).strip()
|
||
text = re.sub(r"\s*```$", "", text).strip()
|
||
match = re.search(r"\{[\s\S]*\}", text)
|
||
json_text = match.group(0) if match else text
|
||
try:
|
||
data = json.loads(json_text)
|
||
except Exception:
|
||
view_match = re.search(r'["\']?view["\']?\s*[::]\s*["\']?([a-z0-9_]+)', text, flags=re.I)
|
||
note_match = re.search(
|
||
r'["\']?note["\']?\s*[::]\s*["\']?([\s\S]*?)(?:["\']?\s*,\s*["\']?confidence|["\']?\s*[,}]\s*$)',
|
||
text,
|
||
flags=re.I,
|
||
)
|
||
confidence_match = re.search(r'["\']?confidence["\']?\s*[::]\s*["\']?([0-9.]+)', text, flags=re.I)
|
||
background_match = re.search(r'["\']?background["\']?\s*[::]\s*["\']?([a-z0-9_]+)', text, flags=re.I)
|
||
tags_match = re.search(r'["\']?use_tags["\']?\s*[::]\s*\[([\s\S]*?)\]', text, flags=re.I)
|
||
landmarks_match = re.search(r'["\']?landmarks["\']?\s*[::]\s*\[([\s\S]*?)(?:\]|\}\s*$)', text, flags=re.I)
|
||
risk_match = re.search(
|
||
r'["\']?risk["\']?\s*[::]\s*["\']?([\s\S]*?)(?:["\']?\s*[,}]\s*$)',
|
||
text,
|
||
flags=re.I,
|
||
)
|
||
orientation = {}
|
||
for key in PRODUCT_ORIENTATION_KEYS:
|
||
orientation_match = re.search(
|
||
rf'["\']?{key}["\']?\s*[::]\s*["\']?([^"\',,}}\]]+)',
|
||
text,
|
||
flags=re.I,
|
||
)
|
||
if orientation_match:
|
||
orientation[key] = orientation_match.group(1)
|
||
data = {
|
||
"view": view_match.group(1) if view_match else "",
|
||
"background": background_match.group(1) if background_match else "unknown",
|
||
"use_tags": re.findall(r"[a-z_]+", tags_match.group(1)) if tags_match else [],
|
||
"orientation": orientation,
|
||
"landmarks": re.findall(r"[\u4e00-\u9fffA-Za-z0-9/_-]+", landmarks_match.group(1)) if landmarks_match else [],
|
||
"note": note_match.group(1) if note_match else "",
|
||
"risk": risk_match.group(1) if risk_match else "",
|
||
"confidence": confidence_match.group(1) if confidence_match else 0.45,
|
||
}
|
||
return normalize_product_view_data(data, index)
|
||
|
||
|
||
def parse_product_view_batch_response(raw: str, indices: list[int]) -> dict[int, dict]:
|
||
text = (raw or "").strip()
|
||
text = re.sub(r"^```(?:json)?\s*", "", text, flags=re.I).strip()
|
||
text = re.sub(r"\s*```$", "", text).strip()
|
||
match = re.search(r"\{[\s\S]*\}", text)
|
||
json_text = match.group(0) if match else text
|
||
try:
|
||
data = json.loads(json_text)
|
||
except Exception:
|
||
starts: list[tuple[int, int]] = []
|
||
for index in indices:
|
||
found = re.search(rf'["\']?index["\']?\s*[::]\s*["\']?{index}["\']?', text)
|
||
if found:
|
||
starts.append((index, found.start()))
|
||
if not starts and len(indices) == 1:
|
||
return {indices[0]: parse_product_view_response(text, indices[0])}
|
||
starts.sort(key=lambda item: item[1])
|
||
tolerant: dict[int, dict] = {}
|
||
for offset, (index, start_pos) in enumerate(starts):
|
||
end_pos = starts[offset + 1][1] if offset + 1 < len(starts) else len(text)
|
||
tolerant[index] = parse_product_view_response(text[start_pos:end_pos], index)
|
||
return tolerant
|
||
raw_items = data.get("items") if isinstance(data, dict) else data
|
||
if not isinstance(raw_items, list):
|
||
raise ValueError("product view batch response missing items[]")
|
||
allowed = set(indices)
|
||
results: dict[int, dict] = {}
|
||
for offset, item in enumerate(raw_items):
|
||
if not isinstance(item, dict):
|
||
continue
|
||
try:
|
||
item_index = int(item.get("index", indices[offset] if offset < len(indices) else -1))
|
||
except Exception:
|
||
item_index = indices[offset] if offset < len(indices) else -1
|
||
if item_index not in allowed:
|
||
continue
|
||
results[item_index] = normalize_product_view_data(item, item_index)
|
||
return results
|
||
|
||
|
||
def product_view_batch_prompt(indices: list[int]) -> str:
|
||
count = len(indices)
|
||
return (
|
||
"你在识别同一款 SKG 挂脖肩颈按摩仪的产品参考图。所有图片都是同一产品,不要判断是不是不同产品,也不要把它当耳机、头戴设备或护颈枕;它是套在脖子上、外置佩戴在肩颈位置的 U 形/围脖式按摩仪,可能有内侧按摩触点、外壳按键、厚度、底部接口和左右不对称结构。\n"
|
||
"先建立产品坐标系,再逐图识别:product_left=产品戴在真人脖子上时佩戴者左肩那一侧;product_right=佩戴者右肩那一侧;top=靠近下巴/脸/颈部上沿;bottom=靠近锁骨/肩部下沿;inner_side=贴近脖子皮肤/按摩触点的一侧;outer_side=外壳/按键/Logo/材质展示面。不要把图片左侧直接等同于产品左侧,必须在 orientation 里说明产品左/右/上/下分别对应图中的哪一边;不确定就写不确定并在 risk 里提醒。\n"
|
||
"每张图的 view 必须从 enum 选一个:front(正面/外侧主外观), left_45(佩戴者左侧45度), right_45(佩戴者右侧45度), side_thickness(侧面厚度), inner_contacts(贴颈内侧/按摩触点), back_bottom(背面/底部/接口)。left_45/right_45 指佩戴者身体左右,不是画面左右。\n"
|
||
"background enum:white, black, simple, complex, unknown。use_tags 只能从 enum 选:hero_packshot, wearing_scale, inner_contact, side_thickness, asymmetry, button_detail, back_bottom, material_texture。\n"
|
||
"landmarks 用中文短词列出可见结构,例如:佩戴者左侧臂、佩戴者右侧臂、U形开口、贴颈内侧、按摩触点、侧边厚度、按键、充电口、底部、外壳材质、局部细节。note 必须用中文写给生视频模型,重点说明这张图适合约束什么,尤其要写清楚左/右/上/下、内/外侧、触点或局部细节。risk 只在可能误导生视频时写中文,如局部裁切、无法判断产品左右、上下颠倒风险、反光、遮挡、分辨率低、背景干扰;否则为空。\n"
|
||
f"本次共有 {count} 张图片,图片前的 Image index 就是输出 index。必须输出同样数量的 items,且 index 不要改。只输出一行严格 JSON,不要 markdown,不要换行。\n"
|
||
"{\"items\":[{\"index\":0,\"view\":\"front|left_45|right_45|side_thickness|inner_contacts|back_bottom\",\"background\":\"white|black|simple|complex|unknown\",\"use_tags\":[\"hero_packshot\"],\"orientation\":{\"product_left\":\"图中哪一侧/不可见/不确定\",\"product_right\":\"图中哪一侧/不可见/不确定\",\"top\":\"图中哪一侧/不可见/不确定\",\"bottom\":\"图中哪一侧/不可见/不确定\",\"inner_side\":\"图中哪一侧/是否可见\",\"outer_side\":\"图中哪一侧/是否可见\",\"opening_direction\":\"U形开口朝图中哪一侧/不可见/不确定\"},\"landmarks\":[\"U形开口\"],\"note\":\"中文备注\",\"risk\":\"\",\"confidence\":0.86}]}"
|
||
)
|
||
|
||
|
||
def analyze_product_view(ref_path: Path, index: int) -> dict:
|
||
if not (IMAGE_API_KEY if PRODUCT_VIEW_MODEL == GPT_IMAGE_MODEL else LLM_API_KEY):
|
||
return fallback_product_view(index)
|
||
img_b64 = base64.b64encode(ref_path.read_bytes()).decode("ascii")
|
||
prompt = (
|
||
"你在识别同一款 SKG 挂脖肩颈按摩仪的一张产品参考图。它是套在脖子上的 U 形/围脖式按摩仪,不是耳机、头戴设备或护颈枕;所有上传图都属于同一产品,不要判断不同产品身份。 "
|
||
"必须使用产品坐标系:product_left=戴在真人脖子上时佩戴者左肩一侧,product_right=佩戴者右肩一侧,top=靠近下巴/脸/颈部上沿,bottom=靠近锁骨/肩部下沿,inner_side=贴颈皮肤/按摩触点,outer_side=外壳/按键/Logo。不要把图片左侧直接当产品左侧;在 orientation 里写清楚产品左/右/上/下对应图中哪边,不确定就说明不确定并写 risk。 "
|
||
"view 从 enum 选一个:front, left_45, right_45, side_thickness, inner_contacts, back_bottom。left_45/right_45 指佩戴者身体左右,不是画面左右。 "
|
||
"background 从 enum 选:white, black, simple, complex, unknown。use_tags 只能从 enum 选:hero_packshot, wearing_scale, inner_contact, side_thickness, asymmetry, button_detail, back_bottom, material_texture。 "
|
||
"landmarks 用中文短词列出可见结构,例如佩戴者左侧臂、佩戴者右侧臂、U形开口、贴颈内侧、按摩触点、侧边厚度、按键、充电口、底部、外壳材质、局部细节。note 用中文写给生视频模型,重点说明左/右/上/下、内/外侧、触点或局部细节。risk 只在可能误导生视频时写中文,否则为空。 "
|
||
"Output one-line strict JSON only. Do not use markdown or line breaks. "
|
||
"{\"view\":\"front|left_45|right_45|side_thickness|inner_contacts|back_bottom\",\"background\":\"white|black|simple|complex|unknown\",\"use_tags\":[\"hero_packshot\"],\"orientation\":{\"product_left\":\"图中哪一侧/不可见/不确定\",\"product_right\":\"图中哪一侧/不可见/不确定\",\"top\":\"图中哪一侧/不可见/不确定\",\"bottom\":\"图中哪一侧/不可见/不确定\",\"inner_side\":\"图中哪一侧/是否可见\",\"outer_side\":\"图中哪一侧/是否可见\",\"opening_direction\":\"U形开口朝图中哪一侧/不可见/不确定\"},\"landmarks\":[\"U形开口\"],\"note\":\"中文备注\",\"risk\":\"\",\"confidence\":0.86}."
|
||
)
|
||
try:
|
||
resp = product_view_llm().chat.completions.create(
|
||
model=PRODUCT_VIEW_MODEL,
|
||
messages=[{"role": "user", "content": [
|
||
{"type": "text", "text": prompt},
|
||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img_b64}"}},
|
||
]}],
|
||
response_format={"type": "json_object"},
|
||
temperature=0.1,
|
||
max_tokens=1600,
|
||
)
|
||
raw = (resp.choices[0].message.content or "").strip()
|
||
if not raw:
|
||
raw = (getattr(resp.choices[0].message, "reasoning_content", "") or "").strip()
|
||
return parse_product_view_response(raw, index)
|
||
except Exception as e:
|
||
fallback = fallback_product_view(index)
|
||
fallback["note"] = f"{fallback['note']} 识别失败:{str(e)[:80]}"
|
||
return fallback
|
||
|
||
|
||
def analyze_product_views_batch(paths_by_index: list[tuple[int, Path]]) -> dict[int, dict]:
|
||
if not (IMAGE_API_KEY if PRODUCT_VIEW_MODEL == GPT_IMAGE_MODEL else LLM_API_KEY):
|
||
return {index: fallback_product_view(index) for index, _path in paths_by_index}
|
||
results: dict[int, dict] = {}
|
||
for start in range(0, len(paths_by_index), PRODUCT_VIEW_BATCH_SIZE):
|
||
chunk = paths_by_index[start:start + PRODUCT_VIEW_BATCH_SIZE]
|
||
indices = [index for index, _path in chunk]
|
||
content: list[dict] = [{"type": "text", "text": product_view_batch_prompt(indices)}]
|
||
for index, path in chunk:
|
||
img_b64 = base64.b64encode(path.read_bytes()).decode("ascii")
|
||
content.append({"type": "text", "text": f"Image index {index}"})
|
||
content.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img_b64}"}})
|
||
try:
|
||
resp = product_view_llm().chat.completions.create(
|
||
model=PRODUCT_VIEW_MODEL,
|
||
messages=[{"role": "user", "content": content}],
|
||
response_format={"type": "json_object"},
|
||
temperature=0.05,
|
||
max_tokens=max(2400, min(7000, 1200 * len(chunk))),
|
||
)
|
||
raw = (resp.choices[0].message.content or "").strip()
|
||
if not raw:
|
||
raw = (getattr(resp.choices[0].message, "reasoning_content", "") or "").strip()
|
||
parsed = parse_product_view_batch_response(raw, indices)
|
||
for index in indices:
|
||
results[index] = parsed.get(index) or analyze_product_view(chunk[indices.index(index)][1], index)
|
||
except Exception as e:
|
||
for index, path in chunk:
|
||
try:
|
||
result = analyze_product_view(path, index)
|
||
except Exception:
|
||
result = fallback_product_view(index)
|
||
if result.get("risk"):
|
||
result["risk"] = f"{result['risk']};批量识别失败后单图兜底"
|
||
else:
|
||
result["risk"] = f"批量识别失败后单图兜底:{str(e)[:60]}"
|
||
results[index] = result
|
||
return results
|
||
|
||
|
||
@app.post("/jobs/{job_id}/assets/product-views/analyze")
|
||
def analyze_product_views(job_id: str, req: AnalyzeProductViewsReq) -> dict:
|
||
if job_id not in JOBS:
|
||
raise HTTPException(404, "job not found")
|
||
path_items: list[tuple[int, Path]] = []
|
||
missing_results: dict[int, dict] = {}
|
||
for index, ref in enumerate(req.refs):
|
||
ref_path = storyboard_ref_path(job_id, ref)
|
||
if not ref_path or not ref_path.exists():
|
||
missing_results[index] = fallback_product_view(index)
|
||
else:
|
||
path_items.append((index, ref_path))
|
||
batch_results = analyze_product_views_batch(path_items) if path_items else {}
|
||
items = []
|
||
for index, _ref in enumerate(req.refs):
|
||
result = batch_results.get(index) or missing_results.get(index) or fallback_product_view(index)
|
||
items.append({
|
||
"index": index,
|
||
"view": result["view"],
|
||
"background": result.get("background", "unknown"),
|
||
"use_tags": result.get("use_tags", default_product_use_tags(result["view"])),
|
||
"orientation": result.get("orientation", default_product_orientation(result["view"])),
|
||
"landmarks": result.get("landmarks", default_product_landmarks(result["view"])),
|
||
"note": result["note"],
|
||
"risk": result.get("risk", ""),
|
||
"confidence": result["confidence"],
|
||
})
|
||
used = {item["view"] for item in items}
|
||
missing = [view for view in PRODUCT_VIEW_VALUES if view not in used]
|
||
return {"items": items, "missing_views": missing}
|
||
|
||
|
||
@app.post("/jobs/{job_id}/assets/product-angle")
|
||
def generate_product_angle_asset(job_id: str, req: GenerateProductAngleAssetReq) -> dict:
|
||
if job_id not in JOBS:
|
||
raise HTTPException(404, "job not found")
|
||
raw_refs = [req.source_ref] + list(req.source_refs or [])
|
||
source_paths: list[Path] = []
|
||
seen_paths: set[str] = set()
|
||
for ref in raw_refs:
|
||
ref_path = storyboard_ref_path(job_id, ref)
|
||
if ref_path and ref_path.exists():
|
||
key = str(ref_path)
|
||
if key not in seen_paths:
|
||
seen_paths.add(key)
|
||
source_paths.append(ref_path)
|
||
if len(source_paths) >= 6:
|
||
break
|
||
if not source_paths:
|
||
raise HTTPException(404, "source product image not found")
|
||
source_path = source_paths[0]
|
||
target_view = (req.target_view or "目标视角").strip()
|
||
note = (req.note or "").strip()
|
||
source_notes = [re.sub(r"\s+", " ", str(item)).strip()[:180] for item in (req.source_notes or []) if str(item).strip()]
|
||
source_note_clause = (
|
||
"Uploaded reference notes from the operator/view recognizer: "
|
||
+ " | ".join(source_notes[:6])
|
||
+ ". "
|
||
if source_notes
|
||
else ""
|
||
)
|
||
prompt = (
|
||
"Use all provided reference images as evidence for the same SKG neck-and-shoulder wearable massage product. "
|
||
"Each input image is one uploaded view of the same product; do not output a board, collage, or multiple products. "
|
||
f"Generate a clean product-only white-background reference image in this missing view: {target_view}. "
|
||
+ source_note_clause
|
||
+ "Preserve the exact product identity: white U-shaped wearable neck and shoulder massager that sits around the neck, asymmetric wearer-left and wearer-right details, side buttons, inner metal massage contacts, opening width, material, thickness, curvature, and real shoulder-neck wearing scale. "
|
||
"Use product coordinates: wearer-left/right are the user's body left/right when worn, top is near chin/upper neck, bottom is near collarbone/shoulders, inner side touches skin, outer side is the shell/buttons. "
|
||
"Do not mirror both sides into identical shapes; keep visible left/right asymmetry and believable shoulder-neck wearable proportions. "
|
||
"The product should be complete, centered, isolated on pure white, large enough to inspect, with no hands, people, packaging, text, UI, watermark, extra accessories, or scene background. "
|
||
"If the target view is not fully visible in the source, infer the missing surfaces conservatively from the same product design without inventing a new model. "
|
||
+ (f"Additional operator note: {note}. " if note else "")
|
||
)
|
||
models = [GPT_IMAGE_MODEL]
|
||
try:
|
||
img_bytes, _mode = _image_edit_call(source_paths, prompt, models=models, fallback_text=False, max_attempts=5, max_side=1600)
|
||
except RuntimeError as e:
|
||
raise HTTPException(_image_error_status(e), f"product angle generation failed: {e}")
|
||
asset_id = f"product_angle_{uuid.uuid4().hex[:10]}"
|
||
out_path = job_dir(job_id) / "assets" / f"{asset_id}.jpg"
|
||
_normalize_asset_image(img_bytes, out_path, source_path, "1024", "white", square=True, fill_subject=True)
|
||
return {
|
||
"kind": "asset",
|
||
"frame_idx": -1,
|
||
"element_id": asset_id,
|
||
"cutout_id": asset_id,
|
||
"label": f"AI 补角度 · {target_view}",
|
||
}
|
||
|
||
|
||
@app.post("/jobs/{job_id}/assets/product-library")
|
||
def copy_product_library_asset(job_id: str, req: CopyProductLibraryAssetReq) -> dict:
|
||
if job_id not in JOBS:
|
||
raise HTTPException(404, "job not found")
|
||
item = find_product_library_item(req.product_id)
|
||
src = product_library_file(item)
|
||
asset_id = uuid.uuid4().hex[:12]
|
||
out_dir = job_dir(job_id) / "assets"
|
||
out_dir.mkdir(parents=True, exist_ok=True)
|
||
out = out_dir / f"{asset_id}.jpg"
|
||
try:
|
||
asset_meta = normalize_product_asset_image(src, out)
|
||
except Exception as e:
|
||
raise HTTPException(400, f"product library copy failed: {e}")
|
||
label = f"产品融合 · {item.title} #{item.image_index}"
|
||
return {
|
||
"kind": "asset",
|
||
"frame_idx": -1,
|
||
"element_id": asset_id,
|
||
"cutout_id": asset_id,
|
||
"label": label,
|
||
"asset_meta": asset_meta,
|
||
}
|
||
|
||
|
||
@app.post("/jobs/{job_id}/assets/character-library")
|
||
def copy_character_library_assets(job_id: str, req: CopyCharacterLibraryAssetReq) -> dict:
|
||
if job_id not in JOBS:
|
||
raise HTTPException(404, "job not found")
|
||
character = find_character_library_item(req.character_id)
|
||
out_dir = job_dir(job_id) / "assets"
|
||
out_dir.mkdir(parents=True, exist_ok=True)
|
||
refs = []
|
||
for image in character.images:
|
||
src = character_library_file(image.filename)
|
||
asset_id = uuid.uuid4().hex[:12]
|
||
out = out_dir / f"{asset_id}.jpg"
|
||
try:
|
||
img = Image.open(src).convert("RGB")
|
||
img.thumbnail((1600, 1600), Image.Resampling.LANCZOS)
|
||
img.save(out, "JPEG", quality=94)
|
||
except Exception as e:
|
||
raise HTTPException(400, f"character library copy failed: {e}")
|
||
refs.append({
|
||
"kind": "asset",
|
||
"frame_idx": -1,
|
||
"element_id": asset_id,
|
||
"cutout_id": asset_id,
|
||
"label": f"角色 · {character.name} · {image.label}",
|
||
})
|
||
return {
|
||
"character_id": character.id,
|
||
"character_name": character.name,
|
||
"images": refs,
|
||
}
|
||
|
||
|
||
def product_image_alpha(img: Image.Image) -> Image.Image:
|
||
rgba = img.convert("RGBA")
|
||
rgb = rgba.convert("RGB")
|
||
diff = ImageChops.difference(rgb, Image.new("RGB", rgb.size, (255, 255, 255)))
|
||
mask = diff.convert("L").point(lambda p: 0 if p < 18 else min(255, int(p * 2.4)))
|
||
mask = mask.filter(ImageFilter.GaussianBlur(0.7))
|
||
rgba.putalpha(mask)
|
||
return rgba
|
||
|
||
|
||
@app.post("/jobs/{job_id}/product-fusion/guide")
|
||
def create_product_fusion_guide(job_id: str, req: ProductFusionShot) -> dict:
|
||
if job_id not in JOBS:
|
||
raise HTTPException(404, "job not found")
|
||
person_path = storyboard_ref_path(job_id, req.person_image)
|
||
product_path = storyboard_ref_path(job_id, req.product_image)
|
||
if not person_path or not person_path.exists():
|
||
raise HTTPException(400, "person image required")
|
||
if not product_path or not product_path.exists():
|
||
raise HTTPException(400, "product image required")
|
||
if not req.product_region or req.product_region.w <= 0 or req.product_region.h <= 0:
|
||
raise HTTPException(400, "product region required")
|
||
|
||
region = req.product_region
|
||
x = max(0.0, min(1.0, float(region.x)))
|
||
y = max(0.0, min(1.0, float(region.y)))
|
||
w = max(0.02, min(1.0 - x, float(region.w)))
|
||
h = max(0.02, min(1.0 - y, float(region.h)))
|
||
|
||
try:
|
||
base = Image.open(person_path).convert("RGB")
|
||
base.thumbnail((1600, 1600), Image.Resampling.LANCZOS)
|
||
product = product_image_alpha(Image.open(product_path))
|
||
bw, bh = base.size
|
||
box = (
|
||
int(round(x * bw)),
|
||
int(round(y * bh)),
|
||
max(1, int(round(w * bw))),
|
||
max(1, int(round(h * bh))),
|
||
)
|
||
product.thumbnail((box[2], box[3]), Image.Resampling.LANCZOS)
|
||
px = box[0] + max(0, (box[2] - product.width) // 2)
|
||
py = box[1] + max(0, (box[3] - product.height) // 2)
|
||
guide = base.convert("RGBA")
|
||
guide.alpha_composite(product, (px, py))
|
||
out = guide.convert("RGB")
|
||
asset_id = uuid.uuid4().hex[:12]
|
||
out_dir = job_dir(job_id) / "assets"
|
||
out_dir.mkdir(parents=True, exist_ok=True)
|
||
out_path = out_dir / f"{asset_id}.jpg"
|
||
out.save(out_path, "JPEG", quality=94)
|
||
except Exception as e:
|
||
raise HTTPException(400, f"product fusion guide failed: {e}")
|
||
|
||
return {
|
||
"kind": "asset",
|
||
"frame_idx": -1,
|
||
"element_id": asset_id,
|
||
"cutout_id": asset_id,
|
||
"label": f"产品融合引导图 · {req.image_model or 'gpt-image-2'}",
|
||
}
|
||
|
||
|
||
def fallback_product_fusion_descriptions() -> list[str]:
|
||
return [
|
||
"清晨卧室柔光里,透明骨架人把白色 SKG 颈部按摩仪轻戴到后颈,微微闭眼露出放松微笑。",
|
||
"现代客厅沙发旁,透明骨架人双手扶住 SKG 机身两侧,肩线慢慢放低,表情从紧绷变舒适。",
|
||
"居家办公桌前,透明骨架人轻按 SKG 侧边控制键,颈部骨架区域清晰可见,神情安静享受。",
|
||
"暖色卧室床边,透明骨架人佩戴 SKG 后轻轻仰头,白色骨架与透明外壳干净明亮,画面高级。",
|
||
"落地窗自然光下,透明骨架人坐姿端正,SKG 产品贴合后颈,嘴角微扬呈现轻松舒缓状态。",
|
||
"简洁浴室镜前,透明骨架人用双手调整 SKG 贴合角度,眼神柔和,产品白色机身清楚可辨。",
|
||
"午后阳台休息区,透明骨架人戴着 SKG 慢慢侧头伸展,肩颈线条舒展,表情舒适而不夸张。",
|
||
"高端影棚白色背景中,透明骨架人平稳转身展示 SKG 佩戴效果,产品比例真实,轮廓清晰。",
|
||
"健身后休息长椅上,透明骨架人把 SKG 放上肩颈,呼吸放慢,脸上出现明显放松感。",
|
||
"办公会议间隙,透明骨架人靠在椅背上佩戴 SKG,轻轻闭眼,画面传达短暂恢复和舒适休息。",
|
||
"夜晚卧室暖灯下,透明骨架人坐在床沿使用 SKG,肩颈骨架被柔和光线照亮,神情安稳享受。",
|
||
"城市公寓客厅里,透明骨架人一边看向窗外一边使用 SKG,动作自然,产品贴合不漂移。",
|
||
"极简桌面场景中,透明骨架人拿起 SKG 靠近颈部,镜头轻推展示产品材质和佩戴准备动作。",
|
||
"木质休闲椅上,透明骨架人佩戴 SKG 后轻轻呼气,肩部下沉,脸部呈现舒缓满足的微笑。",
|
||
"白色商业摄影场景里,透明骨架人用指尖轻触 SKG 按键,产品细节清晰,人物状态轻松专业。",
|
||
"温暖客厅地毯旁,透明骨架人坐姿放松,SKG 稳定贴合后颈,闭眼感受舒适放松的瞬间。",
|
||
"窗边阅读角落中,透明骨架人戴着 SKG 翻开书页,动作慢而自然,表情平和享受。",
|
||
"办公室午休场景里,透明骨架人把 SKG 戴稳后靠回椅背,眼睛半闭,颈肩明显放松。",
|
||
"干净产品广告场景中,透明骨架人轻扶 SKG 两端展示佩戴贴合度,微笑自然,产品不变形。",
|
||
"收尾特写镜头里,透明骨架人佩戴 SKG 后缓慢抬头微笑,白色骨架清楚,整体干净高级。",
|
||
]
|
||
|
||
|
||
@app.post("/jobs/{job_id}/product-fusion/descriptions")
|
||
def generate_product_fusion_descriptions(job_id: str, req: ProductFusionDescriptionReq) -> dict:
|
||
if job_id not in JOBS:
|
||
raise HTTPException(404, "job not found")
|
||
fallback = fallback_product_fusion_descriptions()
|
||
shots = (req.shots or [])[:6]
|
||
if not LLM_API_KEY:
|
||
return {"descriptions": fallback, "mode": "fallback"}
|
||
shot_lines = []
|
||
for i, shot in enumerate(shots, start=1):
|
||
first = (shot.first_image or {}).get("label") or "首帧未填"
|
||
last = (shot.last_image or {}).get("label") or "尾帧未填"
|
||
products = [
|
||
(ref or {}).get("label") or f"产品角度{idx + 1}未填"
|
||
for idx, ref in enumerate((shot.product_images or [])[:4])
|
||
]
|
||
while len(products) < 4:
|
||
products.append(f"产品角度{len(products) + 1}未填")
|
||
shot_lines.append(f"{i}. 首帧={first};尾帧={last};产品角度={products[0]} / {products[1]} / {products[2]} / {products[3]};已有描述={shot.action_text or '空'}")
|
||
prompt = (
|
||
"你是 SKG 产品短视频分镜导演。请写 20 条中文产品融合动作描述,"
|
||
"每条 35-70 字,必须说明透明骨架人在什么场景下使用产品、产品如何佩戴/展示、脸部如何舒适享受。"
|
||
"产品是 SKG 白色 U 形颈部/肩颈按摩仪,四张产品角度图是同一产品的身份真源;不要写医疗治疗承诺,不要出现竞品。"
|
||
"输出 JSON:{\"descriptions\":[\"...\", \"...\"]}。\n\n"
|
||
+ "\n".join(shot_lines)
|
||
)
|
||
try:
|
||
resp = llm().chat.completions.create(
|
||
model=REWRITE_MODEL,
|
||
messages=[
|
||
{"role": "system", "content": "只输出合法 JSON,不要解释。"},
|
||
{"role": "user", "content": prompt},
|
||
],
|
||
temperature=0.5,
|
||
)
|
||
text = resp.choices[0].message.content or ""
|
||
data = json.loads(text)
|
||
descriptions = [str(x).strip() for x in data.get("descriptions", []) if str(x).strip()]
|
||
if len(descriptions) < 20:
|
||
descriptions = (descriptions + fallback)[:20]
|
||
return {"descriptions": descriptions[:20], "mode": "llm"}
|
||
except Exception:
|
||
return {"descriptions": fallback, "mode": "fallback"}
|
||
|
||
|
||
@app.get("/jobs/{job_id}/assets/{asset_id}.jpg")
|
||
def get_storyboard_asset(job_id: str, asset_id: str):
|
||
p = job_dir(job_id) / "assets" / f"{asset_id}.jpg"
|
||
if not p.exists():
|
||
raise HTTPException(404, "asset not found")
|
||
return FileResponse(p, media_type="image/jpeg")
|
||
|
||
|
||
@app.delete("/jobs/{job_id}/storyboard-videos/{video_id}", response_model=Job)
|
||
def delete_storyboard_video(job_id: str, video_id: str) -> Job:
|
||
"""删除 Video Gen 节点里的一个视频任务(成功/失败/排队都可删)。"""
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
before = len(job.generated_videos)
|
||
removed = next((v for v in job.generated_videos if v.id == video_id), None)
|
||
kept = [v for v in job.generated_videos if v.id != video_id]
|
||
if len(kept) == before:
|
||
raise HTTPException(404, "generated video not found")
|
||
out_dir = job_dir(job_id) / "storyboard_videos" / video_id
|
||
if out_dir.exists():
|
||
try:
|
||
shutil.rmtree(out_dir)
|
||
except OSError:
|
||
pass
|
||
msg = f"删除视频任务 · 分镜 {removed.frame_idx + 1}" if removed else "删除视频任务"
|
||
update(job, generated_videos=kept, message=msg)
|
||
return job
|
||
|
||
|
||
@app.put("/jobs/{job_id}/frames/{idx}/storyboard", response_model=Job)
|
||
def update_storyboard(job_id: str, idx: int, req: UpdateStoryboardReq) -> Job:
|
||
"""更新分镜的编排字段(subject / product / scene / action / duration / reference_ids)"""
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
frame = next((f for f in job.frames if f.index == idx), None)
|
||
if not frame:
|
||
raise HTTPException(404, "frame not found")
|
||
new_frames = []
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
f.storyboard = StoryboardScene(
|
||
duration=max(0.0, float(req.duration)),
|
||
first_image=req.first_image,
|
||
last_image=req.last_image,
|
||
product_images=list(req.product_images),
|
||
subject_images=list(req.subject_images),
|
||
product_fusion_shots=list(req.product_fusion_shots),
|
||
visual_mode=req.visual_mode,
|
||
needs_product=bool(req.needs_product),
|
||
needs_subject=bool(req.needs_subject),
|
||
first_frame_plan=req.first_frame_plan.strip(),
|
||
last_frame_plan=req.last_frame_plan.strip(),
|
||
product_placement=req.product_placement.strip(),
|
||
subject_image=req.subject_image,
|
||
scene_image=req.scene_image,
|
||
product_image=req.product_image,
|
||
action_image=req.action_image,
|
||
subject=req.subject.strip(),
|
||
product=req.product.strip(),
|
||
scene=req.scene.strip(),
|
||
action=req.action.strip(),
|
||
reference_ids=list(req.reference_ids),
|
||
)
|
||
new_frames.append(f)
|
||
update(job, frames=new_frames, message=f"分镜 {idx + 1} 编排已更新")
|
||
return job
|
||
|
||
|
||
class PushStoryboardImageReq(BaseModel):
|
||
kind: Literal["keyframe", "cutout", "asset"]
|
||
frame_idx: int
|
||
element_id: str | None = None
|
||
cutout_id: str | None = None
|
||
label: str = ""
|
||
|
||
|
||
@app.post("/jobs/{job_id}/storyboard-images", response_model=Job)
|
||
def push_storyboard_image(job_id: str, req: PushStoryboardImageReq) -> Job:
|
||
"""把一张图(关键帧本身或元素提取图)推送到分镜头编排区"""
|
||
import time as _time
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
# 防重复推送:相同 frame_idx + element_id + cutout_id 已存在就跳过
|
||
for existing in job.storyboard_images:
|
||
if (existing.kind == req.kind
|
||
and existing.frame_idx == req.frame_idx
|
||
and existing.element_id == req.element_id
|
||
and existing.cutout_id == req.cutout_id):
|
||
return job
|
||
img = StoryboardImage(
|
||
ref_id=uuid.uuid4().hex[:8],
|
||
kind=req.kind,
|
||
frame_idx=req.frame_idx,
|
||
element_id=req.element_id,
|
||
cutout_id=req.cutout_id,
|
||
label=req.label.strip(),
|
||
created_at=_time.time(),
|
||
)
|
||
update(job, storyboard_images=job.storyboard_images + [img], message=f"上推到分镜头编排 · {req.label or req.kind}")
|
||
return job
|
||
|
||
|
||
@app.delete("/jobs/{job_id}/storyboard-images/{ref_id}", response_model=Job)
|
||
def remove_storyboard_image(job_id: str, ref_id: str) -> Job:
|
||
"""从分镜头编排区移除一张图"""
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
before = len(job.storyboard_images)
|
||
new_list = [x for x in job.storyboard_images if x.ref_id != ref_id]
|
||
if len(new_list) == before:
|
||
raise HTTPException(404, "storyboard image not found")
|
||
update(job, storyboard_images=new_list, message="从分镜头编排移除一张图")
|
||
return job
|
||
|
||
|
||
@app.get("/jobs/{job_id}/frames/{idx}/elements/{element_id}/cutouts/{cutout_id}.jpg")
|
||
def get_cutout_versioned(job_id: str, idx: int, element_id: str, cutout_id: str):
|
||
p = job_dir(job_id) / "elements" / f"{idx:03d}_{element_id}_{cutout_id}.jpg"
|
||
if not p.exists():
|
||
raise HTTPException(404, "cutout not found")
|
||
return FileResponse(p, media_type="image/jpeg")
|
||
|
||
|
||
@app.get("/jobs/{job_id}/frames/{idx}/elements/{element_id}/cutout.jpg")
|
||
def get_cutout(job_id: str, idx: int, element_id: str):
|
||
"""旧路径兼容(v1 单图)→ 找 elements/{idx}_{element_id}.jpg 或 .png"""
|
||
p = job_dir(job_id) / "elements" / f"{idx:03d}_{element_id}.jpg"
|
||
if not p.exists():
|
||
legacy = job_dir(job_id) / "elements" / f"{idx:03d}_{element_id}.png"
|
||
if legacy.exists():
|
||
return FileResponse(legacy, media_type="image/jpeg")
|
||
raise HTTPException(404, "cutout not found")
|
||
return FileResponse(p, media_type="image/jpeg")
|
||
|
||
|
||
# ---------- 删除:关键帧 / 单张生成图 ----------
|
||
|
||
@app.delete("/jobs/{job_id}/frames/{idx}", response_model=Job)
|
||
def delete_frame(job_id: str, idx: int) -> Job:
|
||
"""删除整张关键帧,清理所有附属文件(原图 / 干净版 / 元素抠图 / 生成图)"""
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
target = next((f for f in job.frames if f.index == idx), None)
|
||
if not target:
|
||
raise HTTPException(404, "frame not found")
|
||
|
||
d = job_dir(job_id)
|
||
# 删文件 — 静默错误,文件可能不存在
|
||
paths = [
|
||
d / "frames" / f"{idx:03d}.jpg",
|
||
d / "cleaned" / f"{idx:03d}.jpg",
|
||
]
|
||
for p in paths:
|
||
if p.exists():
|
||
try: p.unlink()
|
||
except OSError: pass
|
||
# 该帧的所有元素抠图(命名前缀 {idx:03d}_)
|
||
elements_dir = d / "elements"
|
||
if elements_dir.exists():
|
||
for ext in ("png", "jpg"):
|
||
for p in elements_dir.glob(f"{idx:03d}_*.{ext}"):
|
||
try: p.unlink()
|
||
except OSError: pass
|
||
# 该帧的所有生成图
|
||
gen_dir = d / "gen"
|
||
if gen_dir.exists():
|
||
for p in gen_dir.glob(f"{idx:03d}_*.jpg"):
|
||
try: p.unlink()
|
||
except OSError: pass
|
||
|
||
new_frames = [f for f in job.frames if f.index != idx]
|
||
update(job, frames=new_frames, message=f"删除分镜 {idx + 1}")
|
||
return job
|
||
|
||
|
||
@app.delete("/jobs/{job_id}/frames/{idx}/gen/{gen_id}", response_model=Job)
|
||
def delete_generated(job_id: str, idx: int, gen_id: str) -> Job:
|
||
"""删除该 frame 的某张生成图(文件 + 列表)"""
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
frame = next((f for f in job.frames if f.index == idx), None)
|
||
if not frame:
|
||
raise HTTPException(404, "frame not found")
|
||
|
||
p = job_dir(job_id) / "gen" / f"{idx:03d}_{gen_id}.jpg"
|
||
if p.exists():
|
||
try: p.unlink()
|
||
except OSError: pass
|
||
|
||
new_frames = []
|
||
found = False
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
before = len(f.generated_images)
|
||
f.generated_images = [g for g in f.generated_images if g.id != gen_id]
|
||
found = len(f.generated_images) < before
|
||
new_frames.append(f)
|
||
if not found:
|
||
raise HTTPException(404, "generated image not found")
|
||
update(job, frames=new_frames, message=f"删除生成图 · 分镜 {idx + 1}")
|
||
return job
|