2084 lines
78 KiB
Python
2084 lines
78 KiB
Python
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import base64
|
||
import json
|
||
import os
|
||
import shutil
|
||
import subprocess
|
||
import time
|
||
import uuid
|
||
from contextlib import asynccontextmanager
|
||
from pathlib import Path
|
||
from typing import Literal
|
||
|
||
from dotenv import load_dotenv
|
||
from fastapi import BackgroundTasks, FastAPI, File, HTTPException, UploadFile
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.responses import FileResponse
|
||
from pydantic import BaseModel, Field
|
||
|
||
load_dotenv()
|
||
|
||
JOBS_DIR = Path(os.getenv("JOBS_DIR", "./jobs")).resolve()
|
||
JOBS_DIR.mkdir(parents=True, exist_ok=True)
|
||
CORS_ORIGINS = [o.strip() for o in os.getenv("CORS_ORIGINS", "http://localhost:4290").split(",") if o.strip()]
|
||
|
||
LLM_BASE_URL = os.getenv("LLM_BASE_URL", "").strip()
|
||
LLM_API_KEY = os.getenv("LLM_API_KEY", "").strip()
|
||
ASR_MODEL = os.getenv("ASR_MODEL", "whisper-1")
|
||
TRANSLATE_MODEL = os.getenv("TRANSLATE_MODEL", "gemini-2.5-flash")
|
||
REWRITE_MODEL = os.getenv("REWRITE_MODEL", "gemini-2.5-pro")
|
||
VISION_MODEL = os.getenv("VISION_MODEL", "gemini-2.5-flash")
|
||
IMAGE_MODEL = os.getenv("IMAGE_MODEL", "gemini-3-pro-image-preview")
|
||
VIDEO_MODEL = os.getenv("VIDEO_MODEL", "seedance").strip() or "seedance"
|
||
|
||
POE_API_BASE_URL = os.getenv("POE_API_BASE_URL", "https://api.poe.com/v1").strip() or "https://api.poe.com/v1"
|
||
POE_API_KEY = os.getenv("POE_API_KEY", "").strip()
|
||
|
||
|
||
def env_video_model(name: str, default: str) -> str:
|
||
value = os.getenv(name, "").strip()
|
||
if not value:
|
||
return default
|
||
# Older local envs used business aliases as model IDs. Keep those aliases usable
|
||
# while mapping them to concrete Poe video model IDs by default.
|
||
if value.lower() in {"seedance", "kling", "veo", "veo3", "voe"}:
|
||
return default
|
||
return value
|
||
|
||
|
||
VIDEO_MODEL_ALIASES = {
|
||
"seedance": env_video_model("VIDEO_MODEL_SEEDANCE", "seedance-2-fast"),
|
||
"kling": env_video_model("VIDEO_MODEL_KLING", "kling-omni"),
|
||
"veo3": env_video_model("VIDEO_MODEL_VEO3", "veo-3.1-fast"),
|
||
"veo": env_video_model("VIDEO_MODEL_VEO3", "veo-3.1-fast"),
|
||
"voe": env_video_model("VIDEO_MODEL_VEO3", "veo-3.1-fast"),
|
||
}
|
||
VIDEO_API_BASE_URL = os.getenv("VIDEO_API_BASE_URL", "").strip()
|
||
VIDEO_API_KEY = os.getenv("VIDEO_API_KEY", "").strip()
|
||
VIDEO_CREATE_PATH = os.getenv("VIDEO_CREATE_PATH", "/videos").strip() or "/videos"
|
||
VIDEO_CREATE_PATHS = [
|
||
p.strip()
|
||
for p in os.getenv("VIDEO_CREATE_PATHS", f"{VIDEO_CREATE_PATH},/videos/generations,/video/generations").split(",")
|
||
if p.strip()
|
||
]
|
||
VIDEO_STATUS_PATH = os.getenv("VIDEO_STATUS_PATH", "/videos/{id}").strip() or "/videos/{id}"
|
||
VIDEO_CONTENT_PATH = os.getenv("VIDEO_CONTENT_PATH", "/videos/{id}/content").strip() or "/videos/{id}/content"
|
||
VIDEO_DURATION_FIELD = os.getenv("VIDEO_DURATION_FIELD", "seconds").strip() or "seconds"
|
||
|
||
# OpenAI 客户端(OpenAI 兼容网关,含 SKG ezlink)
|
||
from openai import OpenAI
|
||
_llm_client: OpenAI | None = None
|
||
def llm() -> OpenAI:
|
||
global _llm_client
|
||
if _llm_client is None:
|
||
if not LLM_API_KEY:
|
||
raise RuntimeError("LLM_API_KEY 未配置")
|
||
_llm_client = OpenAI(base_url=LLM_BASE_URL or None, api_key=LLM_API_KEY)
|
||
return _llm_client
|
||
|
||
# Pipeline 状态:
|
||
# created → downloading → downloaded(停,等用户点解析)→ splitting → frames_extracted
|
||
# → transcribing → transcribed | failed
|
||
JobStatus = Literal[
|
||
"created", "downloading", "downloaded",
|
||
"splitting", "frames_extracted",
|
||
"transcribing", "transcribed", "failed",
|
||
]
|
||
|
||
KEYFRAME_COUNT = int(os.getenv("KEYFRAME_COUNT", "5"))
|
||
|
||
|
||
class GeneratedImage(BaseModel):
|
||
id: str # uuid hex 12
|
||
prompt: str
|
||
model: str
|
||
mode: str = "edit" # "edit"(带参考图) | "text"(纯文字)
|
||
url: str # /jobs/{job_id}/frames/{idx}/gen/{id}.jpg
|
||
selected: bool = False
|
||
created_at: float = 0.0
|
||
|
||
|
||
class GeneratedVideo(BaseModel):
|
||
id: str
|
||
provider_id: str = ""
|
||
frame_idx: int
|
||
prompt: str
|
||
model: str = ""
|
||
status: Literal["queued", "in_progress", "completed", "failed"] = "queued"
|
||
url: str = ""
|
||
poster_url: str = ""
|
||
duration: float = 4.0
|
||
progress: int = 0
|
||
error: str = ""
|
||
created_at: float = 0.0
|
||
|
||
|
||
class StoryboardScene(BaseModel):
|
||
"""分镜头编排:每个 selected 分镜对应一个 scene 描述
|
||
v2: 4 图槽 + 时长(复制粘贴模式)— 主体 / 场景 / 产品 / 动作 各一张图
|
||
v1 字段保留兼容(subject/product/scene/action/reference_ids)"""
|
||
duration: float = 0
|
||
# 4 图槽:dict 含 {kind, frame_idx, element_id?, cutout_id?, label}
|
||
subject_image: dict | None = None
|
||
scene_image: dict | None = None
|
||
product_image: dict | None = None
|
||
action_image: dict | None = None
|
||
# v1 兼容
|
||
subject: str = ""
|
||
product: str = ""
|
||
scene: str = ""
|
||
action: str = ""
|
||
reference_ids: list[str] = []
|
||
|
||
|
||
class StoryboardImage(BaseModel):
|
||
"""用户从各处"上推"到分镜头编排区的图片"""
|
||
ref_id: str # uuid hex 8
|
||
kind: Literal["keyframe", "cutout"] # keyframe = 关键帧本身 / cutout = 元素提取图
|
||
frame_idx: int
|
||
element_id: str | None = None # cutout 时
|
||
cutout_id: str | None = None # cutout 时(versioned id;老数据可能 == element_id)
|
||
label: str = "" # 显示用名字
|
||
created_at: float = 0.0
|
||
|
||
|
||
class KeyElement(BaseModel):
|
||
"""关键帧里识别 / 用户提取的元素 · 多次提取累积多张图,让用户挑选满意的"""
|
||
id: str # uuid hex 8
|
||
name_zh: str
|
||
name_en: str = ""
|
||
position: str = ""
|
||
source: Literal["auto", "manual", "region"] = "manual"
|
||
region: dict | None = None
|
||
# 多张提取图 id(每次 cutout 端点累积新 id)→ /jobs/.../elements/{element_id}/cutouts/{cutout_id}.jpg
|
||
cutouts: list[str] = []
|
||
# 旧字段兼容(v1 单图)· 渲染时 fallback 用,新提取不再写入
|
||
cutout_id: str | None = None
|
||
cutout_background: Literal["white", "black"] = "white"
|
||
created_at: float = 0.0
|
||
|
||
|
||
class KeyFrame(BaseModel):
|
||
index: int
|
||
timestamp: float
|
||
url: str
|
||
description: dict | None = None # vision 模型识别结果 {scene, objects, style, suggested_prompt}
|
||
cleaned_url: str | None = None # 清洗后干净版(待应用)→ /jobs/{id}/frames/{idx}/cleaned.jpg
|
||
cleaned_applied: bool = False # 是否已用清洗版替换原图(替换后 cleaned_url=null)
|
||
elements: list[KeyElement] = [] # 提取的元素清单(持久化)
|
||
storyboard: StoryboardScene | None = None # 分镜头编排字段
|
||
generated_images: list[GeneratedImage] = []
|
||
|
||
|
||
class TranscriptSegment(BaseModel):
|
||
index: int
|
||
start: float
|
||
end: float
|
||
en: str
|
||
zh: str = ""
|
||
|
||
|
||
class Job(BaseModel):
|
||
id: str
|
||
url: str
|
||
status: JobStatus = "created"
|
||
progress: int = 0
|
||
message: str = ""
|
||
video_url: str = ""
|
||
duration: float = 0.0
|
||
width: int = 0
|
||
height: int = 0
|
||
frames: list[KeyFrame] = Field(default_factory=list)
|
||
transcript: list[TranscriptSegment] = Field(default_factory=list)
|
||
storyboard_images: list[StoryboardImage] = Field(default_factory=list)
|
||
generated_videos: list[GeneratedVideo] = Field(default_factory=list)
|
||
error: str = ""
|
||
|
||
|
||
JOBS: dict[str, Job] = {}
|
||
|
||
|
||
def job_dir(job_id: str) -> Path:
|
||
d = JOBS_DIR / job_id
|
||
d.mkdir(parents=True, exist_ok=True)
|
||
return d
|
||
|
||
|
||
def save_state(job: Job) -> None:
|
||
(job_dir(job.id) / "state.json").write_text(job.model_dump_json(indent=2))
|
||
|
||
|
||
def update(job: Job, **kw) -> None:
|
||
for k, v in kw.items():
|
||
setattr(job, k, v)
|
||
save_state(job)
|
||
|
||
|
||
def public_api_base() -> str:
|
||
return (LLM_BASE_URL or "https://api.openai.com/v1").rstrip("/")
|
||
|
||
|
||
def video_uses_poe() -> bool:
|
||
if VIDEO_API_BASE_URL:
|
||
return VIDEO_API_BASE_URL.rstrip("/") == POE_API_BASE_URL.rstrip("/")
|
||
return bool(POE_API_KEY)
|
||
|
||
|
||
def video_uses_ark() -> bool:
|
||
return "ark.cn-beijing.volces.com" in video_api_base()
|
||
|
||
|
||
def video_api_base() -> str:
|
||
if VIDEO_API_BASE_URL:
|
||
return VIDEO_API_BASE_URL.rstrip("/")
|
||
if POE_API_KEY:
|
||
return POE_API_BASE_URL.rstrip("/")
|
||
return (LLM_BASE_URL or "https://api.openai.com/v1").rstrip("/")
|
||
|
||
|
||
def video_api_key() -> str:
|
||
if VIDEO_API_KEY:
|
||
return VIDEO_API_KEY
|
||
if video_uses_poe():
|
||
return POE_API_KEY
|
||
return LLM_API_KEY
|
||
|
||
|
||
def video_path(template: str, **values: str) -> str:
|
||
path = template.format(**values)
|
||
return path if path.startswith("/") else f"/{path}"
|
||
|
||
|
||
def ensure_video_api_configured() -> None:
|
||
if not video_api_key():
|
||
raise HTTPException(503, "POE_API_KEY、VIDEO_API_KEY 或 LLM_API_KEY 未配置,无法调用生视频 API")
|
||
|
||
|
||
def storyboard_ref_path(job_id: str, ref: dict | None) -> Path | None:
|
||
if not ref:
|
||
return None
|
||
try:
|
||
kind = ref.get("kind")
|
||
frame_idx = int(ref.get("frame_idx"))
|
||
except Exception:
|
||
return None
|
||
if kind == "keyframe":
|
||
p = job_dir(job_id) / "frames" / f"{frame_idx:03d}.jpg"
|
||
return p if p.exists() else None
|
||
if kind == "cutout":
|
||
element_id = (ref.get("element_id") or "").strip()
|
||
cutout_id = (ref.get("cutout_id") or "").strip()
|
||
if not element_id:
|
||
return None
|
||
candidates = []
|
||
if cutout_id and cutout_id != element_id:
|
||
candidates.append(job_dir(job_id) / "elements" / f"{frame_idx:03d}_{element_id}_{cutout_id}.jpg")
|
||
candidates.append(job_dir(job_id) / "elements" / f"{frame_idx:03d}_{element_id}.jpg")
|
||
candidates.append(job_dir(job_id) / "elements" / f"{frame_idx:03d}_{element_id}.png")
|
||
for p in candidates:
|
||
if p.exists():
|
||
return p
|
||
return None
|
||
|
||
|
||
def storyboard_ref_url(job_id: str, ref: dict | None) -> str:
|
||
if not ref:
|
||
return ""
|
||
kind = ref.get("kind")
|
||
frame_idx = ref.get("frame_idx")
|
||
if kind == "keyframe" and frame_idx is not None:
|
||
return f"/jobs/{job_id}/frames/{int(frame_idx)}.jpg"
|
||
if kind == "cutout" and frame_idx is not None and ref.get("element_id"):
|
||
element_id = ref.get("element_id")
|
||
cutout_id = ref.get("cutout_id")
|
||
if cutout_id and cutout_id != element_id:
|
||
return f"/jobs/{job_id}/frames/{int(frame_idx)}/elements/{element_id}/cutouts/{cutout_id}.jpg"
|
||
return f"/jobs/{job_id}/frames/{int(frame_idx)}/elements/{element_id}/cutout.jpg"
|
||
return ""
|
||
|
||
|
||
def prepare_video_reference(src: Path, dst: Path, size: tuple[int, int] = (720, 1280)) -> None:
|
||
dst.parent.mkdir(parents=True, exist_ok=True)
|
||
img = Image.open(src).convert("RGB")
|
||
img.thumbnail(size, Image.Resampling.LANCZOS)
|
||
canvas = Image.new("RGB", size, (8, 8, 10))
|
||
x = (size[0] - img.width) // 2
|
||
y = (size[1] - img.height) // 2
|
||
canvas.paste(img, (x, y))
|
||
canvas.save(dst, "JPEG", quality=94)
|
||
|
||
|
||
def update_generated_video(job_id: str, video_id: str, **kw) -> None:
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
return
|
||
updated = []
|
||
for v in job.generated_videos:
|
||
if v.id == video_id:
|
||
data = v.model_dump()
|
||
data.update(kw)
|
||
updated.append(GeneratedVideo(**data))
|
||
else:
|
||
updated.append(v)
|
||
update(job, generated_videos=updated)
|
||
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(_: FastAPI):
|
||
# 启动时从磁盘恢复 jobs(简化版:只列目录)
|
||
for p in JOBS_DIR.iterdir():
|
||
if p.is_dir() and (p / "state.json").exists():
|
||
try:
|
||
JOBS[p.name] = Job.model_validate_json((p / "state.json").read_text())
|
||
except Exception:
|
||
pass
|
||
yield
|
||
|
||
|
||
app = FastAPI(title="SKG TK 二创 API", lifespan=lifespan)
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=CORS_ORIGINS,
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
|
||
# ---------- Pipeline 实现 ----------
|
||
|
||
def run(cmd: list[str], cwd: Path | None = None) -> str:
|
||
res = subprocess.run(cmd, cwd=cwd, capture_output=True, text=True)
|
||
if res.returncode != 0:
|
||
# ffmpeg 把 banner 写 stderr,挑最后几行(真错误一般在末尾)
|
||
tail = "\n".join(res.stderr.splitlines()[-12:]) or res.stderr[-800:]
|
||
raise RuntimeError(f"cmd failed: {' '.join(cmd[:3])}... · {tail}")
|
||
return res.stdout
|
||
|
||
|
||
# ---- 启发式选帧工具 ----
|
||
import imagehash
|
||
import numpy as np
|
||
from PIL import Image, ImageEnhance, ImageFilter, ImageOps
|
||
|
||
|
||
def _sharpness(img_path: Path) -> float:
|
||
"""Laplacian variance:值越大越清晰,模糊/转场帧值低。"""
|
||
g = np.asarray(Image.open(img_path).convert("L").resize((320, 180)), dtype=np.float32)
|
||
lap = (-4 * g[1:-1, 1:-1]
|
||
+ g[:-2, 1:-1] + g[2:, 1:-1] + g[1:-1, :-2] + g[1:-1, 2:])
|
||
return float(lap.var())
|
||
|
||
|
||
def _select_keyframes(candidates: list[Path], n: int, dup_threshold: int = 8) -> list[Path]:
|
||
"""
|
||
candidates: 按时间排序的候选帧路径
|
||
n: 目标帧数
|
||
dup_threshold: pHash 汉明距离 < 此值视为相似(默认 8,64bit hash 大致 ~12.5% 像素差)
|
||
"""
|
||
if len(candidates) <= n:
|
||
return candidates
|
||
|
||
# 算 pHash + sharpness
|
||
items = []
|
||
for i, p in enumerate(candidates):
|
||
try:
|
||
img = Image.open(p)
|
||
h = imagehash.phash(img)
|
||
s = _sharpness(p)
|
||
items.append({"path": p, "idx": i, "hash": h, "sharp": s})
|
||
except Exception:
|
||
continue
|
||
|
||
# 去重:相似帧保留 sharpness 高的
|
||
deduped: list[dict] = []
|
||
for it in items:
|
||
dup = None
|
||
for kept in deduped:
|
||
if (it["hash"] - kept["hash"]) < dup_threshold:
|
||
dup = kept
|
||
break
|
||
if dup is None:
|
||
deduped.append(it)
|
||
elif it["sharp"] > dup["sharp"]:
|
||
deduped[deduped.index(dup)] = it
|
||
|
||
# 时序分桶:把候选时间轴等分 n 段,每段取去重后 sharpness 最高的
|
||
total = len(candidates)
|
||
buckets: list[list[dict]] = [[] for _ in range(n)]
|
||
for it in deduped:
|
||
b = min(int(it["idx"] * n / total), n - 1)
|
||
buckets[b].append(it)
|
||
|
||
selected: list[dict] = []
|
||
for b in buckets:
|
||
if b:
|
||
selected.append(max(b, key=lambda x: x["sharp"]))
|
||
|
||
# 空桶补足:从未选的 deduped 里按 sharpness 排序补
|
||
chosen_paths = {it["path"] for it in selected}
|
||
remaining = sorted([it for it in deduped if it["path"] not in chosen_paths],
|
||
key=lambda x: -x["sharp"])
|
||
while len(selected) < n and remaining:
|
||
selected.append(remaining.pop(0))
|
||
|
||
# 按时间排序输出
|
||
selected.sort(key=lambda x: x["idx"])
|
||
return [it["path"] for it in selected]
|
||
|
||
|
||
def ffprobe_meta(mp4: Path) -> dict:
|
||
out = run([
|
||
"ffprobe", "-v", "error", "-print_format", "json", "-show_streams", "-show_format", str(mp4),
|
||
])
|
||
return json.loads(out)
|
||
|
||
|
||
async def pipeline_download(job_id: str) -> None:
|
||
"""阶段 1:仅下载(或上传跳过),落 source.mp4,停在 downloaded 等用户点解析。"""
|
||
job = JOBS[job_id]
|
||
d = job_dir(job_id)
|
||
try:
|
||
mp4 = d / "source.mp4"
|
||
if mp4.exists():
|
||
update(job, status="downloading", message="本地上传 · 跳过下载", progress=15)
|
||
else:
|
||
update(job, status="downloading", message="yt-dlp 下载中…", progress=5)
|
||
run([
|
||
"yt-dlp", "-f", "best[ext=mp4]/best",
|
||
"-o", str(mp4),
|
||
"--no-warnings", "--no-playlist",
|
||
"--retries", "3",
|
||
job.url,
|
||
])
|
||
if not mp4.exists():
|
||
raise RuntimeError("下载完成但找不到 source.mp4")
|
||
|
||
meta = ffprobe_meta(mp4)
|
||
v_stream = next((s for s in meta["streams"] if s["codec_type"] == "video"), None)
|
||
duration = float(meta["format"]["duration"])
|
||
update(
|
||
job,
|
||
status="downloaded",
|
||
video_url=f"/jobs/{job_id}/video.mp4",
|
||
duration=duration,
|
||
width=int(v_stream["width"]) if v_stream else 0,
|
||
height=int(v_stream["height"]) if v_stream else 0,
|
||
progress=25,
|
||
message=f"视频就绪 · {duration:.1f}s · 等待解析",
|
||
)
|
||
except Exception as e:
|
||
update(job, status="failed", error=str(e), message="下载失败")
|
||
|
||
|
||
async def pipeline_analyze(job_id: str, frame_count: int = KEYFRAME_COUNT) -> None:
|
||
"""阶段 2:拆音轨 + 抽关键帧。ASR/翻译是独立文案轨,不阻塞视觉素材流。"""
|
||
job = JOBS[job_id]
|
||
d = job_dir(job_id)
|
||
try:
|
||
mp4 = d / "source.mp4"
|
||
if not mp4.exists():
|
||
raise RuntimeError("source.mp4 不存在,先完成下载")
|
||
|
||
update(job, status="splitting", message="ffmpeg 拆分音轨…", progress=35)
|
||
wav = d / "audio.wav"
|
||
run([
|
||
"ffmpeg", "-y", "-i", str(mp4),
|
||
"-vn", "-ac", "1", "-ar", "16000", "-c:a", "pcm_s16le",
|
||
str(wav),
|
||
])
|
||
|
||
n = max(1, min(int(frame_count), 20))
|
||
# 候选数:n 的 6 倍或至少 24,封顶 60
|
||
candidate_count = max(24, min(60, n * 6))
|
||
|
||
update(job, message=f"抽取候选 {candidate_count} 张…", progress=45)
|
||
frames_dir = d / "frames"
|
||
if frames_dir.exists():
|
||
shutil.rmtree(frames_dir)
|
||
frames_dir.mkdir(parents=True)
|
||
cand_dir = d / "candidates"
|
||
if cand_dir.exists():
|
||
shutil.rmtree(cand_dir)
|
||
cand_dir.mkdir(parents=True)
|
||
|
||
# 1) 均匀采样大批候选(fast seek,每张 < 0.5s)
|
||
duration = max(float(job.duration or 1.0), 0.1)
|
||
step = duration / (candidate_count + 1)
|
||
candidate_meta: list[tuple[Path, float]] = [] # (path, timestamp)
|
||
for i in range(candidate_count):
|
||
t = step * (i + 1)
|
||
out = cand_dir / f"c_{i:03d}.jpg"
|
||
run([
|
||
"ffmpeg", "-y", "-ss", str(t), "-i", str(mp4),
|
||
"-frames:v", "1",
|
||
"-pix_fmt", "yuvj420p", "-q:v", "3",
|
||
str(out),
|
||
])
|
||
if out.exists():
|
||
candidate_meta.append((out, t))
|
||
|
||
# 2) D 启发式选 n 张:pHash 去重 + Laplacian 清晰度 + 时序分桶
|
||
update(job, message=f"启发式筛选 {n} / {len(candidate_meta)} 张…", progress=60)
|
||
cand_paths = [m[0] for m in candidate_meta]
|
||
ts_by_path = {m[0]: m[1] for m in candidate_meta}
|
||
chosen = _select_keyframes(cand_paths, n)
|
||
|
||
# 3) 落盘到 frames/<idx>.jpg
|
||
renamed: list[KeyFrame] = []
|
||
chosen_sorted = sorted(chosen, key=lambda p: ts_by_path[p])
|
||
for i, src in enumerate(chosen_sorted):
|
||
dst = frames_dir / f"{i:03d}.jpg"
|
||
shutil.copyfile(src, dst)
|
||
renamed.append(KeyFrame(
|
||
index=i,
|
||
timestamp=round(ts_by_path[src], 2),
|
||
url=f"/jobs/{job_id}/frames/{i}.jpg",
|
||
))
|
||
|
||
# 4) 清理候选目录
|
||
shutil.rmtree(cand_dir, ignore_errors=True)
|
||
|
||
update(
|
||
job,
|
||
status="frames_extracted",
|
||
frames=renamed,
|
||
progress=70,
|
||
message=f"已抽取 {len(renamed)} 张关键帧 · 可继续清洗 / 提取元素 / 分镜编排",
|
||
)
|
||
|
||
except Exception as e:
|
||
update(job, status="failed", error=str(e), message="解析失败")
|
||
|
||
|
||
# ---------- Gemini ASR + 翻译 ----------
|
||
|
||
def _transcribe_sync(wav: Path) -> list[dict]:
|
||
"""whisper-1 verbose_json → segments[{start, end, text}]"""
|
||
with wav.open("rb") as f:
|
||
resp = llm().audio.transcriptions.create(
|
||
file=(wav.name, f, "audio/wav"),
|
||
model=ASR_MODEL,
|
||
response_format="verbose_json",
|
||
timestamp_granularities=["segment"],
|
||
)
|
||
raw = resp.model_dump() if hasattr(resp, "model_dump") else resp
|
||
segments = raw.get("segments") or []
|
||
# 兜底:网关如果不返回 segments,把全文当一段
|
||
if not segments and raw.get("text"):
|
||
segments = [{"start": 0.0, "end": float(raw.get("duration", 0) or 0), "text": raw["text"]}]
|
||
return segments
|
||
|
||
|
||
def _translate_sync(segments: list[dict]) -> list[str]:
|
||
"""gemini-2.5-flash 批量翻译为中文,按段返回"""
|
||
payload = [{"i": i, "en": s.get("text", "").strip()} for i, s in enumerate(segments)]
|
||
prompt = (
|
||
"你是字幕翻译。把下列英文字幕段翻译为简体中文,保持原意、口语化、自然流畅。"
|
||
"严格返回 JSON 数组,不要任何 markdown 或多余文字,schema: "
|
||
'[{"i": 0, "zh": "..."}, ...]\n\n输入:\n'
|
||
+ json.dumps(payload, ensure_ascii=False)
|
||
)
|
||
resp = llm().chat.completions.create(
|
||
model=TRANSLATE_MODEL,
|
||
messages=[{"role": "user", "content": prompt}],
|
||
response_format={"type": "json_object"},
|
||
temperature=0.2,
|
||
)
|
||
content = resp.choices[0].message.content or "[]"
|
||
try:
|
||
data = json.loads(content)
|
||
if isinstance(data, dict):
|
||
for k in ("data", "items", "result", "translations"):
|
||
if k in data and isinstance(data[k], list):
|
||
data = data[k]
|
||
break
|
||
if not isinstance(data, list):
|
||
data = []
|
||
except json.JSONDecodeError:
|
||
data = []
|
||
zh_by_idx: dict[int, str] = {}
|
||
for it in data:
|
||
if isinstance(it, dict) and "i" in it:
|
||
zh_by_idx[int(it["i"])] = str(it.get("zh", ""))
|
||
return [zh_by_idx.get(i, "") for i in range(len(segments))]
|
||
|
||
|
||
async def pipeline_transcribe(job_id: str) -> None:
|
||
job = JOBS[job_id]
|
||
d = job_dir(job_id)
|
||
wav = d / "audio.wav"
|
||
try:
|
||
if not wav.exists():
|
||
raise RuntimeError("audio.wav 不存在")
|
||
|
||
if not LLM_API_KEY:
|
||
# 无 key 模式:mock 数据
|
||
update(job, status="transcribing", message="ASR (mock) …", progress=75)
|
||
await asyncio.sleep(1.0)
|
||
mock = [
|
||
TranscriptSegment(index=0, start=0.0, end=3.5,
|
||
en="Welcome back, today we're testing something new.",
|
||
zh="欢迎回来,今天我们要测试一些新东西。"),
|
||
TranscriptSegment(index=1, start=3.5, end=7.2,
|
||
en="This device looks really sleek and minimal.",
|
||
zh="这个设备看起来非常时尚和简约。"),
|
||
]
|
||
update(job, transcript=mock, status="transcribed", progress=100,
|
||
message="转录完成(MOCK · 未设 LLM_API_KEY)")
|
||
return
|
||
|
||
# 1) whisper ASR
|
||
update(job, status="transcribing", message=f"{ASR_MODEL} 转录中…", progress=78)
|
||
segments = await asyncio.to_thread(_transcribe_sync, wav)
|
||
if not segments:
|
||
raise RuntimeError("ASR 返回 0 段(可能无人声 / 格式问题)")
|
||
|
||
# 先把英文段落落到 job 上(让 UI 提前看到,翻译再补 zh)
|
||
en_only = [
|
||
TranscriptSegment(
|
||
index=i,
|
||
start=float(s.get("start", 0)),
|
||
end=float(s.get("end", 0)),
|
||
en=str(s.get("text", "")).strip(),
|
||
zh="",
|
||
)
|
||
for i, s in enumerate(segments)
|
||
]
|
||
update(job, transcript=en_only, message=f"ASR 完成 · {len(en_only)} 段,开始翻译…", progress=88)
|
||
|
||
# 2) Gemini 翻译
|
||
zh_list = await asyncio.to_thread(_translate_sync, segments)
|
||
full = [
|
||
TranscriptSegment(
|
||
index=seg.index, start=seg.start, end=seg.end, en=seg.en,
|
||
zh=zh_list[i] if i < len(zh_list) else "",
|
||
)
|
||
for i, seg in enumerate(en_only)
|
||
]
|
||
update(job, transcript=full, status="transcribed", progress=100,
|
||
message=f"转录完成 · {len(full)} 段({ASR_MODEL} + {TRANSLATE_MODEL})")
|
||
|
||
except Exception as e:
|
||
update(job, status="failed", error=str(e), message="转录失败")
|
||
|
||
|
||
def _image_edit_call(
|
||
image_path: Path,
|
||
prompt: str,
|
||
model: str | None = None,
|
||
models: list[str] | None = None,
|
||
fallback_text: bool = False,
|
||
max_attempts: int = 3,
|
||
max_side: int = 1024,
|
||
) -> tuple[bytes, str]:
|
||
"""通用 image edit 调用 · 失败重试 + 可选 text fallback。
|
||
返回 (image_bytes, effective_mode) where effective_mode in {"edit","text"}。
|
||
失败 raise RuntimeError。
|
||
输入图自动 resize 到 max_side(默认 1024)边长后再 base64,避免大图把 Gemini
|
||
function call 输入挤超阈值导致 incomplete_generation。
|
||
models: 多模型轮换列表,重试时换 model;不传则单一 model 重试。"""
|
||
import base64 as b64lib
|
||
import io as _io
|
||
import time as _time
|
||
import httpx
|
||
from PIL import Image as _PILImage
|
||
if not LLM_API_KEY:
|
||
raise RuntimeError("LLM_API_KEY 未配置")
|
||
# model 优先级:models 列表 > 单个 model 参数 > IMAGE_MODEL
|
||
if models and len(models) > 0:
|
||
models_cycle = list(models)
|
||
else:
|
||
models_cycle = [model or IMAGE_MODEL]
|
||
model = models_cycle[0]
|
||
# 缩到 max_side 内
|
||
try:
|
||
im = _PILImage.open(image_path)
|
||
if max(im.size) > max_side:
|
||
im.thumbnail((max_side, max_side), _PILImage.LANCZOS)
|
||
buf = _io.BytesIO()
|
||
im.convert("RGB").save(buf, format="JPEG", quality=88)
|
||
img_bytes_in = buf.getvalue()
|
||
except Exception:
|
||
# PIL 失败兜底走原文件
|
||
img_bytes_in = image_path.read_bytes()
|
||
img_b64 = b64lib.b64encode(img_bytes_in).decode("ascii")
|
||
data_uri = f"data:image/jpeg;base64,{img_b64}"
|
||
|
||
plan: list[str] = ["edit"] * max_attempts
|
||
if fallback_text:
|
||
plan.append("text")
|
||
|
||
last_err = ""
|
||
resp_data: dict = {}
|
||
effective_mode = "edit"
|
||
for attempt, current_mode in enumerate(plan):
|
||
# 多模型轮换:第 N 次重试用第 N 个 model(不够时用最后一个)
|
||
current_model = models_cycle[min(attempt, len(models_cycle) - 1)]
|
||
try:
|
||
if current_mode == "edit":
|
||
with httpx.Client(timeout=120) as client:
|
||
r = client.post(
|
||
f"{LLM_BASE_URL}/images/generations",
|
||
headers={
|
||
"Authorization": f"Bearer {LLM_API_KEY}",
|
||
"Content-Type": "application/json",
|
||
},
|
||
json={"model": current_model, "prompt": prompt, "image": data_uri, "n": 1},
|
||
)
|
||
r.raise_for_status()
|
||
resp_data = r.json()
|
||
else:
|
||
resp = llm().images.generate(model=current_model, prompt=prompt, n=1)
|
||
resp_data = resp.model_dump() if hasattr(resp, "model_dump") else {"data": [{"b64_json": resp.data[0].b64_json}]}
|
||
if resp_data.get("data"):
|
||
effective_mode = current_mode
|
||
model = current_model # 记录实际成功的 model
|
||
break
|
||
err_obj = resp_data.get("error") or {}
|
||
last_err = f"empty data · {err_obj.get('code', '')} · {str(err_obj.get('message', ''))[:200]} · model={current_model}"
|
||
except httpx.HTTPStatusError as e:
|
||
body = e.response.text
|
||
# 多模型轮换场景:除明确不可恢复(4xx 鉴权类)外都重试换 model
|
||
sc = e.response.status_code
|
||
fatal = sc in (401, 403)
|
||
last_err = f"HTTP {sc}: {body[:200]} · model={current_model}"
|
||
if fatal:
|
||
raise RuntimeError(f"image edit HTTP {sc}: {body[:300]}")
|
||
except Exception as e:
|
||
last_err = f"{type(e).__name__}: {e} · model={current_model}"
|
||
|
||
if attempt < len(plan) - 1:
|
||
next_model = models_cycle[min(attempt + 1, len(models_cycle) - 1)]
|
||
tag = f"retry {attempt + 1}/{len(plan)} → {next_model}"
|
||
print(f"[image edit {tag}] {last_err}", flush=True)
|
||
_time.sleep(1.0)
|
||
|
||
data_arr = resp_data.get("data", [])
|
||
if not data_arr:
|
||
raise RuntimeError(f"image edit failed after {len(plan)} attempts: {last_err}")
|
||
b64 = data_arr[0].get("b64_json")
|
||
if not b64:
|
||
raise RuntimeError("image edit returned no b64_json")
|
||
return b64lib.b64decode(b64), effective_mode
|
||
|
||
|
||
# ---------- API 路由 ----------
|
||
|
||
class CreateJobReq(BaseModel):
|
||
url: str
|
||
|
||
|
||
class TranslateReq(BaseModel):
|
||
text: str
|
||
target: Literal["en", "zh"] = "en"
|
||
|
||
|
||
@app.post("/translate")
|
||
def translate_text(req: TranslateReq) -> dict:
|
||
"""单条文本翻译(给生图自定义提取元素 zh→en 用)"""
|
||
import re as _re
|
||
text = req.text.strip()
|
||
if not text:
|
||
return {"text": ""}
|
||
if not LLM_API_KEY:
|
||
raise HTTPException(503, "LLM_API_KEY 未配置")
|
||
|
||
target_label = "English" if req.target == "en" else "Simplified Chinese"
|
||
prompt = (
|
||
f"Translate the following text into concise {target_label}, suitable as an element "
|
||
"label in an image-generation prompt. Output only the translation itself — no quotes, "
|
||
"no punctuation, no explanation, no markdown.\n\n"
|
||
f"Input: {text}"
|
||
)
|
||
try:
|
||
resp = llm().chat.completions.create(
|
||
model=TRANSLATE_MODEL,
|
||
messages=[{"role": "user", "content": prompt}],
|
||
temperature=0.2,
|
||
max_tokens=200,
|
||
)
|
||
out = (resp.choices[0].message.content or "").strip()
|
||
if not out:
|
||
rc = getattr(resp.choices[0].message, "reasoning_content", "") or ""
|
||
if rc:
|
||
out = rc.strip().splitlines()[-1].strip()
|
||
out = _re.sub(r'^[\'"「『]+|[\'"」』]+$', "", out).strip()
|
||
return {"text": out}
|
||
except Exception as e:
|
||
raise HTTPException(500, f"translate failed: {e}")
|
||
|
||
|
||
@app.get("/health")
|
||
def health() -> dict:
|
||
return {
|
||
"ok": True,
|
||
"llm_configured": bool(LLM_API_KEY),
|
||
"base_url": LLM_BASE_URL or "openai-default",
|
||
"models": {
|
||
"asr": ASR_MODEL,
|
||
"translate": TRANSLATE_MODEL,
|
||
"rewrite": REWRITE_MODEL,
|
||
"video": VIDEO_MODEL,
|
||
"video_aliases": VIDEO_MODEL_ALIASES,
|
||
"video_provider": "poe" if video_uses_poe() else ("ark" if video_uses_ark() else "custom"),
|
||
"video_base_url": video_api_base(),
|
||
"video_configured": bool(video_api_key()),
|
||
"video_create_paths": VIDEO_CREATE_PATHS,
|
||
},
|
||
}
|
||
|
||
|
||
@app.post("/jobs", response_model=Job)
|
||
async def create_job(req: CreateJobReq, bg: BackgroundTasks) -> Job:
|
||
if not req.url.strip():
|
||
raise HTTPException(400, "url required")
|
||
job_id = uuid.uuid4().hex[:12]
|
||
job = Job(id=job_id, url=req.url.strip())
|
||
JOBS[job_id] = job
|
||
save_state(job)
|
||
bg.add_task(pipeline_download, job_id)
|
||
return job
|
||
|
||
|
||
@app.post("/jobs/upload", response_model=Job)
|
||
async def create_job_from_upload(bg: BackgroundTasks, file: UploadFile = File(...)) -> Job:
|
||
if not file.filename:
|
||
raise HTTPException(400, "file required")
|
||
ext = Path(file.filename).suffix.lower()
|
||
if ext not in {".mp4", ".mov", ".webm", ".mkv", ".m4v"}:
|
||
raise HTTPException(400, f"unsupported video format: {ext}")
|
||
|
||
job_id = uuid.uuid4().hex[:12]
|
||
d = job_dir(job_id)
|
||
mp4 = d / "source.mp4"
|
||
with mp4.open("wb") as f:
|
||
while chunk := await file.read(1024 * 1024):
|
||
f.write(chunk)
|
||
if not mp4.exists() or mp4.stat().st_size == 0:
|
||
raise HTTPException(500, "upload failed")
|
||
|
||
job = Job(id=job_id, url=f"upload://{file.filename}")
|
||
JOBS[job_id] = job
|
||
save_state(job)
|
||
bg.add_task(pipeline_download, job_id)
|
||
return job
|
||
|
||
|
||
@app.post("/jobs/{job_id}/analyze", response_model=Job)
|
||
async def trigger_analyze(job_id: str, bg: BackgroundTasks, frames: int = KEYFRAME_COUNT) -> Job:
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
if job.status not in {"downloaded", "frames_extracted", "transcribed", "failed"}:
|
||
raise HTTPException(409, f"status must be downloaded/failed, got {job.status}")
|
||
bg.add_task(pipeline_analyze, job_id, frames)
|
||
return job
|
||
|
||
|
||
@app.post("/jobs/{job_id}/frames", response_model=Job)
|
||
def add_manual_frame(job_id: str, t: float) -> Job:
|
||
"""从指定时间戳手动抽 1 帧追加到 job.frames"""
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
if not job.video_url:
|
||
raise HTTPException(400, "video not ready")
|
||
d = job_dir(job_id)
|
||
mp4 = d / "source.mp4"
|
||
if not mp4.exists():
|
||
raise HTTPException(400, "source.mp4 missing")
|
||
frames_dir = d / "frames"
|
||
frames_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 新 index:max(existing)+1(即使列表已按 ts 排序,文件名用 index 保持稳定)
|
||
next_idx = max((f.index for f in job.frames), default=-1) + 1
|
||
out = frames_dir / f"{next_idx:03d}.jpg"
|
||
try:
|
||
run([
|
||
"ffmpeg", "-y", "-ss", str(t), "-i", str(mp4),
|
||
"-frames:v", "1", "-pix_fmt", "yuvj420p", "-q:v", "3",
|
||
str(out),
|
||
])
|
||
except RuntimeError as e:
|
||
raise HTTPException(500, f"ffmpeg failed: {e}")
|
||
|
||
new_frame = KeyFrame(
|
||
index=next_idx,
|
||
timestamp=round(float(t), 2),
|
||
url=f"/jobs/{job_id}/frames/{next_idx}.jpg",
|
||
)
|
||
merged = sorted(list(job.frames) + [new_frame], key=lambda f: f.timestamp)
|
||
update(job, frames=merged, message=f"已手动加帧({t:.1f}s),共 {len(merged)} 张")
|
||
return job
|
||
|
||
|
||
@app.get("/jobs/{job_id}", response_model=Job)
|
||
def get_job(job_id: str) -> Job:
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
return job
|
||
|
||
|
||
@app.post("/jobs/{job_id}/transcribe", response_model=Job)
|
||
async def trigger_transcribe(job_id: str, bg: BackgroundTasks) -> Job:
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
if job.status != "frames_extracted":
|
||
raise HTTPException(409, f"status must be frames_extracted, got {job.status}")
|
||
bg.add_task(pipeline_transcribe, job_id)
|
||
return job
|
||
|
||
|
||
@app.get("/jobs/{job_id}/video.mp4")
|
||
def get_video(job_id: str):
|
||
p = job_dir(job_id) / "source.mp4"
|
||
if not p.exists():
|
||
raise HTTPException(404, "video not found")
|
||
return FileResponse(p, media_type="video/mp4")
|
||
|
||
|
||
@app.get("/jobs/{job_id}/frames/{idx}.jpg")
|
||
def get_frame(job_id: str, idx: int):
|
||
p = job_dir(job_id) / "frames" / f"{idx:03d}.jpg"
|
||
if not p.exists():
|
||
raise HTTPException(404, "frame not found")
|
||
return FileResponse(p, media_type="image/jpeg")
|
||
|
||
|
||
class GenerateReq(BaseModel):
|
||
prompt: str
|
||
extra_prompt: str = "" # ✓ 需要的元素(正向)
|
||
negative_prompt: str = "" # ✗ 不需要的元素(负向)
|
||
model: str = "" # 留空用 IMAGE_MODEL 默认
|
||
mode: str = "edit" # "edit" 带参考图,"text" 纯文字
|
||
from_selected: bool = False # True 时优先用 frame.selected 的生成图作 reference(迭代),否则原关键帧
|
||
|
||
|
||
@app.post("/jobs/{job_id}/frames/{idx}/generate", response_model=Job)
|
||
def generate_image(job_id: str, idx: int, req: GenerateReq) -> Job:
|
||
"""根据关键帧 + prompt 生成新图(image-to-image 或 text-to-image)"""
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
frame = next((f for f in job.frames if f.index == idx), None)
|
||
if not frame:
|
||
raise HTTPException(404, "frame not found")
|
||
frame_path = job_dir(job_id) / "frames" / f"{idx:03d}.jpg"
|
||
if not frame_path.exists():
|
||
raise HTTPException(404, "frame file missing")
|
||
|
||
# 决定 i2i 参考图:from_selected=True 且存在 selected 生成图 → 用它(迭代);否则原关键帧
|
||
reference_path = frame_path
|
||
reference_source = "keyframe"
|
||
if req.from_selected:
|
||
sel = next((g for g in frame.generated_images if g.selected), None)
|
||
if sel:
|
||
sel_path = job_dir(job_id) / "gen" / f"{idx:03d}_{sel.id}.jpg"
|
||
if sel_path.exists():
|
||
reference_path = sel_path
|
||
reference_source = f"gen:{sel.id[:6]}"
|
||
|
||
full_prompt = req.prompt.strip()
|
||
if req.extra_prompt.strip():
|
||
full_prompt = f"{full_prompt}. Include: {req.extra_prompt.strip()}"
|
||
if req.negative_prompt.strip():
|
||
full_prompt = f"{full_prompt}. Avoid: {req.negative_prompt.strip()}"
|
||
if not full_prompt:
|
||
raise HTTPException(400, "prompt required")
|
||
|
||
model = req.model or IMAGE_MODEL
|
||
gen_id = uuid.uuid4().hex[:12]
|
||
|
||
import base64 as b64lib
|
||
import time as _time
|
||
import httpx
|
||
|
||
img_b64: str | None = None
|
||
if req.mode == "edit":
|
||
img_b64 = b64lib.b64encode(reference_path.read_bytes()).decode("ascii")
|
||
|
||
# 尝试 i2i 最多 3 次,全失败时降级 text-only 再试 1 次
|
||
plan: list[str] = ([req.mode] * 3) if req.mode == "edit" else [req.mode]
|
||
if req.mode == "edit":
|
||
plan.append("text") # i2i 都失败时自动降级
|
||
resp_data: dict = {}
|
||
last_err = ""
|
||
effective_mode = req.mode
|
||
for attempt, current_mode in enumerate(plan):
|
||
try:
|
||
if current_mode == "edit":
|
||
data_uri = f"data:image/jpeg;base64,{img_b64}"
|
||
# OpenAI SDK 不直接支持 image 参数,用底层 httpx
|
||
with httpx.Client(timeout=120) as client:
|
||
r = client.post(
|
||
f"{LLM_BASE_URL}/images/generations",
|
||
headers={
|
||
"Authorization": f"Bearer {LLM_API_KEY}",
|
||
"Content-Type": "application/json",
|
||
},
|
||
json={
|
||
"model": model,
|
||
"prompt": full_prompt,
|
||
"image": data_uri,
|
||
"n": 1,
|
||
},
|
||
)
|
||
r.raise_for_status()
|
||
resp_data = r.json()
|
||
else:
|
||
# text-only
|
||
resp = llm().images.generate(model=model, prompt=full_prompt, n=1)
|
||
resp_data = resp.model_dump() if hasattr(resp, "model_dump") else {"data": [{"b64_json": resp.data[0].b64_json}]}
|
||
|
||
if resp_data.get("data"):
|
||
effective_mode = current_mode
|
||
break
|
||
err_obj = resp_data.get("error") or {}
|
||
last_err = f"empty data · {err_obj.get('code', '')} · {str(err_obj.get('message', ''))[:200]}"
|
||
except httpx.HTTPStatusError as e:
|
||
body = e.response.text
|
||
transient = (
|
||
e.response.status_code >= 500
|
||
or "incomplete_generation" in body
|
||
or "rate_limit" in body
|
||
or "timeout" in body.lower()
|
||
)
|
||
last_err = f"HTTP {e.response.status_code}: {body[:200]}"
|
||
if not transient:
|
||
raise HTTPException(500, f"image gen HTTP {e.response.status_code}: {body[:300]}")
|
||
except Exception as e:
|
||
last_err = f"{type(e).__name__}: {e}"
|
||
|
||
if attempt < len(plan) - 1:
|
||
next_mode = plan[attempt + 1]
|
||
tag = f"fallback → {next_mode}" if next_mode != current_mode else f"retry {attempt + 1}/{len(plan)}"
|
||
print(f"[image gen {tag}] {last_err}", flush=True)
|
||
_time.sleep(1.5 * (attempt + 1))
|
||
|
||
data_arr = resp_data.get("data", [])
|
||
if not data_arr:
|
||
raise HTTPException(500, f"image gen failed after {len(plan)} attempts: {last_err}")
|
||
|
||
item = data_arr[0]
|
||
b64 = item.get("b64_json")
|
||
if not b64:
|
||
raise HTTPException(500, "image gen returned no b64_json")
|
||
|
||
# 保存到本地 jobs/<id>/gen/<idx>_<gen_id>.jpg
|
||
gen_dir = job_dir(job_id) / "gen"
|
||
gen_dir.mkdir(parents=True, exist_ok=True)
|
||
out_path = gen_dir / f"{idx:03d}_{gen_id}.jpg"
|
||
out_path.write_bytes(b64lib.b64decode(b64))
|
||
|
||
new_gen = GeneratedImage(
|
||
id=gen_id,
|
||
prompt=full_prompt,
|
||
model=model,
|
||
mode=effective_mode,
|
||
url=f"/jobs/{job_id}/frames/{idx}/gen/{gen_id}.jpg",
|
||
selected=False,
|
||
created_at=_time.time(),
|
||
)
|
||
|
||
# 写回 job.frames
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
f.generated_images = f.generated_images + [new_gen]
|
||
update(job, frames=job.frames, message=f"生图完成 · 分镜 {idx + 1}")
|
||
return job
|
||
|
||
|
||
@app.get("/jobs/{job_id}/frames/{idx}/gen/{gen_id}.jpg")
|
||
def get_generated_image(job_id: str, idx: int, gen_id: str):
|
||
p = job_dir(job_id) / "gen" / f"{idx:03d}_{gen_id}.jpg"
|
||
if not p.exists():
|
||
raise HTTPException(404, "generated image not found")
|
||
return FileResponse(p, media_type="image/jpeg")
|
||
|
||
|
||
class SelectGenReq(BaseModel):
|
||
selected: bool
|
||
|
||
|
||
@app.post("/jobs/{job_id}/frames/{idx}/gen/{gen_id}/select", response_model=Job)
|
||
def select_generated(job_id: str, idx: int, gen_id: str, req: SelectGenReq) -> Job:
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
for f in job.frames:
|
||
if f.index != idx:
|
||
continue
|
||
for g in f.generated_images:
|
||
# 单选:该帧只能选一张
|
||
if g.id == gen_id:
|
||
g.selected = req.selected
|
||
else:
|
||
g.selected = False
|
||
break
|
||
update(job, frames=job.frames)
|
||
return job
|
||
|
||
|
||
@app.post("/jobs/{job_id}/frames/{idx}/describe", response_model=Job)
|
||
def describe_frame(job_id: str, idx: int) -> Job:
|
||
"""调 vision 模型识别该关键帧,返回结构化描述。"""
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
frame = next((f for f in job.frames if f.index == idx), None)
|
||
if not frame:
|
||
raise HTTPException(404, "frame not found")
|
||
p = job_dir(job_id) / "frames" / f"{idx:03d}.jpg"
|
||
if not p.exists():
|
||
raise HTTPException(404, "frame file not found")
|
||
|
||
import base64 as b64lib
|
||
import re as _re
|
||
img_b64 = b64lib.b64encode(p.read_bytes()).decode("ascii")
|
||
|
||
prompt = (
|
||
"请识别这张图,输出严格 JSON(不要 markdown 不要解释,不要思考):\n"
|
||
'{\n'
|
||
' "scene": "一句话描述场景",\n'
|
||
' "objects": [{"name": "物体名(中文)", "position": "在画面哪里", "color": "颜色", "extract_prompt": "用于提取该元素的英文 prompt"}],\n'
|
||
' "style": "整体风格 / 打光 / 色调(一句话)",\n'
|
||
' "suggested_prompt": "适合用作下游生图的完整英文 prompt"\n'
|
||
'}\n'
|
||
"要求:objects 列出 3-8 个画面里**可独立提取**的主要元素,extract_prompt 用于后续 image edit 模型。"
|
||
)
|
||
|
||
last_err = ""
|
||
data = None
|
||
for attempt in range(3):
|
||
try:
|
||
resp = llm().chat.completions.create(
|
||
model=VISION_MODEL,
|
||
messages=[{"role": "user", "content": [
|
||
{"type": "text", "text": prompt},
|
||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img_b64}"}},
|
||
]}],
|
||
response_format={"type": "json_object"},
|
||
temperature=0.3,
|
||
max_tokens=3000,
|
||
)
|
||
content = (resp.choices[0].message.content or "").strip()
|
||
if not content:
|
||
# thinking 模型可能 content 空;尝试取 reasoning_content 里挖 JSON
|
||
rc = getattr(resp.choices[0].message, "reasoning_content", "") or ""
|
||
m = _re.search(r"\{[\s\S]*\}", rc)
|
||
content = m.group(0) if m else ""
|
||
# 剥掉 ```json ... ``` 包装
|
||
content = _re.sub(r"^```(?:json)?\s*|\s*```$", "", content).strip()
|
||
if not content:
|
||
last_err = f"empty content (attempt {attempt + 1})"
|
||
continue
|
||
data = json.loads(content)
|
||
break
|
||
except json.JSONDecodeError as e:
|
||
last_err = f"json decode (attempt {attempt + 1}): {e} · raw[:200]={content[:200]}"
|
||
print(f"[vision retry] {last_err}", flush=True)
|
||
continue
|
||
except Exception as e:
|
||
last_err = f"vision call (attempt {attempt + 1}): {e}"
|
||
print(f"[vision retry] {last_err}", flush=True)
|
||
continue
|
||
|
||
if data is None:
|
||
raise HTTPException(500, last_err or "vision failed after 3 retries")
|
||
|
||
# 写回 job
|
||
new_frames = []
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
f.description = data
|
||
new_frames.append(f)
|
||
update(job, frames=new_frames, message=f"识别完成 · 分镜 {idx + 1}")
|
||
return job
|
||
|
||
|
||
# ---------- 清洗水印 / 元素提取(关键帧二阶段加工) ----------
|
||
|
||
class CleanupReq(BaseModel):
|
||
# 多个相对坐标矩形 0-1,限制清洗范围;空 / None = 全图清洗
|
||
regions: list[dict] | None = None # [{"x","y","w","h"}, ...]
|
||
|
||
|
||
def _region_to_phrase(r: dict) -> str:
|
||
"""把相对坐标矩形转成简短方位描述给 prompt 用(避免百分号 / 括号触发模型异常)"""
|
||
x = max(0.0, min(1.0, float(r.get("x", 0))))
|
||
y = max(0.0, min(1.0, float(r.get("y", 0))))
|
||
w = max(0.0, min(1.0 - x, float(r.get("w", 0))))
|
||
h = max(0.0, min(1.0 - y, float(r.get("h", 0))))
|
||
if w <= 0 or h <= 0:
|
||
return ""
|
||
cx, cy = x + w / 2, y + h / 2
|
||
hpos = "left" if cx < 0.4 else "right" if cx > 0.6 else "middle"
|
||
vpos = "top" if cy < 0.4 else "bottom" if cy > 0.6 else "middle"
|
||
if hpos == "middle" and vpos == "middle":
|
||
return "center"
|
||
if hpos == "middle":
|
||
return vpos
|
||
if vpos == "middle":
|
||
return hpos
|
||
return f"{vpos} {hpos}"
|
||
|
||
|
||
@app.post("/jobs/{job_id}/frames/{idx}/cleanup", response_model=Job)
|
||
def cleanup_frame(job_id: str, idx: int, req: CleanupReq | None = None) -> Job:
|
||
"""调 nano-banana image edit 清洗关键帧:去水印 / @用户名 / 字幕 / 平台 logo。
|
||
输出干净版到 jobs/<id>/cleaned/<idx>.jpg,写回 frame.cleaned_url。
|
||
可选 region: 限定只清洗框内区域。"""
|
||
import time as _time
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
frame = next((f for f in job.frames if f.index == idx), None)
|
||
if not frame:
|
||
raise HTTPException(404, "frame not found")
|
||
frame_path = job_dir(job_id) / "frames" / f"{idx:03d}.jpg"
|
||
if not frame_path.exists():
|
||
raise HTTPException(404, "frame file missing")
|
||
|
||
region_phrases: list[str] = []
|
||
if req and req.regions:
|
||
for r in req.regions:
|
||
p = _region_to_phrase(r)
|
||
if p:
|
||
region_phrases.append(p)
|
||
region_phrases = list(dict.fromkeys(region_phrases))
|
||
|
||
# prompt 用"重画一张副本"语义而非"erase / remove only X" — 避免 Gemini 走 mask/inpainting
|
||
# function call 路径(实测该路径在 SKG 网关上 100% 触发 incomplete_generation)
|
||
if region_phrases:
|
||
if len(region_phrases) == 1:
|
||
zones = f"the {region_phrases[0]} area"
|
||
else:
|
||
zones = ", ".join(region_phrases) + " areas"
|
||
prompt = (
|
||
f"Recreate this image as a clean version: remove the text and graphics in {zones}, "
|
||
"keep the rest of the scene identical."
|
||
)
|
||
else:
|
||
prompt = (
|
||
"Recreate this image as a clean version without watermarks, captions, "
|
||
"hashtags, usernames, or platform logos. Keep the composition and style."
|
||
)
|
||
|
||
# 模型轮换:nano-banana-pro 失败时换 flash 系列
|
||
models = [
|
||
IMAGE_MODEL, # gemini-3-pro-image-preview (nano-banana-pro)
|
||
"gemini-3.1-flash-image-preview",
|
||
"gemini-2.5-flash-image",
|
||
]
|
||
try:
|
||
img_bytes, _mode = _image_edit_call(
|
||
frame_path, prompt, models=models, fallback_text=False, max_attempts=3,
|
||
)
|
||
except RuntimeError as e:
|
||
raise HTTPException(500, f"cleanup failed: {e}")
|
||
|
||
out_dir = job_dir(job_id) / "cleaned"
|
||
out_dir.mkdir(parents=True, exist_ok=True)
|
||
out_path = out_dir / f"{idx:03d}.jpg"
|
||
out_path.write_bytes(img_bytes)
|
||
|
||
new_frames = []
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
f.cleaned_url = f"/jobs/{job_id}/frames/{idx}/cleaned.jpg?t={int(_time.time())}"
|
||
f.cleaned_applied = False # 重新清洗:重置"已应用"状态
|
||
new_frames.append(f)
|
||
update(job, frames=new_frames, message=f"清洗完成 · 分镜 {idx + 1}")
|
||
return job
|
||
|
||
|
||
@app.get("/jobs/{job_id}/frames/{idx}/cleaned.jpg")
|
||
def get_cleaned_frame(job_id: str, idx: int):
|
||
p = job_dir(job_id) / "cleaned" / f"{idx:03d}.jpg"
|
||
if not p.exists():
|
||
raise HTTPException(404, "cleaned frame not found")
|
||
return FileResponse(p, media_type="image/jpeg")
|
||
|
||
|
||
@app.delete("/jobs/{job_id}/frames/{idx}/cleanup", response_model=Job)
|
||
def discard_cleaned(job_id: str, idx: int) -> Job:
|
||
"""丢弃待应用的清洗版(不影响已应用的)"""
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
frame = next((f for f in job.frames if f.index == idx), None)
|
||
if not frame:
|
||
raise HTTPException(404, "frame not found")
|
||
p = job_dir(job_id) / "cleaned" / f"{idx:03d}.jpg"
|
||
if p.exists():
|
||
try: p.unlink()
|
||
except OSError: pass
|
||
new_frames = []
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
f.cleaned_url = None
|
||
new_frames.append(f)
|
||
update(job, frames=new_frames, message=f"丢弃清洗版 · 分镜 {idx + 1}")
|
||
return job
|
||
|
||
|
||
@app.post("/jobs/{job_id}/frames/{idx}/cleanup/apply", response_model=Job)
|
||
def apply_cleaned(job_id: str, idx: int) -> Job:
|
||
"""用清洗版替换原关键帧:物理覆盖 frames/{idx}.jpg ← cleaned/{idx}.jpg。
|
||
原图作备份 → orig/{idx}.jpg(首次替换时备份,后续替换跳过)。
|
||
替换后 frame.cleaned_url 清空(不再有"待应用"清洗版)"""
|
||
import shutil as _shutil
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
frame = next((f for f in job.frames if f.index == idx), None)
|
||
if not frame:
|
||
raise HTTPException(404, "frame not found")
|
||
cleaned_path = job_dir(job_id) / "cleaned" / f"{idx:03d}.jpg"
|
||
if not cleaned_path.exists():
|
||
raise HTTPException(404, "no cleaned version to apply")
|
||
frame_path = job_dir(job_id) / "frames" / f"{idx:03d}.jpg"
|
||
|
||
# 首次替换:把原图备份到 orig/{idx}.jpg
|
||
orig_dir = job_dir(job_id) / "orig"
|
||
orig_dir.mkdir(parents=True, exist_ok=True)
|
||
orig_backup = orig_dir / f"{idx:03d}.jpg"
|
||
if not orig_backup.exists() and frame_path.exists():
|
||
_shutil.copy2(frame_path, orig_backup)
|
||
|
||
# 用 cleaned 覆盖 frames/
|
||
_shutil.copy2(cleaned_path, frame_path)
|
||
# 删 cleaned 文件(已经"应用",不再是单独的待选版本)
|
||
try:
|
||
cleaned_path.unlink()
|
||
except OSError:
|
||
pass
|
||
|
||
new_frames = []
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
f.cleaned_url = None
|
||
f.cleaned_applied = True
|
||
new_frames.append(f)
|
||
update(job, frames=new_frames, message=f"已替换分镜 {idx + 1} 为清洗版")
|
||
return job
|
||
|
||
|
||
class AddElementReq(BaseModel):
|
||
name_zh: str
|
||
name_en: str = ""
|
||
position: str = ""
|
||
source: Literal["auto", "manual", "region"] = "manual"
|
||
region: dict | None = None
|
||
|
||
|
||
class UpdateElementReq(BaseModel):
|
||
name_zh: str | None = None
|
||
name_en: str | None = None
|
||
position: str | None = None
|
||
|
||
|
||
@app.post("/jobs/{job_id}/frames/{idx}/elements", response_model=Job)
|
||
def add_element(job_id: str, idx: int, req: AddElementReq) -> Job:
|
||
"""加一条元素 · 若 name_en 缺则自动 zh→en 翻译"""
|
||
import time as _time
|
||
import re as _re
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
frame = next((f for f in job.frames if f.index == idx), None)
|
||
if not frame:
|
||
raise HTTPException(404, "frame not found")
|
||
name_zh = req.name_zh.strip()
|
||
if not name_zh:
|
||
raise HTTPException(400, "name_zh required")
|
||
name_en = req.name_en.strip()
|
||
if not name_en and LLM_API_KEY:
|
||
try:
|
||
prompt = (
|
||
"Translate the following text into concise English, suitable as an element label "
|
||
"in an image-generation prompt. Output only the translation — no quotes, no punctuation, "
|
||
f"no explanation.\n\nInput: {name_zh}"
|
||
)
|
||
resp = llm().chat.completions.create(
|
||
model=TRANSLATE_MODEL,
|
||
messages=[{"role": "user", "content": prompt}],
|
||
temperature=0.2,
|
||
max_tokens=200,
|
||
)
|
||
out = (resp.choices[0].message.content or "").strip()
|
||
if not out:
|
||
rc = getattr(resp.choices[0].message, "reasoning_content", "") or ""
|
||
if rc:
|
||
out = rc.strip().splitlines()[-1].strip()
|
||
name_en = _re.sub(r'^[\'"「『]+|[\'"」』]+$', "", out).strip()
|
||
except Exception as e:
|
||
print(f"[add_element translate failed] {e}", flush=True)
|
||
name_en = ""
|
||
|
||
el = KeyElement(
|
||
id=uuid.uuid4().hex[:8],
|
||
name_zh=name_zh,
|
||
name_en=name_en,
|
||
position=req.position.strip(),
|
||
source=req.source,
|
||
region=req.region,
|
||
created_at=_time.time(),
|
||
)
|
||
new_frames = []
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
f.elements = f.elements + [el]
|
||
new_frames.append(f)
|
||
update(job, frames=new_frames, message=f"加入元素 · 分镜 {idx + 1} · {name_zh}")
|
||
return job
|
||
|
||
|
||
@app.patch("/jobs/{job_id}/frames/{idx}/elements/{element_id}", response_model=Job)
|
||
def update_element(job_id: str, idx: int, element_id: str, req: UpdateElementReq) -> Job:
|
||
"""更新元素标签 / 英文提示。提取不准时允许用户修正,不强制重建元素。"""
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
changed_name = ""
|
||
found = False
|
||
new_frames = []
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
for e in f.elements:
|
||
if e.id == element_id:
|
||
found = True
|
||
if req.name_zh is not None:
|
||
name_zh = req.name_zh.strip()
|
||
if not name_zh:
|
||
raise HTTPException(400, "name_zh required")
|
||
e.name_zh = name_zh
|
||
changed_name = name_zh
|
||
if req.name_en is not None:
|
||
e.name_en = req.name_en.strip()
|
||
if req.position is not None:
|
||
e.position = req.position.strip()
|
||
new_frames.append(f)
|
||
if not found:
|
||
raise HTTPException(404, "element not found")
|
||
update(job, frames=new_frames, message=f"更新元素 · 分镜 {idx + 1} · {changed_name or element_id}")
|
||
return job
|
||
|
||
|
||
@app.delete("/jobs/{job_id}/frames/{idx}/elements/{element_id}", response_model=Job)
|
||
def delete_element(job_id: str, idx: int, element_id: str) -> Job:
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
new_frames = []
|
||
removed = False
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
before = len(f.elements)
|
||
f.elements = [e for e in f.elements if e.id != element_id]
|
||
removed = len(f.elements) < before
|
||
# 若有提取图也删(含多版本)
|
||
if removed:
|
||
elements_dir = job_dir(job_id) / "elements"
|
||
if elements_dir.exists():
|
||
for pat in (f"{idx:03d}_{element_id}.jpg", f"{idx:03d}_{element_id}.png",
|
||
f"{idx:03d}_{element_id}_*.jpg"):
|
||
for p in elements_dir.glob(pat):
|
||
try: p.unlink()
|
||
except OSError: pass
|
||
new_frames.append(f)
|
||
if not removed:
|
||
raise HTTPException(404, "element not found")
|
||
update(job, frames=new_frames, message=f"删除元素 · 分镜 {idx + 1}")
|
||
return job
|
||
|
||
|
||
@app.post("/jobs/{job_id}/frames/{idx}/elements/{element_id}/cutout", response_model=Job)
|
||
def cutout_element(job_id: str, idx: int, element_id: str) -> Job:
|
||
"""AI 提取元素 · 每次累积一张新图:
|
||
调 nano-banana 模型生成**完整、清晰**的元素图(即使原图只露出部分也补全)。
|
||
region 元素:先把 region + 30% padding 区域裁出作为 focus,再发给模型聚焦补全。"""
|
||
from PIL import Image as _PILImage
|
||
import io as _io
|
||
import tempfile as _tempfile
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
frame = next((f for f in job.frames if f.index == idx), None)
|
||
if not frame:
|
||
raise HTTPException(404, "frame not found")
|
||
el = next((e for e in frame.elements if e.id == element_id), None)
|
||
if not el:
|
||
raise HTTPException(404, "element not found")
|
||
|
||
cleaned_path = job_dir(job_id) / "cleaned" / f"{idx:03d}.jpg"
|
||
src = cleaned_path if cleaned_path.exists() else job_dir(job_id) / "frames" / f"{idx:03d}.jpg"
|
||
if not src.exists():
|
||
raise HTTPException(404, "source frame file missing")
|
||
|
||
out_dir = job_dir(job_id) / "elements"
|
||
out_dir.mkdir(parents=True, exist_ok=True)
|
||
new_cutout_id = uuid.uuid4().hex[:8]
|
||
out_path = out_dir / f"{idx:03d}_{element_id}_{new_cutout_id}.jpg"
|
||
|
||
# region 元素:先 PIL 裁出 region + 30% padding 作为 focus 给模型(让它聚焦在该元素)
|
||
tmp_focus: Path | None = None
|
||
model_src = src
|
||
if el.region:
|
||
try:
|
||
im = _PILImage.open(src).convert("RGB")
|
||
W, H = im.size
|
||
r = el.region
|
||
x = max(0.0, min(1.0, float(r.get("x", 0))))
|
||
y = max(0.0, min(1.0, float(r.get("y", 0))))
|
||
w = max(0.0, min(1.0 - x, float(r.get("w", 0))))
|
||
h = max(0.0, min(1.0 - y, float(r.get("h", 0))))
|
||
cx, cy = x + w / 2, y + h / 2
|
||
# 扩大 30% 给上下文(避免裁到正好边界丢失补全 hint)
|
||
ew, eh = w * 1.6, h * 1.6
|
||
x0 = max(0.0, cx - ew / 2); y0 = max(0.0, cy - eh / 2)
|
||
x1 = min(1.0, cx + ew / 2); y1 = min(1.0, cy + eh / 2)
|
||
left, top, right, bottom = int(x0 * W), int(y0 * H), int(x1 * W), int(y1 * H)
|
||
if right - left > 8 and bottom - top > 8:
|
||
cropped = im.crop((left, top, right, bottom))
|
||
tmp = _tempfile.NamedTemporaryFile(suffix=".jpg", delete=False)
|
||
cropped.save(tmp.name, format="JPEG", quality=92)
|
||
tmp.close()
|
||
tmp_focus = Path(tmp.name)
|
||
model_src = tmp_focus
|
||
except Exception as e:
|
||
print(f"[cutout region crop failed, fallback to full frame] {e}", flush=True)
|
||
|
||
target = (el.name_en or el.name_zh).strip()
|
||
prompt = (
|
||
f"Identify the {target} in this image. "
|
||
f"Generate a complete, high-resolution, sharply detailed image of the entire {target} as a standalone asset. "
|
||
f"If the {target} is only partially visible in the source (cropped at edges, occluded by other objects, or out of frame), "
|
||
"intelligently reconstruct the missing parts based on visual context so the result shows the FULL element. "
|
||
"Place the complete element on a pure white background, isolated, with no other objects, no scene fragments, no shadows from the original scene. "
|
||
"Preserve the element's original color palette, style, lighting character, and proportions. "
|
||
"Output must be a clean, high-quality asset image suitable for downstream composition."
|
||
)
|
||
models = [IMAGE_MODEL, "gemini-2.5-flash-image"]
|
||
img_bytes: bytes
|
||
try:
|
||
try:
|
||
img_bytes, _mode = _image_edit_call(
|
||
model_src, prompt, models=models, fallback_text=False, max_attempts=3,
|
||
)
|
||
except RuntimeError as e:
|
||
raise HTTPException(500, f"extract failed: {e}")
|
||
finally:
|
||
if tmp_focus and tmp_focus.exists():
|
||
try: tmp_focus.unlink()
|
||
except OSError: pass
|
||
out_path.write_bytes(img_bytes)
|
||
|
||
new_frames = []
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
for e in f.elements:
|
||
if e.id == element_id:
|
||
e.cutouts = (e.cutouts or []) + [new_cutout_id]
|
||
if not e.cutout_id:
|
||
e.cutout_id = new_cutout_id
|
||
new_frames.append(f)
|
||
update(job, frames=new_frames, message=f"提取完成 · {el.name_zh}")
|
||
return job
|
||
|
||
|
||
@app.delete("/jobs/{job_id}/frames/{idx}/elements/{element_id}/cutouts/{cutout_id}", response_model=Job)
|
||
def delete_cutout(job_id: str, idx: int, element_id: str, cutout_id: str) -> Job:
|
||
"""删除该元素的某张提取图"""
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
p = job_dir(job_id) / "elements" / f"{idx:03d}_{element_id}_{cutout_id}.jpg"
|
||
if p.exists():
|
||
try: p.unlink()
|
||
except OSError: pass
|
||
|
||
removed = False
|
||
new_frames = []
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
for e in f.elements:
|
||
if e.id == element_id:
|
||
if cutout_id in (e.cutouts or []):
|
||
e.cutouts = [c for c in e.cutouts if c != cutout_id]
|
||
removed = True
|
||
# cutout_id 兼容字段:若指向被删的就清空 / 移到 cutouts 第一个
|
||
if e.cutout_id == cutout_id:
|
||
e.cutout_id = e.cutouts[0] if e.cutouts else None
|
||
new_frames.append(f)
|
||
if not removed:
|
||
raise HTTPException(404, "cutout not found in element")
|
||
update(job, frames=new_frames, message=f"删除提取图")
|
||
return job
|
||
|
||
|
||
class UpdateStoryboardReq(BaseModel):
|
||
duration: float = 0
|
||
subject_image: dict | None = None
|
||
scene_image: dict | None = None
|
||
product_image: dict | None = None
|
||
action_image: dict | None = None
|
||
# v1 字段(前端可不传)
|
||
subject: str = ""
|
||
product: str = ""
|
||
scene: str = ""
|
||
action: str = ""
|
||
reference_ids: list[str] = []
|
||
|
||
|
||
class GenerateStoryboardVideoReq(BaseModel):
|
||
prompt: str
|
||
duration: float = 4
|
||
subject_image: dict | None = None
|
||
scene_image: dict | None = None
|
||
product_image: dict | None = None
|
||
action_image: dict | None = None
|
||
model: str = ""
|
||
size: str = "720x1280"
|
||
|
||
|
||
def video_seconds(duration: float) -> str:
|
||
if video_uses_ark():
|
||
if duration <= 0:
|
||
return "5"
|
||
return str(max(4, min(15, round(duration))))
|
||
if duration <= 6:
|
||
return "4"
|
||
if duration <= 10:
|
||
return "8"
|
||
return "12"
|
||
|
||
|
||
def resolve_video_model(raw: str | None) -> str:
|
||
requested = (raw or VIDEO_MODEL or "seedance").strip()
|
||
lowered = requested.lower()
|
||
if lowered in {"sora", "sora-2", "sora_2"}:
|
||
raise HTTPException(400, "Sora 已停用,请选择 Seedance / Kling / Veo 3")
|
||
return VIDEO_MODEL_ALIASES.get(lowered, requested)
|
||
|
||
|
||
def normalize_video_status(status: str | None) -> Literal["queued", "in_progress", "completed", "failed"]:
|
||
s = (status or "queued").lower()
|
||
if s in {"completed", "complete", "succeeded", "success", "done"}:
|
||
return "completed"
|
||
if s in {"failed", "failure", "error", "cancelled", "canceled", "expired"}:
|
||
return "failed"
|
||
if s in {"running", "processing", "in_progress", "generating", "started"}:
|
||
return "in_progress"
|
||
return "queued"
|
||
|
||
|
||
def video_progress(data: dict, fallback: int) -> int:
|
||
raw = data.get("progress", data.get("percentage", data.get("percent", fallback)))
|
||
try:
|
||
value = int(float(raw))
|
||
except Exception:
|
||
value = fallback
|
||
return max(0, min(100, value))
|
||
|
||
|
||
def video_url_from_response(data: dict) -> str:
|
||
for key in ("url", "video_url", "output_url", "download_url"):
|
||
v = data.get(key)
|
||
if isinstance(v, str) and v:
|
||
return v
|
||
arr = data.get("data")
|
||
if isinstance(arr, list) and arr:
|
||
first = arr[0]
|
||
if isinstance(first, dict):
|
||
for key in ("url", "video_url", "output_url", "download_url"):
|
||
v = first.get(key)
|
||
if isinstance(v, str) and v:
|
||
return v
|
||
output = data.get("output")
|
||
if isinstance(output, dict):
|
||
for key in ("url", "video_url", "download_url"):
|
||
v = output.get(key)
|
||
if isinstance(v, str) and v:
|
||
return v
|
||
content = data.get("content")
|
||
if isinstance(content, dict):
|
||
for key in ("video_url", "url", "download_url", "file_url"):
|
||
v = content.get(key)
|
||
if isinstance(v, str) and v:
|
||
return v
|
||
return ""
|
||
|
||
|
||
def download_generated_video(client, base: str, headers: dict, provider_id: str, direct_url: str, out_mp4: Path) -> None:
|
||
if direct_url:
|
||
url = direct_url if direct_url.startswith("http") else f"{base}{direct_url if direct_url.startswith('/') else '/' + direct_url}"
|
||
r = client.get(url, headers=headers if url.startswith(base) else None)
|
||
else:
|
||
r = client.get(f"{base}{video_path(VIDEO_CONTENT_PATH, id=provider_id)}", headers=headers)
|
||
r.raise_for_status()
|
||
out_mp4.write_bytes(r.content)
|
||
|
||
|
||
def size_to_video_ratio(size: str) -> str:
|
||
try:
|
||
w, h = [int(x) for x in size.lower().replace(" ", "").split("x", 1)]
|
||
except Exception:
|
||
return "9:16"
|
||
if w <= 0 or h <= 0:
|
||
return "9:16"
|
||
ratio = w / h
|
||
known = {
|
||
"16:9": 16 / 9,
|
||
"9:16": 9 / 16,
|
||
"1:1": 1,
|
||
"4:3": 4 / 3,
|
||
"3:4": 3 / 4,
|
||
"21:9": 21 / 9,
|
||
}
|
||
return min(known, key=lambda key: abs(known[key] - ratio))
|
||
|
||
|
||
def ark_reference_data_url(ref_img: Path) -> str:
|
||
mime = "image/png" if ref_img.suffix.lower() == ".png" else "image/jpeg"
|
||
return f"data:{mime};base64,{base64.b64encode(ref_img.read_bytes()).decode('ascii')}"
|
||
|
||
|
||
def submit_video_create(client, url: str, headers: dict, ref_img: Path, payload: dict):
|
||
if video_uses_ark():
|
||
data = {
|
||
"model": payload["model"],
|
||
"content": [
|
||
{"type": "text", "text": payload["prompt"]},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {"url": ark_reference_data_url(ref_img)},
|
||
"role": "first_frame",
|
||
},
|
||
],
|
||
"ratio": size_to_video_ratio(str(payload.get("size", ""))),
|
||
"duration": int(float(str(payload.get(VIDEO_DURATION_FIELD, 5)))),
|
||
"watermark": False,
|
||
"resolution": "720p",
|
||
}
|
||
return client.post(url, headers={**headers, "Content-Type": "application/json"}, json=data)
|
||
|
||
if video_uses_poe():
|
||
data = dict(payload)
|
||
data[VIDEO_DURATION_FIELD] = int(float(str(data.get(VIDEO_DURATION_FIELD, 4))))
|
||
data["input_image"] = base64.b64encode(ref_img.read_bytes()).decode("ascii")
|
||
return client.post(url, headers=headers, json=data)
|
||
|
||
with ref_img.open("rb") as fh:
|
||
return client.post(
|
||
url,
|
||
headers=headers,
|
||
data=payload,
|
||
files={"input_reference": ("reference.jpg", fh, "image/jpeg")},
|
||
)
|
||
|
||
|
||
def render_storyboard_video(job_id: str, local_id: str, provider_id: str, ref_path: Path, prompt: str, model: str, seconds: str, size: str) -> None:
|
||
import httpx
|
||
|
||
out_dir = job_dir(job_id) / "storyboard_videos" / local_id
|
||
ref_img = out_dir / "reference.jpg"
|
||
out_mp4 = out_dir / "video.mp4"
|
||
base = video_api_base()
|
||
headers = {"Authorization": f"Bearer {video_api_key()}"}
|
||
|
||
try:
|
||
prepare_video_reference(ref_path, ref_img)
|
||
update_generated_video(job_id, local_id, status="in_progress", progress=5)
|
||
with httpx.Client(timeout=120) as client:
|
||
payload = {"model": model, "prompt": prompt, "size": size}
|
||
payload[VIDEO_DURATION_FIELD] = seconds
|
||
create = None
|
||
create_errors: list[str] = []
|
||
for create_path in VIDEO_CREATE_PATHS:
|
||
resp = submit_video_create(client, f"{base}{video_path(create_path)}", headers, ref_img, payload)
|
||
if resp.status_code < 400:
|
||
create = resp
|
||
break
|
||
create_errors.append(f"{video_path(create_path)} -> HTTP {resp.status_code}: {resp.text[:160]}")
|
||
if resp.status_code not in {400, 404, 405}:
|
||
resp.raise_for_status()
|
||
if create is None:
|
||
raise RuntimeError("视频模型已选择,但当前网关视频生成入口不可用;已尝试 " + " | ".join(create_errors))
|
||
data = create.json()
|
||
video_api_id = data.get("id") or provider_id or local_id
|
||
status = normalize_video_status(data.get("status"))
|
||
progress = video_progress(data, 5)
|
||
direct_url = video_url_from_response(data)
|
||
update_generated_video(job_id, local_id, provider_id=video_api_id, status=status, progress=progress)
|
||
|
||
deadline = time.time() + 420
|
||
while status in {"queued", "in_progress"} and time.time() < deadline:
|
||
time.sleep(8)
|
||
poll = client.get(f"{base}{video_path(VIDEO_STATUS_PATH, id=video_api_id)}", headers=headers)
|
||
poll.raise_for_status()
|
||
pdata = poll.json()
|
||
status = normalize_video_status(pdata.get("status"))
|
||
progress = video_progress(pdata, progress)
|
||
direct_url = video_url_from_response(pdata) or direct_url
|
||
update_generated_video(job_id, local_id, status=status, progress=progress)
|
||
|
||
if status != "completed":
|
||
update_generated_video(job_id, local_id, status="failed", error=f"video status: {status}", progress=progress)
|
||
return
|
||
|
||
download_generated_video(client, base, headers, video_api_id, direct_url, out_mp4)
|
||
update_generated_video(
|
||
job_id,
|
||
local_id,
|
||
status="completed",
|
||
progress=100,
|
||
url=f"/jobs/{job_id}/storyboard-videos/{local_id}.mp4",
|
||
error="",
|
||
)
|
||
except Exception as e:
|
||
update_generated_video(job_id, local_id, status="failed", error=str(e)[:500])
|
||
|
||
|
||
@app.post("/jobs/{job_id}/frames/{idx}/storyboard/video", response_model=Job)
|
||
def generate_storyboard_video(job_id: str, idx: int, req: GenerateStoryboardVideoReq, bg: BackgroundTasks) -> Job:
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
frame = next((f for f in job.frames if f.index == idx), None)
|
||
if not frame:
|
||
raise HTTPException(404, "frame not found")
|
||
ensure_video_api_configured()
|
||
prompt = req.prompt.strip()
|
||
if not prompt:
|
||
raise HTTPException(400, "prompt required")
|
||
|
||
ref = req.product_image or req.subject_image or req.scene_image or req.action_image
|
||
ref_path = storyboard_ref_path(job_id, ref) or (job_dir(job_id) / "frames" / f"{idx:03d}.jpg")
|
||
if not ref_path.exists():
|
||
raise HTTPException(404, "reference image missing")
|
||
poster = storyboard_ref_url(job_id, ref) or f"/jobs/{job_id}/frames/{idx}.jpg"
|
||
|
||
local_id = uuid.uuid4().hex[:12]
|
||
model = resolve_video_model(req.model)
|
||
seconds = video_seconds(float(req.duration or 4))
|
||
item = GeneratedVideo(
|
||
id=local_id,
|
||
provider_id="",
|
||
frame_idx=idx,
|
||
prompt=prompt,
|
||
model=model,
|
||
status="queued",
|
||
url="",
|
||
poster_url=poster,
|
||
duration=float(seconds),
|
||
progress=0,
|
||
created_at=time.time(),
|
||
)
|
||
update(job, generated_videos=[item] + job.generated_videos, message=f"视频生成已提交 · 分镜 {idx + 1}")
|
||
bg.add_task(render_storyboard_video, job_id, local_id, "", ref_path, prompt, model, seconds, req.size)
|
||
return job
|
||
|
||
|
||
@app.get("/jobs/{job_id}/storyboard-videos/{video_id}.mp4")
|
||
def get_storyboard_video(job_id: str, video_id: str):
|
||
p = job_dir(job_id) / "storyboard_videos" / video_id / "video.mp4"
|
||
if not p.exists():
|
||
raise HTTPException(404, "storyboard video not found")
|
||
return FileResponse(p, media_type="video/mp4")
|
||
|
||
|
||
@app.delete("/jobs/{job_id}/storyboard-videos/{video_id}", response_model=Job)
|
||
def delete_storyboard_video(job_id: str, video_id: str) -> Job:
|
||
"""删除 Video Gen 节点里的一个视频任务(成功/失败/排队都可删)。"""
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
before = len(job.generated_videos)
|
||
removed = next((v for v in job.generated_videos if v.id == video_id), None)
|
||
kept = [v for v in job.generated_videos if v.id != video_id]
|
||
if len(kept) == before:
|
||
raise HTTPException(404, "generated video not found")
|
||
out_dir = job_dir(job_id) / "storyboard_videos" / video_id
|
||
if out_dir.exists():
|
||
try:
|
||
shutil.rmtree(out_dir)
|
||
except OSError:
|
||
pass
|
||
msg = f"删除视频任务 · 分镜 {removed.frame_idx + 1}" if removed else "删除视频任务"
|
||
update(job, generated_videos=kept, message=msg)
|
||
return job
|
||
|
||
|
||
@app.put("/jobs/{job_id}/frames/{idx}/storyboard", response_model=Job)
|
||
def update_storyboard(job_id: str, idx: int, req: UpdateStoryboardReq) -> Job:
|
||
"""更新分镜的编排字段(subject / product / scene / action / duration / reference_ids)"""
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
frame = next((f for f in job.frames if f.index == idx), None)
|
||
if not frame:
|
||
raise HTTPException(404, "frame not found")
|
||
new_frames = []
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
f.storyboard = StoryboardScene(
|
||
duration=max(0.0, float(req.duration)),
|
||
subject_image=req.subject_image,
|
||
scene_image=req.scene_image,
|
||
product_image=req.product_image,
|
||
action_image=req.action_image,
|
||
subject=req.subject.strip(),
|
||
product=req.product.strip(),
|
||
scene=req.scene.strip(),
|
||
action=req.action.strip(),
|
||
reference_ids=list(req.reference_ids),
|
||
)
|
||
new_frames.append(f)
|
||
update(job, frames=new_frames, message=f"分镜 {idx + 1} 编排已更新")
|
||
return job
|
||
|
||
|
||
class PushStoryboardImageReq(BaseModel):
|
||
kind: Literal["keyframe", "cutout"]
|
||
frame_idx: int
|
||
element_id: str | None = None
|
||
cutout_id: str | None = None
|
||
label: str = ""
|
||
|
||
|
||
@app.post("/jobs/{job_id}/storyboard-images", response_model=Job)
|
||
def push_storyboard_image(job_id: str, req: PushStoryboardImageReq) -> Job:
|
||
"""把一张图(关键帧本身或元素提取图)推送到分镜头编排区"""
|
||
import time as _time
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
# 防重复推送:相同 frame_idx + element_id + cutout_id 已存在就跳过
|
||
for existing in job.storyboard_images:
|
||
if (existing.kind == req.kind
|
||
and existing.frame_idx == req.frame_idx
|
||
and existing.element_id == req.element_id
|
||
and existing.cutout_id == req.cutout_id):
|
||
return job
|
||
img = StoryboardImage(
|
||
ref_id=uuid.uuid4().hex[:8],
|
||
kind=req.kind,
|
||
frame_idx=req.frame_idx,
|
||
element_id=req.element_id,
|
||
cutout_id=req.cutout_id,
|
||
label=req.label.strip(),
|
||
created_at=_time.time(),
|
||
)
|
||
update(job, storyboard_images=job.storyboard_images + [img], message=f"上推到分镜头编排 · {req.label or req.kind}")
|
||
return job
|
||
|
||
|
||
@app.delete("/jobs/{job_id}/storyboard-images/{ref_id}", response_model=Job)
|
||
def remove_storyboard_image(job_id: str, ref_id: str) -> Job:
|
||
"""从分镜头编排区移除一张图"""
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
before = len(job.storyboard_images)
|
||
new_list = [x for x in job.storyboard_images if x.ref_id != ref_id]
|
||
if len(new_list) == before:
|
||
raise HTTPException(404, "storyboard image not found")
|
||
update(job, storyboard_images=new_list, message="从分镜头编排移除一张图")
|
||
return job
|
||
|
||
|
||
@app.get("/jobs/{job_id}/frames/{idx}/elements/{element_id}/cutouts/{cutout_id}.jpg")
|
||
def get_cutout_versioned(job_id: str, idx: int, element_id: str, cutout_id: str):
|
||
p = job_dir(job_id) / "elements" / f"{idx:03d}_{element_id}_{cutout_id}.jpg"
|
||
if not p.exists():
|
||
raise HTTPException(404, "cutout not found")
|
||
return FileResponse(p, media_type="image/jpeg")
|
||
|
||
|
||
@app.get("/jobs/{job_id}/frames/{idx}/elements/{element_id}/cutout.jpg")
|
||
def get_cutout(job_id: str, idx: int, element_id: str):
|
||
"""旧路径兼容(v1 单图)→ 找 elements/{idx}_{element_id}.jpg 或 .png"""
|
||
p = job_dir(job_id) / "elements" / f"{idx:03d}_{element_id}.jpg"
|
||
if not p.exists():
|
||
legacy = job_dir(job_id) / "elements" / f"{idx:03d}_{element_id}.png"
|
||
if legacy.exists():
|
||
return FileResponse(legacy, media_type="image/jpeg")
|
||
raise HTTPException(404, "cutout not found")
|
||
return FileResponse(p, media_type="image/jpeg")
|
||
|
||
|
||
# ---------- 删除:关键帧 / 单张生成图 ----------
|
||
|
||
@app.delete("/jobs/{job_id}/frames/{idx}", response_model=Job)
|
||
def delete_frame(job_id: str, idx: int) -> Job:
|
||
"""删除整张关键帧,清理所有附属文件(原图 / 干净版 / 元素抠图 / 生成图)"""
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
target = next((f for f in job.frames if f.index == idx), None)
|
||
if not target:
|
||
raise HTTPException(404, "frame not found")
|
||
|
||
d = job_dir(job_id)
|
||
# 删文件 — 静默错误,文件可能不存在
|
||
paths = [
|
||
d / "frames" / f"{idx:03d}.jpg",
|
||
d / "cleaned" / f"{idx:03d}.jpg",
|
||
]
|
||
for p in paths:
|
||
if p.exists():
|
||
try: p.unlink()
|
||
except OSError: pass
|
||
# 该帧的所有元素抠图(命名前缀 {idx:03d}_)
|
||
elements_dir = d / "elements"
|
||
if elements_dir.exists():
|
||
for ext in ("png", "jpg"):
|
||
for p in elements_dir.glob(f"{idx:03d}_*.{ext}"):
|
||
try: p.unlink()
|
||
except OSError: pass
|
||
# 该帧的所有生成图
|
||
gen_dir = d / "gen"
|
||
if gen_dir.exists():
|
||
for p in gen_dir.glob(f"{idx:03d}_*.jpg"):
|
||
try: p.unlink()
|
||
except OSError: pass
|
||
|
||
new_frames = [f for f in job.frames if f.index != idx]
|
||
update(job, frames=new_frames, message=f"删除分镜 {idx + 1}")
|
||
return job
|
||
|
||
|
||
@app.delete("/jobs/{job_id}/frames/{idx}/gen/{gen_id}", response_model=Job)
|
||
def delete_generated(job_id: str, idx: int, gen_id: str) -> Job:
|
||
"""删除该 frame 的某张生成图(文件 + 列表)"""
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
frame = next((f for f in job.frames if f.index == idx), None)
|
||
if not frame:
|
||
raise HTTPException(404, "frame not found")
|
||
|
||
p = job_dir(job_id) / "gen" / f"{idx:03d}_{gen_id}.jpg"
|
||
if p.exists():
|
||
try: p.unlink()
|
||
except OSError: pass
|
||
|
||
new_frames = []
|
||
found = False
|
||
for f in job.frames:
|
||
if f.index == idx:
|
||
before = len(f.generated_images)
|
||
f.generated_images = [g for g in f.generated_images if g.id != gen_id]
|
||
found = len(f.generated_images) < before
|
||
new_frames.append(f)
|
||
if not found:
|
||
raise HTTPException(404, "generated image not found")
|
||
update(job, frames=new_frames, message=f"删除生成图 · 分镜 {idx + 1}")
|
||
return job
|