All checks were successful
Deploy Docusaurus Site / deploy (push) Successful in 27s
359 lines
12 KiB
Python
359 lines
12 KiB
Python
"""
|
|
ASR Worker for real-time transcription.
|
|
|
|
Wraps the existing VibeVoiceASRInference for async/streaming operation.
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
import asyncio
|
|
import threading
|
|
import time
|
|
import queue
|
|
from typing import AsyncGenerator, Optional, List, Callable
|
|
from dataclasses import dataclass
|
|
import numpy as np
|
|
import torch
|
|
|
|
# Add parent directory and demo directory to path for importing existing code
|
|
_parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
sys.path.insert(0, _parent_dir)
|
|
sys.path.insert(0, os.path.join(_parent_dir, "demo"))
|
|
|
|
from .models import (
|
|
TranscriptionResult,
|
|
TranscriptionSegment,
|
|
MessageType,
|
|
SessionConfig,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class InferenceRequest:
|
|
"""Request for ASR inference."""
|
|
audio: np.ndarray
|
|
sample_rate: int
|
|
context_info: Optional[str]
|
|
request_time: float
|
|
segment_start_sec: float
|
|
segment_end_sec: float
|
|
|
|
|
|
class ASRWorker:
|
|
"""
|
|
ASR Worker that wraps VibeVoiceASRInference for real-time use.
|
|
|
|
Features:
|
|
- Async interface for WebSocket integration
|
|
- Streaming output via TextIteratorStreamer
|
|
- Request queuing for handling concurrent segments
|
|
- Graceful model loading and error handling
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_path: str = "microsoft/VibeVoice-ASR",
|
|
device: str = "cuda",
|
|
dtype: torch.dtype = torch.bfloat16,
|
|
attn_implementation: str = "flash_attention_2",
|
|
):
|
|
"""
|
|
Initialize the ASR worker.
|
|
|
|
Args:
|
|
model_path: Path to VibeVoice ASR model
|
|
device: Device to run inference on
|
|
dtype: Model data type
|
|
attn_implementation: Attention implementation
|
|
"""
|
|
self.model_path = model_path
|
|
self.device = device
|
|
self.dtype = dtype
|
|
self.attn_implementation = attn_implementation
|
|
|
|
self._inference = None
|
|
self._is_loaded = False
|
|
self._load_lock = threading.Lock()
|
|
|
|
# Inference queue for serializing requests
|
|
self._inference_semaphore = asyncio.Semaphore(1)
|
|
|
|
def load_model(self) -> bool:
|
|
"""
|
|
Load the ASR model.
|
|
|
|
Returns:
|
|
True if model loaded successfully
|
|
"""
|
|
with self._load_lock:
|
|
if self._is_loaded:
|
|
return True
|
|
|
|
try:
|
|
# Import here to avoid circular imports and allow lazy loading
|
|
# In Docker, the file is copied as vibevoice_asr_gradio_demo.py
|
|
try:
|
|
from vibevoice_asr_gradio_demo import VibeVoiceASRInference
|
|
except ImportError:
|
|
from vibevoice_asr_gradio_demo_patched import VibeVoiceASRInference
|
|
|
|
print(f"Loading VibeVoice ASR model from {self.model_path}...")
|
|
self._inference = VibeVoiceASRInference(
|
|
model_path=self.model_path,
|
|
device=self.device,
|
|
dtype=self.dtype,
|
|
attn_implementation=self.attn_implementation,
|
|
)
|
|
self._is_loaded = True
|
|
print("ASR model loaded successfully")
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"Failed to load ASR model: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
@property
|
|
def is_loaded(self) -> bool:
|
|
"""Check if model is loaded."""
|
|
return self._is_loaded
|
|
|
|
async def transcribe_segment(
|
|
self,
|
|
audio: np.ndarray,
|
|
sample_rate: int = 16000,
|
|
context_info: Optional[str] = None,
|
|
segment_start_sec: float = 0.0,
|
|
segment_end_sec: float = 0.0,
|
|
config: Optional[SessionConfig] = None,
|
|
on_partial: Optional[Callable[[TranscriptionResult], None]] = None,
|
|
) -> TranscriptionResult:
|
|
"""
|
|
Transcribe an audio segment asynchronously.
|
|
|
|
Args:
|
|
audio: Audio data as float32 array
|
|
sample_rate: Audio sample rate
|
|
context_info: Optional context for transcription
|
|
segment_start_sec: Start time of segment in session
|
|
segment_end_sec: End time of segment in session
|
|
config: Session configuration
|
|
on_partial: Callback for partial results
|
|
|
|
Returns:
|
|
Final transcription result
|
|
"""
|
|
if not self._is_loaded:
|
|
if not self.load_model():
|
|
return TranscriptionResult(
|
|
type=MessageType.ERROR,
|
|
text="",
|
|
is_final=True,
|
|
latency_ms=0,
|
|
)
|
|
|
|
config = config or SessionConfig()
|
|
request_time = time.time()
|
|
|
|
# Serialize inference requests
|
|
async with self._inference_semaphore:
|
|
return await self._run_inference(
|
|
audio=audio,
|
|
sample_rate=sample_rate,
|
|
context_info=context_info,
|
|
segment_start_sec=segment_start_sec,
|
|
segment_end_sec=segment_end_sec,
|
|
config=config,
|
|
request_time=request_time,
|
|
on_partial=on_partial,
|
|
)
|
|
|
|
async def _run_inference(
|
|
self,
|
|
audio: np.ndarray,
|
|
sample_rate: int,
|
|
context_info: Optional[str],
|
|
segment_start_sec: float,
|
|
segment_end_sec: float,
|
|
config: SessionConfig,
|
|
request_time: float,
|
|
on_partial: Optional[Callable[[TranscriptionResult], None]],
|
|
) -> TranscriptionResult:
|
|
"""Run the actual inference in a thread pool."""
|
|
from transformers import TextIteratorStreamer
|
|
|
|
# Create streamer for partial results
|
|
streamer = None
|
|
if config.return_partial_results and on_partial:
|
|
streamer = TextIteratorStreamer(
|
|
self._inference.processor.tokenizer,
|
|
skip_prompt=True,
|
|
skip_special_tokens=True,
|
|
)
|
|
|
|
# Result container for thread
|
|
result_container = {"result": None, "error": None}
|
|
|
|
def run_inference():
|
|
try:
|
|
# Save audio to temp file (required by current implementation)
|
|
import tempfile
|
|
import soundfile as sf
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
|
temp_path = f.name
|
|
|
|
# Write audio
|
|
audio_int16 = (audio * 32768.0).clip(-32768, 32767).astype(np.int16)
|
|
sf.write(temp_path, audio_int16, sample_rate, subtype='PCM_16')
|
|
|
|
try:
|
|
result = self._inference.transcribe(
|
|
audio_path=temp_path,
|
|
max_new_tokens=config.max_new_tokens,
|
|
temperature=config.temperature,
|
|
context_info=context_info,
|
|
streamer=streamer,
|
|
)
|
|
result_container["result"] = result
|
|
finally:
|
|
# Clean up temp file
|
|
try:
|
|
os.unlink(temp_path)
|
|
except:
|
|
pass
|
|
|
|
except Exception as e:
|
|
result_container["error"] = str(e)
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
# Start inference in background thread
|
|
inference_thread = threading.Thread(target=run_inference)
|
|
inference_thread.start()
|
|
|
|
# Stream partial results if enabled
|
|
partial_text = ""
|
|
if streamer and on_partial:
|
|
try:
|
|
for new_text in streamer:
|
|
partial_text += new_text
|
|
partial_result = TranscriptionResult(
|
|
type=MessageType.PARTIAL_RESULT,
|
|
text=partial_text,
|
|
is_final=False,
|
|
latency_ms=(time.time() - request_time) * 1000,
|
|
)
|
|
# Call callback (may be async)
|
|
if asyncio.iscoroutinefunction(on_partial):
|
|
await on_partial(partial_result)
|
|
else:
|
|
on_partial(partial_result)
|
|
except Exception as e:
|
|
print(f"Error during streaming: {e}")
|
|
|
|
# Wait for completion
|
|
inference_thread.join()
|
|
|
|
latency_ms = (time.time() - request_time) * 1000
|
|
|
|
if result_container["error"]:
|
|
return TranscriptionResult(
|
|
type=MessageType.ERROR,
|
|
text=f"Error: {result_container['error']}",
|
|
is_final=True,
|
|
latency_ms=latency_ms,
|
|
)
|
|
|
|
result = result_container["result"]
|
|
|
|
# Convert segments to our format
|
|
segments = []
|
|
for seg in result.get("segments", []):
|
|
# Adjust timestamps relative to session
|
|
seg_start = seg.get("start_time", 0)
|
|
seg_end = seg.get("end_time", 0)
|
|
|
|
# If segment has relative timestamps, adjust to absolute
|
|
if isinstance(seg_start, (int, float)) and isinstance(seg_end, (int, float)):
|
|
adjusted_start = segment_start_sec + seg_start
|
|
adjusted_end = segment_start_sec + seg_end
|
|
else:
|
|
adjusted_start = segment_start_sec
|
|
adjusted_end = segment_end_sec
|
|
|
|
segments.append(TranscriptionSegment(
|
|
start_time=adjusted_start,
|
|
end_time=adjusted_end,
|
|
speaker_id=seg.get("speaker_id", "SPEAKER_00"),
|
|
text=seg.get("text", ""),
|
|
))
|
|
|
|
return TranscriptionResult(
|
|
type=MessageType.FINAL_RESULT,
|
|
text=result.get("raw_text", ""),
|
|
is_final=True,
|
|
segments=segments,
|
|
latency_ms=latency_ms,
|
|
)
|
|
|
|
async def transcribe_stream(
|
|
self,
|
|
audio: np.ndarray,
|
|
sample_rate: int = 16000,
|
|
context_info: Optional[str] = None,
|
|
segment_start_sec: float = 0.0,
|
|
segment_end_sec: float = 0.0,
|
|
config: Optional[SessionConfig] = None,
|
|
) -> AsyncGenerator[TranscriptionResult, None]:
|
|
"""
|
|
Transcribe an audio segment with streaming output.
|
|
|
|
Yields partial results followed by final result.
|
|
|
|
Args:
|
|
audio: Audio data
|
|
sample_rate: Sample rate
|
|
context_info: Optional context
|
|
segment_start_sec: Segment start time
|
|
segment_end_sec: Segment end time
|
|
config: Session config
|
|
|
|
Yields:
|
|
TranscriptionResult objects (partial and final)
|
|
"""
|
|
result_queue: asyncio.Queue = asyncio.Queue()
|
|
|
|
async def on_partial(result: TranscriptionResult):
|
|
await result_queue.put(result)
|
|
|
|
# Start transcription task
|
|
transcribe_task = asyncio.create_task(
|
|
self.transcribe_segment(
|
|
audio=audio,
|
|
sample_rate=sample_rate,
|
|
context_info=context_info,
|
|
segment_start_sec=segment_start_sec,
|
|
segment_end_sec=segment_end_sec,
|
|
config=config,
|
|
on_partial=on_partial,
|
|
)
|
|
)
|
|
|
|
# Yield partial results as they come
|
|
while not transcribe_task.done():
|
|
try:
|
|
result = await asyncio.wait_for(result_queue.get(), timeout=0.1)
|
|
yield result
|
|
except asyncio.TimeoutError:
|
|
continue
|
|
|
|
# Drain any remaining partial results
|
|
while not result_queue.empty():
|
|
yield await result_queue.get()
|
|
|
|
# Yield final result
|
|
final_result = await transcribe_task
|
|
yield final_result
|