diff --git a/app.py b/app.py new file mode 100644 index 0000000..2d9152b --- /dev/null +++ b/app.py @@ -0,0 +1,779 @@ +""" +VibeVoice 体验平台 — Liquid Glass 风格 +FastAPI 后端 + 纯 HTML 前端 +""" + +import os +import sys +import json +import torch +import numpy as np +import tempfile +import time +import soundfile as sf +from pathlib import Path +from fastapi import FastAPI, UploadFile, File, Form +from fastapi.responses import FileResponse, HTMLResponse, JSONResponse +from fastapi.staticfiles import StaticFiles +import uvicorn + +SOURCE_DIR = Path(__file__).parent / "source" +STATIC_DIR = Path(__file__).parent / "static" +sys.path.insert(0, str(SOURCE_DIR)) + +app = FastAPI() + +# ========== 全局状态 ========== +asr_model_cache = {} +tts_model_cache = {} + +DEVICE = "mps" if torch.backends.mps.is_available() else "cpu" +DTYPE = torch.float32 + + +def load_asr(): + if asr_model_cache: + return asr_model_cache + + from vibevoice.modular.modeling_vibevoice_asr import VibeVoiceASRForConditionalGeneration + from vibevoice.processor.vibevoice_asr_processor import VibeVoiceASRProcessor + + print(f"Loading ASR model to {DEVICE}...") + processor = VibeVoiceASRProcessor.from_pretrained("microsoft/VibeVoice-ASR") + model = VibeVoiceASRForConditionalGeneration.from_pretrained( + "microsoft/VibeVoice-ASR", + torch_dtype=DTYPE, + attn_implementation="sdpa", + trust_remote_code=True + ) + model = model.to(DEVICE) + model.eval() + asr_model_cache["model"] = model + asr_model_cache["processor"] = processor + print("ASR model loaded") + return asr_model_cache + + +def load_tts(): + if tts_model_cache: + return tts_model_cache + + from vibevoice.modular.modeling_vibevoice_streaming_inference import ( + VibeVoiceStreamingForConditionalGenerationInference, + ) + from vibevoice.processor.vibevoice_streaming_processor import VibeVoiceStreamingProcessor + + print(f"Loading TTS model to {DEVICE}...") + processor = VibeVoiceStreamingProcessor.from_pretrained("microsoft/VibeVoice-Realtime-0.5B") + model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( + "microsoft/VibeVoice-Realtime-0.5B", + torch_dtype=DTYPE, + attn_implementation="sdpa", + ) + model = model.to(DEVICE) + model.eval() + tts_model_cache["model"] = model + tts_model_cache["processor"] = processor + print("TTS model loaded") + return tts_model_cache + + +@app.post("/api/asr") +async def api_asr(audio: UploadFile = File(...), hotwords: str = Form("")): + tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) + content = await audio.read() + tmp.write(content) + tmp.close() + + try: + asr = load_asr() + model = asr["model"] + processor = asr["processor"] + + context_info = hotwords.strip() if hotwords.strip() else None + inputs = processor( + audio=tmp.name, + sampling_rate=None, + return_tensors="pt", + add_generation_prompt=True, + context_info=context_info + ) + inputs = {k: v.to(DEVICE) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} + + start_time = time.time() + with torch.no_grad(): + output_ids = model.generate( + **inputs, + max_new_tokens=32768, + do_sample=False, + pad_token_id=processor.pad_id, + eos_token_id=processor.tokenizer.eos_token_id, + ) + + elapsed = time.time() - start_time + input_length = inputs['input_ids'].shape[1] + generated_ids = output_ids[0, input_length:] + text = processor.decode(generated_ids, skip_special_tokens=True) + + try: + segments = processor.post_process_transcription(text) + except Exception: + segments = [{"text": text}] + + return JSONResponse({"segments": segments, "raw": text, "time": round(elapsed, 1)}) + except Exception as e: + return JSONResponse({"error": str(e)}, status_code=500) + finally: + os.unlink(tmp.name) + + +@app.post("/api/tts") +async def api_tts(text: str = Form(...)): + if not text.strip(): + return JSONResponse({"error": "empty text"}, status_code=400) + + try: + tts = load_tts() + model = tts["model"] + processor = tts["processor"] + + voices_dir = SOURCE_DIR / "demo" / "voices" / "streaming_model" + voice_files = list(voices_dir.rglob("*.pt")) if voices_dir.exists() else [] + if not voice_files: + return JSONResponse({"error": "no voice presets found"}, status_code=500) + + prefilled = torch.load(voice_files[0], map_location=DEVICE, weights_only=False) + processed = processor.process_input_with_cached_prompt( + text=text.strip(), + cached_prompt=prefilled, + padding=True, + return_tensors="pt", + return_attention_mask=True, + ) + inputs = {k: v.to(DEVICE) if hasattr(v, "to") else v for k, v in processed.items()} + + from vibevoice.modular.streamer import AudioStreamer + import copy, threading + + audio_streamer = AudioStreamer(batch_size=1, stop_signal=None, timeout=None) + errors = [] + + model.model.noise_scheduler = model.model.noise_scheduler.from_config( + model.model.noise_scheduler.config, + algorithm_type="sde-dpmsolver++", + beta_schedule="squaredcos_cap_v2", + ) + model.set_ddpm_inference_steps(num_steps=5) + + stop_event = threading.Event() + + def run_gen(): + try: + model.generate( + **inputs, + max_new_tokens=None, + cfg_scale=1.5, + tokenizer=processor.tokenizer, + generation_config={"do_sample": False, "temperature": 1.0, "top_p": 1.0}, + audio_streamer=audio_streamer, + stop_check_fn=stop_event.is_set, + verbose=False, + refresh_negative=True, + all_prefilled_outputs=copy.deepcopy(prefilled), + ) + except Exception as e: + errors.append(e) + audio_streamer.end() + + thread = threading.Thread(target=run_gen, daemon=True) + thread.start() + + audio_chunks = [] + for chunk in audio_streamer.get_stream(0): + if torch.is_tensor(chunk): + chunk = chunk.detach().cpu().to(torch.float32).numpy() + else: + chunk = np.asarray(chunk, dtype=np.float32) + if chunk.ndim > 1: + chunk = chunk.reshape(-1) + audio_chunks.append(chunk) + + thread.join() + if errors: + return JSONResponse({"error": str(errors[0])}, status_code=500) + + audio = np.clip(np.concatenate(audio_chunks), -1.0, 1.0) + tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False, dir="/tmp") + sf.write(tmp.name, audio, 24000) + return FileResponse(tmp.name, media_type="audio/wav", filename="vibevoice_tts.wav") + + except Exception as e: + import traceback + traceback.print_exc() + return JSONResponse({"error": str(e)}, status_code=500) + + +@app.get("/") +def index(): + return HTMLResponse(HTML_PAGE) + + +HTML_PAGE = r""" + +
+ + +Microsoft 开源语音 AI — 语音识别 & 语音合成
+示例:今天我们来讲民法典中关于不当得利的规定。根据民法典第九百八十五条,得利人没有法律根据取得不当利益的,受损失的人可以请求得利人返还取得的利益。
+微软开源语音全家桶,ASR+TTS+实时语音,可用于法考字幕提取
+微软开源 | ASR + TTS + 实时语音 | MIT 许可
+| 技术 | 说明 |
|---|---|
| 连续语音 Tokenizer | 声学 + 语义双 Tokenizer,7.5Hz 超低帧率 |
| 长音频处理 | 单次 60 分钟,无需分段 |
| 说话人分离 | 自动识别 Who + When + What |
| 流式推理 | 边输入文字边生成语音,300ms 首音 |
| 热词支持 | 自定义专业术语提升识别率 |
| 维度 | Whisper | ElevenLabs | VibeVoice |
|---|---|---|---|
| ASR | 有 | 无 | 有(更强) |
| TTS | 无 | 有 | 有 |
| 实时流式 | 无 | 有 | 有 |
| 说话人识别 | 无 | 无 | 内置 |
| 长音频 | 需分段 | N/A | 60分钟单次 |
| 开源 | 是 | 否 | 是(MIT) |
| 费用 | 免费 | 按量付费 | 免费 |
9,553 个法考视频需要提取字幕。VibeVoice-ASR 单次处理 60 分钟 + 自动时间戳 + 说话人识别,配合法律热词("不当得利""善意取得"等)可显著提升识别率。
+ 高优先级 +用 Realtime-0.5B 为题目和解析生成语音朗读,支持边看题边听讲解,提升学习体验。
+ 中优先级 +用 VibeVoice-1.5B 为产品页面生成中英文语音介绍,50+ 语言支持覆盖海外客户。
+ 低优先级 +待补充研究内容...
+待补充...
+| 模型 | 显存需求 | M2 Max 可运行? |
|---|---|---|
| VibeVoice-ASR | ~8GB | 可以(MPS 加速) |
| VibeVoice-1.5B | ~6GB | 可以 |
| VibeVoice-Realtime-0.5B | ~2GB | 可以 |
+ 本机 M2 Max 64GB 完全满足所有模型运行要求 +
+ASR + TTS + 实时语音三合一开源方案,MIT 许可无商用限制。ASR 的 60 分钟长音频 + 说话人识别是真正的差异化优势。本机 M2 Max 可直接运行,不需要 GPU 服务器。对法考字幕提取项目有直接价值。
+