All checks were successful
Deploy Docusaurus Site / deploy (push) Successful in 27s
155 lines
3.8 KiB
Python
155 lines
3.8 KiB
Python
"""
|
|
Data models for real-time ASR WebSocket communication.
|
|
"""
|
|
|
|
from enum import Enum
|
|
from typing import Optional, List, Dict, Any
|
|
from dataclasses import dataclass, field, asdict
|
|
import time
|
|
|
|
|
|
class MessageType(str, Enum):
|
|
"""WebSocket message types."""
|
|
# Client -> Server
|
|
AUDIO_CHUNK = "audio_chunk"
|
|
CONFIG = "config"
|
|
START = "start"
|
|
STOP = "stop"
|
|
|
|
# Server -> Client
|
|
PARTIAL_RESULT = "partial_result"
|
|
FINAL_RESULT = "final_result"
|
|
VAD_EVENT = "vad_event"
|
|
ERROR = "error"
|
|
STATUS = "status"
|
|
|
|
|
|
class VADEventType(str, Enum):
|
|
"""VAD event types."""
|
|
SPEECH_START = "speech_start"
|
|
SPEECH_END = "speech_end"
|
|
|
|
|
|
@dataclass
|
|
class SessionConfig:
|
|
"""Configuration for a real-time ASR session."""
|
|
# Audio parameters
|
|
sample_rate: int = 16000
|
|
chunk_duration_sec: float = 3.0
|
|
overlap_sec: float = 0.5
|
|
|
|
# VAD parameters
|
|
vad_threshold: float = 0.5
|
|
min_speech_duration_ms: int = 250
|
|
min_silence_duration_ms: int = 500
|
|
min_volume_threshold: float = 0.01 # Minimum RMS volume (0.0-1.0) to consider as potential speech
|
|
|
|
# ASR parameters
|
|
max_new_tokens: int = 512
|
|
temperature: float = 0.0
|
|
context_info: Optional[str] = None
|
|
|
|
# Behavior
|
|
return_partial_results: bool = True
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return asdict(self)
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> "SessionConfig":
|
|
return cls(**{k: v for k, v in data.items() if k in cls.__dataclass_fields__})
|
|
|
|
|
|
@dataclass
|
|
class TranscriptionSegment:
|
|
"""A single transcription segment with metadata."""
|
|
start_time: float
|
|
end_time: float
|
|
speaker_id: str
|
|
text: str
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return asdict(self)
|
|
|
|
|
|
@dataclass
|
|
class TranscriptionResult:
|
|
"""Transcription result message."""
|
|
type: MessageType
|
|
text: str
|
|
is_final: bool
|
|
segments: List[TranscriptionSegment] = field(default_factory=list)
|
|
latency_ms: float = 0.0
|
|
timestamp: float = field(default_factory=time.time)
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {
|
|
"type": self.type.value,
|
|
"text": self.text,
|
|
"is_final": self.is_final,
|
|
"segments": [s.to_dict() for s in self.segments],
|
|
"latency_ms": self.latency_ms,
|
|
"timestamp": self.timestamp,
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class VADEvent:
|
|
"""VAD event message."""
|
|
type: MessageType = MessageType.VAD_EVENT
|
|
event: VADEventType = VADEventType.SPEECH_START
|
|
timestamp: float = field(default_factory=time.time)
|
|
audio_timestamp_sec: float = 0.0
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {
|
|
"type": self.type.value,
|
|
"event": self.event.value,
|
|
"timestamp": self.timestamp,
|
|
"audio_timestamp_sec": self.audio_timestamp_sec,
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class StatusMessage:
|
|
"""Status message."""
|
|
type: MessageType = MessageType.STATUS
|
|
status: str = ""
|
|
message: str = ""
|
|
timestamp: float = field(default_factory=time.time)
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {
|
|
"type": self.type.value,
|
|
"status": self.status,
|
|
"message": self.message,
|
|
"timestamp": self.timestamp,
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class ErrorMessage:
|
|
"""Error message."""
|
|
type: MessageType = MessageType.ERROR
|
|
error: str = ""
|
|
code: str = ""
|
|
timestamp: float = field(default_factory=time.time)
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {
|
|
"type": self.type.value,
|
|
"error": self.error,
|
|
"code": self.code,
|
|
"timestamp": self.timestamp,
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class SpeechSegment:
|
|
"""Detected speech segment from VAD."""
|
|
start_sample: int
|
|
end_sample: int
|
|
start_sec: float
|
|
end_sec: float
|
|
confidence: float = 1.0
|