Files
20260512-skg-tk/api/main.py
2026-05-12 16:16:52 +08:00

444 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import asyncio
import json
import os
import shutil
import subprocess
import uuid
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Literal
from dotenv import load_dotenv
from fastapi import BackgroundTasks, FastAPI, File, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from pydantic import BaseModel, Field
load_dotenv()
JOBS_DIR = Path(os.getenv("JOBS_DIR", "./jobs")).resolve()
JOBS_DIR.mkdir(parents=True, exist_ok=True)
CORS_ORIGINS = [o.strip() for o in os.getenv("CORS_ORIGINS", "http://localhost:4290").split(",") if o.strip()]
LLM_BASE_URL = os.getenv("LLM_BASE_URL", "").strip()
LLM_API_KEY = os.getenv("LLM_API_KEY", "").strip()
ASR_MODEL = os.getenv("ASR_MODEL", "whisper-1")
TRANSLATE_MODEL = os.getenv("TRANSLATE_MODEL", "gemini-2.5-flash")
REWRITE_MODEL = os.getenv("REWRITE_MODEL", "gemini-2.5-pro")
# OpenAI 客户端OpenAI 兼容网关,含 SKG ezlink
from openai import OpenAI
_llm_client: OpenAI | None = None
def llm() -> OpenAI:
global _llm_client
if _llm_client is None:
if not LLM_API_KEY:
raise RuntimeError("LLM_API_KEY 未配置")
_llm_client = OpenAI(base_url=LLM_BASE_URL or None, api_key=LLM_API_KEY)
return _llm_client
# Pipeline 状态created → downloading → splitting → frames_extracted → transcribing → transcribed | failed
JobStatus = Literal[
"created", "downloading", "splitting", "frames_extracted",
"transcribing", "transcribed", "failed",
]
class KeyFrame(BaseModel):
index: int
timestamp: float
url: str
class TranscriptSegment(BaseModel):
index: int
start: float
end: float
en: str
zh: str = ""
class Job(BaseModel):
id: str
url: str
status: JobStatus = "created"
progress: int = 0
message: str = ""
video_url: str = ""
duration: float = 0.0
width: int = 0
height: int = 0
frames: list[KeyFrame] = Field(default_factory=list)
transcript: list[TranscriptSegment] = Field(default_factory=list)
error: str = ""
JOBS: dict[str, Job] = {}
def job_dir(job_id: str) -> Path:
d = JOBS_DIR / job_id
d.mkdir(parents=True, exist_ok=True)
return d
def save_state(job: Job) -> None:
(job_dir(job.id) / "state.json").write_text(job.model_dump_json(indent=2))
def update(job: Job, **kw) -> None:
for k, v in kw.items():
setattr(job, k, v)
save_state(job)
@asynccontextmanager
async def lifespan(_: FastAPI):
# 启动时从磁盘恢复 jobs简化版只列目录
for p in JOBS_DIR.iterdir():
if p.is_dir() and (p / "state.json").exists():
try:
JOBS[p.name] = Job.model_validate_json((p / "state.json").read_text())
except Exception:
pass
yield
app = FastAPI(title="SKG TK 二创 API", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ---------- Pipeline 实现 ----------
def run(cmd: list[str], cwd: Path | None = None) -> str:
res = subprocess.run(cmd, cwd=cwd, capture_output=True, text=True)
if res.returncode != 0:
# ffmpeg 把 banner 写 stderr挑最后几行真错误一般在末尾
tail = "\n".join(res.stderr.splitlines()[-12:]) or res.stderr[-800:]
raise RuntimeError(f"cmd failed: {' '.join(cmd[:3])}... · {tail}")
return res.stdout
def ffprobe_meta(mp4: Path) -> dict:
out = run([
"ffprobe", "-v", "error", "-print_format", "json", "-show_streams", "-show_format", str(mp4),
])
return json.loads(out)
async def pipeline_download_split_frames(job_id: str) -> None:
"""步骤 1+2+3下载 + 拆音轨 + 抽取关键帧"""
job = JOBS[job_id]
d = job_dir(job_id)
try:
mp4 = d / "source.mp4"
# ---- 1. yt-dlp 下载(上传模式 mp4 已存在 → 跳过)
if mp4.exists():
update(job, status="downloading", message="本地上传,跳过下载", progress=15)
else:
update(job, status="downloading", message="yt-dlp 下载中…", progress=5)
run([
"yt-dlp", "-f", "best[ext=mp4]/best",
"-o", str(mp4),
"--no-warnings", "--no-playlist",
"--retries", "3",
job.url,
])
if not mp4.exists():
raise RuntimeError("下载完成但找不到 source.mp4")
# 元数据
meta = ffprobe_meta(mp4)
v_stream = next((s for s in meta["streams"] if s["codec_type"] == "video"), None)
duration = float(meta["format"]["duration"])
update(
job,
video_url=f"/jobs/{job_id}/video.mp4",
duration=duration,
width=int(v_stream["width"]) if v_stream else 0,
height=int(v_stream["height"]) if v_stream else 0,
progress=20,
message=f"下载完成 · {duration:.1f}s",
)
# ---- 2. 拆音轨
update(job, status="splitting", message="ffmpeg 拆分音轨…", progress=30)
wav = d / "audio.wav"
run([
"ffmpeg", "-y", "-i", str(mp4),
"-vn", "-ac", "1", "-ar", "16000", "-c:a", "pcm_s16le",
str(wav),
])
# ---- 3. 关键帧抽取(场景切换 + 均匀采样兜底,最多 10 张)
update(job, message="抽取关键帧…", progress=50)
frames_dir = d / "frames"
if frames_dir.exists():
shutil.rmtree(frames_dir)
frames_dir.mkdir(parents=True)
# 先用场景切换检测(失败时不阻塞,走均匀采样兜底)
try:
run([
"ffmpeg", "-y", "-i", str(mp4),
"-vf", "select='gt(scene,0.4)'",
"-fps_mode", "vfr",
"-frames:v", "30",
"-pix_fmt", "yuvj420p", # mjpeg encoder 要 JPEG full-range
"-q:v", "3",
str(frames_dir / "scene_%03d.jpg"),
])
except Exception:
# 场景切换检测在某些纯合成 / 静态视频上会失败,让它静默走兜底
pass
scene_frames = sorted(frames_dir.glob("scene_*.jpg"))
# 均匀采样兜底 / 补足
if len(scene_frames) < 10:
sample_count = 10 - len(scene_frames)
step = duration / (sample_count + 1)
for i in range(sample_count):
t = step * (i + 1)
out = frames_dir / f"sample_{i:03d}.jpg"
run([
"ffmpeg", "-y", "-ss", str(t), "-i", str(mp4),
"-frames:v", "1",
"-pix_fmt", "yuvj420p",
"-q:v", "3", str(out),
])
# 统一排序、按时间戳读取、限制 10 张
all_frames = sorted(frames_dir.glob("*.jpg"))[:10]
renamed: list[KeyFrame] = []
for i, src in enumerate(all_frames):
dst = frames_dir / f"{i:03d}.jpg"
if src != dst:
src.rename(dst)
# 简化:用均匀分布估算时间戳(场景切换的精确时间需要解析 showinfo 输出,先省)
ts = duration * (i + 0.5) / max(len(all_frames), 1)
renamed.append(KeyFrame(index=i, timestamp=round(ts, 2), url=f"/jobs/{job_id}/frames/{i}.jpg"))
update(
job,
status="frames_extracted",
frames=renamed,
progress=70,
message=f"已抽取 {len(renamed)} 张关键帧",
)
except Exception as e:
update(job, status="failed", error=str(e), message="管线失败")
# ---------- Gemini ASR + 翻译 ----------
def _transcribe_sync(wav: Path) -> list[dict]:
"""whisper-1 verbose_json → segments[{start, end, text}]"""
with wav.open("rb") as f:
resp = llm().audio.transcriptions.create(
file=(wav.name, f, "audio/wav"),
model=ASR_MODEL,
response_format="verbose_json",
timestamp_granularities=["segment"],
)
raw = resp.model_dump() if hasattr(resp, "model_dump") else resp
segments = raw.get("segments") or []
# 兜底:网关如果不返回 segments把全文当一段
if not segments and raw.get("text"):
segments = [{"start": 0.0, "end": float(raw.get("duration", 0) or 0), "text": raw["text"]}]
return segments
def _translate_sync(segments: list[dict]) -> list[str]:
"""gemini-2.5-flash 批量翻译为中文,按段返回"""
payload = [{"i": i, "en": s.get("text", "").strip()} for i, s in enumerate(segments)]
prompt = (
"你是字幕翻译。把下列英文字幕段翻译为简体中文,保持原意、口语化、自然流畅。"
"严格返回 JSON 数组,不要任何 markdown 或多余文字schema: "
'[{"i": 0, "zh": "..."}, ...]\n\n输入:\n'
+ json.dumps(payload, ensure_ascii=False)
)
resp = llm().chat.completions.create(
model=TRANSLATE_MODEL,
messages=[{"role": "user", "content": prompt}],
response_format={"type": "json_object"},
temperature=0.2,
)
content = resp.choices[0].message.content or "[]"
try:
data = json.loads(content)
if isinstance(data, dict):
for k in ("data", "items", "result", "translations"):
if k in data and isinstance(data[k], list):
data = data[k]
break
if not isinstance(data, list):
data = []
except json.JSONDecodeError:
data = []
zh_by_idx: dict[int, str] = {}
for it in data:
if isinstance(it, dict) and "i" in it:
zh_by_idx[int(it["i"])] = str(it.get("zh", ""))
return [zh_by_idx.get(i, "") for i in range(len(segments))]
async def pipeline_transcribe(job_id: str) -> None:
job = JOBS[job_id]
d = job_dir(job_id)
wav = d / "audio.wav"
try:
if not wav.exists():
raise RuntimeError("audio.wav 不存在")
if not LLM_API_KEY:
# 无 key 模式mock 数据
update(job, status="transcribing", message="ASR (mock) …", progress=75)
await asyncio.sleep(1.0)
mock = [
TranscriptSegment(index=0, start=0.0, end=3.5,
en="Welcome back, today we're testing something new.",
zh="欢迎回来,今天我们要测试一些新东西。"),
TranscriptSegment(index=1, start=3.5, end=7.2,
en="This device looks really sleek and minimal.",
zh="这个设备看起来非常时尚和简约。"),
]
update(job, transcript=mock, status="transcribed", progress=100,
message="转录完成MOCK · 未设 LLM_API_KEY")
return
# 1) whisper ASR
update(job, status="transcribing", message=f"{ASR_MODEL} 转录中…", progress=78)
segments = await asyncio.to_thread(_transcribe_sync, wav)
if not segments:
raise RuntimeError("ASR 返回 0 段(可能无人声 / 格式问题)")
# 先把英文段落落到 job 上(让 UI 提前看到,翻译再补 zh
en_only = [
TranscriptSegment(
index=i,
start=float(s.get("start", 0)),
end=float(s.get("end", 0)),
en=str(s.get("text", "")).strip(),
zh="",
)
for i, s in enumerate(segments)
]
update(job, transcript=en_only, message=f"ASR 完成 · {len(en_only)} 段,开始翻译…", progress=88)
# 2) Gemini 翻译
zh_list = await asyncio.to_thread(_translate_sync, segments)
full = [
TranscriptSegment(
index=seg.index, start=seg.start, end=seg.end, en=seg.en,
zh=zh_list[i] if i < len(zh_list) else "",
)
for i, seg in enumerate(en_only)
]
update(job, transcript=full, status="transcribed", progress=100,
message=f"转录完成 · {len(full)} 段({ASR_MODEL} + {TRANSLATE_MODEL}")
except Exception as e:
update(job, status="failed", error=str(e), message="转录失败")
# ---------- API 路由 ----------
class CreateJobReq(BaseModel):
url: str
@app.get("/health")
def health() -> dict:
return {
"ok": True,
"llm_configured": bool(LLM_API_KEY),
"base_url": LLM_BASE_URL or "openai-default",
"models": {
"asr": ASR_MODEL,
"translate": TRANSLATE_MODEL,
"rewrite": REWRITE_MODEL,
},
}
@app.post("/jobs", response_model=Job)
async def create_job(req: CreateJobReq, bg: BackgroundTasks) -> Job:
if not req.url.strip():
raise HTTPException(400, "url required")
job_id = uuid.uuid4().hex[:12]
job = Job(id=job_id, url=req.url.strip())
JOBS[job_id] = job
save_state(job)
bg.add_task(pipeline_download_split_frames, job_id)
return job
@app.post("/jobs/upload", response_model=Job)
async def create_job_from_upload(bg: BackgroundTasks, file: UploadFile = File(...)) -> Job:
if not file.filename:
raise HTTPException(400, "file required")
# 简化:只验后缀,不嗅探 magic bytes
ext = Path(file.filename).suffix.lower()
if ext not in {".mp4", ".mov", ".webm", ".mkv", ".m4v"}:
raise HTTPException(400, f"unsupported video format: {ext}")
job_id = uuid.uuid4().hex[:12]
d = job_dir(job_id)
mp4 = d / "source.mp4"
# 直接落盘(流式写入,避免全量进内存)
with mp4.open("wb") as f:
while chunk := await file.read(1024 * 1024):
f.write(chunk)
if not mp4.exists() or mp4.stat().st_size == 0:
raise HTTPException(500, "upload failed")
job = Job(id=job_id, url=f"upload://{file.filename}")
JOBS[job_id] = job
save_state(job)
bg.add_task(pipeline_download_split_frames, job_id)
return job
@app.get("/jobs/{job_id}", response_model=Job)
def get_job(job_id: str) -> Job:
job = JOBS.get(job_id)
if not job:
raise HTTPException(404, "job not found")
return job
@app.post("/jobs/{job_id}/transcribe", response_model=Job)
async def trigger_transcribe(job_id: str, bg: BackgroundTasks) -> Job:
job = JOBS.get(job_id)
if not job:
raise HTTPException(404, "job not found")
if job.status != "frames_extracted":
raise HTTPException(409, f"status must be frames_extracted, got {job.status}")
bg.add_task(pipeline_transcribe, job_id)
return job
@app.get("/jobs/{job_id}/video.mp4")
def get_video(job_id: str):
p = job_dir(job_id) / "source.mp4"
if not p.exists():
raise HTTPException(404, "video not found")
return FileResponse(p, media_type="video/mp4")
@app.get("/jobs/{job_id}/frames/{idx}.jpg")
def get_frame(job_id: str, idx: int):
p = job_dir(job_id) / "frames" / f"{idx:03d}.jpg"
if not p.exists():
raise HTTPException(404, "frame not found")
return FileResponse(p, media_type="image/jpeg")