All checks were successful
Deploy Docusaurus Site / deploy (push) Successful in 27s
301 lines
9.0 KiB
Python
301 lines
9.0 KiB
Python
"""
|
|
FastAPI WebSocket server for real-time ASR.
|
|
|
|
Provides WebSocket endpoint for streaming audio and receiving transcriptions.
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import asyncio
|
|
import json
|
|
import time
|
|
from typing import Optional
|
|
from contextlib import asynccontextmanager
|
|
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
|
|
from fastapi.staticfiles import StaticFiles
|
|
from fastapi.responses import HTMLResponse, JSONResponse
|
|
import uvicorn
|
|
|
|
# Add parent directory to path
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from .models import (
|
|
SessionConfig,
|
|
TranscriptionResult,
|
|
VADEvent,
|
|
StatusMessage,
|
|
ErrorMessage,
|
|
MessageType,
|
|
)
|
|
from .asr_worker import ASRWorker
|
|
from .session_manager import SessionManager
|
|
|
|
|
|
# Global instances
|
|
asr_worker: Optional[ASRWorker] = None
|
|
session_manager: Optional[SessionManager] = None
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Application lifespan manager."""
|
|
global asr_worker, session_manager
|
|
|
|
# Startup
|
|
print("Starting VibeVoice Realtime ASR Server...")
|
|
|
|
# Get model path from environment or use default
|
|
model_path = os.environ.get("VIBEVOICE_MODEL_PATH", "microsoft/VibeVoice-ASR")
|
|
device = os.environ.get("VIBEVOICE_DEVICE", "cuda")
|
|
attn_impl = os.environ.get("VIBEVOICE_ATTN_IMPL", "flash_attention_2")
|
|
|
|
# Initialize ASR worker
|
|
asr_worker = ASRWorker(
|
|
model_path=model_path,
|
|
device=device,
|
|
attn_implementation=attn_impl,
|
|
)
|
|
|
|
# Pre-load model (optional, can be lazy-loaded on first request)
|
|
preload = os.environ.get("VIBEVOICE_PRELOAD_MODEL", "true").lower() == "true"
|
|
if preload:
|
|
print("Pre-loading ASR model...")
|
|
asr_worker.load_model()
|
|
|
|
# Initialize session manager
|
|
max_sessions = int(os.environ.get("VIBEVOICE_MAX_SESSIONS", "10"))
|
|
session_manager = SessionManager(
|
|
asr_worker=asr_worker,
|
|
max_concurrent_sessions=max_sessions,
|
|
)
|
|
await session_manager.start()
|
|
|
|
print("Server ready!")
|
|
|
|
yield
|
|
|
|
# Shutdown
|
|
print("Shutting down...")
|
|
await session_manager.stop()
|
|
|
|
|
|
# Create FastAPI app
|
|
app = FastAPI(
|
|
title="VibeVoice Realtime ASR",
|
|
description="Real-time speech recognition using VibeVoice ASR",
|
|
version="0.1.0",
|
|
lifespan=lifespan,
|
|
)
|
|
|
|
# Mount static files
|
|
static_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "static")
|
|
if os.path.exists(static_dir):
|
|
app.mount("/static", StaticFiles(directory=static_dir), name="static")
|
|
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
"""Root endpoint with API info."""
|
|
return {
|
|
"service": "VibeVoice Realtime ASR",
|
|
"version": "0.1.0",
|
|
"endpoints": {
|
|
"websocket": "/ws/asr/{session_id}",
|
|
"health": "/health",
|
|
"stats": "/stats",
|
|
"client": "/static/realtime_client.html",
|
|
},
|
|
}
|
|
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
"""Health check endpoint."""
|
|
return {
|
|
"status": "healthy",
|
|
"model_loaded": asr_worker.is_loaded if asr_worker else False,
|
|
"active_sessions": len(session_manager._sessions) if session_manager else 0,
|
|
}
|
|
|
|
|
|
@app.get("/stats")
|
|
async def get_stats():
|
|
"""Get server statistics."""
|
|
if session_manager is None:
|
|
raise HTTPException(status_code=503, detail="Server not initialized")
|
|
|
|
return session_manager.get_stats()
|
|
|
|
|
|
@app.websocket("/ws/asr/{session_id}")
|
|
async def websocket_asr(websocket: WebSocket, session_id: str):
|
|
"""
|
|
WebSocket endpoint for real-time ASR.
|
|
|
|
Protocol:
|
|
1. Client connects and optionally sends config message
|
|
2. Client sends binary audio chunks (PCM 16-bit, 16kHz, mono)
|
|
3. Server sends JSON messages with transcription results
|
|
|
|
Message types (server -> client):
|
|
- partial_result: Intermediate transcription
|
|
- final_result: Complete transcription for a segment
|
|
- vad_event: Speech start/end events
|
|
- error: Error messages
|
|
- status: Status updates
|
|
"""
|
|
await websocket.accept()
|
|
|
|
# Send connection confirmation
|
|
await websocket.send_json(
|
|
StatusMessage(
|
|
status="connected",
|
|
message=f"Session {session_id} connected",
|
|
).to_dict()
|
|
)
|
|
|
|
# Result callback
|
|
async def on_result(result: TranscriptionResult):
|
|
try:
|
|
await websocket.send_json(result.to_dict())
|
|
except Exception as e:
|
|
print(f"[{session_id}] Failed to send result: {e}")
|
|
|
|
# VAD event callback
|
|
async def on_vad_event(event: VADEvent):
|
|
try:
|
|
await websocket.send_json(event.to_dict())
|
|
except Exception as e:
|
|
print(f"[{session_id}] Failed to send VAD event: {e}")
|
|
|
|
# Create session
|
|
session = await session_manager.create_session(
|
|
session_id=session_id,
|
|
on_result=on_result,
|
|
on_vad_event=on_vad_event,
|
|
)
|
|
|
|
if session is None:
|
|
await websocket.send_json(
|
|
ErrorMessage(
|
|
error="Maximum sessions reached",
|
|
code="MAX_SESSIONS",
|
|
).to_dict()
|
|
)
|
|
await websocket.close()
|
|
return
|
|
|
|
await websocket.send_json(
|
|
StatusMessage(
|
|
status="ready",
|
|
message="Session ready for audio",
|
|
).to_dict()
|
|
)
|
|
|
|
try:
|
|
while True:
|
|
# Receive message
|
|
message = await websocket.receive()
|
|
|
|
if message["type"] == "websocket.disconnect":
|
|
break
|
|
|
|
# Handle binary audio data
|
|
if "bytes" in message:
|
|
audio_data = message["bytes"]
|
|
try:
|
|
await session.process_audio_chunk(audio_data)
|
|
except Exception as e:
|
|
print(f"[{session_id}] Error processing audio: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
# Handle JSON control messages
|
|
elif "text" in message:
|
|
try:
|
|
data = json.loads(message["text"])
|
|
msg_type = data.get("type")
|
|
|
|
if msg_type == "config":
|
|
# Update session config
|
|
config = SessionConfig.from_dict(data.get("config", {}))
|
|
session.update_config(config)
|
|
await websocket.send_json(
|
|
StatusMessage(
|
|
status="config_updated",
|
|
message="Configuration updated",
|
|
).to_dict()
|
|
)
|
|
|
|
elif msg_type == "stop":
|
|
# Flush and close
|
|
await session.flush()
|
|
await websocket.send_json(
|
|
StatusMessage(
|
|
status="stopped",
|
|
message="Session stopped",
|
|
).to_dict()
|
|
)
|
|
break
|
|
|
|
elif msg_type == "ping":
|
|
await websocket.send_json({"type": "pong", "timestamp": time.time()})
|
|
|
|
except json.JSONDecodeError:
|
|
await websocket.send_json(
|
|
ErrorMessage(
|
|
error="Invalid JSON",
|
|
code="INVALID_JSON",
|
|
).to_dict()
|
|
)
|
|
|
|
except WebSocketDisconnect:
|
|
print(f"[{session_id}] Client disconnected")
|
|
except Exception as e:
|
|
print(f"[{session_id}] Error: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
finally:
|
|
# Clean up session
|
|
await session_manager.close_session(session_id)
|
|
print(f"[{session_id}] Session closed")
|
|
|
|
|
|
def main():
|
|
"""Main entry point."""
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="VibeVoice Realtime ASR Server")
|
|
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
|
|
parser.add_argument("--port", type=int, default=8000, help="Port to bind to")
|
|
parser.add_argument("--model-path", type=str, default="microsoft/VibeVoice-ASR",
|
|
help="Path to VibeVoice ASR model")
|
|
parser.add_argument("--device", type=str, default="cuda", help="Device (cuda/cpu)")
|
|
parser.add_argument("--max-sessions", type=int, default=10, help="Max concurrent sessions")
|
|
parser.add_argument("--no-preload", action="store_true", help="Don't preload model")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Set environment variables for lifespan
|
|
os.environ["VIBEVOICE_MODEL_PATH"] = args.model_path
|
|
os.environ["VIBEVOICE_DEVICE"] = args.device
|
|
os.environ["VIBEVOICE_MAX_SESSIONS"] = str(args.max_sessions)
|
|
os.environ["VIBEVOICE_PRELOAD_MODEL"] = "false" if args.no_preload else "true"
|
|
|
|
print(f"Starting server on {args.host}:{args.port}")
|
|
print(f"Model: {args.model_path}")
|
|
print(f"Device: {args.device}")
|
|
print(f"Max sessions: {args.max_sessions}")
|
|
|
|
uvicorn.run(
|
|
app,
|
|
host=args.host,
|
|
port=args.port,
|
|
log_level="info",
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|