Files
20260512-skg-tk/api/main.py
2026-05-13 11:45:33 +08:00

1370 lines
50 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import asyncio
import json
import os
import shutil
import subprocess
import uuid
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Literal
from dotenv import load_dotenv
from fastapi import BackgroundTasks, FastAPI, File, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from pydantic import BaseModel, Field
load_dotenv()
JOBS_DIR = Path(os.getenv("JOBS_DIR", "./jobs")).resolve()
JOBS_DIR.mkdir(parents=True, exist_ok=True)
CORS_ORIGINS = [o.strip() for o in os.getenv("CORS_ORIGINS", "http://localhost:4290").split(",") if o.strip()]
LLM_BASE_URL = os.getenv("LLM_BASE_URL", "").strip()
LLM_API_KEY = os.getenv("LLM_API_KEY", "").strip()
ASR_MODEL = os.getenv("ASR_MODEL", "whisper-1")
TRANSLATE_MODEL = os.getenv("TRANSLATE_MODEL", "gemini-2.5-flash")
REWRITE_MODEL = os.getenv("REWRITE_MODEL", "gemini-2.5-pro")
VISION_MODEL = os.getenv("VISION_MODEL", "gemini-2.5-flash")
IMAGE_MODEL = os.getenv("IMAGE_MODEL", "gemini-3-pro-image-preview")
# OpenAI 客户端OpenAI 兼容网关,含 SKG ezlink
from openai import OpenAI
_llm_client: OpenAI | None = None
def llm() -> OpenAI:
global _llm_client
if _llm_client is None:
if not LLM_API_KEY:
raise RuntimeError("LLM_API_KEY 未配置")
_llm_client = OpenAI(base_url=LLM_BASE_URL or None, api_key=LLM_API_KEY)
return _llm_client
# Pipeline 状态:
# created → downloading → downloaded等用户点解析→ splitting → frames_extracted
# → transcribing → transcribed | failed
JobStatus = Literal[
"created", "downloading", "downloaded",
"splitting", "frames_extracted",
"transcribing", "transcribed", "failed",
]
KEYFRAME_COUNT = int(os.getenv("KEYFRAME_COUNT", "5"))
class GeneratedImage(BaseModel):
id: str # uuid hex 12
prompt: str
model: str
mode: str = "edit" # "edit"(带参考图) | "text"(纯文字)
url: str # /jobs/{job_id}/frames/{idx}/gen/{id}.jpg
selected: bool = False
created_at: float = 0.0
class KeyElement(BaseModel):
"""关键帧里识别 / 用户提取的元素,可单独抠图给下游做"二创素材层" """
id: str # uuid hex 8
name_zh: str
name_en: str = ""
position: str = ""
source: Literal["auto", "manual", "region"] = "manual"
region: dict | None = None
# 抠图Gemini 输出 JPEG 不支持真透明,所以让模型在纯白 / 纯黑底上输出)
cutout_id: str | None = None # 已抠图 → /jobs/{id}/frames/{idx}/elements/{element_id}/cutout.jpg
cutout_background: Literal["white", "black"] = "white"
created_at: float = 0.0
class KeyFrame(BaseModel):
index: int
timestamp: float
url: str
description: dict | None = None # vision 模型识别结果 {scene, objects, style, suggested_prompt}
cleaned_url: str | None = None # 清洗后干净版(待应用)→ /jobs/{id}/frames/{idx}/cleaned.jpg
cleaned_applied: bool = False # 是否已用清洗版替换原图(替换后 cleaned_url=null
elements: list[KeyElement] = [] # 提取的元素清单(持久化)
generated_images: list[GeneratedImage] = []
class TranscriptSegment(BaseModel):
index: int
start: float
end: float
en: str
zh: str = ""
class Job(BaseModel):
id: str
url: str
status: JobStatus = "created"
progress: int = 0
message: str = ""
video_url: str = ""
duration: float = 0.0
width: int = 0
height: int = 0
frames: list[KeyFrame] = Field(default_factory=list)
transcript: list[TranscriptSegment] = Field(default_factory=list)
error: str = ""
JOBS: dict[str, Job] = {}
def job_dir(job_id: str) -> Path:
d = JOBS_DIR / job_id
d.mkdir(parents=True, exist_ok=True)
return d
def save_state(job: Job) -> None:
(job_dir(job.id) / "state.json").write_text(job.model_dump_json(indent=2))
def update(job: Job, **kw) -> None:
for k, v in kw.items():
setattr(job, k, v)
save_state(job)
@asynccontextmanager
async def lifespan(_: FastAPI):
# 启动时从磁盘恢复 jobs简化版只列目录
for p in JOBS_DIR.iterdir():
if p.is_dir() and (p / "state.json").exists():
try:
JOBS[p.name] = Job.model_validate_json((p / "state.json").read_text())
except Exception:
pass
yield
app = FastAPI(title="SKG TK 二创 API", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ---------- Pipeline 实现 ----------
def run(cmd: list[str], cwd: Path | None = None) -> str:
res = subprocess.run(cmd, cwd=cwd, capture_output=True, text=True)
if res.returncode != 0:
# ffmpeg 把 banner 写 stderr挑最后几行真错误一般在末尾
tail = "\n".join(res.stderr.splitlines()[-12:]) or res.stderr[-800:]
raise RuntimeError(f"cmd failed: {' '.join(cmd[:3])}... · {tail}")
return res.stdout
# ---- 启发式选帧工具 ----
import imagehash
import numpy as np
from PIL import Image
def _sharpness(img_path: Path) -> float:
"""Laplacian variance值越大越清晰模糊/转场帧值低。"""
g = np.asarray(Image.open(img_path).convert("L").resize((320, 180)), dtype=np.float32)
lap = (-4 * g[1:-1, 1:-1]
+ g[:-2, 1:-1] + g[2:, 1:-1] + g[1:-1, :-2] + g[1:-1, 2:])
return float(lap.var())
def _select_keyframes(candidates: list[Path], n: int, dup_threshold: int = 8) -> list[Path]:
"""
candidates: 按时间排序的候选帧路径
n: 目标帧数
dup_threshold: pHash 汉明距离 < 此值视为相似(默认 864bit hash 大致 ~12.5% 像素差)
"""
if len(candidates) <= n:
return candidates
# 算 pHash + sharpness
items = []
for i, p in enumerate(candidates):
try:
img = Image.open(p)
h = imagehash.phash(img)
s = _sharpness(p)
items.append({"path": p, "idx": i, "hash": h, "sharp": s})
except Exception:
continue
# 去重:相似帧保留 sharpness 高的
deduped: list[dict] = []
for it in items:
dup = None
for kept in deduped:
if (it["hash"] - kept["hash"]) < dup_threshold:
dup = kept
break
if dup is None:
deduped.append(it)
elif it["sharp"] > dup["sharp"]:
deduped[deduped.index(dup)] = it
# 时序分桶:把候选时间轴等分 n 段,每段取去重后 sharpness 最高的
total = len(candidates)
buckets: list[list[dict]] = [[] for _ in range(n)]
for it in deduped:
b = min(int(it["idx"] * n / total), n - 1)
buckets[b].append(it)
selected: list[dict] = []
for b in buckets:
if b:
selected.append(max(b, key=lambda x: x["sharp"]))
# 空桶补足:从未选的 deduped 里按 sharpness 排序补
chosen_paths = {it["path"] for it in selected}
remaining = sorted([it for it in deduped if it["path"] not in chosen_paths],
key=lambda x: -x["sharp"])
while len(selected) < n and remaining:
selected.append(remaining.pop(0))
# 按时间排序输出
selected.sort(key=lambda x: x["idx"])
return [it["path"] for it in selected]
def ffprobe_meta(mp4: Path) -> dict:
out = run([
"ffprobe", "-v", "error", "-print_format", "json", "-show_streams", "-show_format", str(mp4),
])
return json.loads(out)
async def pipeline_download(job_id: str) -> None:
"""阶段 1仅下载或上传跳过落 source.mp4停在 downloaded 等用户点解析。"""
job = JOBS[job_id]
d = job_dir(job_id)
try:
mp4 = d / "source.mp4"
if mp4.exists():
update(job, status="downloading", message="本地上传 · 跳过下载", progress=15)
else:
update(job, status="downloading", message="yt-dlp 下载中…", progress=5)
run([
"yt-dlp", "-f", "best[ext=mp4]/best",
"-o", str(mp4),
"--no-warnings", "--no-playlist",
"--retries", "3",
job.url,
])
if not mp4.exists():
raise RuntimeError("下载完成但找不到 source.mp4")
meta = ffprobe_meta(mp4)
v_stream = next((s for s in meta["streams"] if s["codec_type"] == "video"), None)
duration = float(meta["format"]["duration"])
update(
job,
status="downloaded",
video_url=f"/jobs/{job_id}/video.mp4",
duration=duration,
width=int(v_stream["width"]) if v_stream else 0,
height=int(v_stream["height"]) if v_stream else 0,
progress=25,
message=f"视频就绪 · {duration:.1f}s · 等待解析",
)
except Exception as e:
update(job, status="failed", error=str(e), message="下载失败")
async def pipeline_analyze(job_id: str, frame_count: int = KEYFRAME_COUNT) -> None:
"""阶段 2拆音轨 + 抽关键帧 + ASR + 翻译。需要 source.mp4 已存在。"""
job = JOBS[job_id]
d = job_dir(job_id)
try:
mp4 = d / "source.mp4"
if not mp4.exists():
raise RuntimeError("source.mp4 不存在,先完成下载")
update(job, status="splitting", message="ffmpeg 拆分音轨…", progress=35)
wav = d / "audio.wav"
run([
"ffmpeg", "-y", "-i", str(mp4),
"-vn", "-ac", "1", "-ar", "16000", "-c:a", "pcm_s16le",
str(wav),
])
n = max(1, min(int(frame_count), 20))
# 候选数n 的 6 倍或至少 24封顶 60
candidate_count = max(24, min(60, n * 6))
update(job, message=f"抽取候选 {candidate_count} 张…", progress=45)
frames_dir = d / "frames"
if frames_dir.exists():
shutil.rmtree(frames_dir)
frames_dir.mkdir(parents=True)
cand_dir = d / "candidates"
if cand_dir.exists():
shutil.rmtree(cand_dir)
cand_dir.mkdir(parents=True)
# 1) 均匀采样大批候选fast seek每张 < 0.5s
duration = max(float(job.duration or 1.0), 0.1)
step = duration / (candidate_count + 1)
candidate_meta: list[tuple[Path, float]] = [] # (path, timestamp)
for i in range(candidate_count):
t = step * (i + 1)
out = cand_dir / f"c_{i:03d}.jpg"
run([
"ffmpeg", "-y", "-ss", str(t), "-i", str(mp4),
"-frames:v", "1",
"-pix_fmt", "yuvj420p", "-q:v", "3",
str(out),
])
if out.exists():
candidate_meta.append((out, t))
# 2) D 启发式选 n 张pHash 去重 + Laplacian 清晰度 + 时序分桶
update(job, message=f"启发式筛选 {n} / {len(candidate_meta)} 张…", progress=60)
cand_paths = [m[0] for m in candidate_meta]
ts_by_path = {m[0]: m[1] for m in candidate_meta}
chosen = _select_keyframes(cand_paths, n)
# 3) 落盘到 frames/<idx>.jpg
renamed: list[KeyFrame] = []
chosen_sorted = sorted(chosen, key=lambda p: ts_by_path[p])
for i, src in enumerate(chosen_sorted):
dst = frames_dir / f"{i:03d}.jpg"
shutil.copyfile(src, dst)
renamed.append(KeyFrame(
index=i,
timestamp=round(ts_by_path[src], 2),
url=f"/jobs/{job_id}/frames/{i}.jpg",
))
# 4) 清理候选目录
shutil.rmtree(cand_dir, ignore_errors=True)
update(
job,
status="frames_extracted",
frames=renamed,
progress=70,
message=f"已抽取 {len(renamed)} 张关键帧",
)
# 自动接 ASR + 翻译
await pipeline_transcribe(job_id)
except Exception as e:
update(job, status="failed", error=str(e), message="解析失败")
# ---------- Gemini ASR + 翻译 ----------
def _transcribe_sync(wav: Path) -> list[dict]:
"""whisper-1 verbose_json → segments[{start, end, text}]"""
with wav.open("rb") as f:
resp = llm().audio.transcriptions.create(
file=(wav.name, f, "audio/wav"),
model=ASR_MODEL,
response_format="verbose_json",
timestamp_granularities=["segment"],
)
raw = resp.model_dump() if hasattr(resp, "model_dump") else resp
segments = raw.get("segments") or []
# 兜底:网关如果不返回 segments把全文当一段
if not segments and raw.get("text"):
segments = [{"start": 0.0, "end": float(raw.get("duration", 0) or 0), "text": raw["text"]}]
return segments
def _translate_sync(segments: list[dict]) -> list[str]:
"""gemini-2.5-flash 批量翻译为中文,按段返回"""
payload = [{"i": i, "en": s.get("text", "").strip()} for i, s in enumerate(segments)]
prompt = (
"你是字幕翻译。把下列英文字幕段翻译为简体中文,保持原意、口语化、自然流畅。"
"严格返回 JSON 数组,不要任何 markdown 或多余文字schema: "
'[{"i": 0, "zh": "..."}, ...]\n\n输入:\n'
+ json.dumps(payload, ensure_ascii=False)
)
resp = llm().chat.completions.create(
model=TRANSLATE_MODEL,
messages=[{"role": "user", "content": prompt}],
response_format={"type": "json_object"},
temperature=0.2,
)
content = resp.choices[0].message.content or "[]"
try:
data = json.loads(content)
if isinstance(data, dict):
for k in ("data", "items", "result", "translations"):
if k in data and isinstance(data[k], list):
data = data[k]
break
if not isinstance(data, list):
data = []
except json.JSONDecodeError:
data = []
zh_by_idx: dict[int, str] = {}
for it in data:
if isinstance(it, dict) and "i" in it:
zh_by_idx[int(it["i"])] = str(it.get("zh", ""))
return [zh_by_idx.get(i, "") for i in range(len(segments))]
async def pipeline_transcribe(job_id: str) -> None:
job = JOBS[job_id]
d = job_dir(job_id)
wav = d / "audio.wav"
try:
if not wav.exists():
raise RuntimeError("audio.wav 不存在")
if not LLM_API_KEY:
# 无 key 模式mock 数据
update(job, status="transcribing", message="ASR (mock) …", progress=75)
await asyncio.sleep(1.0)
mock = [
TranscriptSegment(index=0, start=0.0, end=3.5,
en="Welcome back, today we're testing something new.",
zh="欢迎回来,今天我们要测试一些新东西。"),
TranscriptSegment(index=1, start=3.5, end=7.2,
en="This device looks really sleek and minimal.",
zh="这个设备看起来非常时尚和简约。"),
]
update(job, transcript=mock, status="transcribed", progress=100,
message="转录完成MOCK · 未设 LLM_API_KEY")
return
# 1) whisper ASR
update(job, status="transcribing", message=f"{ASR_MODEL} 转录中…", progress=78)
segments = await asyncio.to_thread(_transcribe_sync, wav)
if not segments:
raise RuntimeError("ASR 返回 0 段(可能无人声 / 格式问题)")
# 先把英文段落落到 job 上(让 UI 提前看到,翻译再补 zh
en_only = [
TranscriptSegment(
index=i,
start=float(s.get("start", 0)),
end=float(s.get("end", 0)),
en=str(s.get("text", "")).strip(),
zh="",
)
for i, s in enumerate(segments)
]
update(job, transcript=en_only, message=f"ASR 完成 · {len(en_only)} 段,开始翻译…", progress=88)
# 2) Gemini 翻译
zh_list = await asyncio.to_thread(_translate_sync, segments)
full = [
TranscriptSegment(
index=seg.index, start=seg.start, end=seg.end, en=seg.en,
zh=zh_list[i] if i < len(zh_list) else "",
)
for i, seg in enumerate(en_only)
]
update(job, transcript=full, status="transcribed", progress=100,
message=f"转录完成 · {len(full)} 段({ASR_MODEL} + {TRANSLATE_MODEL}")
except Exception as e:
update(job, status="failed", error=str(e), message="转录失败")
def _image_edit_call(
image_path: Path,
prompt: str,
model: str | None = None,
fallback_text: bool = False,
max_attempts: int = 3,
max_side: int = 1024,
) -> tuple[bytes, str]:
"""通用 image edit 调用 · 失败重试 + 可选 text fallback。
返回 (image_bytes, effective_mode) where effective_mode in {"edit","text"}。
失败 raise RuntimeError。
输入图自动 resize 到 max_side默认 1024边长后再 base64避免大图把 Gemini
function call 输入挤超阈值导致 incomplete_generation。"""
import base64 as b64lib
import io as _io
import time as _time
import httpx
from PIL import Image as _PILImage
if not LLM_API_KEY:
raise RuntimeError("LLM_API_KEY 未配置")
model = model or IMAGE_MODEL
# 缩到 max_side 内
try:
im = _PILImage.open(image_path)
if max(im.size) > max_side:
im.thumbnail((max_side, max_side), _PILImage.LANCZOS)
buf = _io.BytesIO()
im.convert("RGB").save(buf, format="JPEG", quality=88)
img_bytes_in = buf.getvalue()
except Exception:
# PIL 失败兜底走原文件
img_bytes_in = image_path.read_bytes()
img_b64 = b64lib.b64encode(img_bytes_in).decode("ascii")
data_uri = f"data:image/jpeg;base64,{img_b64}"
plan: list[str] = ["edit"] * max_attempts
if fallback_text:
plan.append("text")
last_err = ""
resp_data: dict = {}
effective_mode = "edit"
for attempt, current_mode in enumerate(plan):
try:
if current_mode == "edit":
with httpx.Client(timeout=120) as client:
r = client.post(
f"{LLM_BASE_URL}/images/generations",
headers={
"Authorization": f"Bearer {LLM_API_KEY}",
"Content-Type": "application/json",
},
json={"model": model, "prompt": prompt, "image": data_uri, "n": 1},
)
r.raise_for_status()
resp_data = r.json()
else:
resp = llm().images.generate(model=model, prompt=prompt, n=1)
resp_data = resp.model_dump() if hasattr(resp, "model_dump") else {"data": [{"b64_json": resp.data[0].b64_json}]}
if resp_data.get("data"):
effective_mode = current_mode
break
err_obj = resp_data.get("error") or {}
last_err = f"empty data · {err_obj.get('code', '')} · {str(err_obj.get('message', ''))[:200]}"
except httpx.HTTPStatusError as e:
body = e.response.text
transient = (
e.response.status_code >= 500
or "incomplete_generation" in body
or "rate_limit" in body
or "timeout" in body.lower()
)
last_err = f"HTTP {e.response.status_code}: {body[:200]}"
if not transient:
raise RuntimeError(f"image edit HTTP {e.response.status_code}: {body[:300]}")
except Exception as e:
last_err = f"{type(e).__name__}: {e}"
if attempt < len(plan) - 1:
next_mode = plan[attempt + 1]
tag = f"fallback → {next_mode}" if next_mode != current_mode else f"retry {attempt + 1}/{len(plan)}"
print(f"[image edit {tag}] {last_err}", flush=True)
_time.sleep(1.5 * (attempt + 1))
data_arr = resp_data.get("data", [])
if not data_arr:
raise RuntimeError(f"image edit failed after {len(plan)} attempts: {last_err}")
b64 = data_arr[0].get("b64_json")
if not b64:
raise RuntimeError("image edit returned no b64_json")
return b64lib.b64decode(b64), effective_mode
# ---------- API 路由 ----------
class CreateJobReq(BaseModel):
url: str
class TranslateReq(BaseModel):
text: str
target: Literal["en", "zh"] = "en"
@app.post("/translate")
def translate_text(req: TranslateReq) -> dict:
"""单条文本翻译(给生图自定义提取元素 zh→en 用)"""
import re as _re
text = req.text.strip()
if not text:
return {"text": ""}
if not LLM_API_KEY:
raise HTTPException(503, "LLM_API_KEY 未配置")
target_label = "English" if req.target == "en" else "Simplified Chinese"
prompt = (
f"Translate the following text into concise {target_label}, suitable as an element "
"label in an image-generation prompt. Output only the translation itself — no quotes, "
"no punctuation, no explanation, no markdown.\n\n"
f"Input: {text}"
)
try:
resp = llm().chat.completions.create(
model=TRANSLATE_MODEL,
messages=[{"role": "user", "content": prompt}],
temperature=0.2,
max_tokens=200,
)
out = (resp.choices[0].message.content or "").strip()
if not out:
rc = getattr(resp.choices[0].message, "reasoning_content", "") or ""
if rc:
out = rc.strip().splitlines()[-1].strip()
out = _re.sub(r'^[\'"「『]+|[\'"」』]+$', "", out).strip()
return {"text": out}
except Exception as e:
raise HTTPException(500, f"translate failed: {e}")
@app.get("/health")
def health() -> dict:
return {
"ok": True,
"llm_configured": bool(LLM_API_KEY),
"base_url": LLM_BASE_URL or "openai-default",
"models": {
"asr": ASR_MODEL,
"translate": TRANSLATE_MODEL,
"rewrite": REWRITE_MODEL,
},
}
@app.post("/jobs", response_model=Job)
async def create_job(req: CreateJobReq, bg: BackgroundTasks) -> Job:
if not req.url.strip():
raise HTTPException(400, "url required")
job_id = uuid.uuid4().hex[:12]
job = Job(id=job_id, url=req.url.strip())
JOBS[job_id] = job
save_state(job)
bg.add_task(pipeline_download, job_id)
return job
@app.post("/jobs/upload", response_model=Job)
async def create_job_from_upload(bg: BackgroundTasks, file: UploadFile = File(...)) -> Job:
if not file.filename:
raise HTTPException(400, "file required")
ext = Path(file.filename).suffix.lower()
if ext not in {".mp4", ".mov", ".webm", ".mkv", ".m4v"}:
raise HTTPException(400, f"unsupported video format: {ext}")
job_id = uuid.uuid4().hex[:12]
d = job_dir(job_id)
mp4 = d / "source.mp4"
with mp4.open("wb") as f:
while chunk := await file.read(1024 * 1024):
f.write(chunk)
if not mp4.exists() or mp4.stat().st_size == 0:
raise HTTPException(500, "upload failed")
job = Job(id=job_id, url=f"upload://{file.filename}")
JOBS[job_id] = job
save_state(job)
bg.add_task(pipeline_download, job_id)
return job
@app.post("/jobs/{job_id}/analyze", response_model=Job)
async def trigger_analyze(job_id: str, bg: BackgroundTasks, frames: int = KEYFRAME_COUNT) -> Job:
job = JOBS.get(job_id)
if not job:
raise HTTPException(404, "job not found")
if job.status not in {"downloaded", "frames_extracted", "transcribed", "failed"}:
raise HTTPException(409, f"status must be downloaded/failed, got {job.status}")
bg.add_task(pipeline_analyze, job_id, frames)
return job
@app.post("/jobs/{job_id}/frames", response_model=Job)
def add_manual_frame(job_id: str, t: float) -> Job:
"""从指定时间戳手动抽 1 帧追加到 job.frames"""
job = JOBS.get(job_id)
if not job:
raise HTTPException(404, "job not found")
if not job.video_url:
raise HTTPException(400, "video not ready")
d = job_dir(job_id)
mp4 = d / "source.mp4"
if not mp4.exists():
raise HTTPException(400, "source.mp4 missing")
frames_dir = d / "frames"
frames_dir.mkdir(parents=True, exist_ok=True)
# 新 indexmax(existing)+1即使列表已按 ts 排序,文件名用 index 保持稳定)
next_idx = max((f.index for f in job.frames), default=-1) + 1
out = frames_dir / f"{next_idx:03d}.jpg"
try:
run([
"ffmpeg", "-y", "-ss", str(t), "-i", str(mp4),
"-frames:v", "1", "-pix_fmt", "yuvj420p", "-q:v", "3",
str(out),
])
except RuntimeError as e:
raise HTTPException(500, f"ffmpeg failed: {e}")
new_frame = KeyFrame(
index=next_idx,
timestamp=round(float(t), 2),
url=f"/jobs/{job_id}/frames/{next_idx}.jpg",
)
merged = sorted(list(job.frames) + [new_frame], key=lambda f: f.timestamp)
update(job, frames=merged, message=f"已手动加帧({t:.1f}s{len(merged)}")
return job
@app.get("/jobs/{job_id}", response_model=Job)
def get_job(job_id: str) -> Job:
job = JOBS.get(job_id)
if not job:
raise HTTPException(404, "job not found")
return job
@app.post("/jobs/{job_id}/transcribe", response_model=Job)
async def trigger_transcribe(job_id: str, bg: BackgroundTasks) -> Job:
job = JOBS.get(job_id)
if not job:
raise HTTPException(404, "job not found")
if job.status != "frames_extracted":
raise HTTPException(409, f"status must be frames_extracted, got {job.status}")
bg.add_task(pipeline_transcribe, job_id)
return job
@app.get("/jobs/{job_id}/video.mp4")
def get_video(job_id: str):
p = job_dir(job_id) / "source.mp4"
if not p.exists():
raise HTTPException(404, "video not found")
return FileResponse(p, media_type="video/mp4")
@app.get("/jobs/{job_id}/frames/{idx}.jpg")
def get_frame(job_id: str, idx: int):
p = job_dir(job_id) / "frames" / f"{idx:03d}.jpg"
if not p.exists():
raise HTTPException(404, "frame not found")
return FileResponse(p, media_type="image/jpeg")
class GenerateReq(BaseModel):
prompt: str
extra_prompt: str = "" # ✓ 需要的元素(正向)
negative_prompt: str = "" # ✗ 不需要的元素(负向)
model: str = "" # 留空用 IMAGE_MODEL 默认
mode: str = "edit" # "edit" 带参考图,"text" 纯文字
from_selected: bool = False # True 时优先用 frame.selected 的生成图作 reference迭代否则原关键帧
@app.post("/jobs/{job_id}/frames/{idx}/generate", response_model=Job)
def generate_image(job_id: str, idx: int, req: GenerateReq) -> Job:
"""根据关键帧 + prompt 生成新图image-to-image 或 text-to-image"""
job = JOBS.get(job_id)
if not job:
raise HTTPException(404, "job not found")
frame = next((f for f in job.frames if f.index == idx), None)
if not frame:
raise HTTPException(404, "frame not found")
frame_path = job_dir(job_id) / "frames" / f"{idx:03d}.jpg"
if not frame_path.exists():
raise HTTPException(404, "frame file missing")
# 决定 i2i 参考图from_selected=True 且存在 selected 生成图 → 用它(迭代);否则原关键帧
reference_path = frame_path
reference_source = "keyframe"
if req.from_selected:
sel = next((g for g in frame.generated_images if g.selected), None)
if sel:
sel_path = job_dir(job_id) / "gen" / f"{idx:03d}_{sel.id}.jpg"
if sel_path.exists():
reference_path = sel_path
reference_source = f"gen:{sel.id[:6]}"
full_prompt = req.prompt.strip()
if req.extra_prompt.strip():
full_prompt = f"{full_prompt}. Include: {req.extra_prompt.strip()}"
if req.negative_prompt.strip():
full_prompt = f"{full_prompt}. Avoid: {req.negative_prompt.strip()}"
if not full_prompt:
raise HTTPException(400, "prompt required")
model = req.model or IMAGE_MODEL
gen_id = uuid.uuid4().hex[:12]
import base64 as b64lib
import time as _time
import httpx
img_b64: str | None = None
if req.mode == "edit":
img_b64 = b64lib.b64encode(reference_path.read_bytes()).decode("ascii")
# 尝试 i2i 最多 3 次,全失败时降级 text-only 再试 1 次
plan: list[str] = ([req.mode] * 3) if req.mode == "edit" else [req.mode]
if req.mode == "edit":
plan.append("text") # i2i 都失败时自动降级
resp_data: dict = {}
last_err = ""
effective_mode = req.mode
for attempt, current_mode in enumerate(plan):
try:
if current_mode == "edit":
data_uri = f"data:image/jpeg;base64,{img_b64}"
# OpenAI SDK 不直接支持 image 参数,用底层 httpx
with httpx.Client(timeout=120) as client:
r = client.post(
f"{LLM_BASE_URL}/images/generations",
headers={
"Authorization": f"Bearer {LLM_API_KEY}",
"Content-Type": "application/json",
},
json={
"model": model,
"prompt": full_prompt,
"image": data_uri,
"n": 1,
},
)
r.raise_for_status()
resp_data = r.json()
else:
# text-only
resp = llm().images.generate(model=model, prompt=full_prompt, n=1)
resp_data = resp.model_dump() if hasattr(resp, "model_dump") else {"data": [{"b64_json": resp.data[0].b64_json}]}
if resp_data.get("data"):
effective_mode = current_mode
break
err_obj = resp_data.get("error") or {}
last_err = f"empty data · {err_obj.get('code', '')} · {str(err_obj.get('message', ''))[:200]}"
except httpx.HTTPStatusError as e:
body = e.response.text
transient = (
e.response.status_code >= 500
or "incomplete_generation" in body
or "rate_limit" in body
or "timeout" in body.lower()
)
last_err = f"HTTP {e.response.status_code}: {body[:200]}"
if not transient:
raise HTTPException(500, f"image gen HTTP {e.response.status_code}: {body[:300]}")
except Exception as e:
last_err = f"{type(e).__name__}: {e}"
if attempt < len(plan) - 1:
next_mode = plan[attempt + 1]
tag = f"fallback → {next_mode}" if next_mode != current_mode else f"retry {attempt + 1}/{len(plan)}"
print(f"[image gen {tag}] {last_err}", flush=True)
_time.sleep(1.5 * (attempt + 1))
data_arr = resp_data.get("data", [])
if not data_arr:
raise HTTPException(500, f"image gen failed after {len(plan)} attempts: {last_err}")
item = data_arr[0]
b64 = item.get("b64_json")
if not b64:
raise HTTPException(500, "image gen returned no b64_json")
# 保存到本地 jobs/<id>/gen/<idx>_<gen_id>.jpg
gen_dir = job_dir(job_id) / "gen"
gen_dir.mkdir(parents=True, exist_ok=True)
out_path = gen_dir / f"{idx:03d}_{gen_id}.jpg"
out_path.write_bytes(b64lib.b64decode(b64))
new_gen = GeneratedImage(
id=gen_id,
prompt=full_prompt,
model=model,
mode=effective_mode,
url=f"/jobs/{job_id}/frames/{idx}/gen/{gen_id}.jpg",
selected=False,
created_at=_time.time(),
)
# 写回 job.frames
for f in job.frames:
if f.index == idx:
f.generated_images = f.generated_images + [new_gen]
update(job, frames=job.frames, message=f"生图完成 · 分镜 {idx + 1}")
return job
@app.get("/jobs/{job_id}/frames/{idx}/gen/{gen_id}.jpg")
def get_generated_image(job_id: str, idx: int, gen_id: str):
p = job_dir(job_id) / "gen" / f"{idx:03d}_{gen_id}.jpg"
if not p.exists():
raise HTTPException(404, "generated image not found")
return FileResponse(p, media_type="image/jpeg")
class SelectGenReq(BaseModel):
selected: bool
@app.post("/jobs/{job_id}/frames/{idx}/gen/{gen_id}/select", response_model=Job)
def select_generated(job_id: str, idx: int, gen_id: str, req: SelectGenReq) -> Job:
job = JOBS.get(job_id)
if not job:
raise HTTPException(404, "job not found")
for f in job.frames:
if f.index != idx:
continue
for g in f.generated_images:
# 单选:该帧只能选一张
if g.id == gen_id:
g.selected = req.selected
else:
g.selected = False
break
update(job, frames=job.frames)
return job
@app.post("/jobs/{job_id}/frames/{idx}/describe", response_model=Job)
def describe_frame(job_id: str, idx: int) -> Job:
"""调 vision 模型识别该关键帧,返回结构化描述。"""
job = JOBS.get(job_id)
if not job:
raise HTTPException(404, "job not found")
frame = next((f for f in job.frames if f.index == idx), None)
if not frame:
raise HTTPException(404, "frame not found")
p = job_dir(job_id) / "frames" / f"{idx:03d}.jpg"
if not p.exists():
raise HTTPException(404, "frame file not found")
import base64 as b64lib
import re as _re
img_b64 = b64lib.b64encode(p.read_bytes()).decode("ascii")
prompt = (
"请识别这张图,输出严格 JSON不要 markdown 不要解释,不要思考):\n"
'{\n'
' "scene": "一句话描述场景",\n'
' "objects": [{"name": "物体名(中文)", "position": "在画面哪里", "color": "颜色", "extract_prompt": "用于提取该元素的英文 prompt"}],\n'
' "style": "整体风格 / 打光 / 色调(一句话)",\n'
' "suggested_prompt": "适合用作下游生图的完整英文 prompt"\n'
'}\n'
"要求objects 列出 3-8 个画面里**可独立提取**的主要元素extract_prompt 用于后续 image edit 模型。"
)
last_err = ""
data = None
for attempt in range(3):
try:
resp = llm().chat.completions.create(
model=VISION_MODEL,
messages=[{"role": "user", "content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img_b64}"}},
]}],
response_format={"type": "json_object"},
temperature=0.3,
max_tokens=3000,
)
content = (resp.choices[0].message.content or "").strip()
if not content:
# thinking 模型可能 content 空;尝试取 reasoning_content 里挖 JSON
rc = getattr(resp.choices[0].message, "reasoning_content", "") or ""
m = _re.search(r"\{[\s\S]*\}", rc)
content = m.group(0) if m else ""
# 剥掉 ```json ... ``` 包装
content = _re.sub(r"^```(?:json)?\s*|\s*```$", "", content).strip()
if not content:
last_err = f"empty content (attempt {attempt + 1})"
continue
data = json.loads(content)
break
except json.JSONDecodeError as e:
last_err = f"json decode (attempt {attempt + 1}): {e} · raw[:200]={content[:200]}"
print(f"[vision retry] {last_err}", flush=True)
continue
except Exception as e:
last_err = f"vision call (attempt {attempt + 1}): {e}"
print(f"[vision retry] {last_err}", flush=True)
continue
if data is None:
raise HTTPException(500, last_err or "vision failed after 3 retries")
# 写回 job
new_frames = []
for f in job.frames:
if f.index == idx:
f.description = data
new_frames.append(f)
update(job, frames=new_frames, message=f"识别完成 · 分镜 {idx + 1}")
return job
# ---------- 清洗水印 / 元素提取(关键帧二阶段加工) ----------
class CleanupReq(BaseModel):
# 多个相对坐标矩形 0-1限制清洗范围空 / None = 全图清洗
regions: list[dict] | None = None # [{"x","y","w","h"}, ...]
def _region_to_phrase(r: dict) -> str:
"""把相对坐标矩形转成简短方位描述给 prompt 用(避免百分号 / 括号触发模型异常)"""
x = max(0.0, min(1.0, float(r.get("x", 0))))
y = max(0.0, min(1.0, float(r.get("y", 0))))
w = max(0.0, min(1.0 - x, float(r.get("w", 0))))
h = max(0.0, min(1.0 - y, float(r.get("h", 0))))
if w <= 0 or h <= 0:
return ""
cx, cy = x + w / 2, y + h / 2
hpos = "left" if cx < 0.4 else "right" if cx > 0.6 else "middle"
vpos = "top" if cy < 0.4 else "bottom" if cy > 0.6 else "middle"
if hpos == "middle" and vpos == "middle":
return "center"
if hpos == "middle":
return vpos
if vpos == "middle":
return hpos
return f"{vpos} {hpos}"
@app.post("/jobs/{job_id}/frames/{idx}/cleanup", response_model=Job)
def cleanup_frame(job_id: str, idx: int, req: CleanupReq | None = None) -> Job:
"""调 nano-banana image edit 清洗关键帧:去水印 / @用户名 / 字幕 / 平台 logo。
输出干净版到 jobs/<id>/cleaned/<idx>.jpg写回 frame.cleaned_url。
可选 region: 限定只清洗框内区域。"""
import time as _time
job = JOBS.get(job_id)
if not job:
raise HTTPException(404, "job not found")
frame = next((f for f in job.frames if f.index == idx), None)
if not frame:
raise HTTPException(404, "frame not found")
frame_path = job_dir(job_id) / "frames" / f"{idx:03d}.jpg"
if not frame_path.exists():
raise HTTPException(404, "frame file missing")
region_phrases: list[str] = []
if req and req.regions:
for r in req.regions:
p = _region_to_phrase(r)
if p:
region_phrases.append(p)
# 去重保序
region_phrases = list(dict.fromkeys(region_phrases))
if region_phrases:
if len(region_phrases) == 1:
zones = f"the {region_phrases[0]} part"
else:
zones = "these parts: " + ", ".join(region_phrases)
prompt = (
f"Erase the text and graphics in {zones} of the image. "
"Keep all other parts unchanged."
)
else:
prompt = "Erase all watermarks and text overlays. Keep the scene natural."
try:
img_bytes, _mode = _image_edit_call(frame_path, prompt, fallback_text=False, max_attempts=3)
except RuntimeError as e:
raise HTTPException(500, f"cleanup failed: {e}")
out_dir = job_dir(job_id) / "cleaned"
out_dir.mkdir(parents=True, exist_ok=True)
out_path = out_dir / f"{idx:03d}.jpg"
out_path.write_bytes(img_bytes)
new_frames = []
for f in job.frames:
if f.index == idx:
f.cleaned_url = f"/jobs/{job_id}/frames/{idx}/cleaned.jpg?t={int(_time.time())}"
f.cleaned_applied = False # 重新清洗:重置"已应用"状态
new_frames.append(f)
update(job, frames=new_frames, message=f"清洗完成 · 分镜 {idx + 1}")
return job
@app.get("/jobs/{job_id}/frames/{idx}/cleaned.jpg")
def get_cleaned_frame(job_id: str, idx: int):
p = job_dir(job_id) / "cleaned" / f"{idx:03d}.jpg"
if not p.exists():
raise HTTPException(404, "cleaned frame not found")
return FileResponse(p, media_type="image/jpeg")
@app.post("/jobs/{job_id}/frames/{idx}/cleanup/apply", response_model=Job)
def apply_cleaned(job_id: str, idx: int) -> Job:
"""用清洗版替换原关键帧:物理覆盖 frames/{idx}.jpg ← cleaned/{idx}.jpg。
原图作备份 → orig/{idx}.jpg首次替换时备份后续替换跳过
替换后 frame.cleaned_url 清空(不再有"待应用"清洗版)"""
import shutil as _shutil
job = JOBS.get(job_id)
if not job:
raise HTTPException(404, "job not found")
frame = next((f for f in job.frames if f.index == idx), None)
if not frame:
raise HTTPException(404, "frame not found")
cleaned_path = job_dir(job_id) / "cleaned" / f"{idx:03d}.jpg"
if not cleaned_path.exists():
raise HTTPException(404, "no cleaned version to apply")
frame_path = job_dir(job_id) / "frames" / f"{idx:03d}.jpg"
# 首次替换:把原图备份到 orig/{idx}.jpg
orig_dir = job_dir(job_id) / "orig"
orig_dir.mkdir(parents=True, exist_ok=True)
orig_backup = orig_dir / f"{idx:03d}.jpg"
if not orig_backup.exists() and frame_path.exists():
_shutil.copy2(frame_path, orig_backup)
# 用 cleaned 覆盖 frames/
_shutil.copy2(cleaned_path, frame_path)
# 删 cleaned 文件(已经"应用",不再是单独的待选版本)
try:
cleaned_path.unlink()
except OSError:
pass
new_frames = []
for f in job.frames:
if f.index == idx:
f.cleaned_url = None
f.cleaned_applied = True
new_frames.append(f)
update(job, frames=new_frames, message=f"已替换分镜 {idx + 1} 为清洗版")
return job
class AddElementReq(BaseModel):
name_zh: str
name_en: str = ""
position: str = ""
source: Literal["auto", "manual", "region"] = "manual"
region: dict | None = None
@app.post("/jobs/{job_id}/frames/{idx}/elements", response_model=Job)
def add_element(job_id: str, idx: int, req: AddElementReq) -> Job:
"""加一条元素 · 若 name_en 缺则自动 zh→en 翻译"""
import time as _time
import re as _re
job = JOBS.get(job_id)
if not job:
raise HTTPException(404, "job not found")
frame = next((f for f in job.frames if f.index == idx), None)
if not frame:
raise HTTPException(404, "frame not found")
name_zh = req.name_zh.strip()
if not name_zh:
raise HTTPException(400, "name_zh required")
name_en = req.name_en.strip()
if not name_en and LLM_API_KEY:
try:
prompt = (
"Translate the following text into concise English, suitable as an element label "
"in an image-generation prompt. Output only the translation — no quotes, no punctuation, "
f"no explanation.\n\nInput: {name_zh}"
)
resp = llm().chat.completions.create(
model=TRANSLATE_MODEL,
messages=[{"role": "user", "content": prompt}],
temperature=0.2,
max_tokens=200,
)
out = (resp.choices[0].message.content or "").strip()
if not out:
rc = getattr(resp.choices[0].message, "reasoning_content", "") or ""
if rc:
out = rc.strip().splitlines()[-1].strip()
name_en = _re.sub(r'^[\'"「『]+|[\'"」』]+$', "", out).strip()
except Exception as e:
print(f"[add_element translate failed] {e}", flush=True)
name_en = ""
el = KeyElement(
id=uuid.uuid4().hex[:8],
name_zh=name_zh,
name_en=name_en,
position=req.position.strip(),
source=req.source,
region=req.region,
created_at=_time.time(),
)
new_frames = []
for f in job.frames:
if f.index == idx:
f.elements = f.elements + [el]
new_frames.append(f)
update(job, frames=new_frames, message=f"加入元素 · 分镜 {idx + 1} · {name_zh}")
return job
@app.delete("/jobs/{job_id}/frames/{idx}/elements/{element_id}", response_model=Job)
def delete_element(job_id: str, idx: int, element_id: str) -> Job:
job = JOBS.get(job_id)
if not job:
raise HTTPException(404, "job not found")
new_frames = []
removed = False
for f in job.frames:
if f.index == idx:
before = len(f.elements)
f.elements = [e for e in f.elements if e.id != element_id]
removed = len(f.elements) < before
# 若有抠图文件也删
if removed:
for ext in ("jpg", "png"):
cutout = job_dir(job_id) / "elements" / f"{idx:03d}_{element_id}.{ext}"
if cutout.exists():
try: cutout.unlink()
except OSError: pass
new_frames.append(f)
if not removed:
raise HTTPException(404, "element not found")
update(job, frames=new_frames, message=f"删除元素 · 分镜 {idx + 1}")
return job
class CutoutReq(BaseModel):
background: Literal["white", "black"] = "white"
@app.post("/jobs/{job_id}/frames/{idx}/elements/{element_id}/cutout", response_model=Job)
def cutout_element(job_id: str, idx: int, element_id: str, req: CutoutReq | None = None) -> Job:
"""单元素抠图:调 nano-banana image edit输出纯白底 / 纯黑底元素图。
Gemini 输出 JPEG 不支持真 alpha因此用纯色背景。"""
job = JOBS.get(job_id)
if not job:
raise HTTPException(404, "job not found")
frame = next((f for f in job.frames if f.index == idx), None)
if not frame:
raise HTTPException(404, "frame not found")
el = next((e for e in frame.elements if e.id == element_id), None)
if not el:
raise HTTPException(404, "element not found")
# 优先用 cleaned 版作 reference已去掉 logo / 水印干扰fallback 原图
cleaned_path = job_dir(job_id) / "cleaned" / f"{idx:03d}.jpg"
src = cleaned_path if cleaned_path.exists() else job_dir(job_id) / "frames" / f"{idx:03d}.jpg"
if not src.exists():
raise HTTPException(404, "source frame file missing")
background = (req.background if req else "white") or "white"
bg_phrase = f"pure {background}"
target = (el.name_en or el.name_zh).strip()
region_phrase = _region_to_phrase(el.region) if el.region else ""
if region_phrase:
prompt = (
f"Extract whatever is in the {region_phrase} part of the image as a standalone asset. "
f"Place it on a {bg_phrase} background, isolated, no other objects."
)
else:
position_hint = f" Located in the {el.position} area." if el.position else ""
prompt = (
f"Extract the {target} from this image as a standalone asset.{position_hint} "
f"Place it on a {bg_phrase} background, isolated, no other objects."
)
try:
img_bytes, _mode = _image_edit_call(src, prompt, fallback_text=False, max_attempts=3)
except RuntimeError as e:
raise HTTPException(500, f"cutout failed: {e}")
out_dir = job_dir(job_id) / "elements"
out_dir.mkdir(parents=True, exist_ok=True)
# 实际是 JPEG bytes文件用 .jpg 真名
out_path = out_dir / f"{idx:03d}_{element_id}.jpg"
out_path.write_bytes(img_bytes)
# 旧版的 .png 文件(错命名为 .png 的 JPEG也清理掉
old_png = out_dir / f"{idx:03d}_{element_id}.png"
if old_png.exists():
try: old_png.unlink()
except OSError: pass
new_frames = []
for f in job.frames:
if f.index == idx:
for e in f.elements:
if e.id == element_id:
e.cutout_id = element_id
e.cutout_background = background
new_frames.append(f)
update(job, frames=new_frames, message=f"抠图完成 · {el.name_zh}{background} 底)")
return job
@app.get("/jobs/{job_id}/frames/{idx}/elements/{element_id}/cutout.jpg")
def get_cutout(job_id: str, idx: int, element_id: str):
p = job_dir(job_id) / "elements" / f"{idx:03d}_{element_id}.jpg"
if not p.exists():
# 兼容老数据
legacy = job_dir(job_id) / "elements" / f"{idx:03d}_{element_id}.png"
if legacy.exists():
return FileResponse(legacy, media_type="image/jpeg")
raise HTTPException(404, "cutout not found")
return FileResponse(p, media_type="image/jpeg")
# ---------- 删除:关键帧 / 单张生成图 ----------
@app.delete("/jobs/{job_id}/frames/{idx}", response_model=Job)
def delete_frame(job_id: str, idx: int) -> Job:
"""删除整张关键帧,清理所有附属文件(原图 / 干净版 / 元素抠图 / 生成图)"""
job = JOBS.get(job_id)
if not job:
raise HTTPException(404, "job not found")
target = next((f for f in job.frames if f.index == idx), None)
if not target:
raise HTTPException(404, "frame not found")
d = job_dir(job_id)
# 删文件 — 静默错误,文件可能不存在
paths = [
d / "frames" / f"{idx:03d}.jpg",
d / "cleaned" / f"{idx:03d}.jpg",
]
for p in paths:
if p.exists():
try: p.unlink()
except OSError: pass
# 该帧的所有元素抠图(命名前缀 {idx:03d}_
elements_dir = d / "elements"
if elements_dir.exists():
for ext in ("png", "jpg"):
for p in elements_dir.glob(f"{idx:03d}_*.{ext}"):
try: p.unlink()
except OSError: pass
# 该帧的所有生成图
gen_dir = d / "gen"
if gen_dir.exists():
for p in gen_dir.glob(f"{idx:03d}_*.jpg"):
try: p.unlink()
except OSError: pass
new_frames = [f for f in job.frames if f.index != idx]
update(job, frames=new_frames, message=f"删除分镜 {idx + 1}")
return job
@app.delete("/jobs/{job_id}/frames/{idx}/gen/{gen_id}", response_model=Job)
def delete_generated(job_id: str, idx: int, gen_id: str) -> Job:
"""删除该 frame 的某张生成图(文件 + 列表)"""
job = JOBS.get(job_id)
if not job:
raise HTTPException(404, "job not found")
frame = next((f for f in job.frames if f.index == idx), None)
if not frame:
raise HTTPException(404, "frame not found")
p = job_dir(job_id) / "gen" / f"{idx:03d}_{gen_id}.jpg"
if p.exists():
try: p.unlink()
except OSError: pass
new_frames = []
found = False
for f in job.frames:
if f.index == idx:
before = len(f.generated_images)
f.generated_images = [g for g in f.generated_images if g.id != gen_id]
found = len(f.generated_images) < before
new_frames.append(f)
if not found:
raise HTTPException(404, "generated image not found")
update(job, frames=new_frames, message=f"删除生成图 · 分镜 {idx + 1}")
return job