All checks were successful
Deploy Docusaurus Site / deploy (push) Successful in 27s
296 lines
11 KiB
Python
296 lines
11 KiB
Python
"""
|
|
Voice Activity Detection (VAD) processor using Silero-VAD (ONNX version).
|
|
|
|
Detects speech segments in real-time audio streams.
|
|
Uses ONNX runtime to avoid torchaudio dependency issues.
|
|
"""
|
|
|
|
import numpy as np
|
|
from typing import List, Optional, Tuple
|
|
from dataclasses import dataclass
|
|
import threading
|
|
import os
|
|
import urllib.request
|
|
|
|
from .models import SpeechSegment, VADEvent, VADEventType, MessageType
|
|
|
|
|
|
@dataclass
|
|
class VADState:
|
|
"""Internal state of the VAD processor."""
|
|
is_speech_active: bool = False
|
|
speech_start_sample: int = 0
|
|
silence_start_sample: int = 0
|
|
last_speech_prob: float = 0.0
|
|
total_samples_processed: int = 0
|
|
|
|
|
|
class VADProcessor:
|
|
"""
|
|
Voice Activity Detection using Silero-VAD (ONNX version).
|
|
|
|
Features:
|
|
- Real-time speech detection
|
|
- Configurable thresholds for speech/silence duration
|
|
- Event generation for speech start/end
|
|
- Thread-safe operations
|
|
- No torchaudio dependency (uses ONNX runtime)
|
|
"""
|
|
|
|
# Silero VAD ONNX model URL
|
|
ONNX_MODEL_URL = "https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx"
|
|
|
|
def __init__(
|
|
self,
|
|
sample_rate: int = 16000,
|
|
threshold: float = 0.5,
|
|
min_speech_duration_ms: int = 250,
|
|
min_silence_duration_ms: int = 500,
|
|
window_size_samples: int = 512,
|
|
min_volume_threshold: float = 0.01,
|
|
):
|
|
"""
|
|
Initialize the VAD processor.
|
|
|
|
Args:
|
|
sample_rate: Audio sample rate (must be 16000 for Silero-VAD)
|
|
threshold: Speech probability threshold (0.0-1.0)
|
|
min_speech_duration_ms: Minimum speech duration to trigger speech_start
|
|
min_silence_duration_ms: Minimum silence duration to trigger speech_end
|
|
window_size_samples: VAD window size (512 for 16kHz = 32ms)
|
|
min_volume_threshold: Minimum RMS volume (0.0-1.0) to consider as potential speech
|
|
"""
|
|
if sample_rate != 16000:
|
|
raise ValueError("Silero-VAD requires 16kHz sample rate")
|
|
|
|
self.sample_rate = sample_rate
|
|
self.threshold = threshold
|
|
self.min_speech_samples = int(min_speech_duration_ms * sample_rate / 1000)
|
|
self.min_silence_samples = int(min_silence_duration_ms * sample_rate / 1000)
|
|
self.window_size = window_size_samples
|
|
self.min_volume_threshold = min_volume_threshold
|
|
|
|
# Load ONNX model
|
|
self._session = None
|
|
self._load_model()
|
|
|
|
# ONNX model state - single state tensor (size depends on model version)
|
|
# Silero VAD v5 uses a single 'state' tensor of shape (2, 1, 128)
|
|
self._state_tensor = np.zeros((2, 1, 128), dtype=np.float32)
|
|
|
|
# State
|
|
self._state = VADState()
|
|
self._lock = threading.Lock()
|
|
|
|
# Pending speech segment (being accumulated)
|
|
self._pending_segment_start: Optional[int] = None
|
|
|
|
def _get_model_path(self) -> str:
|
|
"""Get path to ONNX model, downloading if necessary."""
|
|
cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "silero-vad")
|
|
os.makedirs(cache_dir, exist_ok=True)
|
|
model_path = os.path.join(cache_dir, "silero_vad.onnx")
|
|
|
|
if not os.path.exists(model_path):
|
|
print(f"Downloading Silero-VAD ONNX model to {model_path}...")
|
|
urllib.request.urlretrieve(self.ONNX_MODEL_URL, model_path)
|
|
print("Download complete.")
|
|
|
|
return model_path
|
|
|
|
def _load_model(self) -> None:
|
|
"""Load Silero-VAD ONNX model."""
|
|
try:
|
|
import onnxruntime as ort
|
|
|
|
model_path = self._get_model_path()
|
|
self._session = ort.InferenceSession(
|
|
model_path,
|
|
providers=['CPUExecutionProvider']
|
|
)
|
|
print(f"Silero-VAD ONNX model loaded from {model_path}")
|
|
|
|
except Exception as e:
|
|
raise RuntimeError(f"Failed to load Silero-VAD ONNX model: {e}")
|
|
|
|
def _run_inference(self, audio_window: np.ndarray) -> float:
|
|
"""Run VAD inference on a single window."""
|
|
# Prepare input
|
|
audio_input = audio_window.reshape(1, -1).astype(np.float32)
|
|
sr_input = np.array([self.sample_rate], dtype=np.int64)
|
|
|
|
# Run inference - Silero VAD v5 uses 'state' instead of 'h'/'c'
|
|
outputs = self._session.run(
|
|
['output', 'stateN'],
|
|
{
|
|
'input': audio_input,
|
|
'sr': sr_input,
|
|
'state': self._state_tensor,
|
|
}
|
|
)
|
|
|
|
# Update state
|
|
speech_prob = outputs[0][0][0]
|
|
self._state_tensor = outputs[1]
|
|
|
|
return float(speech_prob)
|
|
|
|
def reset(self) -> None:
|
|
"""Reset VAD state for a new session."""
|
|
with self._lock:
|
|
self._state = VADState()
|
|
self._pending_segment_start = None
|
|
# Reset state tensor
|
|
self._state_tensor = np.zeros((2, 1, 128), dtype=np.float32)
|
|
|
|
def process(
|
|
self,
|
|
audio_chunk: np.ndarray,
|
|
return_events: bool = True,
|
|
) -> Tuple[List[SpeechSegment], List[VADEvent]]:
|
|
"""
|
|
Process an audio chunk and detect speech segments.
|
|
|
|
Args:
|
|
audio_chunk: Audio data as float32 array
|
|
return_events: Whether to return VAD events
|
|
|
|
Returns:
|
|
Tuple of (completed_segments, events)
|
|
"""
|
|
if audio_chunk.dtype != np.float32:
|
|
audio_chunk = audio_chunk.astype(np.float32)
|
|
|
|
completed_segments: List[SpeechSegment] = []
|
|
events: List[VADEvent] = []
|
|
|
|
with self._lock:
|
|
# Process in windows
|
|
chunk_start_sample = self._state.total_samples_processed
|
|
num_windows = len(audio_chunk) // self.window_size
|
|
|
|
for i in range(num_windows):
|
|
window_start = i * self.window_size
|
|
window_end = window_start + self.window_size
|
|
window = audio_chunk[window_start:window_end]
|
|
|
|
# Check volume (RMS) threshold first
|
|
rms = np.sqrt(np.mean(window ** 2))
|
|
if rms < self.min_volume_threshold:
|
|
# Volume too low, treat as silence
|
|
speech_prob = 0.0
|
|
else:
|
|
# Get speech probability from VAD model
|
|
speech_prob = self._run_inference(window)
|
|
|
|
self._state.last_speech_prob = speech_prob
|
|
|
|
current_sample = chunk_start_sample + window_end
|
|
is_speech = speech_prob >= self.threshold
|
|
|
|
# State machine for speech detection
|
|
if is_speech:
|
|
if not self._state.is_speech_active:
|
|
# Potential speech start
|
|
if self._pending_segment_start is None:
|
|
self._pending_segment_start = current_sample - self.window_size
|
|
|
|
# Check if speech duration exceeds minimum
|
|
speech_duration = current_sample - self._pending_segment_start
|
|
if speech_duration >= self.min_speech_samples:
|
|
self._state.is_speech_active = True
|
|
self._state.speech_start_sample = self._pending_segment_start
|
|
|
|
if return_events:
|
|
events.append(VADEvent(
|
|
type=MessageType.VAD_EVENT,
|
|
event=VADEventType.SPEECH_START,
|
|
audio_timestamp_sec=self._pending_segment_start / self.sample_rate,
|
|
))
|
|
else:
|
|
# Continue speech, reset silence counter
|
|
self._state.silence_start_sample = 0
|
|
else:
|
|
if self._state.is_speech_active:
|
|
# Potential speech end
|
|
if self._state.silence_start_sample == 0:
|
|
self._state.silence_start_sample = current_sample
|
|
|
|
# Check if silence duration exceeds minimum
|
|
silence_duration = current_sample - self._state.silence_start_sample
|
|
if silence_duration >= self.min_silence_samples:
|
|
# Speech ended - create completed segment
|
|
segment = SpeechSegment(
|
|
start_sample=self._state.speech_start_sample,
|
|
end_sample=self._state.silence_start_sample,
|
|
start_sec=self._state.speech_start_sample / self.sample_rate,
|
|
end_sec=self._state.silence_start_sample / self.sample_rate,
|
|
)
|
|
completed_segments.append(segment)
|
|
|
|
if return_events:
|
|
events.append(VADEvent(
|
|
type=MessageType.VAD_EVENT,
|
|
event=VADEventType.SPEECH_END,
|
|
audio_timestamp_sec=self._state.silence_start_sample / self.sample_rate,
|
|
))
|
|
|
|
# Reset state
|
|
self._state.is_speech_active = False
|
|
self._state.speech_start_sample = 0
|
|
self._state.silence_start_sample = 0
|
|
self._pending_segment_start = None
|
|
else:
|
|
# No speech, reset pending
|
|
self._pending_segment_start = None
|
|
|
|
# Update total samples processed
|
|
self._state.total_samples_processed += len(audio_chunk)
|
|
|
|
return completed_segments, events
|
|
|
|
def force_end_speech(self) -> Optional[SpeechSegment]:
|
|
"""
|
|
Force end of current speech segment (e.g., when session ends).
|
|
|
|
Returns:
|
|
Completed speech segment if speech was active, None otherwise
|
|
"""
|
|
with self._lock:
|
|
if self._state.is_speech_active:
|
|
segment = SpeechSegment(
|
|
start_sample=self._state.speech_start_sample,
|
|
end_sample=self._state.total_samples_processed,
|
|
start_sec=self._state.speech_start_sample / self.sample_rate,
|
|
end_sec=self._state.total_samples_processed / self.sample_rate,
|
|
)
|
|
|
|
self._state.is_speech_active = False
|
|
self._state.speech_start_sample = 0
|
|
self._state.silence_start_sample = 0
|
|
self._pending_segment_start = None
|
|
|
|
return segment
|
|
|
|
return None
|
|
|
|
@property
|
|
def is_speech_active(self) -> bool:
|
|
"""Check if speech is currently active."""
|
|
with self._lock:
|
|
return self._state.is_speech_active
|
|
|
|
@property
|
|
def last_speech_probability(self) -> float:
|
|
"""Get the last computed speech probability."""
|
|
with self._lock:
|
|
return self._state.last_speech_prob
|
|
|
|
@property
|
|
def current_speech_duration_sec(self) -> float:
|
|
"""Get duration of current speech segment (if active)."""
|
|
with self._lock:
|
|
if not self._state.is_speech_active:
|
|
return 0.0
|
|
return (self._state.total_samples_processed - self._state.speech_start_sample) / self.sample_rate
|