""" ASR Worker for real-time transcription. Wraps the existing VibeVoiceASRInference for async/streaming operation. """ import sys import os import asyncio import threading import time import queue from typing import AsyncGenerator, Optional, List, Callable from dataclasses import dataclass import numpy as np import torch # Add parent directory and demo directory to path for importing existing code _parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, _parent_dir) sys.path.insert(0, os.path.join(_parent_dir, "demo")) from .models import ( TranscriptionResult, TranscriptionSegment, MessageType, SessionConfig, ) @dataclass class InferenceRequest: """Request for ASR inference.""" audio: np.ndarray sample_rate: int context_info: Optional[str] request_time: float segment_start_sec: float segment_end_sec: float class ASRWorker: """ ASR Worker that wraps VibeVoiceASRInference for real-time use. Features: - Async interface for WebSocket integration - Streaming output via TextIteratorStreamer - Request queuing for handling concurrent segments - Graceful model loading and error handling """ def __init__( self, model_path: str = "microsoft/VibeVoice-ASR", device: str = "cuda", dtype: torch.dtype = torch.bfloat16, attn_implementation: str = "flash_attention_2", ): """ Initialize the ASR worker. Args: model_path: Path to VibeVoice ASR model device: Device to run inference on dtype: Model data type attn_implementation: Attention implementation """ self.model_path = model_path self.device = device self.dtype = dtype self.attn_implementation = attn_implementation self._inference = None self._is_loaded = False self._load_lock = threading.Lock() # Inference queue for serializing requests self._inference_semaphore = asyncio.Semaphore(1) def load_model(self) -> bool: """ Load the ASR model. Returns: True if model loaded successfully """ with self._load_lock: if self._is_loaded: return True try: # Import here to avoid circular imports and allow lazy loading # In Docker, the file is copied as vibevoice_asr_gradio_demo.py try: from vibevoice_asr_gradio_demo import VibeVoiceASRInference except ImportError: from vibevoice_asr_gradio_demo_patched import VibeVoiceASRInference print(f"Loading VibeVoice ASR model from {self.model_path}...") self._inference = VibeVoiceASRInference( model_path=self.model_path, device=self.device, dtype=self.dtype, attn_implementation=self.attn_implementation, ) self._is_loaded = True print("ASR model loaded successfully") return True except Exception as e: print(f"Failed to load ASR model: {e}") import traceback traceback.print_exc() return False @property def is_loaded(self) -> bool: """Check if model is loaded.""" return self._is_loaded async def transcribe_segment( self, audio: np.ndarray, sample_rate: int = 16000, context_info: Optional[str] = None, segment_start_sec: float = 0.0, segment_end_sec: float = 0.0, config: Optional[SessionConfig] = None, on_partial: Optional[Callable[[TranscriptionResult], None]] = None, ) -> TranscriptionResult: """ Transcribe an audio segment asynchronously. Args: audio: Audio data as float32 array sample_rate: Audio sample rate context_info: Optional context for transcription segment_start_sec: Start time of segment in session segment_end_sec: End time of segment in session config: Session configuration on_partial: Callback for partial results Returns: Final transcription result """ if not self._is_loaded: if not self.load_model(): return TranscriptionResult( type=MessageType.ERROR, text="", is_final=True, latency_ms=0, ) config = config or SessionConfig() request_time = time.time() # Serialize inference requests async with self._inference_semaphore: return await self._run_inference( audio=audio, sample_rate=sample_rate, context_info=context_info, segment_start_sec=segment_start_sec, segment_end_sec=segment_end_sec, config=config, request_time=request_time, on_partial=on_partial, ) async def _run_inference( self, audio: np.ndarray, sample_rate: int, context_info: Optional[str], segment_start_sec: float, segment_end_sec: float, config: SessionConfig, request_time: float, on_partial: Optional[Callable[[TranscriptionResult], None]], ) -> TranscriptionResult: """Run the actual inference in a thread pool.""" from transformers import TextIteratorStreamer # Create streamer for partial results streamer = None if config.return_partial_results and on_partial: streamer = TextIteratorStreamer( self._inference.processor.tokenizer, skip_prompt=True, skip_special_tokens=True, ) # Result container for thread result_container = {"result": None, "error": None} def run_inference(): try: # Save audio to temp file (required by current implementation) import tempfile import soundfile as sf with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: temp_path = f.name # Write audio audio_int16 = (audio * 32768.0).clip(-32768, 32767).astype(np.int16) sf.write(temp_path, audio_int16, sample_rate, subtype='PCM_16') try: result = self._inference.transcribe( audio_path=temp_path, max_new_tokens=config.max_new_tokens, temperature=config.temperature, context_info=context_info, streamer=streamer, ) result_container["result"] = result finally: # Clean up temp file try: os.unlink(temp_path) except: pass except Exception as e: result_container["error"] = str(e) import traceback traceback.print_exc() # Start inference in background thread inference_thread = threading.Thread(target=run_inference) inference_thread.start() # Stream partial results if enabled partial_text = "" if streamer and on_partial: try: for new_text in streamer: partial_text += new_text partial_result = TranscriptionResult( type=MessageType.PARTIAL_RESULT, text=partial_text, is_final=False, latency_ms=(time.time() - request_time) * 1000, ) # Call callback (may be async) if asyncio.iscoroutinefunction(on_partial): await on_partial(partial_result) else: on_partial(partial_result) except Exception as e: print(f"Error during streaming: {e}") # Wait for completion inference_thread.join() latency_ms = (time.time() - request_time) * 1000 if result_container["error"]: return TranscriptionResult( type=MessageType.ERROR, text=f"Error: {result_container['error']}", is_final=True, latency_ms=latency_ms, ) result = result_container["result"] # Convert segments to our format segments = [] for seg in result.get("segments", []): # Adjust timestamps relative to session seg_start = seg.get("start_time", 0) seg_end = seg.get("end_time", 0) # If segment has relative timestamps, adjust to absolute if isinstance(seg_start, (int, float)) and isinstance(seg_end, (int, float)): adjusted_start = segment_start_sec + seg_start adjusted_end = segment_start_sec + seg_end else: adjusted_start = segment_start_sec adjusted_end = segment_end_sec segments.append(TranscriptionSegment( start_time=adjusted_start, end_time=adjusted_end, speaker_id=seg.get("speaker_id", "SPEAKER_00"), text=seg.get("text", ""), )) return TranscriptionResult( type=MessageType.FINAL_RESULT, text=result.get("raw_text", ""), is_final=True, segments=segments, latency_ms=latency_ms, ) async def transcribe_stream( self, audio: np.ndarray, sample_rate: int = 16000, context_info: Optional[str] = None, segment_start_sec: float = 0.0, segment_end_sec: float = 0.0, config: Optional[SessionConfig] = None, ) -> AsyncGenerator[TranscriptionResult, None]: """ Transcribe an audio segment with streaming output. Yields partial results followed by final result. Args: audio: Audio data sample_rate: Sample rate context_info: Optional context segment_start_sec: Segment start time segment_end_sec: Segment end time config: Session config Yields: TranscriptionResult objects (partial and final) """ result_queue: asyncio.Queue = asyncio.Queue() async def on_partial(result: TranscriptionResult): await result_queue.put(result) # Start transcription task transcribe_task = asyncio.create_task( self.transcribe_segment( audio=audio, sample_rate=sample_rate, context_info=context_info, segment_start_sec=segment_start_sec, segment_end_sec=segment_end_sec, config=config, on_partial=on_partial, ) ) # Yield partial results as they come while not transcribe_task.done(): try: result = await asyncio.wait_for(result_queue.get(), timeout=0.1) yield result except asyncio.TimeoutError: continue # Drain any remaining partial results while not result_queue.empty(): yield await result_queue.get() # Yield final result final_result = await transcribe_task yield final_result