koide 1fb76254e9
All checks were successful
Deploy Docusaurus Site / deploy (push) Successful in 27s
Add: VibeVoice ASR セットアップスクリプト一式
2026-02-24 01:21:33 +00:00

301 lines
9.0 KiB
Python

"""
FastAPI WebSocket server for real-time ASR.
Provides WebSocket endpoint for streaming audio and receiving transcriptions.
"""
import os
import sys
import asyncio
import json
import time
from typing import Optional
from contextlib import asynccontextmanager
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse, JSONResponse
import uvicorn
# Add parent directory to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from .models import (
SessionConfig,
TranscriptionResult,
VADEvent,
StatusMessage,
ErrorMessage,
MessageType,
)
from .asr_worker import ASRWorker
from .session_manager import SessionManager
# Global instances
asr_worker: Optional[ASRWorker] = None
session_manager: Optional[SessionManager] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan manager."""
global asr_worker, session_manager
# Startup
print("Starting VibeVoice Realtime ASR Server...")
# Get model path from environment or use default
model_path = os.environ.get("VIBEVOICE_MODEL_PATH", "microsoft/VibeVoice-ASR")
device = os.environ.get("VIBEVOICE_DEVICE", "cuda")
attn_impl = os.environ.get("VIBEVOICE_ATTN_IMPL", "flash_attention_2")
# Initialize ASR worker
asr_worker = ASRWorker(
model_path=model_path,
device=device,
attn_implementation=attn_impl,
)
# Pre-load model (optional, can be lazy-loaded on first request)
preload = os.environ.get("VIBEVOICE_PRELOAD_MODEL", "true").lower() == "true"
if preload:
print("Pre-loading ASR model...")
asr_worker.load_model()
# Initialize session manager
max_sessions = int(os.environ.get("VIBEVOICE_MAX_SESSIONS", "10"))
session_manager = SessionManager(
asr_worker=asr_worker,
max_concurrent_sessions=max_sessions,
)
await session_manager.start()
print("Server ready!")
yield
# Shutdown
print("Shutting down...")
await session_manager.stop()
# Create FastAPI app
app = FastAPI(
title="VibeVoice Realtime ASR",
description="Real-time speech recognition using VibeVoice ASR",
version="0.1.0",
lifespan=lifespan,
)
# Mount static files
static_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "static")
if os.path.exists(static_dir):
app.mount("/static", StaticFiles(directory=static_dir), name="static")
@app.get("/")
async def root():
"""Root endpoint with API info."""
return {
"service": "VibeVoice Realtime ASR",
"version": "0.1.0",
"endpoints": {
"websocket": "/ws/asr/{session_id}",
"health": "/health",
"stats": "/stats",
"client": "/static/realtime_client.html",
},
}
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {
"status": "healthy",
"model_loaded": asr_worker.is_loaded if asr_worker else False,
"active_sessions": len(session_manager._sessions) if session_manager else 0,
}
@app.get("/stats")
async def get_stats():
"""Get server statistics."""
if session_manager is None:
raise HTTPException(status_code=503, detail="Server not initialized")
return session_manager.get_stats()
@app.websocket("/ws/asr/{session_id}")
async def websocket_asr(websocket: WebSocket, session_id: str):
"""
WebSocket endpoint for real-time ASR.
Protocol:
1. Client connects and optionally sends config message
2. Client sends binary audio chunks (PCM 16-bit, 16kHz, mono)
3. Server sends JSON messages with transcription results
Message types (server -> client):
- partial_result: Intermediate transcription
- final_result: Complete transcription for a segment
- vad_event: Speech start/end events
- error: Error messages
- status: Status updates
"""
await websocket.accept()
# Send connection confirmation
await websocket.send_json(
StatusMessage(
status="connected",
message=f"Session {session_id} connected",
).to_dict()
)
# Result callback
async def on_result(result: TranscriptionResult):
try:
await websocket.send_json(result.to_dict())
except Exception as e:
print(f"[{session_id}] Failed to send result: {e}")
# VAD event callback
async def on_vad_event(event: VADEvent):
try:
await websocket.send_json(event.to_dict())
except Exception as e:
print(f"[{session_id}] Failed to send VAD event: {e}")
# Create session
session = await session_manager.create_session(
session_id=session_id,
on_result=on_result,
on_vad_event=on_vad_event,
)
if session is None:
await websocket.send_json(
ErrorMessage(
error="Maximum sessions reached",
code="MAX_SESSIONS",
).to_dict()
)
await websocket.close()
return
await websocket.send_json(
StatusMessage(
status="ready",
message="Session ready for audio",
).to_dict()
)
try:
while True:
# Receive message
message = await websocket.receive()
if message["type"] == "websocket.disconnect":
break
# Handle binary audio data
if "bytes" in message:
audio_data = message["bytes"]
try:
await session.process_audio_chunk(audio_data)
except Exception as e:
print(f"[{session_id}] Error processing audio: {e}")
import traceback
traceback.print_exc()
# Handle JSON control messages
elif "text" in message:
try:
data = json.loads(message["text"])
msg_type = data.get("type")
if msg_type == "config":
# Update session config
config = SessionConfig.from_dict(data.get("config", {}))
session.update_config(config)
await websocket.send_json(
StatusMessage(
status="config_updated",
message="Configuration updated",
).to_dict()
)
elif msg_type == "stop":
# Flush and close
await session.flush()
await websocket.send_json(
StatusMessage(
status="stopped",
message="Session stopped",
).to_dict()
)
break
elif msg_type == "ping":
await websocket.send_json({"type": "pong", "timestamp": time.time()})
except json.JSONDecodeError:
await websocket.send_json(
ErrorMessage(
error="Invalid JSON",
code="INVALID_JSON",
).to_dict()
)
except WebSocketDisconnect:
print(f"[{session_id}] Client disconnected")
except Exception as e:
print(f"[{session_id}] Error: {e}")
import traceback
traceback.print_exc()
finally:
# Clean up session
await session_manager.close_session(session_id)
print(f"[{session_id}] Session closed")
def main():
"""Main entry point."""
import argparse
parser = argparse.ArgumentParser(description="VibeVoice Realtime ASR Server")
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
parser.add_argument("--port", type=int, default=8000, help="Port to bind to")
parser.add_argument("--model-path", type=str, default="microsoft/VibeVoice-ASR",
help="Path to VibeVoice ASR model")
parser.add_argument("--device", type=str, default="cuda", help="Device (cuda/cpu)")
parser.add_argument("--max-sessions", type=int, default=10, help="Max concurrent sessions")
parser.add_argument("--no-preload", action="store_true", help="Don't preload model")
args = parser.parse_args()
# Set environment variables for lifespan
os.environ["VIBEVOICE_MODEL_PATH"] = args.model_path
os.environ["VIBEVOICE_DEVICE"] = args.device
os.environ["VIBEVOICE_MAX_SESSIONS"] = str(args.max_sessions)
os.environ["VIBEVOICE_PRELOAD_MODEL"] = "false" if args.no_preload else "true"
print(f"Starting server on {args.host}:{args.port}")
print(f"Model: {args.model_path}")
print(f"Device: {args.device}")
print(f"Max sessions: {args.max_sessions}")
uvicorn.run(
app,
host=args.host,
port=args.port,
log_level="info",
)
if __name__ == "__main__":
main()