""" Session manager for real-time ASR. Manages multiple concurrent client sessions with resource isolation. """ import asyncio import time import uuid from typing import Dict, Optional, Callable, Any from dataclasses import dataclass, field import threading from .models import ( SessionConfig, TranscriptionResult, VADEvent, StatusMessage, ErrorMessage, MessageType, SpeechSegment, ) from .audio_buffer import AudioBuffer from .vad_processor import VADProcessor from .asr_worker import ASRWorker @dataclass class SessionStats: """Statistics for a session.""" created_at: float = field(default_factory=time.time) last_activity: float = field(default_factory=time.time) audio_received_sec: float = 0.0 chunks_received: int = 0 segments_transcribed: int = 0 total_latency_ms: float = 0.0 class RealtimeSession: """ A single real-time ASR session. Manages audio buffering, VAD, and ASR for one client connection. """ def __init__( self, session_id: str, asr_worker: ASRWorker, config: Optional[SessionConfig] = None, on_result: Optional[Callable[[TranscriptionResult], Any]] = None, on_vad_event: Optional[Callable[[VADEvent], Any]] = None, ): """ Initialize a session. Args: session_id: Unique session identifier asr_worker: Shared ASR worker instance config: Session configuration on_result: Callback for transcription results on_vad_event: Callback for VAD events """ self.session_id = session_id self.asr_worker = asr_worker self.config = config or SessionConfig() self.on_result = on_result self.on_vad_event = on_vad_event # Components self.audio_buffer = AudioBuffer( sample_rate=self.config.sample_rate, chunk_duration_sec=self.config.chunk_duration_sec, overlap_sec=self.config.overlap_sec, ) self.vad_processor = VADProcessor( sample_rate=self.config.sample_rate, threshold=self.config.vad_threshold, min_speech_duration_ms=self.config.min_speech_duration_ms, min_silence_duration_ms=self.config.min_silence_duration_ms, min_volume_threshold=self.config.min_volume_threshold, ) # State self.is_active = True self.stats = SessionStats() self._processing_lock = asyncio.Lock() self._pending_tasks: list = [] async def process_audio_chunk(self, audio_data: bytes) -> None: """ Process an incoming audio chunk. Args: audio_data: Raw PCM audio data (16-bit, 16kHz, mono) """ if not self.is_active: return self.stats.last_activity = time.time() self.stats.chunks_received += 1 # Convert bytes to float32 array import numpy as np audio_int16 = np.frombuffer(audio_data, dtype=np.int16) audio_float = audio_int16.astype(np.float32) / 32768.0 self.stats.audio_received_sec += len(audio_float) / self.config.sample_rate # Add to buffer self.audio_buffer.append(audio_float) # Process with VAD segments, events = self.vad_processor.process(audio_float) # Send VAD events if self.on_vad_event: for event in events: await self._send_callback(self.on_vad_event, event) # Process completed speech segments for segment in segments: try: await self._transcribe_segment(segment) except Exception as e: print(f"[Session {self.session_id}] Transcription error: {e}") import traceback traceback.print_exc() async def _transcribe_segment(self, segment: SpeechSegment) -> None: """Transcribe a detected speech segment.""" # Get audio for segment from buffer audio = self.audio_buffer.get_segment(segment.start_sec, segment.end_sec) if audio is None or len(audio) == 0: print(f"[Session {self.session_id}] Could not retrieve audio for segment") return self.stats.segments_transcribed += 1 async def on_partial(result: TranscriptionResult): if self.on_result: await self._send_callback(self.on_result, result) # Run transcription result = await self.asr_worker.transcribe_segment( audio=audio, sample_rate=self.config.sample_rate, context_info=self.config.context_info, segment_start_sec=segment.start_sec, segment_end_sec=segment.end_sec, config=self.config, on_partial=on_partial if self.config.return_partial_results else None, ) self.stats.total_latency_ms += result.latency_ms # Send final result if self.on_result: await self._send_callback(self.on_result, result) async def _send_callback(self, callback: Callable, data: Any) -> None: """Send data via callback, handling both sync and async.""" try: if asyncio.iscoroutinefunction(callback): await callback(data) else: callback(data) except Exception as e: print(f"[Session {self.session_id}] Callback error: {e}") async def flush(self) -> None: """ Flush any remaining audio and force transcription. Called when session ends to process any remaining speech. """ # Force end any active speech segment = self.vad_processor.force_end_speech() if segment: await self._transcribe_segment(segment) # Also check for any unprocessed audio in buffer chunk_info = self.audio_buffer.get_all_unprocessed() if chunk_info and len(chunk_info.audio) > self.config.sample_rate * 0.5: # More than 0.5 seconds of unprocessed audio forced_segment = SpeechSegment( start_sample=chunk_info.start_sample, end_sample=chunk_info.end_sample, start_sec=chunk_info.start_sec, end_sec=chunk_info.end_sec, ) await self._transcribe_segment(forced_segment) def update_config(self, new_config: SessionConfig) -> None: """Update session configuration (partial update supported).""" # Merge with existing config - only update non-default values if new_config.vad_threshold != 0.5: self.config.vad_threshold = new_config.vad_threshold if new_config.min_speech_duration_ms != 250: self.config.min_speech_duration_ms = new_config.min_speech_duration_ms if new_config.min_silence_duration_ms != 500: self.config.min_silence_duration_ms = new_config.min_silence_duration_ms if new_config.min_volume_threshold != 0.01: self.config.min_volume_threshold = new_config.min_volume_threshold if new_config.context_info is not None: self.config.context_info = new_config.context_info # Recreate VAD processor with new parameters self.vad_processor = VADProcessor( sample_rate=self.config.sample_rate, threshold=self.config.vad_threshold, min_speech_duration_ms=self.config.min_speech_duration_ms, min_silence_duration_ms=self.config.min_silence_duration_ms, min_volume_threshold=self.config.min_volume_threshold, ) print(f"[Session {self.session_id}] Config updated: vad_threshold={self.config.vad_threshold}, " f"min_speech={self.config.min_speech_duration_ms}ms, min_silence={self.config.min_silence_duration_ms}ms, " f"min_volume={self.config.min_volume_threshold}") def close(self) -> None: """Close the session and release resources.""" self.is_active = False self.audio_buffer.clear() self.vad_processor.reset() def get_stats(self) -> Dict: """Get session statistics.""" return { "session_id": self.session_id, "created_at": self.stats.created_at, "last_activity": self.stats.last_activity, "duration_sec": time.time() - self.stats.created_at, "audio_received_sec": self.stats.audio_received_sec, "chunks_received": self.stats.chunks_received, "segments_transcribed": self.stats.segments_transcribed, "avg_latency_ms": ( self.stats.total_latency_ms / self.stats.segments_transcribed if self.stats.segments_transcribed > 0 else 0 ), "is_active": self.is_active, "vad_speech_active": self.vad_processor.is_speech_active, } class SessionManager: """ Manages multiple concurrent ASR sessions. Features: - Session creation and cleanup - Resource limiting (max concurrent sessions) - Idle session timeout - Shared ASR worker management """ def __init__( self, asr_worker: ASRWorker, max_concurrent_sessions: int = 10, session_timeout_sec: float = 300.0, ): """ Initialize the session manager. Args: asr_worker: Shared ASR worker max_concurrent_sessions: Maximum number of concurrent sessions session_timeout_sec: Timeout for idle sessions """ self.asr_worker = asr_worker self.max_sessions = max_concurrent_sessions self.session_timeout = session_timeout_sec self._sessions: Dict[str, RealtimeSession] = {} self._lock = asyncio.Lock() # Cleanup task self._cleanup_task: Optional[asyncio.Task] = None async def start(self) -> None: """Start the session manager.""" self._cleanup_task = asyncio.create_task(self._cleanup_loop()) async def stop(self) -> None: """Stop the session manager and close all sessions.""" if self._cleanup_task: self._cleanup_task.cancel() try: await self._cleanup_task except asyncio.CancelledError: pass async with self._lock: for session in self._sessions.values(): session.close() self._sessions.clear() async def create_session( self, session_id: Optional[str] = None, config: Optional[SessionConfig] = None, on_result: Optional[Callable[[TranscriptionResult], Any]] = None, on_vad_event: Optional[Callable[[VADEvent], Any]] = None, ) -> Optional[RealtimeSession]: """ Create a new session. Args: session_id: Optional session ID (generated if not provided) config: Session configuration on_result: Callback for results on_vad_event: Callback for VAD events Returns: Created session, or None if limit reached """ async with self._lock: # Check session limit if len(self._sessions) >= self.max_sessions: return None # Generate session ID if not provided if session_id is None: session_id = str(uuid.uuid4())[:8] # Check for duplicate if session_id in self._sessions: return self._sessions[session_id] # Create session session = RealtimeSession( session_id=session_id, asr_worker=self.asr_worker, config=config, on_result=on_result, on_vad_event=on_vad_event, ) self._sessions[session_id] = session return session async def get_session(self, session_id: str) -> Optional[RealtimeSession]: """Get a session by ID.""" async with self._lock: return self._sessions.get(session_id) async def close_session(self, session_id: str) -> bool: """ Close and remove a session. Args: session_id: Session to close Returns: True if session was found and closed """ async with self._lock: session = self._sessions.pop(session_id, None) if session: await session.flush() session.close() return True return False async def _cleanup_loop(self) -> None: """Background task to clean up idle sessions.""" while True: try: await asyncio.sleep(60) # Check every minute current_time = time.time() sessions_to_close = [] async with self._lock: for session_id, session in self._sessions.items(): idle_time = current_time - session.stats.last_activity if idle_time > self.session_timeout: sessions_to_close.append(session_id) for session_id in sessions_to_close: print(f"Closing idle session: {session_id}") await self.close_session(session_id) except asyncio.CancelledError: break except Exception as e: print(f"Cleanup error: {e}") def get_stats(self) -> Dict: """Get manager statistics.""" return { "active_sessions": len(self._sessions), "max_sessions": self.max_sessions, "session_timeout_sec": self.session_timeout, "sessions": { sid: session.get_stats() for sid, session in self._sessions.items() }, }