โ๏ธ ่จญๅฎ
+๐๏ธ VAD่จญๅฎ
+๐ ใญใฐ
+ +๐ ่ช่ญ็ตๆ
++ ๆฅ็ถใใฆ้ฒ้ณใ้ๅงใใใจใใใใซ่ช่ญ็ตๆใ่กจ็คบใใใพใ +
+diff --git a/static/scripts/vibevoice-asr/Dockerfile b/static/scripts/vibevoice-asr/Dockerfile new file mode 100644 index 0000000..b01c1b6 --- /dev/null +++ b/static/scripts/vibevoice-asr/Dockerfile @@ -0,0 +1,61 @@ +# VibeVoice-ASR for DGX Spark (ARM64, Blackwell GB10, sm_121) +# Based on NVIDIA PyTorch container for CUDA 13.1 compatibility + +ARG TARGETARCH +FROM nvcr.io/nvidia/pytorch:25.11-py3 AS base + +LABEL maintainer="VibeVoice-ASR DGX Spark Setup" +LABEL description="VibeVoice-ASR optimized for DGX Spark (ARM64, CUDA 13.1)" + +# Set environment variables +ENV DEBIAN_FRONTEND=noninteractive +ENV PYTHONDONTWRITEBYTECODE=1 +ENV PYTHONUNBUFFERED=1 + +# PyTorch CUDA settings for DGX Spark +ENV PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +ENV USE_LIBUV=0 + +# Set working directory +WORKDIR /workspace + +# Install system dependencies including FFmpeg for demo +RUN apt-get update && apt-get install -y --no-install-recommends \ + ffmpeg \ + git \ + curl \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* + +# Install flash-attn if not already present +RUN pip install --no-cache-dir flash-attn --no-build-isolation || true + +# Clone and install VibeVoice +RUN git clone https://github.com/microsoft/VibeVoice.git /workspace/VibeVoice && \ + cd /workspace/VibeVoice && \ + pip install --no-cache-dir -e . + +# Create test script and patched demo with MKV support +COPY test_vibevoice.py /workspace/test_vibevoice.py +COPY vibevoice_asr_gradio_demo_patched.py /workspace/VibeVoice/demo/vibevoice_asr_gradio_demo.py + +# Install real-time ASR dependencies +COPY requirements-realtime.txt /workspace/requirements-realtime.txt +RUN pip install --no-cache-dir -r /workspace/requirements-realtime.txt + +# Copy real-time ASR module and startup scripts +COPY realtime/ /workspace/VibeVoice/realtime/ +COPY static/ /workspace/VibeVoice/static/ +COPY run_all.sh /workspace/VibeVoice/run_all.sh +COPY run_realtime.sh /workspace/VibeVoice/run_realtime.sh +RUN chmod +x /workspace/VibeVoice/run_all.sh /workspace/VibeVoice/run_realtime.sh + +# Set default working directory to VibeVoice +WORKDIR /workspace/VibeVoice + +# Expose Gradio port and WebSocket port +EXPOSE 7860 +EXPOSE 8000 + +# Default command: Launch Gradio demo with MKV support +CMD ["python", "demo/vibevoice_asr_gradio_demo.py", "--model_path", "microsoft/VibeVoice-ASR", "--host", "0.0.0.0"] diff --git a/static/scripts/vibevoice-asr/realtime/__init__.py b/static/scripts/vibevoice-asr/realtime/__init__.py new file mode 100644 index 0000000..0fbe8cb --- /dev/null +++ b/static/scripts/vibevoice-asr/realtime/__init__.py @@ -0,0 +1,7 @@ +""" +VibeVoice Realtime ASR Module + +WebSocket-based real-time speech recognition using VibeVoice ASR. +""" + +__version__ = "0.1.0" diff --git a/static/scripts/vibevoice-asr/realtime/asr_worker.py b/static/scripts/vibevoice-asr/realtime/asr_worker.py new file mode 100644 index 0000000..8e37d86 --- /dev/null +++ b/static/scripts/vibevoice-asr/realtime/asr_worker.py @@ -0,0 +1,358 @@ +""" +ASR Worker for real-time transcription. + +Wraps the existing VibeVoiceASRInference for async/streaming operation. +""" + +import sys +import os +import asyncio +import threading +import time +import queue +from typing import AsyncGenerator, Optional, List, Callable +from dataclasses import dataclass +import numpy as np +import torch + +# Add parent directory and demo directory to path for importing existing code +_parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, _parent_dir) +sys.path.insert(0, os.path.join(_parent_dir, "demo")) + +from .models import ( + TranscriptionResult, + TranscriptionSegment, + MessageType, + SessionConfig, +) + + +@dataclass +class InferenceRequest: + """Request for ASR inference.""" + audio: np.ndarray + sample_rate: int + context_info: Optional[str] + request_time: float + segment_start_sec: float + segment_end_sec: float + + +class ASRWorker: + """ + ASR Worker that wraps VibeVoiceASRInference for real-time use. + + Features: + - Async interface for WebSocket integration + - Streaming output via TextIteratorStreamer + - Request queuing for handling concurrent segments + - Graceful model loading and error handling + """ + + def __init__( + self, + model_path: str = "microsoft/VibeVoice-ASR", + device: str = "cuda", + dtype: torch.dtype = torch.bfloat16, + attn_implementation: str = "flash_attention_2", + ): + """ + Initialize the ASR worker. + + Args: + model_path: Path to VibeVoice ASR model + device: Device to run inference on + dtype: Model data type + attn_implementation: Attention implementation + """ + self.model_path = model_path + self.device = device + self.dtype = dtype + self.attn_implementation = attn_implementation + + self._inference = None + self._is_loaded = False + self._load_lock = threading.Lock() + + # Inference queue for serializing requests + self._inference_semaphore = asyncio.Semaphore(1) + + def load_model(self) -> bool: + """ + Load the ASR model. + + Returns: + True if model loaded successfully + """ + with self._load_lock: + if self._is_loaded: + return True + + try: + # Import here to avoid circular imports and allow lazy loading + # In Docker, the file is copied as vibevoice_asr_gradio_demo.py + try: + from vibevoice_asr_gradio_demo import VibeVoiceASRInference + except ImportError: + from vibevoice_asr_gradio_demo_patched import VibeVoiceASRInference + + print(f"Loading VibeVoice ASR model from {self.model_path}...") + self._inference = VibeVoiceASRInference( + model_path=self.model_path, + device=self.device, + dtype=self.dtype, + attn_implementation=self.attn_implementation, + ) + self._is_loaded = True + print("ASR model loaded successfully") + return True + + except Exception as e: + print(f"Failed to load ASR model: {e}") + import traceback + traceback.print_exc() + return False + + @property + def is_loaded(self) -> bool: + """Check if model is loaded.""" + return self._is_loaded + + async def transcribe_segment( + self, + audio: np.ndarray, + sample_rate: int = 16000, + context_info: Optional[str] = None, + segment_start_sec: float = 0.0, + segment_end_sec: float = 0.0, + config: Optional[SessionConfig] = None, + on_partial: Optional[Callable[[TranscriptionResult], None]] = None, + ) -> TranscriptionResult: + """ + Transcribe an audio segment asynchronously. + + Args: + audio: Audio data as float32 array + sample_rate: Audio sample rate + context_info: Optional context for transcription + segment_start_sec: Start time of segment in session + segment_end_sec: End time of segment in session + config: Session configuration + on_partial: Callback for partial results + + Returns: + Final transcription result + """ + if not self._is_loaded: + if not self.load_model(): + return TranscriptionResult( + type=MessageType.ERROR, + text="", + is_final=True, + latency_ms=0, + ) + + config = config or SessionConfig() + request_time = time.time() + + # Serialize inference requests + async with self._inference_semaphore: + return await self._run_inference( + audio=audio, + sample_rate=sample_rate, + context_info=context_info, + segment_start_sec=segment_start_sec, + segment_end_sec=segment_end_sec, + config=config, + request_time=request_time, + on_partial=on_partial, + ) + + async def _run_inference( + self, + audio: np.ndarray, + sample_rate: int, + context_info: Optional[str], + segment_start_sec: float, + segment_end_sec: float, + config: SessionConfig, + request_time: float, + on_partial: Optional[Callable[[TranscriptionResult], None]], + ) -> TranscriptionResult: + """Run the actual inference in a thread pool.""" + from transformers import TextIteratorStreamer + + # Create streamer for partial results + streamer = None + if config.return_partial_results and on_partial: + streamer = TextIteratorStreamer( + self._inference.processor.tokenizer, + skip_prompt=True, + skip_special_tokens=True, + ) + + # Result container for thread + result_container = {"result": None, "error": None} + + def run_inference(): + try: + # Save audio to temp file (required by current implementation) + import tempfile + import soundfile as sf + + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: + temp_path = f.name + + # Write audio + audio_int16 = (audio * 32768.0).clip(-32768, 32767).astype(np.int16) + sf.write(temp_path, audio_int16, sample_rate, subtype='PCM_16') + + try: + result = self._inference.transcribe( + audio_path=temp_path, + max_new_tokens=config.max_new_tokens, + temperature=config.temperature, + context_info=context_info, + streamer=streamer, + ) + result_container["result"] = result + finally: + # Clean up temp file + try: + os.unlink(temp_path) + except: + pass + + except Exception as e: + result_container["error"] = str(e) + import traceback + traceback.print_exc() + + # Start inference in background thread + inference_thread = threading.Thread(target=run_inference) + inference_thread.start() + + # Stream partial results if enabled + partial_text = "" + if streamer and on_partial: + try: + for new_text in streamer: + partial_text += new_text + partial_result = TranscriptionResult( + type=MessageType.PARTIAL_RESULT, + text=partial_text, + is_final=False, + latency_ms=(time.time() - request_time) * 1000, + ) + # Call callback (may be async) + if asyncio.iscoroutinefunction(on_partial): + await on_partial(partial_result) + else: + on_partial(partial_result) + except Exception as e: + print(f"Error during streaming: {e}") + + # Wait for completion + inference_thread.join() + + latency_ms = (time.time() - request_time) * 1000 + + if result_container["error"]: + return TranscriptionResult( + type=MessageType.ERROR, + text=f"Error: {result_container['error']}", + is_final=True, + latency_ms=latency_ms, + ) + + result = result_container["result"] + + # Convert segments to our format + segments = [] + for seg in result.get("segments", []): + # Adjust timestamps relative to session + seg_start = seg.get("start_time", 0) + seg_end = seg.get("end_time", 0) + + # If segment has relative timestamps, adjust to absolute + if isinstance(seg_start, (int, float)) and isinstance(seg_end, (int, float)): + adjusted_start = segment_start_sec + seg_start + adjusted_end = segment_start_sec + seg_end + else: + adjusted_start = segment_start_sec + adjusted_end = segment_end_sec + + segments.append(TranscriptionSegment( + start_time=adjusted_start, + end_time=adjusted_end, + speaker_id=seg.get("speaker_id", "SPEAKER_00"), + text=seg.get("text", ""), + )) + + return TranscriptionResult( + type=MessageType.FINAL_RESULT, + text=result.get("raw_text", ""), + is_final=True, + segments=segments, + latency_ms=latency_ms, + ) + + async def transcribe_stream( + self, + audio: np.ndarray, + sample_rate: int = 16000, + context_info: Optional[str] = None, + segment_start_sec: float = 0.0, + segment_end_sec: float = 0.0, + config: Optional[SessionConfig] = None, + ) -> AsyncGenerator[TranscriptionResult, None]: + """ + Transcribe an audio segment with streaming output. + + Yields partial results followed by final result. + + Args: + audio: Audio data + sample_rate: Sample rate + context_info: Optional context + segment_start_sec: Segment start time + segment_end_sec: Segment end time + config: Session config + + Yields: + TranscriptionResult objects (partial and final) + """ + result_queue: asyncio.Queue = asyncio.Queue() + + async def on_partial(result: TranscriptionResult): + await result_queue.put(result) + + # Start transcription task + transcribe_task = asyncio.create_task( + self.transcribe_segment( + audio=audio, + sample_rate=sample_rate, + context_info=context_info, + segment_start_sec=segment_start_sec, + segment_end_sec=segment_end_sec, + config=config, + on_partial=on_partial, + ) + ) + + # Yield partial results as they come + while not transcribe_task.done(): + try: + result = await asyncio.wait_for(result_queue.get(), timeout=0.1) + yield result + except asyncio.TimeoutError: + continue + + # Drain any remaining partial results + while not result_queue.empty(): + yield await result_queue.get() + + # Yield final result + final_result = await transcribe_task + yield final_result diff --git a/static/scripts/vibevoice-asr/realtime/audio_buffer.py b/static/scripts/vibevoice-asr/realtime/audio_buffer.py new file mode 100644 index 0000000..855624f --- /dev/null +++ b/static/scripts/vibevoice-asr/realtime/audio_buffer.py @@ -0,0 +1,246 @@ +""" +Audio buffer management for real-time ASR. + +Implements a ring buffer for efficient audio chunk management with overlap support. +""" + +import numpy as np +from typing import Optional, Tuple +from dataclasses import dataclass +import threading + + +@dataclass +class AudioChunkInfo: + """Information about an extracted audio chunk.""" + audio: np.ndarray + start_sample: int + end_sample: int + start_sec: float + end_sec: float + + +class AudioBuffer: + """ + Ring buffer for managing audio chunks with overlap support. + + Features: + - Efficient memory management with fixed-size buffer + - Overlap handling for continuous processing + - Thread-safe operations + - Automatic sample rate tracking + """ + + def __init__( + self, + sample_rate: int = 16000, + chunk_duration_sec: float = 3.0, + overlap_sec: float = 0.5, + max_buffer_sec: float = 60.0, + ): + """ + Initialize the audio buffer. + + Args: + sample_rate: Audio sample rate in Hz + chunk_duration_sec: Duration of each processing chunk + overlap_sec: Overlap between consecutive chunks + max_buffer_sec: Maximum buffer duration (older data will be discarded) + """ + self.sample_rate = sample_rate + self.chunk_size = int(chunk_duration_sec * sample_rate) + self.overlap_size = int(overlap_sec * sample_rate) + self.max_buffer_size = int(max_buffer_sec * sample_rate) + + # Main buffer (pre-allocated) + self._buffer = np.zeros(self.max_buffer_size, dtype=np.float32) + self._write_pos = 0 # Next position to write + self._read_pos = 0 # Position of unprocessed data start + self._total_samples_received = 0 # Total samples since session start + + self._lock = threading.Lock() + + @property + def samples_available(self) -> int: + """Number of unprocessed samples in buffer.""" + with self._lock: + return self._write_pos - self._read_pos + + @property + def duration_available_sec(self) -> float: + """Duration of unprocessed audio in seconds.""" + return self.samples_available / self.sample_rate + + @property + def total_duration_sec(self) -> float: + """Total duration of audio received since session start.""" + return self._total_samples_received / self.sample_rate + + def append(self, audio_chunk: np.ndarray) -> int: + """ + Append audio chunk to the buffer. + + Args: + audio_chunk: Audio data as float32 array (range: -1.0 to 1.0) + + Returns: + Number of samples actually appended + """ + if audio_chunk.dtype != np.float32: + audio_chunk = audio_chunk.astype(np.float32) + + # Ensure 1D + if audio_chunk.ndim > 1: + audio_chunk = audio_chunk.flatten() + + with self._lock: + chunk_len = len(audio_chunk) + + # Check if we need to shift buffer (running out of space) + if self._write_pos + chunk_len > self.max_buffer_size: + self._compact_buffer() + + # Still not enough space? Discard old unprocessed data + if self._write_pos + chunk_len > self.max_buffer_size: + overflow = (self._write_pos + chunk_len) - self.max_buffer_size + self._read_pos = min(self._read_pos + overflow, self._write_pos) + self._compact_buffer() + + # Write to buffer + end_pos = self._write_pos + chunk_len + self._buffer[self._write_pos:end_pos] = audio_chunk + self._write_pos = end_pos + self._total_samples_received += chunk_len + + return chunk_len + + def _compact_buffer(self) -> None: + """Move unprocessed data to the beginning of the buffer.""" + if self._read_pos > 0: + unprocessed_len = self._write_pos - self._read_pos + if unprocessed_len > 0: + self._buffer[:unprocessed_len] = self._buffer[self._read_pos:self._write_pos] + self._write_pos = unprocessed_len + self._read_pos = 0 + + def get_chunk_for_inference(self, min_duration_sec: float = 0.5) -> Optional[AudioChunkInfo]: + """ + Get the next chunk for ASR inference. + + Returns a chunk of audio when enough data is available. + The chunk includes overlap from the previous chunk for context. + + Args: + min_duration_sec: Minimum duration required to return a chunk + + Returns: + AudioChunkInfo if enough data is available, None otherwise + """ + min_samples = int(min_duration_sec * self.sample_rate) + + with self._lock: + available = self._write_pos - self._read_pos + + if available < min_samples: + return None + + # Calculate chunk boundaries + chunk_start = self._read_pos + chunk_end = min(self._read_pos + self.chunk_size, self._write_pos) + actual_chunk_size = chunk_end - chunk_start + + # Extract audio + audio = self._buffer[chunk_start:chunk_end].copy() + + # Calculate timestamps based on total samples received + base_sample = self._total_samples_received - (self._write_pos - chunk_start) + start_sec = base_sample / self.sample_rate + end_sec = (base_sample + actual_chunk_size) / self.sample_rate + + return AudioChunkInfo( + audio=audio, + start_sample=base_sample, + end_sample=base_sample + actual_chunk_size, + start_sec=start_sec, + end_sec=end_sec, + ) + + def mark_processed(self, samples: int) -> None: + """ + Mark samples as processed, advancing the read position. + + Keeps overlap_size samples for context in the next chunk. + + Args: + samples: Number of samples that were processed + """ + with self._lock: + # Advance read position but keep overlap for context + advance = max(0, samples - self.overlap_size) + self._read_pos = min(self._read_pos + advance, self._write_pos) + + def get_segment(self, start_sec: float, end_sec: float) -> Optional[np.ndarray]: + """ + Get a specific time segment from the buffer. + + Args: + start_sec: Start time in seconds (relative to session start) + end_sec: End time in seconds + + Returns: + Audio segment if available, None otherwise + """ + start_sample = int(start_sec * self.sample_rate) + end_sample = int(end_sec * self.sample_rate) + + with self._lock: + # Calculate buffer positions + buffer_start_sample = self._total_samples_received - self._write_pos + buffer_end_sample = self._total_samples_received + + # Check if segment is in buffer + if start_sample < buffer_start_sample or end_sample > buffer_end_sample: + return None + + # Convert to buffer indices + buf_start = start_sample - buffer_start_sample + buf_end = end_sample - buffer_start_sample + + return self._buffer[buf_start:buf_end].copy() + + def get_all_unprocessed(self) -> Optional[AudioChunkInfo]: + """ + Get all unprocessed audio. + + Returns: + AudioChunkInfo with all unprocessed audio, or None if empty + """ + with self._lock: + if self._write_pos <= self._read_pos: + return None + + audio = self._buffer[self._read_pos:self._write_pos].copy() + base_sample = self._total_samples_received - (self._write_pos - self._read_pos) + start_sec = base_sample / self.sample_rate + end_sec = self._total_samples_received / self.sample_rate + + return AudioChunkInfo( + audio=audio, + start_sample=base_sample, + end_sample=self._total_samples_received, + start_sec=start_sec, + end_sec=end_sec, + ) + + def clear(self) -> None: + """Clear the buffer and reset all positions.""" + with self._lock: + self._buffer.fill(0) + self._write_pos = 0 + self._read_pos = 0 + self._total_samples_received = 0 + + def reset_read_position(self) -> None: + """Reset read position to current write position (skip all unprocessed).""" + with self._lock: + self._read_pos = self._write_pos diff --git a/static/scripts/vibevoice-asr/realtime/models.py b/static/scripts/vibevoice-asr/realtime/models.py new file mode 100644 index 0000000..6846c87 --- /dev/null +++ b/static/scripts/vibevoice-asr/realtime/models.py @@ -0,0 +1,154 @@ +""" +Data models for real-time ASR WebSocket communication. +""" + +from enum import Enum +from typing import Optional, List, Dict, Any +from dataclasses import dataclass, field, asdict +import time + + +class MessageType(str, Enum): + """WebSocket message types.""" + # Client -> Server + AUDIO_CHUNK = "audio_chunk" + CONFIG = "config" + START = "start" + STOP = "stop" + + # Server -> Client + PARTIAL_RESULT = "partial_result" + FINAL_RESULT = "final_result" + VAD_EVENT = "vad_event" + ERROR = "error" + STATUS = "status" + + +class VADEventType(str, Enum): + """VAD event types.""" + SPEECH_START = "speech_start" + SPEECH_END = "speech_end" + + +@dataclass +class SessionConfig: + """Configuration for a real-time ASR session.""" + # Audio parameters + sample_rate: int = 16000 + chunk_duration_sec: float = 3.0 + overlap_sec: float = 0.5 + + # VAD parameters + vad_threshold: float = 0.5 + min_speech_duration_ms: int = 250 + min_silence_duration_ms: int = 500 + min_volume_threshold: float = 0.01 # Minimum RMS volume (0.0-1.0) to consider as potential speech + + # ASR parameters + max_new_tokens: int = 512 + temperature: float = 0.0 + context_info: Optional[str] = None + + # Behavior + return_partial_results: bool = True + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "SessionConfig": + return cls(**{k: v for k, v in data.items() if k in cls.__dataclass_fields__}) + + +@dataclass +class TranscriptionSegment: + """A single transcription segment with metadata.""" + start_time: float + end_time: float + speaker_id: str + text: str + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + +@dataclass +class TranscriptionResult: + """Transcription result message.""" + type: MessageType + text: str + is_final: bool + segments: List[TranscriptionSegment] = field(default_factory=list) + latency_ms: float = 0.0 + timestamp: float = field(default_factory=time.time) + + def to_dict(self) -> Dict[str, Any]: + return { + "type": self.type.value, + "text": self.text, + "is_final": self.is_final, + "segments": [s.to_dict() for s in self.segments], + "latency_ms": self.latency_ms, + "timestamp": self.timestamp, + } + + +@dataclass +class VADEvent: + """VAD event message.""" + type: MessageType = MessageType.VAD_EVENT + event: VADEventType = VADEventType.SPEECH_START + timestamp: float = field(default_factory=time.time) + audio_timestamp_sec: float = 0.0 + + def to_dict(self) -> Dict[str, Any]: + return { + "type": self.type.value, + "event": self.event.value, + "timestamp": self.timestamp, + "audio_timestamp_sec": self.audio_timestamp_sec, + } + + +@dataclass +class StatusMessage: + """Status message.""" + type: MessageType = MessageType.STATUS + status: str = "" + message: str = "" + timestamp: float = field(default_factory=time.time) + + def to_dict(self) -> Dict[str, Any]: + return { + "type": self.type.value, + "status": self.status, + "message": self.message, + "timestamp": self.timestamp, + } + + +@dataclass +class ErrorMessage: + """Error message.""" + type: MessageType = MessageType.ERROR + error: str = "" + code: str = "" + timestamp: float = field(default_factory=time.time) + + def to_dict(self) -> Dict[str, Any]: + return { + "type": self.type.value, + "error": self.error, + "code": self.code, + "timestamp": self.timestamp, + } + + +@dataclass +class SpeechSegment: + """Detected speech segment from VAD.""" + start_sample: int + end_sample: int + start_sec: float + end_sec: float + confidence: float = 1.0 diff --git a/static/scripts/vibevoice-asr/realtime/server.py b/static/scripts/vibevoice-asr/realtime/server.py new file mode 100644 index 0000000..615490b --- /dev/null +++ b/static/scripts/vibevoice-asr/realtime/server.py @@ -0,0 +1,300 @@ +""" +FastAPI WebSocket server for real-time ASR. + +Provides WebSocket endpoint for streaming audio and receiving transcriptions. +""" + +import os +import sys +import asyncio +import json +import time +from typing import Optional +from contextlib import asynccontextmanager + +from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException +from fastapi.staticfiles import StaticFiles +from fastapi.responses import HTMLResponse, JSONResponse +import uvicorn + +# Add parent directory to path +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from .models import ( + SessionConfig, + TranscriptionResult, + VADEvent, + StatusMessage, + ErrorMessage, + MessageType, +) +from .asr_worker import ASRWorker +from .session_manager import SessionManager + + +# Global instances +asr_worker: Optional[ASRWorker] = None +session_manager: Optional[SessionManager] = None + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Application lifespan manager.""" + global asr_worker, session_manager + + # Startup + print("Starting VibeVoice Realtime ASR Server...") + + # Get model path from environment or use default + model_path = os.environ.get("VIBEVOICE_MODEL_PATH", "microsoft/VibeVoice-ASR") + device = os.environ.get("VIBEVOICE_DEVICE", "cuda") + attn_impl = os.environ.get("VIBEVOICE_ATTN_IMPL", "flash_attention_2") + + # Initialize ASR worker + asr_worker = ASRWorker( + model_path=model_path, + device=device, + attn_implementation=attn_impl, + ) + + # Pre-load model (optional, can be lazy-loaded on first request) + preload = os.environ.get("VIBEVOICE_PRELOAD_MODEL", "true").lower() == "true" + if preload: + print("Pre-loading ASR model...") + asr_worker.load_model() + + # Initialize session manager + max_sessions = int(os.environ.get("VIBEVOICE_MAX_SESSIONS", "10")) + session_manager = SessionManager( + asr_worker=asr_worker, + max_concurrent_sessions=max_sessions, + ) + await session_manager.start() + + print("Server ready!") + + yield + + # Shutdown + print("Shutting down...") + await session_manager.stop() + + +# Create FastAPI app +app = FastAPI( + title="VibeVoice Realtime ASR", + description="Real-time speech recognition using VibeVoice ASR", + version="0.1.0", + lifespan=lifespan, +) + +# Mount static files +static_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "static") +if os.path.exists(static_dir): + app.mount("/static", StaticFiles(directory=static_dir), name="static") + + +@app.get("/") +async def root(): + """Root endpoint with API info.""" + return { + "service": "VibeVoice Realtime ASR", + "version": "0.1.0", + "endpoints": { + "websocket": "/ws/asr/{session_id}", + "health": "/health", + "stats": "/stats", + "client": "/static/realtime_client.html", + }, + } + + +@app.get("/health") +async def health_check(): + """Health check endpoint.""" + return { + "status": "healthy", + "model_loaded": asr_worker.is_loaded if asr_worker else False, + "active_sessions": len(session_manager._sessions) if session_manager else 0, + } + + +@app.get("/stats") +async def get_stats(): + """Get server statistics.""" + if session_manager is None: + raise HTTPException(status_code=503, detail="Server not initialized") + + return session_manager.get_stats() + + +@app.websocket("/ws/asr/{session_id}") +async def websocket_asr(websocket: WebSocket, session_id: str): + """ + WebSocket endpoint for real-time ASR. + + Protocol: + 1. Client connects and optionally sends config message + 2. Client sends binary audio chunks (PCM 16-bit, 16kHz, mono) + 3. Server sends JSON messages with transcription results + + Message types (server -> client): + - partial_result: Intermediate transcription + - final_result: Complete transcription for a segment + - vad_event: Speech start/end events + - error: Error messages + - status: Status updates + """ + await websocket.accept() + + # Send connection confirmation + await websocket.send_json( + StatusMessage( + status="connected", + message=f"Session {session_id} connected", + ).to_dict() + ) + + # Result callback + async def on_result(result: TranscriptionResult): + try: + await websocket.send_json(result.to_dict()) + except Exception as e: + print(f"[{session_id}] Failed to send result: {e}") + + # VAD event callback + async def on_vad_event(event: VADEvent): + try: + await websocket.send_json(event.to_dict()) + except Exception as e: + print(f"[{session_id}] Failed to send VAD event: {e}") + + # Create session + session = await session_manager.create_session( + session_id=session_id, + on_result=on_result, + on_vad_event=on_vad_event, + ) + + if session is None: + await websocket.send_json( + ErrorMessage( + error="Maximum sessions reached", + code="MAX_SESSIONS", + ).to_dict() + ) + await websocket.close() + return + + await websocket.send_json( + StatusMessage( + status="ready", + message="Session ready for audio", + ).to_dict() + ) + + try: + while True: + # Receive message + message = await websocket.receive() + + if message["type"] == "websocket.disconnect": + break + + # Handle binary audio data + if "bytes" in message: + audio_data = message["bytes"] + try: + await session.process_audio_chunk(audio_data) + except Exception as e: + print(f"[{session_id}] Error processing audio: {e}") + import traceback + traceback.print_exc() + + # Handle JSON control messages + elif "text" in message: + try: + data = json.loads(message["text"]) + msg_type = data.get("type") + + if msg_type == "config": + # Update session config + config = SessionConfig.from_dict(data.get("config", {})) + session.update_config(config) + await websocket.send_json( + StatusMessage( + status="config_updated", + message="Configuration updated", + ).to_dict() + ) + + elif msg_type == "stop": + # Flush and close + await session.flush() + await websocket.send_json( + StatusMessage( + status="stopped", + message="Session stopped", + ).to_dict() + ) + break + + elif msg_type == "ping": + await websocket.send_json({"type": "pong", "timestamp": time.time()}) + + except json.JSONDecodeError: + await websocket.send_json( + ErrorMessage( + error="Invalid JSON", + code="INVALID_JSON", + ).to_dict() + ) + + except WebSocketDisconnect: + print(f"[{session_id}] Client disconnected") + except Exception as e: + print(f"[{session_id}] Error: {e}") + import traceback + traceback.print_exc() + finally: + # Clean up session + await session_manager.close_session(session_id) + print(f"[{session_id}] Session closed") + + +def main(): + """Main entry point.""" + import argparse + + parser = argparse.ArgumentParser(description="VibeVoice Realtime ASR Server") + parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to") + parser.add_argument("--port", type=int, default=8000, help="Port to bind to") + parser.add_argument("--model-path", type=str, default="microsoft/VibeVoice-ASR", + help="Path to VibeVoice ASR model") + parser.add_argument("--device", type=str, default="cuda", help="Device (cuda/cpu)") + parser.add_argument("--max-sessions", type=int, default=10, help="Max concurrent sessions") + parser.add_argument("--no-preload", action="store_true", help="Don't preload model") + + args = parser.parse_args() + + # Set environment variables for lifespan + os.environ["VIBEVOICE_MODEL_PATH"] = args.model_path + os.environ["VIBEVOICE_DEVICE"] = args.device + os.environ["VIBEVOICE_MAX_SESSIONS"] = str(args.max_sessions) + os.environ["VIBEVOICE_PRELOAD_MODEL"] = "false" if args.no_preload else "true" + + print(f"Starting server on {args.host}:{args.port}") + print(f"Model: {args.model_path}") + print(f"Device: {args.device}") + print(f"Max sessions: {args.max_sessions}") + + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="info", + ) + + +if __name__ == "__main__": + main() diff --git a/static/scripts/vibevoice-asr/realtime/session_manager.py b/static/scripts/vibevoice-asr/realtime/session_manager.py new file mode 100644 index 0000000..7aaeaff --- /dev/null +++ b/static/scripts/vibevoice-asr/realtime/session_manager.py @@ -0,0 +1,401 @@ +""" +Session manager for real-time ASR. + +Manages multiple concurrent client sessions with resource isolation. +""" + +import asyncio +import time +import uuid +from typing import Dict, Optional, Callable, Any +from dataclasses import dataclass, field +import threading + +from .models import ( + SessionConfig, + TranscriptionResult, + VADEvent, + StatusMessage, + ErrorMessage, + MessageType, + SpeechSegment, +) +from .audio_buffer import AudioBuffer +from .vad_processor import VADProcessor +from .asr_worker import ASRWorker + + +@dataclass +class SessionStats: + """Statistics for a session.""" + created_at: float = field(default_factory=time.time) + last_activity: float = field(default_factory=time.time) + audio_received_sec: float = 0.0 + chunks_received: int = 0 + segments_transcribed: int = 0 + total_latency_ms: float = 0.0 + + +class RealtimeSession: + """ + A single real-time ASR session. + + Manages audio buffering, VAD, and ASR for one client connection. + """ + + def __init__( + self, + session_id: str, + asr_worker: ASRWorker, + config: Optional[SessionConfig] = None, + on_result: Optional[Callable[[TranscriptionResult], Any]] = None, + on_vad_event: Optional[Callable[[VADEvent], Any]] = None, + ): + """ + Initialize a session. + + Args: + session_id: Unique session identifier + asr_worker: Shared ASR worker instance + config: Session configuration + on_result: Callback for transcription results + on_vad_event: Callback for VAD events + """ + self.session_id = session_id + self.asr_worker = asr_worker + self.config = config or SessionConfig() + self.on_result = on_result + self.on_vad_event = on_vad_event + + # Components + self.audio_buffer = AudioBuffer( + sample_rate=self.config.sample_rate, + chunk_duration_sec=self.config.chunk_duration_sec, + overlap_sec=self.config.overlap_sec, + ) + + self.vad_processor = VADProcessor( + sample_rate=self.config.sample_rate, + threshold=self.config.vad_threshold, + min_speech_duration_ms=self.config.min_speech_duration_ms, + min_silence_duration_ms=self.config.min_silence_duration_ms, + min_volume_threshold=self.config.min_volume_threshold, + ) + + # State + self.is_active = True + self.stats = SessionStats() + self._processing_lock = asyncio.Lock() + self._pending_tasks: list = [] + + async def process_audio_chunk(self, audio_data: bytes) -> None: + """ + Process an incoming audio chunk. + + Args: + audio_data: Raw PCM audio data (16-bit, 16kHz, mono) + """ + if not self.is_active: + return + + self.stats.last_activity = time.time() + self.stats.chunks_received += 1 + + # Convert bytes to float32 array + import numpy as np + audio_int16 = np.frombuffer(audio_data, dtype=np.int16) + audio_float = audio_int16.astype(np.float32) / 32768.0 + + self.stats.audio_received_sec += len(audio_float) / self.config.sample_rate + + # Add to buffer + self.audio_buffer.append(audio_float) + + # Process with VAD + segments, events = self.vad_processor.process(audio_float) + + # Send VAD events + if self.on_vad_event: + for event in events: + await self._send_callback(self.on_vad_event, event) + + # Process completed speech segments + for segment in segments: + try: + await self._transcribe_segment(segment) + except Exception as e: + print(f"[Session {self.session_id}] Transcription error: {e}") + import traceback + traceback.print_exc() + + async def _transcribe_segment(self, segment: SpeechSegment) -> None: + """Transcribe a detected speech segment.""" + # Get audio for segment from buffer + audio = self.audio_buffer.get_segment(segment.start_sec, segment.end_sec) + + if audio is None or len(audio) == 0: + print(f"[Session {self.session_id}] Could not retrieve audio for segment") + return + + self.stats.segments_transcribed += 1 + + async def on_partial(result: TranscriptionResult): + if self.on_result: + await self._send_callback(self.on_result, result) + + # Run transcription + result = await self.asr_worker.transcribe_segment( + audio=audio, + sample_rate=self.config.sample_rate, + context_info=self.config.context_info, + segment_start_sec=segment.start_sec, + segment_end_sec=segment.end_sec, + config=self.config, + on_partial=on_partial if self.config.return_partial_results else None, + ) + + self.stats.total_latency_ms += result.latency_ms + + # Send final result + if self.on_result: + await self._send_callback(self.on_result, result) + + async def _send_callback(self, callback: Callable, data: Any) -> None: + """Send data via callback, handling both sync and async.""" + try: + if asyncio.iscoroutinefunction(callback): + await callback(data) + else: + callback(data) + except Exception as e: + print(f"[Session {self.session_id}] Callback error: {e}") + + async def flush(self) -> None: + """ + Flush any remaining audio and force transcription. + + Called when session ends to process any remaining speech. + """ + # Force end any active speech + segment = self.vad_processor.force_end_speech() + if segment: + await self._transcribe_segment(segment) + + # Also check for any unprocessed audio in buffer + chunk_info = self.audio_buffer.get_all_unprocessed() + if chunk_info and len(chunk_info.audio) > self.config.sample_rate * 0.5: + # More than 0.5 seconds of unprocessed audio + forced_segment = SpeechSegment( + start_sample=chunk_info.start_sample, + end_sample=chunk_info.end_sample, + start_sec=chunk_info.start_sec, + end_sec=chunk_info.end_sec, + ) + await self._transcribe_segment(forced_segment) + + def update_config(self, new_config: SessionConfig) -> None: + """Update session configuration (partial update supported).""" + # Merge with existing config - only update non-default values + if new_config.vad_threshold != 0.5: + self.config.vad_threshold = new_config.vad_threshold + if new_config.min_speech_duration_ms != 250: + self.config.min_speech_duration_ms = new_config.min_speech_duration_ms + if new_config.min_silence_duration_ms != 500: + self.config.min_silence_duration_ms = new_config.min_silence_duration_ms + if new_config.min_volume_threshold != 0.01: + self.config.min_volume_threshold = new_config.min_volume_threshold + if new_config.context_info is not None: + self.config.context_info = new_config.context_info + + # Recreate VAD processor with new parameters + self.vad_processor = VADProcessor( + sample_rate=self.config.sample_rate, + threshold=self.config.vad_threshold, + min_speech_duration_ms=self.config.min_speech_duration_ms, + min_silence_duration_ms=self.config.min_silence_duration_ms, + min_volume_threshold=self.config.min_volume_threshold, + ) + print(f"[Session {self.session_id}] Config updated: vad_threshold={self.config.vad_threshold}, " + f"min_speech={self.config.min_speech_duration_ms}ms, min_silence={self.config.min_silence_duration_ms}ms, " + f"min_volume={self.config.min_volume_threshold}") + + def close(self) -> None: + """Close the session and release resources.""" + self.is_active = False + self.audio_buffer.clear() + self.vad_processor.reset() + + def get_stats(self) -> Dict: + """Get session statistics.""" + return { + "session_id": self.session_id, + "created_at": self.stats.created_at, + "last_activity": self.stats.last_activity, + "duration_sec": time.time() - self.stats.created_at, + "audio_received_sec": self.stats.audio_received_sec, + "chunks_received": self.stats.chunks_received, + "segments_transcribed": self.stats.segments_transcribed, + "avg_latency_ms": ( + self.stats.total_latency_ms / self.stats.segments_transcribed + if self.stats.segments_transcribed > 0 else 0 + ), + "is_active": self.is_active, + "vad_speech_active": self.vad_processor.is_speech_active, + } + + +class SessionManager: + """ + Manages multiple concurrent ASR sessions. + + Features: + - Session creation and cleanup + - Resource limiting (max concurrent sessions) + - Idle session timeout + - Shared ASR worker management + """ + + def __init__( + self, + asr_worker: ASRWorker, + max_concurrent_sessions: int = 10, + session_timeout_sec: float = 300.0, + ): + """ + Initialize the session manager. + + Args: + asr_worker: Shared ASR worker + max_concurrent_sessions: Maximum number of concurrent sessions + session_timeout_sec: Timeout for idle sessions + """ + self.asr_worker = asr_worker + self.max_sessions = max_concurrent_sessions + self.session_timeout = session_timeout_sec + + self._sessions: Dict[str, RealtimeSession] = {} + self._lock = asyncio.Lock() + + # Cleanup task + self._cleanup_task: Optional[asyncio.Task] = None + + async def start(self) -> None: + """Start the session manager.""" + self._cleanup_task = asyncio.create_task(self._cleanup_loop()) + + async def stop(self) -> None: + """Stop the session manager and close all sessions.""" + if self._cleanup_task: + self._cleanup_task.cancel() + try: + await self._cleanup_task + except asyncio.CancelledError: + pass + + async with self._lock: + for session in self._sessions.values(): + session.close() + self._sessions.clear() + + async def create_session( + self, + session_id: Optional[str] = None, + config: Optional[SessionConfig] = None, + on_result: Optional[Callable[[TranscriptionResult], Any]] = None, + on_vad_event: Optional[Callable[[VADEvent], Any]] = None, + ) -> Optional[RealtimeSession]: + """ + Create a new session. + + Args: + session_id: Optional session ID (generated if not provided) + config: Session configuration + on_result: Callback for results + on_vad_event: Callback for VAD events + + Returns: + Created session, or None if limit reached + """ + async with self._lock: + # Check session limit + if len(self._sessions) >= self.max_sessions: + return None + + # Generate session ID if not provided + if session_id is None: + session_id = str(uuid.uuid4())[:8] + + # Check for duplicate + if session_id in self._sessions: + return self._sessions[session_id] + + # Create session + session = RealtimeSession( + session_id=session_id, + asr_worker=self.asr_worker, + config=config, + on_result=on_result, + on_vad_event=on_vad_event, + ) + + self._sessions[session_id] = session + return session + + async def get_session(self, session_id: str) -> Optional[RealtimeSession]: + """Get a session by ID.""" + async with self._lock: + return self._sessions.get(session_id) + + async def close_session(self, session_id: str) -> bool: + """ + Close and remove a session. + + Args: + session_id: Session to close + + Returns: + True if session was found and closed + """ + async with self._lock: + session = self._sessions.pop(session_id, None) + if session: + await session.flush() + session.close() + return True + return False + + async def _cleanup_loop(self) -> None: + """Background task to clean up idle sessions.""" + while True: + try: + await asyncio.sleep(60) # Check every minute + + current_time = time.time() + sessions_to_close = [] + + async with self._lock: + for session_id, session in self._sessions.items(): + idle_time = current_time - session.stats.last_activity + if idle_time > self.session_timeout: + sessions_to_close.append(session_id) + + for session_id in sessions_to_close: + print(f"Closing idle session: {session_id}") + await self.close_session(session_id) + + except asyncio.CancelledError: + break + except Exception as e: + print(f"Cleanup error: {e}") + + def get_stats(self) -> Dict: + """Get manager statistics.""" + return { + "active_sessions": len(self._sessions), + "max_sessions": self.max_sessions, + "session_timeout_sec": self.session_timeout, + "sessions": { + sid: session.get_stats() + for sid, session in self._sessions.items() + }, + } diff --git a/static/scripts/vibevoice-asr/realtime/vad_processor.py b/static/scripts/vibevoice-asr/realtime/vad_processor.py new file mode 100644 index 0000000..62fbcf2 --- /dev/null +++ b/static/scripts/vibevoice-asr/realtime/vad_processor.py @@ -0,0 +1,295 @@ +""" +Voice Activity Detection (VAD) processor using Silero-VAD (ONNX version). + +Detects speech segments in real-time audio streams. +Uses ONNX runtime to avoid torchaudio dependency issues. +""" + +import numpy as np +from typing import List, Optional, Tuple +from dataclasses import dataclass +import threading +import os +import urllib.request + +from .models import SpeechSegment, VADEvent, VADEventType, MessageType + + +@dataclass +class VADState: + """Internal state of the VAD processor.""" + is_speech_active: bool = False + speech_start_sample: int = 0 + silence_start_sample: int = 0 + last_speech_prob: float = 0.0 + total_samples_processed: int = 0 + + +class VADProcessor: + """ + Voice Activity Detection using Silero-VAD (ONNX version). + + Features: + - Real-time speech detection + - Configurable thresholds for speech/silence duration + - Event generation for speech start/end + - Thread-safe operations + - No torchaudio dependency (uses ONNX runtime) + """ + + # Silero VAD ONNX model URL + ONNX_MODEL_URL = "https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx" + + def __init__( + self, + sample_rate: int = 16000, + threshold: float = 0.5, + min_speech_duration_ms: int = 250, + min_silence_duration_ms: int = 500, + window_size_samples: int = 512, + min_volume_threshold: float = 0.01, + ): + """ + Initialize the VAD processor. + + Args: + sample_rate: Audio sample rate (must be 16000 for Silero-VAD) + threshold: Speech probability threshold (0.0-1.0) + min_speech_duration_ms: Minimum speech duration to trigger speech_start + min_silence_duration_ms: Minimum silence duration to trigger speech_end + window_size_samples: VAD window size (512 for 16kHz = 32ms) + min_volume_threshold: Minimum RMS volume (0.0-1.0) to consider as potential speech + """ + if sample_rate != 16000: + raise ValueError("Silero-VAD requires 16kHz sample rate") + + self.sample_rate = sample_rate + self.threshold = threshold + self.min_speech_samples = int(min_speech_duration_ms * sample_rate / 1000) + self.min_silence_samples = int(min_silence_duration_ms * sample_rate / 1000) + self.window_size = window_size_samples + self.min_volume_threshold = min_volume_threshold + + # Load ONNX model + self._session = None + self._load_model() + + # ONNX model state - single state tensor (size depends on model version) + # Silero VAD v5 uses a single 'state' tensor of shape (2, 1, 128) + self._state_tensor = np.zeros((2, 1, 128), dtype=np.float32) + + # State + self._state = VADState() + self._lock = threading.Lock() + + # Pending speech segment (being accumulated) + self._pending_segment_start: Optional[int] = None + + def _get_model_path(self) -> str: + """Get path to ONNX model, downloading if necessary.""" + cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "silero-vad") + os.makedirs(cache_dir, exist_ok=True) + model_path = os.path.join(cache_dir, "silero_vad.onnx") + + if not os.path.exists(model_path): + print(f"Downloading Silero-VAD ONNX model to {model_path}...") + urllib.request.urlretrieve(self.ONNX_MODEL_URL, model_path) + print("Download complete.") + + return model_path + + def _load_model(self) -> None: + """Load Silero-VAD ONNX model.""" + try: + import onnxruntime as ort + + model_path = self._get_model_path() + self._session = ort.InferenceSession( + model_path, + providers=['CPUExecutionProvider'] + ) + print(f"Silero-VAD ONNX model loaded from {model_path}") + + except Exception as e: + raise RuntimeError(f"Failed to load Silero-VAD ONNX model: {e}") + + def _run_inference(self, audio_window: np.ndarray) -> float: + """Run VAD inference on a single window.""" + # Prepare input + audio_input = audio_window.reshape(1, -1).astype(np.float32) + sr_input = np.array([self.sample_rate], dtype=np.int64) + + # Run inference - Silero VAD v5 uses 'state' instead of 'h'/'c' + outputs = self._session.run( + ['output', 'stateN'], + { + 'input': audio_input, + 'sr': sr_input, + 'state': self._state_tensor, + } + ) + + # Update state + speech_prob = outputs[0][0][0] + self._state_tensor = outputs[1] + + return float(speech_prob) + + def reset(self) -> None: + """Reset VAD state for a new session.""" + with self._lock: + self._state = VADState() + self._pending_segment_start = None + # Reset state tensor + self._state_tensor = np.zeros((2, 1, 128), dtype=np.float32) + + def process( + self, + audio_chunk: np.ndarray, + return_events: bool = True, + ) -> Tuple[List[SpeechSegment], List[VADEvent]]: + """ + Process an audio chunk and detect speech segments. + + Args: + audio_chunk: Audio data as float32 array + return_events: Whether to return VAD events + + Returns: + Tuple of (completed_segments, events) + """ + if audio_chunk.dtype != np.float32: + audio_chunk = audio_chunk.astype(np.float32) + + completed_segments: List[SpeechSegment] = [] + events: List[VADEvent] = [] + + with self._lock: + # Process in windows + chunk_start_sample = self._state.total_samples_processed + num_windows = len(audio_chunk) // self.window_size + + for i in range(num_windows): + window_start = i * self.window_size + window_end = window_start + self.window_size + window = audio_chunk[window_start:window_end] + + # Check volume (RMS) threshold first + rms = np.sqrt(np.mean(window ** 2)) + if rms < self.min_volume_threshold: + # Volume too low, treat as silence + speech_prob = 0.0 + else: + # Get speech probability from VAD model + speech_prob = self._run_inference(window) + + self._state.last_speech_prob = speech_prob + + current_sample = chunk_start_sample + window_end + is_speech = speech_prob >= self.threshold + + # State machine for speech detection + if is_speech: + if not self._state.is_speech_active: + # Potential speech start + if self._pending_segment_start is None: + self._pending_segment_start = current_sample - self.window_size + + # Check if speech duration exceeds minimum + speech_duration = current_sample - self._pending_segment_start + if speech_duration >= self.min_speech_samples: + self._state.is_speech_active = True + self._state.speech_start_sample = self._pending_segment_start + + if return_events: + events.append(VADEvent( + type=MessageType.VAD_EVENT, + event=VADEventType.SPEECH_START, + audio_timestamp_sec=self._pending_segment_start / self.sample_rate, + )) + else: + # Continue speech, reset silence counter + self._state.silence_start_sample = 0 + else: + if self._state.is_speech_active: + # Potential speech end + if self._state.silence_start_sample == 0: + self._state.silence_start_sample = current_sample + + # Check if silence duration exceeds minimum + silence_duration = current_sample - self._state.silence_start_sample + if silence_duration >= self.min_silence_samples: + # Speech ended - create completed segment + segment = SpeechSegment( + start_sample=self._state.speech_start_sample, + end_sample=self._state.silence_start_sample, + start_sec=self._state.speech_start_sample / self.sample_rate, + end_sec=self._state.silence_start_sample / self.sample_rate, + ) + completed_segments.append(segment) + + if return_events: + events.append(VADEvent( + type=MessageType.VAD_EVENT, + event=VADEventType.SPEECH_END, + audio_timestamp_sec=self._state.silence_start_sample / self.sample_rate, + )) + + # Reset state + self._state.is_speech_active = False + self._state.speech_start_sample = 0 + self._state.silence_start_sample = 0 + self._pending_segment_start = None + else: + # No speech, reset pending + self._pending_segment_start = None + + # Update total samples processed + self._state.total_samples_processed += len(audio_chunk) + + return completed_segments, events + + def force_end_speech(self) -> Optional[SpeechSegment]: + """ + Force end of current speech segment (e.g., when session ends). + + Returns: + Completed speech segment if speech was active, None otherwise + """ + with self._lock: + if self._state.is_speech_active: + segment = SpeechSegment( + start_sample=self._state.speech_start_sample, + end_sample=self._state.total_samples_processed, + start_sec=self._state.speech_start_sample / self.sample_rate, + end_sec=self._state.total_samples_processed / self.sample_rate, + ) + + self._state.is_speech_active = False + self._state.speech_start_sample = 0 + self._state.silence_start_sample = 0 + self._pending_segment_start = None + + return segment + + return None + + @property + def is_speech_active(self) -> bool: + """Check if speech is currently active.""" + with self._lock: + return self._state.is_speech_active + + @property + def last_speech_probability(self) -> float: + """Get the last computed speech probability.""" + with self._lock: + return self._state.last_speech_prob + + @property + def current_speech_duration_sec(self) -> float: + """Get duration of current speech segment (if active).""" + with self._lock: + if not self._state.is_speech_active: + return 0.0 + return (self._state.total_samples_processed - self._state.speech_start_sample) / self.sample_rate diff --git a/static/scripts/vibevoice-asr/requirements-realtime.txt b/static/scripts/vibevoice-asr/requirements-realtime.txt new file mode 100644 index 0000000..c75b3e3 --- /dev/null +++ b/static/scripts/vibevoice-asr/requirements-realtime.txt @@ -0,0 +1,7 @@ +# Real-time ASR dependencies +fastapi>=0.100.0 +uvicorn[standard]>=0.23.0 +websockets>=11.0 +numpy>=1.24.0 +soundfile>=0.12.0 +onnxruntime diff --git a/static/scripts/vibevoice-asr/run_all.sh b/static/scripts/vibevoice-asr/run_all.sh new file mode 100755 index 0000000..0c53f01 --- /dev/null +++ b/static/scripts/vibevoice-asr/run_all.sh @@ -0,0 +1,73 @@ +#!/bin/bash +# Run both Gradio demo and Realtime ASR server +# +# Usage: +# ./run_all.sh +# +# Ports: +# - 7860: Gradio UI (batch ASR) +# - 8000: WebSocket API (realtime ASR) + +set -e + +cd "$(dirname "$0")" + +# Configuration +GRADIO_HOST="${GRADIO_HOST:-0.0.0.0}" +GRADIO_PORT="${GRADIO_PORT:-7860}" +REALTIME_HOST="${REALTIME_HOST:-0.0.0.0}" +REALTIME_PORT="${REALTIME_PORT:-8000}" +MODEL_PATH="${VIBEVOICE_MODEL_PATH:-microsoft/VibeVoice-ASR}" + +echo "==========================================" +echo "VibeVoice ASR - All Services" +echo "==========================================" +echo "" +echo "Starting services:" +echo " - Gradio UI: http://$GRADIO_HOST:$GRADIO_PORT" +echo " - Realtime ASR: http://$REALTIME_HOST:$REALTIME_PORT" +echo " - Test Client: http://$REALTIME_HOST:$REALTIME_PORT/static/realtime_client.html" +echo "" +echo "Model: $MODEL_PATH" +echo "==========================================" +echo "" + +# Trap to clean up background processes on exit +cleanup() { + echo "" + echo "Shutting down..." + kill $REALTIME_PID 2>/dev/null || true + kill $GRADIO_PID 2>/dev/null || true + wait + echo "All services stopped." +} +trap cleanup EXIT INT TERM + +# Start Realtime ASR server in background +echo "[1/2] Starting Realtime ASR server..." +python -m realtime.server \ + --host "$REALTIME_HOST" \ + --port "$REALTIME_PORT" \ + --model-path "$MODEL_PATH" \ + --no-preload & +REALTIME_PID=$! + +# Wait a moment for the server to initialize +sleep 2 + +# Start Gradio demo in background +echo "[2/2] Starting Gradio demo..." +python demo/vibevoice_asr_gradio_demo.py \ + --host "$GRADIO_HOST" \ + --port "$GRADIO_PORT" \ + --model_path "$MODEL_PATH" & +GRADIO_PID=$! + +echo "" +echo "Both services started. Press Ctrl+C to stop." +echo "" + +# Wait for either process to exit +wait -n $REALTIME_PID $GRADIO_PID + +# If one exits, the trap will clean up the other diff --git a/static/scripts/vibevoice-asr/run_realtime.sh b/static/scripts/vibevoice-asr/run_realtime.sh new file mode 100755 index 0000000..d8b7f89 --- /dev/null +++ b/static/scripts/vibevoice-asr/run_realtime.sh @@ -0,0 +1,41 @@ +#!/bin/bash +# Run VibeVoice Realtime ASR Server +# +# Usage: +# ./run_realtime.sh [options] +# +# Options are passed to the server (see --help for details) + +set -e + +cd "$(dirname "$0")" + +# Default options +HOST="${VIBEVOICE_HOST:-0.0.0.0}" +PORT="${VIBEVOICE_PORT:-8000}" +MODEL_PATH="${VIBEVOICE_MODEL_PATH:-microsoft/VibeVoice-ASR}" +DEVICE="${VIBEVOICE_DEVICE:-cuda}" +MAX_SESSIONS="${VIBEVOICE_MAX_SESSIONS:-10}" + +echo "==========================================" +echo "VibeVoice Realtime ASR Server" +echo "==========================================" +echo "Host: $HOST" +echo "Port: $PORT" +echo "Model: $MODEL_PATH" +echo "Device: $DEVICE" +echo "Max Sessions: $MAX_SESSIONS" +echo "==========================================" +echo "" +echo "Web client: http://$HOST:$PORT/static/realtime_client.html" +echo "WebSocket: ws://$HOST:$PORT/ws/asr/{session_id}" +echo "" + +# Run server +python -m realtime.server \ + --host "$HOST" \ + --port "$PORT" \ + --model-path "$MODEL_PATH" \ + --device "$DEVICE" \ + --max-sessions "$MAX_SESSIONS" \ + "$@" diff --git a/static/scripts/vibevoice-asr/setup.sh b/static/scripts/vibevoice-asr/setup.sh new file mode 100644 index 0000000..df2a25e --- /dev/null +++ b/static/scripts/vibevoice-asr/setup.sh @@ -0,0 +1,287 @@ +#!/bin/bash +# VibeVoice-ASR Setup Script for DGX Spark +# Downloads and builds the VibeVoice-ASR container +# +# Usage: +# curl -sL https://docs.techswan.online/scripts/vibevoice-asr/setup.sh | bash +# curl -sL https://docs.techswan.online/scripts/vibevoice-asr/setup.sh | bash -s build +# curl -sL https://docs.techswan.online/scripts/vibevoice-asr/setup.sh | bash -s serve + +set -e + +BASE_URL="https://docs.techswan.online/scripts/vibevoice-asr" +INSTALL_DIR="${VIBEVOICE_DIR:-$HOME/vibevoice-asr}" +IMAGE_NAME="vibevoice-asr:dgx-spark" +CONTAINER_NAME="vibevoice-asr" + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' # No Color + +log_info() { + echo -e "${GREEN}[INFO]${NC} $1" +} + +log_warn() { + echo -e "${YELLOW}[WARN]${NC} $1" +} + +log_error() { + echo -e "${RED}[ERROR]${NC} $1" +} + +log_step() { + echo -e "${BLUE}[STEP]${NC} $1" +} + +download_file() { + local url="$1" + local dest="$2" + local dir=$(dirname "$dest") + + mkdir -p "$dir" + + if command -v curl &> /dev/null; then + curl -sL "$url" -o "$dest" + elif command -v wget &> /dev/null; then + wget -q "$url" -O "$dest" + else + log_error "curl or wget is required" + exit 1 + fi +} + +download_files() { + log_step "Downloading VibeVoice-ASR files to $INSTALL_DIR..." + + mkdir -p "$INSTALL_DIR" + cd "$INSTALL_DIR" + + # Core files + local files=( + "Dockerfile" + "requirements-realtime.txt" + "test_vibevoice.py" + "vibevoice_asr_gradio_demo_patched.py" + "run_realtime.sh" + "run_all.sh" + ) + + for file in "${files[@]}"; do + log_info "Downloading $file..." + download_file "$BASE_URL/$file" "$INSTALL_DIR/$file" + done + + # Realtime module + local realtime_files=( + "__init__.py" + "models.py" + "server.py" + "asr_worker.py" + "session_manager.py" + "audio_buffer.py" + "vad_processor.py" + ) + + mkdir -p "$INSTALL_DIR/realtime" + for file in "${realtime_files[@]}"; do + log_info "Downloading realtime/$file..." + download_file "$BASE_URL/realtime/$file" "$INSTALL_DIR/realtime/$file" + done + + # Static files + mkdir -p "$INSTALL_DIR/static" + log_info "Downloading static/realtime_client.html..." + download_file "$BASE_URL/static/realtime_client.html" "$INSTALL_DIR/static/realtime_client.html" + + # Make scripts executable + chmod +x "$INSTALL_DIR/run_realtime.sh" "$INSTALL_DIR/run_all.sh" + + log_info "All files downloaded to $INSTALL_DIR" +} + +check_prerequisites() { + log_step "Checking prerequisites..." + + # Check Docker + if ! command -v docker &> /dev/null; then + log_error "Docker is not installed" + exit 1 + fi + + # Check NVIDIA Docker runtime + if ! docker info 2>/dev/null | grep -q "Runtimes.*nvidia"; then + log_warn "NVIDIA Docker runtime may not be configured" + fi + + # Check GPU availability + if command -v nvidia-smi &> /dev/null; then + log_info "GPU detected:" + nvidia-smi --query-gpu=name,memory.total --format=csv,noheader + else + log_warn "nvidia-smi not found on host" + fi + + log_info "Prerequisites check complete" +} + +build_image() { + log_step "Building Docker image: ${IMAGE_NAME}" + log_info "This may take several minutes..." + + cd "$INSTALL_DIR" + + docker build \ + --network=host \ + -t "$IMAGE_NAME" \ + -f Dockerfile \ + . + + log_info "Docker image built successfully: ${IMAGE_NAME}" +} + +run_container() { + local mode="${1:-interactive}" + + log_step "Running container in ${mode} mode..." + + # Stop existing container if running + if docker ps -q -f name="$CONTAINER_NAME" | grep -q .; then + log_warn "Stopping existing container..." + docker stop "$CONTAINER_NAME" 2>/dev/null || true + fi + + # Remove existing container + docker rm "$CONTAINER_NAME" 2>/dev/null || true + + # Common Docker options for DGX Spark + local docker_opts=( + --gpus all + --ipc=host + --network=host + --ulimit memlock=-1:-1 + --ulimit stack=-1:-1 + -e "PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True" + -v "$HOME/.cache/huggingface:/root/.cache/huggingface" + --name "$CONTAINER_NAME" + ) + + if [ "$mode" = "interactive" ]; then + docker run --rm -it "${docker_opts[@]}" "$IMAGE_NAME" bash + elif [ "$mode" = "test" ]; then + docker run --rm "${docker_opts[@]}" "$IMAGE_NAME" python /workspace/test_vibevoice.py + elif [ "$mode" = "demo" ]; then + log_info "Starting Gradio demo on port 7860..." + log_info "Access the demo at: http://localhost:7860" + docker run --rm -it "${docker_opts[@]}" "$IMAGE_NAME" + elif [ "$mode" = "realtime" ]; then + log_info "Starting Realtime ASR server on port 8000..." + log_info "WebSocket API: ws://localhost:8000/ws/asr/{session_id}" + log_info "Test client: http://localhost:8000/static/realtime_client.html" + docker run --rm -it "${docker_opts[@]}" "$IMAGE_NAME" \ + python -m realtime.server --host 0.0.0.0 --port 8000 + elif [ "$mode" = "serve" ]; then + log_info "Starting all services..." + log_info " Gradio demo: http://localhost:7860" + log_info " Realtime ASR: http://localhost:8000" + log_info " Test client: http://localhost:8000/static/realtime_client.html" + docker run --rm -it "${docker_opts[@]}" "$IMAGE_NAME" ./run_all.sh + else + log_error "Unknown mode: $mode" + exit 1 + fi +} + +show_usage() { + echo "VibeVoice-ASR Setup for DGX Spark" + echo "" + echo "Usage:" + echo " curl -sL $BASE_URL/setup.sh | bash # Download only" + echo " curl -sL $BASE_URL/setup.sh | bash -s build # Download and build" + echo " curl -sL $BASE_URL/setup.sh | bash -s demo # Download, build, run demo" + echo " curl -sL $BASE_URL/setup.sh | bash -s serve # Download, build, run all" + echo "" + echo "Commands:" + echo " (default) Download files only" + echo " build Download and build Docker image" + echo " demo Download, build, and start Gradio demo (port 7860)" + echo " realtime Download, build, and start Realtime ASR (port 8000)" + echo " serve Download, build, and start both services" + echo " run Run container interactively (after build)" + echo "" + echo "Environment variables:" + echo " VIBEVOICE_DIR Installation directory (default: ~/vibevoice-asr)" + echo "" + echo "After installation, you can also run:" + echo " cd ~/vibevoice-asr" + echo " docker run --gpus all -p 7860:7860 vibevoice-asr:dgx-spark" +} + +main() { + local command="${1:-download}" + + echo "" + echo "==========================================" + echo " VibeVoice-ASR Setup for DGX Spark" + echo "==========================================" + echo "" + + case "$command" in + download) + download_files + echo "" + log_info "Done! Next steps:" + echo " cd $INSTALL_DIR" + echo " docker build -t vibevoice-asr:dgx-spark ." + echo " docker run --gpus all -p 7860:7860 vibevoice-asr:dgx-spark" + ;; + build) + download_files + check_prerequisites + build_image + echo "" + log_info "Done! To run:" + echo " cd $INSTALL_DIR" + echo " docker run --gpus all -p 7860:7860 vibevoice-asr:dgx-spark" + ;; + run) + cd "$INSTALL_DIR" 2>/dev/null || { log_error "Run 'build' first"; exit 1; } + run_container interactive + ;; + test) + cd "$INSTALL_DIR" 2>/dev/null || { log_error "Run 'build' first"; exit 1; } + run_container test + ;; + demo) + download_files + check_prerequisites + build_image + run_container demo + ;; + realtime) + download_files + check_prerequisites + build_image + run_container realtime + ;; + serve) + download_files + check_prerequisites + build_image + run_container serve + ;; + -h|--help|help) + show_usage + ;; + *) + log_error "Unknown command: $command" + show_usage + exit 1 + ;; + esac +} + +main "$@" diff --git a/static/scripts/vibevoice-asr/static/realtime_client.html b/static/scripts/vibevoice-asr/static/realtime_client.html new file mode 100644 index 0000000..4ff576c --- /dev/null +++ b/static/scripts/vibevoice-asr/static/realtime_client.html @@ -0,0 +1,899 @@ + + +
+ + +ใชใขใซใฟใคใ ้ณๅฃฐ่ช่ญใใข
++ ๆฅ็ถใใฆ้ฒ้ณใ้ๅงใใใจใใใใซ่ช่ญ็ตๆใ่กจ็คบใใใพใ +
+๐ต Click the play button to listen to each segment directly!
" + + for i, (label, audio_src, error_msg) in enumerate(audio_segments): + seg = segments[i] if i < len(segments) else {} + start_time = seg.get('start_time', 'N/A') + end_time = seg.get('end_time', 'N/A') + speaker_id = seg.get('speaker_id', 'N/A') + content = seg.get('text', '') + + # Format times nicely + start_str = f"{start_time:.2f}" if isinstance(start_time, (int, float)) else str(start_time) + end_str = f"{end_time:.2f}" if isinstance(end_time, (int, float)) else str(end_time) + + audio_segments_html += f""" +โ No audio segments available.
+This could happen if the model output doesn't contain valid time stamps.
+