All checks were successful
Deploy Docusaurus Site / deploy (push) Successful in 27s
247 lines
8.3 KiB
Python
247 lines
8.3 KiB
Python
"""
|
|
Audio buffer management for real-time ASR.
|
|
|
|
Implements a ring buffer for efficient audio chunk management with overlap support.
|
|
"""
|
|
|
|
import numpy as np
|
|
from typing import Optional, Tuple
|
|
from dataclasses import dataclass
|
|
import threading
|
|
|
|
|
|
@dataclass
|
|
class AudioChunkInfo:
|
|
"""Information about an extracted audio chunk."""
|
|
audio: np.ndarray
|
|
start_sample: int
|
|
end_sample: int
|
|
start_sec: float
|
|
end_sec: float
|
|
|
|
|
|
class AudioBuffer:
|
|
"""
|
|
Ring buffer for managing audio chunks with overlap support.
|
|
|
|
Features:
|
|
- Efficient memory management with fixed-size buffer
|
|
- Overlap handling for continuous processing
|
|
- Thread-safe operations
|
|
- Automatic sample rate tracking
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
sample_rate: int = 16000,
|
|
chunk_duration_sec: float = 3.0,
|
|
overlap_sec: float = 0.5,
|
|
max_buffer_sec: float = 60.0,
|
|
):
|
|
"""
|
|
Initialize the audio buffer.
|
|
|
|
Args:
|
|
sample_rate: Audio sample rate in Hz
|
|
chunk_duration_sec: Duration of each processing chunk
|
|
overlap_sec: Overlap between consecutive chunks
|
|
max_buffer_sec: Maximum buffer duration (older data will be discarded)
|
|
"""
|
|
self.sample_rate = sample_rate
|
|
self.chunk_size = int(chunk_duration_sec * sample_rate)
|
|
self.overlap_size = int(overlap_sec * sample_rate)
|
|
self.max_buffer_size = int(max_buffer_sec * sample_rate)
|
|
|
|
# Main buffer (pre-allocated)
|
|
self._buffer = np.zeros(self.max_buffer_size, dtype=np.float32)
|
|
self._write_pos = 0 # Next position to write
|
|
self._read_pos = 0 # Position of unprocessed data start
|
|
self._total_samples_received = 0 # Total samples since session start
|
|
|
|
self._lock = threading.Lock()
|
|
|
|
@property
|
|
def samples_available(self) -> int:
|
|
"""Number of unprocessed samples in buffer."""
|
|
with self._lock:
|
|
return self._write_pos - self._read_pos
|
|
|
|
@property
|
|
def duration_available_sec(self) -> float:
|
|
"""Duration of unprocessed audio in seconds."""
|
|
return self.samples_available / self.sample_rate
|
|
|
|
@property
|
|
def total_duration_sec(self) -> float:
|
|
"""Total duration of audio received since session start."""
|
|
return self._total_samples_received / self.sample_rate
|
|
|
|
def append(self, audio_chunk: np.ndarray) -> int:
|
|
"""
|
|
Append audio chunk to the buffer.
|
|
|
|
Args:
|
|
audio_chunk: Audio data as float32 array (range: -1.0 to 1.0)
|
|
|
|
Returns:
|
|
Number of samples actually appended
|
|
"""
|
|
if audio_chunk.dtype != np.float32:
|
|
audio_chunk = audio_chunk.astype(np.float32)
|
|
|
|
# Ensure 1D
|
|
if audio_chunk.ndim > 1:
|
|
audio_chunk = audio_chunk.flatten()
|
|
|
|
with self._lock:
|
|
chunk_len = len(audio_chunk)
|
|
|
|
# Check if we need to shift buffer (running out of space)
|
|
if self._write_pos + chunk_len > self.max_buffer_size:
|
|
self._compact_buffer()
|
|
|
|
# Still not enough space? Discard old unprocessed data
|
|
if self._write_pos + chunk_len > self.max_buffer_size:
|
|
overflow = (self._write_pos + chunk_len) - self.max_buffer_size
|
|
self._read_pos = min(self._read_pos + overflow, self._write_pos)
|
|
self._compact_buffer()
|
|
|
|
# Write to buffer
|
|
end_pos = self._write_pos + chunk_len
|
|
self._buffer[self._write_pos:end_pos] = audio_chunk
|
|
self._write_pos = end_pos
|
|
self._total_samples_received += chunk_len
|
|
|
|
return chunk_len
|
|
|
|
def _compact_buffer(self) -> None:
|
|
"""Move unprocessed data to the beginning of the buffer."""
|
|
if self._read_pos > 0:
|
|
unprocessed_len = self._write_pos - self._read_pos
|
|
if unprocessed_len > 0:
|
|
self._buffer[:unprocessed_len] = self._buffer[self._read_pos:self._write_pos]
|
|
self._write_pos = unprocessed_len
|
|
self._read_pos = 0
|
|
|
|
def get_chunk_for_inference(self, min_duration_sec: float = 0.5) -> Optional[AudioChunkInfo]:
|
|
"""
|
|
Get the next chunk for ASR inference.
|
|
|
|
Returns a chunk of audio when enough data is available.
|
|
The chunk includes overlap from the previous chunk for context.
|
|
|
|
Args:
|
|
min_duration_sec: Minimum duration required to return a chunk
|
|
|
|
Returns:
|
|
AudioChunkInfo if enough data is available, None otherwise
|
|
"""
|
|
min_samples = int(min_duration_sec * self.sample_rate)
|
|
|
|
with self._lock:
|
|
available = self._write_pos - self._read_pos
|
|
|
|
if available < min_samples:
|
|
return None
|
|
|
|
# Calculate chunk boundaries
|
|
chunk_start = self._read_pos
|
|
chunk_end = min(self._read_pos + self.chunk_size, self._write_pos)
|
|
actual_chunk_size = chunk_end - chunk_start
|
|
|
|
# Extract audio
|
|
audio = self._buffer[chunk_start:chunk_end].copy()
|
|
|
|
# Calculate timestamps based on total samples received
|
|
base_sample = self._total_samples_received - (self._write_pos - chunk_start)
|
|
start_sec = base_sample / self.sample_rate
|
|
end_sec = (base_sample + actual_chunk_size) / self.sample_rate
|
|
|
|
return AudioChunkInfo(
|
|
audio=audio,
|
|
start_sample=base_sample,
|
|
end_sample=base_sample + actual_chunk_size,
|
|
start_sec=start_sec,
|
|
end_sec=end_sec,
|
|
)
|
|
|
|
def mark_processed(self, samples: int) -> None:
|
|
"""
|
|
Mark samples as processed, advancing the read position.
|
|
|
|
Keeps overlap_size samples for context in the next chunk.
|
|
|
|
Args:
|
|
samples: Number of samples that were processed
|
|
"""
|
|
with self._lock:
|
|
# Advance read position but keep overlap for context
|
|
advance = max(0, samples - self.overlap_size)
|
|
self._read_pos = min(self._read_pos + advance, self._write_pos)
|
|
|
|
def get_segment(self, start_sec: float, end_sec: float) -> Optional[np.ndarray]:
|
|
"""
|
|
Get a specific time segment from the buffer.
|
|
|
|
Args:
|
|
start_sec: Start time in seconds (relative to session start)
|
|
end_sec: End time in seconds
|
|
|
|
Returns:
|
|
Audio segment if available, None otherwise
|
|
"""
|
|
start_sample = int(start_sec * self.sample_rate)
|
|
end_sample = int(end_sec * self.sample_rate)
|
|
|
|
with self._lock:
|
|
# Calculate buffer positions
|
|
buffer_start_sample = self._total_samples_received - self._write_pos
|
|
buffer_end_sample = self._total_samples_received
|
|
|
|
# Check if segment is in buffer
|
|
if start_sample < buffer_start_sample or end_sample > buffer_end_sample:
|
|
return None
|
|
|
|
# Convert to buffer indices
|
|
buf_start = start_sample - buffer_start_sample
|
|
buf_end = end_sample - buffer_start_sample
|
|
|
|
return self._buffer[buf_start:buf_end].copy()
|
|
|
|
def get_all_unprocessed(self) -> Optional[AudioChunkInfo]:
|
|
"""
|
|
Get all unprocessed audio.
|
|
|
|
Returns:
|
|
AudioChunkInfo with all unprocessed audio, or None if empty
|
|
"""
|
|
with self._lock:
|
|
if self._write_pos <= self._read_pos:
|
|
return None
|
|
|
|
audio = self._buffer[self._read_pos:self._write_pos].copy()
|
|
base_sample = self._total_samples_received - (self._write_pos - self._read_pos)
|
|
start_sec = base_sample / self.sample_rate
|
|
end_sec = self._total_samples_received / self.sample_rate
|
|
|
|
return AudioChunkInfo(
|
|
audio=audio,
|
|
start_sample=base_sample,
|
|
end_sample=self._total_samples_received,
|
|
start_sec=start_sec,
|
|
end_sec=end_sec,
|
|
)
|
|
|
|
def clear(self) -> None:
|
|
"""Clear the buffer and reset all positions."""
|
|
with self._lock:
|
|
self._buffer.fill(0)
|
|
self._write_pos = 0
|
|
self._read_pos = 0
|
|
self._total_samples_received = 0
|
|
|
|
def reset_read_position(self) -> None:
|
|
"""Reset read position to current write position (skip all unprocessed)."""
|
|
with self._lock:
|
|
self._read_pos = self._write_pos
|