467 lines
16 KiB
Python
467 lines
16 KiB
Python
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
import os
|
||
import shutil
|
||
import subprocess
|
||
import uuid
|
||
from contextlib import asynccontextmanager
|
||
from pathlib import Path
|
||
from typing import Literal
|
||
|
||
from dotenv import load_dotenv
|
||
from fastapi import BackgroundTasks, FastAPI, File, HTTPException, UploadFile
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.responses import FileResponse
|
||
from pydantic import BaseModel, Field
|
||
|
||
load_dotenv()
|
||
|
||
JOBS_DIR = Path(os.getenv("JOBS_DIR", "./jobs")).resolve()
|
||
JOBS_DIR.mkdir(parents=True, exist_ok=True)
|
||
CORS_ORIGINS = [o.strip() for o in os.getenv("CORS_ORIGINS", "http://localhost:4290").split(",") if o.strip()]
|
||
|
||
LLM_BASE_URL = os.getenv("LLM_BASE_URL", "").strip()
|
||
LLM_API_KEY = os.getenv("LLM_API_KEY", "").strip()
|
||
ASR_MODEL = os.getenv("ASR_MODEL", "whisper-1")
|
||
TRANSLATE_MODEL = os.getenv("TRANSLATE_MODEL", "gemini-2.5-flash")
|
||
REWRITE_MODEL = os.getenv("REWRITE_MODEL", "gemini-2.5-pro")
|
||
|
||
# OpenAI 客户端(OpenAI 兼容网关,含 SKG ezlink)
|
||
from openai import OpenAI
|
||
_llm_client: OpenAI | None = None
|
||
def llm() -> OpenAI:
|
||
global _llm_client
|
||
if _llm_client is None:
|
||
if not LLM_API_KEY:
|
||
raise RuntimeError("LLM_API_KEY 未配置")
|
||
_llm_client = OpenAI(base_url=LLM_BASE_URL or None, api_key=LLM_API_KEY)
|
||
return _llm_client
|
||
|
||
# Pipeline 状态:
|
||
# created → downloading → downloaded(停,等用户点解析)→ splitting → frames_extracted
|
||
# → transcribing → transcribed | failed
|
||
JobStatus = Literal[
|
||
"created", "downloading", "downloaded",
|
||
"splitting", "frames_extracted",
|
||
"transcribing", "transcribed", "failed",
|
||
]
|
||
|
||
KEYFRAME_COUNT = int(os.getenv("KEYFRAME_COUNT", "5"))
|
||
|
||
|
||
class KeyFrame(BaseModel):
|
||
index: int
|
||
timestamp: float
|
||
url: str
|
||
|
||
|
||
class TranscriptSegment(BaseModel):
|
||
index: int
|
||
start: float
|
||
end: float
|
||
en: str
|
||
zh: str = ""
|
||
|
||
|
||
class Job(BaseModel):
|
||
id: str
|
||
url: str
|
||
status: JobStatus = "created"
|
||
progress: int = 0
|
||
message: str = ""
|
||
video_url: str = ""
|
||
duration: float = 0.0
|
||
width: int = 0
|
||
height: int = 0
|
||
frames: list[KeyFrame] = Field(default_factory=list)
|
||
transcript: list[TranscriptSegment] = Field(default_factory=list)
|
||
error: str = ""
|
||
|
||
|
||
JOBS: dict[str, Job] = {}
|
||
|
||
|
||
def job_dir(job_id: str) -> Path:
|
||
d = JOBS_DIR / job_id
|
||
d.mkdir(parents=True, exist_ok=True)
|
||
return d
|
||
|
||
|
||
def save_state(job: Job) -> None:
|
||
(job_dir(job.id) / "state.json").write_text(job.model_dump_json(indent=2))
|
||
|
||
|
||
def update(job: Job, **kw) -> None:
|
||
for k, v in kw.items():
|
||
setattr(job, k, v)
|
||
save_state(job)
|
||
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(_: FastAPI):
|
||
# 启动时从磁盘恢复 jobs(简化版:只列目录)
|
||
for p in JOBS_DIR.iterdir():
|
||
if p.is_dir() and (p / "state.json").exists():
|
||
try:
|
||
JOBS[p.name] = Job.model_validate_json((p / "state.json").read_text())
|
||
except Exception:
|
||
pass
|
||
yield
|
||
|
||
|
||
app = FastAPI(title="SKG TK 二创 API", lifespan=lifespan)
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=CORS_ORIGINS,
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
|
||
# ---------- Pipeline 实现 ----------
|
||
|
||
def run(cmd: list[str], cwd: Path | None = None) -> str:
|
||
res = subprocess.run(cmd, cwd=cwd, capture_output=True, text=True)
|
||
if res.returncode != 0:
|
||
# ffmpeg 把 banner 写 stderr,挑最后几行(真错误一般在末尾)
|
||
tail = "\n".join(res.stderr.splitlines()[-12:]) or res.stderr[-800:]
|
||
raise RuntimeError(f"cmd failed: {' '.join(cmd[:3])}... · {tail}")
|
||
return res.stdout
|
||
|
||
|
||
def ffprobe_meta(mp4: Path) -> dict:
|
||
out = run([
|
||
"ffprobe", "-v", "error", "-print_format", "json", "-show_streams", "-show_format", str(mp4),
|
||
])
|
||
return json.loads(out)
|
||
|
||
|
||
async def pipeline_download(job_id: str) -> None:
|
||
"""阶段 1:仅下载(或上传跳过),落 source.mp4,停在 downloaded 等用户点解析。"""
|
||
job = JOBS[job_id]
|
||
d = job_dir(job_id)
|
||
try:
|
||
mp4 = d / "source.mp4"
|
||
if mp4.exists():
|
||
update(job, status="downloading", message="本地上传 · 跳过下载", progress=15)
|
||
else:
|
||
update(job, status="downloading", message="yt-dlp 下载中…", progress=5)
|
||
run([
|
||
"yt-dlp", "-f", "best[ext=mp4]/best",
|
||
"-o", str(mp4),
|
||
"--no-warnings", "--no-playlist",
|
||
"--retries", "3",
|
||
job.url,
|
||
])
|
||
if not mp4.exists():
|
||
raise RuntimeError("下载完成但找不到 source.mp4")
|
||
|
||
meta = ffprobe_meta(mp4)
|
||
v_stream = next((s for s in meta["streams"] if s["codec_type"] == "video"), None)
|
||
duration = float(meta["format"]["duration"])
|
||
update(
|
||
job,
|
||
status="downloaded",
|
||
video_url=f"/jobs/{job_id}/video.mp4",
|
||
duration=duration,
|
||
width=int(v_stream["width"]) if v_stream else 0,
|
||
height=int(v_stream["height"]) if v_stream else 0,
|
||
progress=25,
|
||
message=f"视频就绪 · {duration:.1f}s · 等待解析",
|
||
)
|
||
except Exception as e:
|
||
update(job, status="failed", error=str(e), message="下载失败")
|
||
|
||
|
||
async def pipeline_analyze(job_id: str, frame_count: int = KEYFRAME_COUNT) -> None:
|
||
"""阶段 2:拆音轨 + 抽关键帧 + ASR + 翻译。需要 source.mp4 已存在。"""
|
||
job = JOBS[job_id]
|
||
d = job_dir(job_id)
|
||
try:
|
||
mp4 = d / "source.mp4"
|
||
if not mp4.exists():
|
||
raise RuntimeError("source.mp4 不存在,先完成下载")
|
||
|
||
update(job, status="splitting", message="ffmpeg 拆分音轨…", progress=35)
|
||
wav = d / "audio.wav"
|
||
run([
|
||
"ffmpeg", "-y", "-i", str(mp4),
|
||
"-vn", "-ac", "1", "-ar", "16000", "-c:a", "pcm_s16le",
|
||
str(wav),
|
||
])
|
||
|
||
n = max(1, min(int(frame_count), 20))
|
||
update(job, message=f"抽取 {n} 张关键帧…", progress=50)
|
||
frames_dir = d / "frames"
|
||
if frames_dir.exists():
|
||
shutil.rmtree(frames_dir)
|
||
frames_dir.mkdir(parents=True)
|
||
|
||
try:
|
||
run([
|
||
"ffmpeg", "-y", "-i", str(mp4),
|
||
"-vf", "select='gt(scene,0.4)'",
|
||
"-fps_mode", "vfr",
|
||
"-frames:v", str(n * 3),
|
||
"-pix_fmt", "yuvj420p",
|
||
"-q:v", "3",
|
||
str(frames_dir / "scene_%03d.jpg"),
|
||
])
|
||
except Exception:
|
||
pass
|
||
scene_frames = sorted(frames_dir.glob("scene_*.jpg"))
|
||
|
||
if len(scene_frames) < n:
|
||
sample_count = n - len(scene_frames)
|
||
duration = job.duration or 1.0
|
||
step = duration / (sample_count + 1)
|
||
for i in range(sample_count):
|
||
t = step * (i + 1)
|
||
out = frames_dir / f"sample_{i:03d}.jpg"
|
||
run([
|
||
"ffmpeg", "-y", "-ss", str(t), "-i", str(mp4),
|
||
"-frames:v", "1",
|
||
"-pix_fmt", "yuvj420p",
|
||
"-q:v", "3", str(out),
|
||
])
|
||
|
||
all_frames = sorted(frames_dir.glob("*.jpg"))[:n]
|
||
renamed: list[KeyFrame] = []
|
||
for i, src in enumerate(all_frames):
|
||
dst = frames_dir / f"{i:03d}.jpg"
|
||
if src != dst:
|
||
src.rename(dst)
|
||
ts = (job.duration or 0) * (i + 0.5) / max(len(all_frames), 1)
|
||
renamed.append(KeyFrame(index=i, timestamp=round(ts, 2), url=f"/jobs/{job_id}/frames/{i}.jpg"))
|
||
|
||
update(
|
||
job,
|
||
status="frames_extracted",
|
||
frames=renamed,
|
||
progress=70,
|
||
message=f"已抽取 {len(renamed)} 张关键帧",
|
||
)
|
||
|
||
# 自动接 ASR + 翻译
|
||
await pipeline_transcribe(job_id)
|
||
|
||
except Exception as e:
|
||
update(job, status="failed", error=str(e), message="解析失败")
|
||
|
||
|
||
# ---------- Gemini ASR + 翻译 ----------
|
||
|
||
def _transcribe_sync(wav: Path) -> list[dict]:
|
||
"""whisper-1 verbose_json → segments[{start, end, text}]"""
|
||
with wav.open("rb") as f:
|
||
resp = llm().audio.transcriptions.create(
|
||
file=(wav.name, f, "audio/wav"),
|
||
model=ASR_MODEL,
|
||
response_format="verbose_json",
|
||
timestamp_granularities=["segment"],
|
||
)
|
||
raw = resp.model_dump() if hasattr(resp, "model_dump") else resp
|
||
segments = raw.get("segments") or []
|
||
# 兜底:网关如果不返回 segments,把全文当一段
|
||
if not segments and raw.get("text"):
|
||
segments = [{"start": 0.0, "end": float(raw.get("duration", 0) or 0), "text": raw["text"]}]
|
||
return segments
|
||
|
||
|
||
def _translate_sync(segments: list[dict]) -> list[str]:
|
||
"""gemini-2.5-flash 批量翻译为中文,按段返回"""
|
||
payload = [{"i": i, "en": s.get("text", "").strip()} for i, s in enumerate(segments)]
|
||
prompt = (
|
||
"你是字幕翻译。把下列英文字幕段翻译为简体中文,保持原意、口语化、自然流畅。"
|
||
"严格返回 JSON 数组,不要任何 markdown 或多余文字,schema: "
|
||
'[{"i": 0, "zh": "..."}, ...]\n\n输入:\n'
|
||
+ json.dumps(payload, ensure_ascii=False)
|
||
)
|
||
resp = llm().chat.completions.create(
|
||
model=TRANSLATE_MODEL,
|
||
messages=[{"role": "user", "content": prompt}],
|
||
response_format={"type": "json_object"},
|
||
temperature=0.2,
|
||
)
|
||
content = resp.choices[0].message.content or "[]"
|
||
try:
|
||
data = json.loads(content)
|
||
if isinstance(data, dict):
|
||
for k in ("data", "items", "result", "translations"):
|
||
if k in data and isinstance(data[k], list):
|
||
data = data[k]
|
||
break
|
||
if not isinstance(data, list):
|
||
data = []
|
||
except json.JSONDecodeError:
|
||
data = []
|
||
zh_by_idx: dict[int, str] = {}
|
||
for it in data:
|
||
if isinstance(it, dict) and "i" in it:
|
||
zh_by_idx[int(it["i"])] = str(it.get("zh", ""))
|
||
return [zh_by_idx.get(i, "") for i in range(len(segments))]
|
||
|
||
|
||
async def pipeline_transcribe(job_id: str) -> None:
|
||
job = JOBS[job_id]
|
||
d = job_dir(job_id)
|
||
wav = d / "audio.wav"
|
||
try:
|
||
if not wav.exists():
|
||
raise RuntimeError("audio.wav 不存在")
|
||
|
||
if not LLM_API_KEY:
|
||
# 无 key 模式:mock 数据
|
||
update(job, status="transcribing", message="ASR (mock) …", progress=75)
|
||
await asyncio.sleep(1.0)
|
||
mock = [
|
||
TranscriptSegment(index=0, start=0.0, end=3.5,
|
||
en="Welcome back, today we're testing something new.",
|
||
zh="欢迎回来,今天我们要测试一些新东西。"),
|
||
TranscriptSegment(index=1, start=3.5, end=7.2,
|
||
en="This device looks really sleek and minimal.",
|
||
zh="这个设备看起来非常时尚和简约。"),
|
||
]
|
||
update(job, transcript=mock, status="transcribed", progress=100,
|
||
message="转录完成(MOCK · 未设 LLM_API_KEY)")
|
||
return
|
||
|
||
# 1) whisper ASR
|
||
update(job, status="transcribing", message=f"{ASR_MODEL} 转录中…", progress=78)
|
||
segments = await asyncio.to_thread(_transcribe_sync, wav)
|
||
if not segments:
|
||
raise RuntimeError("ASR 返回 0 段(可能无人声 / 格式问题)")
|
||
|
||
# 先把英文段落落到 job 上(让 UI 提前看到,翻译再补 zh)
|
||
en_only = [
|
||
TranscriptSegment(
|
||
index=i,
|
||
start=float(s.get("start", 0)),
|
||
end=float(s.get("end", 0)),
|
||
en=str(s.get("text", "")).strip(),
|
||
zh="",
|
||
)
|
||
for i, s in enumerate(segments)
|
||
]
|
||
update(job, transcript=en_only, message=f"ASR 完成 · {len(en_only)} 段,开始翻译…", progress=88)
|
||
|
||
# 2) Gemini 翻译
|
||
zh_list = await asyncio.to_thread(_translate_sync, segments)
|
||
full = [
|
||
TranscriptSegment(
|
||
index=seg.index, start=seg.start, end=seg.end, en=seg.en,
|
||
zh=zh_list[i] if i < len(zh_list) else "",
|
||
)
|
||
for i, seg in enumerate(en_only)
|
||
]
|
||
update(job, transcript=full, status="transcribed", progress=100,
|
||
message=f"转录完成 · {len(full)} 段({ASR_MODEL} + {TRANSLATE_MODEL})")
|
||
|
||
except Exception as e:
|
||
update(job, status="failed", error=str(e), message="转录失败")
|
||
|
||
|
||
# ---------- API 路由 ----------
|
||
|
||
class CreateJobReq(BaseModel):
|
||
url: str
|
||
|
||
|
||
@app.get("/health")
|
||
def health() -> dict:
|
||
return {
|
||
"ok": True,
|
||
"llm_configured": bool(LLM_API_KEY),
|
||
"base_url": LLM_BASE_URL or "openai-default",
|
||
"models": {
|
||
"asr": ASR_MODEL,
|
||
"translate": TRANSLATE_MODEL,
|
||
"rewrite": REWRITE_MODEL,
|
||
},
|
||
}
|
||
|
||
|
||
@app.post("/jobs", response_model=Job)
|
||
async def create_job(req: CreateJobReq, bg: BackgroundTasks) -> Job:
|
||
if not req.url.strip():
|
||
raise HTTPException(400, "url required")
|
||
job_id = uuid.uuid4().hex[:12]
|
||
job = Job(id=job_id, url=req.url.strip())
|
||
JOBS[job_id] = job
|
||
save_state(job)
|
||
bg.add_task(pipeline_download, job_id)
|
||
return job
|
||
|
||
|
||
@app.post("/jobs/upload", response_model=Job)
|
||
async def create_job_from_upload(bg: BackgroundTasks, file: UploadFile = File(...)) -> Job:
|
||
if not file.filename:
|
||
raise HTTPException(400, "file required")
|
||
ext = Path(file.filename).suffix.lower()
|
||
if ext not in {".mp4", ".mov", ".webm", ".mkv", ".m4v"}:
|
||
raise HTTPException(400, f"unsupported video format: {ext}")
|
||
|
||
job_id = uuid.uuid4().hex[:12]
|
||
d = job_dir(job_id)
|
||
mp4 = d / "source.mp4"
|
||
with mp4.open("wb") as f:
|
||
while chunk := await file.read(1024 * 1024):
|
||
f.write(chunk)
|
||
if not mp4.exists() or mp4.stat().st_size == 0:
|
||
raise HTTPException(500, "upload failed")
|
||
|
||
job = Job(id=job_id, url=f"upload://{file.filename}")
|
||
JOBS[job_id] = job
|
||
save_state(job)
|
||
bg.add_task(pipeline_download, job_id)
|
||
return job
|
||
|
||
|
||
@app.post("/jobs/{job_id}/analyze", response_model=Job)
|
||
async def trigger_analyze(job_id: str, bg: BackgroundTasks, frames: int = KEYFRAME_COUNT) -> Job:
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
if job.status not in {"downloaded", "frames_extracted", "transcribed", "failed"}:
|
||
raise HTTPException(409, f"status must be downloaded/failed, got {job.status}")
|
||
bg.add_task(pipeline_analyze, job_id, frames)
|
||
return job
|
||
|
||
|
||
@app.get("/jobs/{job_id}", response_model=Job)
|
||
def get_job(job_id: str) -> Job:
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
return job
|
||
|
||
|
||
@app.post("/jobs/{job_id}/transcribe", response_model=Job)
|
||
async def trigger_transcribe(job_id: str, bg: BackgroundTasks) -> Job:
|
||
job = JOBS.get(job_id)
|
||
if not job:
|
||
raise HTTPException(404, "job not found")
|
||
if job.status != "frames_extracted":
|
||
raise HTTPException(409, f"status must be frames_extracted, got {job.status}")
|
||
bg.add_task(pipeline_transcribe, job_id)
|
||
return job
|
||
|
||
|
||
@app.get("/jobs/{job_id}/video.mp4")
|
||
def get_video(job_id: str):
|
||
p = job_dir(job_id) / "source.mp4"
|
||
if not p.exists():
|
||
raise HTTPException(404, "video not found")
|
||
return FileResponse(p, media_type="video/mp4")
|
||
|
||
|
||
@app.get("/jobs/{job_id}/frames/{idx}.jpg")
|
||
def get_frame(job_id: str, idx: int):
|
||
p = job_dir(job_id) / "frames" / f"{idx:03d}.jpg"
|
||
if not p.exists():
|
||
raise HTTPException(404, "frame not found")
|
||
return FileResponse(p, media_type="image/jpeg")
|