All checks were successful
Deploy Docusaurus Site / deploy (push) Successful in 27s
402 lines
14 KiB
Python
402 lines
14 KiB
Python
"""
|
|
Session manager for real-time ASR.
|
|
|
|
Manages multiple concurrent client sessions with resource isolation.
|
|
"""
|
|
|
|
import asyncio
|
|
import time
|
|
import uuid
|
|
from typing import Dict, Optional, Callable, Any
|
|
from dataclasses import dataclass, field
|
|
import threading
|
|
|
|
from .models import (
|
|
SessionConfig,
|
|
TranscriptionResult,
|
|
VADEvent,
|
|
StatusMessage,
|
|
ErrorMessage,
|
|
MessageType,
|
|
SpeechSegment,
|
|
)
|
|
from .audio_buffer import AudioBuffer
|
|
from .vad_processor import VADProcessor
|
|
from .asr_worker import ASRWorker
|
|
|
|
|
|
@dataclass
|
|
class SessionStats:
|
|
"""Statistics for a session."""
|
|
created_at: float = field(default_factory=time.time)
|
|
last_activity: float = field(default_factory=time.time)
|
|
audio_received_sec: float = 0.0
|
|
chunks_received: int = 0
|
|
segments_transcribed: int = 0
|
|
total_latency_ms: float = 0.0
|
|
|
|
|
|
class RealtimeSession:
|
|
"""
|
|
A single real-time ASR session.
|
|
|
|
Manages audio buffering, VAD, and ASR for one client connection.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
session_id: str,
|
|
asr_worker: ASRWorker,
|
|
config: Optional[SessionConfig] = None,
|
|
on_result: Optional[Callable[[TranscriptionResult], Any]] = None,
|
|
on_vad_event: Optional[Callable[[VADEvent], Any]] = None,
|
|
):
|
|
"""
|
|
Initialize a session.
|
|
|
|
Args:
|
|
session_id: Unique session identifier
|
|
asr_worker: Shared ASR worker instance
|
|
config: Session configuration
|
|
on_result: Callback for transcription results
|
|
on_vad_event: Callback for VAD events
|
|
"""
|
|
self.session_id = session_id
|
|
self.asr_worker = asr_worker
|
|
self.config = config or SessionConfig()
|
|
self.on_result = on_result
|
|
self.on_vad_event = on_vad_event
|
|
|
|
# Components
|
|
self.audio_buffer = AudioBuffer(
|
|
sample_rate=self.config.sample_rate,
|
|
chunk_duration_sec=self.config.chunk_duration_sec,
|
|
overlap_sec=self.config.overlap_sec,
|
|
)
|
|
|
|
self.vad_processor = VADProcessor(
|
|
sample_rate=self.config.sample_rate,
|
|
threshold=self.config.vad_threshold,
|
|
min_speech_duration_ms=self.config.min_speech_duration_ms,
|
|
min_silence_duration_ms=self.config.min_silence_duration_ms,
|
|
min_volume_threshold=self.config.min_volume_threshold,
|
|
)
|
|
|
|
# State
|
|
self.is_active = True
|
|
self.stats = SessionStats()
|
|
self._processing_lock = asyncio.Lock()
|
|
self._pending_tasks: list = []
|
|
|
|
async def process_audio_chunk(self, audio_data: bytes) -> None:
|
|
"""
|
|
Process an incoming audio chunk.
|
|
|
|
Args:
|
|
audio_data: Raw PCM audio data (16-bit, 16kHz, mono)
|
|
"""
|
|
if not self.is_active:
|
|
return
|
|
|
|
self.stats.last_activity = time.time()
|
|
self.stats.chunks_received += 1
|
|
|
|
# Convert bytes to float32 array
|
|
import numpy as np
|
|
audio_int16 = np.frombuffer(audio_data, dtype=np.int16)
|
|
audio_float = audio_int16.astype(np.float32) / 32768.0
|
|
|
|
self.stats.audio_received_sec += len(audio_float) / self.config.sample_rate
|
|
|
|
# Add to buffer
|
|
self.audio_buffer.append(audio_float)
|
|
|
|
# Process with VAD
|
|
segments, events = self.vad_processor.process(audio_float)
|
|
|
|
# Send VAD events
|
|
if self.on_vad_event:
|
|
for event in events:
|
|
await self._send_callback(self.on_vad_event, event)
|
|
|
|
# Process completed speech segments
|
|
for segment in segments:
|
|
try:
|
|
await self._transcribe_segment(segment)
|
|
except Exception as e:
|
|
print(f"[Session {self.session_id}] Transcription error: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
async def _transcribe_segment(self, segment: SpeechSegment) -> None:
|
|
"""Transcribe a detected speech segment."""
|
|
# Get audio for segment from buffer
|
|
audio = self.audio_buffer.get_segment(segment.start_sec, segment.end_sec)
|
|
|
|
if audio is None or len(audio) == 0:
|
|
print(f"[Session {self.session_id}] Could not retrieve audio for segment")
|
|
return
|
|
|
|
self.stats.segments_transcribed += 1
|
|
|
|
async def on_partial(result: TranscriptionResult):
|
|
if self.on_result:
|
|
await self._send_callback(self.on_result, result)
|
|
|
|
# Run transcription
|
|
result = await self.asr_worker.transcribe_segment(
|
|
audio=audio,
|
|
sample_rate=self.config.sample_rate,
|
|
context_info=self.config.context_info,
|
|
segment_start_sec=segment.start_sec,
|
|
segment_end_sec=segment.end_sec,
|
|
config=self.config,
|
|
on_partial=on_partial if self.config.return_partial_results else None,
|
|
)
|
|
|
|
self.stats.total_latency_ms += result.latency_ms
|
|
|
|
# Send final result
|
|
if self.on_result:
|
|
await self._send_callback(self.on_result, result)
|
|
|
|
async def _send_callback(self, callback: Callable, data: Any) -> None:
|
|
"""Send data via callback, handling both sync and async."""
|
|
try:
|
|
if asyncio.iscoroutinefunction(callback):
|
|
await callback(data)
|
|
else:
|
|
callback(data)
|
|
except Exception as e:
|
|
print(f"[Session {self.session_id}] Callback error: {e}")
|
|
|
|
async def flush(self) -> None:
|
|
"""
|
|
Flush any remaining audio and force transcription.
|
|
|
|
Called when session ends to process any remaining speech.
|
|
"""
|
|
# Force end any active speech
|
|
segment = self.vad_processor.force_end_speech()
|
|
if segment:
|
|
await self._transcribe_segment(segment)
|
|
|
|
# Also check for any unprocessed audio in buffer
|
|
chunk_info = self.audio_buffer.get_all_unprocessed()
|
|
if chunk_info and len(chunk_info.audio) > self.config.sample_rate * 0.5:
|
|
# More than 0.5 seconds of unprocessed audio
|
|
forced_segment = SpeechSegment(
|
|
start_sample=chunk_info.start_sample,
|
|
end_sample=chunk_info.end_sample,
|
|
start_sec=chunk_info.start_sec,
|
|
end_sec=chunk_info.end_sec,
|
|
)
|
|
await self._transcribe_segment(forced_segment)
|
|
|
|
def update_config(self, new_config: SessionConfig) -> None:
|
|
"""Update session configuration (partial update supported)."""
|
|
# Merge with existing config - only update non-default values
|
|
if new_config.vad_threshold != 0.5:
|
|
self.config.vad_threshold = new_config.vad_threshold
|
|
if new_config.min_speech_duration_ms != 250:
|
|
self.config.min_speech_duration_ms = new_config.min_speech_duration_ms
|
|
if new_config.min_silence_duration_ms != 500:
|
|
self.config.min_silence_duration_ms = new_config.min_silence_duration_ms
|
|
if new_config.min_volume_threshold != 0.01:
|
|
self.config.min_volume_threshold = new_config.min_volume_threshold
|
|
if new_config.context_info is not None:
|
|
self.config.context_info = new_config.context_info
|
|
|
|
# Recreate VAD processor with new parameters
|
|
self.vad_processor = VADProcessor(
|
|
sample_rate=self.config.sample_rate,
|
|
threshold=self.config.vad_threshold,
|
|
min_speech_duration_ms=self.config.min_speech_duration_ms,
|
|
min_silence_duration_ms=self.config.min_silence_duration_ms,
|
|
min_volume_threshold=self.config.min_volume_threshold,
|
|
)
|
|
print(f"[Session {self.session_id}] Config updated: vad_threshold={self.config.vad_threshold}, "
|
|
f"min_speech={self.config.min_speech_duration_ms}ms, min_silence={self.config.min_silence_duration_ms}ms, "
|
|
f"min_volume={self.config.min_volume_threshold}")
|
|
|
|
def close(self) -> None:
|
|
"""Close the session and release resources."""
|
|
self.is_active = False
|
|
self.audio_buffer.clear()
|
|
self.vad_processor.reset()
|
|
|
|
def get_stats(self) -> Dict:
|
|
"""Get session statistics."""
|
|
return {
|
|
"session_id": self.session_id,
|
|
"created_at": self.stats.created_at,
|
|
"last_activity": self.stats.last_activity,
|
|
"duration_sec": time.time() - self.stats.created_at,
|
|
"audio_received_sec": self.stats.audio_received_sec,
|
|
"chunks_received": self.stats.chunks_received,
|
|
"segments_transcribed": self.stats.segments_transcribed,
|
|
"avg_latency_ms": (
|
|
self.stats.total_latency_ms / self.stats.segments_transcribed
|
|
if self.stats.segments_transcribed > 0 else 0
|
|
),
|
|
"is_active": self.is_active,
|
|
"vad_speech_active": self.vad_processor.is_speech_active,
|
|
}
|
|
|
|
|
|
class SessionManager:
|
|
"""
|
|
Manages multiple concurrent ASR sessions.
|
|
|
|
Features:
|
|
- Session creation and cleanup
|
|
- Resource limiting (max concurrent sessions)
|
|
- Idle session timeout
|
|
- Shared ASR worker management
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
asr_worker: ASRWorker,
|
|
max_concurrent_sessions: int = 10,
|
|
session_timeout_sec: float = 300.0,
|
|
):
|
|
"""
|
|
Initialize the session manager.
|
|
|
|
Args:
|
|
asr_worker: Shared ASR worker
|
|
max_concurrent_sessions: Maximum number of concurrent sessions
|
|
session_timeout_sec: Timeout for idle sessions
|
|
"""
|
|
self.asr_worker = asr_worker
|
|
self.max_sessions = max_concurrent_sessions
|
|
self.session_timeout = session_timeout_sec
|
|
|
|
self._sessions: Dict[str, RealtimeSession] = {}
|
|
self._lock = asyncio.Lock()
|
|
|
|
# Cleanup task
|
|
self._cleanup_task: Optional[asyncio.Task] = None
|
|
|
|
async def start(self) -> None:
|
|
"""Start the session manager."""
|
|
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
|
|
|
async def stop(self) -> None:
|
|
"""Stop the session manager and close all sessions."""
|
|
if self._cleanup_task:
|
|
self._cleanup_task.cancel()
|
|
try:
|
|
await self._cleanup_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
async with self._lock:
|
|
for session in self._sessions.values():
|
|
session.close()
|
|
self._sessions.clear()
|
|
|
|
async def create_session(
|
|
self,
|
|
session_id: Optional[str] = None,
|
|
config: Optional[SessionConfig] = None,
|
|
on_result: Optional[Callable[[TranscriptionResult], Any]] = None,
|
|
on_vad_event: Optional[Callable[[VADEvent], Any]] = None,
|
|
) -> Optional[RealtimeSession]:
|
|
"""
|
|
Create a new session.
|
|
|
|
Args:
|
|
session_id: Optional session ID (generated if not provided)
|
|
config: Session configuration
|
|
on_result: Callback for results
|
|
on_vad_event: Callback for VAD events
|
|
|
|
Returns:
|
|
Created session, or None if limit reached
|
|
"""
|
|
async with self._lock:
|
|
# Check session limit
|
|
if len(self._sessions) >= self.max_sessions:
|
|
return None
|
|
|
|
# Generate session ID if not provided
|
|
if session_id is None:
|
|
session_id = str(uuid.uuid4())[:8]
|
|
|
|
# Check for duplicate
|
|
if session_id in self._sessions:
|
|
return self._sessions[session_id]
|
|
|
|
# Create session
|
|
session = RealtimeSession(
|
|
session_id=session_id,
|
|
asr_worker=self.asr_worker,
|
|
config=config,
|
|
on_result=on_result,
|
|
on_vad_event=on_vad_event,
|
|
)
|
|
|
|
self._sessions[session_id] = session
|
|
return session
|
|
|
|
async def get_session(self, session_id: str) -> Optional[RealtimeSession]:
|
|
"""Get a session by ID."""
|
|
async with self._lock:
|
|
return self._sessions.get(session_id)
|
|
|
|
async def close_session(self, session_id: str) -> bool:
|
|
"""
|
|
Close and remove a session.
|
|
|
|
Args:
|
|
session_id: Session to close
|
|
|
|
Returns:
|
|
True if session was found and closed
|
|
"""
|
|
async with self._lock:
|
|
session = self._sessions.pop(session_id, None)
|
|
if session:
|
|
await session.flush()
|
|
session.close()
|
|
return True
|
|
return False
|
|
|
|
async def _cleanup_loop(self) -> None:
|
|
"""Background task to clean up idle sessions."""
|
|
while True:
|
|
try:
|
|
await asyncio.sleep(60) # Check every minute
|
|
|
|
current_time = time.time()
|
|
sessions_to_close = []
|
|
|
|
async with self._lock:
|
|
for session_id, session in self._sessions.items():
|
|
idle_time = current_time - session.stats.last_activity
|
|
if idle_time > self.session_timeout:
|
|
sessions_to_close.append(session_id)
|
|
|
|
for session_id in sessions_to_close:
|
|
print(f"Closing idle session: {session_id}")
|
|
await self.close_session(session_id)
|
|
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception as e:
|
|
print(f"Cleanup error: {e}")
|
|
|
|
def get_stats(self) -> Dict:
|
|
"""Get manager statistics."""
|
|
return {
|
|
"active_sessions": len(self._sessions),
|
|
"max_sessions": self.max_sessions,
|
|
"session_timeout_sec": self.session_timeout,
|
|
"sessions": {
|
|
sid: session.get_stats()
|
|
for sid, session in self._sessions.items()
|
|
},
|
|
}
|