Add: VibeVoice ASR セットアップスクリプト一式
All checks were successful
Deploy Docusaurus Site / deploy (push) Successful in 27s
All checks were successful
Deploy Docusaurus Site / deploy (push) Successful in 27s
This commit is contained in:
parent
2d753f114f
commit
1fb76254e9
61
static/scripts/vibevoice-asr/Dockerfile
Normal file
61
static/scripts/vibevoice-asr/Dockerfile
Normal file
@ -0,0 +1,61 @@
|
||||
# VibeVoice-ASR for DGX Spark (ARM64, Blackwell GB10, sm_121)
|
||||
# Based on NVIDIA PyTorch container for CUDA 13.1 compatibility
|
||||
|
||||
ARG TARGETARCH
|
||||
FROM nvcr.io/nvidia/pytorch:25.11-py3 AS base
|
||||
|
||||
LABEL maintainer="VibeVoice-ASR DGX Spark Setup"
|
||||
LABEL description="VibeVoice-ASR optimized for DGX Spark (ARM64, CUDA 13.1)"
|
||||
|
||||
# Set environment variables
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
# PyTorch CUDA settings for DGX Spark
|
||||
ENV PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
||||
ENV USE_LIBUV=0
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /workspace
|
||||
|
||||
# Install system dependencies including FFmpeg for demo
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ffmpeg \
|
||||
git \
|
||||
curl \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install flash-attn if not already present
|
||||
RUN pip install --no-cache-dir flash-attn --no-build-isolation || true
|
||||
|
||||
# Clone and install VibeVoice
|
||||
RUN git clone https://github.com/microsoft/VibeVoice.git /workspace/VibeVoice && \
|
||||
cd /workspace/VibeVoice && \
|
||||
pip install --no-cache-dir -e .
|
||||
|
||||
# Create test script and patched demo with MKV support
|
||||
COPY test_vibevoice.py /workspace/test_vibevoice.py
|
||||
COPY vibevoice_asr_gradio_demo_patched.py /workspace/VibeVoice/demo/vibevoice_asr_gradio_demo.py
|
||||
|
||||
# Install real-time ASR dependencies
|
||||
COPY requirements-realtime.txt /workspace/requirements-realtime.txt
|
||||
RUN pip install --no-cache-dir -r /workspace/requirements-realtime.txt
|
||||
|
||||
# Copy real-time ASR module and startup scripts
|
||||
COPY realtime/ /workspace/VibeVoice/realtime/
|
||||
COPY static/ /workspace/VibeVoice/static/
|
||||
COPY run_all.sh /workspace/VibeVoice/run_all.sh
|
||||
COPY run_realtime.sh /workspace/VibeVoice/run_realtime.sh
|
||||
RUN chmod +x /workspace/VibeVoice/run_all.sh /workspace/VibeVoice/run_realtime.sh
|
||||
|
||||
# Set default working directory to VibeVoice
|
||||
WORKDIR /workspace/VibeVoice
|
||||
|
||||
# Expose Gradio port and WebSocket port
|
||||
EXPOSE 7860
|
||||
EXPOSE 8000
|
||||
|
||||
# Default command: Launch Gradio demo with MKV support
|
||||
CMD ["python", "demo/vibevoice_asr_gradio_demo.py", "--model_path", "microsoft/VibeVoice-ASR", "--host", "0.0.0.0"]
|
||||
7
static/scripts/vibevoice-asr/realtime/__init__.py
Normal file
7
static/scripts/vibevoice-asr/realtime/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
"""
|
||||
VibeVoice Realtime ASR Module
|
||||
|
||||
WebSocket-based real-time speech recognition using VibeVoice ASR.
|
||||
"""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
358
static/scripts/vibevoice-asr/realtime/asr_worker.py
Normal file
358
static/scripts/vibevoice-asr/realtime/asr_worker.py
Normal file
@ -0,0 +1,358 @@
|
||||
"""
|
||||
ASR Worker for real-time transcription.
|
||||
|
||||
Wraps the existing VibeVoiceASRInference for async/streaming operation.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
import queue
|
||||
from typing import AsyncGenerator, Optional, List, Callable
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# Add parent directory and demo directory to path for importing existing code
|
||||
_parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, _parent_dir)
|
||||
sys.path.insert(0, os.path.join(_parent_dir, "demo"))
|
||||
|
||||
from .models import (
|
||||
TranscriptionResult,
|
||||
TranscriptionSegment,
|
||||
MessageType,
|
||||
SessionConfig,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceRequest:
|
||||
"""Request for ASR inference."""
|
||||
audio: np.ndarray
|
||||
sample_rate: int
|
||||
context_info: Optional[str]
|
||||
request_time: float
|
||||
segment_start_sec: float
|
||||
segment_end_sec: float
|
||||
|
||||
|
||||
class ASRWorker:
|
||||
"""
|
||||
ASR Worker that wraps VibeVoiceASRInference for real-time use.
|
||||
|
||||
Features:
|
||||
- Async interface for WebSocket integration
|
||||
- Streaming output via TextIteratorStreamer
|
||||
- Request queuing for handling concurrent segments
|
||||
- Graceful model loading and error handling
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str = "microsoft/VibeVoice-ASR",
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
attn_implementation: str = "flash_attention_2",
|
||||
):
|
||||
"""
|
||||
Initialize the ASR worker.
|
||||
|
||||
Args:
|
||||
model_path: Path to VibeVoice ASR model
|
||||
device: Device to run inference on
|
||||
dtype: Model data type
|
||||
attn_implementation: Attention implementation
|
||||
"""
|
||||
self.model_path = model_path
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
self.attn_implementation = attn_implementation
|
||||
|
||||
self._inference = None
|
||||
self._is_loaded = False
|
||||
self._load_lock = threading.Lock()
|
||||
|
||||
# Inference queue for serializing requests
|
||||
self._inference_semaphore = asyncio.Semaphore(1)
|
||||
|
||||
def load_model(self) -> bool:
|
||||
"""
|
||||
Load the ASR model.
|
||||
|
||||
Returns:
|
||||
True if model loaded successfully
|
||||
"""
|
||||
with self._load_lock:
|
||||
if self._is_loaded:
|
||||
return True
|
||||
|
||||
try:
|
||||
# Import here to avoid circular imports and allow lazy loading
|
||||
# In Docker, the file is copied as vibevoice_asr_gradio_demo.py
|
||||
try:
|
||||
from vibevoice_asr_gradio_demo import VibeVoiceASRInference
|
||||
except ImportError:
|
||||
from vibevoice_asr_gradio_demo_patched import VibeVoiceASRInference
|
||||
|
||||
print(f"Loading VibeVoice ASR model from {self.model_path}...")
|
||||
self._inference = VibeVoiceASRInference(
|
||||
model_path=self.model_path,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
attn_implementation=self.attn_implementation,
|
||||
)
|
||||
self._is_loaded = True
|
||||
print("ASR model loaded successfully")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to load ASR model: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_loaded(self) -> bool:
|
||||
"""Check if model is loaded."""
|
||||
return self._is_loaded
|
||||
|
||||
async def transcribe_segment(
|
||||
self,
|
||||
audio: np.ndarray,
|
||||
sample_rate: int = 16000,
|
||||
context_info: Optional[str] = None,
|
||||
segment_start_sec: float = 0.0,
|
||||
segment_end_sec: float = 0.0,
|
||||
config: Optional[SessionConfig] = None,
|
||||
on_partial: Optional[Callable[[TranscriptionResult], None]] = None,
|
||||
) -> TranscriptionResult:
|
||||
"""
|
||||
Transcribe an audio segment asynchronously.
|
||||
|
||||
Args:
|
||||
audio: Audio data as float32 array
|
||||
sample_rate: Audio sample rate
|
||||
context_info: Optional context for transcription
|
||||
segment_start_sec: Start time of segment in session
|
||||
segment_end_sec: End time of segment in session
|
||||
config: Session configuration
|
||||
on_partial: Callback for partial results
|
||||
|
||||
Returns:
|
||||
Final transcription result
|
||||
"""
|
||||
if not self._is_loaded:
|
||||
if not self.load_model():
|
||||
return TranscriptionResult(
|
||||
type=MessageType.ERROR,
|
||||
text="",
|
||||
is_final=True,
|
||||
latency_ms=0,
|
||||
)
|
||||
|
||||
config = config or SessionConfig()
|
||||
request_time = time.time()
|
||||
|
||||
# Serialize inference requests
|
||||
async with self._inference_semaphore:
|
||||
return await self._run_inference(
|
||||
audio=audio,
|
||||
sample_rate=sample_rate,
|
||||
context_info=context_info,
|
||||
segment_start_sec=segment_start_sec,
|
||||
segment_end_sec=segment_end_sec,
|
||||
config=config,
|
||||
request_time=request_time,
|
||||
on_partial=on_partial,
|
||||
)
|
||||
|
||||
async def _run_inference(
|
||||
self,
|
||||
audio: np.ndarray,
|
||||
sample_rate: int,
|
||||
context_info: Optional[str],
|
||||
segment_start_sec: float,
|
||||
segment_end_sec: float,
|
||||
config: SessionConfig,
|
||||
request_time: float,
|
||||
on_partial: Optional[Callable[[TranscriptionResult], None]],
|
||||
) -> TranscriptionResult:
|
||||
"""Run the actual inference in a thread pool."""
|
||||
from transformers import TextIteratorStreamer
|
||||
|
||||
# Create streamer for partial results
|
||||
streamer = None
|
||||
if config.return_partial_results and on_partial:
|
||||
streamer = TextIteratorStreamer(
|
||||
self._inference.processor.tokenizer,
|
||||
skip_prompt=True,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
|
||||
# Result container for thread
|
||||
result_container = {"result": None, "error": None}
|
||||
|
||||
def run_inference():
|
||||
try:
|
||||
# Save audio to temp file (required by current implementation)
|
||||
import tempfile
|
||||
import soundfile as sf
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
||||
temp_path = f.name
|
||||
|
||||
# Write audio
|
||||
audio_int16 = (audio * 32768.0).clip(-32768, 32767).astype(np.int16)
|
||||
sf.write(temp_path, audio_int16, sample_rate, subtype='PCM_16')
|
||||
|
||||
try:
|
||||
result = self._inference.transcribe(
|
||||
audio_path=temp_path,
|
||||
max_new_tokens=config.max_new_tokens,
|
||||
temperature=config.temperature,
|
||||
context_info=context_info,
|
||||
streamer=streamer,
|
||||
)
|
||||
result_container["result"] = result
|
||||
finally:
|
||||
# Clean up temp file
|
||||
try:
|
||||
os.unlink(temp_path)
|
||||
except:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
result_container["error"] = str(e)
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
# Start inference in background thread
|
||||
inference_thread = threading.Thread(target=run_inference)
|
||||
inference_thread.start()
|
||||
|
||||
# Stream partial results if enabled
|
||||
partial_text = ""
|
||||
if streamer and on_partial:
|
||||
try:
|
||||
for new_text in streamer:
|
||||
partial_text += new_text
|
||||
partial_result = TranscriptionResult(
|
||||
type=MessageType.PARTIAL_RESULT,
|
||||
text=partial_text,
|
||||
is_final=False,
|
||||
latency_ms=(time.time() - request_time) * 1000,
|
||||
)
|
||||
# Call callback (may be async)
|
||||
if asyncio.iscoroutinefunction(on_partial):
|
||||
await on_partial(partial_result)
|
||||
else:
|
||||
on_partial(partial_result)
|
||||
except Exception as e:
|
||||
print(f"Error during streaming: {e}")
|
||||
|
||||
# Wait for completion
|
||||
inference_thread.join()
|
||||
|
||||
latency_ms = (time.time() - request_time) * 1000
|
||||
|
||||
if result_container["error"]:
|
||||
return TranscriptionResult(
|
||||
type=MessageType.ERROR,
|
||||
text=f"Error: {result_container['error']}",
|
||||
is_final=True,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
result = result_container["result"]
|
||||
|
||||
# Convert segments to our format
|
||||
segments = []
|
||||
for seg in result.get("segments", []):
|
||||
# Adjust timestamps relative to session
|
||||
seg_start = seg.get("start_time", 0)
|
||||
seg_end = seg.get("end_time", 0)
|
||||
|
||||
# If segment has relative timestamps, adjust to absolute
|
||||
if isinstance(seg_start, (int, float)) and isinstance(seg_end, (int, float)):
|
||||
adjusted_start = segment_start_sec + seg_start
|
||||
adjusted_end = segment_start_sec + seg_end
|
||||
else:
|
||||
adjusted_start = segment_start_sec
|
||||
adjusted_end = segment_end_sec
|
||||
|
||||
segments.append(TranscriptionSegment(
|
||||
start_time=adjusted_start,
|
||||
end_time=adjusted_end,
|
||||
speaker_id=seg.get("speaker_id", "SPEAKER_00"),
|
||||
text=seg.get("text", ""),
|
||||
))
|
||||
|
||||
return TranscriptionResult(
|
||||
type=MessageType.FINAL_RESULT,
|
||||
text=result.get("raw_text", ""),
|
||||
is_final=True,
|
||||
segments=segments,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
async def transcribe_stream(
|
||||
self,
|
||||
audio: np.ndarray,
|
||||
sample_rate: int = 16000,
|
||||
context_info: Optional[str] = None,
|
||||
segment_start_sec: float = 0.0,
|
||||
segment_end_sec: float = 0.0,
|
||||
config: Optional[SessionConfig] = None,
|
||||
) -> AsyncGenerator[TranscriptionResult, None]:
|
||||
"""
|
||||
Transcribe an audio segment with streaming output.
|
||||
|
||||
Yields partial results followed by final result.
|
||||
|
||||
Args:
|
||||
audio: Audio data
|
||||
sample_rate: Sample rate
|
||||
context_info: Optional context
|
||||
segment_start_sec: Segment start time
|
||||
segment_end_sec: Segment end time
|
||||
config: Session config
|
||||
|
||||
Yields:
|
||||
TranscriptionResult objects (partial and final)
|
||||
"""
|
||||
result_queue: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
async def on_partial(result: TranscriptionResult):
|
||||
await result_queue.put(result)
|
||||
|
||||
# Start transcription task
|
||||
transcribe_task = asyncio.create_task(
|
||||
self.transcribe_segment(
|
||||
audio=audio,
|
||||
sample_rate=sample_rate,
|
||||
context_info=context_info,
|
||||
segment_start_sec=segment_start_sec,
|
||||
segment_end_sec=segment_end_sec,
|
||||
config=config,
|
||||
on_partial=on_partial,
|
||||
)
|
||||
)
|
||||
|
||||
# Yield partial results as they come
|
||||
while not transcribe_task.done():
|
||||
try:
|
||||
result = await asyncio.wait_for(result_queue.get(), timeout=0.1)
|
||||
yield result
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
# Drain any remaining partial results
|
||||
while not result_queue.empty():
|
||||
yield await result_queue.get()
|
||||
|
||||
# Yield final result
|
||||
final_result = await transcribe_task
|
||||
yield final_result
|
||||
246
static/scripts/vibevoice-asr/realtime/audio_buffer.py
Normal file
246
static/scripts/vibevoice-asr/realtime/audio_buffer.py
Normal file
@ -0,0 +1,246 @@
|
||||
"""
|
||||
Audio buffer management for real-time ASR.
|
||||
|
||||
Implements a ring buffer for efficient audio chunk management with overlap support.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from typing import Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
import threading
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioChunkInfo:
|
||||
"""Information about an extracted audio chunk."""
|
||||
audio: np.ndarray
|
||||
start_sample: int
|
||||
end_sample: int
|
||||
start_sec: float
|
||||
end_sec: float
|
||||
|
||||
|
||||
class AudioBuffer:
|
||||
"""
|
||||
Ring buffer for managing audio chunks with overlap support.
|
||||
|
||||
Features:
|
||||
- Efficient memory management with fixed-size buffer
|
||||
- Overlap handling for continuous processing
|
||||
- Thread-safe operations
|
||||
- Automatic sample rate tracking
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sample_rate: int = 16000,
|
||||
chunk_duration_sec: float = 3.0,
|
||||
overlap_sec: float = 0.5,
|
||||
max_buffer_sec: float = 60.0,
|
||||
):
|
||||
"""
|
||||
Initialize the audio buffer.
|
||||
|
||||
Args:
|
||||
sample_rate: Audio sample rate in Hz
|
||||
chunk_duration_sec: Duration of each processing chunk
|
||||
overlap_sec: Overlap between consecutive chunks
|
||||
max_buffer_sec: Maximum buffer duration (older data will be discarded)
|
||||
"""
|
||||
self.sample_rate = sample_rate
|
||||
self.chunk_size = int(chunk_duration_sec * sample_rate)
|
||||
self.overlap_size = int(overlap_sec * sample_rate)
|
||||
self.max_buffer_size = int(max_buffer_sec * sample_rate)
|
||||
|
||||
# Main buffer (pre-allocated)
|
||||
self._buffer = np.zeros(self.max_buffer_size, dtype=np.float32)
|
||||
self._write_pos = 0 # Next position to write
|
||||
self._read_pos = 0 # Position of unprocessed data start
|
||||
self._total_samples_received = 0 # Total samples since session start
|
||||
|
||||
self._lock = threading.Lock()
|
||||
|
||||
@property
|
||||
def samples_available(self) -> int:
|
||||
"""Number of unprocessed samples in buffer."""
|
||||
with self._lock:
|
||||
return self._write_pos - self._read_pos
|
||||
|
||||
@property
|
||||
def duration_available_sec(self) -> float:
|
||||
"""Duration of unprocessed audio in seconds."""
|
||||
return self.samples_available / self.sample_rate
|
||||
|
||||
@property
|
||||
def total_duration_sec(self) -> float:
|
||||
"""Total duration of audio received since session start."""
|
||||
return self._total_samples_received / self.sample_rate
|
||||
|
||||
def append(self, audio_chunk: np.ndarray) -> int:
|
||||
"""
|
||||
Append audio chunk to the buffer.
|
||||
|
||||
Args:
|
||||
audio_chunk: Audio data as float32 array (range: -1.0 to 1.0)
|
||||
|
||||
Returns:
|
||||
Number of samples actually appended
|
||||
"""
|
||||
if audio_chunk.dtype != np.float32:
|
||||
audio_chunk = audio_chunk.astype(np.float32)
|
||||
|
||||
# Ensure 1D
|
||||
if audio_chunk.ndim > 1:
|
||||
audio_chunk = audio_chunk.flatten()
|
||||
|
||||
with self._lock:
|
||||
chunk_len = len(audio_chunk)
|
||||
|
||||
# Check if we need to shift buffer (running out of space)
|
||||
if self._write_pos + chunk_len > self.max_buffer_size:
|
||||
self._compact_buffer()
|
||||
|
||||
# Still not enough space? Discard old unprocessed data
|
||||
if self._write_pos + chunk_len > self.max_buffer_size:
|
||||
overflow = (self._write_pos + chunk_len) - self.max_buffer_size
|
||||
self._read_pos = min(self._read_pos + overflow, self._write_pos)
|
||||
self._compact_buffer()
|
||||
|
||||
# Write to buffer
|
||||
end_pos = self._write_pos + chunk_len
|
||||
self._buffer[self._write_pos:end_pos] = audio_chunk
|
||||
self._write_pos = end_pos
|
||||
self._total_samples_received += chunk_len
|
||||
|
||||
return chunk_len
|
||||
|
||||
def _compact_buffer(self) -> None:
|
||||
"""Move unprocessed data to the beginning of the buffer."""
|
||||
if self._read_pos > 0:
|
||||
unprocessed_len = self._write_pos - self._read_pos
|
||||
if unprocessed_len > 0:
|
||||
self._buffer[:unprocessed_len] = self._buffer[self._read_pos:self._write_pos]
|
||||
self._write_pos = unprocessed_len
|
||||
self._read_pos = 0
|
||||
|
||||
def get_chunk_for_inference(self, min_duration_sec: float = 0.5) -> Optional[AudioChunkInfo]:
|
||||
"""
|
||||
Get the next chunk for ASR inference.
|
||||
|
||||
Returns a chunk of audio when enough data is available.
|
||||
The chunk includes overlap from the previous chunk for context.
|
||||
|
||||
Args:
|
||||
min_duration_sec: Minimum duration required to return a chunk
|
||||
|
||||
Returns:
|
||||
AudioChunkInfo if enough data is available, None otherwise
|
||||
"""
|
||||
min_samples = int(min_duration_sec * self.sample_rate)
|
||||
|
||||
with self._lock:
|
||||
available = self._write_pos - self._read_pos
|
||||
|
||||
if available < min_samples:
|
||||
return None
|
||||
|
||||
# Calculate chunk boundaries
|
||||
chunk_start = self._read_pos
|
||||
chunk_end = min(self._read_pos + self.chunk_size, self._write_pos)
|
||||
actual_chunk_size = chunk_end - chunk_start
|
||||
|
||||
# Extract audio
|
||||
audio = self._buffer[chunk_start:chunk_end].copy()
|
||||
|
||||
# Calculate timestamps based on total samples received
|
||||
base_sample = self._total_samples_received - (self._write_pos - chunk_start)
|
||||
start_sec = base_sample / self.sample_rate
|
||||
end_sec = (base_sample + actual_chunk_size) / self.sample_rate
|
||||
|
||||
return AudioChunkInfo(
|
||||
audio=audio,
|
||||
start_sample=base_sample,
|
||||
end_sample=base_sample + actual_chunk_size,
|
||||
start_sec=start_sec,
|
||||
end_sec=end_sec,
|
||||
)
|
||||
|
||||
def mark_processed(self, samples: int) -> None:
|
||||
"""
|
||||
Mark samples as processed, advancing the read position.
|
||||
|
||||
Keeps overlap_size samples for context in the next chunk.
|
||||
|
||||
Args:
|
||||
samples: Number of samples that were processed
|
||||
"""
|
||||
with self._lock:
|
||||
# Advance read position but keep overlap for context
|
||||
advance = max(0, samples - self.overlap_size)
|
||||
self._read_pos = min(self._read_pos + advance, self._write_pos)
|
||||
|
||||
def get_segment(self, start_sec: float, end_sec: float) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Get a specific time segment from the buffer.
|
||||
|
||||
Args:
|
||||
start_sec: Start time in seconds (relative to session start)
|
||||
end_sec: End time in seconds
|
||||
|
||||
Returns:
|
||||
Audio segment if available, None otherwise
|
||||
"""
|
||||
start_sample = int(start_sec * self.sample_rate)
|
||||
end_sample = int(end_sec * self.sample_rate)
|
||||
|
||||
with self._lock:
|
||||
# Calculate buffer positions
|
||||
buffer_start_sample = self._total_samples_received - self._write_pos
|
||||
buffer_end_sample = self._total_samples_received
|
||||
|
||||
# Check if segment is in buffer
|
||||
if start_sample < buffer_start_sample or end_sample > buffer_end_sample:
|
||||
return None
|
||||
|
||||
# Convert to buffer indices
|
||||
buf_start = start_sample - buffer_start_sample
|
||||
buf_end = end_sample - buffer_start_sample
|
||||
|
||||
return self._buffer[buf_start:buf_end].copy()
|
||||
|
||||
def get_all_unprocessed(self) -> Optional[AudioChunkInfo]:
|
||||
"""
|
||||
Get all unprocessed audio.
|
||||
|
||||
Returns:
|
||||
AudioChunkInfo with all unprocessed audio, or None if empty
|
||||
"""
|
||||
with self._lock:
|
||||
if self._write_pos <= self._read_pos:
|
||||
return None
|
||||
|
||||
audio = self._buffer[self._read_pos:self._write_pos].copy()
|
||||
base_sample = self._total_samples_received - (self._write_pos - self._read_pos)
|
||||
start_sec = base_sample / self.sample_rate
|
||||
end_sec = self._total_samples_received / self.sample_rate
|
||||
|
||||
return AudioChunkInfo(
|
||||
audio=audio,
|
||||
start_sample=base_sample,
|
||||
end_sample=self._total_samples_received,
|
||||
start_sec=start_sec,
|
||||
end_sec=end_sec,
|
||||
)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear the buffer and reset all positions."""
|
||||
with self._lock:
|
||||
self._buffer.fill(0)
|
||||
self._write_pos = 0
|
||||
self._read_pos = 0
|
||||
self._total_samples_received = 0
|
||||
|
||||
def reset_read_position(self) -> None:
|
||||
"""Reset read position to current write position (skip all unprocessed)."""
|
||||
with self._lock:
|
||||
self._read_pos = self._write_pos
|
||||
154
static/scripts/vibevoice-asr/realtime/models.py
Normal file
154
static/scripts/vibevoice-asr/realtime/models.py
Normal file
@ -0,0 +1,154 @@
|
||||
"""
|
||||
Data models for real-time ASR WebSocket communication.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional, List, Dict, Any
|
||||
from dataclasses import dataclass, field, asdict
|
||||
import time
|
||||
|
||||
|
||||
class MessageType(str, Enum):
|
||||
"""WebSocket message types."""
|
||||
# Client -> Server
|
||||
AUDIO_CHUNK = "audio_chunk"
|
||||
CONFIG = "config"
|
||||
START = "start"
|
||||
STOP = "stop"
|
||||
|
||||
# Server -> Client
|
||||
PARTIAL_RESULT = "partial_result"
|
||||
FINAL_RESULT = "final_result"
|
||||
VAD_EVENT = "vad_event"
|
||||
ERROR = "error"
|
||||
STATUS = "status"
|
||||
|
||||
|
||||
class VADEventType(str, Enum):
|
||||
"""VAD event types."""
|
||||
SPEECH_START = "speech_start"
|
||||
SPEECH_END = "speech_end"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionConfig:
|
||||
"""Configuration for a real-time ASR session."""
|
||||
# Audio parameters
|
||||
sample_rate: int = 16000
|
||||
chunk_duration_sec: float = 3.0
|
||||
overlap_sec: float = 0.5
|
||||
|
||||
# VAD parameters
|
||||
vad_threshold: float = 0.5
|
||||
min_speech_duration_ms: int = 250
|
||||
min_silence_duration_ms: int = 500
|
||||
min_volume_threshold: float = 0.01 # Minimum RMS volume (0.0-1.0) to consider as potential speech
|
||||
|
||||
# ASR parameters
|
||||
max_new_tokens: int = 512
|
||||
temperature: float = 0.0
|
||||
context_info: Optional[str] = None
|
||||
|
||||
# Behavior
|
||||
return_partial_results: bool = True
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "SessionConfig":
|
||||
return cls(**{k: v for k, v in data.items() if k in cls.__dataclass_fields__})
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptionSegment:
|
||||
"""A single transcription segment with metadata."""
|
||||
start_time: float
|
||||
end_time: float
|
||||
speaker_id: str
|
||||
text: str
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptionResult:
|
||||
"""Transcription result message."""
|
||||
type: MessageType
|
||||
text: str
|
||||
is_final: bool
|
||||
segments: List[TranscriptionSegment] = field(default_factory=list)
|
||||
latency_ms: float = 0.0
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": self.type.value,
|
||||
"text": self.text,
|
||||
"is_final": self.is_final,
|
||||
"segments": [s.to_dict() for s in self.segments],
|
||||
"latency_ms": self.latency_ms,
|
||||
"timestamp": self.timestamp,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class VADEvent:
|
||||
"""VAD event message."""
|
||||
type: MessageType = MessageType.VAD_EVENT
|
||||
event: VADEventType = VADEventType.SPEECH_START
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
audio_timestamp_sec: float = 0.0
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": self.type.value,
|
||||
"event": self.event.value,
|
||||
"timestamp": self.timestamp,
|
||||
"audio_timestamp_sec": self.audio_timestamp_sec,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class StatusMessage:
|
||||
"""Status message."""
|
||||
type: MessageType = MessageType.STATUS
|
||||
status: str = ""
|
||||
message: str = ""
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": self.type.value,
|
||||
"status": self.status,
|
||||
"message": self.message,
|
||||
"timestamp": self.timestamp,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ErrorMessage:
|
||||
"""Error message."""
|
||||
type: MessageType = MessageType.ERROR
|
||||
error: str = ""
|
||||
code: str = ""
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": self.type.value,
|
||||
"error": self.error,
|
||||
"code": self.code,
|
||||
"timestamp": self.timestamp,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeechSegment:
|
||||
"""Detected speech segment from VAD."""
|
||||
start_sample: int
|
||||
end_sample: int
|
||||
start_sec: float
|
||||
end_sec: float
|
||||
confidence: float = 1.0
|
||||
300
static/scripts/vibevoice-asr/realtime/server.py
Normal file
300
static/scripts/vibevoice-asr/realtime/server.py
Normal file
@ -0,0 +1,300 @@
|
||||
"""
|
||||
FastAPI WebSocket server for real-time ASR.
|
||||
|
||||
Provides WebSocket endpoint for streaming audio and receiving transcriptions.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import Optional
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import HTMLResponse, JSONResponse
|
||||
import uvicorn
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from .models import (
|
||||
SessionConfig,
|
||||
TranscriptionResult,
|
||||
VADEvent,
|
||||
StatusMessage,
|
||||
ErrorMessage,
|
||||
MessageType,
|
||||
)
|
||||
from .asr_worker import ASRWorker
|
||||
from .session_manager import SessionManager
|
||||
|
||||
|
||||
# Global instances
|
||||
asr_worker: Optional[ASRWorker] = None
|
||||
session_manager: Optional[SessionManager] = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan manager."""
|
||||
global asr_worker, session_manager
|
||||
|
||||
# Startup
|
||||
print("Starting VibeVoice Realtime ASR Server...")
|
||||
|
||||
# Get model path from environment or use default
|
||||
model_path = os.environ.get("VIBEVOICE_MODEL_PATH", "microsoft/VibeVoice-ASR")
|
||||
device = os.environ.get("VIBEVOICE_DEVICE", "cuda")
|
||||
attn_impl = os.environ.get("VIBEVOICE_ATTN_IMPL", "flash_attention_2")
|
||||
|
||||
# Initialize ASR worker
|
||||
asr_worker = ASRWorker(
|
||||
model_path=model_path,
|
||||
device=device,
|
||||
attn_implementation=attn_impl,
|
||||
)
|
||||
|
||||
# Pre-load model (optional, can be lazy-loaded on first request)
|
||||
preload = os.environ.get("VIBEVOICE_PRELOAD_MODEL", "true").lower() == "true"
|
||||
if preload:
|
||||
print("Pre-loading ASR model...")
|
||||
asr_worker.load_model()
|
||||
|
||||
# Initialize session manager
|
||||
max_sessions = int(os.environ.get("VIBEVOICE_MAX_SESSIONS", "10"))
|
||||
session_manager = SessionManager(
|
||||
asr_worker=asr_worker,
|
||||
max_concurrent_sessions=max_sessions,
|
||||
)
|
||||
await session_manager.start()
|
||||
|
||||
print("Server ready!")
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
print("Shutting down...")
|
||||
await session_manager.stop()
|
||||
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(
|
||||
title="VibeVoice Realtime ASR",
|
||||
description="Real-time speech recognition using VibeVoice ASR",
|
||||
version="0.1.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Mount static files
|
||||
static_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "static")
|
||||
if os.path.exists(static_dir):
|
||||
app.mount("/static", StaticFiles(directory=static_dir), name="static")
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint with API info."""
|
||||
return {
|
||||
"service": "VibeVoice Realtime ASR",
|
||||
"version": "0.1.0",
|
||||
"endpoints": {
|
||||
"websocket": "/ws/asr/{session_id}",
|
||||
"health": "/health",
|
||||
"stats": "/stats",
|
||||
"client": "/static/realtime_client.html",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint."""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"model_loaded": asr_worker.is_loaded if asr_worker else False,
|
||||
"active_sessions": len(session_manager._sessions) if session_manager else 0,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/stats")
|
||||
async def get_stats():
|
||||
"""Get server statistics."""
|
||||
if session_manager is None:
|
||||
raise HTTPException(status_code=503, detail="Server not initialized")
|
||||
|
||||
return session_manager.get_stats()
|
||||
|
||||
|
||||
@app.websocket("/ws/asr/{session_id}")
|
||||
async def websocket_asr(websocket: WebSocket, session_id: str):
|
||||
"""
|
||||
WebSocket endpoint for real-time ASR.
|
||||
|
||||
Protocol:
|
||||
1. Client connects and optionally sends config message
|
||||
2. Client sends binary audio chunks (PCM 16-bit, 16kHz, mono)
|
||||
3. Server sends JSON messages with transcription results
|
||||
|
||||
Message types (server -> client):
|
||||
- partial_result: Intermediate transcription
|
||||
- final_result: Complete transcription for a segment
|
||||
- vad_event: Speech start/end events
|
||||
- error: Error messages
|
||||
- status: Status updates
|
||||
"""
|
||||
await websocket.accept()
|
||||
|
||||
# Send connection confirmation
|
||||
await websocket.send_json(
|
||||
StatusMessage(
|
||||
status="connected",
|
||||
message=f"Session {session_id} connected",
|
||||
).to_dict()
|
||||
)
|
||||
|
||||
# Result callback
|
||||
async def on_result(result: TranscriptionResult):
|
||||
try:
|
||||
await websocket.send_json(result.to_dict())
|
||||
except Exception as e:
|
||||
print(f"[{session_id}] Failed to send result: {e}")
|
||||
|
||||
# VAD event callback
|
||||
async def on_vad_event(event: VADEvent):
|
||||
try:
|
||||
await websocket.send_json(event.to_dict())
|
||||
except Exception as e:
|
||||
print(f"[{session_id}] Failed to send VAD event: {e}")
|
||||
|
||||
# Create session
|
||||
session = await session_manager.create_session(
|
||||
session_id=session_id,
|
||||
on_result=on_result,
|
||||
on_vad_event=on_vad_event,
|
||||
)
|
||||
|
||||
if session is None:
|
||||
await websocket.send_json(
|
||||
ErrorMessage(
|
||||
error="Maximum sessions reached",
|
||||
code="MAX_SESSIONS",
|
||||
).to_dict()
|
||||
)
|
||||
await websocket.close()
|
||||
return
|
||||
|
||||
await websocket.send_json(
|
||||
StatusMessage(
|
||||
status="ready",
|
||||
message="Session ready for audio",
|
||||
).to_dict()
|
||||
)
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Receive message
|
||||
message = await websocket.receive()
|
||||
|
||||
if message["type"] == "websocket.disconnect":
|
||||
break
|
||||
|
||||
# Handle binary audio data
|
||||
if "bytes" in message:
|
||||
audio_data = message["bytes"]
|
||||
try:
|
||||
await session.process_audio_chunk(audio_data)
|
||||
except Exception as e:
|
||||
print(f"[{session_id}] Error processing audio: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
# Handle JSON control messages
|
||||
elif "text" in message:
|
||||
try:
|
||||
data = json.loads(message["text"])
|
||||
msg_type = data.get("type")
|
||||
|
||||
if msg_type == "config":
|
||||
# Update session config
|
||||
config = SessionConfig.from_dict(data.get("config", {}))
|
||||
session.update_config(config)
|
||||
await websocket.send_json(
|
||||
StatusMessage(
|
||||
status="config_updated",
|
||||
message="Configuration updated",
|
||||
).to_dict()
|
||||
)
|
||||
|
||||
elif msg_type == "stop":
|
||||
# Flush and close
|
||||
await session.flush()
|
||||
await websocket.send_json(
|
||||
StatusMessage(
|
||||
status="stopped",
|
||||
message="Session stopped",
|
||||
).to_dict()
|
||||
)
|
||||
break
|
||||
|
||||
elif msg_type == "ping":
|
||||
await websocket.send_json({"type": "pong", "timestamp": time.time()})
|
||||
|
||||
except json.JSONDecodeError:
|
||||
await websocket.send_json(
|
||||
ErrorMessage(
|
||||
error="Invalid JSON",
|
||||
code="INVALID_JSON",
|
||||
).to_dict()
|
||||
)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
print(f"[{session_id}] Client disconnected")
|
||||
except Exception as e:
|
||||
print(f"[{session_id}] Error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
# Clean up session
|
||||
await session_manager.close_session(session_id)
|
||||
print(f"[{session_id}] Session closed")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="VibeVoice Realtime ASR Server")
|
||||
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
|
||||
parser.add_argument("--port", type=int, default=8000, help="Port to bind to")
|
||||
parser.add_argument("--model-path", type=str, default="microsoft/VibeVoice-ASR",
|
||||
help="Path to VibeVoice ASR model")
|
||||
parser.add_argument("--device", type=str, default="cuda", help="Device (cuda/cpu)")
|
||||
parser.add_argument("--max-sessions", type=int, default=10, help="Max concurrent sessions")
|
||||
parser.add_argument("--no-preload", action="store_true", help="Don't preload model")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set environment variables for lifespan
|
||||
os.environ["VIBEVOICE_MODEL_PATH"] = args.model_path
|
||||
os.environ["VIBEVOICE_DEVICE"] = args.device
|
||||
os.environ["VIBEVOICE_MAX_SESSIONS"] = str(args.max_sessions)
|
||||
os.environ["VIBEVOICE_PRELOAD_MODEL"] = "false" if args.no_preload else "true"
|
||||
|
||||
print(f"Starting server on {args.host}:{args.port}")
|
||||
print(f"Model: {args.model_path}")
|
||||
print(f"Device: {args.device}")
|
||||
print(f"Max sessions: {args.max_sessions}")
|
||||
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level="info",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
401
static/scripts/vibevoice-asr/realtime/session_manager.py
Normal file
401
static/scripts/vibevoice-asr/realtime/session_manager.py
Normal file
@ -0,0 +1,401 @@
|
||||
"""
|
||||
Session manager for real-time ASR.
|
||||
|
||||
Manages multiple concurrent client sessions with resource isolation.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from typing import Dict, Optional, Callable, Any
|
||||
from dataclasses import dataclass, field
|
||||
import threading
|
||||
|
||||
from .models import (
|
||||
SessionConfig,
|
||||
TranscriptionResult,
|
||||
VADEvent,
|
||||
StatusMessage,
|
||||
ErrorMessage,
|
||||
MessageType,
|
||||
SpeechSegment,
|
||||
)
|
||||
from .audio_buffer import AudioBuffer
|
||||
from .vad_processor import VADProcessor
|
||||
from .asr_worker import ASRWorker
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionStats:
|
||||
"""Statistics for a session."""
|
||||
created_at: float = field(default_factory=time.time)
|
||||
last_activity: float = field(default_factory=time.time)
|
||||
audio_received_sec: float = 0.0
|
||||
chunks_received: int = 0
|
||||
segments_transcribed: int = 0
|
||||
total_latency_ms: float = 0.0
|
||||
|
||||
|
||||
class RealtimeSession:
|
||||
"""
|
||||
A single real-time ASR session.
|
||||
|
||||
Manages audio buffering, VAD, and ASR for one client connection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
asr_worker: ASRWorker,
|
||||
config: Optional[SessionConfig] = None,
|
||||
on_result: Optional[Callable[[TranscriptionResult], Any]] = None,
|
||||
on_vad_event: Optional[Callable[[VADEvent], Any]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize a session.
|
||||
|
||||
Args:
|
||||
session_id: Unique session identifier
|
||||
asr_worker: Shared ASR worker instance
|
||||
config: Session configuration
|
||||
on_result: Callback for transcription results
|
||||
on_vad_event: Callback for VAD events
|
||||
"""
|
||||
self.session_id = session_id
|
||||
self.asr_worker = asr_worker
|
||||
self.config = config or SessionConfig()
|
||||
self.on_result = on_result
|
||||
self.on_vad_event = on_vad_event
|
||||
|
||||
# Components
|
||||
self.audio_buffer = AudioBuffer(
|
||||
sample_rate=self.config.sample_rate,
|
||||
chunk_duration_sec=self.config.chunk_duration_sec,
|
||||
overlap_sec=self.config.overlap_sec,
|
||||
)
|
||||
|
||||
self.vad_processor = VADProcessor(
|
||||
sample_rate=self.config.sample_rate,
|
||||
threshold=self.config.vad_threshold,
|
||||
min_speech_duration_ms=self.config.min_speech_duration_ms,
|
||||
min_silence_duration_ms=self.config.min_silence_duration_ms,
|
||||
min_volume_threshold=self.config.min_volume_threshold,
|
||||
)
|
||||
|
||||
# State
|
||||
self.is_active = True
|
||||
self.stats = SessionStats()
|
||||
self._processing_lock = asyncio.Lock()
|
||||
self._pending_tasks: list = []
|
||||
|
||||
async def process_audio_chunk(self, audio_data: bytes) -> None:
|
||||
"""
|
||||
Process an incoming audio chunk.
|
||||
|
||||
Args:
|
||||
audio_data: Raw PCM audio data (16-bit, 16kHz, mono)
|
||||
"""
|
||||
if not self.is_active:
|
||||
return
|
||||
|
||||
self.stats.last_activity = time.time()
|
||||
self.stats.chunks_received += 1
|
||||
|
||||
# Convert bytes to float32 array
|
||||
import numpy as np
|
||||
audio_int16 = np.frombuffer(audio_data, dtype=np.int16)
|
||||
audio_float = audio_int16.astype(np.float32) / 32768.0
|
||||
|
||||
self.stats.audio_received_sec += len(audio_float) / self.config.sample_rate
|
||||
|
||||
# Add to buffer
|
||||
self.audio_buffer.append(audio_float)
|
||||
|
||||
# Process with VAD
|
||||
segments, events = self.vad_processor.process(audio_float)
|
||||
|
||||
# Send VAD events
|
||||
if self.on_vad_event:
|
||||
for event in events:
|
||||
await self._send_callback(self.on_vad_event, event)
|
||||
|
||||
# Process completed speech segments
|
||||
for segment in segments:
|
||||
try:
|
||||
await self._transcribe_segment(segment)
|
||||
except Exception as e:
|
||||
print(f"[Session {self.session_id}] Transcription error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
async def _transcribe_segment(self, segment: SpeechSegment) -> None:
|
||||
"""Transcribe a detected speech segment."""
|
||||
# Get audio for segment from buffer
|
||||
audio = self.audio_buffer.get_segment(segment.start_sec, segment.end_sec)
|
||||
|
||||
if audio is None or len(audio) == 0:
|
||||
print(f"[Session {self.session_id}] Could not retrieve audio for segment")
|
||||
return
|
||||
|
||||
self.stats.segments_transcribed += 1
|
||||
|
||||
async def on_partial(result: TranscriptionResult):
|
||||
if self.on_result:
|
||||
await self._send_callback(self.on_result, result)
|
||||
|
||||
# Run transcription
|
||||
result = await self.asr_worker.transcribe_segment(
|
||||
audio=audio,
|
||||
sample_rate=self.config.sample_rate,
|
||||
context_info=self.config.context_info,
|
||||
segment_start_sec=segment.start_sec,
|
||||
segment_end_sec=segment.end_sec,
|
||||
config=self.config,
|
||||
on_partial=on_partial if self.config.return_partial_results else None,
|
||||
)
|
||||
|
||||
self.stats.total_latency_ms += result.latency_ms
|
||||
|
||||
# Send final result
|
||||
if self.on_result:
|
||||
await self._send_callback(self.on_result, result)
|
||||
|
||||
async def _send_callback(self, callback: Callable, data: Any) -> None:
|
||||
"""Send data via callback, handling both sync and async."""
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(data)
|
||||
else:
|
||||
callback(data)
|
||||
except Exception as e:
|
||||
print(f"[Session {self.session_id}] Callback error: {e}")
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""
|
||||
Flush any remaining audio and force transcription.
|
||||
|
||||
Called when session ends to process any remaining speech.
|
||||
"""
|
||||
# Force end any active speech
|
||||
segment = self.vad_processor.force_end_speech()
|
||||
if segment:
|
||||
await self._transcribe_segment(segment)
|
||||
|
||||
# Also check for any unprocessed audio in buffer
|
||||
chunk_info = self.audio_buffer.get_all_unprocessed()
|
||||
if chunk_info and len(chunk_info.audio) > self.config.sample_rate * 0.5:
|
||||
# More than 0.5 seconds of unprocessed audio
|
||||
forced_segment = SpeechSegment(
|
||||
start_sample=chunk_info.start_sample,
|
||||
end_sample=chunk_info.end_sample,
|
||||
start_sec=chunk_info.start_sec,
|
||||
end_sec=chunk_info.end_sec,
|
||||
)
|
||||
await self._transcribe_segment(forced_segment)
|
||||
|
||||
def update_config(self, new_config: SessionConfig) -> None:
|
||||
"""Update session configuration (partial update supported)."""
|
||||
# Merge with existing config - only update non-default values
|
||||
if new_config.vad_threshold != 0.5:
|
||||
self.config.vad_threshold = new_config.vad_threshold
|
||||
if new_config.min_speech_duration_ms != 250:
|
||||
self.config.min_speech_duration_ms = new_config.min_speech_duration_ms
|
||||
if new_config.min_silence_duration_ms != 500:
|
||||
self.config.min_silence_duration_ms = new_config.min_silence_duration_ms
|
||||
if new_config.min_volume_threshold != 0.01:
|
||||
self.config.min_volume_threshold = new_config.min_volume_threshold
|
||||
if new_config.context_info is not None:
|
||||
self.config.context_info = new_config.context_info
|
||||
|
||||
# Recreate VAD processor with new parameters
|
||||
self.vad_processor = VADProcessor(
|
||||
sample_rate=self.config.sample_rate,
|
||||
threshold=self.config.vad_threshold,
|
||||
min_speech_duration_ms=self.config.min_speech_duration_ms,
|
||||
min_silence_duration_ms=self.config.min_silence_duration_ms,
|
||||
min_volume_threshold=self.config.min_volume_threshold,
|
||||
)
|
||||
print(f"[Session {self.session_id}] Config updated: vad_threshold={self.config.vad_threshold}, "
|
||||
f"min_speech={self.config.min_speech_duration_ms}ms, min_silence={self.config.min_silence_duration_ms}ms, "
|
||||
f"min_volume={self.config.min_volume_threshold}")
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the session and release resources."""
|
||||
self.is_active = False
|
||||
self.audio_buffer.clear()
|
||||
self.vad_processor.reset()
|
||||
|
||||
def get_stats(self) -> Dict:
|
||||
"""Get session statistics."""
|
||||
return {
|
||||
"session_id": self.session_id,
|
||||
"created_at": self.stats.created_at,
|
||||
"last_activity": self.stats.last_activity,
|
||||
"duration_sec": time.time() - self.stats.created_at,
|
||||
"audio_received_sec": self.stats.audio_received_sec,
|
||||
"chunks_received": self.stats.chunks_received,
|
||||
"segments_transcribed": self.stats.segments_transcribed,
|
||||
"avg_latency_ms": (
|
||||
self.stats.total_latency_ms / self.stats.segments_transcribed
|
||||
if self.stats.segments_transcribed > 0 else 0
|
||||
),
|
||||
"is_active": self.is_active,
|
||||
"vad_speech_active": self.vad_processor.is_speech_active,
|
||||
}
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""
|
||||
Manages multiple concurrent ASR sessions.
|
||||
|
||||
Features:
|
||||
- Session creation and cleanup
|
||||
- Resource limiting (max concurrent sessions)
|
||||
- Idle session timeout
|
||||
- Shared ASR worker management
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
asr_worker: ASRWorker,
|
||||
max_concurrent_sessions: int = 10,
|
||||
session_timeout_sec: float = 300.0,
|
||||
):
|
||||
"""
|
||||
Initialize the session manager.
|
||||
|
||||
Args:
|
||||
asr_worker: Shared ASR worker
|
||||
max_concurrent_sessions: Maximum number of concurrent sessions
|
||||
session_timeout_sec: Timeout for idle sessions
|
||||
"""
|
||||
self.asr_worker = asr_worker
|
||||
self.max_sessions = max_concurrent_sessions
|
||||
self.session_timeout = session_timeout_sec
|
||||
|
||||
self._sessions: Dict[str, RealtimeSession] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Cleanup task
|
||||
self._cleanup_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the session manager."""
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the session manager and close all sessions."""
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async with self._lock:
|
||||
for session in self._sessions.values():
|
||||
session.close()
|
||||
self._sessions.clear()
|
||||
|
||||
async def create_session(
|
||||
self,
|
||||
session_id: Optional[str] = None,
|
||||
config: Optional[SessionConfig] = None,
|
||||
on_result: Optional[Callable[[TranscriptionResult], Any]] = None,
|
||||
on_vad_event: Optional[Callable[[VADEvent], Any]] = None,
|
||||
) -> Optional[RealtimeSession]:
|
||||
"""
|
||||
Create a new session.
|
||||
|
||||
Args:
|
||||
session_id: Optional session ID (generated if not provided)
|
||||
config: Session configuration
|
||||
on_result: Callback for results
|
||||
on_vad_event: Callback for VAD events
|
||||
|
||||
Returns:
|
||||
Created session, or None if limit reached
|
||||
"""
|
||||
async with self._lock:
|
||||
# Check session limit
|
||||
if len(self._sessions) >= self.max_sessions:
|
||||
return None
|
||||
|
||||
# Generate session ID if not provided
|
||||
if session_id is None:
|
||||
session_id = str(uuid.uuid4())[:8]
|
||||
|
||||
# Check for duplicate
|
||||
if session_id in self._sessions:
|
||||
return self._sessions[session_id]
|
||||
|
||||
# Create session
|
||||
session = RealtimeSession(
|
||||
session_id=session_id,
|
||||
asr_worker=self.asr_worker,
|
||||
config=config,
|
||||
on_result=on_result,
|
||||
on_vad_event=on_vad_event,
|
||||
)
|
||||
|
||||
self._sessions[session_id] = session
|
||||
return session
|
||||
|
||||
async def get_session(self, session_id: str) -> Optional[RealtimeSession]:
|
||||
"""Get a session by ID."""
|
||||
async with self._lock:
|
||||
return self._sessions.get(session_id)
|
||||
|
||||
async def close_session(self, session_id: str) -> bool:
|
||||
"""
|
||||
Close and remove a session.
|
||||
|
||||
Args:
|
||||
session_id: Session to close
|
||||
|
||||
Returns:
|
||||
True if session was found and closed
|
||||
"""
|
||||
async with self._lock:
|
||||
session = self._sessions.pop(session_id, None)
|
||||
if session:
|
||||
await session.flush()
|
||||
session.close()
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _cleanup_loop(self) -> None:
|
||||
"""Background task to clean up idle sessions."""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(60) # Check every minute
|
||||
|
||||
current_time = time.time()
|
||||
sessions_to_close = []
|
||||
|
||||
async with self._lock:
|
||||
for session_id, session in self._sessions.items():
|
||||
idle_time = current_time - session.stats.last_activity
|
||||
if idle_time > self.session_timeout:
|
||||
sessions_to_close.append(session_id)
|
||||
|
||||
for session_id in sessions_to_close:
|
||||
print(f"Closing idle session: {session_id}")
|
||||
await self.close_session(session_id)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Cleanup error: {e}")
|
||||
|
||||
def get_stats(self) -> Dict:
|
||||
"""Get manager statistics."""
|
||||
return {
|
||||
"active_sessions": len(self._sessions),
|
||||
"max_sessions": self.max_sessions,
|
||||
"session_timeout_sec": self.session_timeout,
|
||||
"sessions": {
|
||||
sid: session.get_stats()
|
||||
for sid, session in self._sessions.items()
|
||||
},
|
||||
}
|
||||
295
static/scripts/vibevoice-asr/realtime/vad_processor.py
Normal file
295
static/scripts/vibevoice-asr/realtime/vad_processor.py
Normal file
@ -0,0 +1,295 @@
|
||||
"""
|
||||
Voice Activity Detection (VAD) processor using Silero-VAD (ONNX version).
|
||||
|
||||
Detects speech segments in real-time audio streams.
|
||||
Uses ONNX runtime to avoid torchaudio dependency issues.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from typing import List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
import threading
|
||||
import os
|
||||
import urllib.request
|
||||
|
||||
from .models import SpeechSegment, VADEvent, VADEventType, MessageType
|
||||
|
||||
|
||||
@dataclass
|
||||
class VADState:
|
||||
"""Internal state of the VAD processor."""
|
||||
is_speech_active: bool = False
|
||||
speech_start_sample: int = 0
|
||||
silence_start_sample: int = 0
|
||||
last_speech_prob: float = 0.0
|
||||
total_samples_processed: int = 0
|
||||
|
||||
|
||||
class VADProcessor:
|
||||
"""
|
||||
Voice Activity Detection using Silero-VAD (ONNX version).
|
||||
|
||||
Features:
|
||||
- Real-time speech detection
|
||||
- Configurable thresholds for speech/silence duration
|
||||
- Event generation for speech start/end
|
||||
- Thread-safe operations
|
||||
- No torchaudio dependency (uses ONNX runtime)
|
||||
"""
|
||||
|
||||
# Silero VAD ONNX model URL
|
||||
ONNX_MODEL_URL = "https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sample_rate: int = 16000,
|
||||
threshold: float = 0.5,
|
||||
min_speech_duration_ms: int = 250,
|
||||
min_silence_duration_ms: int = 500,
|
||||
window_size_samples: int = 512,
|
||||
min_volume_threshold: float = 0.01,
|
||||
):
|
||||
"""
|
||||
Initialize the VAD processor.
|
||||
|
||||
Args:
|
||||
sample_rate: Audio sample rate (must be 16000 for Silero-VAD)
|
||||
threshold: Speech probability threshold (0.0-1.0)
|
||||
min_speech_duration_ms: Minimum speech duration to trigger speech_start
|
||||
min_silence_duration_ms: Minimum silence duration to trigger speech_end
|
||||
window_size_samples: VAD window size (512 for 16kHz = 32ms)
|
||||
min_volume_threshold: Minimum RMS volume (0.0-1.0) to consider as potential speech
|
||||
"""
|
||||
if sample_rate != 16000:
|
||||
raise ValueError("Silero-VAD requires 16kHz sample rate")
|
||||
|
||||
self.sample_rate = sample_rate
|
||||
self.threshold = threshold
|
||||
self.min_speech_samples = int(min_speech_duration_ms * sample_rate / 1000)
|
||||
self.min_silence_samples = int(min_silence_duration_ms * sample_rate / 1000)
|
||||
self.window_size = window_size_samples
|
||||
self.min_volume_threshold = min_volume_threshold
|
||||
|
||||
# Load ONNX model
|
||||
self._session = None
|
||||
self._load_model()
|
||||
|
||||
# ONNX model state - single state tensor (size depends on model version)
|
||||
# Silero VAD v5 uses a single 'state' tensor of shape (2, 1, 128)
|
||||
self._state_tensor = np.zeros((2, 1, 128), dtype=np.float32)
|
||||
|
||||
# State
|
||||
self._state = VADState()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# Pending speech segment (being accumulated)
|
||||
self._pending_segment_start: Optional[int] = None
|
||||
|
||||
def _get_model_path(self) -> str:
|
||||
"""Get path to ONNX model, downloading if necessary."""
|
||||
cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "silero-vad")
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
model_path = os.path.join(cache_dir, "silero_vad.onnx")
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
print(f"Downloading Silero-VAD ONNX model to {model_path}...")
|
||||
urllib.request.urlretrieve(self.ONNX_MODEL_URL, model_path)
|
||||
print("Download complete.")
|
||||
|
||||
return model_path
|
||||
|
||||
def _load_model(self) -> None:
|
||||
"""Load Silero-VAD ONNX model."""
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
|
||||
model_path = self._get_model_path()
|
||||
self._session = ort.InferenceSession(
|
||||
model_path,
|
||||
providers=['CPUExecutionProvider']
|
||||
)
|
||||
print(f"Silero-VAD ONNX model loaded from {model_path}")
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load Silero-VAD ONNX model: {e}")
|
||||
|
||||
def _run_inference(self, audio_window: np.ndarray) -> float:
|
||||
"""Run VAD inference on a single window."""
|
||||
# Prepare input
|
||||
audio_input = audio_window.reshape(1, -1).astype(np.float32)
|
||||
sr_input = np.array([self.sample_rate], dtype=np.int64)
|
||||
|
||||
# Run inference - Silero VAD v5 uses 'state' instead of 'h'/'c'
|
||||
outputs = self._session.run(
|
||||
['output', 'stateN'],
|
||||
{
|
||||
'input': audio_input,
|
||||
'sr': sr_input,
|
||||
'state': self._state_tensor,
|
||||
}
|
||||
)
|
||||
|
||||
# Update state
|
||||
speech_prob = outputs[0][0][0]
|
||||
self._state_tensor = outputs[1]
|
||||
|
||||
return float(speech_prob)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset VAD state for a new session."""
|
||||
with self._lock:
|
||||
self._state = VADState()
|
||||
self._pending_segment_start = None
|
||||
# Reset state tensor
|
||||
self._state_tensor = np.zeros((2, 1, 128), dtype=np.float32)
|
||||
|
||||
def process(
|
||||
self,
|
||||
audio_chunk: np.ndarray,
|
||||
return_events: bool = True,
|
||||
) -> Tuple[List[SpeechSegment], List[VADEvent]]:
|
||||
"""
|
||||
Process an audio chunk and detect speech segments.
|
||||
|
||||
Args:
|
||||
audio_chunk: Audio data as float32 array
|
||||
return_events: Whether to return VAD events
|
||||
|
||||
Returns:
|
||||
Tuple of (completed_segments, events)
|
||||
"""
|
||||
if audio_chunk.dtype != np.float32:
|
||||
audio_chunk = audio_chunk.astype(np.float32)
|
||||
|
||||
completed_segments: List[SpeechSegment] = []
|
||||
events: List[VADEvent] = []
|
||||
|
||||
with self._lock:
|
||||
# Process in windows
|
||||
chunk_start_sample = self._state.total_samples_processed
|
||||
num_windows = len(audio_chunk) // self.window_size
|
||||
|
||||
for i in range(num_windows):
|
||||
window_start = i * self.window_size
|
||||
window_end = window_start + self.window_size
|
||||
window = audio_chunk[window_start:window_end]
|
||||
|
||||
# Check volume (RMS) threshold first
|
||||
rms = np.sqrt(np.mean(window ** 2))
|
||||
if rms < self.min_volume_threshold:
|
||||
# Volume too low, treat as silence
|
||||
speech_prob = 0.0
|
||||
else:
|
||||
# Get speech probability from VAD model
|
||||
speech_prob = self._run_inference(window)
|
||||
|
||||
self._state.last_speech_prob = speech_prob
|
||||
|
||||
current_sample = chunk_start_sample + window_end
|
||||
is_speech = speech_prob >= self.threshold
|
||||
|
||||
# State machine for speech detection
|
||||
if is_speech:
|
||||
if not self._state.is_speech_active:
|
||||
# Potential speech start
|
||||
if self._pending_segment_start is None:
|
||||
self._pending_segment_start = current_sample - self.window_size
|
||||
|
||||
# Check if speech duration exceeds minimum
|
||||
speech_duration = current_sample - self._pending_segment_start
|
||||
if speech_duration >= self.min_speech_samples:
|
||||
self._state.is_speech_active = True
|
||||
self._state.speech_start_sample = self._pending_segment_start
|
||||
|
||||
if return_events:
|
||||
events.append(VADEvent(
|
||||
type=MessageType.VAD_EVENT,
|
||||
event=VADEventType.SPEECH_START,
|
||||
audio_timestamp_sec=self._pending_segment_start / self.sample_rate,
|
||||
))
|
||||
else:
|
||||
# Continue speech, reset silence counter
|
||||
self._state.silence_start_sample = 0
|
||||
else:
|
||||
if self._state.is_speech_active:
|
||||
# Potential speech end
|
||||
if self._state.silence_start_sample == 0:
|
||||
self._state.silence_start_sample = current_sample
|
||||
|
||||
# Check if silence duration exceeds minimum
|
||||
silence_duration = current_sample - self._state.silence_start_sample
|
||||
if silence_duration >= self.min_silence_samples:
|
||||
# Speech ended - create completed segment
|
||||
segment = SpeechSegment(
|
||||
start_sample=self._state.speech_start_sample,
|
||||
end_sample=self._state.silence_start_sample,
|
||||
start_sec=self._state.speech_start_sample / self.sample_rate,
|
||||
end_sec=self._state.silence_start_sample / self.sample_rate,
|
||||
)
|
||||
completed_segments.append(segment)
|
||||
|
||||
if return_events:
|
||||
events.append(VADEvent(
|
||||
type=MessageType.VAD_EVENT,
|
||||
event=VADEventType.SPEECH_END,
|
||||
audio_timestamp_sec=self._state.silence_start_sample / self.sample_rate,
|
||||
))
|
||||
|
||||
# Reset state
|
||||
self._state.is_speech_active = False
|
||||
self._state.speech_start_sample = 0
|
||||
self._state.silence_start_sample = 0
|
||||
self._pending_segment_start = None
|
||||
else:
|
||||
# No speech, reset pending
|
||||
self._pending_segment_start = None
|
||||
|
||||
# Update total samples processed
|
||||
self._state.total_samples_processed += len(audio_chunk)
|
||||
|
||||
return completed_segments, events
|
||||
|
||||
def force_end_speech(self) -> Optional[SpeechSegment]:
|
||||
"""
|
||||
Force end of current speech segment (e.g., when session ends).
|
||||
|
||||
Returns:
|
||||
Completed speech segment if speech was active, None otherwise
|
||||
"""
|
||||
with self._lock:
|
||||
if self._state.is_speech_active:
|
||||
segment = SpeechSegment(
|
||||
start_sample=self._state.speech_start_sample,
|
||||
end_sample=self._state.total_samples_processed,
|
||||
start_sec=self._state.speech_start_sample / self.sample_rate,
|
||||
end_sec=self._state.total_samples_processed / self.sample_rate,
|
||||
)
|
||||
|
||||
self._state.is_speech_active = False
|
||||
self._state.speech_start_sample = 0
|
||||
self._state.silence_start_sample = 0
|
||||
self._pending_segment_start = None
|
||||
|
||||
return segment
|
||||
|
||||
return None
|
||||
|
||||
@property
|
||||
def is_speech_active(self) -> bool:
|
||||
"""Check if speech is currently active."""
|
||||
with self._lock:
|
||||
return self._state.is_speech_active
|
||||
|
||||
@property
|
||||
def last_speech_probability(self) -> float:
|
||||
"""Get the last computed speech probability."""
|
||||
with self._lock:
|
||||
return self._state.last_speech_prob
|
||||
|
||||
@property
|
||||
def current_speech_duration_sec(self) -> float:
|
||||
"""Get duration of current speech segment (if active)."""
|
||||
with self._lock:
|
||||
if not self._state.is_speech_active:
|
||||
return 0.0
|
||||
return (self._state.total_samples_processed - self._state.speech_start_sample) / self.sample_rate
|
||||
7
static/scripts/vibevoice-asr/requirements-realtime.txt
Normal file
7
static/scripts/vibevoice-asr/requirements-realtime.txt
Normal file
@ -0,0 +1,7 @@
|
||||
# Real-time ASR dependencies
|
||||
fastapi>=0.100.0
|
||||
uvicorn[standard]>=0.23.0
|
||||
websockets>=11.0
|
||||
numpy>=1.24.0
|
||||
soundfile>=0.12.0
|
||||
onnxruntime
|
||||
73
static/scripts/vibevoice-asr/run_all.sh
Executable file
73
static/scripts/vibevoice-asr/run_all.sh
Executable file
@ -0,0 +1,73 @@
|
||||
#!/bin/bash
|
||||
# Run both Gradio demo and Realtime ASR server
|
||||
#
|
||||
# Usage:
|
||||
# ./run_all.sh
|
||||
#
|
||||
# Ports:
|
||||
# - 7860: Gradio UI (batch ASR)
|
||||
# - 8000: WebSocket API (realtime ASR)
|
||||
|
||||
set -e
|
||||
|
||||
cd "$(dirname "$0")"
|
||||
|
||||
# Configuration
|
||||
GRADIO_HOST="${GRADIO_HOST:-0.0.0.0}"
|
||||
GRADIO_PORT="${GRADIO_PORT:-7860}"
|
||||
REALTIME_HOST="${REALTIME_HOST:-0.0.0.0}"
|
||||
REALTIME_PORT="${REALTIME_PORT:-8000}"
|
||||
MODEL_PATH="${VIBEVOICE_MODEL_PATH:-microsoft/VibeVoice-ASR}"
|
||||
|
||||
echo "=========================================="
|
||||
echo "VibeVoice ASR - All Services"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
echo "Starting services:"
|
||||
echo " - Gradio UI: http://$GRADIO_HOST:$GRADIO_PORT"
|
||||
echo " - Realtime ASR: http://$REALTIME_HOST:$REALTIME_PORT"
|
||||
echo " - Test Client: http://$REALTIME_HOST:$REALTIME_PORT/static/realtime_client.html"
|
||||
echo ""
|
||||
echo "Model: $MODEL_PATH"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
# Trap to clean up background processes on exit
|
||||
cleanup() {
|
||||
echo ""
|
||||
echo "Shutting down..."
|
||||
kill $REALTIME_PID 2>/dev/null || true
|
||||
kill $GRADIO_PID 2>/dev/null || true
|
||||
wait
|
||||
echo "All services stopped."
|
||||
}
|
||||
trap cleanup EXIT INT TERM
|
||||
|
||||
# Start Realtime ASR server in background
|
||||
echo "[1/2] Starting Realtime ASR server..."
|
||||
python -m realtime.server \
|
||||
--host "$REALTIME_HOST" \
|
||||
--port "$REALTIME_PORT" \
|
||||
--model-path "$MODEL_PATH" \
|
||||
--no-preload &
|
||||
REALTIME_PID=$!
|
||||
|
||||
# Wait a moment for the server to initialize
|
||||
sleep 2
|
||||
|
||||
# Start Gradio demo in background
|
||||
echo "[2/2] Starting Gradio demo..."
|
||||
python demo/vibevoice_asr_gradio_demo.py \
|
||||
--host "$GRADIO_HOST" \
|
||||
--port "$GRADIO_PORT" \
|
||||
--model_path "$MODEL_PATH" &
|
||||
GRADIO_PID=$!
|
||||
|
||||
echo ""
|
||||
echo "Both services started. Press Ctrl+C to stop."
|
||||
echo ""
|
||||
|
||||
# Wait for either process to exit
|
||||
wait -n $REALTIME_PID $GRADIO_PID
|
||||
|
||||
# If one exits, the trap will clean up the other
|
||||
41
static/scripts/vibevoice-asr/run_realtime.sh
Executable file
41
static/scripts/vibevoice-asr/run_realtime.sh
Executable file
@ -0,0 +1,41 @@
|
||||
#!/bin/bash
|
||||
# Run VibeVoice Realtime ASR Server
|
||||
#
|
||||
# Usage:
|
||||
# ./run_realtime.sh [options]
|
||||
#
|
||||
# Options are passed to the server (see --help for details)
|
||||
|
||||
set -e
|
||||
|
||||
cd "$(dirname "$0")"
|
||||
|
||||
# Default options
|
||||
HOST="${VIBEVOICE_HOST:-0.0.0.0}"
|
||||
PORT="${VIBEVOICE_PORT:-8000}"
|
||||
MODEL_PATH="${VIBEVOICE_MODEL_PATH:-microsoft/VibeVoice-ASR}"
|
||||
DEVICE="${VIBEVOICE_DEVICE:-cuda}"
|
||||
MAX_SESSIONS="${VIBEVOICE_MAX_SESSIONS:-10}"
|
||||
|
||||
echo "=========================================="
|
||||
echo "VibeVoice Realtime ASR Server"
|
||||
echo "=========================================="
|
||||
echo "Host: $HOST"
|
||||
echo "Port: $PORT"
|
||||
echo "Model: $MODEL_PATH"
|
||||
echo "Device: $DEVICE"
|
||||
echo "Max Sessions: $MAX_SESSIONS"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
echo "Web client: http://$HOST:$PORT/static/realtime_client.html"
|
||||
echo "WebSocket: ws://$HOST:$PORT/ws/asr/{session_id}"
|
||||
echo ""
|
||||
|
||||
# Run server
|
||||
python -m realtime.server \
|
||||
--host "$HOST" \
|
||||
--port "$PORT" \
|
||||
--model-path "$MODEL_PATH" \
|
||||
--device "$DEVICE" \
|
||||
--max-sessions "$MAX_SESSIONS" \
|
||||
"$@"
|
||||
287
static/scripts/vibevoice-asr/setup.sh
Normal file
287
static/scripts/vibevoice-asr/setup.sh
Normal file
@ -0,0 +1,287 @@
|
||||
#!/bin/bash
|
||||
# VibeVoice-ASR Setup Script for DGX Spark
|
||||
# Downloads and builds the VibeVoice-ASR container
|
||||
#
|
||||
# Usage:
|
||||
# curl -sL https://docs.techswan.online/scripts/vibevoice-asr/setup.sh | bash
|
||||
# curl -sL https://docs.techswan.online/scripts/vibevoice-asr/setup.sh | bash -s build
|
||||
# curl -sL https://docs.techswan.online/scripts/vibevoice-asr/setup.sh | bash -s serve
|
||||
|
||||
set -e
|
||||
|
||||
BASE_URL="https://docs.techswan.online/scripts/vibevoice-asr"
|
||||
INSTALL_DIR="${VIBEVOICE_DIR:-$HOME/vibevoice-asr}"
|
||||
IMAGE_NAME="vibevoice-asr:dgx-spark"
|
||||
CONTAINER_NAME="vibevoice-asr"
|
||||
|
||||
# Colors for output
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
BLUE='\033[0;34m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
log_info() {
|
||||
echo -e "${GREEN}[INFO]${NC} $1"
|
||||
}
|
||||
|
||||
log_warn() {
|
||||
echo -e "${YELLOW}[WARN]${NC} $1"
|
||||
}
|
||||
|
||||
log_error() {
|
||||
echo -e "${RED}[ERROR]${NC} $1"
|
||||
}
|
||||
|
||||
log_step() {
|
||||
echo -e "${BLUE}[STEP]${NC} $1"
|
||||
}
|
||||
|
||||
download_file() {
|
||||
local url="$1"
|
||||
local dest="$2"
|
||||
local dir=$(dirname "$dest")
|
||||
|
||||
mkdir -p "$dir"
|
||||
|
||||
if command -v curl &> /dev/null; then
|
||||
curl -sL "$url" -o "$dest"
|
||||
elif command -v wget &> /dev/null; then
|
||||
wget -q "$url" -O "$dest"
|
||||
else
|
||||
log_error "curl or wget is required"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
download_files() {
|
||||
log_step "Downloading VibeVoice-ASR files to $INSTALL_DIR..."
|
||||
|
||||
mkdir -p "$INSTALL_DIR"
|
||||
cd "$INSTALL_DIR"
|
||||
|
||||
# Core files
|
||||
local files=(
|
||||
"Dockerfile"
|
||||
"requirements-realtime.txt"
|
||||
"test_vibevoice.py"
|
||||
"vibevoice_asr_gradio_demo_patched.py"
|
||||
"run_realtime.sh"
|
||||
"run_all.sh"
|
||||
)
|
||||
|
||||
for file in "${files[@]}"; do
|
||||
log_info "Downloading $file..."
|
||||
download_file "$BASE_URL/$file" "$INSTALL_DIR/$file"
|
||||
done
|
||||
|
||||
# Realtime module
|
||||
local realtime_files=(
|
||||
"__init__.py"
|
||||
"models.py"
|
||||
"server.py"
|
||||
"asr_worker.py"
|
||||
"session_manager.py"
|
||||
"audio_buffer.py"
|
||||
"vad_processor.py"
|
||||
)
|
||||
|
||||
mkdir -p "$INSTALL_DIR/realtime"
|
||||
for file in "${realtime_files[@]}"; do
|
||||
log_info "Downloading realtime/$file..."
|
||||
download_file "$BASE_URL/realtime/$file" "$INSTALL_DIR/realtime/$file"
|
||||
done
|
||||
|
||||
# Static files
|
||||
mkdir -p "$INSTALL_DIR/static"
|
||||
log_info "Downloading static/realtime_client.html..."
|
||||
download_file "$BASE_URL/static/realtime_client.html" "$INSTALL_DIR/static/realtime_client.html"
|
||||
|
||||
# Make scripts executable
|
||||
chmod +x "$INSTALL_DIR/run_realtime.sh" "$INSTALL_DIR/run_all.sh"
|
||||
|
||||
log_info "All files downloaded to $INSTALL_DIR"
|
||||
}
|
||||
|
||||
check_prerequisites() {
|
||||
log_step "Checking prerequisites..."
|
||||
|
||||
# Check Docker
|
||||
if ! command -v docker &> /dev/null; then
|
||||
log_error "Docker is not installed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check NVIDIA Docker runtime
|
||||
if ! docker info 2>/dev/null | grep -q "Runtimes.*nvidia"; then
|
||||
log_warn "NVIDIA Docker runtime may not be configured"
|
||||
fi
|
||||
|
||||
# Check GPU availability
|
||||
if command -v nvidia-smi &> /dev/null; then
|
||||
log_info "GPU detected:"
|
||||
nvidia-smi --query-gpu=name,memory.total --format=csv,noheader
|
||||
else
|
||||
log_warn "nvidia-smi not found on host"
|
||||
fi
|
||||
|
||||
log_info "Prerequisites check complete"
|
||||
}
|
||||
|
||||
build_image() {
|
||||
log_step "Building Docker image: ${IMAGE_NAME}"
|
||||
log_info "This may take several minutes..."
|
||||
|
||||
cd "$INSTALL_DIR"
|
||||
|
||||
docker build \
|
||||
--network=host \
|
||||
-t "$IMAGE_NAME" \
|
||||
-f Dockerfile \
|
||||
.
|
||||
|
||||
log_info "Docker image built successfully: ${IMAGE_NAME}"
|
||||
}
|
||||
|
||||
run_container() {
|
||||
local mode="${1:-interactive}"
|
||||
|
||||
log_step "Running container in ${mode} mode..."
|
||||
|
||||
# Stop existing container if running
|
||||
if docker ps -q -f name="$CONTAINER_NAME" | grep -q .; then
|
||||
log_warn "Stopping existing container..."
|
||||
docker stop "$CONTAINER_NAME" 2>/dev/null || true
|
||||
fi
|
||||
|
||||
# Remove existing container
|
||||
docker rm "$CONTAINER_NAME" 2>/dev/null || true
|
||||
|
||||
# Common Docker options for DGX Spark
|
||||
local docker_opts=(
|
||||
--gpus all
|
||||
--ipc=host
|
||||
--network=host
|
||||
--ulimit memlock=-1:-1
|
||||
--ulimit stack=-1:-1
|
||||
-e "PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True"
|
||||
-v "$HOME/.cache/huggingface:/root/.cache/huggingface"
|
||||
--name "$CONTAINER_NAME"
|
||||
)
|
||||
|
||||
if [ "$mode" = "interactive" ]; then
|
||||
docker run --rm -it "${docker_opts[@]}" "$IMAGE_NAME" bash
|
||||
elif [ "$mode" = "test" ]; then
|
||||
docker run --rm "${docker_opts[@]}" "$IMAGE_NAME" python /workspace/test_vibevoice.py
|
||||
elif [ "$mode" = "demo" ]; then
|
||||
log_info "Starting Gradio demo on port 7860..."
|
||||
log_info "Access the demo at: http://localhost:7860"
|
||||
docker run --rm -it "${docker_opts[@]}" "$IMAGE_NAME"
|
||||
elif [ "$mode" = "realtime" ]; then
|
||||
log_info "Starting Realtime ASR server on port 8000..."
|
||||
log_info "WebSocket API: ws://localhost:8000/ws/asr/{session_id}"
|
||||
log_info "Test client: http://localhost:8000/static/realtime_client.html"
|
||||
docker run --rm -it "${docker_opts[@]}" "$IMAGE_NAME" \
|
||||
python -m realtime.server --host 0.0.0.0 --port 8000
|
||||
elif [ "$mode" = "serve" ]; then
|
||||
log_info "Starting all services..."
|
||||
log_info " Gradio demo: http://localhost:7860"
|
||||
log_info " Realtime ASR: http://localhost:8000"
|
||||
log_info " Test client: http://localhost:8000/static/realtime_client.html"
|
||||
docker run --rm -it "${docker_opts[@]}" "$IMAGE_NAME" ./run_all.sh
|
||||
else
|
||||
log_error "Unknown mode: $mode"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
show_usage() {
|
||||
echo "VibeVoice-ASR Setup for DGX Spark"
|
||||
echo ""
|
||||
echo "Usage:"
|
||||
echo " curl -sL $BASE_URL/setup.sh | bash # Download only"
|
||||
echo " curl -sL $BASE_URL/setup.sh | bash -s build # Download and build"
|
||||
echo " curl -sL $BASE_URL/setup.sh | bash -s demo # Download, build, run demo"
|
||||
echo " curl -sL $BASE_URL/setup.sh | bash -s serve # Download, build, run all"
|
||||
echo ""
|
||||
echo "Commands:"
|
||||
echo " (default) Download files only"
|
||||
echo " build Download and build Docker image"
|
||||
echo " demo Download, build, and start Gradio demo (port 7860)"
|
||||
echo " realtime Download, build, and start Realtime ASR (port 8000)"
|
||||
echo " serve Download, build, and start both services"
|
||||
echo " run Run container interactively (after build)"
|
||||
echo ""
|
||||
echo "Environment variables:"
|
||||
echo " VIBEVOICE_DIR Installation directory (default: ~/vibevoice-asr)"
|
||||
echo ""
|
||||
echo "After installation, you can also run:"
|
||||
echo " cd ~/vibevoice-asr"
|
||||
echo " docker run --gpus all -p 7860:7860 vibevoice-asr:dgx-spark"
|
||||
}
|
||||
|
||||
main() {
|
||||
local command="${1:-download}"
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo " VibeVoice-ASR Setup for DGX Spark"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
case "$command" in
|
||||
download)
|
||||
download_files
|
||||
echo ""
|
||||
log_info "Done! Next steps:"
|
||||
echo " cd $INSTALL_DIR"
|
||||
echo " docker build -t vibevoice-asr:dgx-spark ."
|
||||
echo " docker run --gpus all -p 7860:7860 vibevoice-asr:dgx-spark"
|
||||
;;
|
||||
build)
|
||||
download_files
|
||||
check_prerequisites
|
||||
build_image
|
||||
echo ""
|
||||
log_info "Done! To run:"
|
||||
echo " cd $INSTALL_DIR"
|
||||
echo " docker run --gpus all -p 7860:7860 vibevoice-asr:dgx-spark"
|
||||
;;
|
||||
run)
|
||||
cd "$INSTALL_DIR" 2>/dev/null || { log_error "Run 'build' first"; exit 1; }
|
||||
run_container interactive
|
||||
;;
|
||||
test)
|
||||
cd "$INSTALL_DIR" 2>/dev/null || { log_error "Run 'build' first"; exit 1; }
|
||||
run_container test
|
||||
;;
|
||||
demo)
|
||||
download_files
|
||||
check_prerequisites
|
||||
build_image
|
||||
run_container demo
|
||||
;;
|
||||
realtime)
|
||||
download_files
|
||||
check_prerequisites
|
||||
build_image
|
||||
run_container realtime
|
||||
;;
|
||||
serve)
|
||||
download_files
|
||||
check_prerequisites
|
||||
build_image
|
||||
run_container serve
|
||||
;;
|
||||
-h|--help|help)
|
||||
show_usage
|
||||
;;
|
||||
*)
|
||||
log_error "Unknown command: $command"
|
||||
show_usage
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
main "$@"
|
||||
899
static/scripts/vibevoice-asr/static/realtime_client.html
Normal file
899
static/scripts/vibevoice-asr/static/realtime_client.html
Normal file
@ -0,0 +1,899 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="ja">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>VibeVoice Realtime ASR Client</title>
|
||||
<style>
|
||||
:root {
|
||||
--bg-primary: #1a1a2e;
|
||||
--bg-secondary: #16213e;
|
||||
--bg-tertiary: #0f3460;
|
||||
--text-primary: #eaeaea;
|
||||
--text-secondary: #a0a0a0;
|
||||
--accent: #e94560;
|
||||
--success: #4ade80;
|
||||
--warning: #fbbf24;
|
||||
--info: #60a5fa;
|
||||
}
|
||||
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: 'Segoe UI', system-ui, sans-serif;
|
||||
background: var(--bg-primary);
|
||||
color: var(--text-primary);
|
||||
min-height: 100vh;
|
||||
padding: 20px;
|
||||
}
|
||||
|
||||
.container {
|
||||
max-width: 1200px;
|
||||
margin: 0 auto;
|
||||
}
|
||||
|
||||
header {
|
||||
text-align: center;
|
||||
margin-bottom: 30px;
|
||||
}
|
||||
|
||||
h1 {
|
||||
font-size: 2rem;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
|
||||
.subtitle {
|
||||
color: var(--text-secondary);
|
||||
}
|
||||
|
||||
.main-grid {
|
||||
display: grid;
|
||||
grid-template-columns: 1fr 1fr;
|
||||
gap: 20px;
|
||||
}
|
||||
|
||||
@media (max-width: 768px) {
|
||||
.main-grid {
|
||||
grid-template-columns: 1fr;
|
||||
}
|
||||
}
|
||||
|
||||
.card {
|
||||
background: var(--bg-secondary);
|
||||
border-radius: 12px;
|
||||
padding: 20px;
|
||||
}
|
||||
|
||||
.card-title {
|
||||
font-size: 1.1rem;
|
||||
margin-bottom: 15px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
/* Controls */
|
||||
.controls {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 15px;
|
||||
}
|
||||
|
||||
.control-row {
|
||||
display: flex;
|
||||
gap: 10px;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
input[type="text"], input[type="number"] {
|
||||
background: var(--bg-tertiary);
|
||||
border: 1px solid rgba(255,255,255,0.1);
|
||||
border-radius: 8px;
|
||||
padding: 10px 15px;
|
||||
color: var(--text-primary);
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
input:focus {
|
||||
outline: none;
|
||||
border-color: var(--accent);
|
||||
}
|
||||
|
||||
button {
|
||||
background: var(--accent);
|
||||
border: none;
|
||||
border-radius: 8px;
|
||||
padding: 12px 24px;
|
||||
color: white;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s;
|
||||
}
|
||||
|
||||
button:hover {
|
||||
filter: brightness(1.1);
|
||||
}
|
||||
|
||||
button:disabled {
|
||||
opacity: 0.5;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
button.secondary {
|
||||
background: var(--bg-tertiary);
|
||||
}
|
||||
|
||||
button.success {
|
||||
background: var(--success);
|
||||
color: #000;
|
||||
}
|
||||
|
||||
button.stop {
|
||||
background: #ef4444;
|
||||
}
|
||||
|
||||
/* Status */
|
||||
.status-bar {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
padding: 10px 15px;
|
||||
background: var(--bg-tertiary);
|
||||
border-radius: 8px;
|
||||
margin-bottom: 15px;
|
||||
}
|
||||
|
||||
.status-indicator {
|
||||
width: 12px;
|
||||
height: 12px;
|
||||
border-radius: 50%;
|
||||
background: var(--text-secondary);
|
||||
}
|
||||
|
||||
.status-indicator.connected {
|
||||
background: var(--success);
|
||||
box-shadow: 0 0 10px var(--success);
|
||||
}
|
||||
|
||||
.status-indicator.recording {
|
||||
background: var(--accent);
|
||||
box-shadow: 0 0 10px var(--accent);
|
||||
animation: pulse 1s infinite;
|
||||
}
|
||||
|
||||
@keyframes pulse {
|
||||
0%, 100% { opacity: 1; }
|
||||
50% { opacity: 0.5; }
|
||||
}
|
||||
|
||||
/* Transcription output */
|
||||
.transcription-box {
|
||||
background: var(--bg-tertiary);
|
||||
border-radius: 8px;
|
||||
padding: 15px;
|
||||
min-height: 300px;
|
||||
max-height: 500px;
|
||||
overflow-y: auto;
|
||||
font-family: 'Consolas', monospace;
|
||||
line-height: 1.6;
|
||||
}
|
||||
|
||||
.segment {
|
||||
margin-bottom: 15px;
|
||||
padding: 10px;
|
||||
background: rgba(255,255,255,0.05);
|
||||
border-radius: 6px;
|
||||
border-left: 3px solid var(--accent);
|
||||
}
|
||||
|
||||
.segment.partial {
|
||||
border-left-color: var(--warning);
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
.segment-meta {
|
||||
font-size: 0.8rem;
|
||||
color: var(--text-secondary);
|
||||
margin-bottom: 5px;
|
||||
}
|
||||
|
||||
.segment-text {
|
||||
font-size: 1rem;
|
||||
}
|
||||
|
||||
/* Audio visualizer */
|
||||
.visualizer-container {
|
||||
height: 80px;
|
||||
background: var(--bg-tertiary);
|
||||
border-radius: 8px;
|
||||
overflow: hidden;
|
||||
margin-bottom: 15px;
|
||||
}
|
||||
|
||||
#visualizer {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
}
|
||||
|
||||
/* VAD indicator */
|
||||
.vad-indicator {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
padding: 10px;
|
||||
background: var(--bg-tertiary);
|
||||
border-radius: 8px;
|
||||
margin-bottom: 15px;
|
||||
}
|
||||
|
||||
.vad-bar {
|
||||
flex: 1;
|
||||
height: 8px;
|
||||
background: rgba(255,255,255,0.1);
|
||||
border-radius: 4px;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.vad-level {
|
||||
height: 100%;
|
||||
background: var(--success);
|
||||
width: 0%;
|
||||
transition: width 0.1s;
|
||||
}
|
||||
|
||||
.vad-level.speech {
|
||||
background: var(--accent);
|
||||
}
|
||||
|
||||
/* Stats */
|
||||
.stats {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(3, 1fr);
|
||||
gap: 10px;
|
||||
margin-top: 15px;
|
||||
}
|
||||
|
||||
.stat-item {
|
||||
background: var(--bg-tertiary);
|
||||
padding: 10px;
|
||||
border-radius: 8px;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.stat-value {
|
||||
font-size: 1.5rem;
|
||||
font-weight: bold;
|
||||
color: var(--accent);
|
||||
}
|
||||
|
||||
.stat-label {
|
||||
font-size: 0.75rem;
|
||||
color: var(--text-secondary);
|
||||
}
|
||||
|
||||
/* Log */
|
||||
.log-box {
|
||||
background: var(--bg-tertiary);
|
||||
border-radius: 8px;
|
||||
padding: 10px;
|
||||
height: 150px;
|
||||
overflow-y: auto;
|
||||
font-family: 'Consolas', monospace;
|
||||
font-size: 0.8rem;
|
||||
}
|
||||
|
||||
.log-entry {
|
||||
margin-bottom: 2px;
|
||||
color: var(--text-secondary);
|
||||
}
|
||||
|
||||
.log-entry.error {
|
||||
color: #ef4444;
|
||||
}
|
||||
|
||||
.log-entry.success {
|
||||
color: var(--success);
|
||||
}
|
||||
|
||||
.log-entry.info {
|
||||
color: var(--info);
|
||||
}
|
||||
|
||||
/* Settings panel */
|
||||
.settings-group {
|
||||
margin-top: 15px;
|
||||
padding: 15px;
|
||||
background: var(--bg-tertiary);
|
||||
border-radius: 8px;
|
||||
}
|
||||
|
||||
.settings-group h4 {
|
||||
margin-bottom: 10px;
|
||||
font-size: 0.9rem;
|
||||
color: var(--text-secondary);
|
||||
}
|
||||
|
||||
.setting-item {
|
||||
margin-bottom: 12px;
|
||||
}
|
||||
|
||||
.setting-item label {
|
||||
display: block;
|
||||
font-size: 0.85rem;
|
||||
margin-bottom: 5px;
|
||||
color: var(--text-secondary);
|
||||
}
|
||||
|
||||
.setting-item input[type="range"] {
|
||||
width: 100%;
|
||||
height: 6px;
|
||||
border-radius: 3px;
|
||||
background: rgba(255,255,255,0.1);
|
||||
outline: none;
|
||||
-webkit-appearance: none;
|
||||
}
|
||||
|
||||
.setting-item input[type="range"]::-webkit-slider-thumb {
|
||||
-webkit-appearance: none;
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
border-radius: 50%;
|
||||
background: var(--accent);
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.setting-value {
|
||||
font-size: 0.8rem;
|
||||
color: var(--accent);
|
||||
float: right;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<header>
|
||||
<h1>🎙️ VibeVoice Realtime ASR</h1>
|
||||
<p class="subtitle">リアルタイム音声認識デモ</p>
|
||||
</header>
|
||||
|
||||
<div class="status-bar">
|
||||
<div class="status-indicator" id="statusIndicator"></div>
|
||||
<span id="statusText">未接続</span>
|
||||
</div>
|
||||
|
||||
<div class="main-grid">
|
||||
<!-- Left column: Controls -->
|
||||
<div class="card">
|
||||
<h2 class="card-title">⚙️ 設定</h2>
|
||||
<div class="controls">
|
||||
<div class="control-row">
|
||||
<input type="text" id="serverUrl" placeholder="WebSocket URL"
|
||||
value="ws://localhost:8000/ws/asr/demo">
|
||||
</div>
|
||||
<div class="control-row">
|
||||
<button id="connectBtn" onclick="toggleConnection()">接続</button>
|
||||
<button id="recordBtn" class="success" onclick="toggleRecording()" disabled>
|
||||
🎤 録音開始
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="visualizer-container">
|
||||
<canvas id="visualizer"></canvas>
|
||||
</div>
|
||||
|
||||
<div class="vad-indicator">
|
||||
<span>VAD:</span>
|
||||
<div class="vad-bar">
|
||||
<div class="vad-level" id="vadLevel"></div>
|
||||
</div>
|
||||
<span id="vadStatus">待機中</span>
|
||||
</div>
|
||||
|
||||
<div class="settings-group">
|
||||
<h4>🎚️ VAD設定</h4>
|
||||
<div class="setting-item">
|
||||
<label>
|
||||
音声検出閾値
|
||||
<span class="setting-value" id="vadThresholdValue">0.5</span>
|
||||
</label>
|
||||
<input type="range" id="vadThreshold" min="0.1" max="0.9" step="0.1" value="0.5"
|
||||
onchange="updateConfig()">
|
||||
</div>
|
||||
<div class="setting-item">
|
||||
<label>
|
||||
最小発話時間 (ms)
|
||||
<span class="setting-value" id="minSpeechValue">250</span>
|
||||
</label>
|
||||
<input type="range" id="minSpeechDuration" min="100" max="1000" step="50" value="250"
|
||||
onchange="updateConfig()">
|
||||
</div>
|
||||
<div class="setting-item">
|
||||
<label>
|
||||
無音判定時間 (ms)
|
||||
<span class="setting-value" id="minSilenceValue">500</span>
|
||||
</label>
|
||||
<input type="range" id="minSilenceDuration" min="200" max="2000" step="100" value="500"
|
||||
onchange="updateConfig()">
|
||||
</div>
|
||||
<div class="setting-item">
|
||||
<label>
|
||||
最小音量閾値
|
||||
<span class="setting-value" id="minVolumeValue">0.01</span>
|
||||
</label>
|
||||
<input type="range" id="minVolumeThreshold" min="0.001" max="0.1" step="0.001" value="0.01"
|
||||
onchange="updateConfig()">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="stats">
|
||||
<div class="stat-item">
|
||||
<div class="stat-value" id="statDuration">0.0</div>
|
||||
<div class="stat-label">録音時間 (秒)</div>
|
||||
</div>
|
||||
<div class="stat-item">
|
||||
<div class="stat-value" id="statSegments">0</div>
|
||||
<div class="stat-label">セグメント数</div>
|
||||
</div>
|
||||
<div class="stat-item">
|
||||
<div class="stat-value" id="statLatency">-</div>
|
||||
<div class="stat-label">レイテンシ (ms)</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<h3 class="card-title" style="margin-top: 20px;">📋 ログ</h3>
|
||||
<div class="log-box" id="logBox"></div>
|
||||
</div>
|
||||
|
||||
<!-- Right column: Transcription -->
|
||||
<div class="card">
|
||||
<h2 class="card-title">📝 認識結果</h2>
|
||||
<div class="control-row" style="margin-bottom: 15px;">
|
||||
<button class="secondary" onclick="clearTranscription()">クリア</button>
|
||||
<button class="secondary" onclick="copyTranscription()">コピー</button>
|
||||
</div>
|
||||
<div class="transcription-box" id="transcriptionBox">
|
||||
<p style="color: var(--text-secondary); text-align: center; padding: 50px;">
|
||||
接続して録音を開始すると、ここに認識結果が表示されます
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
// State
|
||||
let websocket = null;
|
||||
let mediaStream = null;
|
||||
let audioContext = null;
|
||||
let processor = null;
|
||||
let analyser = null;
|
||||
let isRecording = false;
|
||||
let recordingStartTime = null;
|
||||
let segmentCount = 0;
|
||||
let currentPartialText = '';
|
||||
|
||||
// DOM elements
|
||||
const statusIndicator = document.getElementById('statusIndicator');
|
||||
const statusText = document.getElementById('statusText');
|
||||
const connectBtn = document.getElementById('connectBtn');
|
||||
const recordBtn = document.getElementById('recordBtn');
|
||||
const transcriptionBox = document.getElementById('transcriptionBox');
|
||||
const logBox = document.getElementById('logBox');
|
||||
const vadLevel = document.getElementById('vadLevel');
|
||||
const vadStatus = document.getElementById('vadStatus');
|
||||
|
||||
// Logging
|
||||
function log(message, type = 'info') {
|
||||
const entry = document.createElement('div');
|
||||
entry.className = `log-entry ${type}`;
|
||||
entry.textContent = `[${new Date().toLocaleTimeString()}] ${message}`;
|
||||
logBox.appendChild(entry);
|
||||
logBox.scrollTop = logBox.scrollHeight;
|
||||
console.log(`[${type}] ${message}`);
|
||||
}
|
||||
|
||||
// Update config display values
|
||||
function updateConfigDisplay() {
|
||||
document.getElementById('vadThresholdValue').textContent =
|
||||
document.getElementById('vadThreshold').value;
|
||||
document.getElementById('minSpeechValue').textContent =
|
||||
document.getElementById('minSpeechDuration').value;
|
||||
document.getElementById('minSilenceValue').textContent =
|
||||
document.getElementById('minSilenceDuration').value;
|
||||
document.getElementById('minVolumeValue').textContent =
|
||||
document.getElementById('minVolumeThreshold').value;
|
||||
}
|
||||
|
||||
// Send config to server
|
||||
function updateConfig() {
|
||||
updateConfigDisplay();
|
||||
|
||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||
const config = {
|
||||
type: 'config',
|
||||
config: {
|
||||
vad_threshold: parseFloat(document.getElementById('vadThreshold').value),
|
||||
min_speech_duration_ms: parseInt(document.getElementById('minSpeechDuration').value),
|
||||
min_silence_duration_ms: parseInt(document.getElementById('minSilenceDuration').value),
|
||||
min_volume_threshold: parseFloat(document.getElementById('minVolumeThreshold').value),
|
||||
}
|
||||
};
|
||||
websocket.send(JSON.stringify(config));
|
||||
log(`VAD設定を更新: 閾値=${config.config.vad_threshold}, 最小発話=${config.config.min_speech_duration_ms}ms, 無音判定=${config.config.min_silence_duration_ms}ms, 最小音量=${config.config.min_volume_threshold}`, 'info');
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize config display on load
|
||||
document.addEventListener('DOMContentLoaded', updateConfigDisplay);
|
||||
|
||||
// Connection
|
||||
function toggleConnection() {
|
||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||
disconnect();
|
||||
} else {
|
||||
connect();
|
||||
}
|
||||
}
|
||||
|
||||
function connect() {
|
||||
const url = document.getElementById('serverUrl').value;
|
||||
log(`接続中: ${url}`, 'info');
|
||||
|
||||
try {
|
||||
websocket = new WebSocket(url);
|
||||
|
||||
websocket.onopen = () => {
|
||||
log('接続成功', 'success');
|
||||
statusIndicator.classList.add('connected');
|
||||
statusText.textContent = '接続済み';
|
||||
connectBtn.textContent = '切断';
|
||||
recordBtn.disabled = false;
|
||||
};
|
||||
|
||||
websocket.onclose = () => {
|
||||
log('切断されました', 'info');
|
||||
handleDisconnect();
|
||||
};
|
||||
|
||||
websocket.onerror = (error) => {
|
||||
log(`エラー: ${error}`, 'error');
|
||||
};
|
||||
|
||||
websocket.onmessage = (event) => {
|
||||
handleMessage(JSON.parse(event.data));
|
||||
};
|
||||
|
||||
} catch (error) {
|
||||
log(`接続エラー: ${error}`, 'error');
|
||||
}
|
||||
}
|
||||
|
||||
function disconnect() {
|
||||
if (isRecording) {
|
||||
stopRecording();
|
||||
}
|
||||
if (websocket) {
|
||||
websocket.close();
|
||||
}
|
||||
handleDisconnect();
|
||||
}
|
||||
|
||||
function handleDisconnect() {
|
||||
statusIndicator.classList.remove('connected', 'recording');
|
||||
statusText.textContent = '未接続';
|
||||
connectBtn.textContent = '接続';
|
||||
recordBtn.disabled = true;
|
||||
websocket = null;
|
||||
}
|
||||
|
||||
// Message handling
|
||||
function handleMessage(data) {
|
||||
switch (data.type) {
|
||||
case 'status':
|
||||
log(`ステータス: ${data.message}`, 'info');
|
||||
break;
|
||||
|
||||
case 'partial_result':
|
||||
updatePartialResult(data);
|
||||
break;
|
||||
|
||||
case 'final_result':
|
||||
addFinalResult(data);
|
||||
break;
|
||||
|
||||
case 'vad_event':
|
||||
handleVADEvent(data);
|
||||
break;
|
||||
|
||||
case 'error':
|
||||
log(`サーバーエラー: ${data.error}`, 'error');
|
||||
break;
|
||||
|
||||
case 'pong':
|
||||
// Heartbeat response
|
||||
break;
|
||||
|
||||
default:
|
||||
log(`不明なメッセージ: ${data.type}`, 'info');
|
||||
}
|
||||
}
|
||||
|
||||
function updatePartialResult(data) {
|
||||
currentPartialText = data.text;
|
||||
updateTranscriptionDisplay();
|
||||
|
||||
if (data.latency_ms) {
|
||||
document.getElementById('statLatency').textContent =
|
||||
Math.round(data.latency_ms);
|
||||
}
|
||||
}
|
||||
|
||||
function addFinalResult(data) {
|
||||
currentPartialText = '';
|
||||
segmentCount++;
|
||||
document.getElementById('statSegments').textContent = segmentCount;
|
||||
|
||||
if (data.latency_ms) {
|
||||
document.getElementById('statLatency').textContent =
|
||||
Math.round(data.latency_ms);
|
||||
}
|
||||
|
||||
// Add final segment to display
|
||||
const segment = document.createElement('div');
|
||||
segment.className = 'segment';
|
||||
|
||||
let metaText = '';
|
||||
if (data.segments && data.segments.length > 0) {
|
||||
const seg = data.segments[0];
|
||||
metaText = `[${seg.start_time?.toFixed(2) || '?'}s - ${seg.end_time?.toFixed(2) || '?'}s] ${seg.speaker_id || ''}`;
|
||||
}
|
||||
|
||||
segment.innerHTML = `
|
||||
<div class="segment-meta">${metaText}</div>
|
||||
<div class="segment-text">${data.text}</div>
|
||||
`;
|
||||
|
||||
// Remove placeholder if exists
|
||||
const placeholder = transcriptionBox.querySelector('p');
|
||||
if (placeholder) placeholder.remove();
|
||||
|
||||
transcriptionBox.appendChild(segment);
|
||||
transcriptionBox.scrollTop = transcriptionBox.scrollHeight;
|
||||
|
||||
log(`認識完了: "${data.text.substring(0, 30)}..."`, 'success');
|
||||
}
|
||||
|
||||
function updateTranscriptionDisplay() {
|
||||
// Update or create partial display
|
||||
let partialDiv = transcriptionBox.querySelector('.segment.partial');
|
||||
|
||||
if (currentPartialText) {
|
||||
if (!partialDiv) {
|
||||
partialDiv = document.createElement('div');
|
||||
partialDiv.className = 'segment partial';
|
||||
transcriptionBox.appendChild(partialDiv);
|
||||
}
|
||||
partialDiv.innerHTML = `
|
||||
<div class="segment-meta">認識中...</div>
|
||||
<div class="segment-text">${currentPartialText}</div>
|
||||
`;
|
||||
transcriptionBox.scrollTop = transcriptionBox.scrollHeight;
|
||||
} else if (partialDiv) {
|
||||
partialDiv.remove();
|
||||
}
|
||||
}
|
||||
|
||||
function handleVADEvent(data) {
|
||||
if (data.event === 'speech_start') {
|
||||
vadLevel.classList.add('speech');
|
||||
vadStatus.textContent = '発話中';
|
||||
log(`発話開始 @ ${data.audio_timestamp_sec?.toFixed(2)}s`, 'info');
|
||||
} else if (data.event === 'speech_end') {
|
||||
vadLevel.classList.remove('speech');
|
||||
vadStatus.textContent = '待機中';
|
||||
log(`発話終了 @ ${data.audio_timestamp_sec?.toFixed(2)}s`, 'info');
|
||||
}
|
||||
}
|
||||
|
||||
// Recording
|
||||
async function toggleRecording() {
|
||||
if (isRecording) {
|
||||
stopRecording();
|
||||
} else {
|
||||
await startRecording();
|
||||
}
|
||||
}
|
||||
|
||||
async function startRecording() {
|
||||
try {
|
||||
log('マイクアクセスをリクエスト中...', 'info');
|
||||
|
||||
mediaStream = await navigator.mediaDevices.getUserMedia({
|
||||
audio: {
|
||||
sampleRate: 16000,
|
||||
channelCount: 1,
|
||||
echoCancellation: true,
|
||||
noiseSuppression: true,
|
||||
}
|
||||
});
|
||||
|
||||
audioContext = new (window.AudioContext || window.webkitAudioContext)({
|
||||
sampleRate: 16000
|
||||
});
|
||||
|
||||
const source = audioContext.createMediaStreamSource(mediaStream);
|
||||
|
||||
// Analyser for visualization
|
||||
analyser = audioContext.createAnalyser();
|
||||
analyser.fftSize = 256;
|
||||
source.connect(analyser);
|
||||
|
||||
// ScriptProcessor for sending audio
|
||||
const bufferSize = 4096;
|
||||
processor = audioContext.createScriptProcessor(bufferSize, 1, 1);
|
||||
|
||||
processor.onaudioprocess = (e) => {
|
||||
if (!isRecording || !websocket || websocket.readyState !== WebSocket.OPEN) {
|
||||
return;
|
||||
}
|
||||
|
||||
const inputData = e.inputBuffer.getChannelData(0);
|
||||
|
||||
// Convert to 16-bit PCM
|
||||
const pcmData = new Int16Array(inputData.length);
|
||||
for (let i = 0; i < inputData.length; i++) {
|
||||
pcmData[i] = Math.max(-32768, Math.min(32767, inputData[i] * 32768));
|
||||
}
|
||||
|
||||
// Send as binary
|
||||
websocket.send(pcmData.buffer);
|
||||
};
|
||||
|
||||
source.connect(processor);
|
||||
processor.connect(audioContext.destination);
|
||||
|
||||
isRecording = true;
|
||||
recordingStartTime = Date.now();
|
||||
statusIndicator.classList.add('recording');
|
||||
recordBtn.textContent = '⏹️ 録音停止';
|
||||
recordBtn.classList.remove('success');
|
||||
recordBtn.classList.add('stop');
|
||||
|
||||
// Start visualization
|
||||
visualize();
|
||||
updateDuration();
|
||||
|
||||
log('録音開始', 'success');
|
||||
|
||||
} catch (error) {
|
||||
log(`マイクエラー: ${error}`, 'error');
|
||||
}
|
||||
}
|
||||
|
||||
function stopRecording() {
|
||||
isRecording = false;
|
||||
|
||||
if (processor) {
|
||||
processor.disconnect();
|
||||
processor = null;
|
||||
}
|
||||
|
||||
if (audioContext) {
|
||||
audioContext.close();
|
||||
audioContext = null;
|
||||
}
|
||||
|
||||
if (mediaStream) {
|
||||
mediaStream.getTracks().forEach(track => track.stop());
|
||||
mediaStream = null;
|
||||
}
|
||||
|
||||
statusIndicator.classList.remove('recording');
|
||||
recordBtn.textContent = '🎤 録音開始';
|
||||
recordBtn.classList.remove('stop');
|
||||
recordBtn.classList.add('success');
|
||||
|
||||
// Send stop message
|
||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||
websocket.send(JSON.stringify({ type: 'stop' }));
|
||||
}
|
||||
|
||||
log('録音停止', 'info');
|
||||
}
|
||||
|
||||
// Visualization
|
||||
function visualize() {
|
||||
if (!analyser || !isRecording) return;
|
||||
|
||||
const canvas = document.getElementById('visualizer');
|
||||
const ctx = canvas.getContext('2d');
|
||||
const width = canvas.width = canvas.offsetWidth;
|
||||
const height = canvas.height = canvas.offsetHeight;
|
||||
|
||||
const bufferLength = analyser.frequencyBinCount;
|
||||
const dataArray = new Uint8Array(bufferLength);
|
||||
|
||||
function draw() {
|
||||
if (!isRecording) return;
|
||||
requestAnimationFrame(draw);
|
||||
|
||||
analyser.getByteFrequencyData(dataArray);
|
||||
|
||||
ctx.fillStyle = 'rgb(15, 52, 96)';
|
||||
ctx.fillRect(0, 0, width, height);
|
||||
|
||||
const barWidth = (width / bufferLength) * 2.5;
|
||||
let x = 0;
|
||||
|
||||
// Calculate average for VAD indicator
|
||||
let sum = 0;
|
||||
for (let i = 0; i < bufferLength; i++) {
|
||||
const barHeight = (dataArray[i] / 255) * height;
|
||||
|
||||
const gradient = ctx.createLinearGradient(0, height, 0, height - barHeight);
|
||||
gradient.addColorStop(0, '#e94560');
|
||||
gradient.addColorStop(1, '#4ade80');
|
||||
|
||||
ctx.fillStyle = gradient;
|
||||
ctx.fillRect(x, height - barHeight, barWidth, barHeight);
|
||||
|
||||
x += barWidth + 1;
|
||||
sum += dataArray[i];
|
||||
}
|
||||
|
||||
// Update VAD level indicator
|
||||
const avgLevel = (sum / bufferLength / 255) * 100;
|
||||
vadLevel.style.width = `${avgLevel}%`;
|
||||
}
|
||||
|
||||
draw();
|
||||
}
|
||||
|
||||
function updateDuration() {
|
||||
if (!isRecording) return;
|
||||
|
||||
const duration = (Date.now() - recordingStartTime) / 1000;
|
||||
document.getElementById('statDuration').textContent = duration.toFixed(1);
|
||||
|
||||
requestAnimationFrame(updateDuration);
|
||||
}
|
||||
|
||||
// Utility functions
|
||||
function clearTranscription() {
|
||||
transcriptionBox.innerHTML = `
|
||||
<p style="color: var(--text-secondary); text-align: center; padding: 50px;">
|
||||
接続して録音を開始すると、ここに認識結果が表示されます
|
||||
</p>
|
||||
`;
|
||||
segmentCount = 0;
|
||||
currentPartialText = '';
|
||||
document.getElementById('statSegments').textContent = '0';
|
||||
document.getElementById('statLatency').textContent = '-';
|
||||
log('認識結果をクリアしました', 'info');
|
||||
}
|
||||
|
||||
function copyTranscription() {
|
||||
const segments = transcriptionBox.querySelectorAll('.segment:not(.partial) .segment-text');
|
||||
const text = Array.from(segments).map(s => s.textContent).join('\n');
|
||||
|
||||
if (text) {
|
||||
navigator.clipboard.writeText(text).then(() => {
|
||||
log('認識結果をコピーしました', 'success');
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Heartbeat
|
||||
setInterval(() => {
|
||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||
websocket.send(JSON.stringify({ type: 'ping' }));
|
||||
}
|
||||
}, 30000);
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
173
static/scripts/vibevoice-asr/test_vibevoice.py
Normal file
173
static/scripts/vibevoice-asr/test_vibevoice.py
Normal file
@ -0,0 +1,173 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
VibeVoice-ASR Test Script for DGX Spark
|
||||
Tests basic functionality and GPU availability
|
||||
"""
|
||||
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
|
||||
def test_imports():
|
||||
"""Test that VibeVoice can be imported"""
|
||||
print("=" * 60)
|
||||
print("Testing VibeVoice imports...")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
import vibevoice
|
||||
print("[OK] vibevoice imported successfully")
|
||||
return True
|
||||
except ImportError as e:
|
||||
print(f"[FAIL] Failed to import vibevoice: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_torch_cuda():
|
||||
"""Test PyTorch CUDA availability"""
|
||||
print("\n" + "=" * 60)
|
||||
print("Testing PyTorch CUDA...")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
import torch
|
||||
print(f"[INFO] PyTorch version: {torch.__version__}")
|
||||
print(f"[INFO] CUDA available: {torch.cuda.is_available()}")
|
||||
|
||||
if torch.cuda.is_available():
|
||||
print(f"[INFO] CUDA version: {torch.version.cuda}")
|
||||
print(f"[INFO] GPU count: {torch.cuda.device_count()}")
|
||||
|
||||
for i in range(torch.cuda.device_count()):
|
||||
props = torch.cuda.get_device_properties(i)
|
||||
print(f"[INFO] GPU {i}: {props.name}")
|
||||
print(f" Compute capability: {props.major}.{props.minor}")
|
||||
print(f" Total memory: {props.total_memory / 1024**3:.1f} GB")
|
||||
|
||||
# Quick CUDA test
|
||||
x = torch.randn(100, 100, device='cuda')
|
||||
y = torch.matmul(x, x)
|
||||
print(f"[OK] CUDA tensor operations working")
|
||||
return True
|
||||
else:
|
||||
print("[WARN] CUDA not available")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"[FAIL] PyTorch CUDA test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_flash_attention():
|
||||
"""Test flash attention availability"""
|
||||
print("\n" + "=" * 60)
|
||||
print("Testing Flash Attention...")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
import flash_attn
|
||||
print(f"[OK] flash_attn version: {flash_attn.__version__}")
|
||||
return True
|
||||
except ImportError:
|
||||
print("[WARN] flash_attn not installed (optional)")
|
||||
return True # Not required
|
||||
|
||||
|
||||
def test_ffmpeg():
|
||||
"""Test FFmpeg availability"""
|
||||
print("\n" + "=" * 60)
|
||||
print("Testing FFmpeg...")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["ffmpeg", "-version"],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
if result.returncode == 0:
|
||||
version_line = result.stdout.split('\n')[0]
|
||||
print(f"[OK] {version_line}")
|
||||
return True
|
||||
else:
|
||||
print("[FAIL] FFmpeg returned error")
|
||||
return False
|
||||
except FileNotFoundError:
|
||||
print("[FAIL] FFmpeg not found")
|
||||
return False
|
||||
|
||||
|
||||
def test_asr_model():
|
||||
"""Test loading ASR model (if GPU available)"""
|
||||
print("\n" + "=" * 60)
|
||||
print("Testing ASR Model Loading...")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
import torch
|
||||
if not torch.cuda.is_available():
|
||||
print("[SKIP] Skipping model test - no GPU available")
|
||||
return True
|
||||
|
||||
# Try to load the ASR pipeline
|
||||
from vibevoice import ASRPipeline
|
||||
print("[INFO] Loading ASR pipeline...")
|
||||
|
||||
# Use smaller model for testing
|
||||
pipeline = ASRPipeline()
|
||||
print("[OK] ASR pipeline loaded successfully")
|
||||
|
||||
# Clean up
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return True
|
||||
|
||||
except ImportError as e:
|
||||
print(f"[WARN] ASRPipeline not available: {e}")
|
||||
print("[INFO] This may be normal depending on VibeVoice version")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"[WARN] ASR model test: {e}")
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all tests"""
|
||||
print("\n")
|
||||
print("*" * 60)
|
||||
print(" VibeVoice-ASR Test Suite for DGX Spark")
|
||||
print("*" * 60)
|
||||
|
||||
results = {
|
||||
"imports": test_imports(),
|
||||
"torch_cuda": test_torch_cuda(),
|
||||
"flash_attention": test_flash_attention(),
|
||||
"ffmpeg": test_ffmpeg(),
|
||||
"asr_model": test_asr_model(),
|
||||
}
|
||||
|
||||
print("\n")
|
||||
print("=" * 60)
|
||||
print("Test Summary")
|
||||
print("=" * 60)
|
||||
|
||||
all_passed = True
|
||||
for name, passed in results.items():
|
||||
status = "[OK]" if passed else "[FAIL]"
|
||||
print(f" {status} {name}")
|
||||
if not passed:
|
||||
all_passed = False
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
if all_passed:
|
||||
print("\nAll tests passed!")
|
||||
return 0
|
||||
else:
|
||||
print("\nSome tests failed.")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
1229
static/scripts/vibevoice-asr/vibevoice_asr_gradio_demo_patched.py
Normal file
1229
static/scripts/vibevoice-asr/vibevoice_asr_gradio_demo_patched.py
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user