koide 1fb76254e9
All checks were successful
Deploy Docusaurus Site / deploy (push) Successful in 27s
Add: VibeVoice ASR セットアップスクリプト一式
2026-02-24 01:21:33 +00:00

359 lines
12 KiB
Python

"""
ASR Worker for real-time transcription.
Wraps the existing VibeVoiceASRInference for async/streaming operation.
"""
import sys
import os
import asyncio
import threading
import time
import queue
from typing import AsyncGenerator, Optional, List, Callable
from dataclasses import dataclass
import numpy as np
import torch
# Add parent directory and demo directory to path for importing existing code
_parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, _parent_dir)
sys.path.insert(0, os.path.join(_parent_dir, "demo"))
from .models import (
TranscriptionResult,
TranscriptionSegment,
MessageType,
SessionConfig,
)
@dataclass
class InferenceRequest:
"""Request for ASR inference."""
audio: np.ndarray
sample_rate: int
context_info: Optional[str]
request_time: float
segment_start_sec: float
segment_end_sec: float
class ASRWorker:
"""
ASR Worker that wraps VibeVoiceASRInference for real-time use.
Features:
- Async interface for WebSocket integration
- Streaming output via TextIteratorStreamer
- Request queuing for handling concurrent segments
- Graceful model loading and error handling
"""
def __init__(
self,
model_path: str = "microsoft/VibeVoice-ASR",
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
attn_implementation: str = "flash_attention_2",
):
"""
Initialize the ASR worker.
Args:
model_path: Path to VibeVoice ASR model
device: Device to run inference on
dtype: Model data type
attn_implementation: Attention implementation
"""
self.model_path = model_path
self.device = device
self.dtype = dtype
self.attn_implementation = attn_implementation
self._inference = None
self._is_loaded = False
self._load_lock = threading.Lock()
# Inference queue for serializing requests
self._inference_semaphore = asyncio.Semaphore(1)
def load_model(self) -> bool:
"""
Load the ASR model.
Returns:
True if model loaded successfully
"""
with self._load_lock:
if self._is_loaded:
return True
try:
# Import here to avoid circular imports and allow lazy loading
# In Docker, the file is copied as vibevoice_asr_gradio_demo.py
try:
from vibevoice_asr_gradio_demo import VibeVoiceASRInference
except ImportError:
from vibevoice_asr_gradio_demo_patched import VibeVoiceASRInference
print(f"Loading VibeVoice ASR model from {self.model_path}...")
self._inference = VibeVoiceASRInference(
model_path=self.model_path,
device=self.device,
dtype=self.dtype,
attn_implementation=self.attn_implementation,
)
self._is_loaded = True
print("ASR model loaded successfully")
return True
except Exception as e:
print(f"Failed to load ASR model: {e}")
import traceback
traceback.print_exc()
return False
@property
def is_loaded(self) -> bool:
"""Check if model is loaded."""
return self._is_loaded
async def transcribe_segment(
self,
audio: np.ndarray,
sample_rate: int = 16000,
context_info: Optional[str] = None,
segment_start_sec: float = 0.0,
segment_end_sec: float = 0.0,
config: Optional[SessionConfig] = None,
on_partial: Optional[Callable[[TranscriptionResult], None]] = None,
) -> TranscriptionResult:
"""
Transcribe an audio segment asynchronously.
Args:
audio: Audio data as float32 array
sample_rate: Audio sample rate
context_info: Optional context for transcription
segment_start_sec: Start time of segment in session
segment_end_sec: End time of segment in session
config: Session configuration
on_partial: Callback for partial results
Returns:
Final transcription result
"""
if not self._is_loaded:
if not self.load_model():
return TranscriptionResult(
type=MessageType.ERROR,
text="",
is_final=True,
latency_ms=0,
)
config = config or SessionConfig()
request_time = time.time()
# Serialize inference requests
async with self._inference_semaphore:
return await self._run_inference(
audio=audio,
sample_rate=sample_rate,
context_info=context_info,
segment_start_sec=segment_start_sec,
segment_end_sec=segment_end_sec,
config=config,
request_time=request_time,
on_partial=on_partial,
)
async def _run_inference(
self,
audio: np.ndarray,
sample_rate: int,
context_info: Optional[str],
segment_start_sec: float,
segment_end_sec: float,
config: SessionConfig,
request_time: float,
on_partial: Optional[Callable[[TranscriptionResult], None]],
) -> TranscriptionResult:
"""Run the actual inference in a thread pool."""
from transformers import TextIteratorStreamer
# Create streamer for partial results
streamer = None
if config.return_partial_results and on_partial:
streamer = TextIteratorStreamer(
self._inference.processor.tokenizer,
skip_prompt=True,
skip_special_tokens=True,
)
# Result container for thread
result_container = {"result": None, "error": None}
def run_inference():
try:
# Save audio to temp file (required by current implementation)
import tempfile
import soundfile as sf
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
temp_path = f.name
# Write audio
audio_int16 = (audio * 32768.0).clip(-32768, 32767).astype(np.int16)
sf.write(temp_path, audio_int16, sample_rate, subtype='PCM_16')
try:
result = self._inference.transcribe(
audio_path=temp_path,
max_new_tokens=config.max_new_tokens,
temperature=config.temperature,
context_info=context_info,
streamer=streamer,
)
result_container["result"] = result
finally:
# Clean up temp file
try:
os.unlink(temp_path)
except:
pass
except Exception as e:
result_container["error"] = str(e)
import traceback
traceback.print_exc()
# Start inference in background thread
inference_thread = threading.Thread(target=run_inference)
inference_thread.start()
# Stream partial results if enabled
partial_text = ""
if streamer and on_partial:
try:
for new_text in streamer:
partial_text += new_text
partial_result = TranscriptionResult(
type=MessageType.PARTIAL_RESULT,
text=partial_text,
is_final=False,
latency_ms=(time.time() - request_time) * 1000,
)
# Call callback (may be async)
if asyncio.iscoroutinefunction(on_partial):
await on_partial(partial_result)
else:
on_partial(partial_result)
except Exception as e:
print(f"Error during streaming: {e}")
# Wait for completion
inference_thread.join()
latency_ms = (time.time() - request_time) * 1000
if result_container["error"]:
return TranscriptionResult(
type=MessageType.ERROR,
text=f"Error: {result_container['error']}",
is_final=True,
latency_ms=latency_ms,
)
result = result_container["result"]
# Convert segments to our format
segments = []
for seg in result.get("segments", []):
# Adjust timestamps relative to session
seg_start = seg.get("start_time", 0)
seg_end = seg.get("end_time", 0)
# If segment has relative timestamps, adjust to absolute
if isinstance(seg_start, (int, float)) and isinstance(seg_end, (int, float)):
adjusted_start = segment_start_sec + seg_start
adjusted_end = segment_start_sec + seg_end
else:
adjusted_start = segment_start_sec
adjusted_end = segment_end_sec
segments.append(TranscriptionSegment(
start_time=adjusted_start,
end_time=adjusted_end,
speaker_id=seg.get("speaker_id", "SPEAKER_00"),
text=seg.get("text", ""),
))
return TranscriptionResult(
type=MessageType.FINAL_RESULT,
text=result.get("raw_text", ""),
is_final=True,
segments=segments,
latency_ms=latency_ms,
)
async def transcribe_stream(
self,
audio: np.ndarray,
sample_rate: int = 16000,
context_info: Optional[str] = None,
segment_start_sec: float = 0.0,
segment_end_sec: float = 0.0,
config: Optional[SessionConfig] = None,
) -> AsyncGenerator[TranscriptionResult, None]:
"""
Transcribe an audio segment with streaming output.
Yields partial results followed by final result.
Args:
audio: Audio data
sample_rate: Sample rate
context_info: Optional context
segment_start_sec: Segment start time
segment_end_sec: Segment end time
config: Session config
Yields:
TranscriptionResult objects (partial and final)
"""
result_queue: asyncio.Queue = asyncio.Queue()
async def on_partial(result: TranscriptionResult):
await result_queue.put(result)
# Start transcription task
transcribe_task = asyncio.create_task(
self.transcribe_segment(
audio=audio,
sample_rate=sample_rate,
context_info=context_info,
segment_start_sec=segment_start_sec,
segment_end_sec=segment_end_sec,
config=config,
on_partial=on_partial,
)
)
# Yield partial results as they come
while not transcribe_task.done():
try:
result = await asyncio.wait_for(result_queue.get(), timeout=0.1)
yield result
except asyncio.TimeoutError:
continue
# Drain any remaining partial results
while not result_queue.empty():
yield await result_queue.get()
# Yield final result
final_result = await transcribe_task
yield final_result