""" VibeVoice 体验平台 — Liquid Glass 风格 FastAPI 后端 + 纯 HTML 前端 """ import os import sys import json import torch import numpy as np import tempfile import time import soundfile as sf from pathlib import Path from fastapi import FastAPI, UploadFile, File, Form from fastapi.responses import FileResponse, HTMLResponse, JSONResponse from fastapi.staticfiles import StaticFiles import uvicorn SOURCE_DIR = Path(__file__).parent / "source" STATIC_DIR = Path(__file__).parent / "static" sys.path.insert(0, str(SOURCE_DIR)) app = FastAPI() # ========== 全局状态 ========== asr_model_cache = {} tts_model_cache = {} DEVICE = "mps" if torch.backends.mps.is_available() else "cpu" DTYPE = torch.float32 def load_asr(): if asr_model_cache: return asr_model_cache from vibevoice.modular.modeling_vibevoice_asr import VibeVoiceASRForConditionalGeneration from vibevoice.processor.vibevoice_asr_processor import VibeVoiceASRProcessor print(f"Loading ASR model to {DEVICE}...") processor = VibeVoiceASRProcessor.from_pretrained("microsoft/VibeVoice-ASR") model = VibeVoiceASRForConditionalGeneration.from_pretrained( "microsoft/VibeVoice-ASR", torch_dtype=DTYPE, attn_implementation="sdpa", trust_remote_code=True ) model = model.to(DEVICE) model.eval() asr_model_cache["model"] = model asr_model_cache["processor"] = processor print("ASR model loaded") return asr_model_cache def load_tts(): if tts_model_cache: return tts_model_cache from vibevoice.modular.modeling_vibevoice_streaming_inference import ( VibeVoiceStreamingForConditionalGenerationInference, ) from vibevoice.processor.vibevoice_streaming_processor import VibeVoiceStreamingProcessor print(f"Loading TTS model to {DEVICE}...") processor = VibeVoiceStreamingProcessor.from_pretrained("microsoft/VibeVoice-Realtime-0.5B") model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( "microsoft/VibeVoice-Realtime-0.5B", torch_dtype=DTYPE, attn_implementation="sdpa", ) model = model.to(DEVICE) model.eval() tts_model_cache["model"] = model tts_model_cache["processor"] = processor print("TTS model loaded") return tts_model_cache @app.post("/api/asr") async def api_asr(audio: UploadFile = File(...), hotwords: str = Form("")): tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) content = await audio.read() tmp.write(content) tmp.close() try: asr = load_asr() model = asr["model"] processor = asr["processor"] context_info = hotwords.strip() if hotwords.strip() else None inputs = processor( audio=tmp.name, sampling_rate=None, return_tensors="pt", add_generation_prompt=True, context_info=context_info ) inputs = {k: v.to(DEVICE) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} start_time = time.time() with torch.no_grad(): output_ids = model.generate( **inputs, max_new_tokens=32768, do_sample=False, pad_token_id=processor.pad_id, eos_token_id=processor.tokenizer.eos_token_id, ) elapsed = time.time() - start_time input_length = inputs['input_ids'].shape[1] generated_ids = output_ids[0, input_length:] text = processor.decode(generated_ids, skip_special_tokens=True) try: segments = processor.post_process_transcription(text) except Exception: segments = [{"text": text}] return JSONResponse({"segments": segments, "raw": text, "time": round(elapsed, 1)}) except Exception as e: return JSONResponse({"error": str(e)}, status_code=500) finally: os.unlink(tmp.name) @app.post("/api/tts") async def api_tts(text: str = Form(...)): if not text.strip(): return JSONResponse({"error": "empty text"}, status_code=400) try: tts = load_tts() model = tts["model"] processor = tts["processor"] voices_dir = SOURCE_DIR / "demo" / "voices" / "streaming_model" voice_files = list(voices_dir.rglob("*.pt")) if voices_dir.exists() else [] if not voice_files: return JSONResponse({"error": "no voice presets found"}, status_code=500) prefilled = torch.load(voice_files[0], map_location=DEVICE, weights_only=False) processed = processor.process_input_with_cached_prompt( text=text.strip(), cached_prompt=prefilled, padding=True, return_tensors="pt", return_attention_mask=True, ) inputs = {k: v.to(DEVICE) if hasattr(v, "to") else v for k, v in processed.items()} from vibevoice.modular.streamer import AudioStreamer import copy, threading audio_streamer = AudioStreamer(batch_size=1, stop_signal=None, timeout=None) errors = [] model.model.noise_scheduler = model.model.noise_scheduler.from_config( model.model.noise_scheduler.config, algorithm_type="sde-dpmsolver++", beta_schedule="squaredcos_cap_v2", ) model.set_ddpm_inference_steps(num_steps=5) stop_event = threading.Event() def run_gen(): try: model.generate( **inputs, max_new_tokens=None, cfg_scale=1.5, tokenizer=processor.tokenizer, generation_config={"do_sample": False, "temperature": 1.0, "top_p": 1.0}, audio_streamer=audio_streamer, stop_check_fn=stop_event.is_set, verbose=False, refresh_negative=True, all_prefilled_outputs=copy.deepcopy(prefilled), ) except Exception as e: errors.append(e) audio_streamer.end() thread = threading.Thread(target=run_gen, daemon=True) thread.start() audio_chunks = [] for chunk in audio_streamer.get_stream(0): if torch.is_tensor(chunk): chunk = chunk.detach().cpu().to(torch.float32).numpy() else: chunk = np.asarray(chunk, dtype=np.float32) if chunk.ndim > 1: chunk = chunk.reshape(-1) audio_chunks.append(chunk) thread.join() if errors: return JSONResponse({"error": str(errors[0])}, status_code=500) audio = np.clip(np.concatenate(audio_chunks), -1.0, 1.0) tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False, dir="/tmp") sf.write(tmp.name, audio, 24000) return FileResponse(tmp.name, media_type="audio/wav", filename="vibevoice_tts.wav") except Exception as e: import traceback traceback.print_exc() return JSONResponse({"error": str(e)}, status_code=500) @app.get("/") def index(): return HTMLResponse(HTML_PAGE) HTML_PAGE = r"""
Microsoft 开源语音 AI — 语音识别 & 语音合成
示例:今天我们来讲民法典中关于不当得利的规定。根据民法典第九百八十五条,得利人没有法律根据取得不当利益的,受损失的人可以请求得利人返还取得的利益。