auto-save 2026-05-14 10:20 (~7)
This commit is contained in:
227
api/main.py
227
api/main.py
@@ -12,6 +12,7 @@ from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import httpx
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import BackgroundTasks, FastAPI, File, HTTPException, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
@@ -36,6 +37,18 @@ REWRITE_MODEL = os.getenv("REWRITE_MODEL", "gemini-2.5-pro")
|
||||
VISION_MODEL = os.getenv("VISION_MODEL", "gemini-2.5-flash")
|
||||
IMAGE_MODEL = os.getenv("IMAGE_MODEL", "gemini-3-pro-image-preview")
|
||||
VIDEO_MODEL = os.getenv("VIDEO_MODEL", "seedance").strip() or "seedance"
|
||||
AUDIO_PRODUCT_BRIEF = os.getenv(
|
||||
"AUDIO_PRODUCT_BRIEF",
|
||||
"SKG 智能按摩产品,主打日常肩颈、腰背、眼部、膝盖或足部放松;广告表达要高级、干净、可信,不做医疗疗效承诺。",
|
||||
).strip()
|
||||
AUDIO_REWRITE_MODEL = os.getenv("AUDIO_REWRITE_MODEL", REWRITE_MODEL).strip() or REWRITE_MODEL
|
||||
MINIMAX_API_KEY = os.getenv("MINIMAX_API_KEY", "").strip()
|
||||
MINIMAX_TTS_BASE_URL = os.getenv("MINIMAX_TTS_BASE_URL", "https://api.minimax.io").strip().rstrip("/")
|
||||
MINIMAX_TTS_MODEL = os.getenv("MINIMAX_TTS_MODEL", "speech-2.8-turbo").strip() or "speech-2.8-turbo"
|
||||
MINIMAX_TTS_VOICE_ID = os.getenv(
|
||||
"MINIMAX_TTS_VOICE_ID",
|
||||
"Chinese (Mandarin)_Reliable_Executive",
|
||||
).strip() or "Chinese (Mandarin)_Reliable_Executive"
|
||||
|
||||
POE_API_BASE_URL = os.getenv("POE_API_BASE_URL", "https://api.poe.com/v1").strip() or "https://api.poe.com/v1"
|
||||
POE_API_KEY = os.getenv("POE_API_KEY", "").strip()
|
||||
@@ -337,6 +350,21 @@ class TranscriptSegment(BaseModel):
|
||||
zh: str = ""
|
||||
|
||||
|
||||
class AudioScript(BaseModel):
|
||||
status: Literal["idle", "rewriting", "completed", "failed"] = "idle"
|
||||
source_text: str = ""
|
||||
source_zh: str = ""
|
||||
rewritten_text: str = ""
|
||||
product_brief: str = ""
|
||||
rewrite_model: str = ""
|
||||
voice_provider: str = ""
|
||||
voice_model: str = ""
|
||||
voice_id: str = ""
|
||||
voice_url: str = ""
|
||||
error: str = ""
|
||||
created_at: float = 0.0
|
||||
|
||||
|
||||
class Job(BaseModel):
|
||||
id: str
|
||||
url: str
|
||||
@@ -349,6 +377,7 @@ class Job(BaseModel):
|
||||
height: int = 0
|
||||
frames: list[KeyFrame] = Field(default_factory=list)
|
||||
transcript: list[TranscriptSegment] = Field(default_factory=list)
|
||||
audio_script: AudioScript = Field(default_factory=AudioScript)
|
||||
storyboard_images: list[StoryboardImage] = Field(default_factory=list)
|
||||
generated_videos: list[GeneratedVideo] = Field(default_factory=list)
|
||||
error: str = ""
|
||||
@@ -1351,6 +1380,148 @@ def _translate_sync(segments: list[dict]) -> list[str]:
|
||||
return [zh_by_idx.get(i, "") for i in range(len(segments))]
|
||||
|
||||
|
||||
def _transcript_join(segments: list[TranscriptSegment], field: Literal["en", "zh"]) -> str:
|
||||
lines: list[str] = []
|
||||
for s in segments:
|
||||
text = (s.zh if field == "zh" else s.en).strip()
|
||||
if text:
|
||||
lines.append(f"[{s.start:.1f}-{s.end:.1f}s] {text}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _fallback_audio_script(segments: list[TranscriptSegment]) -> str:
|
||||
joined = " ".join((s.zh or s.en).strip() for s in segments if (s.zh or s.en).strip())
|
||||
if not joined:
|
||||
return "日常疲惫不用硬扛。戴上 SKG,让肩颈慢慢放松,跟着呼吸找回轻松状态。"
|
||||
return (
|
||||
"把日常紧绷交给 SKG。贴合身体需要放松的位置,热敷与按摩节奏自然陪伴,"
|
||||
"让每一次短暂休息都更轻松、更有质感。"
|
||||
)
|
||||
|
||||
|
||||
def _rewrite_audio_script_sync(segments: list[TranscriptSegment]) -> tuple[str, str]:
|
||||
fallback = _fallback_audio_script(segments)
|
||||
if not LLM_API_KEY:
|
||||
return fallback, "LLM_API_KEY 未配置,使用本地 SKG 模板"
|
||||
source_text = _transcript_join(segments, "en")
|
||||
source_zh = _transcript_join(segments, "zh")
|
||||
prompt = (
|
||||
"你是 SKG 短视频口播编导。根据参考视频音频转写,抽取它的表达结构、情绪节奏和可复用卖点,"
|
||||
"改写成适合 SKG 按摩/放松产品二创视频的中文口播文案。\n"
|
||||
"要求:\n"
|
||||
"1. 输出 35-90 个中文字,适合 8-18 秒短视频配音。\n"
|
||||
"2. 口语化、干净、高级,能直接给 TTS 朗读。\n"
|
||||
"3. 不承诺治疗、治愈、医学疗效,不夸大。\n"
|
||||
"4. 不复刻原视频品牌/人物/价格/平台话术,只保留表达结构。\n"
|
||||
"5. 如果参考转写信息不足,按产品信息生成通用 SKG 放松口播。\n"
|
||||
'严格返回 JSON:{"rewritten_text":"..."}。\n\n'
|
||||
f"SKG 产品信息:{AUDIO_PRODUCT_BRIEF}\n\n"
|
||||
f"英文转写:\n{source_text or '无'}\n\n"
|
||||
f"中文翻译:\n{source_zh or '无'}"
|
||||
)
|
||||
try:
|
||||
resp = llm().chat.completions.create(
|
||||
model=AUDIO_REWRITE_MODEL,
|
||||
messages=[
|
||||
{"role": "system", "content": "只输出合法 JSON,不要解释,不要 markdown。"},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
temperature=0.45,
|
||||
max_tokens=600,
|
||||
)
|
||||
raw = (resp.choices[0].message.content or "").strip()
|
||||
if raw.startswith("```"):
|
||||
import re as _re
|
||||
match = _re.search(r"\{[\s\S]*\}", raw)
|
||||
raw = match.group(0) if match else raw
|
||||
data = json.loads(raw)
|
||||
text = str(data.get("rewritten_text", "")).strip()
|
||||
return (text or fallback), ""
|
||||
except Exception as e:
|
||||
return fallback, f"改写失败,使用本地模板:{e}"
|
||||
|
||||
|
||||
def _minimax_tts_url() -> str:
|
||||
if MINIMAX_TTS_BASE_URL.endswith("/v1/t2a_v2"):
|
||||
return MINIMAX_TTS_BASE_URL
|
||||
return f"{MINIMAX_TTS_BASE_URL}/v1/t2a_v2"
|
||||
|
||||
|
||||
def _minimax_tts_sync(job_id: str, text: str) -> str:
|
||||
if not MINIMAX_API_KEY:
|
||||
raise RuntimeError("MINIMAX_API_KEY 未配置,未生成配音")
|
||||
if not text.strip():
|
||||
raise RuntimeError("改写文案为空,未生成配音")
|
||||
payload = {
|
||||
"model": MINIMAX_TTS_MODEL,
|
||||
"text": text.strip()[:9500],
|
||||
"stream": False,
|
||||
"language_boost": "Chinese",
|
||||
"output_format": "hex",
|
||||
"voice_setting": {
|
||||
"voice_id": MINIMAX_TTS_VOICE_ID,
|
||||
"speed": 1,
|
||||
"vol": 1,
|
||||
"pitch": 0,
|
||||
},
|
||||
"audio_setting": {
|
||||
"sample_rate": 32000,
|
||||
"bitrate": 128000,
|
||||
"format": "mp3",
|
||||
"channel": 1,
|
||||
},
|
||||
}
|
||||
resp = httpx.post(
|
||||
_minimax_tts_url(),
|
||||
headers={"Authorization": f"Bearer {MINIMAX_API_KEY}", "Content-Type": "application/json"},
|
||||
json=payload,
|
||||
timeout=90,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
base_resp = data.get("base_resp") or {}
|
||||
if int(base_resp.get("status_code", 0) or 0) != 0:
|
||||
raise RuntimeError(base_resp.get("status_msg") or "MiniMax TTS 返回失败")
|
||||
audio_hex = ((data.get("data") or {}).get("audio") or "").strip()
|
||||
if not audio_hex:
|
||||
raise RuntimeError("MiniMax TTS 未返回 audio hex")
|
||||
try:
|
||||
audio_bytes = bytes.fromhex(audio_hex)
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"MiniMax TTS audio hex 无法解析:{e}") from e
|
||||
out = job_dir(job_id) / "audio_script.mp3"
|
||||
out.write_bytes(audio_bytes)
|
||||
return f"/jobs/{job_id}/audio-script.mp3"
|
||||
|
||||
|
||||
def _build_audio_script_sync(job_id: str, segments: list[TranscriptSegment]) -> AudioScript:
|
||||
source_text = _transcript_join(segments, "en")
|
||||
source_zh = _transcript_join(segments, "zh")
|
||||
rewritten, rewrite_error = _rewrite_audio_script_sync(segments)
|
||||
voice_url = ""
|
||||
voice_error = ""
|
||||
try:
|
||||
voice_url = _minimax_tts_sync(job_id, rewritten)
|
||||
except Exception as e:
|
||||
voice_error = str(e)
|
||||
errors = ";".join(x for x in [rewrite_error, voice_error] if x)
|
||||
return AudioScript(
|
||||
status="completed",
|
||||
source_text=source_text,
|
||||
source_zh=source_zh,
|
||||
rewritten_text=rewritten,
|
||||
product_brief=AUDIO_PRODUCT_BRIEF,
|
||||
rewrite_model=AUDIO_REWRITE_MODEL,
|
||||
voice_provider="minimax",
|
||||
voice_model=MINIMAX_TTS_MODEL,
|
||||
voice_id=MINIMAX_TTS_VOICE_ID,
|
||||
voice_url=voice_url,
|
||||
error=errors,
|
||||
created_at=time.time(),
|
||||
)
|
||||
|
||||
|
||||
async def pipeline_transcribe(job_id: str) -> None:
|
||||
job = JOBS[job_id]
|
||||
d = job_dir(job_id)
|
||||
@@ -1371,7 +1542,25 @@ async def pipeline_transcribe(job_id: str) -> None:
|
||||
en="This device looks really sleek and minimal.",
|
||||
zh="这个设备看起来非常时尚和简约。"),
|
||||
]
|
||||
update(
|
||||
job,
|
||||
transcript=mock,
|
||||
audio_script=AudioScript(
|
||||
status="rewriting",
|
||||
source_text=_transcript_join(mock, "en"),
|
||||
source_zh=_transcript_join(mock, "zh"),
|
||||
product_brief=AUDIO_PRODUCT_BRIEF,
|
||||
rewrite_model=AUDIO_REWRITE_MODEL,
|
||||
voice_provider="minimax",
|
||||
voice_model=MINIMAX_TTS_MODEL,
|
||||
voice_id=MINIMAX_TTS_VOICE_ID,
|
||||
),
|
||||
message="ASR mock 完成,生成 SKG 改写文案…",
|
||||
progress=92,
|
||||
)
|
||||
audio_script = await asyncio.to_thread(_build_audio_script_sync, job_id, mock)
|
||||
update(job, transcript=mock, status="transcribed", progress=100,
|
||||
audio_script=audio_script,
|
||||
message="转录完成(MOCK · 未设 LLM_API_KEY)")
|
||||
return
|
||||
|
||||
@@ -1403,11 +1592,35 @@ async def pipeline_transcribe(job_id: str) -> None:
|
||||
)
|
||||
for i, seg in enumerate(en_only)
|
||||
]
|
||||
update(
|
||||
job,
|
||||
transcript=full,
|
||||
audio_script=AudioScript(
|
||||
status="rewriting",
|
||||
source_text=_transcript_join(full, "en"),
|
||||
source_zh=_transcript_join(full, "zh"),
|
||||
product_brief=AUDIO_PRODUCT_BRIEF,
|
||||
rewrite_model=AUDIO_REWRITE_MODEL,
|
||||
voice_provider="minimax",
|
||||
voice_model=MINIMAX_TTS_MODEL,
|
||||
voice_id=MINIMAX_TTS_VOICE_ID,
|
||||
),
|
||||
message="翻译完成,生成 SKG 改写文案与 MiniMax 配音…",
|
||||
progress=94,
|
||||
)
|
||||
audio_script = await asyncio.to_thread(_build_audio_script_sync, job_id, full)
|
||||
update(job, transcript=full, status="transcribed", progress=100,
|
||||
audio_script=audio_script,
|
||||
message=f"转录完成 · {len(full)} 段({ASR_MODEL} + {TRANSLATE_MODEL})")
|
||||
|
||||
except Exception as e:
|
||||
update(job, status="failed", error=str(e), message="转录失败")
|
||||
update(
|
||||
job,
|
||||
status="failed",
|
||||
audio_script=AudioScript(status="failed", error=str(e), created_at=time.time()),
|
||||
error=str(e),
|
||||
message="转录失败",
|
||||
)
|
||||
|
||||
|
||||
def _image_edit_call(
|
||||
@@ -1566,6 +1779,10 @@ def health() -> dict:
|
||||
"asr": ASR_MODEL,
|
||||
"translate": TRANSLATE_MODEL,
|
||||
"rewrite": REWRITE_MODEL,
|
||||
"audio_rewrite": AUDIO_REWRITE_MODEL,
|
||||
"minimax_tts": MINIMAX_TTS_MODEL,
|
||||
"minimax_voice": MINIMAX_TTS_VOICE_ID,
|
||||
"minimax_configured": bool(MINIMAX_API_KEY),
|
||||
"video": VIDEO_MODEL,
|
||||
"video_aliases": VIDEO_MODEL_ALIASES,
|
||||
"video_provider": "poe" if video_uses_poe() else ("ark" if video_uses_ark() else "custom"),
|
||||
@@ -1765,6 +1982,14 @@ def get_video(job_id: str):
|
||||
return FileResponse(p, media_type="video/mp4")
|
||||
|
||||
|
||||
@app.get("/jobs/{job_id}/audio-script.mp3")
|
||||
def get_audio_script(job_id: str):
|
||||
p = job_dir(job_id) / "audio_script.mp3"
|
||||
if not p.exists():
|
||||
raise HTTPException(404, "audio script not found")
|
||||
return FileResponse(p, media_type="audio/mpeg")
|
||||
|
||||
|
||||
@app.get("/jobs/{job_id}/frames/{idx}.jpg")
|
||||
def get_frame(job_id: str, idx: int):
|
||||
p = job_dir(job_id) / "frames" / f"{idx:03d}.jpg"
|
||||
|
||||
Reference in New Issue
Block a user