3676 lines
147 KiB
Python
3676 lines
147 KiB
Python
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import base64
|
||
import json
|
||
import os
|
||
import shutil
|
||
import subprocess
|
||
import time
|
||
import uuid
|
||
from contextlib import asynccontextmanager
|
||
from pathlib import Path
|
||
from typing import Literal
|
||
|
||
import httpx
|
||
from dotenv import load_dotenv
|
||
from fastapi import BackgroundTasks, FastAPI, File, HTTPException, UploadFile
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.responses import FileResponse
|
||
from pydantic import BaseModel, Field
|
||
|
||
load_dotenv()
|
||
|
||
JOBS_DIR = Path(os.getenv("JOBS_DIR", "./jobs")).resolve()
|
||
JOBS_DIR.mkdir(parents=True, exist_ok=True)
|
||
CORS_ORIGINS = [o.strip() for o in os.getenv("CORS_ORIGINS", "http://localhost:4290").split(",") if o.strip()]
|
||
PRODUCT_LIBRARY_DIR = Path(
|
||
os.getenv("PRODUCT_LIBRARY_DIR", Path(__file__).resolve().parent / "product_library" / "skg-products")
|
||
).resolve()
|
||
PRODUCT_LIBRARY_MANIFEST = PRODUCT_LIBRARY_DIR / "manifest.json"
|
||
|
||
LLM_BASE_URL = os.getenv("LLM_BASE_URL", "").strip()
|
||
LLM_API_KEY = os.getenv("LLM_API_KEY", "").strip()
|
||
ASR_MODEL = os.getenv("ASR_MODEL", "whisper-1")
|
||
TRANSLATE_MODEL = os.getenv("TRANSLATE_MODEL", "gemini-2.5-flash")
|
||
REWRITE_MODEL = os.getenv("REWRITE_MODEL", "gemini-2.5-pro")
|
||
VISION_MODEL = os.getenv("VISION_MODEL", "gemini-2.5-flash")
|
||
IMAGE_MODEL = os.getenv("IMAGE_MODEL", "gemini-3-pro-image-preview")
|
||
VIDEO_MODEL = os.getenv("VIDEO_MODEL", "seedance").strip() or "seedance"
|
||
AUDIO_PRODUCT_BRIEF = os.getenv(
|
||
"AUDIO_PRODUCT_BRIEF",
|
||
"SKG 智能按摩产品,主打日常肩颈、腰背、眼部、膝盖或足部放松;广告表达要高级、干净、可信,不做医疗疗效承诺。",
|
||
).strip()
|
||
AUDIO_REWRITE_MODEL = os.getenv("AUDIO_REWRITE_MODEL", REWRITE_MODEL).strip() or REWRITE_MODEL
|
||
MINIMAX_API_KEY = os.getenv("MINIMAX_API_KEY", "").strip()
|
||
MINIMAX_TTS_BASE_URL = os.getenv("MINIMAX_TTS_BASE_URL", "https://api.minimax.io").strip().rstrip("/")
|
||
MINIMAX_TTS_MODEL = os.getenv("MINIMAX_TTS_MODEL", "speech-2.8-turbo").strip() or "speech-2.8-turbo"
|
||
MINIMAX_TTS_VOICE_ID = os.getenv(
|
||
"MINIMAX_TTS_VOICE_ID",
|
||
"Chinese (Mandarin)_Reliable_Executive",
|
||
).strip() or "Chinese (Mandarin)_Reliable_Executive"
|
||
|
||
POE_API_BASE_URL = os.getenv("POE_API_BASE_URL", "https://api.poe.com/v1").strip() or "https://api.poe.com/v1"
|
||
POE_API_KEY = os.getenv("POE_API_KEY", "").strip()
|
||
|
||
|
||
def env_video_model(name: str, default: str) -> str:
|
||
value = os.getenv(name, "").strip()
|
||
if not value:
|
||
return default
|
||
# Older local envs used business aliases as model IDs. Keep those aliases usable
|
||
# while mapping them to concrete Poe video model IDs by default.
|
||
if value.lower() in {"seedance", "kling", "veo", "veo3", "voe"}:
|
||
return default
|
||
return value
|
||
|
||
|
||
VIDEO_MODEL_ALIASES = {
|
||
"seedance": env_video_model("VIDEO_MODEL_SEEDANCE", "seedance-2-fast"),
|
||
"kling": env_video_model("VIDEO_MODEL_KLING", "kling-omni"),
|
||
"veo3": env_video_model("VIDEO_MODEL_VEO3", "veo-3.1-fast"),
|
||
"veo": env_video_model("VIDEO_MODEL_VEO3", "veo-3.1-fast"),
|
||
"voe": env_video_model("VIDEO_MODEL_VEO3", "veo-3.1-fast"),
|
||
}
|
||
VIDEO_API_BASE_URL = os.getenv("VIDEO_API_BASE_URL", "").strip()
|
||
VIDEO_API_KEY = os.getenv("VIDEO_API_KEY", "").strip()
|
||
VIDEO_CREATE_PATH = os.getenv("VIDEO_CREATE_PATH", "/videos").strip() or "/videos"
|
||
VIDEO_CREATE_PATHS = [
|
||
p.strip()
|
||
for p in os.getenv("VIDEO_CREATE_PATHS", f"{VIDEO_CREATE_PATH},/videos/generations,/video/generations").split(",")
|
||
if p.strip()
|
||
]
|
||
VIDEO_STATUS_PATH = os.getenv("VIDEO_STATUS_PATH", "/videos/{id}").strip() or "/videos/{id}"
|
||
VIDEO_CONTENT_PATH = os.getenv("VIDEO_CONTENT_PATH", "/videos/{id}/content").strip() or "/videos/{id}/content"
|
||
VIDEO_DURATION_FIELD = os.getenv("VIDEO_DURATION_FIELD", "seconds").strip() or "seconds"
|
||
VIDEO_POLL_TIMEOUT_SECONDS = max(60, int(os.getenv("VIDEO_POLL_TIMEOUT_SECONDS", "900")))
|
||
|
||
# OpenAI 客户端(OpenAI 兼容网关,含 SKG ezlink)
|
||
from openai import OpenAI
|
||
_llm_client: OpenAI | None = None
|
||
def llm() -> OpenAI:
|
||
global _llm_client
|
||
if _llm_client is None:
|
||
if not LLM_API_KEY:
|
||
raise RuntimeError("LLM_API_KEY 未配置")
|
||
_llm_client = OpenAI(base_url=LLM_BASE_URL or None, api_key=LLM_API_KEY)
|
||
return _llm_client
|
||
|
||
# Pipeline 状态:
|
||
# created → downloading → downloaded(停,等用户点解析)→ splitting → frames_extracted
|
||
# → transcribing → transcribed | failed
|
||
JobStatus = Literal[
|
||
"created", "downloading", "downloaded",
|
||
"splitting", "frames_extracted",
|
||
"transcribing", "transcribed", "failed",
|
||
]
|
||
|
||
KEYFRAME_COUNT = int(os.getenv("KEYFRAME_COUNT", "12"))
|
||
FrameExtractTarget = Literal["transparent_human", "balanced", "subject", "transition", "expression", "motion"]
|
||
FrameExtractMode = Literal["replace", "append"]
|
||
FrameExtractQuality = Literal["auto", "fast", "accurate", "ultra"]
|
||
AnalyzeTask = tuple[str, int, FrameExtractTarget, FrameExtractMode, FrameExtractQuality]
|
||
AssetBackground = Literal["white", "black"]
|
||
AssetSize = Literal["source", "1024", "1536", "2048"]
|
||
AssetQuality = Literal["hd"]
|
||
SubjectKind = Literal["object", "living"]
|
||
SubjectView = str
|
||
SceneMode = Literal["remove_subject", "similar", "style"]
|
||
SceneStyle = Literal["source", "premium_product", "clean_studio", "warm_lifestyle", "cinematic"]
|
||
FRAME_TARGET_LABELS: dict[FrameExtractTarget, str] = {
|
||
"transparent_human": "透明骨架人",
|
||
"balanced": "综合关键帧",
|
||
"subject": "清晰主体",
|
||
"transition": "转场变化",
|
||
"expression": "表情瞬间",
|
||
"motion": "动作峰值",
|
||
}
|
||
|
||
TRANSPARENT_HUMAN_POSITIVE_PROMPT = (
|
||
"Target subject: transparent human character, translucent human body, glass-like human body, clear acrylic skin, "
|
||
"transparent vinyl skin, visible clean white skeleton inside, skeleton visible inside transparent body, "
|
||
"white bones inside clear body, non-horror skeleton character, friendly transparent humanoid, 3D commercial character, "
|
||
"premium wellness character, transparent body with visible spine, transparent body with visible rib cage. "
|
||
"中文目标:透明人体、半透明人体、玻璃人体、亚克力人体、果冻质感人体、外层透明皮肤、身体内部可见骨架、"
|
||
"透明身体里的白色骨骼、干净白色骨架、非恐怖骷髅人、3D广告角色、透明骨架人、可见脊柱、可见肋骨、"
|
||
"可见颈椎、可见骨盆、可见四肢骨骼、透明皮肤包裹骨架。"
|
||
)
|
||
TRANSPARENT_HUMAN_NEGATIVE_PROMPT = (
|
||
"Avoid: normal human, ordinary skeleton, skeleton only without transparent body, horror skeleton, gore, blood, corpse, "
|
||
"zombie, organs, veins, autopsy, surgery, hospital, dark horror scene, blurry person, heavily occluded person, "
|
||
"person too small, product only, background only, no visible skeleton, no transparent body, transparent clothing only. "
|
||
"反向排除:普通真人、普通骷髅、只有骨架没有透明外壳、恐怖骷髅、血腥、腐烂、僵尸、尸体、器官、血管、"
|
||
"解剖、医院、手术、黑暗恐怖场景、模糊人物、遮挡严重、人物太远、只有产品没有人、只有背景没有人、"
|
||
"看不到骨架、看不到透明身体、透明衣服但不是透明身体。"
|
||
)
|
||
TRANSPARENT_HUMAN_QUALIFIED_STANDARD = (
|
||
"A qualified frame must satisfy all core conditions: 1) there is a humanoid character; "
|
||
"2) the outer body is transparent or translucent; 3) a clean white skeleton is clearly visible inside the body; "
|
||
"4) the transparent body and inner skeleton belong to the same character, not a background overlay; "
|
||
"5) the character should occupy at least about 35% of frame height and be easy to inspect; "
|
||
"6) no severe blur, occlusion, or deformation; 7) clean premium commercial wellness style, non-horror."
|
||
)
|
||
FRAME_QUALITY_LABELS: dict[FrameExtractQuality, str] = {
|
||
"auto": "自动",
|
||
"fast": "快速",
|
||
"accurate": "精细",
|
||
"ultra": "极准",
|
||
}
|
||
|
||
|
||
class GeneratedImage(BaseModel):
|
||
id: str # uuid hex 12
|
||
prompt: str
|
||
model: str
|
||
mode: str = "edit" # "edit"(带参考图) | "text"(纯文字)
|
||
url: str # /jobs/{job_id}/frames/{idx}/gen/{id}.jpg
|
||
selected: bool = False
|
||
created_at: float = 0.0
|
||
|
||
|
||
class GeneratedVideo(BaseModel):
|
||
id: str
|
||
provider_id: str = ""
|
||
frame_idx: int
|
||
prompt: str
|
||
model: str = ""
|
||
status: Literal["queued", "in_progress", "completed", "failed"] = "queued"
|
||
url: str = ""
|
||
poster_url: str = ""
|
||
duration: float = 4.0
|
||
progress: int = 0
|
||
error: str = ""
|
||
created_at: float = 0.0
|
||
|
||
|
||
class VideoSourceRef(BaseModel):
|
||
kind: Literal["image", "source_video"] = "image"
|
||
url: str = ""
|
||
|
||
|
||
class StoryboardScene(BaseModel):
|
||
"""分镜头编排:每个 selected 分镜对应一个 scene 描述
|
||
v2: 4 图槽 + 时长(复制粘贴模式)— 主体 / 场景 / 产品 / 动作 各一张图
|
||
v1 字段保留兼容(subject/product/scene/action/reference_ids)"""
|
||
duration: float = 0
|
||
first_image: dict | None = None
|
||
last_image: dict | None = None
|
||
product_images: list[dict] = Field(default_factory=list)
|
||
product_fusion_shots: list[dict] = Field(default_factory=list)
|
||
# 4 图槽:dict 含 {kind, frame_idx, element_id?, cutout_id?, label}
|
||
subject_image: dict | None = None
|
||
scene_image: dict | None = None
|
||
product_image: dict | None = None
|
||
action_image: dict | None = None
|
||
# v1 兼容
|
||
subject: str = ""
|
||
product: str = ""
|
||
scene: str = ""
|
||
action: str = ""
|
||
reference_ids: list[str] = []
|
||
|
||
|
||
class StoryboardImage(BaseModel):
|
||
"""用户从各处"上推"到分镜头编排区的图片"""
|
||
ref_id: str # uuid hex 8
|
||
kind: Literal["keyframe", "cutout", "asset"] # asset = 场景 / 主体视角等组图素材
|
||
frame_idx: int
|
||
element_id: str | None = None # cutout 时
|
||
cutout_id: str | None = None # cutout 时(versioned id;老数据可能 == element_id)
|
||
label: str = "" # 显示用名字
|
||
created_at: float = 0.0
|
||
|
||
|
||
class QualityReport(BaseModel):
|
||
width: int = 0
|
||
height: int = 0
|
||
short_side: int = 0
|
||
sharpness: float = 0.0
|
||
risk: Literal["ok", "warn", "bad"] = "ok"
|
||
warnings: list[str] = Field(default_factory=list)
|
||
|
||
|
||
class TransparentHumanFrameScore(BaseModel):
|
||
transparent_body_score: int = 0
|
||
skeleton_visible_score: int = 0
|
||
human_prominence_score: int = 0
|
||
clarity_score: int = 0
|
||
commercial_style_score: int = 0
|
||
product_usefulness_score: int = 0
|
||
total_score: int = 0
|
||
qualified: bool = False
|
||
reject_reason: str = ""
|
||
notes: str = ""
|
||
|
||
|
||
class SceneAsset(BaseModel):
|
||
id: str
|
||
label: str = ""
|
||
url: str = ""
|
||
width: int = 0
|
||
height: int = 0
|
||
quality: AssetQuality = "hd"
|
||
size: AssetSize = "source"
|
||
scene_mode: SceneMode = "remove_subject"
|
||
scene_style: SceneStyle = "source"
|
||
quality_report: QualityReport | None = None
|
||
created_at: float = 0.0
|
||
|
||
|
||
class SubjectAsset(BaseModel):
|
||
id: str
|
||
view: SubjectView
|
||
label: str = ""
|
||
url: str = ""
|
||
width: int = 0
|
||
height: int = 0
|
||
background: AssetBackground = "white"
|
||
quality: AssetQuality = "hd"
|
||
size: AssetSize = "source"
|
||
source_frame_indices: list[int] = Field(default_factory=list)
|
||
ai_completed: bool = True
|
||
created_at: float = 0.0
|
||
|
||
|
||
class ProductLibraryItem(BaseModel):
|
||
id: str
|
||
handle: str
|
||
title: str
|
||
product_type: str = ""
|
||
image_type: str = "gallery"
|
||
image_index: int = 0
|
||
filename: str
|
||
url: str
|
||
width: int = 0
|
||
height: int = 0
|
||
source_path: str = ""
|
||
white_score: float = 0.0
|
||
near_white_score: float = 0.0
|
||
has_people: bool = False
|
||
tags: list[str] = Field(default_factory=list)
|
||
|
||
|
||
class ProductFusionRegion(BaseModel):
|
||
x: float = 0
|
||
y: float = 0
|
||
w: float = 0
|
||
h: float = 0
|
||
|
||
|
||
class ProductFusionShot(BaseModel):
|
||
id: str = ""
|
||
product_image: dict | None = None
|
||
person_image: dict | None = None
|
||
product_region: ProductFusionRegion | None = None
|
||
scene_image: dict | None = None
|
||
action_text: str = ""
|
||
duration: float = 5
|
||
image_model: str = "gpt-image-2"
|
||
video_model: str = "seedance"
|
||
guide_image: dict | None = None
|
||
|
||
|
||
class KeyElement(BaseModel):
|
||
"""关键帧里识别 / 用户提取的元素 · 多次提取累积多张图,让用户挑选满意的"""
|
||
id: str # uuid hex 8
|
||
name_zh: str
|
||
name_en: str = ""
|
||
position: str = ""
|
||
source: Literal["auto", "manual", "region"] = "manual"
|
||
region: dict | None = None
|
||
# 多张提取图 id(每次 cutout 端点累积新 id)→ /jobs/.../elements/{element_id}/cutouts/{cutout_id}.jpg
|
||
cutouts: list[str] = []
|
||
# 旧字段兼容(v1 单图)· 渲染时 fallback 用,新提取不再写入
|
||
cutout_id: str | None = None
|
||
cutout_background: Literal["white", "black"] = "white"
|
||
subject_kind: SubjectKind = "object"
|
||
subject_assets: list[SubjectAsset] = Field(default_factory=list)
|
||
created_at: float = 0.0
|
||
|
||
|
||
class KeyFrame(BaseModel):
|
||
index: int
|
||
timestamp: float
|
||
url: str
|
||
description: dict | None = None # vision 模型识别结果 {scene, objects, style, suggested_prompt}
|
||
transparent_human_score: TransparentHumanFrameScore | None = None
|
||
cleaned_url: str | None = None # 清洗后干净版(待应用)→ /jobs/{id}/frames/{idx}/cleaned.jpg
|
||
cleaned_applied: bool = False # 是否已用清洗版替换原图(替换后 cleaned_url=null)
|
||
quality_report: QualityReport | None = None
|
||
scene_assets: list[SceneAsset] = Field(default_factory=list)
|
||
elements: list[KeyElement] = [] # 提取的元素清单(持久化)
|
||
storyboard: StoryboardScene | None = None # 分镜头编排字段
|
||
generated_images: list[GeneratedImage] = []
|
||
|
||
|
||
class TranscriptSegment(BaseModel):
|
||
index: int
|
||
start: float
|
||
end: float
|
||
en: str
|
||
zh: str = ""
|
||
|
||
|
||
class AudioScript(BaseModel):
|
||
status: Literal["idle", "rewriting", "completed", "failed"] = "idle"
|
||
source_text: str = ""
|
||
source_zh: str = ""
|
||
rewritten_text: str = ""
|
||
product_brief: str = ""
|
||
rewrite_model: str = ""
|
||
voice_provider: str = ""
|
||
voice_model: str = ""
|
||
voice_id: str = ""
|
||
voice_url: str = ""
|
||
error: str = ""
|
||
created_at: float = 0.0
|
||
|
||
|
||
class Job(BaseModel):
|
||
id: str
|
||
url: str
|
||
status: JobStatus = "created"
|
||
progress: int = 0
|
||
message: str = ""
|
||
video_url: str = ""
|
||
duration: float = 0.0
|
||
width: int = 0
|
||
height: int = 0
|
||
frames: list[KeyFrame] = Field(default_factory=list)
|
||
transcript: list[TranscriptSegment] = Field(default_factory=list)
|
||
audio_script: AudioScript = Field(default_factory=AudioScript)
|
||
storyboard_images: list[StoryboardImage] = Field(default_factory=list)
|
||
generated_videos: list[GeneratedVideo] = Field(default_factory=list)
|
||
error: str = ""
|
||
|
||
|
||
JOBS: dict[str, Job] = {}
|
||
ANALYZE_QUEUE: list[AnalyzeTask] = []
|
||
ANALYZE_WORKER_RUNNING = False
|
||
|
||
|
||
def job_dir(job_id: str) -> Path:
|
||
d = JOBS_DIR / job_id
|
||
d.mkdir(parents=True, exist_ok=True)
|
||
return d
|
||
|
||
|
||
def save_state(job: Job) -> None:
|
||
(job_dir(job.id) / "state.json").write_text(job.model_dump_json(indent=2))
|
||
|
||
|
||
def update(job: Job, **kw) -> None:
|
||
for k, v in kw.items():
|
||
setattr(job, k, v)
|
||
save_state(job)
|
||
|
||
|
||
def public_api_base() -> str:
|
||
return (LLM_BASE_URL or "https://api.openai.com/v1").rstrip("/")
|
||
|
||
|
||
def video_uses_poe() -> bool:
|
||
if VIDEO_API_BASE_URL:
|
||
return VIDEO_API_BASE_URL.rstrip("/") == POE_API_BASE_URL.rstrip("/")
|
||
return bool(POE_API_KEY)
|
||
|
||
|
||
def video_uses_ark() -> bool:
|
||
base = video_api_base()
|
||
return "ark.cn-beijing.volces.com" in base or "ai.skg.com/doubao" in base
|
||
|
||
|
||
def video_api_base() -> str:
|
||
if VIDEO_API_BASE_URL:
|
||
return VIDEO_API_BASE_URL.rstrip("/")
|
||
if POE_API_KEY:
|
||
return POE_API_BASE_URL.rstrip("/")
|
||
return (LLM_BASE_URL or "https://api.openai.com/v1").rstrip("/")
|
||
|
||
|
||
def video_api_key() -> str:
|
||
if VIDEO_API_KEY:
|
||
return VIDEO_API_KEY
|
||
if video_uses_poe():
|
||
return POE_API_KEY
|
||
return LLM_API_KEY
|
||
|
||
|
||
def video_path(template: str, **values: str) -> str:
|
||
path = template.format(**values)
|
||
return path if path.startswith("/") else f"/{path}"
|
||
|
||
|
||
def ensure_video_api_configured() -> None:
|
||
if not video_api_key():
|
||
raise HTTPException(503, "POE_API_KEY、VIDEO_API_KEY 或 LLM_API_KEY 未配置,无法调用生视频 API")
|
||
|
||
|
||
def storyboard_ref_path(job_id: str, ref: dict | None) -> Path | None:
|
||
if not ref:
|
||
return None
|
||
try:
|
||
kind = ref.get("kind")
|
||
frame_idx = int(ref.get("frame_idx"))
|
||
except Exception:
|
||
return None
|
||
if kind == "keyframe":
|
||
clean = job_dir(job_id) / "cleaned" / f"{frame_idx:03d}.jpg"
|
||
if clean.exists():
|
||
return clean
|
||
p = job_dir(job_id) / "frames" / f"{frame_idx:03d}.jpg"
|
||
return p if p.exists() else None
|
||
if kind == "cutout":
|
||
element_id = (ref.get("element_id") or "").strip()
|
||
cutout_id = (ref.get("cutout_id") or "").strip()
|
||
if not element_id:
|
||
return None
|
||
candidates = []
|
||
if cutout_id and cutout_id != element_id:
|
||
candidates.append(job_dir(job_id) / "elements" / f"{frame_idx:03d}_{element_id}_{cutout_id}.jpg")
|
||
candidates.append(job_dir(job_id) / "elements" / f"{frame_idx:03d}_{element_id}.jpg")
|
||
candidates.append(job_dir(job_id) / "elements" / f"{frame_idx:03d}_{element_id}.png")
|
||
for p in candidates:
|
||
if p.exists():
|
||
return p
|
||
if kind == "asset":
|
||
asset_id = (ref.get("element_id") or ref.get("cutout_id") or "").strip()
|
||
if not asset_id:
|
||
return None
|
||
p = job_dir(job_id) / "assets" / f"{asset_id}.jpg"
|
||
return p if p.exists() else None
|
||
return None
|
||
|
||
|
||
def load_product_library_items() -> list[ProductLibraryItem]:
|
||
if not PRODUCT_LIBRARY_MANIFEST.exists():
|
||
return []
|
||
try:
|
||
data = json.loads(PRODUCT_LIBRARY_MANIFEST.read_text(encoding="utf-8"))
|
||
return [ProductLibraryItem(**item) for item in data.get("items", [])]
|
||
except Exception as e:
|
||
raise HTTPException(500, f"product library manifest invalid: {e}")
|
||
|
||
|
||
def find_product_library_item(product_id: str) -> ProductLibraryItem:
|
||
product_id = product_id.strip()
|
||
for item in load_product_library_items():
|
||
if item.id == product_id:
|
||
return item
|
||
raise HTTPException(404, "product library item not found")
|
||
|
||
|
||
def product_library_file(item: ProductLibraryItem) -> Path:
|
||
p = (PRODUCT_LIBRARY_DIR / item.filename).resolve()
|
||
try:
|
||
p.relative_to(PRODUCT_LIBRARY_DIR)
|
||
except ValueError:
|
||
raise HTTPException(400, "invalid product library path")
|
||
if not p.exists():
|
||
raise HTTPException(404, "product library image missing")
|
||
return p
|
||
|
||
|
||
def storyboard_ref_url(job_id: str, ref: dict | None) -> str:
|
||
if not ref:
|
||
return ""
|
||
kind = ref.get("kind")
|
||
frame_idx = ref.get("frame_idx")
|
||
if kind == "keyframe" and frame_idx is not None:
|
||
return f"/jobs/{job_id}/frames/{int(frame_idx)}.jpg"
|
||
if kind == "cutout" and frame_idx is not None and ref.get("element_id"):
|
||
element_id = ref.get("element_id")
|
||
cutout_id = ref.get("cutout_id")
|
||
if cutout_id and cutout_id != element_id:
|
||
return f"/jobs/{job_id}/frames/{int(frame_idx)}/elements/{element_id}/cutouts/{cutout_id}.jpg"
|
||
return f"/jobs/{job_id}/frames/{int(frame_idx)}/elements/{element_id}/cutout.jpg"
|
||
if kind == "asset" and ref.get("element_id"):
|
||
return f"/jobs/{job_id}/assets/{ref.get('element_id')}.jpg"
|
||
return ""
|
||
|
||
|
||
def prepare_video_reference(src: Path, dst: Path, size: tuple[int, int] = (720, 1280)) -> None:
|
||
dst.parent.mkdir(parents=True, exist_ok=True)
|
||
img = Image.open(src).convert("RGB")
|
||
img.thumbnail(size, Image.Resampling.LANCZOS)
|
||
canvas = Image.new("RGB", size, (8, 8, 10))
|
||
x = (size[0] - img.width) // 2
|
||
y = (size[1] - img.height) // 2
|
||
canvas.paste(img, (x, y))
|
||
canvas.save(dst, "JPEG", quality=94)
|
||
|
||
|
||
def update_generated_video(job_id: str, video_id: str, **kw) -> None:
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
return
|
||
updated = []
|
||
for v in job.generated_videos:
|
||
if v.id == video_id:
|
||
data = v.model_dump()
|
||
data.update(kw)
|
||
updated.append(GeneratedVideo(**data))
|
||
else:
|
||
updated.append(v)
|
||
update(job, generated_videos=updated)
|
||
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(_: FastAPI):
|
||
# 启动时从磁盘恢复 jobs(简化版:只列目录)
|
||
for p in JOBS_DIR.iterdir():
|
||
if p.is_dir() and (p / "state.json").exists():
|
||
try:
|
||
JOBS[p.name] = Job.model_validate_json((p / "state.json").read_text())
|
||
except Exception:
|
||
pass
|
||
yield
|
||
|
||
|
||
app = FastAPI(title="SKG TK 二创 API", lifespan=lifespan)
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=CORS_ORIGINS,
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
|
||
# ---------- Pipeline 实现 ----------
|
||
|
||
def run(cmd: list[str], cwd: Path | None = None) -> str:
|
||
res = subprocess.run(cmd, cwd=cwd, capture_output=True, text=True)
|
||
if res.returncode != 0:
|
||
# ffmpeg 把 banner 写 stderr,挑最后几行(真错误一般在末尾)
|
||
tail = "\n".join(res.stderr.splitlines()[-12:]) or res.stderr[-800:]
|
||
raise RuntimeError(f"cmd failed: {' '.join(cmd[:3])}... · {tail}")
|
||
return res.stdout
|
||
|
||
|
||
# ---- 启发式选帧工具 ----
|
||
import imagehash
|
||
import numpy as np
|
||
from PIL import Image, ImageChops, ImageEnhance, ImageFilter, ImageOps
|
||
|
||
|
||
def _sharpness_from_gray(g: np.ndarray) -> float:
|
||
"""Laplacian variance:值越大越清晰,模糊/转场帧值低。"""
|
||
lap = (-4 * g[1:-1, 1:-1]
|
||
+ g[:-2, 1:-1] + g[2:, 1:-1] + g[1:-1, :-2] + g[1:-1, 2:])
|
||
return float(lap.var())
|
||
|
||
|
||
def _frame_metrics(img_path: Path, idx: int, timestamp: float, metric_width: int = 160) -> dict | None:
|
||
"""低清候选帧的本地评分特征。只用于排序,最终仍从原视频抽原尺寸帧。"""
|
||
try:
|
||
with Image.open(img_path) as raw:
|
||
img = raw.convert("RGB")
|
||
h = imagehash.phash(img)
|
||
src_w, src_h = img.size
|
||
metric_height = max(1, round(metric_width * src_h / max(src_w, 1)))
|
||
small = img.resize((metric_width, metric_height))
|
||
except Exception:
|
||
return None
|
||
|
||
arr = np.asarray(small, dtype=np.float32)
|
||
# Rec. 601 luma,保留 0-255 范围,便于和清晰度 / 对比度阈值一起看。
|
||
gray = (0.299 * arr[:, :, 0] + 0.587 * arr[:, :, 1] + 0.114 * arr[:, :, 2]).astype(np.float32)
|
||
gh, gw = gray.shape
|
||
center = gray[gh // 4:max(gh // 4 + 1, gh * 3 // 4), gw // 4:max(gw // 4 + 1, gw * 3 // 4)]
|
||
rg = arr[:, :, 0] - arr[:, :, 1]
|
||
yb = 0.5 * (arr[:, :, 0] + arr[:, :, 1]) - arr[:, :, 2]
|
||
colorfulness = float(np.sqrt(rg.var() + yb.var()) + 0.3 * np.sqrt(rg.mean() ** 2 + yb.mean() ** 2))
|
||
return {
|
||
"path": img_path,
|
||
"idx": idx,
|
||
"timestamp": timestamp,
|
||
"hash": h,
|
||
"gray": gray,
|
||
"sharp": _sharpness_from_gray(gray),
|
||
"center_sharp": _sharpness_from_gray(center),
|
||
"brightness": float(gray.mean()),
|
||
"contrast": float(gray.std()),
|
||
"colorfulness": colorfulness,
|
||
"scene_score": 0.0,
|
||
"motion": 0.0,
|
||
}
|
||
|
||
|
||
def _physical_memory_gb() -> float:
|
||
try:
|
||
page_size = os.sysconf("SC_PAGE_SIZE")
|
||
pages = os.sysconf("SC_PHYS_PAGES")
|
||
return float(page_size * pages) / (1024 ** 3)
|
||
except Exception:
|
||
return 0.0
|
||
|
||
|
||
def _resolve_frame_quality(duration: float, quality: FrameExtractQuality) -> FrameExtractQuality:
|
||
if quality != "auto":
|
||
return quality
|
||
cores = os.cpu_count() or 4
|
||
memory_gb = _physical_memory_gb()
|
||
strong_machine = cores >= 10 and (memory_gb == 0.0 or memory_gb >= 32)
|
||
if strong_machine and duration <= 180:
|
||
return "ultra"
|
||
if strong_machine and duration <= 600:
|
||
return "accurate"
|
||
if cores >= 8 and duration <= 240:
|
||
return "accurate"
|
||
return "fast"
|
||
|
||
|
||
def _scan_profile(duration: float, quality: FrameExtractQuality) -> tuple[float, int, int, int]:
|
||
"""返回 scan_fps / scan_width / metric_width / estimated_count。"""
|
||
if quality == "ultra":
|
||
base_fps, scan_width, cap, metric_width = 12.0, 960, 1800, 320
|
||
elif quality == "accurate":
|
||
base_fps, scan_width, cap, metric_width = 8.0, 720, 900, 240
|
||
else:
|
||
base_fps, scan_width, cap, metric_width = 2.0, 360, 240, 160
|
||
|
||
estimated = max(1, min(int(duration * base_fps), cap))
|
||
scan_fps = max(0.02, min(base_fps, estimated / max(duration, 0.1)))
|
||
return scan_fps, scan_width, metric_width, estimated
|
||
|
||
|
||
def _image_quality_report(img_path: Path, region: dict | None = None) -> QualityReport:
|
||
warnings: list[str] = []
|
||
try:
|
||
with Image.open(img_path) as raw:
|
||
img = raw.convert("RGB")
|
||
width, height = img.size
|
||
metric_width = min(512, width)
|
||
metric_height = max(1, round(metric_width * height / max(width, 1)))
|
||
small = img.resize((metric_width, metric_height))
|
||
gray = np.asarray(ImageOps.grayscale(small), dtype=np.float32)
|
||
sharp = _sharpness_from_gray(gray)
|
||
except Exception:
|
||
return QualityReport(risk="bad", warnings=["无法读取图片质量信息"])
|
||
|
||
short_side = min(width, height)
|
||
if short_side < 720:
|
||
warnings.append(f"短边 {short_side}px 低于 720px,生视频可能偏糊")
|
||
if sharp < 30:
|
||
warnings.append("清晰度偏低,高清增强后仍可能有细节损失")
|
||
|
||
if region:
|
||
try:
|
||
rw = int(float(region.get("w", 0)) * width)
|
||
rh = int(float(region.get("h", 0)) * height)
|
||
if min(rw, rh) < 512:
|
||
warnings.append(f"主体框约 {rw}×{rh}px,主体素材偏小")
|
||
except Exception:
|
||
pass
|
||
|
||
risk: Literal["ok", "warn", "bad"] = "ok"
|
||
if any("低于" in w or "偏小" in w for w in warnings):
|
||
risk = "warn"
|
||
if short_side < 480 or sharp < 12:
|
||
risk = "bad"
|
||
return QualityReport(width=width, height=height, short_side=short_side, sharpness=round(sharp, 2), risk=risk, warnings=warnings)
|
||
|
||
|
||
def _asset_target_size(source_path: Path, size: AssetSize, square: bool = False) -> tuple[int, int]:
|
||
try:
|
||
with Image.open(source_path) as raw:
|
||
src_w, src_h = raw.size
|
||
except Exception:
|
||
src_w, src_h = 1024, 1024
|
||
if size == "source":
|
||
return max(1, src_w), max(1, src_h)
|
||
side = int(size)
|
||
if square:
|
||
return side, side
|
||
if src_w >= src_h:
|
||
return side, max(1, round(side * src_h / max(src_w, 1)))
|
||
return max(1, round(side * src_w / max(src_h, 1))), side
|
||
|
||
|
||
def _normalize_asset_image(
|
||
img_bytes: bytes,
|
||
out_path: Path,
|
||
source_path: Path,
|
||
size: AssetSize,
|
||
background: AssetBackground = "white",
|
||
square: bool = False,
|
||
fill_subject: bool = False,
|
||
) -> tuple[int, int]:
|
||
import io as _io
|
||
target_w, target_h = _asset_target_size(source_path, size, square=square)
|
||
bg = (255, 255, 255) if background == "white" else (0, 0, 0)
|
||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||
with Image.open(_io.BytesIO(img_bytes)) as raw:
|
||
img = raw.convert("RGB")
|
||
if fill_subject:
|
||
diff = ImageChops.difference(img, Image.new("RGB", img.size, bg))
|
||
mask = diff.convert("L").point(lambda px: 255 if px > 18 else 0)
|
||
bbox = mask.getbbox()
|
||
if bbox:
|
||
left, top, right, bottom = bbox
|
||
pad_x = round((right - left) * 0.06)
|
||
pad_y = round((bottom - top) * 0.06)
|
||
img = img.crop((
|
||
max(0, left - pad_x),
|
||
max(0, top - pad_y),
|
||
min(img.width, right + pad_x),
|
||
min(img.height, bottom + pad_y),
|
||
))
|
||
max_w = max(1, round(target_w * 0.92))
|
||
max_h = max(1, round(target_h * 0.94))
|
||
img.thumbnail((max_w, max_h), Image.Resampling.LANCZOS)
|
||
else:
|
||
img.thumbnail((target_w, target_h), Image.Resampling.LANCZOS)
|
||
canvas = Image.new("RGB", (target_w, target_h), bg)
|
||
canvas.paste(img, ((target_w - img.width) // 2, (target_h - img.height) // 2))
|
||
canvas.save(out_path, "JPEG", quality=95)
|
||
return target_w, target_h
|
||
|
||
|
||
def _asset_url(job_id: str, asset_id: str) -> str:
|
||
return f"/jobs/{job_id}/assets/{asset_id}.jpg"
|
||
|
||
|
||
def _find_frame(job: Job, idx: int) -> KeyFrame:
|
||
frame = next((f for f in job.frames if f.index == idx), None)
|
||
if not frame:
|
||
raise HTTPException(404, "frame not found")
|
||
return frame
|
||
|
||
|
||
def _source_frame_path(job_id: str, idx: int) -> Path:
|
||
cleaned_path = job_dir(job_id) / "cleaned" / f"{idx:03d}.jpg"
|
||
if cleaned_path.exists():
|
||
return cleaned_path
|
||
return job_dir(job_id) / "frames" / f"{idx:03d}.jpg"
|
||
|
||
|
||
def _focus_source_for_element(job_id: str, idx: int, el: KeyElement) -> tuple[Path, Path | None]:
|
||
import tempfile as _tempfile
|
||
src = _source_frame_path(job_id, idx)
|
||
tmp_focus: Path | None = None
|
||
model_src = src
|
||
if not el.region:
|
||
return model_src, tmp_focus
|
||
try:
|
||
im = Image.open(src).convert("RGB")
|
||
W, H = im.size
|
||
r = el.region
|
||
x = max(0.0, min(1.0, float(r.get("x", 0))))
|
||
y = max(0.0, min(1.0, float(r.get("y", 0))))
|
||
w = max(0.0, min(1.0 - x, float(r.get("w", 0))))
|
||
h = max(0.0, min(1.0 - y, float(r.get("h", 0))))
|
||
cx, cy = x + w / 2, y + h / 2
|
||
ew, eh = w * 1.6, h * 1.6
|
||
x0 = max(0.0, cx - ew / 2); y0 = max(0.0, cy - eh / 2)
|
||
x1 = min(1.0, cx + ew / 2); y1 = min(1.0, cy + eh / 2)
|
||
left, top, right, bottom = int(x0 * W), int(y0 * H), int(x1 * W), int(y1 * H)
|
||
if right - left > 8 and bottom - top > 8:
|
||
cropped = im.crop((left, top, right, bottom))
|
||
tmp = _tempfile.NamedTemporaryFile(suffix=".jpg", delete=False)
|
||
cropped.save(tmp.name, format="JPEG", quality=92)
|
||
tmp.close()
|
||
tmp_focus = Path(tmp.name)
|
||
model_src = tmp_focus
|
||
except Exception as e:
|
||
print(f"[focus source crop failed, fallback to full frame] {e}", flush=True)
|
||
return model_src, tmp_focus
|
||
|
||
|
||
def _make_reference_contact_sheet(job_id: str, frame_indices: list[int], out_path: Path) -> Path | None:
|
||
paths: list[Path] = []
|
||
seen: set[int] = set()
|
||
for idx in frame_indices:
|
||
if idx in seen:
|
||
continue
|
||
seen.add(idx)
|
||
p = _source_frame_path(job_id, idx)
|
||
if p.exists():
|
||
paths.append(p)
|
||
if len(paths) >= 6:
|
||
break
|
||
if len(paths) <= 1:
|
||
return None
|
||
|
||
thumbs: list[Image.Image] = []
|
||
for p in paths:
|
||
try:
|
||
im = Image.open(p).convert("RGB")
|
||
im.thumbnail((420, 420), Image.Resampling.LANCZOS)
|
||
canvas = Image.new("RGB", (420, 420), (245, 245, 245))
|
||
canvas.paste(im, ((420 - im.width) // 2, (420 - im.height) // 2))
|
||
thumbs.append(canvas)
|
||
except Exception:
|
||
continue
|
||
if len(thumbs) <= 1:
|
||
return None
|
||
|
||
cols = 3 if len(thumbs) > 2 else 2
|
||
rows = (len(thumbs) + cols - 1) // cols
|
||
sheet = Image.new("RGB", (cols * 420, rows * 420), (245, 245, 245))
|
||
for i, thumb in enumerate(thumbs):
|
||
sheet.paste(thumb, ((i % cols) * 420, (i // cols) * 420))
|
||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||
sheet.save(out_path, "JPEG", quality=92)
|
||
return out_path
|
||
|
||
|
||
SUBJECT_VIEW_LABELS: dict[str, str] = {
|
||
"front": "正面",
|
||
"back": "背面",
|
||
"left": "左侧",
|
||
"right": "右侧",
|
||
"three_quarter_left": "左前 45°",
|
||
"three_quarter_right": "右前 45°",
|
||
"side": "侧面",
|
||
"side_walk": "侧面走路",
|
||
"top": "顶部视角",
|
||
"bottom": "底部视角",
|
||
"expression_neutral": "中性表情",
|
||
"expression_smile": "微笑表情",
|
||
"expression_happy": "开心表情",
|
||
"expression_angry": "生气表情",
|
||
"expression_sad": "难过表情",
|
||
"expression_relaxed": "放松表情",
|
||
"expression_serious": "严肃表情",
|
||
"expression_surprised": "惊讶表情",
|
||
"action_walk": "走路动作",
|
||
"action_turn": "转身动作",
|
||
"action_sit": "坐姿动作",
|
||
"action_hold": "手持动作",
|
||
"action_use": "使用动作",
|
||
}
|
||
|
||
|
||
def _subject_view_labels(kind: SubjectKind, requested: list[str] | None = None) -> list[tuple[SubjectView, str]]:
|
||
if requested:
|
||
normalized: list[str] = []
|
||
for raw in requested:
|
||
key = "".join(ch for ch in str(raw).strip().lower() if ch.isalnum() or ch == "_")
|
||
if key and key not in normalized:
|
||
normalized.append(key)
|
||
return [(key, SUBJECT_VIEW_LABELS.get(key, key.replace("_", " "))) for key in normalized[:12]]
|
||
if kind == "living":
|
||
return [
|
||
("front", "正面站立"),
|
||
("back", "背面站立"),
|
||
("left", "左侧站立"),
|
||
("right", "右侧站立"),
|
||
("three_quarter_left", "左前 45° 站立"),
|
||
("three_quarter_right", "右前 45° 站立"),
|
||
]
|
||
return [
|
||
("front", "正面"),
|
||
("back", "背面"),
|
||
("left", "左侧"),
|
||
("right", "右侧"),
|
||
("top", "顶部"),
|
||
("bottom", "底部"),
|
||
]
|
||
|
||
|
||
def _attach_temporal_metrics(items: list[dict]) -> None:
|
||
"""相邻低清帧差异:转场 / 动作目标依赖它,不需要逐帧高分辨率扫描。"""
|
||
for i, it in enumerate(items):
|
||
prev_delta = 0.0
|
||
next_delta = 0.0
|
||
if i > 0:
|
||
prev_delta = float(np.mean(np.abs(it["gray"] - items[i - 1]["gray"])) / 255.0)
|
||
if i + 1 < len(items):
|
||
next_delta = float(np.mean(np.abs(items[i + 1]["gray"] - it["gray"])) / 255.0)
|
||
it["scene_score"] = max(prev_delta, next_delta)
|
||
it["motion"] = (prev_delta + next_delta) / 2.0
|
||
|
||
|
||
def _normalize_item_metrics(items: list[dict]) -> None:
|
||
for key in ("sharp", "center_sharp", "contrast", "colorfulness", "scene_score", "motion"):
|
||
vals = [float(it.get(key, 0.0)) for it in items if float(it.get(key, 0.0)) > 0]
|
||
cap = float(np.percentile(vals, 95)) if vals else 1.0
|
||
if cap <= 0:
|
||
cap = 1.0
|
||
for it in items:
|
||
it[f"{key}_n"] = min(float(it.get(key, 0.0)) / cap, 1.0)
|
||
|
||
|
||
def _target_score(item: dict, target: FrameExtractTarget) -> float:
|
||
sharp = float(item.get("sharp_n", 0.0))
|
||
center = float(item.get("center_sharp_n", 0.0))
|
||
contrast = float(item.get("contrast_n", 0.0))
|
||
color = float(item.get("colorfulness_n", 0.0))
|
||
scene = float(item.get("scene_score_n", 0.0))
|
||
motion = float(item.get("motion_n", 0.0))
|
||
|
||
if target == "transparent_human":
|
||
# 透明骨架人仍先依赖本地清晰度 / 中心主体 / 对比度筛候选,
|
||
# 后续再交给 Vision 逐张语义验收。
|
||
score = center * 0.45 + sharp * 0.30 + contrast * 0.15 + color * 0.10
|
||
elif target == "subject":
|
||
score = center * 0.48 + sharp * 0.25 + contrast * 0.17 + color * 0.10
|
||
elif target == "transition":
|
||
score = scene * 0.55 + sharp * 0.28 + contrast * 0.12 + color * 0.05
|
||
elif target == "expression":
|
||
# 没有额外视觉模型时,表情/动物瞬间只能用中心细节 + 清晰 + 轻微动作变化做本地近似。
|
||
score = center * 0.40 + sharp * 0.24 + motion * 0.18 + contrast * 0.12 + color * 0.06
|
||
elif target == "motion":
|
||
score = motion * 0.45 + sharp * 0.30 + center * 0.15 + contrast * 0.10
|
||
else:
|
||
score = sharp * 0.45 + scene * 0.22 + center * 0.15 + contrast * 0.12 + color * 0.06
|
||
|
||
brightness = float(item.get("brightness", 0.0))
|
||
raw_contrast = float(item.get("contrast", 0.0))
|
||
if raw_contrast < 4 or brightness < 8 or brightness > 247:
|
||
return score * 0.15
|
||
if raw_contrast < 9:
|
||
return score * 0.65
|
||
return score
|
||
|
||
|
||
def _select_keyframes(candidates: list[dict], n: int, target: FrameExtractTarget, dup_threshold: int = 8) -> list[dict]:
|
||
"""
|
||
candidates: 按时间排序的低清候选帧评分项
|
||
n: 目标帧数
|
||
dup_threshold: pHash 汉明距离 < 此值视为相似(默认 8,64bit hash 大致 ~12.5% 像素差)
|
||
"""
|
||
if len(candidates) <= n:
|
||
return candidates
|
||
|
||
_attach_temporal_metrics(candidates)
|
||
_normalize_item_metrics(candidates)
|
||
for it in candidates:
|
||
it["score"] = _target_score(it, target)
|
||
|
||
# 去重:相似帧保留当前目标下分数更高的
|
||
deduped: list[dict] = []
|
||
for it in candidates:
|
||
dup = None
|
||
for kept in deduped:
|
||
if (it["hash"] - kept["hash"]) < dup_threshold:
|
||
dup = kept
|
||
break
|
||
if dup is None:
|
||
deduped.append(it)
|
||
elif it["score"] > dup["score"]:
|
||
deduped[deduped.index(dup)] = it
|
||
|
||
# 时序分桶:把候选时间轴等分 n 段,每段取当前目标下最优的
|
||
total = len(candidates)
|
||
buckets: list[list[dict]] = [[] for _ in range(n)]
|
||
for it in deduped:
|
||
b = min(int(it["idx"] * n / total), n - 1)
|
||
buckets[b].append(it)
|
||
|
||
selected: list[dict] = []
|
||
for b in buckets:
|
||
if b:
|
||
selected.append(max(b, key=lambda x: x["score"]))
|
||
|
||
# 空桶补足:从未选的 deduped 里按目标分数补
|
||
chosen_paths = {it["path"] for it in selected}
|
||
remaining = sorted([it for it in deduped if it["path"] not in chosen_paths],
|
||
key=lambda x: -x["score"])
|
||
while len(selected) < n and remaining:
|
||
selected.append(remaining.pop(0))
|
||
|
||
# 按时间排序输出
|
||
selected.sort(key=lambda x: x["idx"])
|
||
return selected
|
||
|
||
|
||
def _rank_keyframe_candidates(candidates: list[dict], target: FrameExtractTarget, limit: int, dup_threshold: int = 8) -> list[dict]:
|
||
if not candidates:
|
||
return []
|
||
_attach_temporal_metrics(candidates)
|
||
_normalize_item_metrics(candidates)
|
||
for it in candidates:
|
||
it["score"] = _target_score(it, target)
|
||
deduped: list[dict] = []
|
||
for it in sorted(candidates, key=lambda x: -float(x.get("score", 0.0))):
|
||
if any((it["hash"] - kept["hash"]) < dup_threshold for kept in deduped):
|
||
continue
|
||
deduped.append(it)
|
||
if len(deduped) >= limit:
|
||
break
|
||
return deduped
|
||
|
||
|
||
def _score_transparent_human_frame(img_path: Path) -> TransparentHumanFrameScore:
|
||
if not LLM_API_KEY:
|
||
return TransparentHumanFrameScore(
|
||
qualified=False,
|
||
reject_reason="LLM_API_KEY 未配置,无法进行透明骨架人语义验收",
|
||
)
|
||
img_b64 = base64.b64encode(img_path.read_bytes()).decode("ascii")
|
||
prompt = (
|
||
"You are a strict keyframe quality inspector for a SKG transparent-human video recreation workflow. "
|
||
+ TRANSPARENT_HUMAN_POSITIVE_PROMPT + " "
|
||
+ TRANSPARENT_HUMAN_NEGATIVE_PROMPT + " "
|
||
+ TRANSPARENT_HUMAN_QUALIFIED_STANDARD + "\n\n"
|
||
"Score this single frame using exactly these dimensions:\n"
|
||
"- transparent_body_score: 0-25, clear transparent/translucent outer human body shell.\n"
|
||
"- skeleton_visible_score: 0-25, clean white skeleton clearly visible inside the body.\n"
|
||
"- human_prominence_score: 0-15, character centered/large/easy to identify, ideally >=35% frame height.\n"
|
||
"- clarity_score: 0-15, no severe motion blur, occlusion, or deformation.\n"
|
||
"- commercial_style_score: 0-10, clean premium non-horror advertising/wellness style.\n"
|
||
"- product_usefulness_score: 0-10, useful for later SKG product video generation; neck/shoulder/waist/eye/foot/knee area visible when relevant.\n"
|
||
"Reject if any of these is true: normal human only; ordinary skeleton only; product/background only; transparent person too far; severe blur; more than half occluded; horror/corpse/autopsy/surgery/hospital; unable to judge.\n"
|
||
"Output strict JSON only with keys: transparent_body_score, skeleton_visible_score, human_prominence_score, clarity_score, commercial_style_score, product_usefulness_score, qualified, reject_reason, notes."
|
||
)
|
||
try:
|
||
resp = llm().chat.completions.create(
|
||
model=VISION_MODEL,
|
||
messages=[{"role": "user", "content": [
|
||
{"type": "text", "text": prompt},
|
||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img_b64}"}},
|
||
]}],
|
||
response_format={"type": "json_object"},
|
||
temperature=0.1,
|
||
max_tokens=1200,
|
||
)
|
||
raw = (resp.choices[0].message.content or "").strip()
|
||
if raw.startswith("```"):
|
||
import re as _re
|
||
match = _re.search(r"\{[\s\S]*\}", raw)
|
||
raw = match.group(0) if match else raw
|
||
data = json.loads(raw)
|
||
except Exception as e:
|
||
return TransparentHumanFrameScore(qualified=False, reject_reason=f"AI 评分失败:{e}")
|
||
|
||
def score(name: str, cap: int) -> int:
|
||
try:
|
||
value = int(round(float(data.get(name, 0))))
|
||
except Exception:
|
||
value = 0
|
||
return max(0, min(cap, value))
|
||
|
||
item = TransparentHumanFrameScore(
|
||
transparent_body_score=score("transparent_body_score", 25),
|
||
skeleton_visible_score=score("skeleton_visible_score", 25),
|
||
human_prominence_score=score("human_prominence_score", 15),
|
||
clarity_score=score("clarity_score", 15),
|
||
commercial_style_score=score("commercial_style_score", 10),
|
||
product_usefulness_score=score("product_usefulness_score", 10),
|
||
reject_reason=str(data.get("reject_reason", "") or ""),
|
||
notes=str(data.get("notes", "") or ""),
|
||
)
|
||
item.total_score = (
|
||
item.transparent_body_score
|
||
+ item.skeleton_visible_score
|
||
+ item.human_prominence_score
|
||
+ item.clarity_score
|
||
+ item.commercial_style_score
|
||
+ item.product_usefulness_score
|
||
)
|
||
item.qualified = bool(data.get("qualified")) and (
|
||
item.transparent_body_score >= 18
|
||
and item.skeleton_visible_score >= 18
|
||
and item.human_prominence_score >= 8
|
||
and item.clarity_score >= 8
|
||
and item.commercial_style_score >= 6
|
||
and item.product_usefulness_score >= 4
|
||
and item.total_score >= 72
|
||
)
|
||
if not item.qualified and not item.reject_reason:
|
||
item.reject_reason = f"透明骨架人评分不足,总分 {item.total_score}/100"
|
||
return item
|
||
|
||
|
||
def ffprobe_meta(mp4: Path) -> dict:
|
||
out = run([
|
||
"ffprobe", "-v", "error", "-print_format", "json", "-show_streams", "-show_format", str(mp4),
|
||
])
|
||
return json.loads(out)
|
||
|
||
|
||
async def pipeline_download(job_id: str) -> None:
|
||
"""阶段 1:仅下载(或上传跳过),落 source.mp4,停在 downloaded 等用户点解析。"""
|
||
job = JOBS[job_id]
|
||
d = job_dir(job_id)
|
||
try:
|
||
mp4 = d / "source.mp4"
|
||
if mp4.exists():
|
||
update(job, status="downloading", message="本地上传 · 跳过下载", progress=15)
|
||
else:
|
||
update(job, status="downloading", message="yt-dlp 下载中…", progress=5)
|
||
run([
|
||
"yt-dlp", "-f", "best[ext=mp4]/best",
|
||
"-o", str(mp4),
|
||
"--no-warnings", "--no-playlist",
|
||
"--retries", "3",
|
||
job.url,
|
||
])
|
||
if not mp4.exists():
|
||
raise RuntimeError("下载完成但找不到 source.mp4")
|
||
|
||
meta = ffprobe_meta(mp4)
|
||
v_stream = next((s for s in meta["streams"] if s["codec_type"] == "video"), None)
|
||
duration = float(meta["format"]["duration"])
|
||
update(
|
||
job,
|
||
status="downloaded",
|
||
video_url=f"/jobs/{job_id}/video.mp4",
|
||
duration=duration,
|
||
width=int(v_stream["width"]) if v_stream else 0,
|
||
height=int(v_stream["height"]) if v_stream else 0,
|
||
progress=25,
|
||
message=f"视频就绪 · {duration:.1f}s · 等待解析",
|
||
)
|
||
except Exception as e:
|
||
update(job, status="failed", error=str(e), message="下载失败")
|
||
|
||
|
||
async def pipeline_analyze(
|
||
job_id: str,
|
||
frame_count: int = KEYFRAME_COUNT,
|
||
target: FrameExtractTarget = "transparent_human",
|
||
mode: FrameExtractMode = "replace",
|
||
quality: FrameExtractQuality = "auto",
|
||
) -> None:
|
||
"""阶段 2:拆音轨 + 抽关键帧。ASR/翻译是独立文案轨,不阻塞视觉素材流。"""
|
||
job = JOBS[job_id]
|
||
d = job_dir(job_id)
|
||
try:
|
||
mp4 = d / "source.mp4"
|
||
if not mp4.exists():
|
||
raise RuntimeError("source.mp4 不存在,先完成下载")
|
||
|
||
wav = d / "audio.wav"
|
||
if wav.exists():
|
||
update(job, status="splitting", message="复用音轨 · 准备抽帧…", progress=35)
|
||
else:
|
||
update(job, status="splitting", message="ffmpeg 拆分音轨…", progress=35)
|
||
run([
|
||
"ffmpeg", "-y", "-i", str(mp4),
|
||
"-vn", "-ac", "1", "-ar", "16000", "-c:a", "pcm_s16le",
|
||
str(wav),
|
||
])
|
||
|
||
n = max(1, min(int(frame_count), 20))
|
||
target_label = FRAME_TARGET_LABELS.get(target, FRAME_TARGET_LABELS["balanced"])
|
||
duration = max(float(job.duration or 1.0), 0.1)
|
||
effective_quality = _resolve_frame_quality(duration, quality)
|
||
effective_quality_label = FRAME_QUALITY_LABELS.get(effective_quality, FRAME_QUALITY_LABELS["accurate"])
|
||
quality_label = f"自动·{effective_quality_label}" if quality == "auto" else effective_quality_label
|
||
scan_fps, scan_width, metric_width, estimated_scan_count = _scan_profile(duration, effective_quality)
|
||
|
||
update(job, message=f"本地{quality_label}扫描 · {target_label} · 约 {estimated_scan_count} 帧…", progress=45)
|
||
frames_dir = d / "frames"
|
||
replacing = mode == "replace"
|
||
existing_frames = list(job.frames) if not replacing else []
|
||
if replacing and frames_dir.exists():
|
||
shutil.rmtree(frames_dir)
|
||
frames_dir.mkdir(parents=True, exist_ok=True)
|
||
scan_dir = d / "frame_scan"
|
||
if scan_dir.exists():
|
||
shutil.rmtree(scan_dir)
|
||
scan_dir.mkdir(parents=True)
|
||
|
||
# 1) 低分辨率、低帧率扫描。扫描图只用于候选评分,最终不直接作为关键帧。
|
||
run([
|
||
"ffmpeg", "-y", "-i", str(mp4),
|
||
"-vf", f"fps={scan_fps:.4f},scale={scan_width}:-2",
|
||
"-q:v", "4",
|
||
str(scan_dir / "s_%05d.jpg"),
|
||
])
|
||
|
||
scan_paths = sorted(scan_dir.glob("s_*.jpg"))
|
||
if not scan_paths:
|
||
raise RuntimeError("低清扫描没有生成候选帧")
|
||
|
||
candidates: list[dict] = []
|
||
for i, p in enumerate(scan_paths):
|
||
t = min(i / scan_fps, max(duration - 0.05, 0.0))
|
||
item = _frame_metrics(p, i, t, metric_width)
|
||
if item:
|
||
candidates.append(item)
|
||
if not candidates:
|
||
raise RuntimeError("候选帧评分失败")
|
||
|
||
# 2) 目标化筛选:pHash 去重 + 清晰度 / 中心细节 / 转场变化 / 动作强度。
|
||
# 透明骨架人目标会先扩大候选池,再用 Vision 逐张验收;不合格自动换下一帧。
|
||
semantic_transparent = target == "transparent_human"
|
||
if semantic_transparent:
|
||
selection_count = min(len(candidates), min(max(n * 10, 24), 48))
|
||
update(job, message=f"{quality_label}筛选透明骨架人候选 · 本地 {selection_count} / {len(candidates)} 张…", progress=58)
|
||
chosen = _rank_keyframe_candidates(candidates, target, selection_count)
|
||
else:
|
||
selection_count = n if replacing else min(len(candidates), max(n * 4, n + len(existing_frames) + 2))
|
||
update(job, message=f"{quality_label}筛选 · {target_label} · {n} / {len(candidates)} 张…", progress=60)
|
||
chosen = _select_keyframes(candidates, selection_count, target)
|
||
|
||
# 3) 只对最终选中的时间点,从原视频抽高质量关键帧。
|
||
renamed: list[KeyFrame] = []
|
||
chosen_sorted = chosen if semantic_transparent else sorted(chosen, key=lambda it: float(it["timestamp"]))
|
||
existing_timestamps = [float(f.timestamp) for f in existing_frames]
|
||
next_idx = max((int(f.index) for f in existing_frames), default=-1) + 1
|
||
rejected_by_ai = 0
|
||
for attempt, item in enumerate(chosen_sorted, start=1):
|
||
if len(renamed) >= n:
|
||
break
|
||
t = float(item["timestamp"])
|
||
if not replacing and any(abs(t - old) < 0.35 for old in existing_timestamps):
|
||
continue
|
||
idx = next_idx + len(renamed)
|
||
dst = frames_dir / f"{idx:03d}.jpg"
|
||
run([
|
||
"ffmpeg", "-y", "-ss", f"{t:.3f}", "-i", str(mp4),
|
||
"-frames:v", "1",
|
||
"-pix_fmt", "yuvj420p", "-q:v", "3",
|
||
str(dst),
|
||
])
|
||
transparent_score: TransparentHumanFrameScore | None = None
|
||
if semantic_transparent:
|
||
update(
|
||
job,
|
||
message=f"AI 验收透明骨架人 · 已通过 {len(renamed)}/{n} · 候选 {attempt}/{len(chosen_sorted)}…",
|
||
progress=min(68, 60 + int(attempt / max(1, len(chosen_sorted)) * 8)),
|
||
)
|
||
transparent_score = _score_transparent_human_frame(dst)
|
||
if not transparent_score.qualified:
|
||
rejected_by_ai += 1
|
||
try:
|
||
dst.unlink()
|
||
except OSError:
|
||
pass
|
||
reason = transparent_score.reject_reason or f"总分 {transparent_score.total_score}/100"
|
||
update(job, message=f"AI 退回候选帧 · {reason[:48]} · 自动换下一帧", progress=65)
|
||
continue
|
||
renamed.append(KeyFrame(
|
||
index=idx,
|
||
timestamp=round(t, 2),
|
||
url=f"/jobs/{job_id}/frames/{idx}.jpg",
|
||
transparent_human_score=transparent_score,
|
||
))
|
||
existing_timestamps.append(t)
|
||
|
||
if semantic_transparent and not renamed:
|
||
raise RuntimeError("AI 未找到合格透明骨架人帧:需要透明/半透明人体外壳 + 清楚白色骨架 + 非恐怖广告感")
|
||
|
||
# 4) 清理扫描目录
|
||
shutil.rmtree(scan_dir, ignore_errors=True)
|
||
|
||
merged_frames = sorted(existing_frames + renamed, key=lambda f: f.timestamp)
|
||
action_label = "追加" if not replacing else "抽取"
|
||
|
||
final_message = (
|
||
f"已按「{quality_label} · {target_label}」AI验收 {action_label} {len(renamed)} 张"
|
||
+ (f" · 退回 {rejected_by_ai} 张" if semantic_transparent else "")
|
||
+ f" · 共 {len(merged_frames)} 张"
|
||
) if semantic_transparent else (
|
||
f"已按「{quality_label} · {target_label}」{action_label} {len(renamed)} 张关键帧 · 共 {len(merged_frames)} 张"
|
||
)
|
||
update(
|
||
job,
|
||
status="frames_extracted",
|
||
frames=merged_frames,
|
||
progress=70,
|
||
message=final_message,
|
||
)
|
||
|
||
except Exception as e:
|
||
update(job, status="failed", error=str(e), message="解析失败")
|
||
|
||
|
||
async def analyze_queue_worker() -> None:
|
||
global ANALYZE_WORKER_RUNNING
|
||
ANALYZE_WORKER_RUNNING = True
|
||
try:
|
||
while ANALYZE_QUEUE:
|
||
job_id, frames, target, mode, quality = ANALYZE_QUEUE.pop(0)
|
||
if job_id not in JOBS:
|
||
continue
|
||
await pipeline_analyze(job_id, frames, target, mode, quality)
|
||
if ANALYZE_QUEUE:
|
||
for pos, (queued_job_id, *_rest) in enumerate(ANALYZE_QUEUE, start=1):
|
||
queued_job = JOBS.get(queued_job_id)
|
||
if queued_job:
|
||
update(queued_job, status="splitting", progress=30, message=f"排队等待抽帧 · 前方 {pos - 1} 个任务")
|
||
finally:
|
||
ANALYZE_WORKER_RUNNING = False
|
||
|
||
|
||
# ---------- 音频转写 + 翻译 + SKG 改写 + MiniMax 配音 ----------
|
||
|
||
def _transcribe_sync(wav: Path) -> list[dict]:
|
||
"""whisper-1 verbose_json → segments[{start, end, text}]"""
|
||
with wav.open("rb") as f:
|
||
resp = llm().audio.transcriptions.create(
|
||
file=(wav.name, f, "audio/wav"),
|
||
model=ASR_MODEL,
|
||
response_format="verbose_json",
|
||
timestamp_granularities=["segment"],
|
||
)
|
||
raw = resp.model_dump() if hasattr(resp, "model_dump") else resp
|
||
segments = raw.get("segments") or []
|
||
# 兜底:网关如果不返回 segments,把全文当一段
|
||
if not segments and raw.get("text"):
|
||
segments = [{"start": 0.0, "end": float(raw.get("duration", 0) or 0), "text": raw["text"]}]
|
||
return segments
|
||
|
||
|
||
def _translate_sync(segments: list[dict]) -> list[str]:
|
||
"""批量翻译为中文,按段返回"""
|
||
payload = [{"i": i, "en": s.get("text", "").strip()} for i, s in enumerate(segments)]
|
||
prompt = (
|
||
"你是字幕翻译。把下列英文字幕段翻译为简体中文,保持原意、口语化、自然流畅。"
|
||
"严格返回 JSON 数组,不要任何 markdown 或多余文字,schema: "
|
||
'[{"i": 0, "zh": "..."}, ...]\n\n输入:\n'
|
||
+ json.dumps(payload, ensure_ascii=False)
|
||
)
|
||
resp = llm().chat.completions.create(
|
||
model=TRANSLATE_MODEL,
|
||
messages=[{"role": "user", "content": prompt}],
|
||
response_format={"type": "json_object"},
|
||
temperature=0.2,
|
||
)
|
||
content = resp.choices[0].message.content or "[]"
|
||
try:
|
||
data = json.loads(content)
|
||
if isinstance(data, dict):
|
||
for k in ("data", "items", "result", "translations"):
|
||
if k in data and isinstance(data[k], list):
|
||
data = data[k]
|
||
break
|
||
if not isinstance(data, list):
|
||
data = []
|
||
except json.JSONDecodeError:
|
||
data = []
|
||
zh_by_idx: dict[int, str] = {}
|
||
for it in data:
|
||
if isinstance(it, dict) and "i" in it:
|
||
zh_by_idx[int(it["i"])] = str(it.get("zh", ""))
|
||
return [zh_by_idx.get(i, "") for i in range(len(segments))]
|
||
|
||
|
||
def _transcript_join(segments: list[TranscriptSegment], field: Literal["en", "zh"]) -> str:
|
||
lines: list[str] = []
|
||
for s in segments:
|
||
text = (s.zh if field == "zh" else s.en).strip()
|
||
if text:
|
||
lines.append(f"[{s.start:.1f}-{s.end:.1f}s] {text}")
|
||
return "\n".join(lines)
|
||
|
||
|
||
def _fallback_audio_script(segments: list[TranscriptSegment]) -> str:
|
||
joined = " ".join((s.zh or s.en).strip() for s in segments if (s.zh or s.en).strip())
|
||
if not joined:
|
||
return "日常疲惫不用硬扛。戴上 SKG,让肩颈慢慢放松,跟着呼吸找回轻松状态。"
|
||
return (
|
||
"把日常紧绷交给 SKG。贴合身体需要放松的位置,热敷与按摩节奏自然陪伴,"
|
||
"让每一次短暂休息都更轻松、更有质感。"
|
||
)
|
||
|
||
|
||
def _rewrite_audio_script_sync(segments: list[TranscriptSegment]) -> tuple[str, str]:
|
||
fallback = _fallback_audio_script(segments)
|
||
if not LLM_API_KEY:
|
||
return fallback, "LLM_API_KEY 未配置,使用本地 SKG 模板"
|
||
source_text = _transcript_join(segments, "en")
|
||
source_zh = _transcript_join(segments, "zh")
|
||
prompt = (
|
||
"你是 SKG 短视频口播编导。根据参考视频音频转写,抽取它的表达结构、情绪节奏和可复用卖点,"
|
||
"改写成适合 SKG 按摩/放松产品二创视频的中文口播文案。\n"
|
||
"要求:\n"
|
||
"1. 输出 35-90 个中文字,适合 8-18 秒短视频配音。\n"
|
||
"2. 口语化、干净、高级,能直接给 TTS 朗读。\n"
|
||
"3. 不承诺治疗、治愈、医学疗效,不夸大。\n"
|
||
"4. 不复刻原视频品牌/人物/价格/平台话术,只保留表达结构。\n"
|
||
"5. 如果参考转写信息不足,按产品信息生成通用 SKG 放松口播。\n"
|
||
'严格返回 JSON:{"rewritten_text":"..."}。\n\n'
|
||
f"SKG 产品信息:{AUDIO_PRODUCT_BRIEF}\n\n"
|
||
f"英文转写:\n{source_text or '无'}\n\n"
|
||
f"中文翻译:\n{source_zh or '无'}"
|
||
)
|
||
try:
|
||
resp = llm().chat.completions.create(
|
||
model=AUDIO_REWRITE_MODEL,
|
||
messages=[
|
||
{"role": "system", "content": "只输出合法 JSON,不要解释,不要 markdown。"},
|
||
{"role": "user", "content": prompt},
|
||
],
|
||
response_format={"type": "json_object"},
|
||
temperature=0.45,
|
||
max_tokens=600,
|
||
)
|
||
raw = (resp.choices[0].message.content or "").strip()
|
||
if raw.startswith("```"):
|
||
import re as _re
|
||
match = _re.search(r"\{[\s\S]*\}", raw)
|
||
raw = match.group(0) if match else raw
|
||
data = json.loads(raw)
|
||
text = str(data.get("rewritten_text", "")).strip()
|
||
return (text or fallback), ""
|
||
except Exception as e:
|
||
return fallback, f"改写失败,使用本地模板:{e}"
|
||
|
||
|
||
def _minimax_tts_url() -> str:
|
||
if MINIMAX_TTS_BASE_URL.endswith("/v1/t2a_v2"):
|
||
return MINIMAX_TTS_BASE_URL
|
||
return f"{MINIMAX_TTS_BASE_URL}/v1/t2a_v2"
|
||
|
||
|
||
def _minimax_tts_sync(job_id: str, text: str) -> str:
|
||
if not MINIMAX_API_KEY:
|
||
raise RuntimeError("MINIMAX_API_KEY 未配置,未生成配音")
|
||
if not text.strip():
|
||
raise RuntimeError("改写文案为空,未生成配音")
|
||
payload = {
|
||
"model": MINIMAX_TTS_MODEL,
|
||
"text": text.strip()[:9500],
|
||
"stream": False,
|
||
"language_boost": "Chinese",
|
||
"output_format": "hex",
|
||
"voice_setting": {
|
||
"voice_id": MINIMAX_TTS_VOICE_ID,
|
||
"speed": 1,
|
||
"vol": 1,
|
||
"pitch": 0,
|
||
},
|
||
"audio_setting": {
|
||
"sample_rate": 32000,
|
||
"bitrate": 128000,
|
||
"format": "mp3",
|
||
"channel": 1,
|
||
},
|
||
}
|
||
resp = httpx.post(
|
||
_minimax_tts_url(),
|
||
headers={"Authorization": f"Bearer {MINIMAX_API_KEY}", "Content-Type": "application/json"},
|
||
json=payload,
|
||
timeout=90,
|
||
)
|
||
resp.raise_for_status()
|
||
data = resp.json()
|
||
base_resp = data.get("base_resp") or {}
|
||
if int(base_resp.get("status_code", 0) or 0) != 0:
|
||
raise RuntimeError(base_resp.get("status_msg") or "MiniMax TTS 返回失败")
|
||
audio_hex = ((data.get("data") or {}).get("audio") or "").strip()
|
||
if not audio_hex:
|
||
raise RuntimeError("MiniMax TTS 未返回 audio hex")
|
||
try:
|
||
audio_bytes = bytes.fromhex(audio_hex)
|
||
except ValueError as e:
|
||
raise RuntimeError(f"MiniMax TTS audio hex 无法解析:{e}") from e
|
||
out = job_dir(job_id) / "audio_script.mp3"
|
||
out.write_bytes(audio_bytes)
|
||
return f"/jobs/{job_id}/audio-script.mp3"
|
||
|
||
|
||
def _build_audio_script_sync(job_id: str, segments: list[TranscriptSegment]) -> AudioScript:
|
||
source_text = _transcript_join(segments, "en")
|
||
source_zh = _transcript_join(segments, "zh")
|
||
rewritten, rewrite_error = _rewrite_audio_script_sync(segments)
|
||
voice_url = ""
|
||
voice_error = ""
|
||
try:
|
||
voice_url = _minimax_tts_sync(job_id, rewritten)
|
||
except Exception as e:
|
||
voice_error = str(e)
|
||
errors = ";".join(x for x in [rewrite_error, voice_error] if x)
|
||
return AudioScript(
|
||
status="completed",
|
||
source_text=source_text,
|
||
source_zh=source_zh,
|
||
rewritten_text=rewritten,
|
||
product_brief=AUDIO_PRODUCT_BRIEF,
|
||
rewrite_model=AUDIO_REWRITE_MODEL,
|
||
voice_provider="minimax",
|
||
voice_model=MINIMAX_TTS_MODEL,
|
||
voice_id=MINIMAX_TTS_VOICE_ID,
|
||
voice_url=voice_url,
|
||
error=errors,
|
||
created_at=time.time(),
|
||
)
|
||
|
||
|
||
async def pipeline_transcribe(job_id: str) -> None:
|
||
job = JOBS[job_id]
|
||
d = job_dir(job_id)
|
||
wav = d / "audio.wav"
|
||
try:
|
||
if not wav.exists():
|
||
raise RuntimeError("audio.wav 不存在")
|
||
|
||
if not LLM_API_KEY:
|
||
# 无 key 模式:mock 数据
|
||
update(job, status="transcribing", message="ASR (mock) …", progress=75)
|
||
await asyncio.sleep(1.0)
|
||
mock = [
|
||
TranscriptSegment(index=0, start=0.0, end=3.5,
|
||
en="Welcome back, today we're testing something new.",
|
||
zh="欢迎回来,今天我们要测试一些新东西。"),
|
||
TranscriptSegment(index=1, start=3.5, end=7.2,
|
||
en="This device looks really sleek and minimal.",
|
||
zh="这个设备看起来非常时尚和简约。"),
|
||
]
|
||
update(
|
||
job,
|
||
transcript=mock,
|
||
audio_script=AudioScript(
|
||
status="rewriting",
|
||
source_text=_transcript_join(mock, "en"),
|
||
source_zh=_transcript_join(mock, "zh"),
|
||
product_brief=AUDIO_PRODUCT_BRIEF,
|
||
rewrite_model=AUDIO_REWRITE_MODEL,
|
||
voice_provider="minimax",
|
||
voice_model=MINIMAX_TTS_MODEL,
|
||
voice_id=MINIMAX_TTS_VOICE_ID,
|
||
),
|
||
message="ASR mock 完成,生成 SKG 改写文案…",
|
||
progress=92,
|
||
)
|
||
audio_script = await asyncio.to_thread(_build_audio_script_sync, job_id, mock)
|
||
update(job, transcript=mock, status="transcribed", progress=100,
|
||
audio_script=audio_script,
|
||
message="转录完成(MOCK · 未设 LLM_API_KEY)")
|
||
return
|
||
|
||
# 1) whisper ASR
|
||
update(job, status="transcribing", message=f"{ASR_MODEL} 转录中…", progress=78)
|
||
segments = await asyncio.to_thread(_transcribe_sync, wav)
|
||
if not segments:
|
||
raise RuntimeError("ASR 返回 0 段(可能无人声 / 格式问题)")
|
||
|
||
# 先把英文段落落到 job 上(让 UI 提前看到,翻译再补 zh)
|
||
en_only = [
|
||
TranscriptSegment(
|
||
index=i,
|
||
start=float(s.get("start", 0)),
|
||
end=float(s.get("end", 0)),
|
||
en=str(s.get("text", "")).strip(),
|
||
zh="",
|
||
)
|
||
for i, s in enumerate(segments)
|
||
]
|
||
update(job, transcript=en_only, message=f"ASR 完成 · {len(en_only)} 段,开始翻译…", progress=88)
|
||
|
||
# 2) Gemini 翻译
|
||
zh_list = await asyncio.to_thread(_translate_sync, segments)
|
||
full = [
|
||
TranscriptSegment(
|
||
index=seg.index, start=seg.start, end=seg.end, en=seg.en,
|
||
zh=zh_list[i] if i < len(zh_list) else "",
|
||
)
|
||
for i, seg in enumerate(en_only)
|
||
]
|
||
update(
|
||
job,
|
||
transcript=full,
|
||
audio_script=AudioScript(
|
||
status="rewriting",
|
||
source_text=_transcript_join(full, "en"),
|
||
source_zh=_transcript_join(full, "zh"),
|
||
product_brief=AUDIO_PRODUCT_BRIEF,
|
||
rewrite_model=AUDIO_REWRITE_MODEL,
|
||
voice_provider="minimax",
|
||
voice_model=MINIMAX_TTS_MODEL,
|
||
voice_id=MINIMAX_TTS_VOICE_ID,
|
||
),
|
||
message="翻译完成,生成 SKG 改写文案与 MiniMax 配音…",
|
||
progress=94,
|
||
)
|
||
audio_script = await asyncio.to_thread(_build_audio_script_sync, job_id, full)
|
||
update(job, transcript=full, status="transcribed", progress=100,
|
||
audio_script=audio_script,
|
||
message=f"转录完成 · {len(full)} 段({ASR_MODEL} + {TRANSLATE_MODEL})")
|
||
|
||
except Exception as e:
|
||
update(
|
||
job,
|
||
status="failed",
|
||
audio_script=AudioScript(status="failed", error=str(e), created_at=time.time()),
|
||
error=str(e),
|
||
message="转录失败",
|
||
)
|
||
|
||
|
||
def _image_edit_call(
|
||
image_path: Path,
|
||
prompt: str,
|
||
model: str | None = None,
|
||
models: list[str] | None = None,
|
||
fallback_text: bool = False,
|
||
max_attempts: int = 3,
|
||
max_side: int = 1024,
|
||
) -> tuple[bytes, str]:
|
||
"""通用 image edit 调用 · 失败重试 + 可选 text fallback。
|
||
返回 (image_bytes, effective_mode) where effective_mode in {"edit","text"}。
|
||
失败 raise RuntimeError。
|
||
输入图自动 resize 到 max_side(默认 1024)边长后再 base64,避免大图把 Gemini
|
||
function call 输入挤超阈值导致 incomplete_generation。
|
||
models: 多模型轮换列表,重试时换 model;不传则单一 model 重试。"""
|
||
import base64 as b64lib
|
||
import io as _io
|
||
import time as _time
|
||
import httpx
|
||
from PIL import Image as _PILImage
|
||
if not LLM_API_KEY:
|
||
raise RuntimeError("LLM_API_KEY 未配置")
|
||
# model 优先级:models 列表 > 单个 model 参数 > IMAGE_MODEL
|
||
if models and len(models) > 0:
|
||
models_cycle = list(models)
|
||
else:
|
||
models_cycle = [model or IMAGE_MODEL]
|
||
model = models_cycle[0]
|
||
# 缩到 max_side 内
|
||
try:
|
||
im = _PILImage.open(image_path)
|
||
if max(im.size) > max_side:
|
||
im.thumbnail((max_side, max_side), _PILImage.LANCZOS)
|
||
buf = _io.BytesIO()
|
||
im.convert("RGB").save(buf, format="JPEG", quality=88)
|
||
img_bytes_in = buf.getvalue()
|
||
except Exception:
|
||
# PIL 失败兜底走原文件
|
||
img_bytes_in = image_path.read_bytes()
|
||
img_b64 = b64lib.b64encode(img_bytes_in).decode("ascii")
|
||
data_uri = f"data:image/jpeg;base64,{img_b64}"
|
||
|
||
plan: list[str] = ["edit"] * max_attempts
|
||
if fallback_text:
|
||
plan.append("text")
|
||
|
||
last_err = ""
|
||
resp_data: dict = {}
|
||
effective_mode = "edit"
|
||
for attempt, current_mode in enumerate(plan):
|
||
# 多模型轮换:第 N 次重试用第 N 个 model(不够时用最后一个)
|
||
current_model = models_cycle[min(attempt, len(models_cycle) - 1)]
|
||
try:
|
||
if current_mode == "edit":
|
||
with httpx.Client(timeout=120) as client:
|
||
r = client.post(
|
||
f"{LLM_BASE_URL}/images/generations",
|
||
headers={
|
||
"Authorization": f"Bearer {LLM_API_KEY}",
|
||
"Content-Type": "application/json",
|
||
},
|
||
json={"model": current_model, "prompt": prompt, "image": data_uri, "n": 1},
|
||
)
|
||
r.raise_for_status()
|
||
resp_data = r.json()
|
||
else:
|
||
resp = llm().images.generate(model=current_model, prompt=prompt, n=1)
|
||
resp_data = resp.model_dump() if hasattr(resp, "model_dump") else {"data": [{"b64_json": resp.data[0].b64_json}]}
|
||
if resp_data.get("data"):
|
||
effective_mode = current_mode
|
||
model = current_model # 记录实际成功的 model
|
||
break
|
||
err_obj = resp_data.get("error") or {}
|
||
last_err = f"empty data · {err_obj.get('code', '')} · {str(err_obj.get('message', ''))[:200]} · model={current_model}"
|
||
except httpx.HTTPStatusError as e:
|
||
body = e.response.text
|
||
# 多模型轮换场景:除明确不可恢复(4xx 鉴权类)外都重试换 model
|
||
sc = e.response.status_code
|
||
fatal = sc in (401, 403)
|
||
last_err = f"HTTP {sc}: {body[:200]} · model={current_model}"
|
||
if fatal:
|
||
raise RuntimeError(f"image edit HTTP {sc}: {body[:300]}")
|
||
except Exception as e:
|
||
last_err = f"{type(e).__name__}: {e} · model={current_model}"
|
||
|
||
if attempt < len(plan) - 1:
|
||
next_model = models_cycle[min(attempt + 1, len(models_cycle) - 1)]
|
||
tag = f"retry {attempt + 1}/{len(plan)} → {next_model}"
|
||
print(f"[image edit {tag}] {last_err}", flush=True)
|
||
_time.sleep(1.0)
|
||
|
||
data_arr = resp_data.get("data", [])
|
||
if not data_arr:
|
||
raise RuntimeError(f"image edit failed after {len(plan)} attempts: {last_err}")
|
||
b64 = data_arr[0].get("b64_json")
|
||
if not b64:
|
||
raise RuntimeError("image edit returned no b64_json")
|
||
return b64lib.b64decode(b64), effective_mode
|
||
|
||
|
||
# ---------- API 路由 ----------
|
||
|
||
class CreateJobReq(BaseModel):
|
||
url: str
|
||
|
||
|
||
class TranslateReq(BaseModel):
|
||
text: str
|
||
target: Literal["en", "zh"] = "en"
|
||
|
||
|
||
@app.post("/translate")
|
||
def translate_text(req: TranslateReq) -> dict:
|
||
"""单条文本翻译(给生图自定义提取元素 zh→en 用)"""
|
||
import re as _re
|
||
text = req.text.strip()
|
||
if not text:
|
||
return {"text": ""}
|
||
if not LLM_API_KEY:
|
||
raise HTTPException(503, "LLM_API_KEY 未配置")
|
||
|
||
target_label = "English" if req.target == "en" else "Simplified Chinese"
|
||
prompt = (
|
||
f"Translate the following text into concise {target_label}, suitable as an element "
|
||
"label in an image-generation prompt. Output only the translation itself — no quotes, "
|
||
"no punctuation, no explanation, no markdown.\n\n"
|
||
f"Input: {text}"
|
||
)
|
||
try:
|
||
resp = llm().chat.completions.create(
|
||
model=TRANSLATE_MODEL,
|
||
messages=[{"role": "user", "content": prompt}],
|
||
temperature=0.2,
|
||
max_tokens=200,
|
||
)
|
||
out = (resp.choices[0].message.content or "").strip()
|
||
if not out:
|
||
rc = getattr(resp.choices[0].message, "reasoning_content", "") or ""
|
||
if rc:
|
||
out = rc.strip().splitlines()[-1].strip()
|
||
out = _re.sub(r'^[\'"「『]+|[\'"」』]+$', "", out).strip()
|
||
return {"text": out}
|
||
except Exception as e:
|
||
raise HTTPException(500, f"translate failed: {e}")
|
||
|
||
|
||
@app.get("/health")
|
||
def health() -> dict:
|
||
return {
|
||
"ok": True,
|
||
"llm_configured": bool(LLM_API_KEY),
|
||
"base_url": LLM_BASE_URL or "openai-default",
|
||
"models": {
|
||
"asr": ASR_MODEL,
|
||
"translate": TRANSLATE_MODEL,
|
||
"rewrite": REWRITE_MODEL,
|
||
"audio_rewrite": AUDIO_REWRITE_MODEL,
|
||
"minimax_tts": MINIMAX_TTS_MODEL,
|
||
"minimax_voice": MINIMAX_TTS_VOICE_ID,
|
||
"minimax_configured": bool(MINIMAX_API_KEY),
|
||
"video": VIDEO_MODEL,
|
||
"video_aliases": VIDEO_MODEL_ALIASES,
|
||
"video_provider": "poe" if video_uses_poe() else ("ark" if video_uses_ark() else "custom"),
|
||
"video_base_url": video_api_base(),
|
||
"video_configured": bool(video_api_key()),
|
||
"video_create_paths": VIDEO_CREATE_PATHS,
|
||
},
|
||
}
|
||
|
||
|
||
class JobSummary(BaseModel):
|
||
id: str
|
||
url: str
|
||
status: JobStatus
|
||
progress: int = 0
|
||
message: str = ""
|
||
duration: float = 0.0
|
||
width: int = 0
|
||
height: int = 0
|
||
video_url: str = ""
|
||
frame_count: int = 0
|
||
video_count: int = 0
|
||
thumbnail: str = ""
|
||
error: str = ""
|
||
mtime: float = 0.0
|
||
|
||
|
||
@app.get("/jobs", response_model=list[JobSummary])
|
||
def list_jobs(limit: int | None = None) -> list[JobSummary]:
|
||
"""所有 job 的精简列表,按磁盘 state.json mtime 倒序(最新优先)。前端无 ?job= 时用它回填历史。"""
|
||
items: list[JobSummary] = []
|
||
for job_id, job in JOBS.items():
|
||
state_path = JOBS_DIR / job_id / "state.json"
|
||
mtime = state_path.stat().st_mtime if state_path.exists() else 0.0
|
||
thumb = f"/jobs/{job_id}/frames/{job.frames[0].index}.jpg" if job.frames else ""
|
||
items.append(JobSummary(
|
||
id=job.id,
|
||
url=job.url,
|
||
status=job.status,
|
||
progress=job.progress,
|
||
message=job.message,
|
||
duration=job.duration,
|
||
width=job.width,
|
||
height=job.height,
|
||
video_url=job.video_url,
|
||
frame_count=len(job.frames),
|
||
video_count=len(job.generated_videos),
|
||
thumbnail=thumb,
|
||
error=job.error,
|
||
mtime=mtime,
|
||
))
|
||
items.sort(key=lambda s: s.mtime, reverse=True)
|
||
if limit is not None and limit > 0:
|
||
items = items[:limit]
|
||
return items
|
||
|
||
|
||
@app.post("/jobs", response_model=Job)
|
||
async def create_job(req: CreateJobReq, bg: BackgroundTasks) -> Job:
|
||
if not req.url.strip():
|
||
raise HTTPException(400, "url required")
|
||
job_id = uuid.uuid4().hex[:12]
|
||
job = Job(id=job_id, url=req.url.strip())
|
||
JOBS[job_id] = job
|
||
save_state(job)
|
||
bg.add_task(pipeline_download, job_id)
|
||
return job
|
||
|
||
|
||
@app.post("/jobs/upload", response_model=Job)
|
||
async def create_job_from_upload(bg: BackgroundTasks, file: UploadFile = File(...)) -> Job:
|
||
if not file.filename:
|
||
raise HTTPException(400, "file required")
|
||
ext = Path(file.filename).suffix.lower()
|
||
if ext not in {".mp4", ".mov", ".webm", ".mkv", ".m4v"}:
|
||
raise HTTPException(400, f"unsupported video format: {ext}")
|
||
|
||
job_id = uuid.uuid4().hex[:12]
|
||
d = job_dir(job_id)
|
||
mp4 = d / "source.mp4"
|
||
with mp4.open("wb") as f:
|
||
while chunk := await file.read(1024 * 1024):
|
||
f.write(chunk)
|
||
if not mp4.exists() or mp4.stat().st_size == 0:
|
||
raise HTTPException(500, "upload failed")
|
||
|
||
job = Job(id=job_id, url=f"upload://{file.filename}")
|
||
JOBS[job_id] = job
|
||
save_state(job)
|
||
bg.add_task(pipeline_download, job_id)
|
||
return job
|
||
|
||
|
||
@app.post("/jobs/{job_id}/analyze", response_model=Job)
|
||
async def trigger_analyze(
|
||
job_id: str,
|
||
bg: BackgroundTasks,
|
||
frames: int = KEYFRAME_COUNT,
|
||
target: FrameExtractTarget = "transparent_human",
|
||
mode: FrameExtractMode = "replace",
|
||
quality: FrameExtractQuality = "auto",
|
||
) -> Job:
|
||
global ANALYZE_WORKER_RUNNING
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
if job.status not in {"downloaded", "frames_extracted", "transcribed", "failed"}:
|
||
raise HTTPException(409, f"status must be downloaded/failed, got {job.status}")
|
||
ANALYZE_QUEUE.append((job_id, frames, target, mode, quality))
|
||
position = len(ANALYZE_QUEUE)
|
||
update(
|
||
job,
|
||
status="splitting",
|
||
progress=30,
|
||
message="排队等待抽帧" if ANALYZE_WORKER_RUNNING or position > 1 else "准备抽帧…",
|
||
)
|
||
if not ANALYZE_WORKER_RUNNING:
|
||
ANALYZE_WORKER_RUNNING = True
|
||
bg.add_task(analyze_queue_worker)
|
||
return job
|
||
|
||
|
||
@app.post("/jobs/{job_id}/frames", response_model=Job)
|
||
def add_manual_frame(job_id: str, t: float) -> Job:
|
||
"""从指定时间戳手动抽 1 帧追加到 job.frames"""
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
if not job.video_url:
|
||
raise HTTPException(400, "video not ready")
|
||
d = job_dir(job_id)
|
||
mp4 = d / "source.mp4"
|
||
if not mp4.exists():
|
||
raise HTTPException(400, "source.mp4 missing")
|
||
frames_dir = d / "frames"
|
||
frames_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 新 index:max(existing)+1(即使列表已按 ts 排序,文件名用 index 保持稳定)
|
||
next_idx = max((f.index for f in job.frames), default=-1) + 1
|
||
out = frames_dir / f"{next_idx:03d}.jpg"
|
||
try:
|
||
run([
|
||
"ffmpeg", "-y", "-ss", str(t), "-i", str(mp4),
|
||
"-frames:v", "1", "-pix_fmt", "yuvj420p", "-q:v", "3",
|
||
str(out),
|
||
])
|
||
except RuntimeError as e:
|
||
raise HTTPException(500, f"ffmpeg failed: {e}")
|
||
|
||
new_frame = KeyFrame(
|
||
index=next_idx,
|
||
timestamp=round(float(t), 2),
|
||
url=f"/jobs/{job_id}/frames/{next_idx}.jpg",
|
||
)
|
||
merged = sorted(list(job.frames) + [new_frame], key=lambda f: f.timestamp)
|
||
update(job, frames=merged, message=f"已手动加帧({t:.1f}s),共 {len(merged)} 张")
|
||
return job
|
||
|
||
|
||
@app.get("/jobs/{job_id}", response_model=Job)
|
||
def get_job(job_id: str) -> Job:
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
return job
|
||
|
||
|
||
@app.delete("/jobs/{job_id}")
|
||
def delete_job(job_id: str) -> dict[str, bool | str]:
|
||
d = (JOBS_DIR / job_id).resolve()
|
||
if JOBS_DIR not in d.parents:
|
||
raise HTTPException(400, "invalid job id")
|
||
job = JOBS.pop(job_id, None)
|
||
if not job and not d.exists():
|
||
raise HTTPException(404, "job not found")
|
||
if d.exists():
|
||
shutil.rmtree(d)
|
||
return {"ok": True, "id": job_id}
|
||
|
||
|
||
@app.post("/jobs/{job_id}/transcribe", response_model=Job)
|
||
async def trigger_transcribe(job_id: str, bg: BackgroundTasks) -> Job:
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
if job.status != "frames_extracted":
|
||
raise HTTPException(409, f"status must be frames_extracted, got {job.status}")
|
||
bg.add_task(pipeline_transcribe, job_id)
|
||
return job
|
||
|
||
|
||
@app.get("/jobs/{job_id}/video.mp4")
|
||
def get_video(job_id: str):
|
||
p = job_dir(job_id) / "source.mp4"
|
||
if not p.exists():
|
||
raise HTTPException(404, "video not found")
|
||
return FileResponse(p, media_type="video/mp4")
|
||
|
||
|
||
@app.get("/jobs/{job_id}/audio-script.mp3")
|
||
def get_audio_script(job_id: str):
|
||
p = job_dir(job_id) / "audio_script.mp3"
|
||
if not p.exists():
|
||
raise HTTPException(404, "audio script not found")
|
||
return FileResponse(p, media_type="audio/mpeg")
|
||
|
||
|
||
@app.get("/jobs/{job_id}/frames/{idx}.jpg")
|
||
def get_frame(job_id: str, idx: int):
|
||
p = job_dir(job_id) / "frames" / f"{idx:03d}.jpg"
|
||
if not p.exists():
|
||
raise HTTPException(404, "frame not found")
|
||
return FileResponse(p, media_type="image/jpeg")
|
||
|
||
|
||
class GenerateReq(BaseModel):
|
||
prompt: str
|
||
extra_prompt: str = "" # ✓ 需要的元素(正向)
|
||
negative_prompt: str = "" # ✗ 不需要的元素(负向)
|
||
model: str = "" # 留空用 IMAGE_MODEL 默认
|
||
mode: str = "edit" # "edit" 带参考图,"text" 纯文字
|
||
from_selected: bool = False # True 时优先用 frame.selected 的生成图作 reference(迭代),否则原关键帧
|
||
|
||
|
||
@app.post("/jobs/{job_id}/frames/{idx}/generate", response_model=Job)
|
||
def generate_image(job_id: str, idx: int, req: GenerateReq) -> Job:
|
||
"""根据关键帧 + prompt 生成新图(image-to-image 或 text-to-image)"""
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
frame = next((f for f in job.frames if f.index == idx), None)
|
||
if not frame:
|
||
raise HTTPException(404, "frame not found")
|
||
frame_path = job_dir(job_id) / "frames" / f"{idx:03d}.jpg"
|
||
if not frame_path.exists():
|
||
raise HTTPException(404, "frame file missing")
|
||
|
||
# 决定 i2i 参考图:from_selected=True 且存在 selected 生成图 → 用它(迭代);否则原关键帧
|
||
reference_path = frame_path
|
||
reference_source = "keyframe"
|
||
if req.from_selected:
|
||
sel = next((g for g in frame.generated_images if g.selected), None)
|
||
if sel:
|
||
sel_path = job_dir(job_id) / "gen" / f"{idx:03d}_{sel.id}.jpg"
|
||
if sel_path.exists():
|
||
reference_path = sel_path
|
||
reference_source = f"gen:{sel.id[:6]}"
|
||
|
||
full_prompt = req.prompt.strip()
|
||
if req.extra_prompt.strip():
|
||
full_prompt = f"{full_prompt}. Include: {req.extra_prompt.strip()}"
|
||
if req.negative_prompt.strip():
|
||
full_prompt = f"{full_prompt}. Avoid: {req.negative_prompt.strip()}"
|
||
if not full_prompt:
|
||
raise HTTPException(400, "prompt required")
|
||
|
||
model = req.model or IMAGE_MODEL
|
||
gen_id = uuid.uuid4().hex[:12]
|
||
|
||
import base64 as b64lib
|
||
import time as _time
|
||
import httpx
|
||
|
||
img_b64: str | None = None
|
||
if req.mode == "edit":
|
||
img_b64 = b64lib.b64encode(reference_path.read_bytes()).decode("ascii")
|
||
|
||
# 尝试 i2i 最多 3 次,全失败时降级 text-only 再试 1 次
|
||
plan: list[str] = ([req.mode] * 3) if req.mode == "edit" else [req.mode]
|
||
if req.mode == "edit":
|
||
plan.append("text") # i2i 都失败时自动降级
|
||
resp_data: dict = {}
|
||
last_err = ""
|
||
effective_mode = req.mode
|
||
for attempt, current_mode in enumerate(plan):
|
||
try:
|
||
if current_mode == "edit":
|
||
data_uri = f"data:image/jpeg;base64,{img_b64}"
|
||
# OpenAI SDK 不直接支持 image 参数,用底层 httpx
|
||
with httpx.Client(timeout=120) as client:
|
||
r = client.post(
|
||
f"{LLM_BASE_URL}/images/generations",
|
||
headers={
|
||
"Authorization": f"Bearer {LLM_API_KEY}",
|
||
"Content-Type": "application/json",
|
||
},
|
||
json={
|
||
"model": model,
|
||
"prompt": full_prompt,
|
||
"image": data_uri,
|
||
"n": 1,
|
||
},
|
||
)
|
||
r.raise_for_status()
|
||
resp_data = r.json()
|
||
else:
|
||
# text-only
|
||
resp = llm().images.generate(model=model, prompt=full_prompt, n=1)
|
||
resp_data = resp.model_dump() if hasattr(resp, "model_dump") else {"data": [{"b64_json": resp.data[0].b64_json}]}
|
||
|
||
if resp_data.get("data"):
|
||
effective_mode = current_mode
|
||
break
|
||
err_obj = resp_data.get("error") or {}
|
||
last_err = f"empty data · {err_obj.get('code', '')} · {str(err_obj.get('message', ''))[:200]}"
|
||
except httpx.HTTPStatusError as e:
|
||
body = e.response.text
|
||
transient = (
|
||
e.response.status_code >= 500
|
||
or "incomplete_generation" in body
|
||
or "rate_limit" in body
|
||
or "timeout" in body.lower()
|
||
)
|
||
last_err = f"HTTP {e.response.status_code}: {body[:200]}"
|
||
if not transient:
|
||
raise HTTPException(500, f"image gen HTTP {e.response.status_code}: {body[:300]}")
|
||
except Exception as e:
|
||
last_err = f"{type(e).__name__}: {e}"
|
||
|
||
if attempt < len(plan) - 1:
|
||
next_mode = plan[attempt + 1]
|
||
tag = f"fallback → {next_mode}" if next_mode != current_mode else f"retry {attempt + 1}/{len(plan)}"
|
||
print(f"[image gen {tag}] {last_err}", flush=True)
|
||
_time.sleep(1.5 * (attempt + 1))
|
||
|
||
data_arr = resp_data.get("data", [])
|
||
if not data_arr:
|
||
raise HTTPException(500, f"image gen failed after {len(plan)} attempts: {last_err}")
|
||
|
||
item = data_arr[0]
|
||
b64 = item.get("b64_json")
|
||
if not b64:
|
||
raise HTTPException(500, "image gen returned no b64_json")
|
||
|
||
# 保存到本地 jobs/<id>/gen/<idx>_<gen_id>.jpg
|
||
gen_dir = job_dir(job_id) / "gen"
|
||
gen_dir.mkdir(parents=True, exist_ok=True)
|
||
out_path = gen_dir / f"{idx:03d}_{gen_id}.jpg"
|
||
out_path.write_bytes(b64lib.b64decode(b64))
|
||
|
||
new_gen = GeneratedImage(
|
||
id=gen_id,
|
||
prompt=full_prompt,
|
||
model=model,
|
||
mode=effective_mode,
|
||
url=f"/jobs/{job_id}/frames/{idx}/gen/{gen_id}.jpg",
|
||
selected=False,
|
||
created_at=_time.time(),
|
||
)
|
||
|
||
# 写回 job.frames
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
f.generated_images = f.generated_images + [new_gen]
|
||
update(job, frames=job.frames, message=f"生图完成 · 分镜 {idx + 1}")
|
||
return job
|
||
|
||
|
||
@app.get("/jobs/{job_id}/frames/{idx}/gen/{gen_id}.jpg")
|
||
def get_generated_image(job_id: str, idx: int, gen_id: str):
|
||
p = job_dir(job_id) / "gen" / f"{idx:03d}_{gen_id}.jpg"
|
||
if not p.exists():
|
||
raise HTTPException(404, "generated image not found")
|
||
return FileResponse(p, media_type="image/jpeg")
|
||
|
||
|
||
class SelectGenReq(BaseModel):
|
||
selected: bool
|
||
|
||
|
||
@app.post("/jobs/{job_id}/frames/{idx}/gen/{gen_id}/select", response_model=Job)
|
||
def select_generated(job_id: str, idx: int, gen_id: str, req: SelectGenReq) -> Job:
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
for f in job.frames:
|
||
if f.index != idx:
|
||
continue
|
||
for g in f.generated_images:
|
||
# 单选:该帧只能选一张
|
||
if g.id == gen_id:
|
||
g.selected = req.selected
|
||
else:
|
||
g.selected = False
|
||
break
|
||
update(job, frames=job.frames)
|
||
return job
|
||
|
||
|
||
@app.post("/jobs/{job_id}/frames/{idx}/describe", response_model=Job)
|
||
def describe_frame(job_id: str, idx: int) -> Job:
|
||
"""调 vision 模型识别该关键帧,返回结构化描述。"""
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
frame = next((f for f in job.frames if f.index == idx), None)
|
||
if not frame:
|
||
raise HTTPException(404, "frame not found")
|
||
p = job_dir(job_id) / "frames" / f"{idx:03d}.jpg"
|
||
if not p.exists():
|
||
raise HTTPException(404, "frame file not found")
|
||
|
||
import base64 as b64lib
|
||
import re as _re
|
||
img_b64 = b64lib.b64encode(p.read_bytes()).decode("ascii")
|
||
|
||
prompt = (
|
||
"请识别这张图,输出严格 JSON(不要 markdown 不要解释,不要思考):\n"
|
||
'{\n'
|
||
' "scene": "一句话描述场景",\n'
|
||
' "objects": [{"name": "物体名(中文)", "position": "在画面哪里", "color": "颜色", "extract_prompt": "用于提取该元素的英文 prompt"}],\n'
|
||
' "style": "整体风格 / 打光 / 色调(一句话)",\n'
|
||
' "suggested_prompt": "适合用作下游生图的完整英文 prompt",\n'
|
||
' "transparent_human_assessment": {"transparent_body_score": 0, "skeleton_visible_score": 0, "human_prominence_score": 0, "clarity_score": 0, "commercial_style_score": 0, "product_usefulness_score": 0, "qualified": false, "reject_reason": "如果不合格说明原因"}\n'
|
||
'}\n'
|
||
"要求:objects 列出 3-8 个画面里**可独立提取**的主要元素,extract_prompt 用于后续 image edit 模型。"
|
||
"transparent_human_assessment 按透明骨架人标准评分:"
|
||
+ TRANSPARENT_HUMAN_POSITIVE_PROMPT + " "
|
||
+ TRANSPARENT_HUMAN_NEGATIVE_PROMPT + " "
|
||
+ TRANSPARENT_HUMAN_QUALIFIED_STANDARD
|
||
)
|
||
|
||
last_err = ""
|
||
data = None
|
||
for attempt in range(3):
|
||
try:
|
||
resp = llm().chat.completions.create(
|
||
model=VISION_MODEL,
|
||
messages=[{"role": "user", "content": [
|
||
{"type": "text", "text": prompt},
|
||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img_b64}"}},
|
||
]}],
|
||
response_format={"type": "json_object"},
|
||
temperature=0.3,
|
||
max_tokens=3000,
|
||
)
|
||
content = (resp.choices[0].message.content or "").strip()
|
||
if not content:
|
||
# thinking 模型可能 content 空;尝试取 reasoning_content 里挖 JSON
|
||
rc = getattr(resp.choices[0].message, "reasoning_content", "") or ""
|
||
m = _re.search(r"\{[\s\S]*\}", rc)
|
||
content = m.group(0) if m else ""
|
||
# 剥掉 ```json ... ``` 包装
|
||
content = _re.sub(r"^```(?:json)?\s*|\s*```$", "", content).strip()
|
||
if not content:
|
||
last_err = f"empty content (attempt {attempt + 1})"
|
||
continue
|
||
data = json.loads(content)
|
||
break
|
||
except json.JSONDecodeError as e:
|
||
last_err = f"json decode (attempt {attempt + 1}): {e} · raw[:200]={content[:200]}"
|
||
print(f"[vision retry] {last_err}", flush=True)
|
||
continue
|
||
except Exception as e:
|
||
last_err = f"vision call (attempt {attempt + 1}): {e}"
|
||
print(f"[vision retry] {last_err}", flush=True)
|
||
continue
|
||
|
||
if data is None:
|
||
raise HTTPException(500, last_err or "vision failed after 3 retries")
|
||
|
||
# 写回 job
|
||
new_frames = []
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
f.description = data
|
||
new_frames.append(f)
|
||
update(job, frames=new_frames, message=f"识别完成 · 分镜 {idx + 1}")
|
||
return job
|
||
|
||
|
||
# ---------- 清洗水印 / 元素提取(关键帧二阶段加工) ----------
|
||
|
||
class CleanupReq(BaseModel):
|
||
# 多个相对坐标矩形 0-1,限制清洗范围;空 / None = 全图清洗
|
||
regions: list[dict] | None = None # [{"x","y","w","h"}, ...]
|
||
|
||
|
||
def _region_to_phrase(r: dict) -> str:
|
||
"""把相对坐标矩形转成简短方位描述给 prompt 用(避免百分号 / 括号触发模型异常)"""
|
||
x = max(0.0, min(1.0, float(r.get("x", 0))))
|
||
y = max(0.0, min(1.0, float(r.get("y", 0))))
|
||
w = max(0.0, min(1.0 - x, float(r.get("w", 0))))
|
||
h = max(0.0, min(1.0 - y, float(r.get("h", 0))))
|
||
if w <= 0 or h <= 0:
|
||
return ""
|
||
cx, cy = x + w / 2, y + h / 2
|
||
hpos = "left" if cx < 0.4 else "right" if cx > 0.6 else "middle"
|
||
vpos = "top" if cy < 0.4 else "bottom" if cy > 0.6 else "middle"
|
||
if hpos == "middle" and vpos == "middle":
|
||
return "center"
|
||
if hpos == "middle":
|
||
return vpos
|
||
if vpos == "middle":
|
||
return hpos
|
||
return f"{vpos} {hpos}"
|
||
|
||
|
||
@app.post("/jobs/{job_id}/frames/{idx}/cleanup", response_model=Job)
|
||
def cleanup_frame(job_id: str, idx: int, req: CleanupReq | None = None) -> Job:
|
||
"""调 nano-banana image edit 清洗关键帧:去水印 / @用户名 / 字幕 / 平台 logo。
|
||
输出干净版到 jobs/<id>/cleaned/<idx>.jpg,写回 frame.cleaned_url。
|
||
可选 region: 限定只清洗框内区域。"""
|
||
import time as _time
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
frame = next((f for f in job.frames if f.index == idx), None)
|
||
if not frame:
|
||
raise HTTPException(404, "frame not found")
|
||
frame_path = job_dir(job_id) / "frames" / f"{idx:03d}.jpg"
|
||
if not frame_path.exists():
|
||
raise HTTPException(404, "frame file missing")
|
||
|
||
region_phrases: list[str] = []
|
||
if req and req.regions:
|
||
for r in req.regions:
|
||
p = _region_to_phrase(r)
|
||
if p:
|
||
region_phrases.append(p)
|
||
region_phrases = list(dict.fromkeys(region_phrases))
|
||
|
||
# prompt 用"重画一张副本"语义而非"erase / remove only X" — 避免 Gemini 走 mask/inpainting
|
||
# function call 路径(实测该路径在 SKG 网关上 100% 触发 incomplete_generation)
|
||
if region_phrases:
|
||
if len(region_phrases) == 1:
|
||
zones = f"the {region_phrases[0]} area"
|
||
else:
|
||
zones = ", ".join(region_phrases) + " areas"
|
||
prompt = (
|
||
f"Recreate this image as a clean version: remove the text and graphics in {zones}, "
|
||
"keep the rest of the scene identical."
|
||
)
|
||
else:
|
||
prompt = (
|
||
"Recreate this image as a clean version without watermarks, captions, "
|
||
"hashtags, usernames, or platform logos. Keep the composition and style."
|
||
)
|
||
|
||
# 模型轮换:nano-banana-pro 失败时换 flash 系列
|
||
models = [
|
||
IMAGE_MODEL, # gemini-3-pro-image-preview (nano-banana-pro)
|
||
"gemini-3.1-flash-image-preview",
|
||
"gemini-2.5-flash-image",
|
||
]
|
||
try:
|
||
img_bytes, _mode = _image_edit_call(
|
||
frame_path, prompt, models=models, fallback_text=False, max_attempts=3,
|
||
)
|
||
except RuntimeError as e:
|
||
raise HTTPException(500, f"cleanup failed: {e}")
|
||
|
||
out_dir = job_dir(job_id) / "cleaned"
|
||
out_dir.mkdir(parents=True, exist_ok=True)
|
||
out_path = out_dir / f"{idx:03d}.jpg"
|
||
out_path.write_bytes(img_bytes)
|
||
|
||
new_frames = []
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
f.cleaned_url = f"/jobs/{job_id}/frames/{idx}/cleaned.jpg?t={int(_time.time())}"
|
||
f.cleaned_applied = False # 重新清洗:重置"已应用"状态
|
||
new_frames.append(f)
|
||
update(job, frames=new_frames, message=f"清洗完成 · 分镜 {idx + 1}")
|
||
return job
|
||
|
||
|
||
@app.get("/jobs/{job_id}/frames/{idx}/cleaned.jpg")
|
||
def get_cleaned_frame(job_id: str, idx: int):
|
||
p = job_dir(job_id) / "cleaned" / f"{idx:03d}.jpg"
|
||
if not p.exists():
|
||
raise HTTPException(404, "cleaned frame not found")
|
||
return FileResponse(p, media_type="image/jpeg")
|
||
|
||
|
||
@app.delete("/jobs/{job_id}/frames/{idx}/cleanup", response_model=Job)
|
||
def discard_cleaned(job_id: str, idx: int) -> Job:
|
||
"""丢弃待应用的清洗版(不影响已应用的)"""
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
frame = next((f for f in job.frames if f.index == idx), None)
|
||
if not frame:
|
||
raise HTTPException(404, "frame not found")
|
||
p = job_dir(job_id) / "cleaned" / f"{idx:03d}.jpg"
|
||
if p.exists():
|
||
try: p.unlink()
|
||
except OSError: pass
|
||
new_frames = []
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
f.cleaned_url = None
|
||
new_frames.append(f)
|
||
update(job, frames=new_frames, message=f"丢弃清洗版 · 分镜 {idx + 1}")
|
||
return job
|
||
|
||
|
||
@app.post("/jobs/{job_id}/frames/{idx}/cleanup/apply", response_model=Job)
|
||
def apply_cleaned(job_id: str, idx: int) -> Job:
|
||
"""用清洗版替换原关键帧:物理覆盖 frames/{idx}.jpg ← cleaned/{idx}.jpg。
|
||
原图作备份 → orig/{idx}.jpg(首次替换时备份,后续替换跳过)。
|
||
替换后 frame.cleaned_url 清空(不再有"待应用"清洗版)"""
|
||
import shutil as _shutil
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
frame = next((f for f in job.frames if f.index == idx), None)
|
||
if not frame:
|
||
raise HTTPException(404, "frame not found")
|
||
cleaned_path = job_dir(job_id) / "cleaned" / f"{idx:03d}.jpg"
|
||
if not cleaned_path.exists():
|
||
raise HTTPException(404, "no cleaned version to apply")
|
||
frame_path = job_dir(job_id) / "frames" / f"{idx:03d}.jpg"
|
||
|
||
# 首次替换:把原图备份到 orig/{idx}.jpg
|
||
orig_dir = job_dir(job_id) / "orig"
|
||
orig_dir.mkdir(parents=True, exist_ok=True)
|
||
orig_backup = orig_dir / f"{idx:03d}.jpg"
|
||
if not orig_backup.exists() and frame_path.exists():
|
||
_shutil.copy2(frame_path, orig_backup)
|
||
|
||
# 用 cleaned 覆盖 frames/
|
||
_shutil.copy2(cleaned_path, frame_path)
|
||
# 删 cleaned 文件(已经"应用",不再是单独的待选版本)
|
||
try:
|
||
cleaned_path.unlink()
|
||
except OSError:
|
||
pass
|
||
|
||
new_frames = []
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
f.cleaned_url = None
|
||
f.cleaned_applied = True
|
||
new_frames.append(f)
|
||
update(job, frames=new_frames, message=f"已替换分镜 {idx + 1} 为清洗版")
|
||
return job
|
||
|
||
|
||
class AddElementReq(BaseModel):
|
||
name_zh: str
|
||
name_en: str = ""
|
||
position: str = ""
|
||
source: Literal["auto", "manual", "region"] = "manual"
|
||
region: dict | None = None
|
||
|
||
|
||
class UpdateElementReq(BaseModel):
|
||
name_zh: str | None = None
|
||
name_en: str | None = None
|
||
position: str | None = None
|
||
|
||
|
||
class GenerateSceneAssetReq(BaseModel):
|
||
quality: AssetQuality = "hd"
|
||
size: AssetSize = "source"
|
||
scene_mode: SceneMode = "remove_subject"
|
||
scene_style: SceneStyle = "source"
|
||
prompt: str = ""
|
||
source_frame_indices: list[int] | None = None
|
||
|
||
|
||
class GenerateSubjectAssetsReq(BaseModel):
|
||
subject_kind: SubjectKind = "object"
|
||
background: AssetBackground = "white"
|
||
quality: AssetQuality = "hd"
|
||
size: AssetSize = "source"
|
||
source_frame_indices: list[int] | None = None
|
||
views: list[str] | None = None
|
||
|
||
|
||
@app.post("/jobs/{job_id}/frames/{idx}/elements", response_model=Job)
|
||
def add_element(job_id: str, idx: int, req: AddElementReq) -> Job:
|
||
"""加一条元素 · 若 name_en 缺则自动 zh→en 翻译"""
|
||
import time as _time
|
||
import re as _re
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
frame = next((f for f in job.frames if f.index == idx), None)
|
||
if not frame:
|
||
raise HTTPException(404, "frame not found")
|
||
name_zh = req.name_zh.strip()
|
||
if not name_zh:
|
||
raise HTTPException(400, "name_zh required")
|
||
name_en = req.name_en.strip()
|
||
if not name_en and LLM_API_KEY:
|
||
try:
|
||
prompt = (
|
||
"Translate the following text into concise English, suitable as an element label "
|
||
"in an image-generation prompt. Output only the translation — no quotes, no punctuation, "
|
||
f"no explanation.\n\nInput: {name_zh}"
|
||
)
|
||
resp = llm().chat.completions.create(
|
||
model=TRANSLATE_MODEL,
|
||
messages=[{"role": "user", "content": prompt}],
|
||
temperature=0.2,
|
||
max_tokens=200,
|
||
)
|
||
out = (resp.choices[0].message.content or "").strip()
|
||
if not out:
|
||
rc = getattr(resp.choices[0].message, "reasoning_content", "") or ""
|
||
if rc:
|
||
out = rc.strip().splitlines()[-1].strip()
|
||
name_en = _re.sub(r'^[\'"「『]+|[\'"」』]+$', "", out).strip()
|
||
except Exception as e:
|
||
print(f"[add_element translate failed] {e}", flush=True)
|
||
name_en = ""
|
||
|
||
el = KeyElement(
|
||
id=uuid.uuid4().hex[:8],
|
||
name_zh=name_zh,
|
||
name_en=name_en,
|
||
position=req.position.strip(),
|
||
source=req.source,
|
||
region=req.region,
|
||
created_at=_time.time(),
|
||
)
|
||
new_frames = []
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
f.elements = f.elements + [el]
|
||
new_frames.append(f)
|
||
update(job, frames=new_frames, message=f"加入元素 · 分镜 {idx + 1} · {name_zh}")
|
||
return job
|
||
|
||
|
||
@app.patch("/jobs/{job_id}/frames/{idx}/elements/{element_id}", response_model=Job)
|
||
def update_element(job_id: str, idx: int, element_id: str, req: UpdateElementReq) -> Job:
|
||
"""更新元素标签 / 英文提示。提取不准时允许用户修正,不强制重建元素。"""
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
changed_name = ""
|
||
found = False
|
||
new_frames = []
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
for e in f.elements:
|
||
if e.id == element_id:
|
||
found = True
|
||
if req.name_zh is not None:
|
||
name_zh = req.name_zh.strip()
|
||
if not name_zh:
|
||
raise HTTPException(400, "name_zh required")
|
||
e.name_zh = name_zh
|
||
changed_name = name_zh
|
||
if req.name_en is not None:
|
||
e.name_en = req.name_en.strip()
|
||
if req.position is not None:
|
||
e.position = req.position.strip()
|
||
new_frames.append(f)
|
||
if not found:
|
||
raise HTTPException(404, "element not found")
|
||
update(job, frames=new_frames, message=f"更新元素 · 分镜 {idx + 1} · {changed_name or element_id}")
|
||
return job
|
||
|
||
|
||
@app.delete("/jobs/{job_id}/frames/{idx}/elements/{element_id}", response_model=Job)
|
||
def delete_element(job_id: str, idx: int, element_id: str) -> Job:
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
new_frames = []
|
||
removed = False
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
before = len(f.elements)
|
||
f.elements = [e for e in f.elements if e.id != element_id]
|
||
removed = len(f.elements) < before
|
||
# 若有提取图也删(含多版本)
|
||
if removed:
|
||
elements_dir = job_dir(job_id) / "elements"
|
||
if elements_dir.exists():
|
||
for pat in (f"{idx:03d}_{element_id}.jpg", f"{idx:03d}_{element_id}.png",
|
||
f"{idx:03d}_{element_id}_*.jpg"):
|
||
for p in elements_dir.glob(pat):
|
||
try: p.unlink()
|
||
except OSError: pass
|
||
new_frames.append(f)
|
||
if not removed:
|
||
raise HTTPException(404, "element not found")
|
||
update(job, frames=new_frames, message=f"删除元素 · 分镜 {idx + 1}")
|
||
return job
|
||
|
||
|
||
@app.post("/jobs/{job_id}/frames/{idx}/scene-asset", response_model=Job)
|
||
def generate_scene_asset(job_id: str, idx: int, req: GenerateSceneAssetReq) -> Job:
|
||
"""为关键帧生成一张干净、高清的场景参考图。默认一帧只需要一张,重跑会保留历史供人工比对。
|
||
场景图排在主体资产之后:优先依据已确认主体,去主体并补全背景,再按模式生成原场景/相似场景/换风格场景。"""
|
||
import time as _time
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
frame = _find_frame(job, idx)
|
||
src = _source_frame_path(job_id, idx)
|
||
if not src.exists():
|
||
raise HTTPException(404, "source frame file missing")
|
||
|
||
source_indices = [int(x) for x in (req.source_frame_indices or [idx]) if isinstance(x, int) or str(x).isdigit()]
|
||
if not source_indices:
|
||
source_indices = [idx]
|
||
source_indices = list(dict.fromkeys(source_indices))[:8]
|
||
model_src = src
|
||
sheet_tmp: Path | None = None
|
||
if len(source_indices) > 1:
|
||
sheet_tmp = job_dir(job_id) / "tmp" / f"scene_refs_{idx:03d}_{uuid.uuid4().hex[:6]}.jpg"
|
||
sheet = _make_reference_contact_sheet(job_id, source_indices, sheet_tmp)
|
||
if sheet:
|
||
model_src = sheet
|
||
|
||
confirmed_subjects = [
|
||
(e.name_en or e.name_zh).strip()
|
||
for ref_frame in job.frames
|
||
for e in (ref_frame.elements or [])
|
||
if (e.subject_assets or [])
|
||
]
|
||
if not confirmed_subjects:
|
||
confirmed_subjects = [
|
||
(e.name_en or e.name_zh).strip()
|
||
for ref_frame in job.frames
|
||
for e in (ref_frame.elements or [])
|
||
if (e.name_en or e.name_zh).strip()
|
||
][:3]
|
||
confirmed_subjects = list(dict.fromkeys([x for x in confirmed_subjects if x]))[:3]
|
||
subject_clause = (
|
||
"Confirmed foreground subject(s) to remove: " + ", ".join(confirmed_subjects) + ". "
|
||
if confirmed_subjects
|
||
else "Remove the main foreground subject from the frame if present. "
|
||
)
|
||
mode_clause = {
|
||
"remove_subject": (
|
||
"Keep the original environment, camera angle, perspective, composition, lighting direction, color mood, and spatial layout. "
|
||
"The result should be an empty clean scene/background plate with the subject removed and the occluded background reconstructed."
|
||
),
|
||
"similar": (
|
||
"Create a similar but not identical scene/background plate: keep the same camera angle, rough spatial layout, lighting direction, and usage context, "
|
||
"but vary props, surface details, textures, and small environmental details so it is not a duplicate of the source."
|
||
),
|
||
"style": (
|
||
"Create a scene/background plate with the same camera angle and spatial layout, but reinterpret the environment in the selected visual style. "
|
||
"Keep it believable and useful for image-to-video generation."
|
||
),
|
||
}[req.scene_mode]
|
||
style_clause = {
|
||
"source": "Follow the original source style.",
|
||
"premium_product": "Use a premium product-advertising style: polished, high-end, clean commercial lighting, refined materials.",
|
||
"clean_studio": "Use a clean studio style: simple surfaces, controlled lighting, minimal distractions.",
|
||
"warm_lifestyle": "Use a warm lifestyle style: realistic lived-in details, soft natural light, approachable atmosphere.",
|
||
"cinematic": "Use a cinematic style: dramatic but natural lighting, richer depth, filmic contrast, not fantasy.",
|
||
}[req.scene_style]
|
||
user_prompt = req.prompt.strip()
|
||
user_prompt_clause = (
|
||
"User scene direction: " + user_prompt[:1200] + " "
|
||
if user_prompt
|
||
else ""
|
||
)
|
||
reference_clause = (
|
||
f"Use the selected reference frame contact sheet as visual evidence for location, composition, lighting, materials, and atmosphere. Reference frame indices: {', '.join(str(i + 1) for i in source_indices)}. "
|
||
if len(source_indices) > 1
|
||
else "Use the provided frame as the primary visual reference. "
|
||
)
|
||
prompt = (
|
||
"Create one clean high-definition scene/background reference image from this frame. "
|
||
+ subject_clause
|
||
+ "Do not include the removed subject, duplicate people, animals, products, text, watermark, platform UI, captions, usernames, hashtags, logos, or overlay graphics. "
|
||
+ reference_clause
|
||
+ user_prompt_clause
|
||
+ mode_clause + " "
|
||
+ style_clause + " "
|
||
+ "Enhance clarity and texture while avoiding over-smoothing, warped geometry, or changing important perspective details. "
|
||
+ "Do not create multiple views. Do not isolate objects."
|
||
)
|
||
models = [IMAGE_MODEL, "gemini-3.1-flash-image-preview", "gemini-2.5-flash-image"]
|
||
try:
|
||
img_bytes, _mode = _image_edit_call(model_src, prompt, models=models, fallback_text=False, max_attempts=3, max_side=1280)
|
||
except RuntimeError as e:
|
||
raise HTTPException(500, f"scene asset failed: {e}")
|
||
finally:
|
||
if sheet_tmp and sheet_tmp.exists():
|
||
try: sheet_tmp.unlink()
|
||
except OSError: pass
|
||
|
||
asset_id = f"scene_{idx:03d}_{uuid.uuid4().hex[:8]}"
|
||
out_path = job_dir(job_id) / "assets" / f"{asset_id}.jpg"
|
||
width, height = _normalize_asset_image(img_bytes, out_path, src, req.size, "white", square=False)
|
||
report = _image_quality_report(out_path)
|
||
scene = SceneAsset(
|
||
id=asset_id,
|
||
label=f"分镜 {idx + 1} 场景图",
|
||
url=_asset_url(job_id, asset_id),
|
||
width=width,
|
||
height=height,
|
||
quality=req.quality,
|
||
size=req.size,
|
||
scene_mode=req.scene_mode,
|
||
scene_style=req.scene_style,
|
||
quality_report=report,
|
||
created_at=_time.time(),
|
||
)
|
||
|
||
new_frames = []
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
f.quality_report = _image_quality_report(src)
|
||
f.scene_assets = (f.scene_assets or []) + [scene]
|
||
new_frames.append(f)
|
||
update(job, frames=new_frames, message=f"场景图生成完成 · 分镜 {idx + 1}")
|
||
return job
|
||
|
||
|
||
@app.post("/jobs/{job_id}/frames/{idx}/elements/{element_id}/cutout", response_model=Job)
|
||
def cutout_element(job_id: str, idx: int, element_id: str) -> Job:
|
||
"""AI 提取元素 · 每次累积一张新图:
|
||
调 nano-banana 模型生成**完整、清晰**的元素图(即使原图只露出部分也补全)。
|
||
region 元素:先把 region + 30% padding 区域裁出作为 focus,再发给模型聚焦补全。"""
|
||
from PIL import Image as _PILImage
|
||
import io as _io
|
||
import tempfile as _tempfile
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
frame = next((f for f in job.frames if f.index == idx), None)
|
||
if not frame:
|
||
raise HTTPException(404, "frame not found")
|
||
el = next((e for e in frame.elements if e.id == element_id), None)
|
||
if not el:
|
||
raise HTTPException(404, "element not found")
|
||
|
||
cleaned_path = job_dir(job_id) / "cleaned" / f"{idx:03d}.jpg"
|
||
src = cleaned_path if cleaned_path.exists() else job_dir(job_id) / "frames" / f"{idx:03d}.jpg"
|
||
if not src.exists():
|
||
raise HTTPException(404, "source frame file missing")
|
||
|
||
out_dir = job_dir(job_id) / "elements"
|
||
out_dir.mkdir(parents=True, exist_ok=True)
|
||
new_cutout_id = uuid.uuid4().hex[:8]
|
||
out_path = out_dir / f"{idx:03d}_{element_id}_{new_cutout_id}.jpg"
|
||
|
||
# region 元素:先 PIL 裁出 region + 30% padding 作为 focus 给模型(让它聚焦在该元素)
|
||
tmp_focus: Path | None = None
|
||
model_src = src
|
||
if el.region:
|
||
try:
|
||
im = _PILImage.open(src).convert("RGB")
|
||
W, H = im.size
|
||
r = el.region
|
||
x = max(0.0, min(1.0, float(r.get("x", 0))))
|
||
y = max(0.0, min(1.0, float(r.get("y", 0))))
|
||
w = max(0.0, min(1.0 - x, float(r.get("w", 0))))
|
||
h = max(0.0, min(1.0 - y, float(r.get("h", 0))))
|
||
cx, cy = x + w / 2, y + h / 2
|
||
# 扩大 30% 给上下文(避免裁到正好边界丢失补全 hint)
|
||
ew, eh = w * 1.6, h * 1.6
|
||
x0 = max(0.0, cx - ew / 2); y0 = max(0.0, cy - eh / 2)
|
||
x1 = min(1.0, cx + ew / 2); y1 = min(1.0, cy + eh / 2)
|
||
left, top, right, bottom = int(x0 * W), int(y0 * H), int(x1 * W), int(y1 * H)
|
||
if right - left > 8 and bottom - top > 8:
|
||
cropped = im.crop((left, top, right, bottom))
|
||
tmp = _tempfile.NamedTemporaryFile(suffix=".jpg", delete=False)
|
||
cropped.save(tmp.name, format="JPEG", quality=92)
|
||
tmp.close()
|
||
tmp_focus = Path(tmp.name)
|
||
model_src = tmp_focus
|
||
except Exception as e:
|
||
print(f"[cutout region crop failed, fallback to full frame] {e}", flush=True)
|
||
|
||
target = (el.name_en or el.name_zh).strip()
|
||
prompt = (
|
||
f"Identify the {target} in this image. "
|
||
f"Generate a complete, high-resolution, sharply detailed image of the entire {target} as a standalone asset. "
|
||
f"If the {target} is only partially visible in the source (cropped at edges, occluded by other objects, or out of frame), "
|
||
"intelligently reconstruct the missing parts based on visual context so the result shows the FULL element. "
|
||
"Place the complete element on a pure white background, isolated, with no other objects, no scene fragments, no shadows from the original scene. "
|
||
"Preserve the element's original color palette, style, lighting character, and proportions. "
|
||
"Output must be a clean, high-quality asset image suitable for downstream composition."
|
||
)
|
||
models = [IMAGE_MODEL, "gemini-2.5-flash-image"]
|
||
img_bytes: bytes
|
||
try:
|
||
try:
|
||
img_bytes, _mode = _image_edit_call(
|
||
model_src, prompt, models=models, fallback_text=False, max_attempts=3,
|
||
)
|
||
except RuntimeError as e:
|
||
raise HTTPException(500, f"extract failed: {e}")
|
||
finally:
|
||
if tmp_focus and tmp_focus.exists():
|
||
try: tmp_focus.unlink()
|
||
except OSError: pass
|
||
out_path.write_bytes(img_bytes)
|
||
|
||
new_frames = []
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
for e in f.elements:
|
||
if e.id == element_id:
|
||
e.cutouts = (e.cutouts or []) + [new_cutout_id]
|
||
if not e.cutout_id:
|
||
e.cutout_id = new_cutout_id
|
||
new_frames.append(f)
|
||
update(job, frames=new_frames, message=f"提取完成 · {el.name_zh}")
|
||
return job
|
||
|
||
|
||
@app.post("/jobs/{job_id}/frames/{idx}/elements/{element_id}/subject-assets", response_model=Job)
|
||
def generate_subject_assets(job_id: str, idx: int, element_id: str, req: GenerateSubjectAssetsReq) -> Job:
|
||
"""为一个主体生成多视角资产包。
|
||
如果传入 source_frame_indices,则把多张已选关键帧拼成参考板,表示这些帧都在服务同一个主体。"""
|
||
import time as _time
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
frame = _find_frame(job, idx)
|
||
el = next((e for e in frame.elements if e.id == element_id), None)
|
||
if not el:
|
||
raise HTTPException(404, "element not found")
|
||
|
||
source_indices = [int(x) for x in (req.source_frame_indices or [idx]) if isinstance(x, int) or str(x).isdigit()]
|
||
if idx not in source_indices:
|
||
source_indices = [idx] + source_indices
|
||
source_indices = list(dict.fromkeys(source_indices))[:6]
|
||
|
||
model_src, tmp_focus = _focus_source_for_element(job_id, idx, el)
|
||
sheet_tmp: Path | None = None
|
||
if len(source_indices) > 1:
|
||
sheet_tmp = job_dir(job_id) / "tmp" / f"subject_refs_{idx:03d}_{element_id}_{uuid.uuid4().hex[:6]}.jpg"
|
||
sheet = _make_reference_contact_sheet(job_id, source_indices, sheet_tmp)
|
||
if sheet:
|
||
model_src = sheet
|
||
|
||
target = (el.name_en or el.name_zh).strip()
|
||
bg_phrase = "pure white" if req.background == "white" else "pure black"
|
||
kind_phrase = "person, animal, or living character" if req.subject_kind == "living" else "object or product-like subject"
|
||
transparent_character_clause = (
|
||
TRANSPARENT_HUMAN_POSITIVE_PROMPT
|
||
+ " The generated living character must be a friendly transparent humanoid with transparent or translucent outer body and clean white skeleton visible inside the same body. "
|
||
+ TRANSPARENT_HUMAN_NEGATIVE_PROMPT
|
||
+ " Do not render a normal human, ordinary skeleton-only character, horror skeleton, medical anatomy, organs, veins, blood, corpse, zombie, hospital, surgery, or autopsy visual. "
|
||
if req.subject_kind == "living"
|
||
else ""
|
||
)
|
||
models = [IMAGE_MODEL, "gemini-3.1-flash-image-preview", "gemini-2.5-flash-image"]
|
||
generated: list[SubjectAsset] = []
|
||
try:
|
||
for view, view_label in _subject_view_labels(req.subject_kind, req.views):
|
||
if req.subject_kind == "living":
|
||
if view.startswith("expression_"):
|
||
emotion = view_label.replace("表情", "")
|
||
view_prompt = f"full-body upright standing character reference with a clear {emotion} facial expression"
|
||
elif view.startswith("action_") or view == "side_walk":
|
||
view_prompt = f"full-body upright standing character reference, {view_label}, same identity and proportions"
|
||
else:
|
||
view_prompt = f"full-body upright standing character reference, {view_label}"
|
||
else:
|
||
view_prompt = f"complete object/product reference, {view_label} view"
|
||
prompt = (
|
||
f"Use the reference image(s) only as visual evidence to redraw the same {target}; do not crop, cut out, paste, or extract pixels from the source. "
|
||
f"Generate one newly rendered {view_prompt} of the same subject. "
|
||
f"The subject is a {kind_phrase}. If multiple frames are shown, treat them as evidence of one same subject, not multiple subjects. "
|
||
"Preserve identity, proportions, silhouette, material, colors, styling, and distinctive details across all generated views. "
|
||
"The subject must be complete, centered, full body or full object, head-to-feet visible when applicable, not cropped by the canvas. "
|
||
"Make the subject large and readable: it should occupy about 85-95% of the image height with only small margins. "
|
||
f"Create a high-definition standalone asset on a solid {bg_phrase} background. "
|
||
"No extra objects, no props, no additional products, no background elements, no original scene fragments, no shadows from the original scene, no text, no watermark, no UI. "
|
||
"If the source is incomplete, partially visible, occluded, or low resolution, reconstruct the missing parts by redrawing a clean complete subject while staying consistent with the reference. "
|
||
"For living subjects, keep a normal upright standing pose for the standard views; do not create sitting, walking, medical, horror, or distorted anatomy unless explicitly requested by the view label. "
|
||
+ transparent_character_clause
|
||
)
|
||
try:
|
||
img_bytes, _mode = _image_edit_call(model_src, prompt, models=models, fallback_text=False, max_attempts=3, max_side=1280)
|
||
except RuntimeError as e:
|
||
raise HTTPException(500, f"subject asset {view} failed: {e}")
|
||
|
||
asset_id = f"subject_{idx:03d}_{element_id}_{view}_{uuid.uuid4().hex[:8]}"
|
||
out_path = job_dir(job_id) / "assets" / f"{asset_id}.jpg"
|
||
width, height = _normalize_asset_image(img_bytes, out_path, _source_frame_path(job_id, idx), req.size, req.background, square=False, fill_subject=True)
|
||
generated.append(SubjectAsset(
|
||
id=asset_id,
|
||
view=view,
|
||
label=f"{el.name_zh} · {view_label}",
|
||
url=_asset_url(job_id, asset_id),
|
||
width=width,
|
||
height=height,
|
||
background=req.background,
|
||
quality=req.quality,
|
||
size=req.size,
|
||
source_frame_indices=source_indices,
|
||
created_at=_time.time(),
|
||
))
|
||
finally:
|
||
for p in (tmp_focus, sheet_tmp):
|
||
if p and p.exists():
|
||
try: p.unlink()
|
||
except OSError: pass
|
||
|
||
src = _source_frame_path(job_id, idx)
|
||
new_frames = []
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
f.quality_report = _image_quality_report(src, el.region)
|
||
for e in f.elements:
|
||
if e.id == element_id:
|
||
e.subject_kind = req.subject_kind
|
||
e.cutout_background = req.background
|
||
e.subject_assets = (e.subject_assets or []) + generated
|
||
new_frames.append(f)
|
||
update(job, frames=new_frames, message=f"主体资产包生成完成 · {el.name_zh} · {len(generated)} 张")
|
||
return job
|
||
|
||
|
||
@app.delete("/jobs/{job_id}/frames/{idx}/elements/{element_id}/cutouts/{cutout_id}", response_model=Job)
|
||
def delete_cutout(job_id: str, idx: int, element_id: str, cutout_id: str) -> Job:
|
||
"""删除该元素的某张提取图"""
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
p = job_dir(job_id) / "elements" / f"{idx:03d}_{element_id}_{cutout_id}.jpg"
|
||
if p.exists():
|
||
try: p.unlink()
|
||
except OSError: pass
|
||
|
||
removed = False
|
||
new_frames = []
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
for e in f.elements:
|
||
if e.id == element_id:
|
||
if cutout_id in (e.cutouts or []):
|
||
e.cutouts = [c for c in e.cutouts if c != cutout_id]
|
||
removed = True
|
||
# cutout_id 兼容字段:若指向被删的就清空 / 移到 cutouts 第一个
|
||
if e.cutout_id == cutout_id:
|
||
e.cutout_id = e.cutouts[0] if e.cutouts else None
|
||
new_frames.append(f)
|
||
if not removed:
|
||
raise HTTPException(404, "cutout not found in element")
|
||
update(job, frames=new_frames, message=f"删除提取图")
|
||
return job
|
||
|
||
|
||
class UpdateStoryboardReq(BaseModel):
|
||
duration: float = 0
|
||
first_image: dict | None = None
|
||
last_image: dict | None = None
|
||
product_images: list[dict] = Field(default_factory=list)
|
||
product_fusion_shots: list[dict] = Field(default_factory=list)
|
||
subject_image: dict | None = None
|
||
scene_image: dict | None = None
|
||
product_image: dict | None = None
|
||
action_image: dict | None = None
|
||
# v1 字段(前端可不传)
|
||
subject: str = ""
|
||
product: str = ""
|
||
scene: str = ""
|
||
action: str = ""
|
||
reference_ids: list[str] = []
|
||
|
||
|
||
class GenerateStoryboardVideoReq(BaseModel):
|
||
prompt: str
|
||
duration: float = 4
|
||
first_image: dict | None = None
|
||
last_image: dict | None = None
|
||
product_images: list[dict] = Field(default_factory=list)
|
||
subject_image: dict | None = None
|
||
scene_image: dict | None = None
|
||
product_image: dict | None = None
|
||
action_image: dict | None = None
|
||
source_ref: VideoSourceRef | None = None
|
||
model: str = ""
|
||
size: str = "720x1280"
|
||
|
||
|
||
class ProductFusionDescriptionReq(BaseModel):
|
||
shots: list[ProductFusionShot] = Field(default_factory=list)
|
||
|
||
|
||
def video_seconds(duration: float) -> str:
|
||
if video_uses_ark():
|
||
if duration <= 0:
|
||
return "5"
|
||
return str(max(4, min(15, round(duration))))
|
||
if duration <= 6:
|
||
return "4"
|
||
if duration <= 10:
|
||
return "8"
|
||
return "12"
|
||
|
||
|
||
def resolve_video_model(raw: str | None) -> str:
|
||
requested = (raw or VIDEO_MODEL or "seedance").strip()
|
||
lowered = requested.lower()
|
||
if lowered in {"sora", "sora-2", "sora_2"}:
|
||
raise HTTPException(400, "Sora 已停用,请选择 Seedance / Kling / Veo 3")
|
||
return VIDEO_MODEL_ALIASES.get(lowered, requested)
|
||
|
||
|
||
def normalize_video_status(status: str | None) -> Literal["queued", "in_progress", "completed", "failed"]:
|
||
s = (status or "queued").lower()
|
||
if s in {"completed", "complete", "succeeded", "success", "done"}:
|
||
return "completed"
|
||
if s in {"failed", "failure", "error", "cancelled", "canceled", "expired"}:
|
||
return "failed"
|
||
if s in {"running", "processing", "in_progress", "generating", "started"}:
|
||
return "in_progress"
|
||
return "queued"
|
||
|
||
|
||
def video_progress(data: dict, fallback: int) -> int:
|
||
raw = data.get("progress", data.get("percentage", data.get("percent", fallback)))
|
||
try:
|
||
value = int(float(raw))
|
||
except Exception:
|
||
value = fallback
|
||
return max(0, min(100, value))
|
||
|
||
|
||
def video_url_from_response(data: dict) -> str:
|
||
for key in ("url", "video_url", "output_url", "download_url"):
|
||
v = data.get(key)
|
||
if isinstance(v, str) and v:
|
||
return v
|
||
arr = data.get("data")
|
||
if isinstance(arr, list) and arr:
|
||
first = arr[0]
|
||
if isinstance(first, dict):
|
||
for key in ("url", "video_url", "output_url", "download_url"):
|
||
v = first.get(key)
|
||
if isinstance(v, str) and v:
|
||
return v
|
||
output = data.get("output")
|
||
if isinstance(output, dict):
|
||
for key in ("url", "video_url", "download_url"):
|
||
v = output.get(key)
|
||
if isinstance(v, str) and v:
|
||
return v
|
||
content = data.get("content")
|
||
if isinstance(content, dict):
|
||
for key in ("video_url", "url", "download_url", "file_url"):
|
||
v = content.get(key)
|
||
if isinstance(v, str) and v:
|
||
return v
|
||
return ""
|
||
|
||
|
||
def download_generated_video(client, base: str, headers: dict, provider_id: str, direct_url: str, out_mp4: Path) -> None:
|
||
if direct_url:
|
||
url = direct_url if direct_url.startswith("http") else f"{base}{direct_url if direct_url.startswith('/') else '/' + direct_url}"
|
||
r = client.get(url, headers=headers if url.startswith(base) else None)
|
||
else:
|
||
r = client.get(f"{base}{video_path(VIDEO_CONTENT_PATH, id=provider_id)}", headers=headers)
|
||
r.raise_for_status()
|
||
out_mp4.write_bytes(r.content)
|
||
|
||
|
||
def size_to_video_ratio(size: str) -> str:
|
||
try:
|
||
w, h = [int(x) for x in size.lower().replace(" ", "").split("x", 1)]
|
||
except Exception:
|
||
return "9:16"
|
||
if w <= 0 or h <= 0:
|
||
return "9:16"
|
||
ratio = w / h
|
||
known = {
|
||
"16:9": 16 / 9,
|
||
"9:16": 9 / 16,
|
||
"1:1": 1,
|
||
"4:3": 4 / 3,
|
||
"3:4": 3 / 4,
|
||
"21:9": 21 / 9,
|
||
}
|
||
return min(known, key=lambda key: abs(known[key] - ratio))
|
||
|
||
|
||
def ark_reference_data_url(ref_img: Path) -> str:
|
||
mime = "image/png" if ref_img.suffix.lower() == ".png" else "image/jpeg"
|
||
return f"data:{mime};base64,{base64.b64encode(ref_img.read_bytes()).decode('ascii')}"
|
||
|
||
|
||
def submit_video_create(
|
||
client,
|
||
url: str,
|
||
headers: dict,
|
||
ref_img: Path,
|
||
payload: dict,
|
||
source_ref: VideoSourceRef | None = None,
|
||
last_img: Path | None = None,
|
||
product_imgs: list[Path] | None = None,
|
||
):
|
||
if video_uses_ark():
|
||
content = [{"type": "text", "text": payload["prompt"]}]
|
||
if source_ref and source_ref.kind == "source_video" and source_ref.url:
|
||
content.append(
|
||
{
|
||
"type": "video_url",
|
||
"video_url": {"url": source_ref.url},
|
||
"role": "reference_video",
|
||
}
|
||
)
|
||
content.append(
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {"url": ark_reference_data_url(ref_img)},
|
||
"role": "first_frame",
|
||
}
|
||
)
|
||
if last_img and last_img.exists():
|
||
content.append(
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {"url": ark_reference_data_url(last_img)},
|
||
"role": "last_frame",
|
||
}
|
||
)
|
||
for product_img in (product_imgs or [])[:6]:
|
||
if product_img.exists():
|
||
content.append(
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {"url": ark_reference_data_url(product_img)},
|
||
"role": "reference_image",
|
||
}
|
||
)
|
||
data = {
|
||
"model": payload["model"],
|
||
"content": content,
|
||
"ratio": size_to_video_ratio(str(payload.get("size", ""))),
|
||
"duration": int(float(str(payload.get(VIDEO_DURATION_FIELD, 5)))),
|
||
"watermark": False,
|
||
"resolution": "720p",
|
||
}
|
||
return client.post(url, headers={**headers, "Content-Type": "application/json"}, json=data)
|
||
|
||
if video_uses_poe():
|
||
data = dict(payload)
|
||
data[VIDEO_DURATION_FIELD] = int(float(str(data.get(VIDEO_DURATION_FIELD, 4))))
|
||
data["input_image"] = base64.b64encode(ref_img.read_bytes()).decode("ascii")
|
||
return client.post(url, headers=headers, json=data)
|
||
|
||
with ref_img.open("rb") as fh:
|
||
return client.post(
|
||
url,
|
||
headers=headers,
|
||
data=payload,
|
||
files={"input_reference": ("reference.jpg", fh, "image/jpeg")},
|
||
)
|
||
|
||
|
||
def render_storyboard_video(
|
||
job_id: str,
|
||
local_id: str,
|
||
provider_id: str,
|
||
ref_path: Path,
|
||
prompt: str,
|
||
model: str,
|
||
seconds: str,
|
||
size: str,
|
||
source_ref: VideoSourceRef | None = None,
|
||
last_ref_path: Path | None = None,
|
||
product_ref_paths: list[Path] | None = None,
|
||
) -> None:
|
||
import httpx
|
||
|
||
out_dir = job_dir(job_id) / "storyboard_videos" / local_id
|
||
ref_img = out_dir / "reference.jpg"
|
||
last_img = out_dir / "last_reference.jpg"
|
||
out_mp4 = out_dir / "video.mp4"
|
||
base = video_api_base()
|
||
headers = {"Authorization": f"Bearer {video_api_key()}"}
|
||
|
||
try:
|
||
prepare_video_reference(ref_path, ref_img)
|
||
prepared_last_img: Path | None = None
|
||
if last_ref_path and last_ref_path.exists():
|
||
prepare_video_reference(last_ref_path, last_img)
|
||
prepared_last_img = last_img
|
||
prepared_product_imgs: list[Path] = []
|
||
for i, product_ref_path in enumerate((product_ref_paths or [])[:6], start=1):
|
||
if product_ref_path.exists():
|
||
product_img = out_dir / f"product_reference_{i}.jpg"
|
||
prepare_video_reference(product_ref_path, product_img)
|
||
prepared_product_imgs.append(product_img)
|
||
update_generated_video(job_id, local_id, status="in_progress", progress=5)
|
||
with httpx.Client(timeout=120) as client:
|
||
payload = {"model": model, "prompt": prompt, "size": size}
|
||
payload[VIDEO_DURATION_FIELD] = seconds
|
||
create = None
|
||
create_errors: list[str] = []
|
||
for create_path in VIDEO_CREATE_PATHS:
|
||
resp = submit_video_create(client, f"{base}{video_path(create_path)}", headers, ref_img, payload, source_ref, prepared_last_img, prepared_product_imgs)
|
||
if video_uses_ark() and source_ref and resp.status_code in {400, 422}:
|
||
create_errors.append(f"{video_path(create_path)} + reference_video -> HTTP {resp.status_code}: {resp.text[:160]}")
|
||
resp = submit_video_create(client, f"{base}{video_path(create_path)}", headers, ref_img, payload, None, prepared_last_img, prepared_product_imgs)
|
||
if video_uses_ark() and prepared_last_img and resp.status_code in {400, 422}:
|
||
create_errors.append(f"{video_path(create_path)} + last_frame -> HTTP {resp.status_code}: {resp.text[:160]}")
|
||
resp = submit_video_create(client, f"{base}{video_path(create_path)}", headers, ref_img, payload, None, None, prepared_product_imgs)
|
||
if video_uses_ark() and prepared_product_imgs and resp.status_code in {400, 422}:
|
||
create_errors.append(f"{video_path(create_path)} + product_reference -> HTTP {resp.status_code}: {resp.text[:160]}")
|
||
resp = submit_video_create(client, f"{base}{video_path(create_path)}", headers, ref_img, payload, None, prepared_last_img, None)
|
||
if resp.status_code < 400:
|
||
create = resp
|
||
break
|
||
create_errors.append(f"{video_path(create_path)} -> HTTP {resp.status_code}: {resp.text[:160]}")
|
||
if resp.status_code not in {400, 404, 405}:
|
||
resp.raise_for_status()
|
||
if create is None:
|
||
raise RuntimeError("视频模型已选择,但当前网关视频生成入口不可用;已尝试 " + " | ".join(create_errors))
|
||
data = create.json()
|
||
video_api_id = data.get("id") or provider_id or local_id
|
||
status = normalize_video_status(data.get("status"))
|
||
progress = video_progress(data, 5)
|
||
direct_url = video_url_from_response(data)
|
||
update_generated_video(job_id, local_id, provider_id=video_api_id, status=status, progress=progress)
|
||
|
||
deadline = time.time() + VIDEO_POLL_TIMEOUT_SECONDS
|
||
while status in {"queued", "in_progress"} and time.time() < deadline:
|
||
time.sleep(8)
|
||
poll = client.get(f"{base}{video_path(VIDEO_STATUS_PATH, id=video_api_id)}", headers=headers)
|
||
poll.raise_for_status()
|
||
pdata = poll.json()
|
||
status = normalize_video_status(pdata.get("status"))
|
||
progress = video_progress(pdata, progress)
|
||
direct_url = video_url_from_response(pdata) or direct_url
|
||
update_generated_video(job_id, local_id, status=status, progress=progress)
|
||
|
||
if status != "completed":
|
||
update_generated_video(job_id, local_id, status="failed", error=f"video status: {status}", progress=progress)
|
||
return
|
||
|
||
download_generated_video(client, base, headers, video_api_id, direct_url, out_mp4)
|
||
update_generated_video(
|
||
job_id,
|
||
local_id,
|
||
status="completed",
|
||
progress=100,
|
||
url=f"/jobs/{job_id}/storyboard-videos/{local_id}.mp4",
|
||
error="",
|
||
)
|
||
except Exception as e:
|
||
update_generated_video(job_id, local_id, status="failed", error=str(e)[:500])
|
||
|
||
|
||
@app.post("/jobs/{job_id}/frames/{idx}/storyboard/video", response_model=Job)
|
||
def generate_storyboard_video(job_id: str, idx: int, req: GenerateStoryboardVideoReq, bg: BackgroundTasks) -> Job:
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
frame = next((f for f in job.frames if f.index == idx), None)
|
||
if not frame:
|
||
raise HTTPException(404, "frame not found")
|
||
ensure_video_api_configured()
|
||
prompt = req.prompt.strip()
|
||
if not prompt:
|
||
raise HTTPException(400, "prompt required")
|
||
|
||
ref = req.first_image or req.subject_image or req.product_image or req.scene_image or req.action_image
|
||
ref_path = storyboard_ref_path(job_id, ref) or (job_dir(job_id) / "frames" / f"{idx:03d}.jpg")
|
||
if not ref_path.exists():
|
||
raise HTTPException(404, "reference image missing")
|
||
poster = storyboard_ref_url(job_id, ref) or f"/jobs/{job_id}/frames/{idx}.jpg"
|
||
last_ref_path = storyboard_ref_path(job_id, req.last_image)
|
||
raw_product_refs = req.product_images[:6] if req.product_images else ([req.product_image] if req.product_image else [])
|
||
product_ref_paths = [p for p in (storyboard_ref_path(job_id, r) for r in raw_product_refs) if p]
|
||
|
||
local_id = uuid.uuid4().hex[:12]
|
||
model = resolve_video_model(req.model)
|
||
seconds = video_seconds(float(req.duration or 4))
|
||
item = GeneratedVideo(
|
||
id=local_id,
|
||
provider_id="",
|
||
frame_idx=idx,
|
||
prompt=prompt,
|
||
model=model,
|
||
status="queued",
|
||
url="",
|
||
poster_url=poster,
|
||
duration=float(seconds),
|
||
progress=0,
|
||
created_at=time.time(),
|
||
)
|
||
update(job, generated_videos=[item] + job.generated_videos, message=f"视频生成已提交 · 分镜 {idx + 1}")
|
||
source_ref = req.source_ref
|
||
if source_ref and source_ref.kind == "source_video" and not source_ref.url:
|
||
source_ref = None
|
||
bg.add_task(render_storyboard_video, job_id, local_id, "", ref_path, prompt, model, seconds, req.size, source_ref, last_ref_path, product_ref_paths)
|
||
return job
|
||
|
||
|
||
@app.get("/jobs/{job_id}/storyboard-videos/{video_id}.mp4")
|
||
def get_storyboard_video(job_id: str, video_id: str):
|
||
p = job_dir(job_id) / "storyboard_videos" / video_id / "video.mp4"
|
||
if not p.exists():
|
||
raise HTTPException(404, "storyboard video not found")
|
||
return FileResponse(p, media_type="video/mp4")
|
||
|
||
|
||
class CopyProductLibraryAssetReq(BaseModel):
|
||
product_id: str
|
||
|
||
|
||
@app.get("/product-library/skg", response_model=list[ProductLibraryItem])
|
||
def list_skg_product_library() -> list[ProductLibraryItem]:
|
||
"""内置 SKG 白底产品图库。来源是本地筛选后的产品图 manifest。"""
|
||
return load_product_library_items()
|
||
|
||
|
||
@app.get("/product-library/skg/images/{filename}")
|
||
def get_skg_product_library_image(filename: str):
|
||
items = load_product_library_items()
|
||
item = next((x for x in items if Path(x.filename).name == filename), None)
|
||
if not item:
|
||
raise HTTPException(404, "product library image not found")
|
||
return FileResponse(product_library_file(item), media_type="image/jpeg")
|
||
|
||
|
||
@app.post("/jobs/{job_id}/assets")
|
||
async def upload_storyboard_asset(job_id: str, file: UploadFile = File(...)) -> dict:
|
||
if job_id not in JOBS:
|
||
raise HTTPException(404, "job not found")
|
||
asset_id = uuid.uuid4().hex[:12]
|
||
out_dir = job_dir(job_id) / "assets"
|
||
out_dir.mkdir(parents=True, exist_ok=True)
|
||
tmp = out_dir / f"{asset_id}.upload"
|
||
out = out_dir / f"{asset_id}.jpg"
|
||
try:
|
||
tmp.write_bytes(await file.read())
|
||
img = Image.open(tmp).convert("RGB")
|
||
img.thumbnail((1600, 1600), Image.Resampling.LANCZOS)
|
||
img.save(out, "JPEG", quality=94)
|
||
except Exception as e:
|
||
raise HTTPException(400, f"product image upload failed: {e}")
|
||
finally:
|
||
try:
|
||
tmp.unlink()
|
||
except Exception:
|
||
pass
|
||
return {
|
||
"kind": "asset",
|
||
"frame_idx": -1,
|
||
"element_id": asset_id,
|
||
"cutout_id": asset_id,
|
||
"label": file.filename or "SKG 产品图",
|
||
}
|
||
|
||
|
||
@app.post("/jobs/{job_id}/assets/product-library")
|
||
def copy_product_library_asset(job_id: str, req: CopyProductLibraryAssetReq) -> dict:
|
||
if job_id not in JOBS:
|
||
raise HTTPException(404, "job not found")
|
||
item = find_product_library_item(req.product_id)
|
||
src = product_library_file(item)
|
||
asset_id = uuid.uuid4().hex[:12]
|
||
out_dir = job_dir(job_id) / "assets"
|
||
out_dir.mkdir(parents=True, exist_ok=True)
|
||
out = out_dir / f"{asset_id}.jpg"
|
||
try:
|
||
img = Image.open(src).convert("RGB")
|
||
img.thumbnail((1600, 1600), Image.Resampling.LANCZOS)
|
||
img.save(out, "JPEG", quality=94)
|
||
except Exception as e:
|
||
raise HTTPException(400, f"product library copy failed: {e}")
|
||
label = f"产品融合 · {item.title} #{item.image_index}"
|
||
return {
|
||
"kind": "asset",
|
||
"frame_idx": -1,
|
||
"element_id": asset_id,
|
||
"cutout_id": asset_id,
|
||
"label": label,
|
||
}
|
||
|
||
|
||
def product_image_alpha(img: Image.Image) -> Image.Image:
|
||
rgba = img.convert("RGBA")
|
||
rgb = rgba.convert("RGB")
|
||
diff = ImageChops.difference(rgb, Image.new("RGB", rgb.size, (255, 255, 255)))
|
||
mask = diff.convert("L").point(lambda p: 0 if p < 18 else min(255, int(p * 2.4)))
|
||
mask = mask.filter(ImageFilter.GaussianBlur(0.7))
|
||
rgba.putalpha(mask)
|
||
return rgba
|
||
|
||
|
||
@app.post("/jobs/{job_id}/product-fusion/guide")
|
||
def create_product_fusion_guide(job_id: str, req: ProductFusionShot) -> dict:
|
||
if job_id not in JOBS:
|
||
raise HTTPException(404, "job not found")
|
||
person_path = storyboard_ref_path(job_id, req.person_image)
|
||
product_path = storyboard_ref_path(job_id, req.product_image)
|
||
if not person_path or not person_path.exists():
|
||
raise HTTPException(400, "person image required")
|
||
if not product_path or not product_path.exists():
|
||
raise HTTPException(400, "product image required")
|
||
if not req.product_region or req.product_region.w <= 0 or req.product_region.h <= 0:
|
||
raise HTTPException(400, "product region required")
|
||
|
||
region = req.product_region
|
||
x = max(0.0, min(1.0, float(region.x)))
|
||
y = max(0.0, min(1.0, float(region.y)))
|
||
w = max(0.02, min(1.0 - x, float(region.w)))
|
||
h = max(0.02, min(1.0 - y, float(region.h)))
|
||
|
||
try:
|
||
base = Image.open(person_path).convert("RGB")
|
||
base.thumbnail((1600, 1600), Image.Resampling.LANCZOS)
|
||
product = product_image_alpha(Image.open(product_path))
|
||
bw, bh = base.size
|
||
box = (
|
||
int(round(x * bw)),
|
||
int(round(y * bh)),
|
||
max(1, int(round(w * bw))),
|
||
max(1, int(round(h * bh))),
|
||
)
|
||
product.thumbnail((box[2], box[3]), Image.Resampling.LANCZOS)
|
||
px = box[0] + max(0, (box[2] - product.width) // 2)
|
||
py = box[1] + max(0, (box[3] - product.height) // 2)
|
||
guide = base.convert("RGBA")
|
||
guide.alpha_composite(product, (px, py))
|
||
out = guide.convert("RGB")
|
||
asset_id = uuid.uuid4().hex[:12]
|
||
out_dir = job_dir(job_id) / "assets"
|
||
out_dir.mkdir(parents=True, exist_ok=True)
|
||
out_path = out_dir / f"{asset_id}.jpg"
|
||
out.save(out_path, "JPEG", quality=94)
|
||
except Exception as e:
|
||
raise HTTPException(400, f"product fusion guide failed: {e}")
|
||
|
||
return {
|
||
"kind": "asset",
|
||
"frame_idx": -1,
|
||
"element_id": asset_id,
|
||
"cutout_id": asset_id,
|
||
"label": f"产品融合引导图 · {req.image_model or 'gpt-image-2'}",
|
||
}
|
||
|
||
|
||
def fallback_product_fusion_descriptions() -> list[str]:
|
||
return [
|
||
"人物双手拿起 SKG 颈部按摩仪,准备戴到脖子上,镜头轻微推近产品。",
|
||
"人物把 SKG 按摩仪贴合到肩颈位置,手部轻轻调整两侧机身角度。",
|
||
"人物坐在场景中轻按侧边控制区,产品保持在画框指定区域内清晰可见。",
|
||
"人物闭眼放松,肩颈从紧绷变舒展,产品佩戴位置稳定不漂移。",
|
||
"镜头靠近展示 SKG 产品材质、按键和内侧触点,手部不要遮挡产品主体。",
|
||
"使用后的放松状态收尾,人物自然抬头,产品仍保持白色 U 形外观和真实比例。",
|
||
]
|
||
|
||
|
||
@app.post("/jobs/{job_id}/product-fusion/descriptions")
|
||
def generate_product_fusion_descriptions(job_id: str, req: ProductFusionDescriptionReq) -> dict:
|
||
if job_id not in JOBS:
|
||
raise HTTPException(404, "job not found")
|
||
fallback = fallback_product_fusion_descriptions()
|
||
shots = (req.shots or [])[:6]
|
||
if not LLM_API_KEY:
|
||
return {"descriptions": fallback, "mode": "fallback"}
|
||
shot_lines = []
|
||
for i, shot in enumerate(shots, start=1):
|
||
product = (shot.product_image or {}).get("label") or "SKG 产品图"
|
||
person = (shot.person_image or {}).get("label") or "白底人物姿态图"
|
||
scene = (shot.scene_image or {}).get("label") or "场景图"
|
||
region = shot.product_region
|
||
region_text = f"x={region.x:.2f}, y={region.y:.2f}, w={region.w:.2f}, h={region.h:.2f}" if region else "未画区域"
|
||
shot_lines.append(f"{i}. 产品={product};人物={person};区域={region_text};场景={scene};已有描述={shot.action_text or '空'}")
|
||
prompt = (
|
||
"你是 SKG 产品短视频分镜导演。请为 6 条产品融合镜头各写一条中文动作描述,"
|
||
"每条 20-40 字,必须说明人物在做什么、产品如何佩戴/展示、动作如何自然连续。"
|
||
"产品是 SKG 白色 U 形颈部/肩颈按摩仪,不要写医疗治疗承诺,不要出现竞品。"
|
||
"输出 JSON:{\"descriptions\":[\"...\", \"...\"]}。\n\n"
|
||
+ "\n".join(shot_lines)
|
||
)
|
||
try:
|
||
resp = llm().chat.completions.create(
|
||
model=REWRITE_MODEL,
|
||
messages=[
|
||
{"role": "system", "content": "只输出合法 JSON,不要解释。"},
|
||
{"role": "user", "content": prompt},
|
||
],
|
||
temperature=0.5,
|
||
)
|
||
text = resp.choices[0].message.content or ""
|
||
data = json.loads(text)
|
||
descriptions = [str(x).strip() for x in data.get("descriptions", []) if str(x).strip()]
|
||
if len(descriptions) < 6:
|
||
descriptions = (descriptions + fallback)[:6]
|
||
return {"descriptions": descriptions[:6], "mode": "llm"}
|
||
except Exception:
|
||
return {"descriptions": fallback, "mode": "fallback"}
|
||
|
||
|
||
@app.get("/jobs/{job_id}/assets/{asset_id}.jpg")
|
||
def get_storyboard_asset(job_id: str, asset_id: str):
|
||
p = job_dir(job_id) / "assets" / f"{asset_id}.jpg"
|
||
if not p.exists():
|
||
raise HTTPException(404, "asset not found")
|
||
return FileResponse(p, media_type="image/jpeg")
|
||
|
||
|
||
@app.delete("/jobs/{job_id}/storyboard-videos/{video_id}", response_model=Job)
|
||
def delete_storyboard_video(job_id: str, video_id: str) -> Job:
|
||
"""删除 Video Gen 节点里的一个视频任务(成功/失败/排队都可删)。"""
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
before = len(job.generated_videos)
|
||
removed = next((v for v in job.generated_videos if v.id == video_id), None)
|
||
kept = [v for v in job.generated_videos if v.id != video_id]
|
||
if len(kept) == before:
|
||
raise HTTPException(404, "generated video not found")
|
||
out_dir = job_dir(job_id) / "storyboard_videos" / video_id
|
||
if out_dir.exists():
|
||
try:
|
||
shutil.rmtree(out_dir)
|
||
except OSError:
|
||
pass
|
||
msg = f"删除视频任务 · 分镜 {removed.frame_idx + 1}" if removed else "删除视频任务"
|
||
update(job, generated_videos=kept, message=msg)
|
||
return job
|
||
|
||
|
||
@app.put("/jobs/{job_id}/frames/{idx}/storyboard", response_model=Job)
|
||
def update_storyboard(job_id: str, idx: int, req: UpdateStoryboardReq) -> Job:
|
||
"""更新分镜的编排字段(subject / product / scene / action / duration / reference_ids)"""
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
frame = next((f for f in job.frames if f.index == idx), None)
|
||
if not frame:
|
||
raise HTTPException(404, "frame not found")
|
||
new_frames = []
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
f.storyboard = StoryboardScene(
|
||
duration=max(0.0, float(req.duration)),
|
||
first_image=req.first_image,
|
||
last_image=req.last_image,
|
||
product_images=list(req.product_images),
|
||
product_fusion_shots=list(req.product_fusion_shots),
|
||
subject_image=req.subject_image,
|
||
scene_image=req.scene_image,
|
||
product_image=req.product_image,
|
||
action_image=req.action_image,
|
||
subject=req.subject.strip(),
|
||
product=req.product.strip(),
|
||
scene=req.scene.strip(),
|
||
action=req.action.strip(),
|
||
reference_ids=list(req.reference_ids),
|
||
)
|
||
new_frames.append(f)
|
||
update(job, frames=new_frames, message=f"分镜 {idx + 1} 编排已更新")
|
||
return job
|
||
|
||
|
||
class PushStoryboardImageReq(BaseModel):
|
||
kind: Literal["keyframe", "cutout", "asset"]
|
||
frame_idx: int
|
||
element_id: str | None = None
|
||
cutout_id: str | None = None
|
||
label: str = ""
|
||
|
||
|
||
@app.post("/jobs/{job_id}/storyboard-images", response_model=Job)
|
||
def push_storyboard_image(job_id: str, req: PushStoryboardImageReq) -> Job:
|
||
"""把一张图(关键帧本身或元素提取图)推送到分镜头编排区"""
|
||
import time as _time
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
# 防重复推送:相同 frame_idx + element_id + cutout_id 已存在就跳过
|
||
for existing in job.storyboard_images:
|
||
if (existing.kind == req.kind
|
||
and existing.frame_idx == req.frame_idx
|
||
and existing.element_id == req.element_id
|
||
and existing.cutout_id == req.cutout_id):
|
||
return job
|
||
img = StoryboardImage(
|
||
ref_id=uuid.uuid4().hex[:8],
|
||
kind=req.kind,
|
||
frame_idx=req.frame_idx,
|
||
element_id=req.element_id,
|
||
cutout_id=req.cutout_id,
|
||
label=req.label.strip(),
|
||
created_at=_time.time(),
|
||
)
|
||
update(job, storyboard_images=job.storyboard_images + [img], message=f"上推到分镜头编排 · {req.label or req.kind}")
|
||
return job
|
||
|
||
|
||
@app.delete("/jobs/{job_id}/storyboard-images/{ref_id}", response_model=Job)
|
||
def remove_storyboard_image(job_id: str, ref_id: str) -> Job:
|
||
"""从分镜头编排区移除一张图"""
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
before = len(job.storyboard_images)
|
||
new_list = [x for x in job.storyboard_images if x.ref_id != ref_id]
|
||
if len(new_list) == before:
|
||
raise HTTPException(404, "storyboard image not found")
|
||
update(job, storyboard_images=new_list, message="从分镜头编排移除一张图")
|
||
return job
|
||
|
||
|
||
@app.get("/jobs/{job_id}/frames/{idx}/elements/{element_id}/cutouts/{cutout_id}.jpg")
|
||
def get_cutout_versioned(job_id: str, idx: int, element_id: str, cutout_id: str):
|
||
p = job_dir(job_id) / "elements" / f"{idx:03d}_{element_id}_{cutout_id}.jpg"
|
||
if not p.exists():
|
||
raise HTTPException(404, "cutout not found")
|
||
return FileResponse(p, media_type="image/jpeg")
|
||
|
||
|
||
@app.get("/jobs/{job_id}/frames/{idx}/elements/{element_id}/cutout.jpg")
|
||
def get_cutout(job_id: str, idx: int, element_id: str):
|
||
"""旧路径兼容(v1 单图)→ 找 elements/{idx}_{element_id}.jpg 或 .png"""
|
||
p = job_dir(job_id) / "elements" / f"{idx:03d}_{element_id}.jpg"
|
||
if not p.exists():
|
||
legacy = job_dir(job_id) / "elements" / f"{idx:03d}_{element_id}.png"
|
||
if legacy.exists():
|
||
return FileResponse(legacy, media_type="image/jpeg")
|
||
raise HTTPException(404, "cutout not found")
|
||
return FileResponse(p, media_type="image/jpeg")
|
||
|
||
|
||
# ---------- 删除:关键帧 / 单张生成图 ----------
|
||
|
||
@app.delete("/jobs/{job_id}/frames/{idx}", response_model=Job)
|
||
def delete_frame(job_id: str, idx: int) -> Job:
|
||
"""删除整张关键帧,清理所有附属文件(原图 / 干净版 / 元素抠图 / 生成图)"""
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
target = next((f for f in job.frames if f.index == idx), None)
|
||
if not target:
|
||
raise HTTPException(404, "frame not found")
|
||
|
||
d = job_dir(job_id)
|
||
# 删文件 — 静默错误,文件可能不存在
|
||
paths = [
|
||
d / "frames" / f"{idx:03d}.jpg",
|
||
d / "cleaned" / f"{idx:03d}.jpg",
|
||
]
|
||
for p in paths:
|
||
if p.exists():
|
||
try: p.unlink()
|
||
except OSError: pass
|
||
# 该帧的所有元素抠图(命名前缀 {idx:03d}_)
|
||
elements_dir = d / "elements"
|
||
if elements_dir.exists():
|
||
for ext in ("png", "jpg"):
|
||
for p in elements_dir.glob(f"{idx:03d}_*.{ext}"):
|
||
try: p.unlink()
|
||
except OSError: pass
|
||
# 该帧的所有生成图
|
||
gen_dir = d / "gen"
|
||
if gen_dir.exists():
|
||
for p in gen_dir.glob(f"{idx:03d}_*.jpg"):
|
||
try: p.unlink()
|
||
except OSError: pass
|
||
|
||
new_frames = [f for f in job.frames if f.index != idx]
|
||
update(job, frames=new_frames, message=f"删除分镜 {idx + 1}")
|
||
return job
|
||
|
||
|
||
@app.delete("/jobs/{job_id}/frames/{idx}/gen/{gen_id}", response_model=Job)
|
||
def delete_generated(job_id: str, idx: int, gen_id: str) -> Job:
|
||
"""删除该 frame 的某张生成图(文件 + 列表)"""
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
frame = next((f for f in job.frames if f.index == idx), None)
|
||
if not frame:
|
||
raise HTTPException(404, "frame not found")
|
||
|
||
p = job_dir(job_id) / "gen" / f"{idx:03d}_{gen_id}.jpg"
|
||
if p.exists():
|
||
try: p.unlink()
|
||
except OSError: pass
|
||
|
||
new_frames = []
|
||
found = False
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
before = len(f.generated_images)
|
||
f.generated_images = [g for g in f.generated_images if g.id != gen_id]
|
||
found = len(f.generated_images) < before
|
||
new_frames.append(f)
|
||
if not found:
|
||
raise HTTPException(404, "generated image not found")
|
||
update(job, frames=new_frames, message=f"删除生成图 · 分镜 {idx + 1}")
|
||
return job
|