auto-save 2026-05-14 10:20 (~7)

This commit is contained in:
2026-05-14 10:20:16 +08:00
parent ee32d83b6c
commit be1ae80750
7 changed files with 347 additions and 57 deletions

View File

@@ -12,6 +12,7 @@ from contextlib import asynccontextmanager
from pathlib import Path
from typing import Literal
import httpx
from dotenv import load_dotenv
from fastapi import BackgroundTasks, FastAPI, File, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
@@ -36,6 +37,18 @@ REWRITE_MODEL = os.getenv("REWRITE_MODEL", "gemini-2.5-pro")
VISION_MODEL = os.getenv("VISION_MODEL", "gemini-2.5-flash")
IMAGE_MODEL = os.getenv("IMAGE_MODEL", "gemini-3-pro-image-preview")
VIDEO_MODEL = os.getenv("VIDEO_MODEL", "seedance").strip() or "seedance"
AUDIO_PRODUCT_BRIEF = os.getenv(
"AUDIO_PRODUCT_BRIEF",
"SKG 智能按摩产品,主打日常肩颈、腰背、眼部、膝盖或足部放松;广告表达要高级、干净、可信,不做医疗疗效承诺。",
).strip()
AUDIO_REWRITE_MODEL = os.getenv("AUDIO_REWRITE_MODEL", REWRITE_MODEL).strip() or REWRITE_MODEL
MINIMAX_API_KEY = os.getenv("MINIMAX_API_KEY", "").strip()
MINIMAX_TTS_BASE_URL = os.getenv("MINIMAX_TTS_BASE_URL", "https://api.minimax.io").strip().rstrip("/")
MINIMAX_TTS_MODEL = os.getenv("MINIMAX_TTS_MODEL", "speech-2.8-turbo").strip() or "speech-2.8-turbo"
MINIMAX_TTS_VOICE_ID = os.getenv(
"MINIMAX_TTS_VOICE_ID",
"Chinese (Mandarin)_Reliable_Executive",
).strip() or "Chinese (Mandarin)_Reliable_Executive"
POE_API_BASE_URL = os.getenv("POE_API_BASE_URL", "https://api.poe.com/v1").strip() or "https://api.poe.com/v1"
POE_API_KEY = os.getenv("POE_API_KEY", "").strip()
@@ -337,6 +350,21 @@ class TranscriptSegment(BaseModel):
zh: str = ""
class AudioScript(BaseModel):
status: Literal["idle", "rewriting", "completed", "failed"] = "idle"
source_text: str = ""
source_zh: str = ""
rewritten_text: str = ""
product_brief: str = ""
rewrite_model: str = ""
voice_provider: str = ""
voice_model: str = ""
voice_id: str = ""
voice_url: str = ""
error: str = ""
created_at: float = 0.0
class Job(BaseModel):
id: str
url: str
@@ -349,6 +377,7 @@ class Job(BaseModel):
height: int = 0
frames: list[KeyFrame] = Field(default_factory=list)
transcript: list[TranscriptSegment] = Field(default_factory=list)
audio_script: AudioScript = Field(default_factory=AudioScript)
storyboard_images: list[StoryboardImage] = Field(default_factory=list)
generated_videos: list[GeneratedVideo] = Field(default_factory=list)
error: str = ""
@@ -1351,6 +1380,148 @@ def _translate_sync(segments: list[dict]) -> list[str]:
return [zh_by_idx.get(i, "") for i in range(len(segments))]
def _transcript_join(segments: list[TranscriptSegment], field: Literal["en", "zh"]) -> str:
lines: list[str] = []
for s in segments:
text = (s.zh if field == "zh" else s.en).strip()
if text:
lines.append(f"[{s.start:.1f}-{s.end:.1f}s] {text}")
return "\n".join(lines)
def _fallback_audio_script(segments: list[TranscriptSegment]) -> str:
joined = " ".join((s.zh or s.en).strip() for s in segments if (s.zh or s.en).strip())
if not joined:
return "日常疲惫不用硬扛。戴上 SKG让肩颈慢慢放松跟着呼吸找回轻松状态。"
return (
"把日常紧绷交给 SKG。贴合身体需要放松的位置热敷与按摩节奏自然陪伴"
"让每一次短暂休息都更轻松、更有质感。"
)
def _rewrite_audio_script_sync(segments: list[TranscriptSegment]) -> tuple[str, str]:
fallback = _fallback_audio_script(segments)
if not LLM_API_KEY:
return fallback, "LLM_API_KEY 未配置,使用本地 SKG 模板"
source_text = _transcript_join(segments, "en")
source_zh = _transcript_join(segments, "zh")
prompt = (
"你是 SKG 短视频口播编导。根据参考视频音频转写,抽取它的表达结构、情绪节奏和可复用卖点,"
"改写成适合 SKG 按摩/放松产品二创视频的中文口播文案。\n"
"要求:\n"
"1. 输出 35-90 个中文字,适合 8-18 秒短视频配音。\n"
"2. 口语化、干净、高级,能直接给 TTS 朗读。\n"
"3. 不承诺治疗、治愈、医学疗效,不夸大。\n"
"4. 不复刻原视频品牌/人物/价格/平台话术,只保留表达结构。\n"
"5. 如果参考转写信息不足,按产品信息生成通用 SKG 放松口播。\n"
'严格返回 JSON{"rewritten_text":"..."}。\n\n'
f"SKG 产品信息:{AUDIO_PRODUCT_BRIEF}\n\n"
f"英文转写:\n{source_text or ''}\n\n"
f"中文翻译:\n{source_zh or ''}"
)
try:
resp = llm().chat.completions.create(
model=AUDIO_REWRITE_MODEL,
messages=[
{"role": "system", "content": "只输出合法 JSON不要解释不要 markdown。"},
{"role": "user", "content": prompt},
],
response_format={"type": "json_object"},
temperature=0.45,
max_tokens=600,
)
raw = (resp.choices[0].message.content or "").strip()
if raw.startswith("```"):
import re as _re
match = _re.search(r"\{[\s\S]*\}", raw)
raw = match.group(0) if match else raw
data = json.loads(raw)
text = str(data.get("rewritten_text", "")).strip()
return (text or fallback), ""
except Exception as e:
return fallback, f"改写失败,使用本地模板:{e}"
def _minimax_tts_url() -> str:
if MINIMAX_TTS_BASE_URL.endswith("/v1/t2a_v2"):
return MINIMAX_TTS_BASE_URL
return f"{MINIMAX_TTS_BASE_URL}/v1/t2a_v2"
def _minimax_tts_sync(job_id: str, text: str) -> str:
if not MINIMAX_API_KEY:
raise RuntimeError("MINIMAX_API_KEY 未配置,未生成配音")
if not text.strip():
raise RuntimeError("改写文案为空,未生成配音")
payload = {
"model": MINIMAX_TTS_MODEL,
"text": text.strip()[:9500],
"stream": False,
"language_boost": "Chinese",
"output_format": "hex",
"voice_setting": {
"voice_id": MINIMAX_TTS_VOICE_ID,
"speed": 1,
"vol": 1,
"pitch": 0,
},
"audio_setting": {
"sample_rate": 32000,
"bitrate": 128000,
"format": "mp3",
"channel": 1,
},
}
resp = httpx.post(
_minimax_tts_url(),
headers={"Authorization": f"Bearer {MINIMAX_API_KEY}", "Content-Type": "application/json"},
json=payload,
timeout=90,
)
resp.raise_for_status()
data = resp.json()
base_resp = data.get("base_resp") or {}
if int(base_resp.get("status_code", 0) or 0) != 0:
raise RuntimeError(base_resp.get("status_msg") or "MiniMax TTS 返回失败")
audio_hex = ((data.get("data") or {}).get("audio") or "").strip()
if not audio_hex:
raise RuntimeError("MiniMax TTS 未返回 audio hex")
try:
audio_bytes = bytes.fromhex(audio_hex)
except ValueError as e:
raise RuntimeError(f"MiniMax TTS audio hex 无法解析:{e}") from e
out = job_dir(job_id) / "audio_script.mp3"
out.write_bytes(audio_bytes)
return f"/jobs/{job_id}/audio-script.mp3"
def _build_audio_script_sync(job_id: str, segments: list[TranscriptSegment]) -> AudioScript:
source_text = _transcript_join(segments, "en")
source_zh = _transcript_join(segments, "zh")
rewritten, rewrite_error = _rewrite_audio_script_sync(segments)
voice_url = ""
voice_error = ""
try:
voice_url = _minimax_tts_sync(job_id, rewritten)
except Exception as e:
voice_error = str(e)
errors = "".join(x for x in [rewrite_error, voice_error] if x)
return AudioScript(
status="completed",
source_text=source_text,
source_zh=source_zh,
rewritten_text=rewritten,
product_brief=AUDIO_PRODUCT_BRIEF,
rewrite_model=AUDIO_REWRITE_MODEL,
voice_provider="minimax",
voice_model=MINIMAX_TTS_MODEL,
voice_id=MINIMAX_TTS_VOICE_ID,
voice_url=voice_url,
error=errors,
created_at=time.time(),
)
async def pipeline_transcribe(job_id: str) -> None:
job = JOBS[job_id]
d = job_dir(job_id)
@@ -1371,7 +1542,25 @@ async def pipeline_transcribe(job_id: str) -> None:
en="This device looks really sleek and minimal.",
zh="这个设备看起来非常时尚和简约。"),
]
update(
job,
transcript=mock,
audio_script=AudioScript(
status="rewriting",
source_text=_transcript_join(mock, "en"),
source_zh=_transcript_join(mock, "zh"),
product_brief=AUDIO_PRODUCT_BRIEF,
rewrite_model=AUDIO_REWRITE_MODEL,
voice_provider="minimax",
voice_model=MINIMAX_TTS_MODEL,
voice_id=MINIMAX_TTS_VOICE_ID,
),
message="ASR mock 完成,生成 SKG 改写文案…",
progress=92,
)
audio_script = await asyncio.to_thread(_build_audio_script_sync, job_id, mock)
update(job, transcript=mock, status="transcribed", progress=100,
audio_script=audio_script,
message="转录完成MOCK · 未设 LLM_API_KEY")
return
@@ -1403,11 +1592,35 @@ async def pipeline_transcribe(job_id: str) -> None:
)
for i, seg in enumerate(en_only)
]
update(
job,
transcript=full,
audio_script=AudioScript(
status="rewriting",
source_text=_transcript_join(full, "en"),
source_zh=_transcript_join(full, "zh"),
product_brief=AUDIO_PRODUCT_BRIEF,
rewrite_model=AUDIO_REWRITE_MODEL,
voice_provider="minimax",
voice_model=MINIMAX_TTS_MODEL,
voice_id=MINIMAX_TTS_VOICE_ID,
),
message="翻译完成,生成 SKG 改写文案与 MiniMax 配音…",
progress=94,
)
audio_script = await asyncio.to_thread(_build_audio_script_sync, job_id, full)
update(job, transcript=full, status="transcribed", progress=100,
audio_script=audio_script,
message=f"转录完成 · {len(full)} 段({ASR_MODEL} + {TRANSLATE_MODEL}")
except Exception as e:
update(job, status="failed", error=str(e), message="转录失败")
update(
job,
status="failed",
audio_script=AudioScript(status="failed", error=str(e), created_at=time.time()),
error=str(e),
message="转录失败",
)
def _image_edit_call(
@@ -1566,6 +1779,10 @@ def health() -> dict:
"asr": ASR_MODEL,
"translate": TRANSLATE_MODEL,
"rewrite": REWRITE_MODEL,
"audio_rewrite": AUDIO_REWRITE_MODEL,
"minimax_tts": MINIMAX_TTS_MODEL,
"minimax_voice": MINIMAX_TTS_VOICE_ID,
"minimax_configured": bool(MINIMAX_API_KEY),
"video": VIDEO_MODEL,
"video_aliases": VIDEO_MODEL_ALIASES,
"video_provider": "poe" if video_uses_poe() else ("ark" if video_uses_ark() else "custom"),
@@ -1765,6 +1982,14 @@ def get_video(job_id: str):
return FileResponse(p, media_type="video/mp4")
@app.get("/jobs/{job_id}/audio-script.mp3")
def get_audio_script(job_id: str):
p = job_dir(job_id) / "audio_script.mp3"
if not p.exists():
raise HTTPException(404, "audio script not found")
return FileResponse(p, media_type="audio/mpeg")
@app.get("/jobs/{job_id}/frames/{idx}.jpg")
def get_frame(job_id: str, idx: int):
p = job_dir(job_id) / "frames" / f"{idx:03d}.jpg"