from __future__ import annotations import asyncio import json import os import shutil import subprocess import uuid from contextlib import asynccontextmanager from pathlib import Path from typing import Literal from dotenv import load_dotenv from fastapi import BackgroundTasks, FastAPI, File, HTTPException, UploadFile from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse from pydantic import BaseModel, Field load_dotenv() JOBS_DIR = Path(os.getenv("JOBS_DIR", "./jobs")).resolve() JOBS_DIR.mkdir(parents=True, exist_ok=True) CORS_ORIGINS = [o.strip() for o in os.getenv("CORS_ORIGINS", "http://localhost:4290").split(",") if o.strip()] LLM_BASE_URL = os.getenv("LLM_BASE_URL", "").strip() LLM_API_KEY = os.getenv("LLM_API_KEY", "").strip() ASR_MODEL = os.getenv("ASR_MODEL", "whisper-1") TRANSLATE_MODEL = os.getenv("TRANSLATE_MODEL", "gemini-2.5-flash") REWRITE_MODEL = os.getenv("REWRITE_MODEL", "gemini-2.5-pro") # OpenAI 客户端(OpenAI 兼容网关,含 SKG ezlink) from openai import OpenAI _llm_client: OpenAI | None = None def llm() -> OpenAI: global _llm_client if _llm_client is None: if not LLM_API_KEY: raise RuntimeError("LLM_API_KEY 未配置") _llm_client = OpenAI(base_url=LLM_BASE_URL or None, api_key=LLM_API_KEY) return _llm_client # Pipeline 状态: # created → downloading → downloaded(停,等用户点解析)→ splitting → frames_extracted # → transcribing → transcribed | failed JobStatus = Literal[ "created", "downloading", "downloaded", "splitting", "frames_extracted", "transcribing", "transcribed", "failed", ] KEYFRAME_COUNT = int(os.getenv("KEYFRAME_COUNT", "5")) class KeyFrame(BaseModel): index: int timestamp: float url: str class TranscriptSegment(BaseModel): index: int start: float end: float en: str zh: str = "" class Job(BaseModel): id: str url: str status: JobStatus = "created" progress: int = 0 message: str = "" video_url: str = "" duration: float = 0.0 width: int = 0 height: int = 0 frames: list[KeyFrame] = Field(default_factory=list) transcript: list[TranscriptSegment] = Field(default_factory=list) error: str = "" JOBS: dict[str, Job] = {} def job_dir(job_id: str) -> Path: d = JOBS_DIR / job_id d.mkdir(parents=True, exist_ok=True) return d def save_state(job: Job) -> None: (job_dir(job.id) / "state.json").write_text(job.model_dump_json(indent=2)) def update(job: Job, **kw) -> None: for k, v in kw.items(): setattr(job, k, v) save_state(job) @asynccontextmanager async def lifespan(_: FastAPI): # 启动时从磁盘恢复 jobs(简化版:只列目录) for p in JOBS_DIR.iterdir(): if p.is_dir() and (p / "state.json").exists(): try: JOBS[p.name] = Job.model_validate_json((p / "state.json").read_text()) except Exception: pass yield app = FastAPI(title="SKG TK 二创 API", lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=CORS_ORIGINS, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ---------- Pipeline 实现 ---------- def run(cmd: list[str], cwd: Path | None = None) -> str: res = subprocess.run(cmd, cwd=cwd, capture_output=True, text=True) if res.returncode != 0: # ffmpeg 把 banner 写 stderr,挑最后几行(真错误一般在末尾) tail = "\n".join(res.stderr.splitlines()[-12:]) or res.stderr[-800:] raise RuntimeError(f"cmd failed: {' '.join(cmd[:3])}... · {tail}") return res.stdout # ---- 启发式选帧工具 ---- import imagehash import numpy as np from PIL import Image def _sharpness(img_path: Path) -> float: """Laplacian variance:值越大越清晰,模糊/转场帧值低。""" g = np.asarray(Image.open(img_path).convert("L").resize((320, 180)), dtype=np.float32) lap = (-4 * g[1:-1, 1:-1] + g[:-2, 1:-1] + g[2:, 1:-1] + g[1:-1, :-2] + g[1:-1, 2:]) return float(lap.var()) def _select_keyframes(candidates: list[Path], n: int, dup_threshold: int = 8) -> list[Path]: """ candidates: 按时间排序的候选帧路径 n: 目标帧数 dup_threshold: pHash 汉明距离 < 此值视为相似(默认 8,64bit hash 大致 ~12.5% 像素差) """ if len(candidates) <= n: return candidates # 算 pHash + sharpness items = [] for i, p in enumerate(candidates): try: img = Image.open(p) h = imagehash.phash(img) s = _sharpness(p) items.append({"path": p, "idx": i, "hash": h, "sharp": s}) except Exception: continue # 去重:相似帧保留 sharpness 高的 deduped: list[dict] = [] for it in items: dup = None for kept in deduped: if (it["hash"] - kept["hash"]) < dup_threshold: dup = kept break if dup is None: deduped.append(it) elif it["sharp"] > dup["sharp"]: deduped[deduped.index(dup)] = it # 时序分桶:把候选时间轴等分 n 段,每段取去重后 sharpness 最高的 total = len(candidates) buckets: list[list[dict]] = [[] for _ in range(n)] for it in deduped: b = min(int(it["idx"] * n / total), n - 1) buckets[b].append(it) selected: list[dict] = [] for b in buckets: if b: selected.append(max(b, key=lambda x: x["sharp"])) # 空桶补足:从未选的 deduped 里按 sharpness 排序补 chosen_paths = {it["path"] for it in selected} remaining = sorted([it for it in deduped if it["path"] not in chosen_paths], key=lambda x: -x["sharp"]) while len(selected) < n and remaining: selected.append(remaining.pop(0)) # 按时间排序输出 selected.sort(key=lambda x: x["idx"]) return [it["path"] for it in selected] def ffprobe_meta(mp4: Path) -> dict: out = run([ "ffprobe", "-v", "error", "-print_format", "json", "-show_streams", "-show_format", str(mp4), ]) return json.loads(out) async def pipeline_download(job_id: str) -> None: """阶段 1:仅下载(或上传跳过),落 source.mp4,停在 downloaded 等用户点解析。""" job = JOBS[job_id] d = job_dir(job_id) try: mp4 = d / "source.mp4" if mp4.exists(): update(job, status="downloading", message="本地上传 · 跳过下载", progress=15) else: update(job, status="downloading", message="yt-dlp 下载中…", progress=5) run([ "yt-dlp", "-f", "best[ext=mp4]/best", "-o", str(mp4), "--no-warnings", "--no-playlist", "--retries", "3", job.url, ]) if not mp4.exists(): raise RuntimeError("下载完成但找不到 source.mp4") meta = ffprobe_meta(mp4) v_stream = next((s for s in meta["streams"] if s["codec_type"] == "video"), None) duration = float(meta["format"]["duration"]) update( job, status="downloaded", video_url=f"/jobs/{job_id}/video.mp4", duration=duration, width=int(v_stream["width"]) if v_stream else 0, height=int(v_stream["height"]) if v_stream else 0, progress=25, message=f"视频就绪 · {duration:.1f}s · 等待解析", ) except Exception as e: update(job, status="failed", error=str(e), message="下载失败") async def pipeline_analyze(job_id: str, frame_count: int = KEYFRAME_COUNT) -> None: """阶段 2:拆音轨 + 抽关键帧 + ASR + 翻译。需要 source.mp4 已存在。""" job = JOBS[job_id] d = job_dir(job_id) try: mp4 = d / "source.mp4" if not mp4.exists(): raise RuntimeError("source.mp4 不存在,先完成下载") update(job, status="splitting", message="ffmpeg 拆分音轨…", progress=35) wav = d / "audio.wav" run([ "ffmpeg", "-y", "-i", str(mp4), "-vn", "-ac", "1", "-ar", "16000", "-c:a", "pcm_s16le", str(wav), ]) n = max(1, min(int(frame_count), 20)) # 候选数:n 的 6 倍或至少 24,封顶 60 candidate_count = max(24, min(60, n * 6)) update(job, message=f"抽取候选 {candidate_count} 张…", progress=45) frames_dir = d / "frames" if frames_dir.exists(): shutil.rmtree(frames_dir) frames_dir.mkdir(parents=True) cand_dir = d / "candidates" if cand_dir.exists(): shutil.rmtree(cand_dir) cand_dir.mkdir(parents=True) # 1) 均匀采样大批候选(fast seek,每张 < 0.5s) duration = max(float(job.duration or 1.0), 0.1) step = duration / (candidate_count + 1) candidate_meta: list[tuple[Path, float]] = [] # (path, timestamp) for i in range(candidate_count): t = step * (i + 1) out = cand_dir / f"c_{i:03d}.jpg" run([ "ffmpeg", "-y", "-ss", str(t), "-i", str(mp4), "-frames:v", "1", "-pix_fmt", "yuvj420p", "-q:v", "3", str(out), ]) if out.exists(): candidate_meta.append((out, t)) # 2) D 启发式选 n 张:pHash 去重 + Laplacian 清晰度 + 时序分桶 update(job, message=f"启发式筛选 {n} / {len(candidate_meta)} 张…", progress=60) cand_paths = [m[0] for m in candidate_meta] ts_by_path = {m[0]: m[1] for m in candidate_meta} chosen = _select_keyframes(cand_paths, n) # 3) 落盘到 frames/.jpg renamed: list[KeyFrame] = [] chosen_sorted = sorted(chosen, key=lambda p: ts_by_path[p]) for i, src in enumerate(chosen_sorted): dst = frames_dir / f"{i:03d}.jpg" shutil.copyfile(src, dst) renamed.append(KeyFrame( index=i, timestamp=round(ts_by_path[src], 2), url=f"/jobs/{job_id}/frames/{i}.jpg", )) # 4) 清理候选目录 shutil.rmtree(cand_dir, ignore_errors=True) update( job, status="frames_extracted", frames=renamed, progress=70, message=f"已抽取 {len(renamed)} 张关键帧", ) # 自动接 ASR + 翻译 await pipeline_transcribe(job_id) except Exception as e: update(job, status="failed", error=str(e), message="解析失败") # ---------- Gemini ASR + 翻译 ---------- def _transcribe_sync(wav: Path) -> list[dict]: """whisper-1 verbose_json → segments[{start, end, text}]""" with wav.open("rb") as f: resp = llm().audio.transcriptions.create( file=(wav.name, f, "audio/wav"), model=ASR_MODEL, response_format="verbose_json", timestamp_granularities=["segment"], ) raw = resp.model_dump() if hasattr(resp, "model_dump") else resp segments = raw.get("segments") or [] # 兜底:网关如果不返回 segments,把全文当一段 if not segments and raw.get("text"): segments = [{"start": 0.0, "end": float(raw.get("duration", 0) or 0), "text": raw["text"]}] return segments def _translate_sync(segments: list[dict]) -> list[str]: """gemini-2.5-flash 批量翻译为中文,按段返回""" payload = [{"i": i, "en": s.get("text", "").strip()} for i, s in enumerate(segments)] prompt = ( "你是字幕翻译。把下列英文字幕段翻译为简体中文,保持原意、口语化、自然流畅。" "严格返回 JSON 数组,不要任何 markdown 或多余文字,schema: " '[{"i": 0, "zh": "..."}, ...]\n\n输入:\n' + json.dumps(payload, ensure_ascii=False) ) resp = llm().chat.completions.create( model=TRANSLATE_MODEL, messages=[{"role": "user", "content": prompt}], response_format={"type": "json_object"}, temperature=0.2, ) content = resp.choices[0].message.content or "[]" try: data = json.loads(content) if isinstance(data, dict): for k in ("data", "items", "result", "translations"): if k in data and isinstance(data[k], list): data = data[k] break if not isinstance(data, list): data = [] except json.JSONDecodeError: data = [] zh_by_idx: dict[int, str] = {} for it in data: if isinstance(it, dict) and "i" in it: zh_by_idx[int(it["i"])] = str(it.get("zh", "")) return [zh_by_idx.get(i, "") for i in range(len(segments))] async def pipeline_transcribe(job_id: str) -> None: job = JOBS[job_id] d = job_dir(job_id) wav = d / "audio.wav" try: if not wav.exists(): raise RuntimeError("audio.wav 不存在") if not LLM_API_KEY: # 无 key 模式:mock 数据 update(job, status="transcribing", message="ASR (mock) …", progress=75) await asyncio.sleep(1.0) mock = [ TranscriptSegment(index=0, start=0.0, end=3.5, en="Welcome back, today we're testing something new.", zh="欢迎回来,今天我们要测试一些新东西。"), TranscriptSegment(index=1, start=3.5, end=7.2, en="This device looks really sleek and minimal.", zh="这个设备看起来非常时尚和简约。"), ] update(job, transcript=mock, status="transcribed", progress=100, message="转录完成(MOCK · 未设 LLM_API_KEY)") return # 1) whisper ASR update(job, status="transcribing", message=f"{ASR_MODEL} 转录中…", progress=78) segments = await asyncio.to_thread(_transcribe_sync, wav) if not segments: raise RuntimeError("ASR 返回 0 段(可能无人声 / 格式问题)") # 先把英文段落落到 job 上(让 UI 提前看到,翻译再补 zh) en_only = [ TranscriptSegment( index=i, start=float(s.get("start", 0)), end=float(s.get("end", 0)), en=str(s.get("text", "")).strip(), zh="", ) for i, s in enumerate(segments) ] update(job, transcript=en_only, message=f"ASR 完成 · {len(en_only)} 段,开始翻译…", progress=88) # 2) Gemini 翻译 zh_list = await asyncio.to_thread(_translate_sync, segments) full = [ TranscriptSegment( index=seg.index, start=seg.start, end=seg.end, en=seg.en, zh=zh_list[i] if i < len(zh_list) else "", ) for i, seg in enumerate(en_only) ] update(job, transcript=full, status="transcribed", progress=100, message=f"转录完成 · {len(full)} 段({ASR_MODEL} + {TRANSLATE_MODEL})") except Exception as e: update(job, status="failed", error=str(e), message="转录失败") # ---------- API 路由 ---------- class CreateJobReq(BaseModel): url: str @app.get("/health") def health() -> dict: return { "ok": True, "llm_configured": bool(LLM_API_KEY), "base_url": LLM_BASE_URL or "openai-default", "models": { "asr": ASR_MODEL, "translate": TRANSLATE_MODEL, "rewrite": REWRITE_MODEL, }, } @app.post("/jobs", response_model=Job) async def create_job(req: CreateJobReq, bg: BackgroundTasks) -> Job: if not req.url.strip(): raise HTTPException(400, "url required") job_id = uuid.uuid4().hex[:12] job = Job(id=job_id, url=req.url.strip()) JOBS[job_id] = job save_state(job) bg.add_task(pipeline_download, job_id) return job @app.post("/jobs/upload", response_model=Job) async def create_job_from_upload(bg: BackgroundTasks, file: UploadFile = File(...)) -> Job: if not file.filename: raise HTTPException(400, "file required") ext = Path(file.filename).suffix.lower() if ext not in {".mp4", ".mov", ".webm", ".mkv", ".m4v"}: raise HTTPException(400, f"unsupported video format: {ext}") job_id = uuid.uuid4().hex[:12] d = job_dir(job_id) mp4 = d / "source.mp4" with mp4.open("wb") as f: while chunk := await file.read(1024 * 1024): f.write(chunk) if not mp4.exists() or mp4.stat().st_size == 0: raise HTTPException(500, "upload failed") job = Job(id=job_id, url=f"upload://{file.filename}") JOBS[job_id] = job save_state(job) bg.add_task(pipeline_download, job_id) return job @app.post("/jobs/{job_id}/analyze", response_model=Job) async def trigger_analyze(job_id: str, bg: BackgroundTasks, frames: int = KEYFRAME_COUNT) -> Job: job = JOBS.get(job_id) if not job: raise HTTPException(404, "job not found") if job.status not in {"downloaded", "frames_extracted", "transcribed", "failed"}: raise HTTPException(409, f"status must be downloaded/failed, got {job.status}") bg.add_task(pipeline_analyze, job_id, frames) return job @app.post("/jobs/{job_id}/frames", response_model=Job) def add_manual_frame(job_id: str, t: float) -> Job: """从指定时间戳手动抽 1 帧追加到 job.frames""" job = JOBS.get(job_id) if not job: raise HTTPException(404, "job not found") if not job.video_url: raise HTTPException(400, "video not ready") d = job_dir(job_id) mp4 = d / "source.mp4" if not mp4.exists(): raise HTTPException(400, "source.mp4 missing") frames_dir = d / "frames" frames_dir.mkdir(parents=True, exist_ok=True) # 新 index:max(existing)+1(即使列表已按 ts 排序,文件名用 index 保持稳定) next_idx = max((f.index for f in job.frames), default=-1) + 1 out = frames_dir / f"{next_idx:03d}.jpg" try: run([ "ffmpeg", "-y", "-ss", str(t), "-i", str(mp4), "-frames:v", "1", "-pix_fmt", "yuvj420p", "-q:v", "3", str(out), ]) except RuntimeError as e: raise HTTPException(500, f"ffmpeg failed: {e}") new_frame = KeyFrame( index=next_idx, timestamp=round(float(t), 2), url=f"/jobs/{job_id}/frames/{next_idx}.jpg", ) merged = sorted(list(job.frames) + [new_frame], key=lambda f: f.timestamp) update(job, frames=merged, message=f"已手动加帧({t:.1f}s),共 {len(merged)} 张") return job @app.get("/jobs/{job_id}", response_model=Job) def get_job(job_id: str) -> Job: job = JOBS.get(job_id) if not job: raise HTTPException(404, "job not found") return job @app.post("/jobs/{job_id}/transcribe", response_model=Job) async def trigger_transcribe(job_id: str, bg: BackgroundTasks) -> Job: job = JOBS.get(job_id) if not job: raise HTTPException(404, "job not found") if job.status != "frames_extracted": raise HTTPException(409, f"status must be frames_extracted, got {job.status}") bg.add_task(pipeline_transcribe, job_id) return job @app.get("/jobs/{job_id}/video.mp4") def get_video(job_id: str): p = job_dir(job_id) / "source.mp4" if not p.exists(): raise HTTPException(404, "video not found") return FileResponse(p, media_type="video/mp4") @app.get("/jobs/{job_id}/frames/{idx}.jpg") def get_frame(job_id: str, idx: int): p = job_dir(job_id) / "frames" / f"{idx:03d}.jpg" if not p.exists(): raise HTTPException(404, "frame not found") return FileResponse(p, media_type="image/jpeg")