Add: VibeVoice ASR セットアップスクリプト一式
All checks were successful
Deploy Docusaurus Site / deploy (push) Successful in 27s
All checks were successful
Deploy Docusaurus Site / deploy (push) Successful in 27s
This commit is contained in:
parent
2d753f114f
commit
1fb76254e9
61
static/scripts/vibevoice-asr/Dockerfile
Normal file
61
static/scripts/vibevoice-asr/Dockerfile
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
# VibeVoice-ASR for DGX Spark (ARM64, Blackwell GB10, sm_121)
|
||||||
|
# Based on NVIDIA PyTorch container for CUDA 13.1 compatibility
|
||||||
|
|
||||||
|
ARG TARGETARCH
|
||||||
|
FROM nvcr.io/nvidia/pytorch:25.11-py3 AS base
|
||||||
|
|
||||||
|
LABEL maintainer="VibeVoice-ASR DGX Spark Setup"
|
||||||
|
LABEL description="VibeVoice-ASR optimized for DGX Spark (ARM64, CUDA 13.1)"
|
||||||
|
|
||||||
|
# Set environment variables
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
ENV PYTHONDONTWRITEBYTECODE=1
|
||||||
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
|
||||||
|
# PyTorch CUDA settings for DGX Spark
|
||||||
|
ENV PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
||||||
|
ENV USE_LIBUV=0
|
||||||
|
|
||||||
|
# Set working directory
|
||||||
|
WORKDIR /workspace
|
||||||
|
|
||||||
|
# Install system dependencies including FFmpeg for demo
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
ffmpeg \
|
||||||
|
git \
|
||||||
|
curl \
|
||||||
|
&& apt-get clean \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Install flash-attn if not already present
|
||||||
|
RUN pip install --no-cache-dir flash-attn --no-build-isolation || true
|
||||||
|
|
||||||
|
# Clone and install VibeVoice
|
||||||
|
RUN git clone https://github.com/microsoft/VibeVoice.git /workspace/VibeVoice && \
|
||||||
|
cd /workspace/VibeVoice && \
|
||||||
|
pip install --no-cache-dir -e .
|
||||||
|
|
||||||
|
# Create test script and patched demo with MKV support
|
||||||
|
COPY test_vibevoice.py /workspace/test_vibevoice.py
|
||||||
|
COPY vibevoice_asr_gradio_demo_patched.py /workspace/VibeVoice/demo/vibevoice_asr_gradio_demo.py
|
||||||
|
|
||||||
|
# Install real-time ASR dependencies
|
||||||
|
COPY requirements-realtime.txt /workspace/requirements-realtime.txt
|
||||||
|
RUN pip install --no-cache-dir -r /workspace/requirements-realtime.txt
|
||||||
|
|
||||||
|
# Copy real-time ASR module and startup scripts
|
||||||
|
COPY realtime/ /workspace/VibeVoice/realtime/
|
||||||
|
COPY static/ /workspace/VibeVoice/static/
|
||||||
|
COPY run_all.sh /workspace/VibeVoice/run_all.sh
|
||||||
|
COPY run_realtime.sh /workspace/VibeVoice/run_realtime.sh
|
||||||
|
RUN chmod +x /workspace/VibeVoice/run_all.sh /workspace/VibeVoice/run_realtime.sh
|
||||||
|
|
||||||
|
# Set default working directory to VibeVoice
|
||||||
|
WORKDIR /workspace/VibeVoice
|
||||||
|
|
||||||
|
# Expose Gradio port and WebSocket port
|
||||||
|
EXPOSE 7860
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
# Default command: Launch Gradio demo with MKV support
|
||||||
|
CMD ["python", "demo/vibevoice_asr_gradio_demo.py", "--model_path", "microsoft/VibeVoice-ASR", "--host", "0.0.0.0"]
|
||||||
7
static/scripts/vibevoice-asr/realtime/__init__.py
Normal file
7
static/scripts/vibevoice-asr/realtime/__init__.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
"""
|
||||||
|
VibeVoice Realtime ASR Module
|
||||||
|
|
||||||
|
WebSocket-based real-time speech recognition using VibeVoice ASR.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__version__ = "0.1.0"
|
||||||
358
static/scripts/vibevoice-asr/realtime/asr_worker.py
Normal file
358
static/scripts/vibevoice-asr/realtime/asr_worker.py
Normal file
@ -0,0 +1,358 @@
|
|||||||
|
"""
|
||||||
|
ASR Worker for real-time transcription.
|
||||||
|
|
||||||
|
Wraps the existing VibeVoiceASRInference for async/streaming operation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import queue
|
||||||
|
from typing import AsyncGenerator, Optional, List, Callable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Add parent directory and demo directory to path for importing existing code
|
||||||
|
_parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
sys.path.insert(0, _parent_dir)
|
||||||
|
sys.path.insert(0, os.path.join(_parent_dir, "demo"))
|
||||||
|
|
||||||
|
from .models import (
|
||||||
|
TranscriptionResult,
|
||||||
|
TranscriptionSegment,
|
||||||
|
MessageType,
|
||||||
|
SessionConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InferenceRequest:
|
||||||
|
"""Request for ASR inference."""
|
||||||
|
audio: np.ndarray
|
||||||
|
sample_rate: int
|
||||||
|
context_info: Optional[str]
|
||||||
|
request_time: float
|
||||||
|
segment_start_sec: float
|
||||||
|
segment_end_sec: float
|
||||||
|
|
||||||
|
|
||||||
|
class ASRWorker:
|
||||||
|
"""
|
||||||
|
ASR Worker that wraps VibeVoiceASRInference for real-time use.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Async interface for WebSocket integration
|
||||||
|
- Streaming output via TextIteratorStreamer
|
||||||
|
- Request queuing for handling concurrent segments
|
||||||
|
- Graceful model loading and error handling
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_path: str = "microsoft/VibeVoice-ASR",
|
||||||
|
device: str = "cuda",
|
||||||
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
|
attn_implementation: str = "flash_attention_2",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the ASR worker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to VibeVoice ASR model
|
||||||
|
device: Device to run inference on
|
||||||
|
dtype: Model data type
|
||||||
|
attn_implementation: Attention implementation
|
||||||
|
"""
|
||||||
|
self.model_path = model_path
|
||||||
|
self.device = device
|
||||||
|
self.dtype = dtype
|
||||||
|
self.attn_implementation = attn_implementation
|
||||||
|
|
||||||
|
self._inference = None
|
||||||
|
self._is_loaded = False
|
||||||
|
self._load_lock = threading.Lock()
|
||||||
|
|
||||||
|
# Inference queue for serializing requests
|
||||||
|
self._inference_semaphore = asyncio.Semaphore(1)
|
||||||
|
|
||||||
|
def load_model(self) -> bool:
|
||||||
|
"""
|
||||||
|
Load the ASR model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if model loaded successfully
|
||||||
|
"""
|
||||||
|
with self._load_lock:
|
||||||
|
if self._is_loaded:
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Import here to avoid circular imports and allow lazy loading
|
||||||
|
# In Docker, the file is copied as vibevoice_asr_gradio_demo.py
|
||||||
|
try:
|
||||||
|
from vibevoice_asr_gradio_demo import VibeVoiceASRInference
|
||||||
|
except ImportError:
|
||||||
|
from vibevoice_asr_gradio_demo_patched import VibeVoiceASRInference
|
||||||
|
|
||||||
|
print(f"Loading VibeVoice ASR model from {self.model_path}...")
|
||||||
|
self._inference = VibeVoiceASRInference(
|
||||||
|
model_path=self.model_path,
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype,
|
||||||
|
attn_implementation=self.attn_implementation,
|
||||||
|
)
|
||||||
|
self._is_loaded = True
|
||||||
|
print("ASR model loaded successfully")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to load ASR model: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_loaded(self) -> bool:
|
||||||
|
"""Check if model is loaded."""
|
||||||
|
return self._is_loaded
|
||||||
|
|
||||||
|
async def transcribe_segment(
|
||||||
|
self,
|
||||||
|
audio: np.ndarray,
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
context_info: Optional[str] = None,
|
||||||
|
segment_start_sec: float = 0.0,
|
||||||
|
segment_end_sec: float = 0.0,
|
||||||
|
config: Optional[SessionConfig] = None,
|
||||||
|
on_partial: Optional[Callable[[TranscriptionResult], None]] = None,
|
||||||
|
) -> TranscriptionResult:
|
||||||
|
"""
|
||||||
|
Transcribe an audio segment asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio: Audio data as float32 array
|
||||||
|
sample_rate: Audio sample rate
|
||||||
|
context_info: Optional context for transcription
|
||||||
|
segment_start_sec: Start time of segment in session
|
||||||
|
segment_end_sec: End time of segment in session
|
||||||
|
config: Session configuration
|
||||||
|
on_partial: Callback for partial results
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Final transcription result
|
||||||
|
"""
|
||||||
|
if not self._is_loaded:
|
||||||
|
if not self.load_model():
|
||||||
|
return TranscriptionResult(
|
||||||
|
type=MessageType.ERROR,
|
||||||
|
text="",
|
||||||
|
is_final=True,
|
||||||
|
latency_ms=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
config = config or SessionConfig()
|
||||||
|
request_time = time.time()
|
||||||
|
|
||||||
|
# Serialize inference requests
|
||||||
|
async with self._inference_semaphore:
|
||||||
|
return await self._run_inference(
|
||||||
|
audio=audio,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
context_info=context_info,
|
||||||
|
segment_start_sec=segment_start_sec,
|
||||||
|
segment_end_sec=segment_end_sec,
|
||||||
|
config=config,
|
||||||
|
request_time=request_time,
|
||||||
|
on_partial=on_partial,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _run_inference(
|
||||||
|
self,
|
||||||
|
audio: np.ndarray,
|
||||||
|
sample_rate: int,
|
||||||
|
context_info: Optional[str],
|
||||||
|
segment_start_sec: float,
|
||||||
|
segment_end_sec: float,
|
||||||
|
config: SessionConfig,
|
||||||
|
request_time: float,
|
||||||
|
on_partial: Optional[Callable[[TranscriptionResult], None]],
|
||||||
|
) -> TranscriptionResult:
|
||||||
|
"""Run the actual inference in a thread pool."""
|
||||||
|
from transformers import TextIteratorStreamer
|
||||||
|
|
||||||
|
# Create streamer for partial results
|
||||||
|
streamer = None
|
||||||
|
if config.return_partial_results and on_partial:
|
||||||
|
streamer = TextIteratorStreamer(
|
||||||
|
self._inference.processor.tokenizer,
|
||||||
|
skip_prompt=True,
|
||||||
|
skip_special_tokens=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Result container for thread
|
||||||
|
result_container = {"result": None, "error": None}
|
||||||
|
|
||||||
|
def run_inference():
|
||||||
|
try:
|
||||||
|
# Save audio to temp file (required by current implementation)
|
||||||
|
import tempfile
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
||||||
|
temp_path = f.name
|
||||||
|
|
||||||
|
# Write audio
|
||||||
|
audio_int16 = (audio * 32768.0).clip(-32768, 32767).astype(np.int16)
|
||||||
|
sf.write(temp_path, audio_int16, sample_rate, subtype='PCM_16')
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = self._inference.transcribe(
|
||||||
|
audio_path=temp_path,
|
||||||
|
max_new_tokens=config.max_new_tokens,
|
||||||
|
temperature=config.temperature,
|
||||||
|
context_info=context_info,
|
||||||
|
streamer=streamer,
|
||||||
|
)
|
||||||
|
result_container["result"] = result
|
||||||
|
finally:
|
||||||
|
# Clean up temp file
|
||||||
|
try:
|
||||||
|
os.unlink(temp_path)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
result_container["error"] = str(e)
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
# Start inference in background thread
|
||||||
|
inference_thread = threading.Thread(target=run_inference)
|
||||||
|
inference_thread.start()
|
||||||
|
|
||||||
|
# Stream partial results if enabled
|
||||||
|
partial_text = ""
|
||||||
|
if streamer and on_partial:
|
||||||
|
try:
|
||||||
|
for new_text in streamer:
|
||||||
|
partial_text += new_text
|
||||||
|
partial_result = TranscriptionResult(
|
||||||
|
type=MessageType.PARTIAL_RESULT,
|
||||||
|
text=partial_text,
|
||||||
|
is_final=False,
|
||||||
|
latency_ms=(time.time() - request_time) * 1000,
|
||||||
|
)
|
||||||
|
# Call callback (may be async)
|
||||||
|
if asyncio.iscoroutinefunction(on_partial):
|
||||||
|
await on_partial(partial_result)
|
||||||
|
else:
|
||||||
|
on_partial(partial_result)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error during streaming: {e}")
|
||||||
|
|
||||||
|
# Wait for completion
|
||||||
|
inference_thread.join()
|
||||||
|
|
||||||
|
latency_ms = (time.time() - request_time) * 1000
|
||||||
|
|
||||||
|
if result_container["error"]:
|
||||||
|
return TranscriptionResult(
|
||||||
|
type=MessageType.ERROR,
|
||||||
|
text=f"Error: {result_container['error']}",
|
||||||
|
is_final=True,
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = result_container["result"]
|
||||||
|
|
||||||
|
# Convert segments to our format
|
||||||
|
segments = []
|
||||||
|
for seg in result.get("segments", []):
|
||||||
|
# Adjust timestamps relative to session
|
||||||
|
seg_start = seg.get("start_time", 0)
|
||||||
|
seg_end = seg.get("end_time", 0)
|
||||||
|
|
||||||
|
# If segment has relative timestamps, adjust to absolute
|
||||||
|
if isinstance(seg_start, (int, float)) and isinstance(seg_end, (int, float)):
|
||||||
|
adjusted_start = segment_start_sec + seg_start
|
||||||
|
adjusted_end = segment_start_sec + seg_end
|
||||||
|
else:
|
||||||
|
adjusted_start = segment_start_sec
|
||||||
|
adjusted_end = segment_end_sec
|
||||||
|
|
||||||
|
segments.append(TranscriptionSegment(
|
||||||
|
start_time=adjusted_start,
|
||||||
|
end_time=adjusted_end,
|
||||||
|
speaker_id=seg.get("speaker_id", "SPEAKER_00"),
|
||||||
|
text=seg.get("text", ""),
|
||||||
|
))
|
||||||
|
|
||||||
|
return TranscriptionResult(
|
||||||
|
type=MessageType.FINAL_RESULT,
|
||||||
|
text=result.get("raw_text", ""),
|
||||||
|
is_final=True,
|
||||||
|
segments=segments,
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def transcribe_stream(
|
||||||
|
self,
|
||||||
|
audio: np.ndarray,
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
context_info: Optional[str] = None,
|
||||||
|
segment_start_sec: float = 0.0,
|
||||||
|
segment_end_sec: float = 0.0,
|
||||||
|
config: Optional[SessionConfig] = None,
|
||||||
|
) -> AsyncGenerator[TranscriptionResult, None]:
|
||||||
|
"""
|
||||||
|
Transcribe an audio segment with streaming output.
|
||||||
|
|
||||||
|
Yields partial results followed by final result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio: Audio data
|
||||||
|
sample_rate: Sample rate
|
||||||
|
context_info: Optional context
|
||||||
|
segment_start_sec: Segment start time
|
||||||
|
segment_end_sec: Segment end time
|
||||||
|
config: Session config
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
TranscriptionResult objects (partial and final)
|
||||||
|
"""
|
||||||
|
result_queue: asyncio.Queue = asyncio.Queue()
|
||||||
|
|
||||||
|
async def on_partial(result: TranscriptionResult):
|
||||||
|
await result_queue.put(result)
|
||||||
|
|
||||||
|
# Start transcription task
|
||||||
|
transcribe_task = asyncio.create_task(
|
||||||
|
self.transcribe_segment(
|
||||||
|
audio=audio,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
context_info=context_info,
|
||||||
|
segment_start_sec=segment_start_sec,
|
||||||
|
segment_end_sec=segment_end_sec,
|
||||||
|
config=config,
|
||||||
|
on_partial=on_partial,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Yield partial results as they come
|
||||||
|
while not transcribe_task.done():
|
||||||
|
try:
|
||||||
|
result = await asyncio.wait_for(result_queue.get(), timeout=0.1)
|
||||||
|
yield result
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Drain any remaining partial results
|
||||||
|
while not result_queue.empty():
|
||||||
|
yield await result_queue.get()
|
||||||
|
|
||||||
|
# Yield final result
|
||||||
|
final_result = await transcribe_task
|
||||||
|
yield final_result
|
||||||
246
static/scripts/vibevoice-asr/realtime/audio_buffer.py
Normal file
246
static/scripts/vibevoice-asr/realtime/audio_buffer.py
Normal file
@ -0,0 +1,246 @@
|
|||||||
|
"""
|
||||||
|
Audio buffer management for real-time ASR.
|
||||||
|
|
||||||
|
Implements a ring buffer for efficient audio chunk management with overlap support.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import threading
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AudioChunkInfo:
|
||||||
|
"""Information about an extracted audio chunk."""
|
||||||
|
audio: np.ndarray
|
||||||
|
start_sample: int
|
||||||
|
end_sample: int
|
||||||
|
start_sec: float
|
||||||
|
end_sec: float
|
||||||
|
|
||||||
|
|
||||||
|
class AudioBuffer:
|
||||||
|
"""
|
||||||
|
Ring buffer for managing audio chunks with overlap support.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Efficient memory management with fixed-size buffer
|
||||||
|
- Overlap handling for continuous processing
|
||||||
|
- Thread-safe operations
|
||||||
|
- Automatic sample rate tracking
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
chunk_duration_sec: float = 3.0,
|
||||||
|
overlap_sec: float = 0.5,
|
||||||
|
max_buffer_sec: float = 60.0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the audio buffer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample_rate: Audio sample rate in Hz
|
||||||
|
chunk_duration_sec: Duration of each processing chunk
|
||||||
|
overlap_sec: Overlap between consecutive chunks
|
||||||
|
max_buffer_sec: Maximum buffer duration (older data will be discarded)
|
||||||
|
"""
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.chunk_size = int(chunk_duration_sec * sample_rate)
|
||||||
|
self.overlap_size = int(overlap_sec * sample_rate)
|
||||||
|
self.max_buffer_size = int(max_buffer_sec * sample_rate)
|
||||||
|
|
||||||
|
# Main buffer (pre-allocated)
|
||||||
|
self._buffer = np.zeros(self.max_buffer_size, dtype=np.float32)
|
||||||
|
self._write_pos = 0 # Next position to write
|
||||||
|
self._read_pos = 0 # Position of unprocessed data start
|
||||||
|
self._total_samples_received = 0 # Total samples since session start
|
||||||
|
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def samples_available(self) -> int:
|
||||||
|
"""Number of unprocessed samples in buffer."""
|
||||||
|
with self._lock:
|
||||||
|
return self._write_pos - self._read_pos
|
||||||
|
|
||||||
|
@property
|
||||||
|
def duration_available_sec(self) -> float:
|
||||||
|
"""Duration of unprocessed audio in seconds."""
|
||||||
|
return self.samples_available / self.sample_rate
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_duration_sec(self) -> float:
|
||||||
|
"""Total duration of audio received since session start."""
|
||||||
|
return self._total_samples_received / self.sample_rate
|
||||||
|
|
||||||
|
def append(self, audio_chunk: np.ndarray) -> int:
|
||||||
|
"""
|
||||||
|
Append audio chunk to the buffer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_chunk: Audio data as float32 array (range: -1.0 to 1.0)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of samples actually appended
|
||||||
|
"""
|
||||||
|
if audio_chunk.dtype != np.float32:
|
||||||
|
audio_chunk = audio_chunk.astype(np.float32)
|
||||||
|
|
||||||
|
# Ensure 1D
|
||||||
|
if audio_chunk.ndim > 1:
|
||||||
|
audio_chunk = audio_chunk.flatten()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
chunk_len = len(audio_chunk)
|
||||||
|
|
||||||
|
# Check if we need to shift buffer (running out of space)
|
||||||
|
if self._write_pos + chunk_len > self.max_buffer_size:
|
||||||
|
self._compact_buffer()
|
||||||
|
|
||||||
|
# Still not enough space? Discard old unprocessed data
|
||||||
|
if self._write_pos + chunk_len > self.max_buffer_size:
|
||||||
|
overflow = (self._write_pos + chunk_len) - self.max_buffer_size
|
||||||
|
self._read_pos = min(self._read_pos + overflow, self._write_pos)
|
||||||
|
self._compact_buffer()
|
||||||
|
|
||||||
|
# Write to buffer
|
||||||
|
end_pos = self._write_pos + chunk_len
|
||||||
|
self._buffer[self._write_pos:end_pos] = audio_chunk
|
||||||
|
self._write_pos = end_pos
|
||||||
|
self._total_samples_received += chunk_len
|
||||||
|
|
||||||
|
return chunk_len
|
||||||
|
|
||||||
|
def _compact_buffer(self) -> None:
|
||||||
|
"""Move unprocessed data to the beginning of the buffer."""
|
||||||
|
if self._read_pos > 0:
|
||||||
|
unprocessed_len = self._write_pos - self._read_pos
|
||||||
|
if unprocessed_len > 0:
|
||||||
|
self._buffer[:unprocessed_len] = self._buffer[self._read_pos:self._write_pos]
|
||||||
|
self._write_pos = unprocessed_len
|
||||||
|
self._read_pos = 0
|
||||||
|
|
||||||
|
def get_chunk_for_inference(self, min_duration_sec: float = 0.5) -> Optional[AudioChunkInfo]:
|
||||||
|
"""
|
||||||
|
Get the next chunk for ASR inference.
|
||||||
|
|
||||||
|
Returns a chunk of audio when enough data is available.
|
||||||
|
The chunk includes overlap from the previous chunk for context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min_duration_sec: Minimum duration required to return a chunk
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AudioChunkInfo if enough data is available, None otherwise
|
||||||
|
"""
|
||||||
|
min_samples = int(min_duration_sec * self.sample_rate)
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
available = self._write_pos - self._read_pos
|
||||||
|
|
||||||
|
if available < min_samples:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Calculate chunk boundaries
|
||||||
|
chunk_start = self._read_pos
|
||||||
|
chunk_end = min(self._read_pos + self.chunk_size, self._write_pos)
|
||||||
|
actual_chunk_size = chunk_end - chunk_start
|
||||||
|
|
||||||
|
# Extract audio
|
||||||
|
audio = self._buffer[chunk_start:chunk_end].copy()
|
||||||
|
|
||||||
|
# Calculate timestamps based on total samples received
|
||||||
|
base_sample = self._total_samples_received - (self._write_pos - chunk_start)
|
||||||
|
start_sec = base_sample / self.sample_rate
|
||||||
|
end_sec = (base_sample + actual_chunk_size) / self.sample_rate
|
||||||
|
|
||||||
|
return AudioChunkInfo(
|
||||||
|
audio=audio,
|
||||||
|
start_sample=base_sample,
|
||||||
|
end_sample=base_sample + actual_chunk_size,
|
||||||
|
start_sec=start_sec,
|
||||||
|
end_sec=end_sec,
|
||||||
|
)
|
||||||
|
|
||||||
|
def mark_processed(self, samples: int) -> None:
|
||||||
|
"""
|
||||||
|
Mark samples as processed, advancing the read position.
|
||||||
|
|
||||||
|
Keeps overlap_size samples for context in the next chunk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
samples: Number of samples that were processed
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
# Advance read position but keep overlap for context
|
||||||
|
advance = max(0, samples - self.overlap_size)
|
||||||
|
self._read_pos = min(self._read_pos + advance, self._write_pos)
|
||||||
|
|
||||||
|
def get_segment(self, start_sec: float, end_sec: float) -> Optional[np.ndarray]:
|
||||||
|
"""
|
||||||
|
Get a specific time segment from the buffer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_sec: Start time in seconds (relative to session start)
|
||||||
|
end_sec: End time in seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Audio segment if available, None otherwise
|
||||||
|
"""
|
||||||
|
start_sample = int(start_sec * self.sample_rate)
|
||||||
|
end_sample = int(end_sec * self.sample_rate)
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
# Calculate buffer positions
|
||||||
|
buffer_start_sample = self._total_samples_received - self._write_pos
|
||||||
|
buffer_end_sample = self._total_samples_received
|
||||||
|
|
||||||
|
# Check if segment is in buffer
|
||||||
|
if start_sample < buffer_start_sample or end_sample > buffer_end_sample:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Convert to buffer indices
|
||||||
|
buf_start = start_sample - buffer_start_sample
|
||||||
|
buf_end = end_sample - buffer_start_sample
|
||||||
|
|
||||||
|
return self._buffer[buf_start:buf_end].copy()
|
||||||
|
|
||||||
|
def get_all_unprocessed(self) -> Optional[AudioChunkInfo]:
|
||||||
|
"""
|
||||||
|
Get all unprocessed audio.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AudioChunkInfo with all unprocessed audio, or None if empty
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
if self._write_pos <= self._read_pos:
|
||||||
|
return None
|
||||||
|
|
||||||
|
audio = self._buffer[self._read_pos:self._write_pos].copy()
|
||||||
|
base_sample = self._total_samples_received - (self._write_pos - self._read_pos)
|
||||||
|
start_sec = base_sample / self.sample_rate
|
||||||
|
end_sec = self._total_samples_received / self.sample_rate
|
||||||
|
|
||||||
|
return AudioChunkInfo(
|
||||||
|
audio=audio,
|
||||||
|
start_sample=base_sample,
|
||||||
|
end_sample=self._total_samples_received,
|
||||||
|
start_sec=start_sec,
|
||||||
|
end_sec=end_sec,
|
||||||
|
)
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Clear the buffer and reset all positions."""
|
||||||
|
with self._lock:
|
||||||
|
self._buffer.fill(0)
|
||||||
|
self._write_pos = 0
|
||||||
|
self._read_pos = 0
|
||||||
|
self._total_samples_received = 0
|
||||||
|
|
||||||
|
def reset_read_position(self) -> None:
|
||||||
|
"""Reset read position to current write position (skip all unprocessed)."""
|
||||||
|
with self._lock:
|
||||||
|
self._read_pos = self._write_pos
|
||||||
154
static/scripts/vibevoice-asr/realtime/models.py
Normal file
154
static/scripts/vibevoice-asr/realtime/models.py
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
"""
|
||||||
|
Data models for real-time ASR WebSocket communication.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
from dataclasses import dataclass, field, asdict
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
class MessageType(str, Enum):
|
||||||
|
"""WebSocket message types."""
|
||||||
|
# Client -> Server
|
||||||
|
AUDIO_CHUNK = "audio_chunk"
|
||||||
|
CONFIG = "config"
|
||||||
|
START = "start"
|
||||||
|
STOP = "stop"
|
||||||
|
|
||||||
|
# Server -> Client
|
||||||
|
PARTIAL_RESULT = "partial_result"
|
||||||
|
FINAL_RESULT = "final_result"
|
||||||
|
VAD_EVENT = "vad_event"
|
||||||
|
ERROR = "error"
|
||||||
|
STATUS = "status"
|
||||||
|
|
||||||
|
|
||||||
|
class VADEventType(str, Enum):
|
||||||
|
"""VAD event types."""
|
||||||
|
SPEECH_START = "speech_start"
|
||||||
|
SPEECH_END = "speech_end"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SessionConfig:
|
||||||
|
"""Configuration for a real-time ASR session."""
|
||||||
|
# Audio parameters
|
||||||
|
sample_rate: int = 16000
|
||||||
|
chunk_duration_sec: float = 3.0
|
||||||
|
overlap_sec: float = 0.5
|
||||||
|
|
||||||
|
# VAD parameters
|
||||||
|
vad_threshold: float = 0.5
|
||||||
|
min_speech_duration_ms: int = 250
|
||||||
|
min_silence_duration_ms: int = 500
|
||||||
|
min_volume_threshold: float = 0.01 # Minimum RMS volume (0.0-1.0) to consider as potential speech
|
||||||
|
|
||||||
|
# ASR parameters
|
||||||
|
max_new_tokens: int = 512
|
||||||
|
temperature: float = 0.0
|
||||||
|
context_info: Optional[str] = None
|
||||||
|
|
||||||
|
# Behavior
|
||||||
|
return_partial_results: bool = True
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return asdict(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "SessionConfig":
|
||||||
|
return cls(**{k: v for k, v in data.items() if k in cls.__dataclass_fields__})
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TranscriptionSegment:
|
||||||
|
"""A single transcription segment with metadata."""
|
||||||
|
start_time: float
|
||||||
|
end_time: float
|
||||||
|
speaker_id: str
|
||||||
|
text: str
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return asdict(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TranscriptionResult:
|
||||||
|
"""Transcription result message."""
|
||||||
|
type: MessageType
|
||||||
|
text: str
|
||||||
|
is_final: bool
|
||||||
|
segments: List[TranscriptionSegment] = field(default_factory=list)
|
||||||
|
latency_ms: float = 0.0
|
||||||
|
timestamp: float = field(default_factory=time.time)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": self.type.value,
|
||||||
|
"text": self.text,
|
||||||
|
"is_final": self.is_final,
|
||||||
|
"segments": [s.to_dict() for s in self.segments],
|
||||||
|
"latency_ms": self.latency_ms,
|
||||||
|
"timestamp": self.timestamp,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VADEvent:
|
||||||
|
"""VAD event message."""
|
||||||
|
type: MessageType = MessageType.VAD_EVENT
|
||||||
|
event: VADEventType = VADEventType.SPEECH_START
|
||||||
|
timestamp: float = field(default_factory=time.time)
|
||||||
|
audio_timestamp_sec: float = 0.0
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": self.type.value,
|
||||||
|
"event": self.event.value,
|
||||||
|
"timestamp": self.timestamp,
|
||||||
|
"audio_timestamp_sec": self.audio_timestamp_sec,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StatusMessage:
|
||||||
|
"""Status message."""
|
||||||
|
type: MessageType = MessageType.STATUS
|
||||||
|
status: str = ""
|
||||||
|
message: str = ""
|
||||||
|
timestamp: float = field(default_factory=time.time)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": self.type.value,
|
||||||
|
"status": self.status,
|
||||||
|
"message": self.message,
|
||||||
|
"timestamp": self.timestamp,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ErrorMessage:
|
||||||
|
"""Error message."""
|
||||||
|
type: MessageType = MessageType.ERROR
|
||||||
|
error: str = ""
|
||||||
|
code: str = ""
|
||||||
|
timestamp: float = field(default_factory=time.time)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": self.type.value,
|
||||||
|
"error": self.error,
|
||||||
|
"code": self.code,
|
||||||
|
"timestamp": self.timestamp,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SpeechSegment:
|
||||||
|
"""Detected speech segment from VAD."""
|
||||||
|
start_sample: int
|
||||||
|
end_sample: int
|
||||||
|
start_sec: float
|
||||||
|
end_sec: float
|
||||||
|
confidence: float = 1.0
|
||||||
300
static/scripts/vibevoice-asr/realtime/server.py
Normal file
300
static/scripts/vibevoice-asr/realtime/server.py
Normal file
@ -0,0 +1,300 @@
|
|||||||
|
"""
|
||||||
|
FastAPI WebSocket server for real-time ASR.
|
||||||
|
|
||||||
|
Provides WebSocket endpoint for streaming audio and receiving transcriptions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
from fastapi.responses import HTMLResponse, JSONResponse
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
# Add parent directory to path
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
from .models import (
|
||||||
|
SessionConfig,
|
||||||
|
TranscriptionResult,
|
||||||
|
VADEvent,
|
||||||
|
StatusMessage,
|
||||||
|
ErrorMessage,
|
||||||
|
MessageType,
|
||||||
|
)
|
||||||
|
from .asr_worker import ASRWorker
|
||||||
|
from .session_manager import SessionManager
|
||||||
|
|
||||||
|
|
||||||
|
# Global instances
|
||||||
|
asr_worker: Optional[ASRWorker] = None
|
||||||
|
session_manager: Optional[SessionManager] = None
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""Application lifespan manager."""
|
||||||
|
global asr_worker, session_manager
|
||||||
|
|
||||||
|
# Startup
|
||||||
|
print("Starting VibeVoice Realtime ASR Server...")
|
||||||
|
|
||||||
|
# Get model path from environment or use default
|
||||||
|
model_path = os.environ.get("VIBEVOICE_MODEL_PATH", "microsoft/VibeVoice-ASR")
|
||||||
|
device = os.environ.get("VIBEVOICE_DEVICE", "cuda")
|
||||||
|
attn_impl = os.environ.get("VIBEVOICE_ATTN_IMPL", "flash_attention_2")
|
||||||
|
|
||||||
|
# Initialize ASR worker
|
||||||
|
asr_worker = ASRWorker(
|
||||||
|
model_path=model_path,
|
||||||
|
device=device,
|
||||||
|
attn_implementation=attn_impl,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pre-load model (optional, can be lazy-loaded on first request)
|
||||||
|
preload = os.environ.get("VIBEVOICE_PRELOAD_MODEL", "true").lower() == "true"
|
||||||
|
if preload:
|
||||||
|
print("Pre-loading ASR model...")
|
||||||
|
asr_worker.load_model()
|
||||||
|
|
||||||
|
# Initialize session manager
|
||||||
|
max_sessions = int(os.environ.get("VIBEVOICE_MAX_SESSIONS", "10"))
|
||||||
|
session_manager = SessionManager(
|
||||||
|
asr_worker=asr_worker,
|
||||||
|
max_concurrent_sessions=max_sessions,
|
||||||
|
)
|
||||||
|
await session_manager.start()
|
||||||
|
|
||||||
|
print("Server ready!")
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Shutdown
|
||||||
|
print("Shutting down...")
|
||||||
|
await session_manager.stop()
|
||||||
|
|
||||||
|
|
||||||
|
# Create FastAPI app
|
||||||
|
app = FastAPI(
|
||||||
|
title="VibeVoice Realtime ASR",
|
||||||
|
description="Real-time speech recognition using VibeVoice ASR",
|
||||||
|
version="0.1.0",
|
||||||
|
lifespan=lifespan,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mount static files
|
||||||
|
static_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "static")
|
||||||
|
if os.path.exists(static_dir):
|
||||||
|
app.mount("/static", StaticFiles(directory=static_dir), name="static")
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def root():
|
||||||
|
"""Root endpoint with API info."""
|
||||||
|
return {
|
||||||
|
"service": "VibeVoice Realtime ASR",
|
||||||
|
"version": "0.1.0",
|
||||||
|
"endpoints": {
|
||||||
|
"websocket": "/ws/asr/{session_id}",
|
||||||
|
"health": "/health",
|
||||||
|
"stats": "/stats",
|
||||||
|
"client": "/static/realtime_client.html",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check():
|
||||||
|
"""Health check endpoint."""
|
||||||
|
return {
|
||||||
|
"status": "healthy",
|
||||||
|
"model_loaded": asr_worker.is_loaded if asr_worker else False,
|
||||||
|
"active_sessions": len(session_manager._sessions) if session_manager else 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/stats")
|
||||||
|
async def get_stats():
|
||||||
|
"""Get server statistics."""
|
||||||
|
if session_manager is None:
|
||||||
|
raise HTTPException(status_code=503, detail="Server not initialized")
|
||||||
|
|
||||||
|
return session_manager.get_stats()
|
||||||
|
|
||||||
|
|
||||||
|
@app.websocket("/ws/asr/{session_id}")
|
||||||
|
async def websocket_asr(websocket: WebSocket, session_id: str):
|
||||||
|
"""
|
||||||
|
WebSocket endpoint for real-time ASR.
|
||||||
|
|
||||||
|
Protocol:
|
||||||
|
1. Client connects and optionally sends config message
|
||||||
|
2. Client sends binary audio chunks (PCM 16-bit, 16kHz, mono)
|
||||||
|
3. Server sends JSON messages with transcription results
|
||||||
|
|
||||||
|
Message types (server -> client):
|
||||||
|
- partial_result: Intermediate transcription
|
||||||
|
- final_result: Complete transcription for a segment
|
||||||
|
- vad_event: Speech start/end events
|
||||||
|
- error: Error messages
|
||||||
|
- status: Status updates
|
||||||
|
"""
|
||||||
|
await websocket.accept()
|
||||||
|
|
||||||
|
# Send connection confirmation
|
||||||
|
await websocket.send_json(
|
||||||
|
StatusMessage(
|
||||||
|
status="connected",
|
||||||
|
message=f"Session {session_id} connected",
|
||||||
|
).to_dict()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Result callback
|
||||||
|
async def on_result(result: TranscriptionResult):
|
||||||
|
try:
|
||||||
|
await websocket.send_json(result.to_dict())
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[{session_id}] Failed to send result: {e}")
|
||||||
|
|
||||||
|
# VAD event callback
|
||||||
|
async def on_vad_event(event: VADEvent):
|
||||||
|
try:
|
||||||
|
await websocket.send_json(event.to_dict())
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[{session_id}] Failed to send VAD event: {e}")
|
||||||
|
|
||||||
|
# Create session
|
||||||
|
session = await session_manager.create_session(
|
||||||
|
session_id=session_id,
|
||||||
|
on_result=on_result,
|
||||||
|
on_vad_event=on_vad_event,
|
||||||
|
)
|
||||||
|
|
||||||
|
if session is None:
|
||||||
|
await websocket.send_json(
|
||||||
|
ErrorMessage(
|
||||||
|
error="Maximum sessions reached",
|
||||||
|
code="MAX_SESSIONS",
|
||||||
|
).to_dict()
|
||||||
|
)
|
||||||
|
await websocket.close()
|
||||||
|
return
|
||||||
|
|
||||||
|
await websocket.send_json(
|
||||||
|
StatusMessage(
|
||||||
|
status="ready",
|
||||||
|
message="Session ready for audio",
|
||||||
|
).to_dict()
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
# Receive message
|
||||||
|
message = await websocket.receive()
|
||||||
|
|
||||||
|
if message["type"] == "websocket.disconnect":
|
||||||
|
break
|
||||||
|
|
||||||
|
# Handle binary audio data
|
||||||
|
if "bytes" in message:
|
||||||
|
audio_data = message["bytes"]
|
||||||
|
try:
|
||||||
|
await session.process_audio_chunk(audio_data)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[{session_id}] Error processing audio: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
# Handle JSON control messages
|
||||||
|
elif "text" in message:
|
||||||
|
try:
|
||||||
|
data = json.loads(message["text"])
|
||||||
|
msg_type = data.get("type")
|
||||||
|
|
||||||
|
if msg_type == "config":
|
||||||
|
# Update session config
|
||||||
|
config = SessionConfig.from_dict(data.get("config", {}))
|
||||||
|
session.update_config(config)
|
||||||
|
await websocket.send_json(
|
||||||
|
StatusMessage(
|
||||||
|
status="config_updated",
|
||||||
|
message="Configuration updated",
|
||||||
|
).to_dict()
|
||||||
|
)
|
||||||
|
|
||||||
|
elif msg_type == "stop":
|
||||||
|
# Flush and close
|
||||||
|
await session.flush()
|
||||||
|
await websocket.send_json(
|
||||||
|
StatusMessage(
|
||||||
|
status="stopped",
|
||||||
|
message="Session stopped",
|
||||||
|
).to_dict()
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
elif msg_type == "ping":
|
||||||
|
await websocket.send_json({"type": "pong", "timestamp": time.time()})
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
await websocket.send_json(
|
||||||
|
ErrorMessage(
|
||||||
|
error="Invalid JSON",
|
||||||
|
code="INVALID_JSON",
|
||||||
|
).to_dict()
|
||||||
|
)
|
||||||
|
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
print(f"[{session_id}] Client disconnected")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[{session_id}] Error: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
finally:
|
||||||
|
# Clean up session
|
||||||
|
await session_manager.close_session(session_id)
|
||||||
|
print(f"[{session_id}] Session closed")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main entry point."""
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="VibeVoice Realtime ASR Server")
|
||||||
|
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
|
||||||
|
parser.add_argument("--port", type=int, default=8000, help="Port to bind to")
|
||||||
|
parser.add_argument("--model-path", type=str, default="microsoft/VibeVoice-ASR",
|
||||||
|
help="Path to VibeVoice ASR model")
|
||||||
|
parser.add_argument("--device", type=str, default="cuda", help="Device (cuda/cpu)")
|
||||||
|
parser.add_argument("--max-sessions", type=int, default=10, help="Max concurrent sessions")
|
||||||
|
parser.add_argument("--no-preload", action="store_true", help="Don't preload model")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Set environment variables for lifespan
|
||||||
|
os.environ["VIBEVOICE_MODEL_PATH"] = args.model_path
|
||||||
|
os.environ["VIBEVOICE_DEVICE"] = args.device
|
||||||
|
os.environ["VIBEVOICE_MAX_SESSIONS"] = str(args.max_sessions)
|
||||||
|
os.environ["VIBEVOICE_PRELOAD_MODEL"] = "false" if args.no_preload else "true"
|
||||||
|
|
||||||
|
print(f"Starting server on {args.host}:{args.port}")
|
||||||
|
print(f"Model: {args.model_path}")
|
||||||
|
print(f"Device: {args.device}")
|
||||||
|
print(f"Max sessions: {args.max_sessions}")
|
||||||
|
|
||||||
|
uvicorn.run(
|
||||||
|
app,
|
||||||
|
host=args.host,
|
||||||
|
port=args.port,
|
||||||
|
log_level="info",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
401
static/scripts/vibevoice-asr/realtime/session_manager.py
Normal file
401
static/scripts/vibevoice-asr/realtime/session_manager.py
Normal file
@ -0,0 +1,401 @@
|
|||||||
|
"""
|
||||||
|
Session manager for real-time ASR.
|
||||||
|
|
||||||
|
Manages multiple concurrent client sessions with resource isolation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Dict, Optional, Callable, Any
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
import threading
|
||||||
|
|
||||||
|
from .models import (
|
||||||
|
SessionConfig,
|
||||||
|
TranscriptionResult,
|
||||||
|
VADEvent,
|
||||||
|
StatusMessage,
|
||||||
|
ErrorMessage,
|
||||||
|
MessageType,
|
||||||
|
SpeechSegment,
|
||||||
|
)
|
||||||
|
from .audio_buffer import AudioBuffer
|
||||||
|
from .vad_processor import VADProcessor
|
||||||
|
from .asr_worker import ASRWorker
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SessionStats:
|
||||||
|
"""Statistics for a session."""
|
||||||
|
created_at: float = field(default_factory=time.time)
|
||||||
|
last_activity: float = field(default_factory=time.time)
|
||||||
|
audio_received_sec: float = 0.0
|
||||||
|
chunks_received: int = 0
|
||||||
|
segments_transcribed: int = 0
|
||||||
|
total_latency_ms: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class RealtimeSession:
|
||||||
|
"""
|
||||||
|
A single real-time ASR session.
|
||||||
|
|
||||||
|
Manages audio buffering, VAD, and ASR for one client connection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
asr_worker: ASRWorker,
|
||||||
|
config: Optional[SessionConfig] = None,
|
||||||
|
on_result: Optional[Callable[[TranscriptionResult], Any]] = None,
|
||||||
|
on_vad_event: Optional[Callable[[VADEvent], Any]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize a session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: Unique session identifier
|
||||||
|
asr_worker: Shared ASR worker instance
|
||||||
|
config: Session configuration
|
||||||
|
on_result: Callback for transcription results
|
||||||
|
on_vad_event: Callback for VAD events
|
||||||
|
"""
|
||||||
|
self.session_id = session_id
|
||||||
|
self.asr_worker = asr_worker
|
||||||
|
self.config = config or SessionConfig()
|
||||||
|
self.on_result = on_result
|
||||||
|
self.on_vad_event = on_vad_event
|
||||||
|
|
||||||
|
# Components
|
||||||
|
self.audio_buffer = AudioBuffer(
|
||||||
|
sample_rate=self.config.sample_rate,
|
||||||
|
chunk_duration_sec=self.config.chunk_duration_sec,
|
||||||
|
overlap_sec=self.config.overlap_sec,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.vad_processor = VADProcessor(
|
||||||
|
sample_rate=self.config.sample_rate,
|
||||||
|
threshold=self.config.vad_threshold,
|
||||||
|
min_speech_duration_ms=self.config.min_speech_duration_ms,
|
||||||
|
min_silence_duration_ms=self.config.min_silence_duration_ms,
|
||||||
|
min_volume_threshold=self.config.min_volume_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
# State
|
||||||
|
self.is_active = True
|
||||||
|
self.stats = SessionStats()
|
||||||
|
self._processing_lock = asyncio.Lock()
|
||||||
|
self._pending_tasks: list = []
|
||||||
|
|
||||||
|
async def process_audio_chunk(self, audio_data: bytes) -> None:
|
||||||
|
"""
|
||||||
|
Process an incoming audio chunk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_data: Raw PCM audio data (16-bit, 16kHz, mono)
|
||||||
|
"""
|
||||||
|
if not self.is_active:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.stats.last_activity = time.time()
|
||||||
|
self.stats.chunks_received += 1
|
||||||
|
|
||||||
|
# Convert bytes to float32 array
|
||||||
|
import numpy as np
|
||||||
|
audio_int16 = np.frombuffer(audio_data, dtype=np.int16)
|
||||||
|
audio_float = audio_int16.astype(np.float32) / 32768.0
|
||||||
|
|
||||||
|
self.stats.audio_received_sec += len(audio_float) / self.config.sample_rate
|
||||||
|
|
||||||
|
# Add to buffer
|
||||||
|
self.audio_buffer.append(audio_float)
|
||||||
|
|
||||||
|
# Process with VAD
|
||||||
|
segments, events = self.vad_processor.process(audio_float)
|
||||||
|
|
||||||
|
# Send VAD events
|
||||||
|
if self.on_vad_event:
|
||||||
|
for event in events:
|
||||||
|
await self._send_callback(self.on_vad_event, event)
|
||||||
|
|
||||||
|
# Process completed speech segments
|
||||||
|
for segment in segments:
|
||||||
|
try:
|
||||||
|
await self._transcribe_segment(segment)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[Session {self.session_id}] Transcription error: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
async def _transcribe_segment(self, segment: SpeechSegment) -> None:
|
||||||
|
"""Transcribe a detected speech segment."""
|
||||||
|
# Get audio for segment from buffer
|
||||||
|
audio = self.audio_buffer.get_segment(segment.start_sec, segment.end_sec)
|
||||||
|
|
||||||
|
if audio is None or len(audio) == 0:
|
||||||
|
print(f"[Session {self.session_id}] Could not retrieve audio for segment")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.stats.segments_transcribed += 1
|
||||||
|
|
||||||
|
async def on_partial(result: TranscriptionResult):
|
||||||
|
if self.on_result:
|
||||||
|
await self._send_callback(self.on_result, result)
|
||||||
|
|
||||||
|
# Run transcription
|
||||||
|
result = await self.asr_worker.transcribe_segment(
|
||||||
|
audio=audio,
|
||||||
|
sample_rate=self.config.sample_rate,
|
||||||
|
context_info=self.config.context_info,
|
||||||
|
segment_start_sec=segment.start_sec,
|
||||||
|
segment_end_sec=segment.end_sec,
|
||||||
|
config=self.config,
|
||||||
|
on_partial=on_partial if self.config.return_partial_results else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.stats.total_latency_ms += result.latency_ms
|
||||||
|
|
||||||
|
# Send final result
|
||||||
|
if self.on_result:
|
||||||
|
await self._send_callback(self.on_result, result)
|
||||||
|
|
||||||
|
async def _send_callback(self, callback: Callable, data: Any) -> None:
|
||||||
|
"""Send data via callback, handling both sync and async."""
|
||||||
|
try:
|
||||||
|
if asyncio.iscoroutinefunction(callback):
|
||||||
|
await callback(data)
|
||||||
|
else:
|
||||||
|
callback(data)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[Session {self.session_id}] Callback error: {e}")
|
||||||
|
|
||||||
|
async def flush(self) -> None:
|
||||||
|
"""
|
||||||
|
Flush any remaining audio and force transcription.
|
||||||
|
|
||||||
|
Called when session ends to process any remaining speech.
|
||||||
|
"""
|
||||||
|
# Force end any active speech
|
||||||
|
segment = self.vad_processor.force_end_speech()
|
||||||
|
if segment:
|
||||||
|
await self._transcribe_segment(segment)
|
||||||
|
|
||||||
|
# Also check for any unprocessed audio in buffer
|
||||||
|
chunk_info = self.audio_buffer.get_all_unprocessed()
|
||||||
|
if chunk_info and len(chunk_info.audio) > self.config.sample_rate * 0.5:
|
||||||
|
# More than 0.5 seconds of unprocessed audio
|
||||||
|
forced_segment = SpeechSegment(
|
||||||
|
start_sample=chunk_info.start_sample,
|
||||||
|
end_sample=chunk_info.end_sample,
|
||||||
|
start_sec=chunk_info.start_sec,
|
||||||
|
end_sec=chunk_info.end_sec,
|
||||||
|
)
|
||||||
|
await self._transcribe_segment(forced_segment)
|
||||||
|
|
||||||
|
def update_config(self, new_config: SessionConfig) -> None:
|
||||||
|
"""Update session configuration (partial update supported)."""
|
||||||
|
# Merge with existing config - only update non-default values
|
||||||
|
if new_config.vad_threshold != 0.5:
|
||||||
|
self.config.vad_threshold = new_config.vad_threshold
|
||||||
|
if new_config.min_speech_duration_ms != 250:
|
||||||
|
self.config.min_speech_duration_ms = new_config.min_speech_duration_ms
|
||||||
|
if new_config.min_silence_duration_ms != 500:
|
||||||
|
self.config.min_silence_duration_ms = new_config.min_silence_duration_ms
|
||||||
|
if new_config.min_volume_threshold != 0.01:
|
||||||
|
self.config.min_volume_threshold = new_config.min_volume_threshold
|
||||||
|
if new_config.context_info is not None:
|
||||||
|
self.config.context_info = new_config.context_info
|
||||||
|
|
||||||
|
# Recreate VAD processor with new parameters
|
||||||
|
self.vad_processor = VADProcessor(
|
||||||
|
sample_rate=self.config.sample_rate,
|
||||||
|
threshold=self.config.vad_threshold,
|
||||||
|
min_speech_duration_ms=self.config.min_speech_duration_ms,
|
||||||
|
min_silence_duration_ms=self.config.min_silence_duration_ms,
|
||||||
|
min_volume_threshold=self.config.min_volume_threshold,
|
||||||
|
)
|
||||||
|
print(f"[Session {self.session_id}] Config updated: vad_threshold={self.config.vad_threshold}, "
|
||||||
|
f"min_speech={self.config.min_speech_duration_ms}ms, min_silence={self.config.min_silence_duration_ms}ms, "
|
||||||
|
f"min_volume={self.config.min_volume_threshold}")
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the session and release resources."""
|
||||||
|
self.is_active = False
|
||||||
|
self.audio_buffer.clear()
|
||||||
|
self.vad_processor.reset()
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict:
|
||||||
|
"""Get session statistics."""
|
||||||
|
return {
|
||||||
|
"session_id": self.session_id,
|
||||||
|
"created_at": self.stats.created_at,
|
||||||
|
"last_activity": self.stats.last_activity,
|
||||||
|
"duration_sec": time.time() - self.stats.created_at,
|
||||||
|
"audio_received_sec": self.stats.audio_received_sec,
|
||||||
|
"chunks_received": self.stats.chunks_received,
|
||||||
|
"segments_transcribed": self.stats.segments_transcribed,
|
||||||
|
"avg_latency_ms": (
|
||||||
|
self.stats.total_latency_ms / self.stats.segments_transcribed
|
||||||
|
if self.stats.segments_transcribed > 0 else 0
|
||||||
|
),
|
||||||
|
"is_active": self.is_active,
|
||||||
|
"vad_speech_active": self.vad_processor.is_speech_active,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SessionManager:
|
||||||
|
"""
|
||||||
|
Manages multiple concurrent ASR sessions.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Session creation and cleanup
|
||||||
|
- Resource limiting (max concurrent sessions)
|
||||||
|
- Idle session timeout
|
||||||
|
- Shared ASR worker management
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
asr_worker: ASRWorker,
|
||||||
|
max_concurrent_sessions: int = 10,
|
||||||
|
session_timeout_sec: float = 300.0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the session manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
asr_worker: Shared ASR worker
|
||||||
|
max_concurrent_sessions: Maximum number of concurrent sessions
|
||||||
|
session_timeout_sec: Timeout for idle sessions
|
||||||
|
"""
|
||||||
|
self.asr_worker = asr_worker
|
||||||
|
self.max_sessions = max_concurrent_sessions
|
||||||
|
self.session_timeout = session_timeout_sec
|
||||||
|
|
||||||
|
self._sessions: Dict[str, RealtimeSession] = {}
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
# Cleanup task
|
||||||
|
self._cleanup_task: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
"""Start the session manager."""
|
||||||
|
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
"""Stop the session manager and close all sessions."""
|
||||||
|
if self._cleanup_task:
|
||||||
|
self._cleanup_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._cleanup_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
for session in self._sessions.values():
|
||||||
|
session.close()
|
||||||
|
self._sessions.clear()
|
||||||
|
|
||||||
|
async def create_session(
|
||||||
|
self,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
config: Optional[SessionConfig] = None,
|
||||||
|
on_result: Optional[Callable[[TranscriptionResult], Any]] = None,
|
||||||
|
on_vad_event: Optional[Callable[[VADEvent], Any]] = None,
|
||||||
|
) -> Optional[RealtimeSession]:
|
||||||
|
"""
|
||||||
|
Create a new session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: Optional session ID (generated if not provided)
|
||||||
|
config: Session configuration
|
||||||
|
on_result: Callback for results
|
||||||
|
on_vad_event: Callback for VAD events
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created session, or None if limit reached
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
# Check session limit
|
||||||
|
if len(self._sessions) >= self.max_sessions:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Generate session ID if not provided
|
||||||
|
if session_id is None:
|
||||||
|
session_id = str(uuid.uuid4())[:8]
|
||||||
|
|
||||||
|
# Check for duplicate
|
||||||
|
if session_id in self._sessions:
|
||||||
|
return self._sessions[session_id]
|
||||||
|
|
||||||
|
# Create session
|
||||||
|
session = RealtimeSession(
|
||||||
|
session_id=session_id,
|
||||||
|
asr_worker=self.asr_worker,
|
||||||
|
config=config,
|
||||||
|
on_result=on_result,
|
||||||
|
on_vad_event=on_vad_event,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._sessions[session_id] = session
|
||||||
|
return session
|
||||||
|
|
||||||
|
async def get_session(self, session_id: str) -> Optional[RealtimeSession]:
|
||||||
|
"""Get a session by ID."""
|
||||||
|
async with self._lock:
|
||||||
|
return self._sessions.get(session_id)
|
||||||
|
|
||||||
|
async def close_session(self, session_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Close and remove a session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: Session to close
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if session was found and closed
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
session = self._sessions.pop(session_id, None)
|
||||||
|
if session:
|
||||||
|
await session.flush()
|
||||||
|
session.close()
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _cleanup_loop(self) -> None:
|
||||||
|
"""Background task to clean up idle sessions."""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(60) # Check every minute
|
||||||
|
|
||||||
|
current_time = time.time()
|
||||||
|
sessions_to_close = []
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
for session_id, session in self._sessions.items():
|
||||||
|
idle_time = current_time - session.stats.last_activity
|
||||||
|
if idle_time > self.session_timeout:
|
||||||
|
sessions_to_close.append(session_id)
|
||||||
|
|
||||||
|
for session_id in sessions_to_close:
|
||||||
|
print(f"Closing idle session: {session_id}")
|
||||||
|
await self.close_session(session_id)
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Cleanup error: {e}")
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict:
|
||||||
|
"""Get manager statistics."""
|
||||||
|
return {
|
||||||
|
"active_sessions": len(self._sessions),
|
||||||
|
"max_sessions": self.max_sessions,
|
||||||
|
"session_timeout_sec": self.session_timeout,
|
||||||
|
"sessions": {
|
||||||
|
sid: session.get_stats()
|
||||||
|
for sid, session in self._sessions.items()
|
||||||
|
},
|
||||||
|
}
|
||||||
295
static/scripts/vibevoice-asr/realtime/vad_processor.py
Normal file
295
static/scripts/vibevoice-asr/realtime/vad_processor.py
Normal file
@ -0,0 +1,295 @@
|
|||||||
|
"""
|
||||||
|
Voice Activity Detection (VAD) processor using Silero-VAD (ONNX version).
|
||||||
|
|
||||||
|
Detects speech segments in real-time audio streams.
|
||||||
|
Uses ONNX runtime to avoid torchaudio dependency issues.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import threading
|
||||||
|
import os
|
||||||
|
import urllib.request
|
||||||
|
|
||||||
|
from .models import SpeechSegment, VADEvent, VADEventType, MessageType
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VADState:
|
||||||
|
"""Internal state of the VAD processor."""
|
||||||
|
is_speech_active: bool = False
|
||||||
|
speech_start_sample: int = 0
|
||||||
|
silence_start_sample: int = 0
|
||||||
|
last_speech_prob: float = 0.0
|
||||||
|
total_samples_processed: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class VADProcessor:
|
||||||
|
"""
|
||||||
|
Voice Activity Detection using Silero-VAD (ONNX version).
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Real-time speech detection
|
||||||
|
- Configurable thresholds for speech/silence duration
|
||||||
|
- Event generation for speech start/end
|
||||||
|
- Thread-safe operations
|
||||||
|
- No torchaudio dependency (uses ONNX runtime)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Silero VAD ONNX model URL
|
||||||
|
ONNX_MODEL_URL = "https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
threshold: float = 0.5,
|
||||||
|
min_speech_duration_ms: int = 250,
|
||||||
|
min_silence_duration_ms: int = 500,
|
||||||
|
window_size_samples: int = 512,
|
||||||
|
min_volume_threshold: float = 0.01,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the VAD processor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample_rate: Audio sample rate (must be 16000 for Silero-VAD)
|
||||||
|
threshold: Speech probability threshold (0.0-1.0)
|
||||||
|
min_speech_duration_ms: Minimum speech duration to trigger speech_start
|
||||||
|
min_silence_duration_ms: Minimum silence duration to trigger speech_end
|
||||||
|
window_size_samples: VAD window size (512 for 16kHz = 32ms)
|
||||||
|
min_volume_threshold: Minimum RMS volume (0.0-1.0) to consider as potential speech
|
||||||
|
"""
|
||||||
|
if sample_rate != 16000:
|
||||||
|
raise ValueError("Silero-VAD requires 16kHz sample rate")
|
||||||
|
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.threshold = threshold
|
||||||
|
self.min_speech_samples = int(min_speech_duration_ms * sample_rate / 1000)
|
||||||
|
self.min_silence_samples = int(min_silence_duration_ms * sample_rate / 1000)
|
||||||
|
self.window_size = window_size_samples
|
||||||
|
self.min_volume_threshold = min_volume_threshold
|
||||||
|
|
||||||
|
# Load ONNX model
|
||||||
|
self._session = None
|
||||||
|
self._load_model()
|
||||||
|
|
||||||
|
# ONNX model state - single state tensor (size depends on model version)
|
||||||
|
# Silero VAD v5 uses a single 'state' tensor of shape (2, 1, 128)
|
||||||
|
self._state_tensor = np.zeros((2, 1, 128), dtype=np.float32)
|
||||||
|
|
||||||
|
# State
|
||||||
|
self._state = VADState()
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
# Pending speech segment (being accumulated)
|
||||||
|
self._pending_segment_start: Optional[int] = None
|
||||||
|
|
||||||
|
def _get_model_path(self) -> str:
|
||||||
|
"""Get path to ONNX model, downloading if necessary."""
|
||||||
|
cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "silero-vad")
|
||||||
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
model_path = os.path.join(cache_dir, "silero_vad.onnx")
|
||||||
|
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
print(f"Downloading Silero-VAD ONNX model to {model_path}...")
|
||||||
|
urllib.request.urlretrieve(self.ONNX_MODEL_URL, model_path)
|
||||||
|
print("Download complete.")
|
||||||
|
|
||||||
|
return model_path
|
||||||
|
|
||||||
|
def _load_model(self) -> None:
|
||||||
|
"""Load Silero-VAD ONNX model."""
|
||||||
|
try:
|
||||||
|
import onnxruntime as ort
|
||||||
|
|
||||||
|
model_path = self._get_model_path()
|
||||||
|
self._session = ort.InferenceSession(
|
||||||
|
model_path,
|
||||||
|
providers=['CPUExecutionProvider']
|
||||||
|
)
|
||||||
|
print(f"Silero-VAD ONNX model loaded from {model_path}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to load Silero-VAD ONNX model: {e}")
|
||||||
|
|
||||||
|
def _run_inference(self, audio_window: np.ndarray) -> float:
|
||||||
|
"""Run VAD inference on a single window."""
|
||||||
|
# Prepare input
|
||||||
|
audio_input = audio_window.reshape(1, -1).astype(np.float32)
|
||||||
|
sr_input = np.array([self.sample_rate], dtype=np.int64)
|
||||||
|
|
||||||
|
# Run inference - Silero VAD v5 uses 'state' instead of 'h'/'c'
|
||||||
|
outputs = self._session.run(
|
||||||
|
['output', 'stateN'],
|
||||||
|
{
|
||||||
|
'input': audio_input,
|
||||||
|
'sr': sr_input,
|
||||||
|
'state': self._state_tensor,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update state
|
||||||
|
speech_prob = outputs[0][0][0]
|
||||||
|
self._state_tensor = outputs[1]
|
||||||
|
|
||||||
|
return float(speech_prob)
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset VAD state for a new session."""
|
||||||
|
with self._lock:
|
||||||
|
self._state = VADState()
|
||||||
|
self._pending_segment_start = None
|
||||||
|
# Reset state tensor
|
||||||
|
self._state_tensor = np.zeros((2, 1, 128), dtype=np.float32)
|
||||||
|
|
||||||
|
def process(
|
||||||
|
self,
|
||||||
|
audio_chunk: np.ndarray,
|
||||||
|
return_events: bool = True,
|
||||||
|
) -> Tuple[List[SpeechSegment], List[VADEvent]]:
|
||||||
|
"""
|
||||||
|
Process an audio chunk and detect speech segments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_chunk: Audio data as float32 array
|
||||||
|
return_events: Whether to return VAD events
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (completed_segments, events)
|
||||||
|
"""
|
||||||
|
if audio_chunk.dtype != np.float32:
|
||||||
|
audio_chunk = audio_chunk.astype(np.float32)
|
||||||
|
|
||||||
|
completed_segments: List[SpeechSegment] = []
|
||||||
|
events: List[VADEvent] = []
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
# Process in windows
|
||||||
|
chunk_start_sample = self._state.total_samples_processed
|
||||||
|
num_windows = len(audio_chunk) // self.window_size
|
||||||
|
|
||||||
|
for i in range(num_windows):
|
||||||
|
window_start = i * self.window_size
|
||||||
|
window_end = window_start + self.window_size
|
||||||
|
window = audio_chunk[window_start:window_end]
|
||||||
|
|
||||||
|
# Check volume (RMS) threshold first
|
||||||
|
rms = np.sqrt(np.mean(window ** 2))
|
||||||
|
if rms < self.min_volume_threshold:
|
||||||
|
# Volume too low, treat as silence
|
||||||
|
speech_prob = 0.0
|
||||||
|
else:
|
||||||
|
# Get speech probability from VAD model
|
||||||
|
speech_prob = self._run_inference(window)
|
||||||
|
|
||||||
|
self._state.last_speech_prob = speech_prob
|
||||||
|
|
||||||
|
current_sample = chunk_start_sample + window_end
|
||||||
|
is_speech = speech_prob >= self.threshold
|
||||||
|
|
||||||
|
# State machine for speech detection
|
||||||
|
if is_speech:
|
||||||
|
if not self._state.is_speech_active:
|
||||||
|
# Potential speech start
|
||||||
|
if self._pending_segment_start is None:
|
||||||
|
self._pending_segment_start = current_sample - self.window_size
|
||||||
|
|
||||||
|
# Check if speech duration exceeds minimum
|
||||||
|
speech_duration = current_sample - self._pending_segment_start
|
||||||
|
if speech_duration >= self.min_speech_samples:
|
||||||
|
self._state.is_speech_active = True
|
||||||
|
self._state.speech_start_sample = self._pending_segment_start
|
||||||
|
|
||||||
|
if return_events:
|
||||||
|
events.append(VADEvent(
|
||||||
|
type=MessageType.VAD_EVENT,
|
||||||
|
event=VADEventType.SPEECH_START,
|
||||||
|
audio_timestamp_sec=self._pending_segment_start / self.sample_rate,
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
# Continue speech, reset silence counter
|
||||||
|
self._state.silence_start_sample = 0
|
||||||
|
else:
|
||||||
|
if self._state.is_speech_active:
|
||||||
|
# Potential speech end
|
||||||
|
if self._state.silence_start_sample == 0:
|
||||||
|
self._state.silence_start_sample = current_sample
|
||||||
|
|
||||||
|
# Check if silence duration exceeds minimum
|
||||||
|
silence_duration = current_sample - self._state.silence_start_sample
|
||||||
|
if silence_duration >= self.min_silence_samples:
|
||||||
|
# Speech ended - create completed segment
|
||||||
|
segment = SpeechSegment(
|
||||||
|
start_sample=self._state.speech_start_sample,
|
||||||
|
end_sample=self._state.silence_start_sample,
|
||||||
|
start_sec=self._state.speech_start_sample / self.sample_rate,
|
||||||
|
end_sec=self._state.silence_start_sample / self.sample_rate,
|
||||||
|
)
|
||||||
|
completed_segments.append(segment)
|
||||||
|
|
||||||
|
if return_events:
|
||||||
|
events.append(VADEvent(
|
||||||
|
type=MessageType.VAD_EVENT,
|
||||||
|
event=VADEventType.SPEECH_END,
|
||||||
|
audio_timestamp_sec=self._state.silence_start_sample / self.sample_rate,
|
||||||
|
))
|
||||||
|
|
||||||
|
# Reset state
|
||||||
|
self._state.is_speech_active = False
|
||||||
|
self._state.speech_start_sample = 0
|
||||||
|
self._state.silence_start_sample = 0
|
||||||
|
self._pending_segment_start = None
|
||||||
|
else:
|
||||||
|
# No speech, reset pending
|
||||||
|
self._pending_segment_start = None
|
||||||
|
|
||||||
|
# Update total samples processed
|
||||||
|
self._state.total_samples_processed += len(audio_chunk)
|
||||||
|
|
||||||
|
return completed_segments, events
|
||||||
|
|
||||||
|
def force_end_speech(self) -> Optional[SpeechSegment]:
|
||||||
|
"""
|
||||||
|
Force end of current speech segment (e.g., when session ends).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Completed speech segment if speech was active, None otherwise
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
if self._state.is_speech_active:
|
||||||
|
segment = SpeechSegment(
|
||||||
|
start_sample=self._state.speech_start_sample,
|
||||||
|
end_sample=self._state.total_samples_processed,
|
||||||
|
start_sec=self._state.speech_start_sample / self.sample_rate,
|
||||||
|
end_sec=self._state.total_samples_processed / self.sample_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._state.is_speech_active = False
|
||||||
|
self._state.speech_start_sample = 0
|
||||||
|
self._state.silence_start_sample = 0
|
||||||
|
self._pending_segment_start = None
|
||||||
|
|
||||||
|
return segment
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_speech_active(self) -> bool:
|
||||||
|
"""Check if speech is currently active."""
|
||||||
|
with self._lock:
|
||||||
|
return self._state.is_speech_active
|
||||||
|
|
||||||
|
@property
|
||||||
|
def last_speech_probability(self) -> float:
|
||||||
|
"""Get the last computed speech probability."""
|
||||||
|
with self._lock:
|
||||||
|
return self._state.last_speech_prob
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_speech_duration_sec(self) -> float:
|
||||||
|
"""Get duration of current speech segment (if active)."""
|
||||||
|
with self._lock:
|
||||||
|
if not self._state.is_speech_active:
|
||||||
|
return 0.0
|
||||||
|
return (self._state.total_samples_processed - self._state.speech_start_sample) / self.sample_rate
|
||||||
7
static/scripts/vibevoice-asr/requirements-realtime.txt
Normal file
7
static/scripts/vibevoice-asr/requirements-realtime.txt
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
# Real-time ASR dependencies
|
||||||
|
fastapi>=0.100.0
|
||||||
|
uvicorn[standard]>=0.23.0
|
||||||
|
websockets>=11.0
|
||||||
|
numpy>=1.24.0
|
||||||
|
soundfile>=0.12.0
|
||||||
|
onnxruntime
|
||||||
73
static/scripts/vibevoice-asr/run_all.sh
Executable file
73
static/scripts/vibevoice-asr/run_all.sh
Executable file
@ -0,0 +1,73 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Run both Gradio demo and Realtime ASR server
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# ./run_all.sh
|
||||||
|
#
|
||||||
|
# Ports:
|
||||||
|
# - 7860: Gradio UI (batch ASR)
|
||||||
|
# - 8000: WebSocket API (realtime ASR)
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
cd "$(dirname "$0")"
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
GRADIO_HOST="${GRADIO_HOST:-0.0.0.0}"
|
||||||
|
GRADIO_PORT="${GRADIO_PORT:-7860}"
|
||||||
|
REALTIME_HOST="${REALTIME_HOST:-0.0.0.0}"
|
||||||
|
REALTIME_PORT="${REALTIME_PORT:-8000}"
|
||||||
|
MODEL_PATH="${VIBEVOICE_MODEL_PATH:-microsoft/VibeVoice-ASR}"
|
||||||
|
|
||||||
|
echo "=========================================="
|
||||||
|
echo "VibeVoice ASR - All Services"
|
||||||
|
echo "=========================================="
|
||||||
|
echo ""
|
||||||
|
echo "Starting services:"
|
||||||
|
echo " - Gradio UI: http://$GRADIO_HOST:$GRADIO_PORT"
|
||||||
|
echo " - Realtime ASR: http://$REALTIME_HOST:$REALTIME_PORT"
|
||||||
|
echo " - Test Client: http://$REALTIME_HOST:$REALTIME_PORT/static/realtime_client.html"
|
||||||
|
echo ""
|
||||||
|
echo "Model: $MODEL_PATH"
|
||||||
|
echo "=========================================="
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Trap to clean up background processes on exit
|
||||||
|
cleanup() {
|
||||||
|
echo ""
|
||||||
|
echo "Shutting down..."
|
||||||
|
kill $REALTIME_PID 2>/dev/null || true
|
||||||
|
kill $GRADIO_PID 2>/dev/null || true
|
||||||
|
wait
|
||||||
|
echo "All services stopped."
|
||||||
|
}
|
||||||
|
trap cleanup EXIT INT TERM
|
||||||
|
|
||||||
|
# Start Realtime ASR server in background
|
||||||
|
echo "[1/2] Starting Realtime ASR server..."
|
||||||
|
python -m realtime.server \
|
||||||
|
--host "$REALTIME_HOST" \
|
||||||
|
--port "$REALTIME_PORT" \
|
||||||
|
--model-path "$MODEL_PATH" \
|
||||||
|
--no-preload &
|
||||||
|
REALTIME_PID=$!
|
||||||
|
|
||||||
|
# Wait a moment for the server to initialize
|
||||||
|
sleep 2
|
||||||
|
|
||||||
|
# Start Gradio demo in background
|
||||||
|
echo "[2/2] Starting Gradio demo..."
|
||||||
|
python demo/vibevoice_asr_gradio_demo.py \
|
||||||
|
--host "$GRADIO_HOST" \
|
||||||
|
--port "$GRADIO_PORT" \
|
||||||
|
--model_path "$MODEL_PATH" &
|
||||||
|
GRADIO_PID=$!
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Both services started. Press Ctrl+C to stop."
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Wait for either process to exit
|
||||||
|
wait -n $REALTIME_PID $GRADIO_PID
|
||||||
|
|
||||||
|
# If one exits, the trap will clean up the other
|
||||||
41
static/scripts/vibevoice-asr/run_realtime.sh
Executable file
41
static/scripts/vibevoice-asr/run_realtime.sh
Executable file
@ -0,0 +1,41 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Run VibeVoice Realtime ASR Server
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# ./run_realtime.sh [options]
|
||||||
|
#
|
||||||
|
# Options are passed to the server (see --help for details)
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
cd "$(dirname "$0")"
|
||||||
|
|
||||||
|
# Default options
|
||||||
|
HOST="${VIBEVOICE_HOST:-0.0.0.0}"
|
||||||
|
PORT="${VIBEVOICE_PORT:-8000}"
|
||||||
|
MODEL_PATH="${VIBEVOICE_MODEL_PATH:-microsoft/VibeVoice-ASR}"
|
||||||
|
DEVICE="${VIBEVOICE_DEVICE:-cuda}"
|
||||||
|
MAX_SESSIONS="${VIBEVOICE_MAX_SESSIONS:-10}"
|
||||||
|
|
||||||
|
echo "=========================================="
|
||||||
|
echo "VibeVoice Realtime ASR Server"
|
||||||
|
echo "=========================================="
|
||||||
|
echo "Host: $HOST"
|
||||||
|
echo "Port: $PORT"
|
||||||
|
echo "Model: $MODEL_PATH"
|
||||||
|
echo "Device: $DEVICE"
|
||||||
|
echo "Max Sessions: $MAX_SESSIONS"
|
||||||
|
echo "=========================================="
|
||||||
|
echo ""
|
||||||
|
echo "Web client: http://$HOST:$PORT/static/realtime_client.html"
|
||||||
|
echo "WebSocket: ws://$HOST:$PORT/ws/asr/{session_id}"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Run server
|
||||||
|
python -m realtime.server \
|
||||||
|
--host "$HOST" \
|
||||||
|
--port "$PORT" \
|
||||||
|
--model-path "$MODEL_PATH" \
|
||||||
|
--device "$DEVICE" \
|
||||||
|
--max-sessions "$MAX_SESSIONS" \
|
||||||
|
"$@"
|
||||||
287
static/scripts/vibevoice-asr/setup.sh
Normal file
287
static/scripts/vibevoice-asr/setup.sh
Normal file
@ -0,0 +1,287 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# VibeVoice-ASR Setup Script for DGX Spark
|
||||||
|
# Downloads and builds the VibeVoice-ASR container
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# curl -sL https://docs.techswan.online/scripts/vibevoice-asr/setup.sh | bash
|
||||||
|
# curl -sL https://docs.techswan.online/scripts/vibevoice-asr/setup.sh | bash -s build
|
||||||
|
# curl -sL https://docs.techswan.online/scripts/vibevoice-asr/setup.sh | bash -s serve
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
BASE_URL="https://docs.techswan.online/scripts/vibevoice-asr"
|
||||||
|
INSTALL_DIR="${VIBEVOICE_DIR:-$HOME/vibevoice-asr}"
|
||||||
|
IMAGE_NAME="vibevoice-asr:dgx-spark"
|
||||||
|
CONTAINER_NAME="vibevoice-asr"
|
||||||
|
|
||||||
|
# Colors for output
|
||||||
|
RED='\033[0;31m'
|
||||||
|
GREEN='\033[0;32m'
|
||||||
|
YELLOW='\033[1;33m'
|
||||||
|
BLUE='\033[0;34m'
|
||||||
|
NC='\033[0m' # No Color
|
||||||
|
|
||||||
|
log_info() {
|
||||||
|
echo -e "${GREEN}[INFO]${NC} $1"
|
||||||
|
}
|
||||||
|
|
||||||
|
log_warn() {
|
||||||
|
echo -e "${YELLOW}[WARN]${NC} $1"
|
||||||
|
}
|
||||||
|
|
||||||
|
log_error() {
|
||||||
|
echo -e "${RED}[ERROR]${NC} $1"
|
||||||
|
}
|
||||||
|
|
||||||
|
log_step() {
|
||||||
|
echo -e "${BLUE}[STEP]${NC} $1"
|
||||||
|
}
|
||||||
|
|
||||||
|
download_file() {
|
||||||
|
local url="$1"
|
||||||
|
local dest="$2"
|
||||||
|
local dir=$(dirname "$dest")
|
||||||
|
|
||||||
|
mkdir -p "$dir"
|
||||||
|
|
||||||
|
if command -v curl &> /dev/null; then
|
||||||
|
curl -sL "$url" -o "$dest"
|
||||||
|
elif command -v wget &> /dev/null; then
|
||||||
|
wget -q "$url" -O "$dest"
|
||||||
|
else
|
||||||
|
log_error "curl or wget is required"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
download_files() {
|
||||||
|
log_step "Downloading VibeVoice-ASR files to $INSTALL_DIR..."
|
||||||
|
|
||||||
|
mkdir -p "$INSTALL_DIR"
|
||||||
|
cd "$INSTALL_DIR"
|
||||||
|
|
||||||
|
# Core files
|
||||||
|
local files=(
|
||||||
|
"Dockerfile"
|
||||||
|
"requirements-realtime.txt"
|
||||||
|
"test_vibevoice.py"
|
||||||
|
"vibevoice_asr_gradio_demo_patched.py"
|
||||||
|
"run_realtime.sh"
|
||||||
|
"run_all.sh"
|
||||||
|
)
|
||||||
|
|
||||||
|
for file in "${files[@]}"; do
|
||||||
|
log_info "Downloading $file..."
|
||||||
|
download_file "$BASE_URL/$file" "$INSTALL_DIR/$file"
|
||||||
|
done
|
||||||
|
|
||||||
|
# Realtime module
|
||||||
|
local realtime_files=(
|
||||||
|
"__init__.py"
|
||||||
|
"models.py"
|
||||||
|
"server.py"
|
||||||
|
"asr_worker.py"
|
||||||
|
"session_manager.py"
|
||||||
|
"audio_buffer.py"
|
||||||
|
"vad_processor.py"
|
||||||
|
)
|
||||||
|
|
||||||
|
mkdir -p "$INSTALL_DIR/realtime"
|
||||||
|
for file in "${realtime_files[@]}"; do
|
||||||
|
log_info "Downloading realtime/$file..."
|
||||||
|
download_file "$BASE_URL/realtime/$file" "$INSTALL_DIR/realtime/$file"
|
||||||
|
done
|
||||||
|
|
||||||
|
# Static files
|
||||||
|
mkdir -p "$INSTALL_DIR/static"
|
||||||
|
log_info "Downloading static/realtime_client.html..."
|
||||||
|
download_file "$BASE_URL/static/realtime_client.html" "$INSTALL_DIR/static/realtime_client.html"
|
||||||
|
|
||||||
|
# Make scripts executable
|
||||||
|
chmod +x "$INSTALL_DIR/run_realtime.sh" "$INSTALL_DIR/run_all.sh"
|
||||||
|
|
||||||
|
log_info "All files downloaded to $INSTALL_DIR"
|
||||||
|
}
|
||||||
|
|
||||||
|
check_prerequisites() {
|
||||||
|
log_step "Checking prerequisites..."
|
||||||
|
|
||||||
|
# Check Docker
|
||||||
|
if ! command -v docker &> /dev/null; then
|
||||||
|
log_error "Docker is not installed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check NVIDIA Docker runtime
|
||||||
|
if ! docker info 2>/dev/null | grep -q "Runtimes.*nvidia"; then
|
||||||
|
log_warn "NVIDIA Docker runtime may not be configured"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check GPU availability
|
||||||
|
if command -v nvidia-smi &> /dev/null; then
|
||||||
|
log_info "GPU detected:"
|
||||||
|
nvidia-smi --query-gpu=name,memory.total --format=csv,noheader
|
||||||
|
else
|
||||||
|
log_warn "nvidia-smi not found on host"
|
||||||
|
fi
|
||||||
|
|
||||||
|
log_info "Prerequisites check complete"
|
||||||
|
}
|
||||||
|
|
||||||
|
build_image() {
|
||||||
|
log_step "Building Docker image: ${IMAGE_NAME}"
|
||||||
|
log_info "This may take several minutes..."
|
||||||
|
|
||||||
|
cd "$INSTALL_DIR"
|
||||||
|
|
||||||
|
docker build \
|
||||||
|
--network=host \
|
||||||
|
-t "$IMAGE_NAME" \
|
||||||
|
-f Dockerfile \
|
||||||
|
.
|
||||||
|
|
||||||
|
log_info "Docker image built successfully: ${IMAGE_NAME}"
|
||||||
|
}
|
||||||
|
|
||||||
|
run_container() {
|
||||||
|
local mode="${1:-interactive}"
|
||||||
|
|
||||||
|
log_step "Running container in ${mode} mode..."
|
||||||
|
|
||||||
|
# Stop existing container if running
|
||||||
|
if docker ps -q -f name="$CONTAINER_NAME" | grep -q .; then
|
||||||
|
log_warn "Stopping existing container..."
|
||||||
|
docker stop "$CONTAINER_NAME" 2>/dev/null || true
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Remove existing container
|
||||||
|
docker rm "$CONTAINER_NAME" 2>/dev/null || true
|
||||||
|
|
||||||
|
# Common Docker options for DGX Spark
|
||||||
|
local docker_opts=(
|
||||||
|
--gpus all
|
||||||
|
--ipc=host
|
||||||
|
--network=host
|
||||||
|
--ulimit memlock=-1:-1
|
||||||
|
--ulimit stack=-1:-1
|
||||||
|
-e "PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True"
|
||||||
|
-v "$HOME/.cache/huggingface:/root/.cache/huggingface"
|
||||||
|
--name "$CONTAINER_NAME"
|
||||||
|
)
|
||||||
|
|
||||||
|
if [ "$mode" = "interactive" ]; then
|
||||||
|
docker run --rm -it "${docker_opts[@]}" "$IMAGE_NAME" bash
|
||||||
|
elif [ "$mode" = "test" ]; then
|
||||||
|
docker run --rm "${docker_opts[@]}" "$IMAGE_NAME" python /workspace/test_vibevoice.py
|
||||||
|
elif [ "$mode" = "demo" ]; then
|
||||||
|
log_info "Starting Gradio demo on port 7860..."
|
||||||
|
log_info "Access the demo at: http://localhost:7860"
|
||||||
|
docker run --rm -it "${docker_opts[@]}" "$IMAGE_NAME"
|
||||||
|
elif [ "$mode" = "realtime" ]; then
|
||||||
|
log_info "Starting Realtime ASR server on port 8000..."
|
||||||
|
log_info "WebSocket API: ws://localhost:8000/ws/asr/{session_id}"
|
||||||
|
log_info "Test client: http://localhost:8000/static/realtime_client.html"
|
||||||
|
docker run --rm -it "${docker_opts[@]}" "$IMAGE_NAME" \
|
||||||
|
python -m realtime.server --host 0.0.0.0 --port 8000
|
||||||
|
elif [ "$mode" = "serve" ]; then
|
||||||
|
log_info "Starting all services..."
|
||||||
|
log_info " Gradio demo: http://localhost:7860"
|
||||||
|
log_info " Realtime ASR: http://localhost:8000"
|
||||||
|
log_info " Test client: http://localhost:8000/static/realtime_client.html"
|
||||||
|
docker run --rm -it "${docker_opts[@]}" "$IMAGE_NAME" ./run_all.sh
|
||||||
|
else
|
||||||
|
log_error "Unknown mode: $mode"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
show_usage() {
|
||||||
|
echo "VibeVoice-ASR Setup for DGX Spark"
|
||||||
|
echo ""
|
||||||
|
echo "Usage:"
|
||||||
|
echo " curl -sL $BASE_URL/setup.sh | bash # Download only"
|
||||||
|
echo " curl -sL $BASE_URL/setup.sh | bash -s build # Download and build"
|
||||||
|
echo " curl -sL $BASE_URL/setup.sh | bash -s demo # Download, build, run demo"
|
||||||
|
echo " curl -sL $BASE_URL/setup.sh | bash -s serve # Download, build, run all"
|
||||||
|
echo ""
|
||||||
|
echo "Commands:"
|
||||||
|
echo " (default) Download files only"
|
||||||
|
echo " build Download and build Docker image"
|
||||||
|
echo " demo Download, build, and start Gradio demo (port 7860)"
|
||||||
|
echo " realtime Download, build, and start Realtime ASR (port 8000)"
|
||||||
|
echo " serve Download, build, and start both services"
|
||||||
|
echo " run Run container interactively (after build)"
|
||||||
|
echo ""
|
||||||
|
echo "Environment variables:"
|
||||||
|
echo " VIBEVOICE_DIR Installation directory (default: ~/vibevoice-asr)"
|
||||||
|
echo ""
|
||||||
|
echo "After installation, you can also run:"
|
||||||
|
echo " cd ~/vibevoice-asr"
|
||||||
|
echo " docker run --gpus all -p 7860:7860 vibevoice-asr:dgx-spark"
|
||||||
|
}
|
||||||
|
|
||||||
|
main() {
|
||||||
|
local command="${1:-download}"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=========================================="
|
||||||
|
echo " VibeVoice-ASR Setup for DGX Spark"
|
||||||
|
echo "=========================================="
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
case "$command" in
|
||||||
|
download)
|
||||||
|
download_files
|
||||||
|
echo ""
|
||||||
|
log_info "Done! Next steps:"
|
||||||
|
echo " cd $INSTALL_DIR"
|
||||||
|
echo " docker build -t vibevoice-asr:dgx-spark ."
|
||||||
|
echo " docker run --gpus all -p 7860:7860 vibevoice-asr:dgx-spark"
|
||||||
|
;;
|
||||||
|
build)
|
||||||
|
download_files
|
||||||
|
check_prerequisites
|
||||||
|
build_image
|
||||||
|
echo ""
|
||||||
|
log_info "Done! To run:"
|
||||||
|
echo " cd $INSTALL_DIR"
|
||||||
|
echo " docker run --gpus all -p 7860:7860 vibevoice-asr:dgx-spark"
|
||||||
|
;;
|
||||||
|
run)
|
||||||
|
cd "$INSTALL_DIR" 2>/dev/null || { log_error "Run 'build' first"; exit 1; }
|
||||||
|
run_container interactive
|
||||||
|
;;
|
||||||
|
test)
|
||||||
|
cd "$INSTALL_DIR" 2>/dev/null || { log_error "Run 'build' first"; exit 1; }
|
||||||
|
run_container test
|
||||||
|
;;
|
||||||
|
demo)
|
||||||
|
download_files
|
||||||
|
check_prerequisites
|
||||||
|
build_image
|
||||||
|
run_container demo
|
||||||
|
;;
|
||||||
|
realtime)
|
||||||
|
download_files
|
||||||
|
check_prerequisites
|
||||||
|
build_image
|
||||||
|
run_container realtime
|
||||||
|
;;
|
||||||
|
serve)
|
||||||
|
download_files
|
||||||
|
check_prerequisites
|
||||||
|
build_image
|
||||||
|
run_container serve
|
||||||
|
;;
|
||||||
|
-h|--help|help)
|
||||||
|
show_usage
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
log_error "Unknown command: $command"
|
||||||
|
show_usage
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
}
|
||||||
|
|
||||||
|
main "$@"
|
||||||
899
static/scripts/vibevoice-asr/static/realtime_client.html
Normal file
899
static/scripts/vibevoice-asr/static/realtime_client.html
Normal file
@ -0,0 +1,899 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="ja">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>VibeVoice Realtime ASR Client</title>
|
||||||
|
<style>
|
||||||
|
:root {
|
||||||
|
--bg-primary: #1a1a2e;
|
||||||
|
--bg-secondary: #16213e;
|
||||||
|
--bg-tertiary: #0f3460;
|
||||||
|
--text-primary: #eaeaea;
|
||||||
|
--text-secondary: #a0a0a0;
|
||||||
|
--accent: #e94560;
|
||||||
|
--success: #4ade80;
|
||||||
|
--warning: #fbbf24;
|
||||||
|
--info: #60a5fa;
|
||||||
|
}
|
||||||
|
|
||||||
|
* {
|
||||||
|
box-sizing: border-box;
|
||||||
|
margin: 0;
|
||||||
|
padding: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
body {
|
||||||
|
font-family: 'Segoe UI', system-ui, sans-serif;
|
||||||
|
background: var(--bg-primary);
|
||||||
|
color: var(--text-primary);
|
||||||
|
min-height: 100vh;
|
||||||
|
padding: 20px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.container {
|
||||||
|
max-width: 1200px;
|
||||||
|
margin: 0 auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
header {
|
||||||
|
text-align: center;
|
||||||
|
margin-bottom: 30px;
|
||||||
|
}
|
||||||
|
|
||||||
|
h1 {
|
||||||
|
font-size: 2rem;
|
||||||
|
margin-bottom: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.subtitle {
|
||||||
|
color: var(--text-secondary);
|
||||||
|
}
|
||||||
|
|
||||||
|
.main-grid {
|
||||||
|
display: grid;
|
||||||
|
grid-template-columns: 1fr 1fr;
|
||||||
|
gap: 20px;
|
||||||
|
}
|
||||||
|
|
||||||
|
@media (max-width: 768px) {
|
||||||
|
.main-grid {
|
||||||
|
grid-template-columns: 1fr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
.card {
|
||||||
|
background: var(--bg-secondary);
|
||||||
|
border-radius: 12px;
|
||||||
|
padding: 20px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.card-title {
|
||||||
|
font-size: 1.1rem;
|
||||||
|
margin-bottom: 15px;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Controls */
|
||||||
|
.controls {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 15px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.control-row {
|
||||||
|
display: flex;
|
||||||
|
gap: 10px;
|
||||||
|
align-items: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
input[type="text"], input[type="number"] {
|
||||||
|
background: var(--bg-tertiary);
|
||||||
|
border: 1px solid rgba(255,255,255,0.1);
|
||||||
|
border-radius: 8px;
|
||||||
|
padding: 10px 15px;
|
||||||
|
color: var(--text-primary);
|
||||||
|
flex: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
input:focus {
|
||||||
|
outline: none;
|
||||||
|
border-color: var(--accent);
|
||||||
|
}
|
||||||
|
|
||||||
|
button {
|
||||||
|
background: var(--accent);
|
||||||
|
border: none;
|
||||||
|
border-radius: 8px;
|
||||||
|
padding: 12px 24px;
|
||||||
|
color: white;
|
||||||
|
font-weight: 600;
|
||||||
|
cursor: pointer;
|
||||||
|
transition: all 0.2s;
|
||||||
|
}
|
||||||
|
|
||||||
|
button:hover {
|
||||||
|
filter: brightness(1.1);
|
||||||
|
}
|
||||||
|
|
||||||
|
button:disabled {
|
||||||
|
opacity: 0.5;
|
||||||
|
cursor: not-allowed;
|
||||||
|
}
|
||||||
|
|
||||||
|
button.secondary {
|
||||||
|
background: var(--bg-tertiary);
|
||||||
|
}
|
||||||
|
|
||||||
|
button.success {
|
||||||
|
background: var(--success);
|
||||||
|
color: #000;
|
||||||
|
}
|
||||||
|
|
||||||
|
button.stop {
|
||||||
|
background: #ef4444;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Status */
|
||||||
|
.status-bar {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 10px;
|
||||||
|
padding: 10px 15px;
|
||||||
|
background: var(--bg-tertiary);
|
||||||
|
border-radius: 8px;
|
||||||
|
margin-bottom: 15px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-indicator {
|
||||||
|
width: 12px;
|
||||||
|
height: 12px;
|
||||||
|
border-radius: 50%;
|
||||||
|
background: var(--text-secondary);
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-indicator.connected {
|
||||||
|
background: var(--success);
|
||||||
|
box-shadow: 0 0 10px var(--success);
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-indicator.recording {
|
||||||
|
background: var(--accent);
|
||||||
|
box-shadow: 0 0 10px var(--accent);
|
||||||
|
animation: pulse 1s infinite;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes pulse {
|
||||||
|
0%, 100% { opacity: 1; }
|
||||||
|
50% { opacity: 0.5; }
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Transcription output */
|
||||||
|
.transcription-box {
|
||||||
|
background: var(--bg-tertiary);
|
||||||
|
border-radius: 8px;
|
||||||
|
padding: 15px;
|
||||||
|
min-height: 300px;
|
||||||
|
max-height: 500px;
|
||||||
|
overflow-y: auto;
|
||||||
|
font-family: 'Consolas', monospace;
|
||||||
|
line-height: 1.6;
|
||||||
|
}
|
||||||
|
|
||||||
|
.segment {
|
||||||
|
margin-bottom: 15px;
|
||||||
|
padding: 10px;
|
||||||
|
background: rgba(255,255,255,0.05);
|
||||||
|
border-radius: 6px;
|
||||||
|
border-left: 3px solid var(--accent);
|
||||||
|
}
|
||||||
|
|
||||||
|
.segment.partial {
|
||||||
|
border-left-color: var(--warning);
|
||||||
|
opacity: 0.8;
|
||||||
|
}
|
||||||
|
|
||||||
|
.segment-meta {
|
||||||
|
font-size: 0.8rem;
|
||||||
|
color: var(--text-secondary);
|
||||||
|
margin-bottom: 5px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.segment-text {
|
||||||
|
font-size: 1rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Audio visualizer */
|
||||||
|
.visualizer-container {
|
||||||
|
height: 80px;
|
||||||
|
background: var(--bg-tertiary);
|
||||||
|
border-radius: 8px;
|
||||||
|
overflow: hidden;
|
||||||
|
margin-bottom: 15px;
|
||||||
|
}
|
||||||
|
|
||||||
|
#visualizer {
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* VAD indicator */
|
||||||
|
.vad-indicator {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 10px;
|
||||||
|
padding: 10px;
|
||||||
|
background: var(--bg-tertiary);
|
||||||
|
border-radius: 8px;
|
||||||
|
margin-bottom: 15px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.vad-bar {
|
||||||
|
flex: 1;
|
||||||
|
height: 8px;
|
||||||
|
background: rgba(255,255,255,0.1);
|
||||||
|
border-radius: 4px;
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
|
||||||
|
.vad-level {
|
||||||
|
height: 100%;
|
||||||
|
background: var(--success);
|
||||||
|
width: 0%;
|
||||||
|
transition: width 0.1s;
|
||||||
|
}
|
||||||
|
|
||||||
|
.vad-level.speech {
|
||||||
|
background: var(--accent);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Stats */
|
||||||
|
.stats {
|
||||||
|
display: grid;
|
||||||
|
grid-template-columns: repeat(3, 1fr);
|
||||||
|
gap: 10px;
|
||||||
|
margin-top: 15px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.stat-item {
|
||||||
|
background: var(--bg-tertiary);
|
||||||
|
padding: 10px;
|
||||||
|
border-radius: 8px;
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
.stat-value {
|
||||||
|
font-size: 1.5rem;
|
||||||
|
font-weight: bold;
|
||||||
|
color: var(--accent);
|
||||||
|
}
|
||||||
|
|
||||||
|
.stat-label {
|
||||||
|
font-size: 0.75rem;
|
||||||
|
color: var(--text-secondary);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Log */
|
||||||
|
.log-box {
|
||||||
|
background: var(--bg-tertiary);
|
||||||
|
border-radius: 8px;
|
||||||
|
padding: 10px;
|
||||||
|
height: 150px;
|
||||||
|
overflow-y: auto;
|
||||||
|
font-family: 'Consolas', monospace;
|
||||||
|
font-size: 0.8rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.log-entry {
|
||||||
|
margin-bottom: 2px;
|
||||||
|
color: var(--text-secondary);
|
||||||
|
}
|
||||||
|
|
||||||
|
.log-entry.error {
|
||||||
|
color: #ef4444;
|
||||||
|
}
|
||||||
|
|
||||||
|
.log-entry.success {
|
||||||
|
color: var(--success);
|
||||||
|
}
|
||||||
|
|
||||||
|
.log-entry.info {
|
||||||
|
color: var(--info);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Settings panel */
|
||||||
|
.settings-group {
|
||||||
|
margin-top: 15px;
|
||||||
|
padding: 15px;
|
||||||
|
background: var(--bg-tertiary);
|
||||||
|
border-radius: 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.settings-group h4 {
|
||||||
|
margin-bottom: 10px;
|
||||||
|
font-size: 0.9rem;
|
||||||
|
color: var(--text-secondary);
|
||||||
|
}
|
||||||
|
|
||||||
|
.setting-item {
|
||||||
|
margin-bottom: 12px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.setting-item label {
|
||||||
|
display: block;
|
||||||
|
font-size: 0.85rem;
|
||||||
|
margin-bottom: 5px;
|
||||||
|
color: var(--text-secondary);
|
||||||
|
}
|
||||||
|
|
||||||
|
.setting-item input[type="range"] {
|
||||||
|
width: 100%;
|
||||||
|
height: 6px;
|
||||||
|
border-radius: 3px;
|
||||||
|
background: rgba(255,255,255,0.1);
|
||||||
|
outline: none;
|
||||||
|
-webkit-appearance: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.setting-item input[type="range"]::-webkit-slider-thumb {
|
||||||
|
-webkit-appearance: none;
|
||||||
|
width: 16px;
|
||||||
|
height: 16px;
|
||||||
|
border-radius: 50%;
|
||||||
|
background: var(--accent);
|
||||||
|
cursor: pointer;
|
||||||
|
}
|
||||||
|
|
||||||
|
.setting-value {
|
||||||
|
font-size: 0.8rem;
|
||||||
|
color: var(--accent);
|
||||||
|
float: right;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="container">
|
||||||
|
<header>
|
||||||
|
<h1>🎙️ VibeVoice Realtime ASR</h1>
|
||||||
|
<p class="subtitle">リアルタイム音声認識デモ</p>
|
||||||
|
</header>
|
||||||
|
|
||||||
|
<div class="status-bar">
|
||||||
|
<div class="status-indicator" id="statusIndicator"></div>
|
||||||
|
<span id="statusText">未接続</span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="main-grid">
|
||||||
|
<!-- Left column: Controls -->
|
||||||
|
<div class="card">
|
||||||
|
<h2 class="card-title">⚙️ 設定</h2>
|
||||||
|
<div class="controls">
|
||||||
|
<div class="control-row">
|
||||||
|
<input type="text" id="serverUrl" placeholder="WebSocket URL"
|
||||||
|
value="ws://localhost:8000/ws/asr/demo">
|
||||||
|
</div>
|
||||||
|
<div class="control-row">
|
||||||
|
<button id="connectBtn" onclick="toggleConnection()">接続</button>
|
||||||
|
<button id="recordBtn" class="success" onclick="toggleRecording()" disabled>
|
||||||
|
🎤 録音開始
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="visualizer-container">
|
||||||
|
<canvas id="visualizer"></canvas>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="vad-indicator">
|
||||||
|
<span>VAD:</span>
|
||||||
|
<div class="vad-bar">
|
||||||
|
<div class="vad-level" id="vadLevel"></div>
|
||||||
|
</div>
|
||||||
|
<span id="vadStatus">待機中</span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="settings-group">
|
||||||
|
<h4>🎚️ VAD設定</h4>
|
||||||
|
<div class="setting-item">
|
||||||
|
<label>
|
||||||
|
音声検出閾値
|
||||||
|
<span class="setting-value" id="vadThresholdValue">0.5</span>
|
||||||
|
</label>
|
||||||
|
<input type="range" id="vadThreshold" min="0.1" max="0.9" step="0.1" value="0.5"
|
||||||
|
onchange="updateConfig()">
|
||||||
|
</div>
|
||||||
|
<div class="setting-item">
|
||||||
|
<label>
|
||||||
|
最小発話時間 (ms)
|
||||||
|
<span class="setting-value" id="minSpeechValue">250</span>
|
||||||
|
</label>
|
||||||
|
<input type="range" id="minSpeechDuration" min="100" max="1000" step="50" value="250"
|
||||||
|
onchange="updateConfig()">
|
||||||
|
</div>
|
||||||
|
<div class="setting-item">
|
||||||
|
<label>
|
||||||
|
無音判定時間 (ms)
|
||||||
|
<span class="setting-value" id="minSilenceValue">500</span>
|
||||||
|
</label>
|
||||||
|
<input type="range" id="minSilenceDuration" min="200" max="2000" step="100" value="500"
|
||||||
|
onchange="updateConfig()">
|
||||||
|
</div>
|
||||||
|
<div class="setting-item">
|
||||||
|
<label>
|
||||||
|
最小音量閾値
|
||||||
|
<span class="setting-value" id="minVolumeValue">0.01</span>
|
||||||
|
</label>
|
||||||
|
<input type="range" id="minVolumeThreshold" min="0.001" max="0.1" step="0.001" value="0.01"
|
||||||
|
onchange="updateConfig()">
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="stats">
|
||||||
|
<div class="stat-item">
|
||||||
|
<div class="stat-value" id="statDuration">0.0</div>
|
||||||
|
<div class="stat-label">録音時間 (秒)</div>
|
||||||
|
</div>
|
||||||
|
<div class="stat-item">
|
||||||
|
<div class="stat-value" id="statSegments">0</div>
|
||||||
|
<div class="stat-label">セグメント数</div>
|
||||||
|
</div>
|
||||||
|
<div class="stat-item">
|
||||||
|
<div class="stat-value" id="statLatency">-</div>
|
||||||
|
<div class="stat-label">レイテンシ (ms)</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<h3 class="card-title" style="margin-top: 20px;">📋 ログ</h3>
|
||||||
|
<div class="log-box" id="logBox"></div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Right column: Transcription -->
|
||||||
|
<div class="card">
|
||||||
|
<h2 class="card-title">📝 認識結果</h2>
|
||||||
|
<div class="control-row" style="margin-bottom: 15px;">
|
||||||
|
<button class="secondary" onclick="clearTranscription()">クリア</button>
|
||||||
|
<button class="secondary" onclick="copyTranscription()">コピー</button>
|
||||||
|
</div>
|
||||||
|
<div class="transcription-box" id="transcriptionBox">
|
||||||
|
<p style="color: var(--text-secondary); text-align: center; padding: 50px;">
|
||||||
|
接続して録音を開始すると、ここに認識結果が表示されます
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
// State
|
||||||
|
let websocket = null;
|
||||||
|
let mediaStream = null;
|
||||||
|
let audioContext = null;
|
||||||
|
let processor = null;
|
||||||
|
let analyser = null;
|
||||||
|
let isRecording = false;
|
||||||
|
let recordingStartTime = null;
|
||||||
|
let segmentCount = 0;
|
||||||
|
let currentPartialText = '';
|
||||||
|
|
||||||
|
// DOM elements
|
||||||
|
const statusIndicator = document.getElementById('statusIndicator');
|
||||||
|
const statusText = document.getElementById('statusText');
|
||||||
|
const connectBtn = document.getElementById('connectBtn');
|
||||||
|
const recordBtn = document.getElementById('recordBtn');
|
||||||
|
const transcriptionBox = document.getElementById('transcriptionBox');
|
||||||
|
const logBox = document.getElementById('logBox');
|
||||||
|
const vadLevel = document.getElementById('vadLevel');
|
||||||
|
const vadStatus = document.getElementById('vadStatus');
|
||||||
|
|
||||||
|
// Logging
|
||||||
|
function log(message, type = 'info') {
|
||||||
|
const entry = document.createElement('div');
|
||||||
|
entry.className = `log-entry ${type}`;
|
||||||
|
entry.textContent = `[${new Date().toLocaleTimeString()}] ${message}`;
|
||||||
|
logBox.appendChild(entry);
|
||||||
|
logBox.scrollTop = logBox.scrollHeight;
|
||||||
|
console.log(`[${type}] ${message}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update config display values
|
||||||
|
function updateConfigDisplay() {
|
||||||
|
document.getElementById('vadThresholdValue').textContent =
|
||||||
|
document.getElementById('vadThreshold').value;
|
||||||
|
document.getElementById('minSpeechValue').textContent =
|
||||||
|
document.getElementById('minSpeechDuration').value;
|
||||||
|
document.getElementById('minSilenceValue').textContent =
|
||||||
|
document.getElementById('minSilenceDuration').value;
|
||||||
|
document.getElementById('minVolumeValue').textContent =
|
||||||
|
document.getElementById('minVolumeThreshold').value;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send config to server
|
||||||
|
function updateConfig() {
|
||||||
|
updateConfigDisplay();
|
||||||
|
|
||||||
|
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||||
|
const config = {
|
||||||
|
type: 'config',
|
||||||
|
config: {
|
||||||
|
vad_threshold: parseFloat(document.getElementById('vadThreshold').value),
|
||||||
|
min_speech_duration_ms: parseInt(document.getElementById('minSpeechDuration').value),
|
||||||
|
min_silence_duration_ms: parseInt(document.getElementById('minSilenceDuration').value),
|
||||||
|
min_volume_threshold: parseFloat(document.getElementById('minVolumeThreshold').value),
|
||||||
|
}
|
||||||
|
};
|
||||||
|
websocket.send(JSON.stringify(config));
|
||||||
|
log(`VAD設定を更新: 閾値=${config.config.vad_threshold}, 最小発話=${config.config.min_speech_duration_ms}ms, 無音判定=${config.config.min_silence_duration_ms}ms, 最小音量=${config.config.min_volume_threshold}`, 'info');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize config display on load
|
||||||
|
document.addEventListener('DOMContentLoaded', updateConfigDisplay);
|
||||||
|
|
||||||
|
// Connection
|
||||||
|
function toggleConnection() {
|
||||||
|
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||||
|
disconnect();
|
||||||
|
} else {
|
||||||
|
connect();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function connect() {
|
||||||
|
const url = document.getElementById('serverUrl').value;
|
||||||
|
log(`接続中: ${url}`, 'info');
|
||||||
|
|
||||||
|
try {
|
||||||
|
websocket = new WebSocket(url);
|
||||||
|
|
||||||
|
websocket.onopen = () => {
|
||||||
|
log('接続成功', 'success');
|
||||||
|
statusIndicator.classList.add('connected');
|
||||||
|
statusText.textContent = '接続済み';
|
||||||
|
connectBtn.textContent = '切断';
|
||||||
|
recordBtn.disabled = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
websocket.onclose = () => {
|
||||||
|
log('切断されました', 'info');
|
||||||
|
handleDisconnect();
|
||||||
|
};
|
||||||
|
|
||||||
|
websocket.onerror = (error) => {
|
||||||
|
log(`エラー: ${error}`, 'error');
|
||||||
|
};
|
||||||
|
|
||||||
|
websocket.onmessage = (event) => {
|
||||||
|
handleMessage(JSON.parse(event.data));
|
||||||
|
};
|
||||||
|
|
||||||
|
} catch (error) {
|
||||||
|
log(`接続エラー: ${error}`, 'error');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function disconnect() {
|
||||||
|
if (isRecording) {
|
||||||
|
stopRecording();
|
||||||
|
}
|
||||||
|
if (websocket) {
|
||||||
|
websocket.close();
|
||||||
|
}
|
||||||
|
handleDisconnect();
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleDisconnect() {
|
||||||
|
statusIndicator.classList.remove('connected', 'recording');
|
||||||
|
statusText.textContent = '未接続';
|
||||||
|
connectBtn.textContent = '接続';
|
||||||
|
recordBtn.disabled = true;
|
||||||
|
websocket = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Message handling
|
||||||
|
function handleMessage(data) {
|
||||||
|
switch (data.type) {
|
||||||
|
case 'status':
|
||||||
|
log(`ステータス: ${data.message}`, 'info');
|
||||||
|
break;
|
||||||
|
|
||||||
|
case 'partial_result':
|
||||||
|
updatePartialResult(data);
|
||||||
|
break;
|
||||||
|
|
||||||
|
case 'final_result':
|
||||||
|
addFinalResult(data);
|
||||||
|
break;
|
||||||
|
|
||||||
|
case 'vad_event':
|
||||||
|
handleVADEvent(data);
|
||||||
|
break;
|
||||||
|
|
||||||
|
case 'error':
|
||||||
|
log(`サーバーエラー: ${data.error}`, 'error');
|
||||||
|
break;
|
||||||
|
|
||||||
|
case 'pong':
|
||||||
|
// Heartbeat response
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
log(`不明なメッセージ: ${data.type}`, 'info');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function updatePartialResult(data) {
|
||||||
|
currentPartialText = data.text;
|
||||||
|
updateTranscriptionDisplay();
|
||||||
|
|
||||||
|
if (data.latency_ms) {
|
||||||
|
document.getElementById('statLatency').textContent =
|
||||||
|
Math.round(data.latency_ms);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function addFinalResult(data) {
|
||||||
|
currentPartialText = '';
|
||||||
|
segmentCount++;
|
||||||
|
document.getElementById('statSegments').textContent = segmentCount;
|
||||||
|
|
||||||
|
if (data.latency_ms) {
|
||||||
|
document.getElementById('statLatency').textContent =
|
||||||
|
Math.round(data.latency_ms);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add final segment to display
|
||||||
|
const segment = document.createElement('div');
|
||||||
|
segment.className = 'segment';
|
||||||
|
|
||||||
|
let metaText = '';
|
||||||
|
if (data.segments && data.segments.length > 0) {
|
||||||
|
const seg = data.segments[0];
|
||||||
|
metaText = `[${seg.start_time?.toFixed(2) || '?'}s - ${seg.end_time?.toFixed(2) || '?'}s] ${seg.speaker_id || ''}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
segment.innerHTML = `
|
||||||
|
<div class="segment-meta">${metaText}</div>
|
||||||
|
<div class="segment-text">${data.text}</div>
|
||||||
|
`;
|
||||||
|
|
||||||
|
// Remove placeholder if exists
|
||||||
|
const placeholder = transcriptionBox.querySelector('p');
|
||||||
|
if (placeholder) placeholder.remove();
|
||||||
|
|
||||||
|
transcriptionBox.appendChild(segment);
|
||||||
|
transcriptionBox.scrollTop = transcriptionBox.scrollHeight;
|
||||||
|
|
||||||
|
log(`認識完了: "${data.text.substring(0, 30)}..."`, 'success');
|
||||||
|
}
|
||||||
|
|
||||||
|
function updateTranscriptionDisplay() {
|
||||||
|
// Update or create partial display
|
||||||
|
let partialDiv = transcriptionBox.querySelector('.segment.partial');
|
||||||
|
|
||||||
|
if (currentPartialText) {
|
||||||
|
if (!partialDiv) {
|
||||||
|
partialDiv = document.createElement('div');
|
||||||
|
partialDiv.className = 'segment partial';
|
||||||
|
transcriptionBox.appendChild(partialDiv);
|
||||||
|
}
|
||||||
|
partialDiv.innerHTML = `
|
||||||
|
<div class="segment-meta">認識中...</div>
|
||||||
|
<div class="segment-text">${currentPartialText}</div>
|
||||||
|
`;
|
||||||
|
transcriptionBox.scrollTop = transcriptionBox.scrollHeight;
|
||||||
|
} else if (partialDiv) {
|
||||||
|
partialDiv.remove();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleVADEvent(data) {
|
||||||
|
if (data.event === 'speech_start') {
|
||||||
|
vadLevel.classList.add('speech');
|
||||||
|
vadStatus.textContent = '発話中';
|
||||||
|
log(`発話開始 @ ${data.audio_timestamp_sec?.toFixed(2)}s`, 'info');
|
||||||
|
} else if (data.event === 'speech_end') {
|
||||||
|
vadLevel.classList.remove('speech');
|
||||||
|
vadStatus.textContent = '待機中';
|
||||||
|
log(`発話終了 @ ${data.audio_timestamp_sec?.toFixed(2)}s`, 'info');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recording
|
||||||
|
async function toggleRecording() {
|
||||||
|
if (isRecording) {
|
||||||
|
stopRecording();
|
||||||
|
} else {
|
||||||
|
await startRecording();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function startRecording() {
|
||||||
|
try {
|
||||||
|
log('マイクアクセスをリクエスト中...', 'info');
|
||||||
|
|
||||||
|
mediaStream = await navigator.mediaDevices.getUserMedia({
|
||||||
|
audio: {
|
||||||
|
sampleRate: 16000,
|
||||||
|
channelCount: 1,
|
||||||
|
echoCancellation: true,
|
||||||
|
noiseSuppression: true,
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
audioContext = new (window.AudioContext || window.webkitAudioContext)({
|
||||||
|
sampleRate: 16000
|
||||||
|
});
|
||||||
|
|
||||||
|
const source = audioContext.createMediaStreamSource(mediaStream);
|
||||||
|
|
||||||
|
// Analyser for visualization
|
||||||
|
analyser = audioContext.createAnalyser();
|
||||||
|
analyser.fftSize = 256;
|
||||||
|
source.connect(analyser);
|
||||||
|
|
||||||
|
// ScriptProcessor for sending audio
|
||||||
|
const bufferSize = 4096;
|
||||||
|
processor = audioContext.createScriptProcessor(bufferSize, 1, 1);
|
||||||
|
|
||||||
|
processor.onaudioprocess = (e) => {
|
||||||
|
if (!isRecording || !websocket || websocket.readyState !== WebSocket.OPEN) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const inputData = e.inputBuffer.getChannelData(0);
|
||||||
|
|
||||||
|
// Convert to 16-bit PCM
|
||||||
|
const pcmData = new Int16Array(inputData.length);
|
||||||
|
for (let i = 0; i < inputData.length; i++) {
|
||||||
|
pcmData[i] = Math.max(-32768, Math.min(32767, inputData[i] * 32768));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send as binary
|
||||||
|
websocket.send(pcmData.buffer);
|
||||||
|
};
|
||||||
|
|
||||||
|
source.connect(processor);
|
||||||
|
processor.connect(audioContext.destination);
|
||||||
|
|
||||||
|
isRecording = true;
|
||||||
|
recordingStartTime = Date.now();
|
||||||
|
statusIndicator.classList.add('recording');
|
||||||
|
recordBtn.textContent = '⏹️ 録音停止';
|
||||||
|
recordBtn.classList.remove('success');
|
||||||
|
recordBtn.classList.add('stop');
|
||||||
|
|
||||||
|
// Start visualization
|
||||||
|
visualize();
|
||||||
|
updateDuration();
|
||||||
|
|
||||||
|
log('録音開始', 'success');
|
||||||
|
|
||||||
|
} catch (error) {
|
||||||
|
log(`マイクエラー: ${error}`, 'error');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function stopRecording() {
|
||||||
|
isRecording = false;
|
||||||
|
|
||||||
|
if (processor) {
|
||||||
|
processor.disconnect();
|
||||||
|
processor = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (audioContext) {
|
||||||
|
audioContext.close();
|
||||||
|
audioContext = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (mediaStream) {
|
||||||
|
mediaStream.getTracks().forEach(track => track.stop());
|
||||||
|
mediaStream = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
statusIndicator.classList.remove('recording');
|
||||||
|
recordBtn.textContent = '🎤 録音開始';
|
||||||
|
recordBtn.classList.remove('stop');
|
||||||
|
recordBtn.classList.add('success');
|
||||||
|
|
||||||
|
// Send stop message
|
||||||
|
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||||
|
websocket.send(JSON.stringify({ type: 'stop' }));
|
||||||
|
}
|
||||||
|
|
||||||
|
log('録音停止', 'info');
|
||||||
|
}
|
||||||
|
|
||||||
|
// Visualization
|
||||||
|
function visualize() {
|
||||||
|
if (!analyser || !isRecording) return;
|
||||||
|
|
||||||
|
const canvas = document.getElementById('visualizer');
|
||||||
|
const ctx = canvas.getContext('2d');
|
||||||
|
const width = canvas.width = canvas.offsetWidth;
|
||||||
|
const height = canvas.height = canvas.offsetHeight;
|
||||||
|
|
||||||
|
const bufferLength = analyser.frequencyBinCount;
|
||||||
|
const dataArray = new Uint8Array(bufferLength);
|
||||||
|
|
||||||
|
function draw() {
|
||||||
|
if (!isRecording) return;
|
||||||
|
requestAnimationFrame(draw);
|
||||||
|
|
||||||
|
analyser.getByteFrequencyData(dataArray);
|
||||||
|
|
||||||
|
ctx.fillStyle = 'rgb(15, 52, 96)';
|
||||||
|
ctx.fillRect(0, 0, width, height);
|
||||||
|
|
||||||
|
const barWidth = (width / bufferLength) * 2.5;
|
||||||
|
let x = 0;
|
||||||
|
|
||||||
|
// Calculate average for VAD indicator
|
||||||
|
let sum = 0;
|
||||||
|
for (let i = 0; i < bufferLength; i++) {
|
||||||
|
const barHeight = (dataArray[i] / 255) * height;
|
||||||
|
|
||||||
|
const gradient = ctx.createLinearGradient(0, height, 0, height - barHeight);
|
||||||
|
gradient.addColorStop(0, '#e94560');
|
||||||
|
gradient.addColorStop(1, '#4ade80');
|
||||||
|
|
||||||
|
ctx.fillStyle = gradient;
|
||||||
|
ctx.fillRect(x, height - barHeight, barWidth, barHeight);
|
||||||
|
|
||||||
|
x += barWidth + 1;
|
||||||
|
sum += dataArray[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update VAD level indicator
|
||||||
|
const avgLevel = (sum / bufferLength / 255) * 100;
|
||||||
|
vadLevel.style.width = `${avgLevel}%`;
|
||||||
|
}
|
||||||
|
|
||||||
|
draw();
|
||||||
|
}
|
||||||
|
|
||||||
|
function updateDuration() {
|
||||||
|
if (!isRecording) return;
|
||||||
|
|
||||||
|
const duration = (Date.now() - recordingStartTime) / 1000;
|
||||||
|
document.getElementById('statDuration').textContent = duration.toFixed(1);
|
||||||
|
|
||||||
|
requestAnimationFrame(updateDuration);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Utility functions
|
||||||
|
function clearTranscription() {
|
||||||
|
transcriptionBox.innerHTML = `
|
||||||
|
<p style="color: var(--text-secondary); text-align: center; padding: 50px;">
|
||||||
|
接続して録音を開始すると、ここに認識結果が表示されます
|
||||||
|
</p>
|
||||||
|
`;
|
||||||
|
segmentCount = 0;
|
||||||
|
currentPartialText = '';
|
||||||
|
document.getElementById('statSegments').textContent = '0';
|
||||||
|
document.getElementById('statLatency').textContent = '-';
|
||||||
|
log('認識結果をクリアしました', 'info');
|
||||||
|
}
|
||||||
|
|
||||||
|
function copyTranscription() {
|
||||||
|
const segments = transcriptionBox.querySelectorAll('.segment:not(.partial) .segment-text');
|
||||||
|
const text = Array.from(segments).map(s => s.textContent).join('\n');
|
||||||
|
|
||||||
|
if (text) {
|
||||||
|
navigator.clipboard.writeText(text).then(() => {
|
||||||
|
log('認識結果をコピーしました', 'success');
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Heartbeat
|
||||||
|
setInterval(() => {
|
||||||
|
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||||
|
websocket.send(JSON.stringify({ type: 'ping' }));
|
||||||
|
}
|
||||||
|
}, 30000);
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
173
static/scripts/vibevoice-asr/test_vibevoice.py
Normal file
173
static/scripts/vibevoice-asr/test_vibevoice.py
Normal file
@ -0,0 +1,173 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
VibeVoice-ASR Test Script for DGX Spark
|
||||||
|
Tests basic functionality and GPU availability
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
|
||||||
|
def test_imports():
|
||||||
|
"""Test that VibeVoice can be imported"""
|
||||||
|
print("=" * 60)
|
||||||
|
print("Testing VibeVoice imports...")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import vibevoice
|
||||||
|
print("[OK] vibevoice imported successfully")
|
||||||
|
return True
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"[FAIL] Failed to import vibevoice: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def test_torch_cuda():
|
||||||
|
"""Test PyTorch CUDA availability"""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Testing PyTorch CUDA...")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
print(f"[INFO] PyTorch version: {torch.__version__}")
|
||||||
|
print(f"[INFO] CUDA available: {torch.cuda.is_available()}")
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
print(f"[INFO] CUDA version: {torch.version.cuda}")
|
||||||
|
print(f"[INFO] GPU count: {torch.cuda.device_count()}")
|
||||||
|
|
||||||
|
for i in range(torch.cuda.device_count()):
|
||||||
|
props = torch.cuda.get_device_properties(i)
|
||||||
|
print(f"[INFO] GPU {i}: {props.name}")
|
||||||
|
print(f" Compute capability: {props.major}.{props.minor}")
|
||||||
|
print(f" Total memory: {props.total_memory / 1024**3:.1f} GB")
|
||||||
|
|
||||||
|
# Quick CUDA test
|
||||||
|
x = torch.randn(100, 100, device='cuda')
|
||||||
|
y = torch.matmul(x, x)
|
||||||
|
print(f"[OK] CUDA tensor operations working")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print("[WARN] CUDA not available")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[FAIL] PyTorch CUDA test failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def test_flash_attention():
|
||||||
|
"""Test flash attention availability"""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Testing Flash Attention...")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import flash_attn
|
||||||
|
print(f"[OK] flash_attn version: {flash_attn.__version__}")
|
||||||
|
return True
|
||||||
|
except ImportError:
|
||||||
|
print("[WARN] flash_attn not installed (optional)")
|
||||||
|
return True # Not required
|
||||||
|
|
||||||
|
|
||||||
|
def test_ffmpeg():
|
||||||
|
"""Test FFmpeg availability"""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Testing FFmpeg...")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
["ffmpeg", "-version"],
|
||||||
|
capture_output=True,
|
||||||
|
text=True
|
||||||
|
)
|
||||||
|
if result.returncode == 0:
|
||||||
|
version_line = result.stdout.split('\n')[0]
|
||||||
|
print(f"[OK] {version_line}")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print("[FAIL] FFmpeg returned error")
|
||||||
|
return False
|
||||||
|
except FileNotFoundError:
|
||||||
|
print("[FAIL] FFmpeg not found")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def test_asr_model():
|
||||||
|
"""Test loading ASR model (if GPU available)"""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Testing ASR Model Loading...")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
print("[SKIP] Skipping model test - no GPU available")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Try to load the ASR pipeline
|
||||||
|
from vibevoice import ASRPipeline
|
||||||
|
print("[INFO] Loading ASR pipeline...")
|
||||||
|
|
||||||
|
# Use smaller model for testing
|
||||||
|
pipeline = ASRPipeline()
|
||||||
|
print("[OK] ASR pipeline loaded successfully")
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
del pipeline
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"[WARN] ASRPipeline not available: {e}")
|
||||||
|
print("[INFO] This may be normal depending on VibeVoice version")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[WARN] ASR model test: {e}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Run all tests"""
|
||||||
|
print("\n")
|
||||||
|
print("*" * 60)
|
||||||
|
print(" VibeVoice-ASR Test Suite for DGX Spark")
|
||||||
|
print("*" * 60)
|
||||||
|
|
||||||
|
results = {
|
||||||
|
"imports": test_imports(),
|
||||||
|
"torch_cuda": test_torch_cuda(),
|
||||||
|
"flash_attention": test_flash_attention(),
|
||||||
|
"ffmpeg": test_ffmpeg(),
|
||||||
|
"asr_model": test_asr_model(),
|
||||||
|
}
|
||||||
|
|
||||||
|
print("\n")
|
||||||
|
print("=" * 60)
|
||||||
|
print("Test Summary")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
all_passed = True
|
||||||
|
for name, passed in results.items():
|
||||||
|
status = "[OK]" if passed else "[FAIL]"
|
||||||
|
print(f" {status} {name}")
|
||||||
|
if not passed:
|
||||||
|
all_passed = False
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
if all_passed:
|
||||||
|
print("\nAll tests passed!")
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
print("\nSome tests failed.")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
1229
static/scripts/vibevoice-asr/vibevoice_asr_gradio_demo_patched.py
Normal file
1229
static/scripts/vibevoice-asr/vibevoice_asr_gradio_demo_patched.py
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user