from __future__ import annotations import asyncio import json import os import shutil import subprocess import uuid from contextlib import asynccontextmanager from pathlib import Path from typing import Literal from dotenv import load_dotenv from fastapi import BackgroundTasks, FastAPI, File, HTTPException, UploadFile from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse from pydantic import BaseModel, Field load_dotenv() JOBS_DIR = Path(os.getenv("JOBS_DIR", "./jobs")).resolve() JOBS_DIR.mkdir(parents=True, exist_ok=True) CORS_ORIGINS = [o.strip() for o in os.getenv("CORS_ORIGINS", "http://localhost:4290").split(",") if o.strip()] GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "").strip() GEMINI_MODEL = os.getenv("GEMINI_MODEL", "gemini-2.5-flash") # Pipeline 状态:created → downloading → splitting → frames_extracted → transcribing → transcribed | failed JobStatus = Literal[ "created", "downloading", "splitting", "frames_extracted", "transcribing", "transcribed", "failed", ] class KeyFrame(BaseModel): index: int timestamp: float url: str class TranscriptSegment(BaseModel): index: int start: float end: float en: str zh: str = "" class Job(BaseModel): id: str url: str status: JobStatus = "created" progress: int = 0 message: str = "" video_url: str = "" duration: float = 0.0 width: int = 0 height: int = 0 frames: list[KeyFrame] = Field(default_factory=list) transcript: list[TranscriptSegment] = Field(default_factory=list) error: str = "" JOBS: dict[str, Job] = {} def job_dir(job_id: str) -> Path: d = JOBS_DIR / job_id d.mkdir(parents=True, exist_ok=True) return d def save_state(job: Job) -> None: (job_dir(job.id) / "state.json").write_text(job.model_dump_json(indent=2)) def update(job: Job, **kw) -> None: for k, v in kw.items(): setattr(job, k, v) save_state(job) @asynccontextmanager async def lifespan(_: FastAPI): # 启动时从磁盘恢复 jobs(简化版:只列目录) for p in JOBS_DIR.iterdir(): if p.is_dir() and (p / "state.json").exists(): try: JOBS[p.name] = Job.model_validate_json((p / "state.json").read_text()) except Exception: pass yield app = FastAPI(title="SKG TK 二创 API", lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=CORS_ORIGINS, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ---------- Pipeline 实现 ---------- def run(cmd: list[str], cwd: Path | None = None) -> str: res = subprocess.run(cmd, cwd=cwd, capture_output=True, text=True) if res.returncode != 0: # ffmpeg 把 banner 写 stderr,挑最后几行(真错误一般在末尾) tail = "\n".join(res.stderr.splitlines()[-12:]) or res.stderr[-800:] raise RuntimeError(f"cmd failed: {' '.join(cmd[:3])}... · {tail}") return res.stdout def ffprobe_meta(mp4: Path) -> dict: out = run([ "ffprobe", "-v", "error", "-print_format", "json", "-show_streams", "-show_format", str(mp4), ]) return json.loads(out) async def pipeline_download_split_frames(job_id: str) -> None: """步骤 1+2+3:下载 + 拆音轨 + 抽取关键帧""" job = JOBS[job_id] d = job_dir(job_id) try: mp4 = d / "source.mp4" # ---- 1. yt-dlp 下载(上传模式 mp4 已存在 → 跳过) if mp4.exists(): update(job, status="downloading", message="本地上传,跳过下载", progress=15) else: update(job, status="downloading", message="yt-dlp 下载中…", progress=5) run([ "yt-dlp", "-f", "best[ext=mp4]/best", "-o", str(mp4), "--no-warnings", "--no-playlist", "--retries", "3", job.url, ]) if not mp4.exists(): raise RuntimeError("下载完成但找不到 source.mp4") # 元数据 meta = ffprobe_meta(mp4) v_stream = next((s for s in meta["streams"] if s["codec_type"] == "video"), None) duration = float(meta["format"]["duration"]) update( job, video_url=f"/jobs/{job_id}/video.mp4", duration=duration, width=int(v_stream["width"]) if v_stream else 0, height=int(v_stream["height"]) if v_stream else 0, progress=20, message=f"下载完成 · {duration:.1f}s", ) # ---- 2. 拆音轨 update(job, status="splitting", message="ffmpeg 拆分音轨…", progress=30) wav = d / "audio.wav" run([ "ffmpeg", "-y", "-i", str(mp4), "-vn", "-ac", "1", "-ar", "16000", "-c:a", "pcm_s16le", str(wav), ]) # ---- 3. 关键帧抽取(场景切换 + 均匀采样兜底,最多 10 张) update(job, message="抽取关键帧…", progress=50) frames_dir = d / "frames" if frames_dir.exists(): shutil.rmtree(frames_dir) frames_dir.mkdir(parents=True) # 先用场景切换检测(失败时不阻塞,走均匀采样兜底) try: run([ "ffmpeg", "-y", "-i", str(mp4), "-vf", "select='gt(scene,0.4)'", "-fps_mode", "vfr", "-frames:v", "30", "-pix_fmt", "yuvj420p", # mjpeg encoder 要 JPEG full-range "-q:v", "3", str(frames_dir / "scene_%03d.jpg"), ]) except Exception: # 场景切换检测在某些纯合成 / 静态视频上会失败,让它静默走兜底 pass scene_frames = sorted(frames_dir.glob("scene_*.jpg")) # 均匀采样兜底 / 补足 if len(scene_frames) < 10: sample_count = 10 - len(scene_frames) step = duration / (sample_count + 1) for i in range(sample_count): t = step * (i + 1) out = frames_dir / f"sample_{i:03d}.jpg" run([ "ffmpeg", "-y", "-ss", str(t), "-i", str(mp4), "-frames:v", "1", "-pix_fmt", "yuvj420p", "-q:v", "3", str(out), ]) # 统一排序、按时间戳读取、限制 10 张 all_frames = sorted(frames_dir.glob("*.jpg"))[:10] renamed: list[KeyFrame] = [] for i, src in enumerate(all_frames): dst = frames_dir / f"{i:03d}.jpg" if src != dst: src.rename(dst) # 简化:用均匀分布估算时间戳(场景切换的精确时间需要解析 showinfo 输出,先省) ts = duration * (i + 0.5) / max(len(all_frames), 1) renamed.append(KeyFrame(index=i, timestamp=round(ts, 2), url=f"/jobs/{job_id}/frames/{i}.jpg")) update( job, status="frames_extracted", frames=renamed, progress=70, message=f"已抽取 {len(renamed)} 张关键帧", ) except Exception as e: update(job, status="failed", error=str(e), message="管线失败") # ---------- Gemini ASR + 翻译 ---------- async def pipeline_transcribe(job_id: str) -> None: job = JOBS[job_id] d = job_dir(job_id) wav = d / "audio.wav" try: if not wav.exists(): raise RuntimeError("audio.wav 不存在") update(job, status="transcribing", message="Gemini ASR 处理中…", progress=75) if not GEMINI_API_KEY: # 无 key 模式:mock 数据,方便 UI 联调 await asyncio.sleep(1.2) mock_segments = [ TranscriptSegment(index=0, start=0.0, end=3.5, en="Welcome back to my channel, today we're testing something new.", zh="欢迎回来我的频道,今天我们要测试一些新东西。"), TranscriptSegment(index=1, start=3.5, end=7.2, en="This device looks really sleek and the design is quite minimal.", zh="这个设备看起来非常时尚,设计也相当简约。"), TranscriptSegment(index=2, start=7.2, end=11.0, en="Let me show you how it works in real life situations.", zh="让我向你展示它在实际场景中如何工作。"), ] update(job, transcript=mock_segments, status="transcribed", progress=100, message="转录完成(MOCK 模式 · 未设 GEMINI_API_KEY)") return # 真模式:调 Gemini import google.generativeai as genai genai.configure(api_key=GEMINI_API_KEY) model = genai.GenerativeModel(GEMINI_MODEL) audio_file = genai.upload_file(str(wav), mime_type="audio/wav") prompt = ( "Transcribe the English audio with sentence-level timestamps. " "Then provide a Chinese translation for each segment. " "Return strictly as JSON array, no prose, schema: " '[{"start": float_seconds, "end": float_seconds, "en": "...", "zh": "..."}]' ) resp = await asyncio.to_thread( model.generate_content, [audio_file, prompt], generation_config={"response_mime_type": "application/json"}, ) raw = resp.text or "[]" data = json.loads(raw) segs = [ TranscriptSegment( index=i, start=float(s.get("start", 0)), end=float(s.get("end", 0)), en=str(s.get("en", "")), zh=str(s.get("zh", "")), ) for i, s in enumerate(data) ] update(job, transcript=segs, status="transcribed", progress=100, message=f"转录完成 · {len(segs)} 段") except Exception as e: update(job, status="failed", error=str(e), message="转录失败") # ---------- API 路由 ---------- class CreateJobReq(BaseModel): url: str @app.get("/health") def health() -> dict: return {"ok": True, "gemini_configured": bool(GEMINI_API_KEY), "model": GEMINI_MODEL} @app.post("/jobs", response_model=Job) async def create_job(req: CreateJobReq, bg: BackgroundTasks) -> Job: if not req.url.strip(): raise HTTPException(400, "url required") job_id = uuid.uuid4().hex[:12] job = Job(id=job_id, url=req.url.strip()) JOBS[job_id] = job save_state(job) bg.add_task(pipeline_download_split_frames, job_id) return job @app.post("/jobs/upload", response_model=Job) async def create_job_from_upload(bg: BackgroundTasks, file: UploadFile = File(...)) -> Job: if not file.filename: raise HTTPException(400, "file required") # 简化:只验后缀,不嗅探 magic bytes ext = Path(file.filename).suffix.lower() if ext not in {".mp4", ".mov", ".webm", ".mkv", ".m4v"}: raise HTTPException(400, f"unsupported video format: {ext}") job_id = uuid.uuid4().hex[:12] d = job_dir(job_id) mp4 = d / "source.mp4" # 直接落盘(流式写入,避免全量进内存) with mp4.open("wb") as f: while chunk := await file.read(1024 * 1024): f.write(chunk) if not mp4.exists() or mp4.stat().st_size == 0: raise HTTPException(500, "upload failed") job = Job(id=job_id, url=f"upload://{file.filename}") JOBS[job_id] = job save_state(job) bg.add_task(pipeline_download_split_frames, job_id) return job @app.get("/jobs/{job_id}", response_model=Job) def get_job(job_id: str) -> Job: job = JOBS.get(job_id) if not job: raise HTTPException(404, "job not found") return job @app.post("/jobs/{job_id}/transcribe", response_model=Job) async def trigger_transcribe(job_id: str, bg: BackgroundTasks) -> Job: job = JOBS.get(job_id) if not job: raise HTTPException(404, "job not found") if job.status != "frames_extracted": raise HTTPException(409, f"status must be frames_extracted, got {job.status}") bg.add_task(pipeline_transcribe, job_id) return job @app.get("/jobs/{job_id}/video.mp4") def get_video(job_id: str): p = job_dir(job_id) / "source.mp4" if not p.exists(): raise HTTPException(404, "video not found") return FileResponse(p, media_type="video/mp4") @app.get("/jobs/{job_id}/frames/{idx}.jpg") def get_frame(job_id: str, idx: int): p = job_dir(job_id) / "frames" / f"{idx:03d}.jpg" if not p.exists(): raise HTTPException(404, "frame not found") return FileResponse(p, media_type="image/jpeg")