""" Data models for real-time ASR WebSocket communication. """ from enum import Enum from typing import Optional, List, Dict, Any from dataclasses import dataclass, field, asdict import time class MessageType(str, Enum): """WebSocket message types.""" # Client -> Server AUDIO_CHUNK = "audio_chunk" CONFIG = "config" START = "start" STOP = "stop" # Server -> Client PARTIAL_RESULT = "partial_result" FINAL_RESULT = "final_result" VAD_EVENT = "vad_event" ERROR = "error" STATUS = "status" class VADEventType(str, Enum): """VAD event types.""" SPEECH_START = "speech_start" SPEECH_END = "speech_end" @dataclass class SessionConfig: """Configuration for a real-time ASR session.""" # Audio parameters sample_rate: int = 16000 chunk_duration_sec: float = 3.0 overlap_sec: float = 0.5 # VAD parameters vad_threshold: float = 0.5 min_speech_duration_ms: int = 250 min_silence_duration_ms: int = 500 min_volume_threshold: float = 0.01 # Minimum RMS volume (0.0-1.0) to consider as potential speech # ASR parameters max_new_tokens: int = 512 temperature: float = 0.0 context_info: Optional[str] = None # Behavior return_partial_results: bool = True def to_dict(self) -> Dict[str, Any]: return asdict(self) @classmethod def from_dict(cls, data: Dict[str, Any]) -> "SessionConfig": return cls(**{k: v for k, v in data.items() if k in cls.__dataclass_fields__}) @dataclass class TranscriptionSegment: """A single transcription segment with metadata.""" start_time: float end_time: float speaker_id: str text: str def to_dict(self) -> Dict[str, Any]: return asdict(self) @dataclass class TranscriptionResult: """Transcription result message.""" type: MessageType text: str is_final: bool segments: List[TranscriptionSegment] = field(default_factory=list) latency_ms: float = 0.0 timestamp: float = field(default_factory=time.time) def to_dict(self) -> Dict[str, Any]: return { "type": self.type.value, "text": self.text, "is_final": self.is_final, "segments": [s.to_dict() for s in self.segments], "latency_ms": self.latency_ms, "timestamp": self.timestamp, } @dataclass class VADEvent: """VAD event message.""" type: MessageType = MessageType.VAD_EVENT event: VADEventType = VADEventType.SPEECH_START timestamp: float = field(default_factory=time.time) audio_timestamp_sec: float = 0.0 def to_dict(self) -> Dict[str, Any]: return { "type": self.type.value, "event": self.event.value, "timestamp": self.timestamp, "audio_timestamp_sec": self.audio_timestamp_sec, } @dataclass class StatusMessage: """Status message.""" type: MessageType = MessageType.STATUS status: str = "" message: str = "" timestamp: float = field(default_factory=time.time) def to_dict(self) -> Dict[str, Any]: return { "type": self.type.value, "status": self.status, "message": self.message, "timestamp": self.timestamp, } @dataclass class ErrorMessage: """Error message.""" type: MessageType = MessageType.ERROR error: str = "" code: str = "" timestamp: float = field(default_factory=time.time) def to_dict(self) -> Dict[str, Any]: return { "type": self.type.value, "error": self.error, "code": self.code, "timestamp": self.timestamp, } @dataclass class SpeechSegment: """Detected speech segment from VAD.""" start_sample: int end_sample: int start_sec: float end_sec: float confidence: float = 1.0