""" FastAPI WebSocket server for real-time ASR. Provides WebSocket endpoint for streaming audio and receiving transcriptions. """ import os import sys import asyncio import json import time from typing import Optional from contextlib import asynccontextmanager from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException from fastapi.staticfiles import StaticFiles from fastapi.responses import HTMLResponse, JSONResponse import uvicorn # Add parent directory to path sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from .models import ( SessionConfig, TranscriptionResult, VADEvent, StatusMessage, ErrorMessage, MessageType, ) from .asr_worker import ASRWorker from .session_manager import SessionManager # Global instances asr_worker: Optional[ASRWorker] = None session_manager: Optional[SessionManager] = None @asynccontextmanager async def lifespan(app: FastAPI): """Application lifespan manager.""" global asr_worker, session_manager # Startup print("Starting VibeVoice Realtime ASR Server...") # Get model path from environment or use default model_path = os.environ.get("VIBEVOICE_MODEL_PATH", "microsoft/VibeVoice-ASR") device = os.environ.get("VIBEVOICE_DEVICE", "cuda") attn_impl = os.environ.get("VIBEVOICE_ATTN_IMPL", "flash_attention_2") # Initialize ASR worker asr_worker = ASRWorker( model_path=model_path, device=device, attn_implementation=attn_impl, ) # Pre-load model (optional, can be lazy-loaded on first request) preload = os.environ.get("VIBEVOICE_PRELOAD_MODEL", "true").lower() == "true" if preload: print("Pre-loading ASR model...") asr_worker.load_model() # Initialize session manager max_sessions = int(os.environ.get("VIBEVOICE_MAX_SESSIONS", "10")) session_manager = SessionManager( asr_worker=asr_worker, max_concurrent_sessions=max_sessions, ) await session_manager.start() print("Server ready!") yield # Shutdown print("Shutting down...") await session_manager.stop() # Create FastAPI app app = FastAPI( title="VibeVoice Realtime ASR", description="Real-time speech recognition using VibeVoice ASR", version="0.1.0", lifespan=lifespan, ) # Mount static files static_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "static") if os.path.exists(static_dir): app.mount("/static", StaticFiles(directory=static_dir), name="static") @app.get("/") async def root(): """Root endpoint with API info.""" return { "service": "VibeVoice Realtime ASR", "version": "0.1.0", "endpoints": { "websocket": "/ws/asr/{session_id}", "health": "/health", "stats": "/stats", "client": "/static/realtime_client.html", }, } @app.get("/health") async def health_check(): """Health check endpoint.""" return { "status": "healthy", "model_loaded": asr_worker.is_loaded if asr_worker else False, "active_sessions": len(session_manager._sessions) if session_manager else 0, } @app.get("/stats") async def get_stats(): """Get server statistics.""" if session_manager is None: raise HTTPException(status_code=503, detail="Server not initialized") return session_manager.get_stats() @app.websocket("/ws/asr/{session_id}") async def websocket_asr(websocket: WebSocket, session_id: str): """ WebSocket endpoint for real-time ASR. Protocol: 1. Client connects and optionally sends config message 2. Client sends binary audio chunks (PCM 16-bit, 16kHz, mono) 3. Server sends JSON messages with transcription results Message types (server -> client): - partial_result: Intermediate transcription - final_result: Complete transcription for a segment - vad_event: Speech start/end events - error: Error messages - status: Status updates """ await websocket.accept() # Send connection confirmation await websocket.send_json( StatusMessage( status="connected", message=f"Session {session_id} connected", ).to_dict() ) # Result callback async def on_result(result: TranscriptionResult): try: await websocket.send_json(result.to_dict()) except Exception as e: print(f"[{session_id}] Failed to send result: {e}") # VAD event callback async def on_vad_event(event: VADEvent): try: await websocket.send_json(event.to_dict()) except Exception as e: print(f"[{session_id}] Failed to send VAD event: {e}") # Create session session = await session_manager.create_session( session_id=session_id, on_result=on_result, on_vad_event=on_vad_event, ) if session is None: await websocket.send_json( ErrorMessage( error="Maximum sessions reached", code="MAX_SESSIONS", ).to_dict() ) await websocket.close() return await websocket.send_json( StatusMessage( status="ready", message="Session ready for audio", ).to_dict() ) try: while True: # Receive message message = await websocket.receive() if message["type"] == "websocket.disconnect": break # Handle binary audio data if "bytes" in message: audio_data = message["bytes"] try: await session.process_audio_chunk(audio_data) except Exception as e: print(f"[{session_id}] Error processing audio: {e}") import traceback traceback.print_exc() # Handle JSON control messages elif "text" in message: try: data = json.loads(message["text"]) msg_type = data.get("type") if msg_type == "config": # Update session config config = SessionConfig.from_dict(data.get("config", {})) session.update_config(config) await websocket.send_json( StatusMessage( status="config_updated", message="Configuration updated", ).to_dict() ) elif msg_type == "stop": # Flush and close await session.flush() await websocket.send_json( StatusMessage( status="stopped", message="Session stopped", ).to_dict() ) break elif msg_type == "ping": await websocket.send_json({"type": "pong", "timestamp": time.time()}) except json.JSONDecodeError: await websocket.send_json( ErrorMessage( error="Invalid JSON", code="INVALID_JSON", ).to_dict() ) except WebSocketDisconnect: print(f"[{session_id}] Client disconnected") except Exception as e: print(f"[{session_id}] Error: {e}") import traceback traceback.print_exc() finally: # Clean up session await session_manager.close_session(session_id) print(f"[{session_id}] Session closed") def main(): """Main entry point.""" import argparse parser = argparse.ArgumentParser(description="VibeVoice Realtime ASR Server") parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to") parser.add_argument("--port", type=int, default=8000, help="Port to bind to") parser.add_argument("--model-path", type=str, default="microsoft/VibeVoice-ASR", help="Path to VibeVoice ASR model") parser.add_argument("--device", type=str, default="cuda", help="Device (cuda/cpu)") parser.add_argument("--max-sessions", type=int, default=10, help="Max concurrent sessions") parser.add_argument("--no-preload", action="store_true", help="Don't preload model") args = parser.parse_args() # Set environment variables for lifespan os.environ["VIBEVOICE_MODEL_PATH"] = args.model_path os.environ["VIBEVOICE_DEVICE"] = args.device os.environ["VIBEVOICE_MAX_SESSIONS"] = str(args.max_sessions) os.environ["VIBEVOICE_PRELOAD_MODEL"] = "false" if args.no_preload else "true" print(f"Starting server on {args.host}:{args.port}") print(f"Model: {args.model_path}") print(f"Device: {args.device}") print(f"Max sessions: {args.max_sessions}") uvicorn.run( app, host=args.host, port=args.port, log_level="info", ) if __name__ == "__main__": main()