fix: add cloud local asr fallback
This commit is contained in:
67
api/main.py
67
api/main.py
@@ -61,8 +61,12 @@ LLM_API_KEY = os.getenv("LLM_API_KEY", "").strip()
|
||||
ASR_BASE_URL = os.getenv("ASR_BASE_URL", LLM_BASE_URL).strip()
|
||||
ASR_API_KEY = (os.getenv("ASR_API_KEY") or LLM_API_KEY).strip()
|
||||
ASR_MODEL = os.getenv("ASR_MODEL", "whisper-1")
|
||||
ASR_REMOTE_ENABLED = os.getenv("ASR_REMOTE_ENABLED", "true").strip().lower() not in {"0", "false", "no", "off"}
|
||||
ASR_FALLBACK_MODEL = os.getenv("ASR_FALLBACK_MODEL", "gemini-2.5-flash").strip() or "gemini-2.5-flash"
|
||||
ASR_TIMEOUT_SECONDS = max(15, int(os.getenv("ASR_TIMEOUT_SECONDS", "45")))
|
||||
FASTER_WHISPER_MODEL = os.getenv("FASTER_WHISPER_MODEL", "tiny.en").strip() or "tiny.en"
|
||||
FASTER_WHISPER_DEVICE = os.getenv("FASTER_WHISPER_DEVICE", "cpu").strip() or "cpu"
|
||||
FASTER_WHISPER_COMPUTE_TYPE = os.getenv("FASTER_WHISPER_COMPUTE_TYPE", "int8").strip() or "int8"
|
||||
LOCAL_ASR_BIN = os.getenv("LOCAL_ASR_BIN", "").strip()
|
||||
LOCAL_ASR_MODEL = os.getenv("LOCAL_ASR_MODEL", "mlx-community/whisper-tiny").strip() or "mlx-community/whisper-tiny"
|
||||
LOCAL_ASR_TIMEOUT_SECONDS = max(30, int(os.getenv("LOCAL_ASR_TIMEOUT_SECONDS", "180")))
|
||||
@@ -2794,6 +2798,32 @@ def _transcribe_mlx_sync(wav: Path) -> list[dict]:
|
||||
return _validate_asr_segments(segments, duration, "mlx_whisper")
|
||||
|
||||
|
||||
def _transcribe_faster_whisper_sync(wav: Path) -> list[dict]:
|
||||
try:
|
||||
from faster_whisper import WhisperModel
|
||||
except Exception as e:
|
||||
raise TranscriptionUnavailable(f"faster-whisper 不可用:{e}") from e
|
||||
duration = media_duration(wav)
|
||||
model = WhisperModel(
|
||||
FASTER_WHISPER_MODEL,
|
||||
device=FASTER_WHISPER_DEVICE,
|
||||
compute_type=FASTER_WHISPER_COMPUTE_TYPE,
|
||||
)
|
||||
raw_segments, _info = model.transcribe(
|
||||
str(wav.resolve()),
|
||||
language="en",
|
||||
beam_size=1,
|
||||
vad_filter=True,
|
||||
condition_on_previous_text=False,
|
||||
)
|
||||
segments = [
|
||||
{"start": float(seg.start), "end": float(seg.end), "text": str(seg.text or "").strip()}
|
||||
for seg in raw_segments
|
||||
if str(seg.text or "").strip()
|
||||
]
|
||||
return _validate_asr_segments(segments, duration, f"faster-whisper:{FASTER_WHISPER_MODEL}")
|
||||
|
||||
|
||||
def _transcribe_gemini_sync(wav: Path) -> list[dict]:
|
||||
duration = media_duration(wav)
|
||||
audio_b64 = base64.b64encode(wav.read_bytes()).decode("ascii")
|
||||
@@ -2828,22 +2858,29 @@ def _transcribe_sync(wav: Path) -> list[dict]:
|
||||
"""Remote ASR first, local mlx_whisper second. Gemini fallback is guarded against fake timelines."""
|
||||
errors: list[str] = []
|
||||
duration = media_duration(wav)
|
||||
if ASR_REMOTE_ENABLED:
|
||||
try:
|
||||
with wav.open("rb") as f:
|
||||
resp = asr_llm().with_options(timeout=ASR_TIMEOUT_SECONDS).audio.transcriptions.create(
|
||||
file=(wav.name, f, "audio/wav"),
|
||||
model=ASR_MODEL,
|
||||
response_format="verbose_json",
|
||||
timestamp_granularities=["segment"],
|
||||
)
|
||||
raw = resp.model_dump() if hasattr(resp, "model_dump") else resp
|
||||
segments = raw.get("segments") or []
|
||||
# 兜底:网关如果不返回 segments,把全文当一段
|
||||
if not segments and raw.get("text"):
|
||||
segments = [{"start": 0.0, "end": float(raw.get("duration", 0) or 0), "text": raw["text"]}]
|
||||
return _validate_asr_segments(segments, duration, ASR_MODEL)
|
||||
except Exception as e:
|
||||
errors.append(f"{ASR_MODEL}: {e}")
|
||||
else:
|
||||
errors.append(f"{ASR_MODEL}: remote disabled")
|
||||
try:
|
||||
with wav.open("rb") as f:
|
||||
resp = asr_llm().with_options(timeout=ASR_TIMEOUT_SECONDS).audio.transcriptions.create(
|
||||
file=(wav.name, f, "audio/wav"),
|
||||
model=ASR_MODEL,
|
||||
response_format="verbose_json",
|
||||
timestamp_granularities=["segment"],
|
||||
)
|
||||
raw = resp.model_dump() if hasattr(resp, "model_dump") else resp
|
||||
segments = raw.get("segments") or []
|
||||
# 兜底:网关如果不返回 segments,把全文当一段
|
||||
if not segments and raw.get("text"):
|
||||
segments = [{"start": 0.0, "end": float(raw.get("duration", 0) or 0), "text": raw["text"]}]
|
||||
return _validate_asr_segments(segments, duration, ASR_MODEL)
|
||||
return _transcribe_faster_whisper_sync(wav)
|
||||
except Exception as e:
|
||||
errors.append(f"{ASR_MODEL}: {e}")
|
||||
errors.append(f"faster-whisper: {e}")
|
||||
try:
|
||||
return _transcribe_mlx_sync(wav)
|
||||
except Exception as e:
|
||||
@@ -3956,6 +3993,8 @@ def health() -> dict:
|
||||
"models": {
|
||||
"asr": ASR_MODEL,
|
||||
"asr_base_url": ASR_BASE_URL or LLM_BASE_URL or "openai-default",
|
||||
"asr_remote_enabled": ASR_REMOTE_ENABLED,
|
||||
"faster_whisper": FASTER_WHISPER_MODEL,
|
||||
"local_asr": LOCAL_ASR_MODEL,
|
||||
"asr_fallback": ASR_FALLBACK_MODEL,
|
||||
"translate": TRANSLATE_MODEL,
|
||||
|
||||
@@ -9,3 +9,4 @@ httpx==0.27.2
|
||||
imagehash==4.3.1
|
||||
Pillow>=11.0
|
||||
numpy>=2.0
|
||||
faster-whisper==1.1.1
|
||||
|
||||
Reference in New Issue
Block a user