from __future__ import annotations import asyncio import base64 import json import os import shutil import subprocess import time import uuid from contextlib import asynccontextmanager from pathlib import Path from typing import Literal from dotenv import load_dotenv from fastapi import BackgroundTasks, FastAPI, File, HTTPException, UploadFile from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse from pydantic import BaseModel, Field load_dotenv() JOBS_DIR = Path(os.getenv("JOBS_DIR", "./jobs")).resolve() JOBS_DIR.mkdir(parents=True, exist_ok=True) CORS_ORIGINS = [o.strip() for o in os.getenv("CORS_ORIGINS", "http://localhost:4290").split(",") if o.strip()] LLM_BASE_URL = os.getenv("LLM_BASE_URL", "").strip() LLM_API_KEY = os.getenv("LLM_API_KEY", "").strip() ASR_MODEL = os.getenv("ASR_MODEL", "whisper-1") TRANSLATE_MODEL = os.getenv("TRANSLATE_MODEL", "gemini-2.5-flash") REWRITE_MODEL = os.getenv("REWRITE_MODEL", "gemini-2.5-pro") VISION_MODEL = os.getenv("VISION_MODEL", "gemini-2.5-flash") IMAGE_MODEL = os.getenv("IMAGE_MODEL", "gemini-3-pro-image-preview") VIDEO_MODEL = os.getenv("VIDEO_MODEL", "seedance").strip() or "seedance" POE_API_BASE_URL = os.getenv("POE_API_BASE_URL", "https://api.poe.com/v1").strip() or "https://api.poe.com/v1" POE_API_KEY = os.getenv("POE_API_KEY", "").strip() def env_video_model(name: str, default: str) -> str: value = os.getenv(name, "").strip() if not value: return default # Older local envs used business aliases as model IDs. Keep those aliases usable # while mapping them to concrete Poe video model IDs by default. if value.lower() in {"seedance", "kling", "veo", "veo3", "voe"}: return default return value VIDEO_MODEL_ALIASES = { "seedance": env_video_model("VIDEO_MODEL_SEEDANCE", "seedance-2-fast"), "kling": env_video_model("VIDEO_MODEL_KLING", "kling-omni"), "veo3": env_video_model("VIDEO_MODEL_VEO3", "veo-3.1-fast"), "veo": env_video_model("VIDEO_MODEL_VEO3", "veo-3.1-fast"), "voe": env_video_model("VIDEO_MODEL_VEO3", "veo-3.1-fast"), } VIDEO_API_BASE_URL = os.getenv("VIDEO_API_BASE_URL", "").strip() VIDEO_API_KEY = os.getenv("VIDEO_API_KEY", "").strip() VIDEO_CREATE_PATH = os.getenv("VIDEO_CREATE_PATH", "/videos").strip() or "/videos" VIDEO_CREATE_PATHS = [ p.strip() for p in os.getenv("VIDEO_CREATE_PATHS", f"{VIDEO_CREATE_PATH},/videos/generations,/video/generations").split(",") if p.strip() ] VIDEO_STATUS_PATH = os.getenv("VIDEO_STATUS_PATH", "/videos/{id}").strip() or "/videos/{id}" VIDEO_CONTENT_PATH = os.getenv("VIDEO_CONTENT_PATH", "/videos/{id}/content").strip() or "/videos/{id}/content" VIDEO_DURATION_FIELD = os.getenv("VIDEO_DURATION_FIELD", "seconds").strip() or "seconds" # OpenAI 客户端(OpenAI 兼容网关,含 SKG ezlink) from openai import OpenAI _llm_client: OpenAI | None = None def llm() -> OpenAI: global _llm_client if _llm_client is None: if not LLM_API_KEY: raise RuntimeError("LLM_API_KEY 未配置") _llm_client = OpenAI(base_url=LLM_BASE_URL or None, api_key=LLM_API_KEY) return _llm_client # Pipeline 状态: # created → downloading → downloaded(停,等用户点解析)→ splitting → frames_extracted # → transcribing → transcribed | failed JobStatus = Literal[ "created", "downloading", "downloaded", "splitting", "frames_extracted", "transcribing", "transcribed", "failed", ] KEYFRAME_COUNT = int(os.getenv("KEYFRAME_COUNT", "5")) FrameExtractTarget = Literal["balanced", "subject", "transition", "expression", "motion"] FRAME_TARGET_LABELS: dict[FrameExtractTarget, str] = { "balanced": "综合关键帧", "subject": "清晰主体", "transition": "转场变化", "expression": "表情瞬间", "motion": "动作峰值", } class GeneratedImage(BaseModel): id: str # uuid hex 12 prompt: str model: str mode: str = "edit" # "edit"(带参考图) | "text"(纯文字) url: str # /jobs/{job_id}/frames/{idx}/gen/{id}.jpg selected: bool = False created_at: float = 0.0 class GeneratedVideo(BaseModel): id: str provider_id: str = "" frame_idx: int prompt: str model: str = "" status: Literal["queued", "in_progress", "completed", "failed"] = "queued" url: str = "" poster_url: str = "" duration: float = 4.0 progress: int = 0 error: str = "" created_at: float = 0.0 class VideoSourceRef(BaseModel): kind: Literal["image", "source_video"] = "image" url: str = "" class StoryboardScene(BaseModel): """分镜头编排:每个 selected 分镜对应一个 scene 描述 v2: 4 图槽 + 时长(复制粘贴模式)— 主体 / 场景 / 产品 / 动作 各一张图 v1 字段保留兼容(subject/product/scene/action/reference_ids)""" duration: float = 0 first_image: dict | None = None last_image: dict | None = None product_images: list[dict] = Field(default_factory=list) # 4 图槽:dict 含 {kind, frame_idx, element_id?, cutout_id?, label} subject_image: dict | None = None scene_image: dict | None = None product_image: dict | None = None action_image: dict | None = None # v1 兼容 subject: str = "" product: str = "" scene: str = "" action: str = "" reference_ids: list[str] = [] class StoryboardImage(BaseModel): """用户从各处"上推"到分镜头编排区的图片""" ref_id: str # uuid hex 8 kind: Literal["keyframe", "cutout"] # keyframe = 关键帧本身 / cutout = 元素提取图 frame_idx: int element_id: str | None = None # cutout 时 cutout_id: str | None = None # cutout 时(versioned id;老数据可能 == element_id) label: str = "" # 显示用名字 created_at: float = 0.0 class KeyElement(BaseModel): """关键帧里识别 / 用户提取的元素 · 多次提取累积多张图,让用户挑选满意的""" id: str # uuid hex 8 name_zh: str name_en: str = "" position: str = "" source: Literal["auto", "manual", "region"] = "manual" region: dict | None = None # 多张提取图 id(每次 cutout 端点累积新 id)→ /jobs/.../elements/{element_id}/cutouts/{cutout_id}.jpg cutouts: list[str] = [] # 旧字段兼容(v1 单图)· 渲染时 fallback 用,新提取不再写入 cutout_id: str | None = None cutout_background: Literal["white", "black"] = "white" created_at: float = 0.0 class KeyFrame(BaseModel): index: int timestamp: float url: str description: dict | None = None # vision 模型识别结果 {scene, objects, style, suggested_prompt} cleaned_url: str | None = None # 清洗后干净版(待应用)→ /jobs/{id}/frames/{idx}/cleaned.jpg cleaned_applied: bool = False # 是否已用清洗版替换原图(替换后 cleaned_url=null) elements: list[KeyElement] = [] # 提取的元素清单(持久化) storyboard: StoryboardScene | None = None # 分镜头编排字段 generated_images: list[GeneratedImage] = [] class TranscriptSegment(BaseModel): index: int start: float end: float en: str zh: str = "" class Job(BaseModel): id: str url: str status: JobStatus = "created" progress: int = 0 message: str = "" video_url: str = "" duration: float = 0.0 width: int = 0 height: int = 0 frames: list[KeyFrame] = Field(default_factory=list) transcript: list[TranscriptSegment] = Field(default_factory=list) storyboard_images: list[StoryboardImage] = Field(default_factory=list) generated_videos: list[GeneratedVideo] = Field(default_factory=list) error: str = "" JOBS: dict[str, Job] = {} def job_dir(job_id: str) -> Path: d = JOBS_DIR / job_id d.mkdir(parents=True, exist_ok=True) return d def save_state(job: Job) -> None: (job_dir(job.id) / "state.json").write_text(job.model_dump_json(indent=2)) def update(job: Job, **kw) -> None: for k, v in kw.items(): setattr(job, k, v) save_state(job) def public_api_base() -> str: return (LLM_BASE_URL or "https://api.openai.com/v1").rstrip("/") def video_uses_poe() -> bool: if VIDEO_API_BASE_URL: return VIDEO_API_BASE_URL.rstrip("/") == POE_API_BASE_URL.rstrip("/") return bool(POE_API_KEY) def video_uses_ark() -> bool: return "ark.cn-beijing.volces.com" in video_api_base() def video_api_base() -> str: if VIDEO_API_BASE_URL: return VIDEO_API_BASE_URL.rstrip("/") if POE_API_KEY: return POE_API_BASE_URL.rstrip("/") return (LLM_BASE_URL or "https://api.openai.com/v1").rstrip("/") def video_api_key() -> str: if VIDEO_API_KEY: return VIDEO_API_KEY if video_uses_poe(): return POE_API_KEY return LLM_API_KEY def video_path(template: str, **values: str) -> str: path = template.format(**values) return path if path.startswith("/") else f"/{path}" def ensure_video_api_configured() -> None: if not video_api_key(): raise HTTPException(503, "POE_API_KEY、VIDEO_API_KEY 或 LLM_API_KEY 未配置,无法调用生视频 API") def storyboard_ref_path(job_id: str, ref: dict | None) -> Path | None: if not ref: return None try: kind = ref.get("kind") frame_idx = int(ref.get("frame_idx")) except Exception: return None if kind == "keyframe": clean = job_dir(job_id) / "cleaned" / f"{frame_idx:03d}.jpg" if clean.exists(): return clean p = job_dir(job_id) / "frames" / f"{frame_idx:03d}.jpg" return p if p.exists() else None if kind == "cutout": element_id = (ref.get("element_id") or "").strip() cutout_id = (ref.get("cutout_id") or "").strip() if not element_id: return None candidates = [] if cutout_id and cutout_id != element_id: candidates.append(job_dir(job_id) / "elements" / f"{frame_idx:03d}_{element_id}_{cutout_id}.jpg") candidates.append(job_dir(job_id) / "elements" / f"{frame_idx:03d}_{element_id}.jpg") candidates.append(job_dir(job_id) / "elements" / f"{frame_idx:03d}_{element_id}.png") for p in candidates: if p.exists(): return p if kind == "asset": asset_id = (ref.get("element_id") or ref.get("cutout_id") or "").strip() if not asset_id: return None p = job_dir(job_id) / "assets" / f"{asset_id}.jpg" return p if p.exists() else None return None def storyboard_ref_url(job_id: str, ref: dict | None) -> str: if not ref: return "" kind = ref.get("kind") frame_idx = ref.get("frame_idx") if kind == "keyframe" and frame_idx is not None: return f"/jobs/{job_id}/frames/{int(frame_idx)}.jpg" if kind == "cutout" and frame_idx is not None and ref.get("element_id"): element_id = ref.get("element_id") cutout_id = ref.get("cutout_id") if cutout_id and cutout_id != element_id: return f"/jobs/{job_id}/frames/{int(frame_idx)}/elements/{element_id}/cutouts/{cutout_id}.jpg" return f"/jobs/{job_id}/frames/{int(frame_idx)}/elements/{element_id}/cutout.jpg" if kind == "asset" and ref.get("element_id"): return f"/jobs/{job_id}/assets/{ref.get('element_id')}.jpg" return "" def prepare_video_reference(src: Path, dst: Path, size: tuple[int, int] = (720, 1280)) -> None: dst.parent.mkdir(parents=True, exist_ok=True) img = Image.open(src).convert("RGB") img.thumbnail(size, Image.Resampling.LANCZOS) canvas = Image.new("RGB", size, (8, 8, 10)) x = (size[0] - img.width) // 2 y = (size[1] - img.height) // 2 canvas.paste(img, (x, y)) canvas.save(dst, "JPEG", quality=94) def update_generated_video(job_id: str, video_id: str, **kw) -> None: job = JOBS.get(job_id) if not job: return updated = [] for v in job.generated_videos: if v.id == video_id: data = v.model_dump() data.update(kw) updated.append(GeneratedVideo(**data)) else: updated.append(v) update(job, generated_videos=updated) @asynccontextmanager async def lifespan(_: FastAPI): # 启动时从磁盘恢复 jobs(简化版:只列目录) for p in JOBS_DIR.iterdir(): if p.is_dir() and (p / "state.json").exists(): try: JOBS[p.name] = Job.model_validate_json((p / "state.json").read_text()) except Exception: pass yield app = FastAPI(title="SKG TK 二创 API", lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=CORS_ORIGINS, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ---------- Pipeline 实现 ---------- def run(cmd: list[str], cwd: Path | None = None) -> str: res = subprocess.run(cmd, cwd=cwd, capture_output=True, text=True) if res.returncode != 0: # ffmpeg 把 banner 写 stderr,挑最后几行(真错误一般在末尾) tail = "\n".join(res.stderr.splitlines()[-12:]) or res.stderr[-800:] raise RuntimeError(f"cmd failed: {' '.join(cmd[:3])}... · {tail}") return res.stdout # ---- 启发式选帧工具 ---- import imagehash import numpy as np from PIL import Image, ImageEnhance, ImageFilter, ImageOps def _sharpness_from_gray(g: np.ndarray) -> float: """Laplacian variance:值越大越清晰,模糊/转场帧值低。""" lap = (-4 * g[1:-1, 1:-1] + g[:-2, 1:-1] + g[2:, 1:-1] + g[1:-1, :-2] + g[1:-1, 2:]) return float(lap.var()) def _frame_metrics(img_path: Path, idx: int, timestamp: float) -> dict | None: """低清候选帧的本地评分特征。只用于排序,最终仍从原视频抽原尺寸帧。""" try: with Image.open(img_path) as raw: img = raw.convert("RGB") h = imagehash.phash(img) small = img.resize((160, 90)) except Exception: return None arr = np.asarray(small, dtype=np.float32) # Rec. 601 luma,保留 0-255 范围,便于和清晰度 / 对比度阈值一起看。 gray = (0.299 * arr[:, :, 0] + 0.587 * arr[:, :, 1] + 0.114 * arr[:, :, 2]).astype(np.float32) center = gray[22:68, 40:120] rg = arr[:, :, 0] - arr[:, :, 1] yb = 0.5 * (arr[:, :, 0] + arr[:, :, 1]) - arr[:, :, 2] colorfulness = float(np.sqrt(rg.var() + yb.var()) + 0.3 * np.sqrt(rg.mean() ** 2 + yb.mean() ** 2)) return { "path": img_path, "idx": idx, "timestamp": timestamp, "hash": h, "gray": gray, "sharp": _sharpness_from_gray(gray), "center_sharp": _sharpness_from_gray(center), "brightness": float(gray.mean()), "contrast": float(gray.std()), "colorfulness": colorfulness, "scene_score": 0.0, "motion": 0.0, } def _attach_temporal_metrics(items: list[dict]) -> None: """相邻低清帧差异:转场 / 动作目标依赖它,不需要逐帧高分辨率扫描。""" for i, it in enumerate(items): prev_delta = 0.0 next_delta = 0.0 if i > 0: prev_delta = float(np.mean(np.abs(it["gray"] - items[i - 1]["gray"])) / 255.0) if i + 1 < len(items): next_delta = float(np.mean(np.abs(items[i + 1]["gray"] - it["gray"])) / 255.0) it["scene_score"] = max(prev_delta, next_delta) it["motion"] = (prev_delta + next_delta) / 2.0 def _normalize_item_metrics(items: list[dict]) -> None: for key in ("sharp", "center_sharp", "contrast", "colorfulness", "scene_score", "motion"): vals = [float(it.get(key, 0.0)) for it in items if float(it.get(key, 0.0)) > 0] cap = float(np.percentile(vals, 95)) if vals else 1.0 if cap <= 0: cap = 1.0 for it in items: it[f"{key}_n"] = min(float(it.get(key, 0.0)) / cap, 1.0) def _target_score(item: dict, target: FrameExtractTarget) -> float: sharp = float(item.get("sharp_n", 0.0)) center = float(item.get("center_sharp_n", 0.0)) contrast = float(item.get("contrast_n", 0.0)) color = float(item.get("colorfulness_n", 0.0)) scene = float(item.get("scene_score_n", 0.0)) motion = float(item.get("motion_n", 0.0)) if target == "subject": score = center * 0.48 + sharp * 0.25 + contrast * 0.17 + color * 0.10 elif target == "transition": score = scene * 0.55 + sharp * 0.28 + contrast * 0.12 + color * 0.05 elif target == "expression": # 没有额外视觉模型时,表情/动物瞬间只能用中心细节 + 清晰 + 轻微动作变化做本地近似。 score = center * 0.40 + sharp * 0.24 + motion * 0.18 + contrast * 0.12 + color * 0.06 elif target == "motion": score = motion * 0.45 + sharp * 0.30 + center * 0.15 + contrast * 0.10 else: score = sharp * 0.45 + scene * 0.22 + center * 0.15 + contrast * 0.12 + color * 0.06 brightness = float(item.get("brightness", 0.0)) raw_contrast = float(item.get("contrast", 0.0)) if raw_contrast < 4 or brightness < 8 or brightness > 247: return score * 0.15 if raw_contrast < 9: return score * 0.65 return score def _select_keyframes(candidates: list[dict], n: int, target: FrameExtractTarget, dup_threshold: int = 8) -> list[dict]: """ candidates: 按时间排序的低清候选帧评分项 n: 目标帧数 dup_threshold: pHash 汉明距离 < 此值视为相似(默认 8,64bit hash 大致 ~12.5% 像素差) """ if len(candidates) <= n: return candidates _attach_temporal_metrics(candidates) _normalize_item_metrics(candidates) for it in candidates: it["score"] = _target_score(it, target) # 去重:相似帧保留当前目标下分数更高的 deduped: list[dict] = [] for it in candidates: dup = None for kept in deduped: if (it["hash"] - kept["hash"]) < dup_threshold: dup = kept break if dup is None: deduped.append(it) elif it["score"] > dup["score"]: deduped[deduped.index(dup)] = it # 时序分桶:把候选时间轴等分 n 段,每段取当前目标下最优的 total = len(candidates) buckets: list[list[dict]] = [[] for _ in range(n)] for it in deduped: b = min(int(it["idx"] * n / total), n - 1) buckets[b].append(it) selected: list[dict] = [] for b in buckets: if b: selected.append(max(b, key=lambda x: x["score"])) # 空桶补足:从未选的 deduped 里按目标分数补 chosen_paths = {it["path"] for it in selected} remaining = sorted([it for it in deduped if it["path"] not in chosen_paths], key=lambda x: -x["score"]) while len(selected) < n and remaining: selected.append(remaining.pop(0)) # 按时间排序输出 selected.sort(key=lambda x: x["idx"]) return selected def ffprobe_meta(mp4: Path) -> dict: out = run([ "ffprobe", "-v", "error", "-print_format", "json", "-show_streams", "-show_format", str(mp4), ]) return json.loads(out) async def pipeline_download(job_id: str) -> None: """阶段 1:仅下载(或上传跳过),落 source.mp4,停在 downloaded 等用户点解析。""" job = JOBS[job_id] d = job_dir(job_id) try: mp4 = d / "source.mp4" if mp4.exists(): update(job, status="downloading", message="本地上传 · 跳过下载", progress=15) else: update(job, status="downloading", message="yt-dlp 下载中…", progress=5) run([ "yt-dlp", "-f", "best[ext=mp4]/best", "-o", str(mp4), "--no-warnings", "--no-playlist", "--retries", "3", job.url, ]) if not mp4.exists(): raise RuntimeError("下载完成但找不到 source.mp4") meta = ffprobe_meta(mp4) v_stream = next((s for s in meta["streams"] if s["codec_type"] == "video"), None) duration = float(meta["format"]["duration"]) update( job, status="downloaded", video_url=f"/jobs/{job_id}/video.mp4", duration=duration, width=int(v_stream["width"]) if v_stream else 0, height=int(v_stream["height"]) if v_stream else 0, progress=25, message=f"视频就绪 · {duration:.1f}s · 等待解析", ) except Exception as e: update(job, status="failed", error=str(e), message="下载失败") async def pipeline_analyze( job_id: str, frame_count: int = KEYFRAME_COUNT, target: FrameExtractTarget = "balanced", ) -> None: """阶段 2:拆音轨 + 抽关键帧。ASR/翻译是独立文案轨,不阻塞视觉素材流。""" job = JOBS[job_id] d = job_dir(job_id) try: mp4 = d / "source.mp4" if not mp4.exists(): raise RuntimeError("source.mp4 不存在,先完成下载") update(job, status="splitting", message="ffmpeg 拆分音轨…", progress=35) wav = d / "audio.wav" run([ "ffmpeg", "-y", "-i", str(mp4), "-vn", "-ac", "1", "-ar", "16000", "-c:a", "pcm_s16le", str(wav), ]) n = max(1, min(int(frame_count), 20)) target_label = FRAME_TARGET_LABELS.get(target, FRAME_TARGET_LABELS["balanced"]) duration = max(float(job.duration or 1.0), 0.1) scan_fps = min(2.0, max(0.02, 180.0 / duration)) estimated_scan_count = max(1, int(duration * scan_fps)) update(job, message=f"低清扫描候选 · {target_label} · 约 {estimated_scan_count} 帧…", progress=45) frames_dir = d / "frames" if frames_dir.exists(): shutil.rmtree(frames_dir) frames_dir.mkdir(parents=True) scan_dir = d / "frame_scan" if scan_dir.exists(): shutil.rmtree(scan_dir) scan_dir.mkdir(parents=True) # 1) 低分辨率、低帧率扫描。扫描图只用于候选评分,最终不直接作为关键帧。 run([ "ffmpeg", "-y", "-i", str(mp4), "-vf", f"fps={scan_fps:.4f},scale=360:-2", "-q:v", "4", str(scan_dir / "s_%05d.jpg"), ]) scan_paths = sorted(scan_dir.glob("s_*.jpg")) if not scan_paths: raise RuntimeError("低清扫描没有生成候选帧") candidates: list[dict] = [] for i, p in enumerate(scan_paths): t = min(i / scan_fps, max(duration - 0.05, 0.0)) item = _frame_metrics(p, i, t) if item: candidates.append(item) if not candidates: raise RuntimeError("候选帧评分失败") # 2) 目标化筛选:pHash 去重 + 清晰度 / 中心细节 / 转场变化 / 动作强度 + 时序分桶。 update(job, message=f"{target_label}筛选 {n} / {len(candidates)} 张…", progress=60) chosen = _select_keyframes(candidates, n, target) # 3) 只对最终选中的时间点,从原视频抽高质量关键帧。 renamed: list[KeyFrame] = [] chosen_sorted = sorted(chosen, key=lambda it: float(it["timestamp"])) for i, item in enumerate(chosen_sorted): dst = frames_dir / f"{i:03d}.jpg" t = float(item["timestamp"]) run([ "ffmpeg", "-y", "-ss", f"{t:.3f}", "-i", str(mp4), "-frames:v", "1", "-pix_fmt", "yuvj420p", "-q:v", "3", str(dst), ]) renamed.append(KeyFrame( index=i, timestamp=round(t, 2), url=f"/jobs/{job_id}/frames/{i}.jpg", )) # 4) 清理扫描目录 shutil.rmtree(scan_dir, ignore_errors=True) update( job, status="frames_extracted", frames=renamed, progress=70, message=f"已按「{target_label}」抽取 {len(renamed)} 张关键帧 · 可继续清洗 / 提取元素 / 分镜编排", ) except Exception as e: update(job, status="failed", error=str(e), message="解析失败") # ---------- Gemini ASR + 翻译 ---------- def _transcribe_sync(wav: Path) -> list[dict]: """whisper-1 verbose_json → segments[{start, end, text}]""" with wav.open("rb") as f: resp = llm().audio.transcriptions.create( file=(wav.name, f, "audio/wav"), model=ASR_MODEL, response_format="verbose_json", timestamp_granularities=["segment"], ) raw = resp.model_dump() if hasattr(resp, "model_dump") else resp segments = raw.get("segments") or [] # 兜底:网关如果不返回 segments,把全文当一段 if not segments and raw.get("text"): segments = [{"start": 0.0, "end": float(raw.get("duration", 0) or 0), "text": raw["text"]}] return segments def _translate_sync(segments: list[dict]) -> list[str]: """gemini-2.5-flash 批量翻译为中文,按段返回""" payload = [{"i": i, "en": s.get("text", "").strip()} for i, s in enumerate(segments)] prompt = ( "你是字幕翻译。把下列英文字幕段翻译为简体中文,保持原意、口语化、自然流畅。" "严格返回 JSON 数组,不要任何 markdown 或多余文字,schema: " '[{"i": 0, "zh": "..."}, ...]\n\n输入:\n' + json.dumps(payload, ensure_ascii=False) ) resp = llm().chat.completions.create( model=TRANSLATE_MODEL, messages=[{"role": "user", "content": prompt}], response_format={"type": "json_object"}, temperature=0.2, ) content = resp.choices[0].message.content or "[]" try: data = json.loads(content) if isinstance(data, dict): for k in ("data", "items", "result", "translations"): if k in data and isinstance(data[k], list): data = data[k] break if not isinstance(data, list): data = [] except json.JSONDecodeError: data = [] zh_by_idx: dict[int, str] = {} for it in data: if isinstance(it, dict) and "i" in it: zh_by_idx[int(it["i"])] = str(it.get("zh", "")) return [zh_by_idx.get(i, "") for i in range(len(segments))] async def pipeline_transcribe(job_id: str) -> None: job = JOBS[job_id] d = job_dir(job_id) wav = d / "audio.wav" try: if not wav.exists(): raise RuntimeError("audio.wav 不存在") if not LLM_API_KEY: # 无 key 模式:mock 数据 update(job, status="transcribing", message="ASR (mock) …", progress=75) await asyncio.sleep(1.0) mock = [ TranscriptSegment(index=0, start=0.0, end=3.5, en="Welcome back, today we're testing something new.", zh="欢迎回来,今天我们要测试一些新东西。"), TranscriptSegment(index=1, start=3.5, end=7.2, en="This device looks really sleek and minimal.", zh="这个设备看起来非常时尚和简约。"), ] update(job, transcript=mock, status="transcribed", progress=100, message="转录完成(MOCK · 未设 LLM_API_KEY)") return # 1) whisper ASR update(job, status="transcribing", message=f"{ASR_MODEL} 转录中…", progress=78) segments = await asyncio.to_thread(_transcribe_sync, wav) if not segments: raise RuntimeError("ASR 返回 0 段(可能无人声 / 格式问题)") # 先把英文段落落到 job 上(让 UI 提前看到,翻译再补 zh) en_only = [ TranscriptSegment( index=i, start=float(s.get("start", 0)), end=float(s.get("end", 0)), en=str(s.get("text", "")).strip(), zh="", ) for i, s in enumerate(segments) ] update(job, transcript=en_only, message=f"ASR 完成 · {len(en_only)} 段,开始翻译…", progress=88) # 2) Gemini 翻译 zh_list = await asyncio.to_thread(_translate_sync, segments) full = [ TranscriptSegment( index=seg.index, start=seg.start, end=seg.end, en=seg.en, zh=zh_list[i] if i < len(zh_list) else "", ) for i, seg in enumerate(en_only) ] update(job, transcript=full, status="transcribed", progress=100, message=f"转录完成 · {len(full)} 段({ASR_MODEL} + {TRANSLATE_MODEL})") except Exception as e: update(job, status="failed", error=str(e), message="转录失败") def _image_edit_call( image_path: Path, prompt: str, model: str | None = None, models: list[str] | None = None, fallback_text: bool = False, max_attempts: int = 3, max_side: int = 1024, ) -> tuple[bytes, str]: """通用 image edit 调用 · 失败重试 + 可选 text fallback。 返回 (image_bytes, effective_mode) where effective_mode in {"edit","text"}。 失败 raise RuntimeError。 输入图自动 resize 到 max_side(默认 1024)边长后再 base64,避免大图把 Gemini function call 输入挤超阈值导致 incomplete_generation。 models: 多模型轮换列表,重试时换 model;不传则单一 model 重试。""" import base64 as b64lib import io as _io import time as _time import httpx from PIL import Image as _PILImage if not LLM_API_KEY: raise RuntimeError("LLM_API_KEY 未配置") # model 优先级:models 列表 > 单个 model 参数 > IMAGE_MODEL if models and len(models) > 0: models_cycle = list(models) else: models_cycle = [model or IMAGE_MODEL] model = models_cycle[0] # 缩到 max_side 内 try: im = _PILImage.open(image_path) if max(im.size) > max_side: im.thumbnail((max_side, max_side), _PILImage.LANCZOS) buf = _io.BytesIO() im.convert("RGB").save(buf, format="JPEG", quality=88) img_bytes_in = buf.getvalue() except Exception: # PIL 失败兜底走原文件 img_bytes_in = image_path.read_bytes() img_b64 = b64lib.b64encode(img_bytes_in).decode("ascii") data_uri = f"data:image/jpeg;base64,{img_b64}" plan: list[str] = ["edit"] * max_attempts if fallback_text: plan.append("text") last_err = "" resp_data: dict = {} effective_mode = "edit" for attempt, current_mode in enumerate(plan): # 多模型轮换:第 N 次重试用第 N 个 model(不够时用最后一个) current_model = models_cycle[min(attempt, len(models_cycle) - 1)] try: if current_mode == "edit": with httpx.Client(timeout=120) as client: r = client.post( f"{LLM_BASE_URL}/images/generations", headers={ "Authorization": f"Bearer {LLM_API_KEY}", "Content-Type": "application/json", }, json={"model": current_model, "prompt": prompt, "image": data_uri, "n": 1}, ) r.raise_for_status() resp_data = r.json() else: resp = llm().images.generate(model=current_model, prompt=prompt, n=1) resp_data = resp.model_dump() if hasattr(resp, "model_dump") else {"data": [{"b64_json": resp.data[0].b64_json}]} if resp_data.get("data"): effective_mode = current_mode model = current_model # 记录实际成功的 model break err_obj = resp_data.get("error") or {} last_err = f"empty data · {err_obj.get('code', '')} · {str(err_obj.get('message', ''))[:200]} · model={current_model}" except httpx.HTTPStatusError as e: body = e.response.text # 多模型轮换场景:除明确不可恢复(4xx 鉴权类)外都重试换 model sc = e.response.status_code fatal = sc in (401, 403) last_err = f"HTTP {sc}: {body[:200]} · model={current_model}" if fatal: raise RuntimeError(f"image edit HTTP {sc}: {body[:300]}") except Exception as e: last_err = f"{type(e).__name__}: {e} · model={current_model}" if attempt < len(plan) - 1: next_model = models_cycle[min(attempt + 1, len(models_cycle) - 1)] tag = f"retry {attempt + 1}/{len(plan)} → {next_model}" print(f"[image edit {tag}] {last_err}", flush=True) _time.sleep(1.0) data_arr = resp_data.get("data", []) if not data_arr: raise RuntimeError(f"image edit failed after {len(plan)} attempts: {last_err}") b64 = data_arr[0].get("b64_json") if not b64: raise RuntimeError("image edit returned no b64_json") return b64lib.b64decode(b64), effective_mode # ---------- API 路由 ---------- class CreateJobReq(BaseModel): url: str class TranslateReq(BaseModel): text: str target: Literal["en", "zh"] = "en" @app.post("/translate") def translate_text(req: TranslateReq) -> dict: """单条文本翻译(给生图自定义提取元素 zh→en 用)""" import re as _re text = req.text.strip() if not text: return {"text": ""} if not LLM_API_KEY: raise HTTPException(503, "LLM_API_KEY 未配置") target_label = "English" if req.target == "en" else "Simplified Chinese" prompt = ( f"Translate the following text into concise {target_label}, suitable as an element " "label in an image-generation prompt. Output only the translation itself — no quotes, " "no punctuation, no explanation, no markdown.\n\n" f"Input: {text}" ) try: resp = llm().chat.completions.create( model=TRANSLATE_MODEL, messages=[{"role": "user", "content": prompt}], temperature=0.2, max_tokens=200, ) out = (resp.choices[0].message.content or "").strip() if not out: rc = getattr(resp.choices[0].message, "reasoning_content", "") or "" if rc: out = rc.strip().splitlines()[-1].strip() out = _re.sub(r'^[\'"「『]+|[\'"」』]+$', "", out).strip() return {"text": out} except Exception as e: raise HTTPException(500, f"translate failed: {e}") @app.get("/health") def health() -> dict: return { "ok": True, "llm_configured": bool(LLM_API_KEY), "base_url": LLM_BASE_URL or "openai-default", "models": { "asr": ASR_MODEL, "translate": TRANSLATE_MODEL, "rewrite": REWRITE_MODEL, "video": VIDEO_MODEL, "video_aliases": VIDEO_MODEL_ALIASES, "video_provider": "poe" if video_uses_poe() else ("ark" if video_uses_ark() else "custom"), "video_base_url": video_api_base(), "video_configured": bool(video_api_key()), "video_create_paths": VIDEO_CREATE_PATHS, }, } class JobSummary(BaseModel): id: str url: str status: JobStatus progress: int = 0 message: str = "" duration: float = 0.0 width: int = 0 height: int = 0 video_url: str = "" frame_count: int = 0 video_count: int = 0 thumbnail: str = "" error: str = "" mtime: float = 0.0 @app.get("/jobs", response_model=list[JobSummary]) def list_jobs(limit: int | None = None) -> list[JobSummary]: """所有 job 的精简列表,按磁盘 state.json mtime 倒序(最新优先)。前端无 ?job= 时用它回填历史。""" items: list[JobSummary] = [] for job_id, job in JOBS.items(): state_path = JOBS_DIR / job_id / "state.json" mtime = state_path.stat().st_mtime if state_path.exists() else 0.0 thumb = f"/jobs/{job_id}/frames/{job.frames[0].index}.jpg" if job.frames else "" items.append(JobSummary( id=job.id, url=job.url, status=job.status, progress=job.progress, message=job.message, duration=job.duration, width=job.width, height=job.height, video_url=job.video_url, frame_count=len(job.frames), video_count=len(job.generated_videos), thumbnail=thumb, error=job.error, mtime=mtime, )) items.sort(key=lambda s: s.mtime, reverse=True) if limit is not None and limit > 0: items = items[:limit] return items @app.post("/jobs", response_model=Job) async def create_job(req: CreateJobReq, bg: BackgroundTasks) -> Job: if not req.url.strip(): raise HTTPException(400, "url required") job_id = uuid.uuid4().hex[:12] job = Job(id=job_id, url=req.url.strip()) JOBS[job_id] = job save_state(job) bg.add_task(pipeline_download, job_id) return job @app.post("/jobs/upload", response_model=Job) async def create_job_from_upload(bg: BackgroundTasks, file: UploadFile = File(...)) -> Job: if not file.filename: raise HTTPException(400, "file required") ext = Path(file.filename).suffix.lower() if ext not in {".mp4", ".mov", ".webm", ".mkv", ".m4v"}: raise HTTPException(400, f"unsupported video format: {ext}") job_id = uuid.uuid4().hex[:12] d = job_dir(job_id) mp4 = d / "source.mp4" with mp4.open("wb") as f: while chunk := await file.read(1024 * 1024): f.write(chunk) if not mp4.exists() or mp4.stat().st_size == 0: raise HTTPException(500, "upload failed") job = Job(id=job_id, url=f"upload://{file.filename}") JOBS[job_id] = job save_state(job) bg.add_task(pipeline_download, job_id) return job @app.post("/jobs/{job_id}/analyze", response_model=Job) async def trigger_analyze( job_id: str, bg: BackgroundTasks, frames: int = KEYFRAME_COUNT, target: FrameExtractTarget = "balanced", ) -> Job: job = JOBS.get(job_id) if not job: raise HTTPException(404, "job not found") if job.status not in {"downloaded", "frames_extracted", "transcribed", "failed"}: raise HTTPException(409, f"status must be downloaded/failed, got {job.status}") bg.add_task(pipeline_analyze, job_id, frames, target) return job @app.post("/jobs/{job_id}/frames", response_model=Job) def add_manual_frame(job_id: str, t: float) -> Job: """从指定时间戳手动抽 1 帧追加到 job.frames""" job = JOBS.get(job_id) if not job: raise HTTPException(404, "job not found") if not job.video_url: raise HTTPException(400, "video not ready") d = job_dir(job_id) mp4 = d / "source.mp4" if not mp4.exists(): raise HTTPException(400, "source.mp4 missing") frames_dir = d / "frames" frames_dir.mkdir(parents=True, exist_ok=True) # 新 index:max(existing)+1(即使列表已按 ts 排序,文件名用 index 保持稳定) next_idx = max((f.index for f in job.frames), default=-1) + 1 out = frames_dir / f"{next_idx:03d}.jpg" try: run([ "ffmpeg", "-y", "-ss", str(t), "-i", str(mp4), "-frames:v", "1", "-pix_fmt", "yuvj420p", "-q:v", "3", str(out), ]) except RuntimeError as e: raise HTTPException(500, f"ffmpeg failed: {e}") new_frame = KeyFrame( index=next_idx, timestamp=round(float(t), 2), url=f"/jobs/{job_id}/frames/{next_idx}.jpg", ) merged = sorted(list(job.frames) + [new_frame], key=lambda f: f.timestamp) update(job, frames=merged, message=f"已手动加帧({t:.1f}s),共 {len(merged)} 张") return job @app.get("/jobs/{job_id}", response_model=Job) def get_job(job_id: str) -> Job: job = JOBS.get(job_id) if not job: raise HTTPException(404, "job not found") return job @app.delete("/jobs/{job_id}") def delete_job(job_id: str) -> dict[str, bool | str]: d = (JOBS_DIR / job_id).resolve() if JOBS_DIR not in d.parents: raise HTTPException(400, "invalid job id") job = JOBS.pop(job_id, None) if not job and not d.exists(): raise HTTPException(404, "job not found") if d.exists(): shutil.rmtree(d) return {"ok": True, "id": job_id} @app.post("/jobs/{job_id}/transcribe", response_model=Job) async def trigger_transcribe(job_id: str, bg: BackgroundTasks) -> Job: job = JOBS.get(job_id) if not job: raise HTTPException(404, "job not found") if job.status != "frames_extracted": raise HTTPException(409, f"status must be frames_extracted, got {job.status}") bg.add_task(pipeline_transcribe, job_id) return job @app.get("/jobs/{job_id}/video.mp4") def get_video(job_id: str): p = job_dir(job_id) / "source.mp4" if not p.exists(): raise HTTPException(404, "video not found") return FileResponse(p, media_type="video/mp4") @app.get("/jobs/{job_id}/frames/{idx}.jpg") def get_frame(job_id: str, idx: int): p = job_dir(job_id) / "frames" / f"{idx:03d}.jpg" if not p.exists(): raise HTTPException(404, "frame not found") return FileResponse(p, media_type="image/jpeg") class GenerateReq(BaseModel): prompt: str extra_prompt: str = "" # ✓ 需要的元素(正向) negative_prompt: str = "" # ✗ 不需要的元素(负向) model: str = "" # 留空用 IMAGE_MODEL 默认 mode: str = "edit" # "edit" 带参考图,"text" 纯文字 from_selected: bool = False # True 时优先用 frame.selected 的生成图作 reference(迭代),否则原关键帧 @app.post("/jobs/{job_id}/frames/{idx}/generate", response_model=Job) def generate_image(job_id: str, idx: int, req: GenerateReq) -> Job: """根据关键帧 + prompt 生成新图(image-to-image 或 text-to-image)""" job = JOBS.get(job_id) if not job: raise HTTPException(404, "job not found") frame = next((f for f in job.frames if f.index == idx), None) if not frame: raise HTTPException(404, "frame not found") frame_path = job_dir(job_id) / "frames" / f"{idx:03d}.jpg" if not frame_path.exists(): raise HTTPException(404, "frame file missing") # 决定 i2i 参考图:from_selected=True 且存在 selected 生成图 → 用它(迭代);否则原关键帧 reference_path = frame_path reference_source = "keyframe" if req.from_selected: sel = next((g for g in frame.generated_images if g.selected), None) if sel: sel_path = job_dir(job_id) / "gen" / f"{idx:03d}_{sel.id}.jpg" if sel_path.exists(): reference_path = sel_path reference_source = f"gen:{sel.id[:6]}" full_prompt = req.prompt.strip() if req.extra_prompt.strip(): full_prompt = f"{full_prompt}. Include: {req.extra_prompt.strip()}" if req.negative_prompt.strip(): full_prompt = f"{full_prompt}. Avoid: {req.negative_prompt.strip()}" if not full_prompt: raise HTTPException(400, "prompt required") model = req.model or IMAGE_MODEL gen_id = uuid.uuid4().hex[:12] import base64 as b64lib import time as _time import httpx img_b64: str | None = None if req.mode == "edit": img_b64 = b64lib.b64encode(reference_path.read_bytes()).decode("ascii") # 尝试 i2i 最多 3 次,全失败时降级 text-only 再试 1 次 plan: list[str] = ([req.mode] * 3) if req.mode == "edit" else [req.mode] if req.mode == "edit": plan.append("text") # i2i 都失败时自动降级 resp_data: dict = {} last_err = "" effective_mode = req.mode for attempt, current_mode in enumerate(plan): try: if current_mode == "edit": data_uri = f"data:image/jpeg;base64,{img_b64}" # OpenAI SDK 不直接支持 image 参数,用底层 httpx with httpx.Client(timeout=120) as client: r = client.post( f"{LLM_BASE_URL}/images/generations", headers={ "Authorization": f"Bearer {LLM_API_KEY}", "Content-Type": "application/json", }, json={ "model": model, "prompt": full_prompt, "image": data_uri, "n": 1, }, ) r.raise_for_status() resp_data = r.json() else: # text-only resp = llm().images.generate(model=model, prompt=full_prompt, n=1) resp_data = resp.model_dump() if hasattr(resp, "model_dump") else {"data": [{"b64_json": resp.data[0].b64_json}]} if resp_data.get("data"): effective_mode = current_mode break err_obj = resp_data.get("error") or {} last_err = f"empty data · {err_obj.get('code', '')} · {str(err_obj.get('message', ''))[:200]}" except httpx.HTTPStatusError as e: body = e.response.text transient = ( e.response.status_code >= 500 or "incomplete_generation" in body or "rate_limit" in body or "timeout" in body.lower() ) last_err = f"HTTP {e.response.status_code}: {body[:200]}" if not transient: raise HTTPException(500, f"image gen HTTP {e.response.status_code}: {body[:300]}") except Exception as e: last_err = f"{type(e).__name__}: {e}" if attempt < len(plan) - 1: next_mode = plan[attempt + 1] tag = f"fallback → {next_mode}" if next_mode != current_mode else f"retry {attempt + 1}/{len(plan)}" print(f"[image gen {tag}] {last_err}", flush=True) _time.sleep(1.5 * (attempt + 1)) data_arr = resp_data.get("data", []) if not data_arr: raise HTTPException(500, f"image gen failed after {len(plan)} attempts: {last_err}") item = data_arr[0] b64 = item.get("b64_json") if not b64: raise HTTPException(500, "image gen returned no b64_json") # 保存到本地 jobs//gen/_.jpg gen_dir = job_dir(job_id) / "gen" gen_dir.mkdir(parents=True, exist_ok=True) out_path = gen_dir / f"{idx:03d}_{gen_id}.jpg" out_path.write_bytes(b64lib.b64decode(b64)) new_gen = GeneratedImage( id=gen_id, prompt=full_prompt, model=model, mode=effective_mode, url=f"/jobs/{job_id}/frames/{idx}/gen/{gen_id}.jpg", selected=False, created_at=_time.time(), ) # 写回 job.frames for f in job.frames: if f.index == idx: f.generated_images = f.generated_images + [new_gen] update(job, frames=job.frames, message=f"生图完成 · 分镜 {idx + 1}") return job @app.get("/jobs/{job_id}/frames/{idx}/gen/{gen_id}.jpg") def get_generated_image(job_id: str, idx: int, gen_id: str): p = job_dir(job_id) / "gen" / f"{idx:03d}_{gen_id}.jpg" if not p.exists(): raise HTTPException(404, "generated image not found") return FileResponse(p, media_type="image/jpeg") class SelectGenReq(BaseModel): selected: bool @app.post("/jobs/{job_id}/frames/{idx}/gen/{gen_id}/select", response_model=Job) def select_generated(job_id: str, idx: int, gen_id: str, req: SelectGenReq) -> Job: job = JOBS.get(job_id) if not job: raise HTTPException(404, "job not found") for f in job.frames: if f.index != idx: continue for g in f.generated_images: # 单选:该帧只能选一张 if g.id == gen_id: g.selected = req.selected else: g.selected = False break update(job, frames=job.frames) return job @app.post("/jobs/{job_id}/frames/{idx}/describe", response_model=Job) def describe_frame(job_id: str, idx: int) -> Job: """调 vision 模型识别该关键帧,返回结构化描述。""" job = JOBS.get(job_id) if not job: raise HTTPException(404, "job not found") frame = next((f for f in job.frames if f.index == idx), None) if not frame: raise HTTPException(404, "frame not found") p = job_dir(job_id) / "frames" / f"{idx:03d}.jpg" if not p.exists(): raise HTTPException(404, "frame file not found") import base64 as b64lib import re as _re img_b64 = b64lib.b64encode(p.read_bytes()).decode("ascii") prompt = ( "请识别这张图,输出严格 JSON(不要 markdown 不要解释,不要思考):\n" '{\n' ' "scene": "一句话描述场景",\n' ' "objects": [{"name": "物体名(中文)", "position": "在画面哪里", "color": "颜色", "extract_prompt": "用于提取该元素的英文 prompt"}],\n' ' "style": "整体风格 / 打光 / 色调(一句话)",\n' ' "suggested_prompt": "适合用作下游生图的完整英文 prompt"\n' '}\n' "要求:objects 列出 3-8 个画面里**可独立提取**的主要元素,extract_prompt 用于后续 image edit 模型。" ) last_err = "" data = None for attempt in range(3): try: resp = llm().chat.completions.create( model=VISION_MODEL, messages=[{"role": "user", "content": [ {"type": "text", "text": prompt}, {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img_b64}"}}, ]}], response_format={"type": "json_object"}, temperature=0.3, max_tokens=3000, ) content = (resp.choices[0].message.content or "").strip() if not content: # thinking 模型可能 content 空;尝试取 reasoning_content 里挖 JSON rc = getattr(resp.choices[0].message, "reasoning_content", "") or "" m = _re.search(r"\{[\s\S]*\}", rc) content = m.group(0) if m else "" # 剥掉 ```json ... ``` 包装 content = _re.sub(r"^```(?:json)?\s*|\s*```$", "", content).strip() if not content: last_err = f"empty content (attempt {attempt + 1})" continue data = json.loads(content) break except json.JSONDecodeError as e: last_err = f"json decode (attempt {attempt + 1}): {e} · raw[:200]={content[:200]}" print(f"[vision retry] {last_err}", flush=True) continue except Exception as e: last_err = f"vision call (attempt {attempt + 1}): {e}" print(f"[vision retry] {last_err}", flush=True) continue if data is None: raise HTTPException(500, last_err or "vision failed after 3 retries") # 写回 job new_frames = [] for f in job.frames: if f.index == idx: f.description = data new_frames.append(f) update(job, frames=new_frames, message=f"识别完成 · 分镜 {idx + 1}") return job # ---------- 清洗水印 / 元素提取(关键帧二阶段加工) ---------- class CleanupReq(BaseModel): # 多个相对坐标矩形 0-1,限制清洗范围;空 / None = 全图清洗 regions: list[dict] | None = None # [{"x","y","w","h"}, ...] def _region_to_phrase(r: dict) -> str: """把相对坐标矩形转成简短方位描述给 prompt 用(避免百分号 / 括号触发模型异常)""" x = max(0.0, min(1.0, float(r.get("x", 0)))) y = max(0.0, min(1.0, float(r.get("y", 0)))) w = max(0.0, min(1.0 - x, float(r.get("w", 0)))) h = max(0.0, min(1.0 - y, float(r.get("h", 0)))) if w <= 0 or h <= 0: return "" cx, cy = x + w / 2, y + h / 2 hpos = "left" if cx < 0.4 else "right" if cx > 0.6 else "middle" vpos = "top" if cy < 0.4 else "bottom" if cy > 0.6 else "middle" if hpos == "middle" and vpos == "middle": return "center" if hpos == "middle": return vpos if vpos == "middle": return hpos return f"{vpos} {hpos}" @app.post("/jobs/{job_id}/frames/{idx}/cleanup", response_model=Job) def cleanup_frame(job_id: str, idx: int, req: CleanupReq | None = None) -> Job: """调 nano-banana image edit 清洗关键帧:去水印 / @用户名 / 字幕 / 平台 logo。 输出干净版到 jobs//cleaned/.jpg,写回 frame.cleaned_url。 可选 region: 限定只清洗框内区域。""" import time as _time job = JOBS.get(job_id) if not job: raise HTTPException(404, "job not found") frame = next((f for f in job.frames if f.index == idx), None) if not frame: raise HTTPException(404, "frame not found") frame_path = job_dir(job_id) / "frames" / f"{idx:03d}.jpg" if not frame_path.exists(): raise HTTPException(404, "frame file missing") region_phrases: list[str] = [] if req and req.regions: for r in req.regions: p = _region_to_phrase(r) if p: region_phrases.append(p) region_phrases = list(dict.fromkeys(region_phrases)) # prompt 用"重画一张副本"语义而非"erase / remove only X" — 避免 Gemini 走 mask/inpainting # function call 路径(实测该路径在 SKG 网关上 100% 触发 incomplete_generation) if region_phrases: if len(region_phrases) == 1: zones = f"the {region_phrases[0]} area" else: zones = ", ".join(region_phrases) + " areas" prompt = ( f"Recreate this image as a clean version: remove the text and graphics in {zones}, " "keep the rest of the scene identical." ) else: prompt = ( "Recreate this image as a clean version without watermarks, captions, " "hashtags, usernames, or platform logos. Keep the composition and style." ) # 模型轮换:nano-banana-pro 失败时换 flash 系列 models = [ IMAGE_MODEL, # gemini-3-pro-image-preview (nano-banana-pro) "gemini-3.1-flash-image-preview", "gemini-2.5-flash-image", ] try: img_bytes, _mode = _image_edit_call( frame_path, prompt, models=models, fallback_text=False, max_attempts=3, ) except RuntimeError as e: raise HTTPException(500, f"cleanup failed: {e}") out_dir = job_dir(job_id) / "cleaned" out_dir.mkdir(parents=True, exist_ok=True) out_path = out_dir / f"{idx:03d}.jpg" out_path.write_bytes(img_bytes) new_frames = [] for f in job.frames: if f.index == idx: f.cleaned_url = f"/jobs/{job_id}/frames/{idx}/cleaned.jpg?t={int(_time.time())}" f.cleaned_applied = False # 重新清洗:重置"已应用"状态 new_frames.append(f) update(job, frames=new_frames, message=f"清洗完成 · 分镜 {idx + 1}") return job @app.get("/jobs/{job_id}/frames/{idx}/cleaned.jpg") def get_cleaned_frame(job_id: str, idx: int): p = job_dir(job_id) / "cleaned" / f"{idx:03d}.jpg" if not p.exists(): raise HTTPException(404, "cleaned frame not found") return FileResponse(p, media_type="image/jpeg") @app.delete("/jobs/{job_id}/frames/{idx}/cleanup", response_model=Job) def discard_cleaned(job_id: str, idx: int) -> Job: """丢弃待应用的清洗版(不影响已应用的)""" job = JOBS.get(job_id) if not job: raise HTTPException(404, "job not found") frame = next((f for f in job.frames if f.index == idx), None) if not frame: raise HTTPException(404, "frame not found") p = job_dir(job_id) / "cleaned" / f"{idx:03d}.jpg" if p.exists(): try: p.unlink() except OSError: pass new_frames = [] for f in job.frames: if f.index == idx: f.cleaned_url = None new_frames.append(f) update(job, frames=new_frames, message=f"丢弃清洗版 · 分镜 {idx + 1}") return job @app.post("/jobs/{job_id}/frames/{idx}/cleanup/apply", response_model=Job) def apply_cleaned(job_id: str, idx: int) -> Job: """用清洗版替换原关键帧:物理覆盖 frames/{idx}.jpg ← cleaned/{idx}.jpg。 原图作备份 → orig/{idx}.jpg(首次替换时备份,后续替换跳过)。 替换后 frame.cleaned_url 清空(不再有"待应用"清洗版)""" import shutil as _shutil job = JOBS.get(job_id) if not job: raise HTTPException(404, "job not found") frame = next((f for f in job.frames if f.index == idx), None) if not frame: raise HTTPException(404, "frame not found") cleaned_path = job_dir(job_id) / "cleaned" / f"{idx:03d}.jpg" if not cleaned_path.exists(): raise HTTPException(404, "no cleaned version to apply") frame_path = job_dir(job_id) / "frames" / f"{idx:03d}.jpg" # 首次替换:把原图备份到 orig/{idx}.jpg orig_dir = job_dir(job_id) / "orig" orig_dir.mkdir(parents=True, exist_ok=True) orig_backup = orig_dir / f"{idx:03d}.jpg" if not orig_backup.exists() and frame_path.exists(): _shutil.copy2(frame_path, orig_backup) # 用 cleaned 覆盖 frames/ _shutil.copy2(cleaned_path, frame_path) # 删 cleaned 文件(已经"应用",不再是单独的待选版本) try: cleaned_path.unlink() except OSError: pass new_frames = [] for f in job.frames: if f.index == idx: f.cleaned_url = None f.cleaned_applied = True new_frames.append(f) update(job, frames=new_frames, message=f"已替换分镜 {idx + 1} 为清洗版") return job class AddElementReq(BaseModel): name_zh: str name_en: str = "" position: str = "" source: Literal["auto", "manual", "region"] = "manual" region: dict | None = None class UpdateElementReq(BaseModel): name_zh: str | None = None name_en: str | None = None position: str | None = None @app.post("/jobs/{job_id}/frames/{idx}/elements", response_model=Job) def add_element(job_id: str, idx: int, req: AddElementReq) -> Job: """加一条元素 · 若 name_en 缺则自动 zh→en 翻译""" import time as _time import re as _re job = JOBS.get(job_id) if not job: raise HTTPException(404, "job not found") frame = next((f for f in job.frames if f.index == idx), None) if not frame: raise HTTPException(404, "frame not found") name_zh = req.name_zh.strip() if not name_zh: raise HTTPException(400, "name_zh required") name_en = req.name_en.strip() if not name_en and LLM_API_KEY: try: prompt = ( "Translate the following text into concise English, suitable as an element label " "in an image-generation prompt. Output only the translation — no quotes, no punctuation, " f"no explanation.\n\nInput: {name_zh}" ) resp = llm().chat.completions.create( model=TRANSLATE_MODEL, messages=[{"role": "user", "content": prompt}], temperature=0.2, max_tokens=200, ) out = (resp.choices[0].message.content or "").strip() if not out: rc = getattr(resp.choices[0].message, "reasoning_content", "") or "" if rc: out = rc.strip().splitlines()[-1].strip() name_en = _re.sub(r'^[\'"「『]+|[\'"」』]+$', "", out).strip() except Exception as e: print(f"[add_element translate failed] {e}", flush=True) name_en = "" el = KeyElement( id=uuid.uuid4().hex[:8], name_zh=name_zh, name_en=name_en, position=req.position.strip(), source=req.source, region=req.region, created_at=_time.time(), ) new_frames = [] for f in job.frames: if f.index == idx: f.elements = f.elements + [el] new_frames.append(f) update(job, frames=new_frames, message=f"加入元素 · 分镜 {idx + 1} · {name_zh}") return job @app.patch("/jobs/{job_id}/frames/{idx}/elements/{element_id}", response_model=Job) def update_element(job_id: str, idx: int, element_id: str, req: UpdateElementReq) -> Job: """更新元素标签 / 英文提示。提取不准时允许用户修正,不强制重建元素。""" job = JOBS.get(job_id) if not job: raise HTTPException(404, "job not found") changed_name = "" found = False new_frames = [] for f in job.frames: if f.index == idx: for e in f.elements: if e.id == element_id: found = True if req.name_zh is not None: name_zh = req.name_zh.strip() if not name_zh: raise HTTPException(400, "name_zh required") e.name_zh = name_zh changed_name = name_zh if req.name_en is not None: e.name_en = req.name_en.strip() if req.position is not None: e.position = req.position.strip() new_frames.append(f) if not found: raise HTTPException(404, "element not found") update(job, frames=new_frames, message=f"更新元素 · 分镜 {idx + 1} · {changed_name or element_id}") return job @app.delete("/jobs/{job_id}/frames/{idx}/elements/{element_id}", response_model=Job) def delete_element(job_id: str, idx: int, element_id: str) -> Job: job = JOBS.get(job_id) if not job: raise HTTPException(404, "job not found") new_frames = [] removed = False for f in job.frames: if f.index == idx: before = len(f.elements) f.elements = [e for e in f.elements if e.id != element_id] removed = len(f.elements) < before # 若有提取图也删(含多版本) if removed: elements_dir = job_dir(job_id) / "elements" if elements_dir.exists(): for pat in (f"{idx:03d}_{element_id}.jpg", f"{idx:03d}_{element_id}.png", f"{idx:03d}_{element_id}_*.jpg"): for p in elements_dir.glob(pat): try: p.unlink() except OSError: pass new_frames.append(f) if not removed: raise HTTPException(404, "element not found") update(job, frames=new_frames, message=f"删除元素 · 分镜 {idx + 1}") return job @app.post("/jobs/{job_id}/frames/{idx}/elements/{element_id}/cutout", response_model=Job) def cutout_element(job_id: str, idx: int, element_id: str) -> Job: """AI 提取元素 · 每次累积一张新图: 调 nano-banana 模型生成**完整、清晰**的元素图(即使原图只露出部分也补全)。 region 元素:先把 region + 30% padding 区域裁出作为 focus,再发给模型聚焦补全。""" from PIL import Image as _PILImage import io as _io import tempfile as _tempfile job = JOBS.get(job_id) if not job: raise HTTPException(404, "job not found") frame = next((f for f in job.frames if f.index == idx), None) if not frame: raise HTTPException(404, "frame not found") el = next((e for e in frame.elements if e.id == element_id), None) if not el: raise HTTPException(404, "element not found") cleaned_path = job_dir(job_id) / "cleaned" / f"{idx:03d}.jpg" src = cleaned_path if cleaned_path.exists() else job_dir(job_id) / "frames" / f"{idx:03d}.jpg" if not src.exists(): raise HTTPException(404, "source frame file missing") out_dir = job_dir(job_id) / "elements" out_dir.mkdir(parents=True, exist_ok=True) new_cutout_id = uuid.uuid4().hex[:8] out_path = out_dir / f"{idx:03d}_{element_id}_{new_cutout_id}.jpg" # region 元素:先 PIL 裁出 region + 30% padding 作为 focus 给模型(让它聚焦在该元素) tmp_focus: Path | None = None model_src = src if el.region: try: im = _PILImage.open(src).convert("RGB") W, H = im.size r = el.region x = max(0.0, min(1.0, float(r.get("x", 0)))) y = max(0.0, min(1.0, float(r.get("y", 0)))) w = max(0.0, min(1.0 - x, float(r.get("w", 0)))) h = max(0.0, min(1.0 - y, float(r.get("h", 0)))) cx, cy = x + w / 2, y + h / 2 # 扩大 30% 给上下文(避免裁到正好边界丢失补全 hint) ew, eh = w * 1.6, h * 1.6 x0 = max(0.0, cx - ew / 2); y0 = max(0.0, cy - eh / 2) x1 = min(1.0, cx + ew / 2); y1 = min(1.0, cy + eh / 2) left, top, right, bottom = int(x0 * W), int(y0 * H), int(x1 * W), int(y1 * H) if right - left > 8 and bottom - top > 8: cropped = im.crop((left, top, right, bottom)) tmp = _tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) cropped.save(tmp.name, format="JPEG", quality=92) tmp.close() tmp_focus = Path(tmp.name) model_src = tmp_focus except Exception as e: print(f"[cutout region crop failed, fallback to full frame] {e}", flush=True) target = (el.name_en or el.name_zh).strip() prompt = ( f"Identify the {target} in this image. " f"Generate a complete, high-resolution, sharply detailed image of the entire {target} as a standalone asset. " f"If the {target} is only partially visible in the source (cropped at edges, occluded by other objects, or out of frame), " "intelligently reconstruct the missing parts based on visual context so the result shows the FULL element. " "Place the complete element on a pure white background, isolated, with no other objects, no scene fragments, no shadows from the original scene. " "Preserve the element's original color palette, style, lighting character, and proportions. " "Output must be a clean, high-quality asset image suitable for downstream composition." ) models = [IMAGE_MODEL, "gemini-2.5-flash-image"] img_bytes: bytes try: try: img_bytes, _mode = _image_edit_call( model_src, prompt, models=models, fallback_text=False, max_attempts=3, ) except RuntimeError as e: raise HTTPException(500, f"extract failed: {e}") finally: if tmp_focus and tmp_focus.exists(): try: tmp_focus.unlink() except OSError: pass out_path.write_bytes(img_bytes) new_frames = [] for f in job.frames: if f.index == idx: for e in f.elements: if e.id == element_id: e.cutouts = (e.cutouts or []) + [new_cutout_id] if not e.cutout_id: e.cutout_id = new_cutout_id new_frames.append(f) update(job, frames=new_frames, message=f"提取完成 · {el.name_zh}") return job @app.delete("/jobs/{job_id}/frames/{idx}/elements/{element_id}/cutouts/{cutout_id}", response_model=Job) def delete_cutout(job_id: str, idx: int, element_id: str, cutout_id: str) -> Job: """删除该元素的某张提取图""" job = JOBS.get(job_id) if not job: raise HTTPException(404, "job not found") p = job_dir(job_id) / "elements" / f"{idx:03d}_{element_id}_{cutout_id}.jpg" if p.exists(): try: p.unlink() except OSError: pass removed = False new_frames = [] for f in job.frames: if f.index == idx: for e in f.elements: if e.id == element_id: if cutout_id in (e.cutouts or []): e.cutouts = [c for c in e.cutouts if c != cutout_id] removed = True # cutout_id 兼容字段:若指向被删的就清空 / 移到 cutouts 第一个 if e.cutout_id == cutout_id: e.cutout_id = e.cutouts[0] if e.cutouts else None new_frames.append(f) if not removed: raise HTTPException(404, "cutout not found in element") update(job, frames=new_frames, message=f"删除提取图") return job class UpdateStoryboardReq(BaseModel): duration: float = 0 subject_image: dict | None = None scene_image: dict | None = None product_image: dict | None = None action_image: dict | None = None # v1 字段(前端可不传) subject: str = "" product: str = "" scene: str = "" action: str = "" reference_ids: list[str] = [] class GenerateStoryboardVideoReq(BaseModel): prompt: str duration: float = 4 first_image: dict | None = None last_image: dict | None = None product_images: list[dict] = Field(default_factory=list) subject_image: dict | None = None scene_image: dict | None = None product_image: dict | None = None action_image: dict | None = None source_ref: VideoSourceRef | None = None model: str = "" size: str = "720x1280" def video_seconds(duration: float) -> str: if video_uses_ark(): if duration <= 0: return "5" return str(max(4, min(15, round(duration)))) if duration <= 6: return "4" if duration <= 10: return "8" return "12" def resolve_video_model(raw: str | None) -> str: requested = (raw or VIDEO_MODEL or "seedance").strip() lowered = requested.lower() if lowered in {"sora", "sora-2", "sora_2"}: raise HTTPException(400, "Sora 已停用,请选择 Seedance / Kling / Veo 3") return VIDEO_MODEL_ALIASES.get(lowered, requested) def normalize_video_status(status: str | None) -> Literal["queued", "in_progress", "completed", "failed"]: s = (status or "queued").lower() if s in {"completed", "complete", "succeeded", "success", "done"}: return "completed" if s in {"failed", "failure", "error", "cancelled", "canceled", "expired"}: return "failed" if s in {"running", "processing", "in_progress", "generating", "started"}: return "in_progress" return "queued" def video_progress(data: dict, fallback: int) -> int: raw = data.get("progress", data.get("percentage", data.get("percent", fallback))) try: value = int(float(raw)) except Exception: value = fallback return max(0, min(100, value)) def video_url_from_response(data: dict) -> str: for key in ("url", "video_url", "output_url", "download_url"): v = data.get(key) if isinstance(v, str) and v: return v arr = data.get("data") if isinstance(arr, list) and arr: first = arr[0] if isinstance(first, dict): for key in ("url", "video_url", "output_url", "download_url"): v = first.get(key) if isinstance(v, str) and v: return v output = data.get("output") if isinstance(output, dict): for key in ("url", "video_url", "download_url"): v = output.get(key) if isinstance(v, str) and v: return v content = data.get("content") if isinstance(content, dict): for key in ("video_url", "url", "download_url", "file_url"): v = content.get(key) if isinstance(v, str) and v: return v return "" def download_generated_video(client, base: str, headers: dict, provider_id: str, direct_url: str, out_mp4: Path) -> None: if direct_url: url = direct_url if direct_url.startswith("http") else f"{base}{direct_url if direct_url.startswith('/') else '/' + direct_url}" r = client.get(url, headers=headers if url.startswith(base) else None) else: r = client.get(f"{base}{video_path(VIDEO_CONTENT_PATH, id=provider_id)}", headers=headers) r.raise_for_status() out_mp4.write_bytes(r.content) def size_to_video_ratio(size: str) -> str: try: w, h = [int(x) for x in size.lower().replace(" ", "").split("x", 1)] except Exception: return "9:16" if w <= 0 or h <= 0: return "9:16" ratio = w / h known = { "16:9": 16 / 9, "9:16": 9 / 16, "1:1": 1, "4:3": 4 / 3, "3:4": 3 / 4, "21:9": 21 / 9, } return min(known, key=lambda key: abs(known[key] - ratio)) def ark_reference_data_url(ref_img: Path) -> str: mime = "image/png" if ref_img.suffix.lower() == ".png" else "image/jpeg" return f"data:{mime};base64,{base64.b64encode(ref_img.read_bytes()).decode('ascii')}" def submit_video_create( client, url: str, headers: dict, ref_img: Path, payload: dict, source_ref: VideoSourceRef | None = None, last_img: Path | None = None, product_imgs: list[Path] | None = None, ): if video_uses_ark(): content = [{"type": "text", "text": payload["prompt"]}] if source_ref and source_ref.kind == "source_video" and source_ref.url: content.append( { "type": "video_url", "video_url": {"url": source_ref.url}, "role": "reference_video", } ) content.append( { "type": "image_url", "image_url": {"url": ark_reference_data_url(ref_img)}, "role": "first_frame", } ) if last_img and last_img.exists(): content.append( { "type": "image_url", "image_url": {"url": ark_reference_data_url(last_img)}, "role": "last_frame", } ) for product_img in (product_imgs or [])[:6]: if product_img.exists(): content.append( { "type": "image_url", "image_url": {"url": ark_reference_data_url(product_img)}, "role": "reference_image", } ) data = { "model": payload["model"], "content": content, "ratio": size_to_video_ratio(str(payload.get("size", ""))), "duration": int(float(str(payload.get(VIDEO_DURATION_FIELD, 5)))), "watermark": False, "resolution": "720p", } return client.post(url, headers={**headers, "Content-Type": "application/json"}, json=data) if video_uses_poe(): data = dict(payload) data[VIDEO_DURATION_FIELD] = int(float(str(data.get(VIDEO_DURATION_FIELD, 4)))) data["input_image"] = base64.b64encode(ref_img.read_bytes()).decode("ascii") return client.post(url, headers=headers, json=data) with ref_img.open("rb") as fh: return client.post( url, headers=headers, data=payload, files={"input_reference": ("reference.jpg", fh, "image/jpeg")}, ) def render_storyboard_video( job_id: str, local_id: str, provider_id: str, ref_path: Path, prompt: str, model: str, seconds: str, size: str, source_ref: VideoSourceRef | None = None, last_ref_path: Path | None = None, product_ref_paths: list[Path] | None = None, ) -> None: import httpx out_dir = job_dir(job_id) / "storyboard_videos" / local_id ref_img = out_dir / "reference.jpg" last_img = out_dir / "last_reference.jpg" out_mp4 = out_dir / "video.mp4" base = video_api_base() headers = {"Authorization": f"Bearer {video_api_key()}"} try: prepare_video_reference(ref_path, ref_img) prepared_last_img: Path | None = None if last_ref_path and last_ref_path.exists(): prepare_video_reference(last_ref_path, last_img) prepared_last_img = last_img prepared_product_imgs: list[Path] = [] for i, product_ref_path in enumerate((product_ref_paths or [])[:6], start=1): if product_ref_path.exists(): product_img = out_dir / f"product_reference_{i}.jpg" prepare_video_reference(product_ref_path, product_img) prepared_product_imgs.append(product_img) update_generated_video(job_id, local_id, status="in_progress", progress=5) with httpx.Client(timeout=120) as client: payload = {"model": model, "prompt": prompt, "size": size} payload[VIDEO_DURATION_FIELD] = seconds create = None create_errors: list[str] = [] for create_path in VIDEO_CREATE_PATHS: resp = submit_video_create(client, f"{base}{video_path(create_path)}", headers, ref_img, payload, source_ref, prepared_last_img, prepared_product_imgs) if video_uses_ark() and source_ref and resp.status_code in {400, 422}: create_errors.append(f"{video_path(create_path)} + reference_video -> HTTP {resp.status_code}: {resp.text[:160]}") resp = submit_video_create(client, f"{base}{video_path(create_path)}", headers, ref_img, payload, None, prepared_last_img, prepared_product_imgs) if video_uses_ark() and prepared_last_img and resp.status_code in {400, 422}: create_errors.append(f"{video_path(create_path)} + last_frame -> HTTP {resp.status_code}: {resp.text[:160]}") resp = submit_video_create(client, f"{base}{video_path(create_path)}", headers, ref_img, payload, None, None, prepared_product_imgs) if video_uses_ark() and prepared_product_imgs and resp.status_code in {400, 422}: create_errors.append(f"{video_path(create_path)} + product_reference -> HTTP {resp.status_code}: {resp.text[:160]}") resp = submit_video_create(client, f"{base}{video_path(create_path)}", headers, ref_img, payload, None, prepared_last_img, None) if resp.status_code < 400: create = resp break create_errors.append(f"{video_path(create_path)} -> HTTP {resp.status_code}: {resp.text[:160]}") if resp.status_code not in {400, 404, 405}: resp.raise_for_status() if create is None: raise RuntimeError("视频模型已选择,但当前网关视频生成入口不可用;已尝试 " + " | ".join(create_errors)) data = create.json() video_api_id = data.get("id") or provider_id or local_id status = normalize_video_status(data.get("status")) progress = video_progress(data, 5) direct_url = video_url_from_response(data) update_generated_video(job_id, local_id, provider_id=video_api_id, status=status, progress=progress) deadline = time.time() + 420 while status in {"queued", "in_progress"} and time.time() < deadline: time.sleep(8) poll = client.get(f"{base}{video_path(VIDEO_STATUS_PATH, id=video_api_id)}", headers=headers) poll.raise_for_status() pdata = poll.json() status = normalize_video_status(pdata.get("status")) progress = video_progress(pdata, progress) direct_url = video_url_from_response(pdata) or direct_url update_generated_video(job_id, local_id, status=status, progress=progress) if status != "completed": update_generated_video(job_id, local_id, status="failed", error=f"video status: {status}", progress=progress) return download_generated_video(client, base, headers, video_api_id, direct_url, out_mp4) update_generated_video( job_id, local_id, status="completed", progress=100, url=f"/jobs/{job_id}/storyboard-videos/{local_id}.mp4", error="", ) except Exception as e: update_generated_video(job_id, local_id, status="failed", error=str(e)[:500]) @app.post("/jobs/{job_id}/frames/{idx}/storyboard/video", response_model=Job) def generate_storyboard_video(job_id: str, idx: int, req: GenerateStoryboardVideoReq, bg: BackgroundTasks) -> Job: job = JOBS.get(job_id) if not job: raise HTTPException(404, "job not found") frame = next((f for f in job.frames if f.index == idx), None) if not frame: raise HTTPException(404, "frame not found") ensure_video_api_configured() prompt = req.prompt.strip() if not prompt: raise HTTPException(400, "prompt required") ref = req.first_image or req.subject_image or req.product_image or req.scene_image or req.action_image ref_path = storyboard_ref_path(job_id, ref) or (job_dir(job_id) / "frames" / f"{idx:03d}.jpg") if not ref_path.exists(): raise HTTPException(404, "reference image missing") poster = storyboard_ref_url(job_id, ref) or f"/jobs/{job_id}/frames/{idx}.jpg" last_ref_path = storyboard_ref_path(job_id, req.last_image) raw_product_refs = req.product_images[:6] if req.product_images else ([req.product_image] if req.product_image else []) product_ref_paths = [p for p in (storyboard_ref_path(job_id, r) for r in raw_product_refs) if p] local_id = uuid.uuid4().hex[:12] model = resolve_video_model(req.model) seconds = video_seconds(float(req.duration or 4)) item = GeneratedVideo( id=local_id, provider_id="", frame_idx=idx, prompt=prompt, model=model, status="queued", url="", poster_url=poster, duration=float(seconds), progress=0, created_at=time.time(), ) update(job, generated_videos=[item] + job.generated_videos, message=f"视频生成已提交 · 分镜 {idx + 1}") source_ref = req.source_ref if source_ref and source_ref.kind == "source_video" and not source_ref.url: source_ref = None bg.add_task(render_storyboard_video, job_id, local_id, "", ref_path, prompt, model, seconds, req.size, source_ref, last_ref_path, product_ref_paths) return job @app.get("/jobs/{job_id}/storyboard-videos/{video_id}.mp4") def get_storyboard_video(job_id: str, video_id: str): p = job_dir(job_id) / "storyboard_videos" / video_id / "video.mp4" if not p.exists(): raise HTTPException(404, "storyboard video not found") return FileResponse(p, media_type="video/mp4") @app.post("/jobs/{job_id}/assets") async def upload_storyboard_asset(job_id: str, file: UploadFile = File(...)) -> dict: if job_id not in JOBS: raise HTTPException(404, "job not found") asset_id = uuid.uuid4().hex[:12] out_dir = job_dir(job_id) / "assets" out_dir.mkdir(parents=True, exist_ok=True) tmp = out_dir / f"{asset_id}.upload" out = out_dir / f"{asset_id}.jpg" try: tmp.write_bytes(await file.read()) img = Image.open(tmp).convert("RGB") img.thumbnail((1600, 1600), Image.Resampling.LANCZOS) img.save(out, "JPEG", quality=94) except Exception as e: raise HTTPException(400, f"product image upload failed: {e}") finally: try: tmp.unlink() except Exception: pass return { "kind": "asset", "frame_idx": -1, "element_id": asset_id, "cutout_id": asset_id, "label": file.filename or "SKG 产品图", } @app.get("/jobs/{job_id}/assets/{asset_id}.jpg") def get_storyboard_asset(job_id: str, asset_id: str): p = job_dir(job_id) / "assets" / f"{asset_id}.jpg" if not p.exists(): raise HTTPException(404, "asset not found") return FileResponse(p, media_type="image/jpeg") @app.delete("/jobs/{job_id}/storyboard-videos/{video_id}", response_model=Job) def delete_storyboard_video(job_id: str, video_id: str) -> Job: """删除 Video Gen 节点里的一个视频任务(成功/失败/排队都可删)。""" job = JOBS.get(job_id) if not job: raise HTTPException(404, "job not found") before = len(job.generated_videos) removed = next((v for v in job.generated_videos if v.id == video_id), None) kept = [v for v in job.generated_videos if v.id != video_id] if len(kept) == before: raise HTTPException(404, "generated video not found") out_dir = job_dir(job_id) / "storyboard_videos" / video_id if out_dir.exists(): try: shutil.rmtree(out_dir) except OSError: pass msg = f"删除视频任务 · 分镜 {removed.frame_idx + 1}" if removed else "删除视频任务" update(job, generated_videos=kept, message=msg) return job @app.put("/jobs/{job_id}/frames/{idx}/storyboard", response_model=Job) def update_storyboard(job_id: str, idx: int, req: UpdateStoryboardReq) -> Job: """更新分镜的编排字段(subject / product / scene / action / duration / reference_ids)""" job = JOBS.get(job_id) if not job: raise HTTPException(404, "job not found") frame = next((f for f in job.frames if f.index == idx), None) if not frame: raise HTTPException(404, "frame not found") new_frames = [] for f in job.frames: if f.index == idx: f.storyboard = StoryboardScene( duration=max(0.0, float(req.duration)), subject_image=req.subject_image, scene_image=req.scene_image, product_image=req.product_image, action_image=req.action_image, subject=req.subject.strip(), product=req.product.strip(), scene=req.scene.strip(), action=req.action.strip(), reference_ids=list(req.reference_ids), ) new_frames.append(f) update(job, frames=new_frames, message=f"分镜 {idx + 1} 编排已更新") return job class PushStoryboardImageReq(BaseModel): kind: Literal["keyframe", "cutout"] frame_idx: int element_id: str | None = None cutout_id: str | None = None label: str = "" @app.post("/jobs/{job_id}/storyboard-images", response_model=Job) def push_storyboard_image(job_id: str, req: PushStoryboardImageReq) -> Job: """把一张图(关键帧本身或元素提取图)推送到分镜头编排区""" import time as _time job = JOBS.get(job_id) if not job: raise HTTPException(404, "job not found") # 防重复推送:相同 frame_idx + element_id + cutout_id 已存在就跳过 for existing in job.storyboard_images: if (existing.kind == req.kind and existing.frame_idx == req.frame_idx and existing.element_id == req.element_id and existing.cutout_id == req.cutout_id): return job img = StoryboardImage( ref_id=uuid.uuid4().hex[:8], kind=req.kind, frame_idx=req.frame_idx, element_id=req.element_id, cutout_id=req.cutout_id, label=req.label.strip(), created_at=_time.time(), ) update(job, storyboard_images=job.storyboard_images + [img], message=f"上推到分镜头编排 · {req.label or req.kind}") return job @app.delete("/jobs/{job_id}/storyboard-images/{ref_id}", response_model=Job) def remove_storyboard_image(job_id: str, ref_id: str) -> Job: """从分镜头编排区移除一张图""" job = JOBS.get(job_id) if not job: raise HTTPException(404, "job not found") before = len(job.storyboard_images) new_list = [x for x in job.storyboard_images if x.ref_id != ref_id] if len(new_list) == before: raise HTTPException(404, "storyboard image not found") update(job, storyboard_images=new_list, message="从分镜头编排移除一张图") return job @app.get("/jobs/{job_id}/frames/{idx}/elements/{element_id}/cutouts/{cutout_id}.jpg") def get_cutout_versioned(job_id: str, idx: int, element_id: str, cutout_id: str): p = job_dir(job_id) / "elements" / f"{idx:03d}_{element_id}_{cutout_id}.jpg" if not p.exists(): raise HTTPException(404, "cutout not found") return FileResponse(p, media_type="image/jpeg") @app.get("/jobs/{job_id}/frames/{idx}/elements/{element_id}/cutout.jpg") def get_cutout(job_id: str, idx: int, element_id: str): """旧路径兼容(v1 单图)→ 找 elements/{idx}_{element_id}.jpg 或 .png""" p = job_dir(job_id) / "elements" / f"{idx:03d}_{element_id}.jpg" if not p.exists(): legacy = job_dir(job_id) / "elements" / f"{idx:03d}_{element_id}.png" if legacy.exists(): return FileResponse(legacy, media_type="image/jpeg") raise HTTPException(404, "cutout not found") return FileResponse(p, media_type="image/jpeg") # ---------- 删除:关键帧 / 单张生成图 ---------- @app.delete("/jobs/{job_id}/frames/{idx}", response_model=Job) def delete_frame(job_id: str, idx: int) -> Job: """删除整张关键帧,清理所有附属文件(原图 / 干净版 / 元素抠图 / 生成图)""" job = JOBS.get(job_id) if not job: raise HTTPException(404, "job not found") target = next((f for f in job.frames if f.index == idx), None) if not target: raise HTTPException(404, "frame not found") d = job_dir(job_id) # 删文件 — 静默错误,文件可能不存在 paths = [ d / "frames" / f"{idx:03d}.jpg", d / "cleaned" / f"{idx:03d}.jpg", ] for p in paths: if p.exists(): try: p.unlink() except OSError: pass # 该帧的所有元素抠图(命名前缀 {idx:03d}_) elements_dir = d / "elements" if elements_dir.exists(): for ext in ("png", "jpg"): for p in elements_dir.glob(f"{idx:03d}_*.{ext}"): try: p.unlink() except OSError: pass # 该帧的所有生成图 gen_dir = d / "gen" if gen_dir.exists(): for p in gen_dir.glob(f"{idx:03d}_*.jpg"): try: p.unlink() except OSError: pass new_frames = [f for f in job.frames if f.index != idx] update(job, frames=new_frames, message=f"删除分镜 {idx + 1}") return job @app.delete("/jobs/{job_id}/frames/{idx}/gen/{gen_id}", response_model=Job) def delete_generated(job_id: str, idx: int, gen_id: str) -> Job: """删除该 frame 的某张生成图(文件 + 列表)""" job = JOBS.get(job_id) if not job: raise HTTPException(404, "job not found") frame = next((f for f in job.frames if f.index == idx), None) if not frame: raise HTTPException(404, "frame not found") p = job_dir(job_id) / "gen" / f"{idx:03d}_{gen_id}.jpg" if p.exists(): try: p.unlink() except OSError: pass new_frames = [] found = False for f in job.frames: if f.index == idx: before = len(f.generated_images) f.generated_images = [g for g in f.generated_images if g.id != gen_id] found = len(f.generated_images) < before new_frames.append(f) if not found: raise HTTPException(404, "generated image not found") update(job, frames=new_frames, message=f"删除生成图 · 分镜 {idx + 1}") return job