""" Voice Activity Detection (VAD) processor using Silero-VAD (ONNX version). Detects speech segments in real-time audio streams. Uses ONNX runtime to avoid torchaudio dependency issues. """ import numpy as np from typing import List, Optional, Tuple from dataclasses import dataclass import threading import os import urllib.request from .models import SpeechSegment, VADEvent, VADEventType, MessageType @dataclass class VADState: """Internal state of the VAD processor.""" is_speech_active: bool = False speech_start_sample: int = 0 silence_start_sample: int = 0 last_speech_prob: float = 0.0 total_samples_processed: int = 0 class VADProcessor: """ Voice Activity Detection using Silero-VAD (ONNX version). Features: - Real-time speech detection - Configurable thresholds for speech/silence duration - Event generation for speech start/end - Thread-safe operations - No torchaudio dependency (uses ONNX runtime) """ # Silero VAD ONNX model URL ONNX_MODEL_URL = "https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx" def __init__( self, sample_rate: int = 16000, threshold: float = 0.5, min_speech_duration_ms: int = 250, min_silence_duration_ms: int = 500, window_size_samples: int = 512, min_volume_threshold: float = 0.01, ): """ Initialize the VAD processor. Args: sample_rate: Audio sample rate (must be 16000 for Silero-VAD) threshold: Speech probability threshold (0.0-1.0) min_speech_duration_ms: Minimum speech duration to trigger speech_start min_silence_duration_ms: Minimum silence duration to trigger speech_end window_size_samples: VAD window size (512 for 16kHz = 32ms) min_volume_threshold: Minimum RMS volume (0.0-1.0) to consider as potential speech """ if sample_rate != 16000: raise ValueError("Silero-VAD requires 16kHz sample rate") self.sample_rate = sample_rate self.threshold = threshold self.min_speech_samples = int(min_speech_duration_ms * sample_rate / 1000) self.min_silence_samples = int(min_silence_duration_ms * sample_rate / 1000) self.window_size = window_size_samples self.min_volume_threshold = min_volume_threshold # Load ONNX model self._session = None self._load_model() # ONNX model state - single state tensor (size depends on model version) # Silero VAD v5 uses a single 'state' tensor of shape (2, 1, 128) self._state_tensor = np.zeros((2, 1, 128), dtype=np.float32) # State self._state = VADState() self._lock = threading.Lock() # Pending speech segment (being accumulated) self._pending_segment_start: Optional[int] = None def _get_model_path(self) -> str: """Get path to ONNX model, downloading if necessary.""" cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "silero-vad") os.makedirs(cache_dir, exist_ok=True) model_path = os.path.join(cache_dir, "silero_vad.onnx") if not os.path.exists(model_path): print(f"Downloading Silero-VAD ONNX model to {model_path}...") urllib.request.urlretrieve(self.ONNX_MODEL_URL, model_path) print("Download complete.") return model_path def _load_model(self) -> None: """Load Silero-VAD ONNX model.""" try: import onnxruntime as ort model_path = self._get_model_path() self._session = ort.InferenceSession( model_path, providers=['CPUExecutionProvider'] ) print(f"Silero-VAD ONNX model loaded from {model_path}") except Exception as e: raise RuntimeError(f"Failed to load Silero-VAD ONNX model: {e}") def _run_inference(self, audio_window: np.ndarray) -> float: """Run VAD inference on a single window.""" # Prepare input audio_input = audio_window.reshape(1, -1).astype(np.float32) sr_input = np.array([self.sample_rate], dtype=np.int64) # Run inference - Silero VAD v5 uses 'state' instead of 'h'/'c' outputs = self._session.run( ['output', 'stateN'], { 'input': audio_input, 'sr': sr_input, 'state': self._state_tensor, } ) # Update state speech_prob = outputs[0][0][0] self._state_tensor = outputs[1] return float(speech_prob) def reset(self) -> None: """Reset VAD state for a new session.""" with self._lock: self._state = VADState() self._pending_segment_start = None # Reset state tensor self._state_tensor = np.zeros((2, 1, 128), dtype=np.float32) def process( self, audio_chunk: np.ndarray, return_events: bool = True, ) -> Tuple[List[SpeechSegment], List[VADEvent]]: """ Process an audio chunk and detect speech segments. Args: audio_chunk: Audio data as float32 array return_events: Whether to return VAD events Returns: Tuple of (completed_segments, events) """ if audio_chunk.dtype != np.float32: audio_chunk = audio_chunk.astype(np.float32) completed_segments: List[SpeechSegment] = [] events: List[VADEvent] = [] with self._lock: # Process in windows chunk_start_sample = self._state.total_samples_processed num_windows = len(audio_chunk) // self.window_size for i in range(num_windows): window_start = i * self.window_size window_end = window_start + self.window_size window = audio_chunk[window_start:window_end] # Check volume (RMS) threshold first rms = np.sqrt(np.mean(window ** 2)) if rms < self.min_volume_threshold: # Volume too low, treat as silence speech_prob = 0.0 else: # Get speech probability from VAD model speech_prob = self._run_inference(window) self._state.last_speech_prob = speech_prob current_sample = chunk_start_sample + window_end is_speech = speech_prob >= self.threshold # State machine for speech detection if is_speech: if not self._state.is_speech_active: # Potential speech start if self._pending_segment_start is None: self._pending_segment_start = current_sample - self.window_size # Check if speech duration exceeds minimum speech_duration = current_sample - self._pending_segment_start if speech_duration >= self.min_speech_samples: self._state.is_speech_active = True self._state.speech_start_sample = self._pending_segment_start if return_events: events.append(VADEvent( type=MessageType.VAD_EVENT, event=VADEventType.SPEECH_START, audio_timestamp_sec=self._pending_segment_start / self.sample_rate, )) else: # Continue speech, reset silence counter self._state.silence_start_sample = 0 else: if self._state.is_speech_active: # Potential speech end if self._state.silence_start_sample == 0: self._state.silence_start_sample = current_sample # Check if silence duration exceeds minimum silence_duration = current_sample - self._state.silence_start_sample if silence_duration >= self.min_silence_samples: # Speech ended - create completed segment segment = SpeechSegment( start_sample=self._state.speech_start_sample, end_sample=self._state.silence_start_sample, start_sec=self._state.speech_start_sample / self.sample_rate, end_sec=self._state.silence_start_sample / self.sample_rate, ) completed_segments.append(segment) if return_events: events.append(VADEvent( type=MessageType.VAD_EVENT, event=VADEventType.SPEECH_END, audio_timestamp_sec=self._state.silence_start_sample / self.sample_rate, )) # Reset state self._state.is_speech_active = False self._state.speech_start_sample = 0 self._state.silence_start_sample = 0 self._pending_segment_start = None else: # No speech, reset pending self._pending_segment_start = None # Update total samples processed self._state.total_samples_processed += len(audio_chunk) return completed_segments, events def force_end_speech(self) -> Optional[SpeechSegment]: """ Force end of current speech segment (e.g., when session ends). Returns: Completed speech segment if speech was active, None otherwise """ with self._lock: if self._state.is_speech_active: segment = SpeechSegment( start_sample=self._state.speech_start_sample, end_sample=self._state.total_samples_processed, start_sec=self._state.speech_start_sample / self.sample_rate, end_sec=self._state.total_samples_processed / self.sample_rate, ) self._state.is_speech_active = False self._state.speech_start_sample = 0 self._state.silence_start_sample = 0 self._pending_segment_start = None return segment return None @property def is_speech_active(self) -> bool: """Check if speech is currently active.""" with self._lock: return self._state.is_speech_active @property def last_speech_probability(self) -> float: """Get the last computed speech probability.""" with self._lock: return self._state.last_speech_prob @property def current_speech_duration_sec(self) -> float: """Get duration of current speech segment (if active).""" with self._lock: if not self._state.is_speech_active: return 0.0 return (self._state.total_samples_processed - self._state.speech_start_sample) / self.sample_rate