Add: VibeVoice ASR セットアップスクリプト一式
All checks were successful
Deploy Docusaurus Site / deploy (push) Successful in 27s

This commit is contained in:
koide 2026-02-24 01:21:33 +00:00
parent 2d753f114f
commit 1fb76254e9
15 changed files with 4531 additions and 0 deletions

View File

@ -0,0 +1,61 @@
# VibeVoice-ASR for DGX Spark (ARM64, Blackwell GB10, sm_121)
# Based on NVIDIA PyTorch container for CUDA 13.1 compatibility
ARG TARGETARCH
FROM nvcr.io/nvidia/pytorch:25.11-py3 AS base
LABEL maintainer="VibeVoice-ASR DGX Spark Setup"
LABEL description="VibeVoice-ASR optimized for DGX Spark (ARM64, CUDA 13.1)"
# Set environment variables
ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONDONTWRITEBYTECODE=1
ENV PYTHONUNBUFFERED=1
# PyTorch CUDA settings for DGX Spark
ENV PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
ENV USE_LIBUV=0
# Set working directory
WORKDIR /workspace
# Install system dependencies including FFmpeg for demo
RUN apt-get update && apt-get install -y --no-install-recommends \
ffmpeg \
git \
curl \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
# Install flash-attn if not already present
RUN pip install --no-cache-dir flash-attn --no-build-isolation || true
# Clone and install VibeVoice
RUN git clone https://github.com/microsoft/VibeVoice.git /workspace/VibeVoice && \
cd /workspace/VibeVoice && \
pip install --no-cache-dir -e .
# Create test script and patched demo with MKV support
COPY test_vibevoice.py /workspace/test_vibevoice.py
COPY vibevoice_asr_gradio_demo_patched.py /workspace/VibeVoice/demo/vibevoice_asr_gradio_demo.py
# Install real-time ASR dependencies
COPY requirements-realtime.txt /workspace/requirements-realtime.txt
RUN pip install --no-cache-dir -r /workspace/requirements-realtime.txt
# Copy real-time ASR module and startup scripts
COPY realtime/ /workspace/VibeVoice/realtime/
COPY static/ /workspace/VibeVoice/static/
COPY run_all.sh /workspace/VibeVoice/run_all.sh
COPY run_realtime.sh /workspace/VibeVoice/run_realtime.sh
RUN chmod +x /workspace/VibeVoice/run_all.sh /workspace/VibeVoice/run_realtime.sh
# Set default working directory to VibeVoice
WORKDIR /workspace/VibeVoice
# Expose Gradio port and WebSocket port
EXPOSE 7860
EXPOSE 8000
# Default command: Launch Gradio demo with MKV support
CMD ["python", "demo/vibevoice_asr_gradio_demo.py", "--model_path", "microsoft/VibeVoice-ASR", "--host", "0.0.0.0"]

View File

@ -0,0 +1,7 @@
"""
VibeVoice Realtime ASR Module
WebSocket-based real-time speech recognition using VibeVoice ASR.
"""
__version__ = "0.1.0"

View File

@ -0,0 +1,358 @@
"""
ASR Worker for real-time transcription.
Wraps the existing VibeVoiceASRInference for async/streaming operation.
"""
import sys
import os
import asyncio
import threading
import time
import queue
from typing import AsyncGenerator, Optional, List, Callable
from dataclasses import dataclass
import numpy as np
import torch
# Add parent directory and demo directory to path for importing existing code
_parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, _parent_dir)
sys.path.insert(0, os.path.join(_parent_dir, "demo"))
from .models import (
TranscriptionResult,
TranscriptionSegment,
MessageType,
SessionConfig,
)
@dataclass
class InferenceRequest:
"""Request for ASR inference."""
audio: np.ndarray
sample_rate: int
context_info: Optional[str]
request_time: float
segment_start_sec: float
segment_end_sec: float
class ASRWorker:
"""
ASR Worker that wraps VibeVoiceASRInference for real-time use.
Features:
- Async interface for WebSocket integration
- Streaming output via TextIteratorStreamer
- Request queuing for handling concurrent segments
- Graceful model loading and error handling
"""
def __init__(
self,
model_path: str = "microsoft/VibeVoice-ASR",
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
attn_implementation: str = "flash_attention_2",
):
"""
Initialize the ASR worker.
Args:
model_path: Path to VibeVoice ASR model
device: Device to run inference on
dtype: Model data type
attn_implementation: Attention implementation
"""
self.model_path = model_path
self.device = device
self.dtype = dtype
self.attn_implementation = attn_implementation
self._inference = None
self._is_loaded = False
self._load_lock = threading.Lock()
# Inference queue for serializing requests
self._inference_semaphore = asyncio.Semaphore(1)
def load_model(self) -> bool:
"""
Load the ASR model.
Returns:
True if model loaded successfully
"""
with self._load_lock:
if self._is_loaded:
return True
try:
# Import here to avoid circular imports and allow lazy loading
# In Docker, the file is copied as vibevoice_asr_gradio_demo.py
try:
from vibevoice_asr_gradio_demo import VibeVoiceASRInference
except ImportError:
from vibevoice_asr_gradio_demo_patched import VibeVoiceASRInference
print(f"Loading VibeVoice ASR model from {self.model_path}...")
self._inference = VibeVoiceASRInference(
model_path=self.model_path,
device=self.device,
dtype=self.dtype,
attn_implementation=self.attn_implementation,
)
self._is_loaded = True
print("ASR model loaded successfully")
return True
except Exception as e:
print(f"Failed to load ASR model: {e}")
import traceback
traceback.print_exc()
return False
@property
def is_loaded(self) -> bool:
"""Check if model is loaded."""
return self._is_loaded
async def transcribe_segment(
self,
audio: np.ndarray,
sample_rate: int = 16000,
context_info: Optional[str] = None,
segment_start_sec: float = 0.0,
segment_end_sec: float = 0.0,
config: Optional[SessionConfig] = None,
on_partial: Optional[Callable[[TranscriptionResult], None]] = None,
) -> TranscriptionResult:
"""
Transcribe an audio segment asynchronously.
Args:
audio: Audio data as float32 array
sample_rate: Audio sample rate
context_info: Optional context for transcription
segment_start_sec: Start time of segment in session
segment_end_sec: End time of segment in session
config: Session configuration
on_partial: Callback for partial results
Returns:
Final transcription result
"""
if not self._is_loaded:
if not self.load_model():
return TranscriptionResult(
type=MessageType.ERROR,
text="",
is_final=True,
latency_ms=0,
)
config = config or SessionConfig()
request_time = time.time()
# Serialize inference requests
async with self._inference_semaphore:
return await self._run_inference(
audio=audio,
sample_rate=sample_rate,
context_info=context_info,
segment_start_sec=segment_start_sec,
segment_end_sec=segment_end_sec,
config=config,
request_time=request_time,
on_partial=on_partial,
)
async def _run_inference(
self,
audio: np.ndarray,
sample_rate: int,
context_info: Optional[str],
segment_start_sec: float,
segment_end_sec: float,
config: SessionConfig,
request_time: float,
on_partial: Optional[Callable[[TranscriptionResult], None]],
) -> TranscriptionResult:
"""Run the actual inference in a thread pool."""
from transformers import TextIteratorStreamer
# Create streamer for partial results
streamer = None
if config.return_partial_results and on_partial:
streamer = TextIteratorStreamer(
self._inference.processor.tokenizer,
skip_prompt=True,
skip_special_tokens=True,
)
# Result container for thread
result_container = {"result": None, "error": None}
def run_inference():
try:
# Save audio to temp file (required by current implementation)
import tempfile
import soundfile as sf
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
temp_path = f.name
# Write audio
audio_int16 = (audio * 32768.0).clip(-32768, 32767).astype(np.int16)
sf.write(temp_path, audio_int16, sample_rate, subtype='PCM_16')
try:
result = self._inference.transcribe(
audio_path=temp_path,
max_new_tokens=config.max_new_tokens,
temperature=config.temperature,
context_info=context_info,
streamer=streamer,
)
result_container["result"] = result
finally:
# Clean up temp file
try:
os.unlink(temp_path)
except:
pass
except Exception as e:
result_container["error"] = str(e)
import traceback
traceback.print_exc()
# Start inference in background thread
inference_thread = threading.Thread(target=run_inference)
inference_thread.start()
# Stream partial results if enabled
partial_text = ""
if streamer and on_partial:
try:
for new_text in streamer:
partial_text += new_text
partial_result = TranscriptionResult(
type=MessageType.PARTIAL_RESULT,
text=partial_text,
is_final=False,
latency_ms=(time.time() - request_time) * 1000,
)
# Call callback (may be async)
if asyncio.iscoroutinefunction(on_partial):
await on_partial(partial_result)
else:
on_partial(partial_result)
except Exception as e:
print(f"Error during streaming: {e}")
# Wait for completion
inference_thread.join()
latency_ms = (time.time() - request_time) * 1000
if result_container["error"]:
return TranscriptionResult(
type=MessageType.ERROR,
text=f"Error: {result_container['error']}",
is_final=True,
latency_ms=latency_ms,
)
result = result_container["result"]
# Convert segments to our format
segments = []
for seg in result.get("segments", []):
# Adjust timestamps relative to session
seg_start = seg.get("start_time", 0)
seg_end = seg.get("end_time", 0)
# If segment has relative timestamps, adjust to absolute
if isinstance(seg_start, (int, float)) and isinstance(seg_end, (int, float)):
adjusted_start = segment_start_sec + seg_start
adjusted_end = segment_start_sec + seg_end
else:
adjusted_start = segment_start_sec
adjusted_end = segment_end_sec
segments.append(TranscriptionSegment(
start_time=adjusted_start,
end_time=adjusted_end,
speaker_id=seg.get("speaker_id", "SPEAKER_00"),
text=seg.get("text", ""),
))
return TranscriptionResult(
type=MessageType.FINAL_RESULT,
text=result.get("raw_text", ""),
is_final=True,
segments=segments,
latency_ms=latency_ms,
)
async def transcribe_stream(
self,
audio: np.ndarray,
sample_rate: int = 16000,
context_info: Optional[str] = None,
segment_start_sec: float = 0.0,
segment_end_sec: float = 0.0,
config: Optional[SessionConfig] = None,
) -> AsyncGenerator[TranscriptionResult, None]:
"""
Transcribe an audio segment with streaming output.
Yields partial results followed by final result.
Args:
audio: Audio data
sample_rate: Sample rate
context_info: Optional context
segment_start_sec: Segment start time
segment_end_sec: Segment end time
config: Session config
Yields:
TranscriptionResult objects (partial and final)
"""
result_queue: asyncio.Queue = asyncio.Queue()
async def on_partial(result: TranscriptionResult):
await result_queue.put(result)
# Start transcription task
transcribe_task = asyncio.create_task(
self.transcribe_segment(
audio=audio,
sample_rate=sample_rate,
context_info=context_info,
segment_start_sec=segment_start_sec,
segment_end_sec=segment_end_sec,
config=config,
on_partial=on_partial,
)
)
# Yield partial results as they come
while not transcribe_task.done():
try:
result = await asyncio.wait_for(result_queue.get(), timeout=0.1)
yield result
except asyncio.TimeoutError:
continue
# Drain any remaining partial results
while not result_queue.empty():
yield await result_queue.get()
# Yield final result
final_result = await transcribe_task
yield final_result

View File

@ -0,0 +1,246 @@
"""
Audio buffer management for real-time ASR.
Implements a ring buffer for efficient audio chunk management with overlap support.
"""
import numpy as np
from typing import Optional, Tuple
from dataclasses import dataclass
import threading
@dataclass
class AudioChunkInfo:
"""Information about an extracted audio chunk."""
audio: np.ndarray
start_sample: int
end_sample: int
start_sec: float
end_sec: float
class AudioBuffer:
"""
Ring buffer for managing audio chunks with overlap support.
Features:
- Efficient memory management with fixed-size buffer
- Overlap handling for continuous processing
- Thread-safe operations
- Automatic sample rate tracking
"""
def __init__(
self,
sample_rate: int = 16000,
chunk_duration_sec: float = 3.0,
overlap_sec: float = 0.5,
max_buffer_sec: float = 60.0,
):
"""
Initialize the audio buffer.
Args:
sample_rate: Audio sample rate in Hz
chunk_duration_sec: Duration of each processing chunk
overlap_sec: Overlap between consecutive chunks
max_buffer_sec: Maximum buffer duration (older data will be discarded)
"""
self.sample_rate = sample_rate
self.chunk_size = int(chunk_duration_sec * sample_rate)
self.overlap_size = int(overlap_sec * sample_rate)
self.max_buffer_size = int(max_buffer_sec * sample_rate)
# Main buffer (pre-allocated)
self._buffer = np.zeros(self.max_buffer_size, dtype=np.float32)
self._write_pos = 0 # Next position to write
self._read_pos = 0 # Position of unprocessed data start
self._total_samples_received = 0 # Total samples since session start
self._lock = threading.Lock()
@property
def samples_available(self) -> int:
"""Number of unprocessed samples in buffer."""
with self._lock:
return self._write_pos - self._read_pos
@property
def duration_available_sec(self) -> float:
"""Duration of unprocessed audio in seconds."""
return self.samples_available / self.sample_rate
@property
def total_duration_sec(self) -> float:
"""Total duration of audio received since session start."""
return self._total_samples_received / self.sample_rate
def append(self, audio_chunk: np.ndarray) -> int:
"""
Append audio chunk to the buffer.
Args:
audio_chunk: Audio data as float32 array (range: -1.0 to 1.0)
Returns:
Number of samples actually appended
"""
if audio_chunk.dtype != np.float32:
audio_chunk = audio_chunk.astype(np.float32)
# Ensure 1D
if audio_chunk.ndim > 1:
audio_chunk = audio_chunk.flatten()
with self._lock:
chunk_len = len(audio_chunk)
# Check if we need to shift buffer (running out of space)
if self._write_pos + chunk_len > self.max_buffer_size:
self._compact_buffer()
# Still not enough space? Discard old unprocessed data
if self._write_pos + chunk_len > self.max_buffer_size:
overflow = (self._write_pos + chunk_len) - self.max_buffer_size
self._read_pos = min(self._read_pos + overflow, self._write_pos)
self._compact_buffer()
# Write to buffer
end_pos = self._write_pos + chunk_len
self._buffer[self._write_pos:end_pos] = audio_chunk
self._write_pos = end_pos
self._total_samples_received += chunk_len
return chunk_len
def _compact_buffer(self) -> None:
"""Move unprocessed data to the beginning of the buffer."""
if self._read_pos > 0:
unprocessed_len = self._write_pos - self._read_pos
if unprocessed_len > 0:
self._buffer[:unprocessed_len] = self._buffer[self._read_pos:self._write_pos]
self._write_pos = unprocessed_len
self._read_pos = 0
def get_chunk_for_inference(self, min_duration_sec: float = 0.5) -> Optional[AudioChunkInfo]:
"""
Get the next chunk for ASR inference.
Returns a chunk of audio when enough data is available.
The chunk includes overlap from the previous chunk for context.
Args:
min_duration_sec: Minimum duration required to return a chunk
Returns:
AudioChunkInfo if enough data is available, None otherwise
"""
min_samples = int(min_duration_sec * self.sample_rate)
with self._lock:
available = self._write_pos - self._read_pos
if available < min_samples:
return None
# Calculate chunk boundaries
chunk_start = self._read_pos
chunk_end = min(self._read_pos + self.chunk_size, self._write_pos)
actual_chunk_size = chunk_end - chunk_start
# Extract audio
audio = self._buffer[chunk_start:chunk_end].copy()
# Calculate timestamps based on total samples received
base_sample = self._total_samples_received - (self._write_pos - chunk_start)
start_sec = base_sample / self.sample_rate
end_sec = (base_sample + actual_chunk_size) / self.sample_rate
return AudioChunkInfo(
audio=audio,
start_sample=base_sample,
end_sample=base_sample + actual_chunk_size,
start_sec=start_sec,
end_sec=end_sec,
)
def mark_processed(self, samples: int) -> None:
"""
Mark samples as processed, advancing the read position.
Keeps overlap_size samples for context in the next chunk.
Args:
samples: Number of samples that were processed
"""
with self._lock:
# Advance read position but keep overlap for context
advance = max(0, samples - self.overlap_size)
self._read_pos = min(self._read_pos + advance, self._write_pos)
def get_segment(self, start_sec: float, end_sec: float) -> Optional[np.ndarray]:
"""
Get a specific time segment from the buffer.
Args:
start_sec: Start time in seconds (relative to session start)
end_sec: End time in seconds
Returns:
Audio segment if available, None otherwise
"""
start_sample = int(start_sec * self.sample_rate)
end_sample = int(end_sec * self.sample_rate)
with self._lock:
# Calculate buffer positions
buffer_start_sample = self._total_samples_received - self._write_pos
buffer_end_sample = self._total_samples_received
# Check if segment is in buffer
if start_sample < buffer_start_sample or end_sample > buffer_end_sample:
return None
# Convert to buffer indices
buf_start = start_sample - buffer_start_sample
buf_end = end_sample - buffer_start_sample
return self._buffer[buf_start:buf_end].copy()
def get_all_unprocessed(self) -> Optional[AudioChunkInfo]:
"""
Get all unprocessed audio.
Returns:
AudioChunkInfo with all unprocessed audio, or None if empty
"""
with self._lock:
if self._write_pos <= self._read_pos:
return None
audio = self._buffer[self._read_pos:self._write_pos].copy()
base_sample = self._total_samples_received - (self._write_pos - self._read_pos)
start_sec = base_sample / self.sample_rate
end_sec = self._total_samples_received / self.sample_rate
return AudioChunkInfo(
audio=audio,
start_sample=base_sample,
end_sample=self._total_samples_received,
start_sec=start_sec,
end_sec=end_sec,
)
def clear(self) -> None:
"""Clear the buffer and reset all positions."""
with self._lock:
self._buffer.fill(0)
self._write_pos = 0
self._read_pos = 0
self._total_samples_received = 0
def reset_read_position(self) -> None:
"""Reset read position to current write position (skip all unprocessed)."""
with self._lock:
self._read_pos = self._write_pos

View File

@ -0,0 +1,154 @@
"""
Data models for real-time ASR WebSocket communication.
"""
from enum import Enum
from typing import Optional, List, Dict, Any
from dataclasses import dataclass, field, asdict
import time
class MessageType(str, Enum):
"""WebSocket message types."""
# Client -> Server
AUDIO_CHUNK = "audio_chunk"
CONFIG = "config"
START = "start"
STOP = "stop"
# Server -> Client
PARTIAL_RESULT = "partial_result"
FINAL_RESULT = "final_result"
VAD_EVENT = "vad_event"
ERROR = "error"
STATUS = "status"
class VADEventType(str, Enum):
"""VAD event types."""
SPEECH_START = "speech_start"
SPEECH_END = "speech_end"
@dataclass
class SessionConfig:
"""Configuration for a real-time ASR session."""
# Audio parameters
sample_rate: int = 16000
chunk_duration_sec: float = 3.0
overlap_sec: float = 0.5
# VAD parameters
vad_threshold: float = 0.5
min_speech_duration_ms: int = 250
min_silence_duration_ms: int = 500
min_volume_threshold: float = 0.01 # Minimum RMS volume (0.0-1.0) to consider as potential speech
# ASR parameters
max_new_tokens: int = 512
temperature: float = 0.0
context_info: Optional[str] = None
# Behavior
return_partial_results: bool = True
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SessionConfig":
return cls(**{k: v for k, v in data.items() if k in cls.__dataclass_fields__})
@dataclass
class TranscriptionSegment:
"""A single transcription segment with metadata."""
start_time: float
end_time: float
speaker_id: str
text: str
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
@dataclass
class TranscriptionResult:
"""Transcription result message."""
type: MessageType
text: str
is_final: bool
segments: List[TranscriptionSegment] = field(default_factory=list)
latency_ms: float = 0.0
timestamp: float = field(default_factory=time.time)
def to_dict(self) -> Dict[str, Any]:
return {
"type": self.type.value,
"text": self.text,
"is_final": self.is_final,
"segments": [s.to_dict() for s in self.segments],
"latency_ms": self.latency_ms,
"timestamp": self.timestamp,
}
@dataclass
class VADEvent:
"""VAD event message."""
type: MessageType = MessageType.VAD_EVENT
event: VADEventType = VADEventType.SPEECH_START
timestamp: float = field(default_factory=time.time)
audio_timestamp_sec: float = 0.0
def to_dict(self) -> Dict[str, Any]:
return {
"type": self.type.value,
"event": self.event.value,
"timestamp": self.timestamp,
"audio_timestamp_sec": self.audio_timestamp_sec,
}
@dataclass
class StatusMessage:
"""Status message."""
type: MessageType = MessageType.STATUS
status: str = ""
message: str = ""
timestamp: float = field(default_factory=time.time)
def to_dict(self) -> Dict[str, Any]:
return {
"type": self.type.value,
"status": self.status,
"message": self.message,
"timestamp": self.timestamp,
}
@dataclass
class ErrorMessage:
"""Error message."""
type: MessageType = MessageType.ERROR
error: str = ""
code: str = ""
timestamp: float = field(default_factory=time.time)
def to_dict(self) -> Dict[str, Any]:
return {
"type": self.type.value,
"error": self.error,
"code": self.code,
"timestamp": self.timestamp,
}
@dataclass
class SpeechSegment:
"""Detected speech segment from VAD."""
start_sample: int
end_sample: int
start_sec: float
end_sec: float
confidence: float = 1.0

View File

@ -0,0 +1,300 @@
"""
FastAPI WebSocket server for real-time ASR.
Provides WebSocket endpoint for streaming audio and receiving transcriptions.
"""
import os
import sys
import asyncio
import json
import time
from typing import Optional
from contextlib import asynccontextmanager
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse, JSONResponse
import uvicorn
# Add parent directory to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from .models import (
SessionConfig,
TranscriptionResult,
VADEvent,
StatusMessage,
ErrorMessage,
MessageType,
)
from .asr_worker import ASRWorker
from .session_manager import SessionManager
# Global instances
asr_worker: Optional[ASRWorker] = None
session_manager: Optional[SessionManager] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan manager."""
global asr_worker, session_manager
# Startup
print("Starting VibeVoice Realtime ASR Server...")
# Get model path from environment or use default
model_path = os.environ.get("VIBEVOICE_MODEL_PATH", "microsoft/VibeVoice-ASR")
device = os.environ.get("VIBEVOICE_DEVICE", "cuda")
attn_impl = os.environ.get("VIBEVOICE_ATTN_IMPL", "flash_attention_2")
# Initialize ASR worker
asr_worker = ASRWorker(
model_path=model_path,
device=device,
attn_implementation=attn_impl,
)
# Pre-load model (optional, can be lazy-loaded on first request)
preload = os.environ.get("VIBEVOICE_PRELOAD_MODEL", "true").lower() == "true"
if preload:
print("Pre-loading ASR model...")
asr_worker.load_model()
# Initialize session manager
max_sessions = int(os.environ.get("VIBEVOICE_MAX_SESSIONS", "10"))
session_manager = SessionManager(
asr_worker=asr_worker,
max_concurrent_sessions=max_sessions,
)
await session_manager.start()
print("Server ready!")
yield
# Shutdown
print("Shutting down...")
await session_manager.stop()
# Create FastAPI app
app = FastAPI(
title="VibeVoice Realtime ASR",
description="Real-time speech recognition using VibeVoice ASR",
version="0.1.0",
lifespan=lifespan,
)
# Mount static files
static_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "static")
if os.path.exists(static_dir):
app.mount("/static", StaticFiles(directory=static_dir), name="static")
@app.get("/")
async def root():
"""Root endpoint with API info."""
return {
"service": "VibeVoice Realtime ASR",
"version": "0.1.0",
"endpoints": {
"websocket": "/ws/asr/{session_id}",
"health": "/health",
"stats": "/stats",
"client": "/static/realtime_client.html",
},
}
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {
"status": "healthy",
"model_loaded": asr_worker.is_loaded if asr_worker else False,
"active_sessions": len(session_manager._sessions) if session_manager else 0,
}
@app.get("/stats")
async def get_stats():
"""Get server statistics."""
if session_manager is None:
raise HTTPException(status_code=503, detail="Server not initialized")
return session_manager.get_stats()
@app.websocket("/ws/asr/{session_id}")
async def websocket_asr(websocket: WebSocket, session_id: str):
"""
WebSocket endpoint for real-time ASR.
Protocol:
1. Client connects and optionally sends config message
2. Client sends binary audio chunks (PCM 16-bit, 16kHz, mono)
3. Server sends JSON messages with transcription results
Message types (server -> client):
- partial_result: Intermediate transcription
- final_result: Complete transcription for a segment
- vad_event: Speech start/end events
- error: Error messages
- status: Status updates
"""
await websocket.accept()
# Send connection confirmation
await websocket.send_json(
StatusMessage(
status="connected",
message=f"Session {session_id} connected",
).to_dict()
)
# Result callback
async def on_result(result: TranscriptionResult):
try:
await websocket.send_json(result.to_dict())
except Exception as e:
print(f"[{session_id}] Failed to send result: {e}")
# VAD event callback
async def on_vad_event(event: VADEvent):
try:
await websocket.send_json(event.to_dict())
except Exception as e:
print(f"[{session_id}] Failed to send VAD event: {e}")
# Create session
session = await session_manager.create_session(
session_id=session_id,
on_result=on_result,
on_vad_event=on_vad_event,
)
if session is None:
await websocket.send_json(
ErrorMessage(
error="Maximum sessions reached",
code="MAX_SESSIONS",
).to_dict()
)
await websocket.close()
return
await websocket.send_json(
StatusMessage(
status="ready",
message="Session ready for audio",
).to_dict()
)
try:
while True:
# Receive message
message = await websocket.receive()
if message["type"] == "websocket.disconnect":
break
# Handle binary audio data
if "bytes" in message:
audio_data = message["bytes"]
try:
await session.process_audio_chunk(audio_data)
except Exception as e:
print(f"[{session_id}] Error processing audio: {e}")
import traceback
traceback.print_exc()
# Handle JSON control messages
elif "text" in message:
try:
data = json.loads(message["text"])
msg_type = data.get("type")
if msg_type == "config":
# Update session config
config = SessionConfig.from_dict(data.get("config", {}))
session.update_config(config)
await websocket.send_json(
StatusMessage(
status="config_updated",
message="Configuration updated",
).to_dict()
)
elif msg_type == "stop":
# Flush and close
await session.flush()
await websocket.send_json(
StatusMessage(
status="stopped",
message="Session stopped",
).to_dict()
)
break
elif msg_type == "ping":
await websocket.send_json({"type": "pong", "timestamp": time.time()})
except json.JSONDecodeError:
await websocket.send_json(
ErrorMessage(
error="Invalid JSON",
code="INVALID_JSON",
).to_dict()
)
except WebSocketDisconnect:
print(f"[{session_id}] Client disconnected")
except Exception as e:
print(f"[{session_id}] Error: {e}")
import traceback
traceback.print_exc()
finally:
# Clean up session
await session_manager.close_session(session_id)
print(f"[{session_id}] Session closed")
def main():
"""Main entry point."""
import argparse
parser = argparse.ArgumentParser(description="VibeVoice Realtime ASR Server")
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
parser.add_argument("--port", type=int, default=8000, help="Port to bind to")
parser.add_argument("--model-path", type=str, default="microsoft/VibeVoice-ASR",
help="Path to VibeVoice ASR model")
parser.add_argument("--device", type=str, default="cuda", help="Device (cuda/cpu)")
parser.add_argument("--max-sessions", type=int, default=10, help="Max concurrent sessions")
parser.add_argument("--no-preload", action="store_true", help="Don't preload model")
args = parser.parse_args()
# Set environment variables for lifespan
os.environ["VIBEVOICE_MODEL_PATH"] = args.model_path
os.environ["VIBEVOICE_DEVICE"] = args.device
os.environ["VIBEVOICE_MAX_SESSIONS"] = str(args.max_sessions)
os.environ["VIBEVOICE_PRELOAD_MODEL"] = "false" if args.no_preload else "true"
print(f"Starting server on {args.host}:{args.port}")
print(f"Model: {args.model_path}")
print(f"Device: {args.device}")
print(f"Max sessions: {args.max_sessions}")
uvicorn.run(
app,
host=args.host,
port=args.port,
log_level="info",
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,401 @@
"""
Session manager for real-time ASR.
Manages multiple concurrent client sessions with resource isolation.
"""
import asyncio
import time
import uuid
from typing import Dict, Optional, Callable, Any
from dataclasses import dataclass, field
import threading
from .models import (
SessionConfig,
TranscriptionResult,
VADEvent,
StatusMessage,
ErrorMessage,
MessageType,
SpeechSegment,
)
from .audio_buffer import AudioBuffer
from .vad_processor import VADProcessor
from .asr_worker import ASRWorker
@dataclass
class SessionStats:
"""Statistics for a session."""
created_at: float = field(default_factory=time.time)
last_activity: float = field(default_factory=time.time)
audio_received_sec: float = 0.0
chunks_received: int = 0
segments_transcribed: int = 0
total_latency_ms: float = 0.0
class RealtimeSession:
"""
A single real-time ASR session.
Manages audio buffering, VAD, and ASR for one client connection.
"""
def __init__(
self,
session_id: str,
asr_worker: ASRWorker,
config: Optional[SessionConfig] = None,
on_result: Optional[Callable[[TranscriptionResult], Any]] = None,
on_vad_event: Optional[Callable[[VADEvent], Any]] = None,
):
"""
Initialize a session.
Args:
session_id: Unique session identifier
asr_worker: Shared ASR worker instance
config: Session configuration
on_result: Callback for transcription results
on_vad_event: Callback for VAD events
"""
self.session_id = session_id
self.asr_worker = asr_worker
self.config = config or SessionConfig()
self.on_result = on_result
self.on_vad_event = on_vad_event
# Components
self.audio_buffer = AudioBuffer(
sample_rate=self.config.sample_rate,
chunk_duration_sec=self.config.chunk_duration_sec,
overlap_sec=self.config.overlap_sec,
)
self.vad_processor = VADProcessor(
sample_rate=self.config.sample_rate,
threshold=self.config.vad_threshold,
min_speech_duration_ms=self.config.min_speech_duration_ms,
min_silence_duration_ms=self.config.min_silence_duration_ms,
min_volume_threshold=self.config.min_volume_threshold,
)
# State
self.is_active = True
self.stats = SessionStats()
self._processing_lock = asyncio.Lock()
self._pending_tasks: list = []
async def process_audio_chunk(self, audio_data: bytes) -> None:
"""
Process an incoming audio chunk.
Args:
audio_data: Raw PCM audio data (16-bit, 16kHz, mono)
"""
if not self.is_active:
return
self.stats.last_activity = time.time()
self.stats.chunks_received += 1
# Convert bytes to float32 array
import numpy as np
audio_int16 = np.frombuffer(audio_data, dtype=np.int16)
audio_float = audio_int16.astype(np.float32) / 32768.0
self.stats.audio_received_sec += len(audio_float) / self.config.sample_rate
# Add to buffer
self.audio_buffer.append(audio_float)
# Process with VAD
segments, events = self.vad_processor.process(audio_float)
# Send VAD events
if self.on_vad_event:
for event in events:
await self._send_callback(self.on_vad_event, event)
# Process completed speech segments
for segment in segments:
try:
await self._transcribe_segment(segment)
except Exception as e:
print(f"[Session {self.session_id}] Transcription error: {e}")
import traceback
traceback.print_exc()
async def _transcribe_segment(self, segment: SpeechSegment) -> None:
"""Transcribe a detected speech segment."""
# Get audio for segment from buffer
audio = self.audio_buffer.get_segment(segment.start_sec, segment.end_sec)
if audio is None or len(audio) == 0:
print(f"[Session {self.session_id}] Could not retrieve audio for segment")
return
self.stats.segments_transcribed += 1
async def on_partial(result: TranscriptionResult):
if self.on_result:
await self._send_callback(self.on_result, result)
# Run transcription
result = await self.asr_worker.transcribe_segment(
audio=audio,
sample_rate=self.config.sample_rate,
context_info=self.config.context_info,
segment_start_sec=segment.start_sec,
segment_end_sec=segment.end_sec,
config=self.config,
on_partial=on_partial if self.config.return_partial_results else None,
)
self.stats.total_latency_ms += result.latency_ms
# Send final result
if self.on_result:
await self._send_callback(self.on_result, result)
async def _send_callback(self, callback: Callable, data: Any) -> None:
"""Send data via callback, handling both sync and async."""
try:
if asyncio.iscoroutinefunction(callback):
await callback(data)
else:
callback(data)
except Exception as e:
print(f"[Session {self.session_id}] Callback error: {e}")
async def flush(self) -> None:
"""
Flush any remaining audio and force transcription.
Called when session ends to process any remaining speech.
"""
# Force end any active speech
segment = self.vad_processor.force_end_speech()
if segment:
await self._transcribe_segment(segment)
# Also check for any unprocessed audio in buffer
chunk_info = self.audio_buffer.get_all_unprocessed()
if chunk_info and len(chunk_info.audio) > self.config.sample_rate * 0.5:
# More than 0.5 seconds of unprocessed audio
forced_segment = SpeechSegment(
start_sample=chunk_info.start_sample,
end_sample=chunk_info.end_sample,
start_sec=chunk_info.start_sec,
end_sec=chunk_info.end_sec,
)
await self._transcribe_segment(forced_segment)
def update_config(self, new_config: SessionConfig) -> None:
"""Update session configuration (partial update supported)."""
# Merge with existing config - only update non-default values
if new_config.vad_threshold != 0.5:
self.config.vad_threshold = new_config.vad_threshold
if new_config.min_speech_duration_ms != 250:
self.config.min_speech_duration_ms = new_config.min_speech_duration_ms
if new_config.min_silence_duration_ms != 500:
self.config.min_silence_duration_ms = new_config.min_silence_duration_ms
if new_config.min_volume_threshold != 0.01:
self.config.min_volume_threshold = new_config.min_volume_threshold
if new_config.context_info is not None:
self.config.context_info = new_config.context_info
# Recreate VAD processor with new parameters
self.vad_processor = VADProcessor(
sample_rate=self.config.sample_rate,
threshold=self.config.vad_threshold,
min_speech_duration_ms=self.config.min_speech_duration_ms,
min_silence_duration_ms=self.config.min_silence_duration_ms,
min_volume_threshold=self.config.min_volume_threshold,
)
print(f"[Session {self.session_id}] Config updated: vad_threshold={self.config.vad_threshold}, "
f"min_speech={self.config.min_speech_duration_ms}ms, min_silence={self.config.min_silence_duration_ms}ms, "
f"min_volume={self.config.min_volume_threshold}")
def close(self) -> None:
"""Close the session and release resources."""
self.is_active = False
self.audio_buffer.clear()
self.vad_processor.reset()
def get_stats(self) -> Dict:
"""Get session statistics."""
return {
"session_id": self.session_id,
"created_at": self.stats.created_at,
"last_activity": self.stats.last_activity,
"duration_sec": time.time() - self.stats.created_at,
"audio_received_sec": self.stats.audio_received_sec,
"chunks_received": self.stats.chunks_received,
"segments_transcribed": self.stats.segments_transcribed,
"avg_latency_ms": (
self.stats.total_latency_ms / self.stats.segments_transcribed
if self.stats.segments_transcribed > 0 else 0
),
"is_active": self.is_active,
"vad_speech_active": self.vad_processor.is_speech_active,
}
class SessionManager:
"""
Manages multiple concurrent ASR sessions.
Features:
- Session creation and cleanup
- Resource limiting (max concurrent sessions)
- Idle session timeout
- Shared ASR worker management
"""
def __init__(
self,
asr_worker: ASRWorker,
max_concurrent_sessions: int = 10,
session_timeout_sec: float = 300.0,
):
"""
Initialize the session manager.
Args:
asr_worker: Shared ASR worker
max_concurrent_sessions: Maximum number of concurrent sessions
session_timeout_sec: Timeout for idle sessions
"""
self.asr_worker = asr_worker
self.max_sessions = max_concurrent_sessions
self.session_timeout = session_timeout_sec
self._sessions: Dict[str, RealtimeSession] = {}
self._lock = asyncio.Lock()
# Cleanup task
self._cleanup_task: Optional[asyncio.Task] = None
async def start(self) -> None:
"""Start the session manager."""
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
async def stop(self) -> None:
"""Stop the session manager and close all sessions."""
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
async with self._lock:
for session in self._sessions.values():
session.close()
self._sessions.clear()
async def create_session(
self,
session_id: Optional[str] = None,
config: Optional[SessionConfig] = None,
on_result: Optional[Callable[[TranscriptionResult], Any]] = None,
on_vad_event: Optional[Callable[[VADEvent], Any]] = None,
) -> Optional[RealtimeSession]:
"""
Create a new session.
Args:
session_id: Optional session ID (generated if not provided)
config: Session configuration
on_result: Callback for results
on_vad_event: Callback for VAD events
Returns:
Created session, or None if limit reached
"""
async with self._lock:
# Check session limit
if len(self._sessions) >= self.max_sessions:
return None
# Generate session ID if not provided
if session_id is None:
session_id = str(uuid.uuid4())[:8]
# Check for duplicate
if session_id in self._sessions:
return self._sessions[session_id]
# Create session
session = RealtimeSession(
session_id=session_id,
asr_worker=self.asr_worker,
config=config,
on_result=on_result,
on_vad_event=on_vad_event,
)
self._sessions[session_id] = session
return session
async def get_session(self, session_id: str) -> Optional[RealtimeSession]:
"""Get a session by ID."""
async with self._lock:
return self._sessions.get(session_id)
async def close_session(self, session_id: str) -> bool:
"""
Close and remove a session.
Args:
session_id: Session to close
Returns:
True if session was found and closed
"""
async with self._lock:
session = self._sessions.pop(session_id, None)
if session:
await session.flush()
session.close()
return True
return False
async def _cleanup_loop(self) -> None:
"""Background task to clean up idle sessions."""
while True:
try:
await asyncio.sleep(60) # Check every minute
current_time = time.time()
sessions_to_close = []
async with self._lock:
for session_id, session in self._sessions.items():
idle_time = current_time - session.stats.last_activity
if idle_time > self.session_timeout:
sessions_to_close.append(session_id)
for session_id in sessions_to_close:
print(f"Closing idle session: {session_id}")
await self.close_session(session_id)
except asyncio.CancelledError:
break
except Exception as e:
print(f"Cleanup error: {e}")
def get_stats(self) -> Dict:
"""Get manager statistics."""
return {
"active_sessions": len(self._sessions),
"max_sessions": self.max_sessions,
"session_timeout_sec": self.session_timeout,
"sessions": {
sid: session.get_stats()
for sid, session in self._sessions.items()
},
}

View File

@ -0,0 +1,295 @@
"""
Voice Activity Detection (VAD) processor using Silero-VAD (ONNX version).
Detects speech segments in real-time audio streams.
Uses ONNX runtime to avoid torchaudio dependency issues.
"""
import numpy as np
from typing import List, Optional, Tuple
from dataclasses import dataclass
import threading
import os
import urllib.request
from .models import SpeechSegment, VADEvent, VADEventType, MessageType
@dataclass
class VADState:
"""Internal state of the VAD processor."""
is_speech_active: bool = False
speech_start_sample: int = 0
silence_start_sample: int = 0
last_speech_prob: float = 0.0
total_samples_processed: int = 0
class VADProcessor:
"""
Voice Activity Detection using Silero-VAD (ONNX version).
Features:
- Real-time speech detection
- Configurable thresholds for speech/silence duration
- Event generation for speech start/end
- Thread-safe operations
- No torchaudio dependency (uses ONNX runtime)
"""
# Silero VAD ONNX model URL
ONNX_MODEL_URL = "https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx"
def __init__(
self,
sample_rate: int = 16000,
threshold: float = 0.5,
min_speech_duration_ms: int = 250,
min_silence_duration_ms: int = 500,
window_size_samples: int = 512,
min_volume_threshold: float = 0.01,
):
"""
Initialize the VAD processor.
Args:
sample_rate: Audio sample rate (must be 16000 for Silero-VAD)
threshold: Speech probability threshold (0.0-1.0)
min_speech_duration_ms: Minimum speech duration to trigger speech_start
min_silence_duration_ms: Minimum silence duration to trigger speech_end
window_size_samples: VAD window size (512 for 16kHz = 32ms)
min_volume_threshold: Minimum RMS volume (0.0-1.0) to consider as potential speech
"""
if sample_rate != 16000:
raise ValueError("Silero-VAD requires 16kHz sample rate")
self.sample_rate = sample_rate
self.threshold = threshold
self.min_speech_samples = int(min_speech_duration_ms * sample_rate / 1000)
self.min_silence_samples = int(min_silence_duration_ms * sample_rate / 1000)
self.window_size = window_size_samples
self.min_volume_threshold = min_volume_threshold
# Load ONNX model
self._session = None
self._load_model()
# ONNX model state - single state tensor (size depends on model version)
# Silero VAD v5 uses a single 'state' tensor of shape (2, 1, 128)
self._state_tensor = np.zeros((2, 1, 128), dtype=np.float32)
# State
self._state = VADState()
self._lock = threading.Lock()
# Pending speech segment (being accumulated)
self._pending_segment_start: Optional[int] = None
def _get_model_path(self) -> str:
"""Get path to ONNX model, downloading if necessary."""
cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "silero-vad")
os.makedirs(cache_dir, exist_ok=True)
model_path = os.path.join(cache_dir, "silero_vad.onnx")
if not os.path.exists(model_path):
print(f"Downloading Silero-VAD ONNX model to {model_path}...")
urllib.request.urlretrieve(self.ONNX_MODEL_URL, model_path)
print("Download complete.")
return model_path
def _load_model(self) -> None:
"""Load Silero-VAD ONNX model."""
try:
import onnxruntime as ort
model_path = self._get_model_path()
self._session = ort.InferenceSession(
model_path,
providers=['CPUExecutionProvider']
)
print(f"Silero-VAD ONNX model loaded from {model_path}")
except Exception as e:
raise RuntimeError(f"Failed to load Silero-VAD ONNX model: {e}")
def _run_inference(self, audio_window: np.ndarray) -> float:
"""Run VAD inference on a single window."""
# Prepare input
audio_input = audio_window.reshape(1, -1).astype(np.float32)
sr_input = np.array([self.sample_rate], dtype=np.int64)
# Run inference - Silero VAD v5 uses 'state' instead of 'h'/'c'
outputs = self._session.run(
['output', 'stateN'],
{
'input': audio_input,
'sr': sr_input,
'state': self._state_tensor,
}
)
# Update state
speech_prob = outputs[0][0][0]
self._state_tensor = outputs[1]
return float(speech_prob)
def reset(self) -> None:
"""Reset VAD state for a new session."""
with self._lock:
self._state = VADState()
self._pending_segment_start = None
# Reset state tensor
self._state_tensor = np.zeros((2, 1, 128), dtype=np.float32)
def process(
self,
audio_chunk: np.ndarray,
return_events: bool = True,
) -> Tuple[List[SpeechSegment], List[VADEvent]]:
"""
Process an audio chunk and detect speech segments.
Args:
audio_chunk: Audio data as float32 array
return_events: Whether to return VAD events
Returns:
Tuple of (completed_segments, events)
"""
if audio_chunk.dtype != np.float32:
audio_chunk = audio_chunk.astype(np.float32)
completed_segments: List[SpeechSegment] = []
events: List[VADEvent] = []
with self._lock:
# Process in windows
chunk_start_sample = self._state.total_samples_processed
num_windows = len(audio_chunk) // self.window_size
for i in range(num_windows):
window_start = i * self.window_size
window_end = window_start + self.window_size
window = audio_chunk[window_start:window_end]
# Check volume (RMS) threshold first
rms = np.sqrt(np.mean(window ** 2))
if rms < self.min_volume_threshold:
# Volume too low, treat as silence
speech_prob = 0.0
else:
# Get speech probability from VAD model
speech_prob = self._run_inference(window)
self._state.last_speech_prob = speech_prob
current_sample = chunk_start_sample + window_end
is_speech = speech_prob >= self.threshold
# State machine for speech detection
if is_speech:
if not self._state.is_speech_active:
# Potential speech start
if self._pending_segment_start is None:
self._pending_segment_start = current_sample - self.window_size
# Check if speech duration exceeds minimum
speech_duration = current_sample - self._pending_segment_start
if speech_duration >= self.min_speech_samples:
self._state.is_speech_active = True
self._state.speech_start_sample = self._pending_segment_start
if return_events:
events.append(VADEvent(
type=MessageType.VAD_EVENT,
event=VADEventType.SPEECH_START,
audio_timestamp_sec=self._pending_segment_start / self.sample_rate,
))
else:
# Continue speech, reset silence counter
self._state.silence_start_sample = 0
else:
if self._state.is_speech_active:
# Potential speech end
if self._state.silence_start_sample == 0:
self._state.silence_start_sample = current_sample
# Check if silence duration exceeds minimum
silence_duration = current_sample - self._state.silence_start_sample
if silence_duration >= self.min_silence_samples:
# Speech ended - create completed segment
segment = SpeechSegment(
start_sample=self._state.speech_start_sample,
end_sample=self._state.silence_start_sample,
start_sec=self._state.speech_start_sample / self.sample_rate,
end_sec=self._state.silence_start_sample / self.sample_rate,
)
completed_segments.append(segment)
if return_events:
events.append(VADEvent(
type=MessageType.VAD_EVENT,
event=VADEventType.SPEECH_END,
audio_timestamp_sec=self._state.silence_start_sample / self.sample_rate,
))
# Reset state
self._state.is_speech_active = False
self._state.speech_start_sample = 0
self._state.silence_start_sample = 0
self._pending_segment_start = None
else:
# No speech, reset pending
self._pending_segment_start = None
# Update total samples processed
self._state.total_samples_processed += len(audio_chunk)
return completed_segments, events
def force_end_speech(self) -> Optional[SpeechSegment]:
"""
Force end of current speech segment (e.g., when session ends).
Returns:
Completed speech segment if speech was active, None otherwise
"""
with self._lock:
if self._state.is_speech_active:
segment = SpeechSegment(
start_sample=self._state.speech_start_sample,
end_sample=self._state.total_samples_processed,
start_sec=self._state.speech_start_sample / self.sample_rate,
end_sec=self._state.total_samples_processed / self.sample_rate,
)
self._state.is_speech_active = False
self._state.speech_start_sample = 0
self._state.silence_start_sample = 0
self._pending_segment_start = None
return segment
return None
@property
def is_speech_active(self) -> bool:
"""Check if speech is currently active."""
with self._lock:
return self._state.is_speech_active
@property
def last_speech_probability(self) -> float:
"""Get the last computed speech probability."""
with self._lock:
return self._state.last_speech_prob
@property
def current_speech_duration_sec(self) -> float:
"""Get duration of current speech segment (if active)."""
with self._lock:
if not self._state.is_speech_active:
return 0.0
return (self._state.total_samples_processed - self._state.speech_start_sample) / self.sample_rate

View File

@ -0,0 +1,7 @@
# Real-time ASR dependencies
fastapi>=0.100.0
uvicorn[standard]>=0.23.0
websockets>=11.0
numpy>=1.24.0
soundfile>=0.12.0
onnxruntime

View File

@ -0,0 +1,73 @@
#!/bin/bash
# Run both Gradio demo and Realtime ASR server
#
# Usage:
# ./run_all.sh
#
# Ports:
# - 7860: Gradio UI (batch ASR)
# - 8000: WebSocket API (realtime ASR)
set -e
cd "$(dirname "$0")"
# Configuration
GRADIO_HOST="${GRADIO_HOST:-0.0.0.0}"
GRADIO_PORT="${GRADIO_PORT:-7860}"
REALTIME_HOST="${REALTIME_HOST:-0.0.0.0}"
REALTIME_PORT="${REALTIME_PORT:-8000}"
MODEL_PATH="${VIBEVOICE_MODEL_PATH:-microsoft/VibeVoice-ASR}"
echo "=========================================="
echo "VibeVoice ASR - All Services"
echo "=========================================="
echo ""
echo "Starting services:"
echo " - Gradio UI: http://$GRADIO_HOST:$GRADIO_PORT"
echo " - Realtime ASR: http://$REALTIME_HOST:$REALTIME_PORT"
echo " - Test Client: http://$REALTIME_HOST:$REALTIME_PORT/static/realtime_client.html"
echo ""
echo "Model: $MODEL_PATH"
echo "=========================================="
echo ""
# Trap to clean up background processes on exit
cleanup() {
echo ""
echo "Shutting down..."
kill $REALTIME_PID 2>/dev/null || true
kill $GRADIO_PID 2>/dev/null || true
wait
echo "All services stopped."
}
trap cleanup EXIT INT TERM
# Start Realtime ASR server in background
echo "[1/2] Starting Realtime ASR server..."
python -m realtime.server \
--host "$REALTIME_HOST" \
--port "$REALTIME_PORT" \
--model-path "$MODEL_PATH" \
--no-preload &
REALTIME_PID=$!
# Wait a moment for the server to initialize
sleep 2
# Start Gradio demo in background
echo "[2/2] Starting Gradio demo..."
python demo/vibevoice_asr_gradio_demo.py \
--host "$GRADIO_HOST" \
--port "$GRADIO_PORT" \
--model_path "$MODEL_PATH" &
GRADIO_PID=$!
echo ""
echo "Both services started. Press Ctrl+C to stop."
echo ""
# Wait for either process to exit
wait -n $REALTIME_PID $GRADIO_PID
# If one exits, the trap will clean up the other

View File

@ -0,0 +1,41 @@
#!/bin/bash
# Run VibeVoice Realtime ASR Server
#
# Usage:
# ./run_realtime.sh [options]
#
# Options are passed to the server (see --help for details)
set -e
cd "$(dirname "$0")"
# Default options
HOST="${VIBEVOICE_HOST:-0.0.0.0}"
PORT="${VIBEVOICE_PORT:-8000}"
MODEL_PATH="${VIBEVOICE_MODEL_PATH:-microsoft/VibeVoice-ASR}"
DEVICE="${VIBEVOICE_DEVICE:-cuda}"
MAX_SESSIONS="${VIBEVOICE_MAX_SESSIONS:-10}"
echo "=========================================="
echo "VibeVoice Realtime ASR Server"
echo "=========================================="
echo "Host: $HOST"
echo "Port: $PORT"
echo "Model: $MODEL_PATH"
echo "Device: $DEVICE"
echo "Max Sessions: $MAX_SESSIONS"
echo "=========================================="
echo ""
echo "Web client: http://$HOST:$PORT/static/realtime_client.html"
echo "WebSocket: ws://$HOST:$PORT/ws/asr/{session_id}"
echo ""
# Run server
python -m realtime.server \
--host "$HOST" \
--port "$PORT" \
--model-path "$MODEL_PATH" \
--device "$DEVICE" \
--max-sessions "$MAX_SESSIONS" \
"$@"

View File

@ -0,0 +1,287 @@
#!/bin/bash
# VibeVoice-ASR Setup Script for DGX Spark
# Downloads and builds the VibeVoice-ASR container
#
# Usage:
# curl -sL https://docs.techswan.online/scripts/vibevoice-asr/setup.sh | bash
# curl -sL https://docs.techswan.online/scripts/vibevoice-asr/setup.sh | bash -s build
# curl -sL https://docs.techswan.online/scripts/vibevoice-asr/setup.sh | bash -s serve
set -e
BASE_URL="https://docs.techswan.online/scripts/vibevoice-asr"
INSTALL_DIR="${VIBEVOICE_DIR:-$HOME/vibevoice-asr}"
IMAGE_NAME="vibevoice-asr:dgx-spark"
CONTAINER_NAME="vibevoice-asr"
# Colors for output
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
BLUE='\033[0;34m'
NC='\033[0m' # No Color
log_info() {
echo -e "${GREEN}[INFO]${NC} $1"
}
log_warn() {
echo -e "${YELLOW}[WARN]${NC} $1"
}
log_error() {
echo -e "${RED}[ERROR]${NC} $1"
}
log_step() {
echo -e "${BLUE}[STEP]${NC} $1"
}
download_file() {
local url="$1"
local dest="$2"
local dir=$(dirname "$dest")
mkdir -p "$dir"
if command -v curl &> /dev/null; then
curl -sL "$url" -o "$dest"
elif command -v wget &> /dev/null; then
wget -q "$url" -O "$dest"
else
log_error "curl or wget is required"
exit 1
fi
}
download_files() {
log_step "Downloading VibeVoice-ASR files to $INSTALL_DIR..."
mkdir -p "$INSTALL_DIR"
cd "$INSTALL_DIR"
# Core files
local files=(
"Dockerfile"
"requirements-realtime.txt"
"test_vibevoice.py"
"vibevoice_asr_gradio_demo_patched.py"
"run_realtime.sh"
"run_all.sh"
)
for file in "${files[@]}"; do
log_info "Downloading $file..."
download_file "$BASE_URL/$file" "$INSTALL_DIR/$file"
done
# Realtime module
local realtime_files=(
"__init__.py"
"models.py"
"server.py"
"asr_worker.py"
"session_manager.py"
"audio_buffer.py"
"vad_processor.py"
)
mkdir -p "$INSTALL_DIR/realtime"
for file in "${realtime_files[@]}"; do
log_info "Downloading realtime/$file..."
download_file "$BASE_URL/realtime/$file" "$INSTALL_DIR/realtime/$file"
done
# Static files
mkdir -p "$INSTALL_DIR/static"
log_info "Downloading static/realtime_client.html..."
download_file "$BASE_URL/static/realtime_client.html" "$INSTALL_DIR/static/realtime_client.html"
# Make scripts executable
chmod +x "$INSTALL_DIR/run_realtime.sh" "$INSTALL_DIR/run_all.sh"
log_info "All files downloaded to $INSTALL_DIR"
}
check_prerequisites() {
log_step "Checking prerequisites..."
# Check Docker
if ! command -v docker &> /dev/null; then
log_error "Docker is not installed"
exit 1
fi
# Check NVIDIA Docker runtime
if ! docker info 2>/dev/null | grep -q "Runtimes.*nvidia"; then
log_warn "NVIDIA Docker runtime may not be configured"
fi
# Check GPU availability
if command -v nvidia-smi &> /dev/null; then
log_info "GPU detected:"
nvidia-smi --query-gpu=name,memory.total --format=csv,noheader
else
log_warn "nvidia-smi not found on host"
fi
log_info "Prerequisites check complete"
}
build_image() {
log_step "Building Docker image: ${IMAGE_NAME}"
log_info "This may take several minutes..."
cd "$INSTALL_DIR"
docker build \
--network=host \
-t "$IMAGE_NAME" \
-f Dockerfile \
.
log_info "Docker image built successfully: ${IMAGE_NAME}"
}
run_container() {
local mode="${1:-interactive}"
log_step "Running container in ${mode} mode..."
# Stop existing container if running
if docker ps -q -f name="$CONTAINER_NAME" | grep -q .; then
log_warn "Stopping existing container..."
docker stop "$CONTAINER_NAME" 2>/dev/null || true
fi
# Remove existing container
docker rm "$CONTAINER_NAME" 2>/dev/null || true
# Common Docker options for DGX Spark
local docker_opts=(
--gpus all
--ipc=host
--network=host
--ulimit memlock=-1:-1
--ulimit stack=-1:-1
-e "PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True"
-v "$HOME/.cache/huggingface:/root/.cache/huggingface"
--name "$CONTAINER_NAME"
)
if [ "$mode" = "interactive" ]; then
docker run --rm -it "${docker_opts[@]}" "$IMAGE_NAME" bash
elif [ "$mode" = "test" ]; then
docker run --rm "${docker_opts[@]}" "$IMAGE_NAME" python /workspace/test_vibevoice.py
elif [ "$mode" = "demo" ]; then
log_info "Starting Gradio demo on port 7860..."
log_info "Access the demo at: http://localhost:7860"
docker run --rm -it "${docker_opts[@]}" "$IMAGE_NAME"
elif [ "$mode" = "realtime" ]; then
log_info "Starting Realtime ASR server on port 8000..."
log_info "WebSocket API: ws://localhost:8000/ws/asr/{session_id}"
log_info "Test client: http://localhost:8000/static/realtime_client.html"
docker run --rm -it "${docker_opts[@]}" "$IMAGE_NAME" \
python -m realtime.server --host 0.0.0.0 --port 8000
elif [ "$mode" = "serve" ]; then
log_info "Starting all services..."
log_info " Gradio demo: http://localhost:7860"
log_info " Realtime ASR: http://localhost:8000"
log_info " Test client: http://localhost:8000/static/realtime_client.html"
docker run --rm -it "${docker_opts[@]}" "$IMAGE_NAME" ./run_all.sh
else
log_error "Unknown mode: $mode"
exit 1
fi
}
show_usage() {
echo "VibeVoice-ASR Setup for DGX Spark"
echo ""
echo "Usage:"
echo " curl -sL $BASE_URL/setup.sh | bash # Download only"
echo " curl -sL $BASE_URL/setup.sh | bash -s build # Download and build"
echo " curl -sL $BASE_URL/setup.sh | bash -s demo # Download, build, run demo"
echo " curl -sL $BASE_URL/setup.sh | bash -s serve # Download, build, run all"
echo ""
echo "Commands:"
echo " (default) Download files only"
echo " build Download and build Docker image"
echo " demo Download, build, and start Gradio demo (port 7860)"
echo " realtime Download, build, and start Realtime ASR (port 8000)"
echo " serve Download, build, and start both services"
echo " run Run container interactively (after build)"
echo ""
echo "Environment variables:"
echo " VIBEVOICE_DIR Installation directory (default: ~/vibevoice-asr)"
echo ""
echo "After installation, you can also run:"
echo " cd ~/vibevoice-asr"
echo " docker run --gpus all -p 7860:7860 vibevoice-asr:dgx-spark"
}
main() {
local command="${1:-download}"
echo ""
echo "=========================================="
echo " VibeVoice-ASR Setup for DGX Spark"
echo "=========================================="
echo ""
case "$command" in
download)
download_files
echo ""
log_info "Done! Next steps:"
echo " cd $INSTALL_DIR"
echo " docker build -t vibevoice-asr:dgx-spark ."
echo " docker run --gpus all -p 7860:7860 vibevoice-asr:dgx-spark"
;;
build)
download_files
check_prerequisites
build_image
echo ""
log_info "Done! To run:"
echo " cd $INSTALL_DIR"
echo " docker run --gpus all -p 7860:7860 vibevoice-asr:dgx-spark"
;;
run)
cd "$INSTALL_DIR" 2>/dev/null || { log_error "Run 'build' first"; exit 1; }
run_container interactive
;;
test)
cd "$INSTALL_DIR" 2>/dev/null || { log_error "Run 'build' first"; exit 1; }
run_container test
;;
demo)
download_files
check_prerequisites
build_image
run_container demo
;;
realtime)
download_files
check_prerequisites
build_image
run_container realtime
;;
serve)
download_files
check_prerequisites
build_image
run_container serve
;;
-h|--help|help)
show_usage
;;
*)
log_error "Unknown command: $command"
show_usage
exit 1
;;
esac
}
main "$@"

View File

@ -0,0 +1,899 @@
<!DOCTYPE html>
<html lang="ja">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>VibeVoice Realtime ASR Client</title>
<style>
:root {
--bg-primary: #1a1a2e;
--bg-secondary: #16213e;
--bg-tertiary: #0f3460;
--text-primary: #eaeaea;
--text-secondary: #a0a0a0;
--accent: #e94560;
--success: #4ade80;
--warning: #fbbf24;
--info: #60a5fa;
}
* {
box-sizing: border-box;
margin: 0;
padding: 0;
}
body {
font-family: 'Segoe UI', system-ui, sans-serif;
background: var(--bg-primary);
color: var(--text-primary);
min-height: 100vh;
padding: 20px;
}
.container {
max-width: 1200px;
margin: 0 auto;
}
header {
text-align: center;
margin-bottom: 30px;
}
h1 {
font-size: 2rem;
margin-bottom: 10px;
}
.subtitle {
color: var(--text-secondary);
}
.main-grid {
display: grid;
grid-template-columns: 1fr 1fr;
gap: 20px;
}
@media (max-width: 768px) {
.main-grid {
grid-template-columns: 1fr;
}
}
.card {
background: var(--bg-secondary);
border-radius: 12px;
padding: 20px;
}
.card-title {
font-size: 1.1rem;
margin-bottom: 15px;
display: flex;
align-items: center;
gap: 8px;
}
/* Controls */
.controls {
display: flex;
flex-direction: column;
gap: 15px;
}
.control-row {
display: flex;
gap: 10px;
align-items: center;
}
input[type="text"], input[type="number"] {
background: var(--bg-tertiary);
border: 1px solid rgba(255,255,255,0.1);
border-radius: 8px;
padding: 10px 15px;
color: var(--text-primary);
flex: 1;
}
input:focus {
outline: none;
border-color: var(--accent);
}
button {
background: var(--accent);
border: none;
border-radius: 8px;
padding: 12px 24px;
color: white;
font-weight: 600;
cursor: pointer;
transition: all 0.2s;
}
button:hover {
filter: brightness(1.1);
}
button:disabled {
opacity: 0.5;
cursor: not-allowed;
}
button.secondary {
background: var(--bg-tertiary);
}
button.success {
background: var(--success);
color: #000;
}
button.stop {
background: #ef4444;
}
/* Status */
.status-bar {
display: flex;
align-items: center;
gap: 10px;
padding: 10px 15px;
background: var(--bg-tertiary);
border-radius: 8px;
margin-bottom: 15px;
}
.status-indicator {
width: 12px;
height: 12px;
border-radius: 50%;
background: var(--text-secondary);
}
.status-indicator.connected {
background: var(--success);
box-shadow: 0 0 10px var(--success);
}
.status-indicator.recording {
background: var(--accent);
box-shadow: 0 0 10px var(--accent);
animation: pulse 1s infinite;
}
@keyframes pulse {
0%, 100% { opacity: 1; }
50% { opacity: 0.5; }
}
/* Transcription output */
.transcription-box {
background: var(--bg-tertiary);
border-radius: 8px;
padding: 15px;
min-height: 300px;
max-height: 500px;
overflow-y: auto;
font-family: 'Consolas', monospace;
line-height: 1.6;
}
.segment {
margin-bottom: 15px;
padding: 10px;
background: rgba(255,255,255,0.05);
border-radius: 6px;
border-left: 3px solid var(--accent);
}
.segment.partial {
border-left-color: var(--warning);
opacity: 0.8;
}
.segment-meta {
font-size: 0.8rem;
color: var(--text-secondary);
margin-bottom: 5px;
}
.segment-text {
font-size: 1rem;
}
/* Audio visualizer */
.visualizer-container {
height: 80px;
background: var(--bg-tertiary);
border-radius: 8px;
overflow: hidden;
margin-bottom: 15px;
}
#visualizer {
width: 100%;
height: 100%;
}
/* VAD indicator */
.vad-indicator {
display: flex;
align-items: center;
gap: 10px;
padding: 10px;
background: var(--bg-tertiary);
border-radius: 8px;
margin-bottom: 15px;
}
.vad-bar {
flex: 1;
height: 8px;
background: rgba(255,255,255,0.1);
border-radius: 4px;
overflow: hidden;
}
.vad-level {
height: 100%;
background: var(--success);
width: 0%;
transition: width 0.1s;
}
.vad-level.speech {
background: var(--accent);
}
/* Stats */
.stats {
display: grid;
grid-template-columns: repeat(3, 1fr);
gap: 10px;
margin-top: 15px;
}
.stat-item {
background: var(--bg-tertiary);
padding: 10px;
border-radius: 8px;
text-align: center;
}
.stat-value {
font-size: 1.5rem;
font-weight: bold;
color: var(--accent);
}
.stat-label {
font-size: 0.75rem;
color: var(--text-secondary);
}
/* Log */
.log-box {
background: var(--bg-tertiary);
border-radius: 8px;
padding: 10px;
height: 150px;
overflow-y: auto;
font-family: 'Consolas', monospace;
font-size: 0.8rem;
}
.log-entry {
margin-bottom: 2px;
color: var(--text-secondary);
}
.log-entry.error {
color: #ef4444;
}
.log-entry.success {
color: var(--success);
}
.log-entry.info {
color: var(--info);
}
/* Settings panel */
.settings-group {
margin-top: 15px;
padding: 15px;
background: var(--bg-tertiary);
border-radius: 8px;
}
.settings-group h4 {
margin-bottom: 10px;
font-size: 0.9rem;
color: var(--text-secondary);
}
.setting-item {
margin-bottom: 12px;
}
.setting-item label {
display: block;
font-size: 0.85rem;
margin-bottom: 5px;
color: var(--text-secondary);
}
.setting-item input[type="range"] {
width: 100%;
height: 6px;
border-radius: 3px;
background: rgba(255,255,255,0.1);
outline: none;
-webkit-appearance: none;
}
.setting-item input[type="range"]::-webkit-slider-thumb {
-webkit-appearance: none;
width: 16px;
height: 16px;
border-radius: 50%;
background: var(--accent);
cursor: pointer;
}
.setting-value {
font-size: 0.8rem;
color: var(--accent);
float: right;
}
</style>
</head>
<body>
<div class="container">
<header>
<h1>🎙️ VibeVoice Realtime ASR</h1>
<p class="subtitle">リアルタイム音声認識デモ</p>
</header>
<div class="status-bar">
<div class="status-indicator" id="statusIndicator"></div>
<span id="statusText">未接続</span>
</div>
<div class="main-grid">
<!-- Left column: Controls -->
<div class="card">
<h2 class="card-title">⚙️ 設定</h2>
<div class="controls">
<div class="control-row">
<input type="text" id="serverUrl" placeholder="WebSocket URL"
value="ws://localhost:8000/ws/asr/demo">
</div>
<div class="control-row">
<button id="connectBtn" onclick="toggleConnection()">接続</button>
<button id="recordBtn" class="success" onclick="toggleRecording()" disabled>
🎤 録音開始
</button>
</div>
</div>
<div class="visualizer-container">
<canvas id="visualizer"></canvas>
</div>
<div class="vad-indicator">
<span>VAD:</span>
<div class="vad-bar">
<div class="vad-level" id="vadLevel"></div>
</div>
<span id="vadStatus">待機中</span>
</div>
<div class="settings-group">
<h4>🎚️ VAD設定</h4>
<div class="setting-item">
<label>
音声検出閾値
<span class="setting-value" id="vadThresholdValue">0.5</span>
</label>
<input type="range" id="vadThreshold" min="0.1" max="0.9" step="0.1" value="0.5"
onchange="updateConfig()">
</div>
<div class="setting-item">
<label>
最小発話時間 (ms)
<span class="setting-value" id="minSpeechValue">250</span>
</label>
<input type="range" id="minSpeechDuration" min="100" max="1000" step="50" value="250"
onchange="updateConfig()">
</div>
<div class="setting-item">
<label>
無音判定時間 (ms)
<span class="setting-value" id="minSilenceValue">500</span>
</label>
<input type="range" id="minSilenceDuration" min="200" max="2000" step="100" value="500"
onchange="updateConfig()">
</div>
<div class="setting-item">
<label>
最小音量閾値
<span class="setting-value" id="minVolumeValue">0.01</span>
</label>
<input type="range" id="minVolumeThreshold" min="0.001" max="0.1" step="0.001" value="0.01"
onchange="updateConfig()">
</div>
</div>
<div class="stats">
<div class="stat-item">
<div class="stat-value" id="statDuration">0.0</div>
<div class="stat-label">録音時間 (秒)</div>
</div>
<div class="stat-item">
<div class="stat-value" id="statSegments">0</div>
<div class="stat-label">セグメント数</div>
</div>
<div class="stat-item">
<div class="stat-value" id="statLatency">-</div>
<div class="stat-label">レイテンシ (ms)</div>
</div>
</div>
<h3 class="card-title" style="margin-top: 20px;">📋 ログ</h3>
<div class="log-box" id="logBox"></div>
</div>
<!-- Right column: Transcription -->
<div class="card">
<h2 class="card-title">📝 認識結果</h2>
<div class="control-row" style="margin-bottom: 15px;">
<button class="secondary" onclick="clearTranscription()">クリア</button>
<button class="secondary" onclick="copyTranscription()">コピー</button>
</div>
<div class="transcription-box" id="transcriptionBox">
<p style="color: var(--text-secondary); text-align: center; padding: 50px;">
接続して録音を開始すると、ここに認識結果が表示されます
</p>
</div>
</div>
</div>
</div>
<script>
// State
let websocket = null;
let mediaStream = null;
let audioContext = null;
let processor = null;
let analyser = null;
let isRecording = false;
let recordingStartTime = null;
let segmentCount = 0;
let currentPartialText = '';
// DOM elements
const statusIndicator = document.getElementById('statusIndicator');
const statusText = document.getElementById('statusText');
const connectBtn = document.getElementById('connectBtn');
const recordBtn = document.getElementById('recordBtn');
const transcriptionBox = document.getElementById('transcriptionBox');
const logBox = document.getElementById('logBox');
const vadLevel = document.getElementById('vadLevel');
const vadStatus = document.getElementById('vadStatus');
// Logging
function log(message, type = 'info') {
const entry = document.createElement('div');
entry.className = `log-entry ${type}`;
entry.textContent = `[${new Date().toLocaleTimeString()}] ${message}`;
logBox.appendChild(entry);
logBox.scrollTop = logBox.scrollHeight;
console.log(`[${type}] ${message}`);
}
// Update config display values
function updateConfigDisplay() {
document.getElementById('vadThresholdValue').textContent =
document.getElementById('vadThreshold').value;
document.getElementById('minSpeechValue').textContent =
document.getElementById('minSpeechDuration').value;
document.getElementById('minSilenceValue').textContent =
document.getElementById('minSilenceDuration').value;
document.getElementById('minVolumeValue').textContent =
document.getElementById('minVolumeThreshold').value;
}
// Send config to server
function updateConfig() {
updateConfigDisplay();
if (websocket && websocket.readyState === WebSocket.OPEN) {
const config = {
type: 'config',
config: {
vad_threshold: parseFloat(document.getElementById('vadThreshold').value),
min_speech_duration_ms: parseInt(document.getElementById('minSpeechDuration').value),
min_silence_duration_ms: parseInt(document.getElementById('minSilenceDuration').value),
min_volume_threshold: parseFloat(document.getElementById('minVolumeThreshold').value),
}
};
websocket.send(JSON.stringify(config));
log(`VAD設定を更新: 閾値=${config.config.vad_threshold}, 最小発話=${config.config.min_speech_duration_ms}ms, 無音判定=${config.config.min_silence_duration_ms}ms, 最小音量=${config.config.min_volume_threshold}`, 'info');
}
}
// Initialize config display on load
document.addEventListener('DOMContentLoaded', updateConfigDisplay);
// Connection
function toggleConnection() {
if (websocket && websocket.readyState === WebSocket.OPEN) {
disconnect();
} else {
connect();
}
}
function connect() {
const url = document.getElementById('serverUrl').value;
log(`接続中: ${url}`, 'info');
try {
websocket = new WebSocket(url);
websocket.onopen = () => {
log('接続成功', 'success');
statusIndicator.classList.add('connected');
statusText.textContent = '接続済み';
connectBtn.textContent = '切断';
recordBtn.disabled = false;
};
websocket.onclose = () => {
log('切断されました', 'info');
handleDisconnect();
};
websocket.onerror = (error) => {
log(`エラー: ${error}`, 'error');
};
websocket.onmessage = (event) => {
handleMessage(JSON.parse(event.data));
};
} catch (error) {
log(`接続エラー: ${error}`, 'error');
}
}
function disconnect() {
if (isRecording) {
stopRecording();
}
if (websocket) {
websocket.close();
}
handleDisconnect();
}
function handleDisconnect() {
statusIndicator.classList.remove('connected', 'recording');
statusText.textContent = '未接続';
connectBtn.textContent = '接続';
recordBtn.disabled = true;
websocket = null;
}
// Message handling
function handleMessage(data) {
switch (data.type) {
case 'status':
log(`ステータス: ${data.message}`, 'info');
break;
case 'partial_result':
updatePartialResult(data);
break;
case 'final_result':
addFinalResult(data);
break;
case 'vad_event':
handleVADEvent(data);
break;
case 'error':
log(`サーバーエラー: ${data.error}`, 'error');
break;
case 'pong':
// Heartbeat response
break;
default:
log(`不明なメッセージ: ${data.type}`, 'info');
}
}
function updatePartialResult(data) {
currentPartialText = data.text;
updateTranscriptionDisplay();
if (data.latency_ms) {
document.getElementById('statLatency').textContent =
Math.round(data.latency_ms);
}
}
function addFinalResult(data) {
currentPartialText = '';
segmentCount++;
document.getElementById('statSegments').textContent = segmentCount;
if (data.latency_ms) {
document.getElementById('statLatency').textContent =
Math.round(data.latency_ms);
}
// Add final segment to display
const segment = document.createElement('div');
segment.className = 'segment';
let metaText = '';
if (data.segments && data.segments.length > 0) {
const seg = data.segments[0];
metaText = `[${seg.start_time?.toFixed(2) || '?'}s - ${seg.end_time?.toFixed(2) || '?'}s] ${seg.speaker_id || ''}`;
}
segment.innerHTML = `
<div class="segment-meta">${metaText}</div>
<div class="segment-text">${data.text}</div>
`;
// Remove placeholder if exists
const placeholder = transcriptionBox.querySelector('p');
if (placeholder) placeholder.remove();
transcriptionBox.appendChild(segment);
transcriptionBox.scrollTop = transcriptionBox.scrollHeight;
log(`認識完了: "${data.text.substring(0, 30)}..."`, 'success');
}
function updateTranscriptionDisplay() {
// Update or create partial display
let partialDiv = transcriptionBox.querySelector('.segment.partial');
if (currentPartialText) {
if (!partialDiv) {
partialDiv = document.createElement('div');
partialDiv.className = 'segment partial';
transcriptionBox.appendChild(partialDiv);
}
partialDiv.innerHTML = `
<div class="segment-meta">認識中...</div>
<div class="segment-text">${currentPartialText}</div>
`;
transcriptionBox.scrollTop = transcriptionBox.scrollHeight;
} else if (partialDiv) {
partialDiv.remove();
}
}
function handleVADEvent(data) {
if (data.event === 'speech_start') {
vadLevel.classList.add('speech');
vadStatus.textContent = '発話中';
log(`発話開始 @ ${data.audio_timestamp_sec?.toFixed(2)}s`, 'info');
} else if (data.event === 'speech_end') {
vadLevel.classList.remove('speech');
vadStatus.textContent = '待機中';
log(`発話終了 @ ${data.audio_timestamp_sec?.toFixed(2)}s`, 'info');
}
}
// Recording
async function toggleRecording() {
if (isRecording) {
stopRecording();
} else {
await startRecording();
}
}
async function startRecording() {
try {
log('マイクアクセスをリクエスト中...', 'info');
mediaStream = await navigator.mediaDevices.getUserMedia({
audio: {
sampleRate: 16000,
channelCount: 1,
echoCancellation: true,
noiseSuppression: true,
}
});
audioContext = new (window.AudioContext || window.webkitAudioContext)({
sampleRate: 16000
});
const source = audioContext.createMediaStreamSource(mediaStream);
// Analyser for visualization
analyser = audioContext.createAnalyser();
analyser.fftSize = 256;
source.connect(analyser);
// ScriptProcessor for sending audio
const bufferSize = 4096;
processor = audioContext.createScriptProcessor(bufferSize, 1, 1);
processor.onaudioprocess = (e) => {
if (!isRecording || !websocket || websocket.readyState !== WebSocket.OPEN) {
return;
}
const inputData = e.inputBuffer.getChannelData(0);
// Convert to 16-bit PCM
const pcmData = new Int16Array(inputData.length);
for (let i = 0; i < inputData.length; i++) {
pcmData[i] = Math.max(-32768, Math.min(32767, inputData[i] * 32768));
}
// Send as binary
websocket.send(pcmData.buffer);
};
source.connect(processor);
processor.connect(audioContext.destination);
isRecording = true;
recordingStartTime = Date.now();
statusIndicator.classList.add('recording');
recordBtn.textContent = '⏹️ 録音停止';
recordBtn.classList.remove('success');
recordBtn.classList.add('stop');
// Start visualization
visualize();
updateDuration();
log('録音開始', 'success');
} catch (error) {
log(`マイクエラー: ${error}`, 'error');
}
}
function stopRecording() {
isRecording = false;
if (processor) {
processor.disconnect();
processor = null;
}
if (audioContext) {
audioContext.close();
audioContext = null;
}
if (mediaStream) {
mediaStream.getTracks().forEach(track => track.stop());
mediaStream = null;
}
statusIndicator.classList.remove('recording');
recordBtn.textContent = '🎤 録音開始';
recordBtn.classList.remove('stop');
recordBtn.classList.add('success');
// Send stop message
if (websocket && websocket.readyState === WebSocket.OPEN) {
websocket.send(JSON.stringify({ type: 'stop' }));
}
log('録音停止', 'info');
}
// Visualization
function visualize() {
if (!analyser || !isRecording) return;
const canvas = document.getElementById('visualizer');
const ctx = canvas.getContext('2d');
const width = canvas.width = canvas.offsetWidth;
const height = canvas.height = canvas.offsetHeight;
const bufferLength = analyser.frequencyBinCount;
const dataArray = new Uint8Array(bufferLength);
function draw() {
if (!isRecording) return;
requestAnimationFrame(draw);
analyser.getByteFrequencyData(dataArray);
ctx.fillStyle = 'rgb(15, 52, 96)';
ctx.fillRect(0, 0, width, height);
const barWidth = (width / bufferLength) * 2.5;
let x = 0;
// Calculate average for VAD indicator
let sum = 0;
for (let i = 0; i < bufferLength; i++) {
const barHeight = (dataArray[i] / 255) * height;
const gradient = ctx.createLinearGradient(0, height, 0, height - barHeight);
gradient.addColorStop(0, '#e94560');
gradient.addColorStop(1, '#4ade80');
ctx.fillStyle = gradient;
ctx.fillRect(x, height - barHeight, barWidth, barHeight);
x += barWidth + 1;
sum += dataArray[i];
}
// Update VAD level indicator
const avgLevel = (sum / bufferLength / 255) * 100;
vadLevel.style.width = `${avgLevel}%`;
}
draw();
}
function updateDuration() {
if (!isRecording) return;
const duration = (Date.now() - recordingStartTime) / 1000;
document.getElementById('statDuration').textContent = duration.toFixed(1);
requestAnimationFrame(updateDuration);
}
// Utility functions
function clearTranscription() {
transcriptionBox.innerHTML = `
<p style="color: var(--text-secondary); text-align: center; padding: 50px;">
接続して録音を開始すると、ここに認識結果が表示されます
</p>
`;
segmentCount = 0;
currentPartialText = '';
document.getElementById('statSegments').textContent = '0';
document.getElementById('statLatency').textContent = '-';
log('認識結果をクリアしました', 'info');
}
function copyTranscription() {
const segments = transcriptionBox.querySelectorAll('.segment:not(.partial) .segment-text');
const text = Array.from(segments).map(s => s.textContent).join('\n');
if (text) {
navigator.clipboard.writeText(text).then(() => {
log('認識結果をコピーしました', 'success');
});
}
}
// Heartbeat
setInterval(() => {
if (websocket && websocket.readyState === WebSocket.OPEN) {
websocket.send(JSON.stringify({ type: 'ping' }));
}
}, 30000);
</script>
</body>
</html>

View File

@ -0,0 +1,173 @@
#!/usr/bin/env python3
"""
VibeVoice-ASR Test Script for DGX Spark
Tests basic functionality and GPU availability
"""
import sys
import subprocess
def test_imports():
"""Test that VibeVoice can be imported"""
print("=" * 60)
print("Testing VibeVoice imports...")
print("=" * 60)
try:
import vibevoice
print("[OK] vibevoice imported successfully")
return True
except ImportError as e:
print(f"[FAIL] Failed to import vibevoice: {e}")
return False
def test_torch_cuda():
"""Test PyTorch CUDA availability"""
print("\n" + "=" * 60)
print("Testing PyTorch CUDA...")
print("=" * 60)
try:
import torch
print(f"[INFO] PyTorch version: {torch.__version__}")
print(f"[INFO] CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"[INFO] CUDA version: {torch.version.cuda}")
print(f"[INFO] GPU count: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
props = torch.cuda.get_device_properties(i)
print(f"[INFO] GPU {i}: {props.name}")
print(f" Compute capability: {props.major}.{props.minor}")
print(f" Total memory: {props.total_memory / 1024**3:.1f} GB")
# Quick CUDA test
x = torch.randn(100, 100, device='cuda')
y = torch.matmul(x, x)
print(f"[OK] CUDA tensor operations working")
return True
else:
print("[WARN] CUDA not available")
return False
except Exception as e:
print(f"[FAIL] PyTorch CUDA test failed: {e}")
return False
def test_flash_attention():
"""Test flash attention availability"""
print("\n" + "=" * 60)
print("Testing Flash Attention...")
print("=" * 60)
try:
import flash_attn
print(f"[OK] flash_attn version: {flash_attn.__version__}")
return True
except ImportError:
print("[WARN] flash_attn not installed (optional)")
return True # Not required
def test_ffmpeg():
"""Test FFmpeg availability"""
print("\n" + "=" * 60)
print("Testing FFmpeg...")
print("=" * 60)
try:
result = subprocess.run(
["ffmpeg", "-version"],
capture_output=True,
text=True
)
if result.returncode == 0:
version_line = result.stdout.split('\n')[0]
print(f"[OK] {version_line}")
return True
else:
print("[FAIL] FFmpeg returned error")
return False
except FileNotFoundError:
print("[FAIL] FFmpeg not found")
return False
def test_asr_model():
"""Test loading ASR model (if GPU available)"""
print("\n" + "=" * 60)
print("Testing ASR Model Loading...")
print("=" * 60)
try:
import torch
if not torch.cuda.is_available():
print("[SKIP] Skipping model test - no GPU available")
return True
# Try to load the ASR pipeline
from vibevoice import ASRPipeline
print("[INFO] Loading ASR pipeline...")
# Use smaller model for testing
pipeline = ASRPipeline()
print("[OK] ASR pipeline loaded successfully")
# Clean up
del pipeline
torch.cuda.empty_cache()
return True
except ImportError as e:
print(f"[WARN] ASRPipeline not available: {e}")
print("[INFO] This may be normal depending on VibeVoice version")
return True
except Exception as e:
print(f"[WARN] ASR model test: {e}")
return True
def main():
"""Run all tests"""
print("\n")
print("*" * 60)
print(" VibeVoice-ASR Test Suite for DGX Spark")
print("*" * 60)
results = {
"imports": test_imports(),
"torch_cuda": test_torch_cuda(),
"flash_attention": test_flash_attention(),
"ffmpeg": test_ffmpeg(),
"asr_model": test_asr_model(),
}
print("\n")
print("=" * 60)
print("Test Summary")
print("=" * 60)
all_passed = True
for name, passed in results.items():
status = "[OK]" if passed else "[FAIL]"
print(f" {status} {name}")
if not passed:
all_passed = False
print("=" * 60)
if all_passed:
print("\nAll tests passed!")
return 0
else:
print("\nSome tests failed.")
return 1
if __name__ == "__main__":
sys.exit(main())

File diff suppressed because it is too large Load Diff