#!/usr/bin/env python """ VibeVoice ASR Gradio Demo (Patched for MKV and video file support) """ import os import sys import torch import numpy as np import soundfile as sf from pathlib import Path import argparse import time import json import gradio as gr from typing import List, Dict, Tuple, Optional, Generator import tempfile import base64 import io import traceback import threading from concurrent.futures import ThreadPoolExecutor, as_completed # Import TextIteratorStreamer for streaming generation from transformers import TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList try: from liger_kernel.transformers import apply_liger_kernel_to_qwen2 # Only apply RoPE, RMSNorm, SwiGLU patches (these affect the underlying Qwen2 layers) apply_liger_kernel_to_qwen2( rope=True, rms_norm=True, swiglu=True, cross_entropy=False, ) print("✅ Liger Kernel applied to Qwen2 components (RoPE, RMSNorm, SwiGLU)") except Exception as e: print(f"⚠️ Failed to apply Liger Kernel: {e}, you can install it with: pip install liger-kernel") # Try to import pydub for MP3 conversion try: from pydub import AudioSegment HAS_PYDUB = True except ImportError: HAS_PYDUB = False print("⚠️ Warning: pydub not available, falling back to WAV format") from vibevoice.modular.modeling_vibevoice_asr import VibeVoiceASRForConditionalGeneration from vibevoice.processor.vibevoice_asr_processor import VibeVoiceASRProcessor from vibevoice.processor.audio_utils import load_audio_use_ffmpeg, COMMON_AUDIO_EXTS # Add MKV and additional video format support MKV_EXTS = ['.mkv', '.MKV', '.Mkv', '.avi', '.AVI'] for ext in MKV_EXTS: if ext not in COMMON_AUDIO_EXTS: COMMON_AUDIO_EXTS.append(ext) print(f"✅ MKV/Video support enabled. Total supported formats: {len(COMMON_AUDIO_EXTS)}") class VibeVoiceASRInference: """Simple inference wrapper for VibeVoice ASR model.""" def __init__(self, model_path: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16, attn_implementation: str = "flash_attention_2"): """ Initialize the ASR inference pipeline. Args: model_path: Path to the pretrained model (HuggingFace format directory or model name) device: Device to run inference on dtype: Data type for model weights attn_implementation: Attention implementation to use ('flash_attention_2', 'sdpa', 'eager') """ print(f"Loading VibeVoice ASR model from {model_path}") # Load processor self.processor = VibeVoiceASRProcessor.from_pretrained(model_path) # Load model print(f"Using attention implementation: {attn_implementation}") self.model = VibeVoiceASRForConditionalGeneration.from_pretrained( model_path, dtype=dtype, device_map=device if device == "auto" else None, attn_implementation=attn_implementation, trust_remote_code=True ) if device != "auto": self.model = self.model.to(device) self.device = device if device != "auto" else next(self.model.parameters()).device self.model.eval() # Print model info total_params = sum(p.numel() for p in self.model.parameters()) print(f"✅ Model loaded successfully on {self.device}") print(f"📊 Total parameters: {total_params:,} ({total_params/1e9:.2f}B)") def transcribe( self, audio_path: str = None, audio_array: np.ndarray = None, sample_rate: int = None, max_new_tokens: int = 512, temperature: float = 0.0, top_p: float = 1.0, do_sample: bool = False, num_beams: int = 1, repetition_penalty: float = 1.0, context_info: str = None, streamer: Optional[TextIteratorStreamer] = None, ) -> dict: """ Transcribe audio to text. Args: audio_path: Path to audio file audio_array: Audio array (if not loading from file) sample_rate: Sample rate of audio array max_new_tokens: Maximum tokens to generate temperature: Temperature for sampling (0 for greedy) top_p: Top-p for nucleus sampling (1.0 for no filtering) do_sample: Whether to use sampling num_beams: Number of beams for beam search (1 for greedy) repetition_penalty: Repetition penalty (1.0 for no penalty) context_info: Optional context information (e.g., hotwords, speaker names, topics) to help transcription streamer: Optional TextIteratorStreamer for streaming output Returns: Dictionary with transcription results """ # Process audio inputs = self.processor( audio=audio_path, sampling_rate=sample_rate, return_tensors="pt", add_generation_prompt=True, context_info=context_info ) # Move to device inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} # Generate generation_config = { "max_new_tokens": max_new_tokens, "temperature": temperature if temperature > 0 else None, "top_p": top_p if do_sample else None, "do_sample": do_sample, "num_beams": num_beams, "repetition_penalty": repetition_penalty, "pad_token_id": self.processor.pad_id, "eos_token_id": self.processor.tokenizer.eos_token_id, } # Add streamer if provided if streamer is not None: generation_config["streamer"] = streamer # Add stopping criteria for stop button support generation_config["stopping_criteria"] = StoppingCriteriaList([StopOnFlag()]) # Remove None values generation_config = {k: v for k, v in generation_config.items() if v is not None} start_time = time.time() # Calculate input token statistics before generation input_ids = inputs['input_ids'][0] # Shape: [seq_len] total_input_tokens = input_ids.shape[0] # Count padding tokens (tokens equal to pad_id) pad_id = self.processor.pad_id padding_mask = (input_ids == pad_id) num_padding_tokens = padding_mask.sum().item() # Count speech tokens (tokens between speech_start_id and speech_end_id) speech_start_id = self.processor.speech_start_id speech_end_id = self.processor.speech_end_id # Find speech regions input_ids_list = input_ids.tolist() num_speech_tokens = 0 in_speech = False for token_id in input_ids_list: if token_id == speech_start_id: in_speech = True num_speech_tokens += 1 # Count speech_start token elif token_id == speech_end_id: in_speech = False num_speech_tokens += 1 # Count speech_end token elif in_speech: num_speech_tokens += 1 # Text tokens = total - speech - padding num_text_tokens = total_input_tokens - num_speech_tokens - num_padding_tokens with torch.no_grad(): output_ids = self.model.generate( **inputs, **generation_config ) generation_time = time.time() - start_time # Decode output generated_ids = output_ids[0, inputs['input_ids'].shape[1]:] generated_text = self.processor.decode(generated_ids, skip_special_tokens=True) # Parse structured output try: transcription_segments = self.processor.post_process_transcription(generated_text) except Exception as e: print(f"Warning: Failed to parse structured output: {e}") transcription_segments = [] return { "raw_text": generated_text, "segments": transcription_segments, "generation_time": generation_time, "input_tokens": { "total": total_input_tokens, "speech": num_speech_tokens, "text": num_text_tokens, "padding": num_padding_tokens, }, } def clip_and_encode_audio( audio_data: np.ndarray, sr: int, start_time: float, end_time: float, segment_idx: int, use_mp3: bool = True, target_sr: int = 16000, # Downsample to 16kHz for smaller size mp3_bitrate: str = "32k" # Use low bitrate for minimal transfer ) -> Tuple[int, Optional[str], Optional[str]]: """ Clip audio segment and encode to base64. Args: audio_data: Full audio array sr: Sample rate start_time: Start time in seconds end_time: End time in seconds segment_idx: Segment index for identification use_mp3: Whether to use MP3 format (smaller size) target_sr: Target sample rate for downsampling (lower = smaller) mp3_bitrate: MP3 bitrate (lower = smaller, e.g., "24k", "32k", "48k") Returns: Tuple of (segment_idx, base64_string, error_message) """ try: # Convert time to sample indices start_sample = int(start_time * sr) end_sample = int(end_time * sr) # Ensure indices are within bounds start_sample = max(0, start_sample) end_sample = min(len(audio_data), end_sample) if start_sample >= end_sample: return segment_idx, None, f"Invalid time range: [{start_time:.2f}s - {end_time:.2f}s]" # Extract segment segment_data = audio_data[start_sample:end_sample] # Downsample if needed (reduces data size significantly) if sr != target_sr and target_sr < sr: # Simple downsampling using linear interpolation duration = len(segment_data) / sr new_length = int(duration * target_sr) indices = np.linspace(0, len(segment_data) - 1, new_length) segment_data = np.interp(indices, np.arange(len(segment_data)), segment_data) sr = target_sr # Convert float32 audio to int16 for encoding segment_data_int16 = (segment_data * 32768.0).astype(np.int16) # Convert to MP3 if pydub is available and use_mp3 is True if use_mp3 and HAS_PYDUB: try: # Write to WAV in memory wav_buffer = io.BytesIO() sf.write(wav_buffer, segment_data_int16, sr, format='WAV', subtype='PCM_16') wav_buffer.seek(0) # Convert to MP3 with low bitrate audio_segment = AudioSegment.from_wav(wav_buffer) # Convert to mono if stereo (halves the size) if audio_segment.channels > 1: audio_segment = audio_segment.set_channels(1) mp3_buffer = io.BytesIO() audio_segment.export(mp3_buffer, format='mp3', bitrate=mp3_bitrate) mp3_buffer.seek(0) # Encode to base64 audio_bytes = mp3_buffer.read() audio_base64 = base64.b64encode(audio_bytes).decode('utf-8') audio_src = f"data:audio/mp3;base64,{audio_base64}" return segment_idx, audio_src, None except Exception as e: # Fall back to WAV on error print(f"MP3 conversion failed for segment {segment_idx}, using WAV: {e}") # Fall back to WAV format (no temp file, use in-memory buffer) wav_buffer = io.BytesIO() sf.write(wav_buffer, segment_data_int16, sr, format='WAV', subtype='PCM_16') wav_buffer.seek(0) audio_bytes = wav_buffer.read() audio_base64 = base64.b64encode(audio_bytes).decode('utf-8') audio_src = f"data:audio/wav;base64,{audio_base64}" return segment_idx, audio_src, None except Exception as e: error_msg = f"Error clipping segment {segment_idx}: {str(e)}" print(error_msg) return segment_idx, None, error_msg def extract_audio_segments(audio_path: str, segments: List[Dict]) -> List[Tuple[str, str, Optional[str]]]: """ Extract multiple segments from audio file efficiently with parallel processing. Args: audio_path: Path to original audio file segments: List of segment dictionaries with start_time, end_time, etc. Returns: List of tuples (segment_label, audio_base64_src, error_msg) """ try: # Read audio file once using ffmpeg for better format support print(f"📂 Loading audio file: {audio_path}") audio_data, sr = load_audio_use_ffmpeg(audio_path, resample=False) print(f"✅ Audio loaded: {len(audio_data)} samples, {sr} Hz") # Prepare tasks tasks = [] use_mp3 = HAS_PYDUB # Use MP3 if available for i, seg in enumerate(segments): start_time = seg.get('start_time') end_time = seg.get('end_time') # Skip if times are not available or invalid if (not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)) or start_time >= end_time): tasks.append((i, None, None, None, None, None)) # Will be filtered later continue tasks.append((audio_data, sr, start_time, end_time, i, use_mp3)) # Process in parallel using ThreadPoolExecutor results = [] total_segments = len(tasks) completed_count = 0 # Use CPU count for max workers max_workers = os.cpu_count() or 4 print(f"🚀 Starting parallel processing with {max_workers} threads...") with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = {} for task in tasks: if task[0] is None: # Skip invalid tasks continue future = executor.submit(clip_and_encode_audio, *task) futures[future] = task[4] # segment_idx for future in as_completed(futures): try: result = future.result() results.append(result) completed_count += 1 # Log progress every 100 segments or at completion if completed_count % 100 == 0 or completed_count == len(futures): print(f"Progress: {completed_count}/{len(futures)} segments processed ({completed_count*100//len(futures)}%)") except Exception as e: idx = futures[future] results.append((idx, None, f"Processing error: {str(e)}")) completed_count += 1 print(f"Error on segment {idx}: {e}") print(f"✅ Completed processing all {len(futures)} valid segments") # Sort by segment index to maintain order results.sort(key=lambda x: x[0]) # Build output list with labels audio_segments = [] for i, (idx, audio_src, error_msg) in enumerate(results): seg = segments[idx] if idx < len(segments) else {} start_time = seg.get('start_time', 'N/A') end_time = seg.get('end_time', 'N/A') speaker_id = seg.get('speaker_id', 'N/A') segment_label = f"Segment {idx+1}: [{start_time:.2f}s - {end_time:.2f}s] Speaker {speaker_id}" audio_segments.append((segment_label, audio_src, error_msg)) return audio_segments except Exception as e: print(f"Error loading audio file: {e}") return [] # Global variable to store the ASR model asr_model = None # Global stop flag for generation stop_generation_flag = False class StopOnFlag(StoppingCriteria): """Custom stopping criteria that checks a global flag.""" def __call__(self, input_ids, scores, **kwargs): global stop_generation_flag return stop_generation_flag def parse_time_to_seconds(val: Optional[str]) -> Optional[float]: """Parse seconds or hh:mm:ss to float seconds.""" if val is None: return None val = val.strip() if not val: return None try: return float(val) except ValueError: pass if ":" in val: parts = val.split(":") if not all(p.strip().replace(".", "", 1).isdigit() for p in parts): return None parts = [float(p) for p in parts] if len(parts) == 3: h, m, s = parts elif len(parts) == 2: h = 0 m, s = parts else: return None return h * 3600 + m * 60 + s return None def slice_audio_to_temp( audio_data: np.ndarray, sample_rate: int, start_sec: Optional[float], end_sec: Optional[float] ) -> Tuple[Optional[str], Optional[str]]: """Slice audio_data to [start_sec, end_sec) and write to a temp WAV file.""" n_samples = len(audio_data) full_duration = n_samples / float(sample_rate) start = 0.0 if start_sec is None else max(0.0, start_sec) end = full_duration if end_sec is None else min(full_duration, end_sec) if end <= start: return None, f"Invalid time range: start={start:.2f}s, end={end:.2f}s" start_idx = int(start * sample_rate) end_idx = int(end * sample_rate) segment = audio_data[start_idx:end_idx] temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") temp_file.close() segment_int16 = (segment * 32768.0).astype(np.int16) sf.write(temp_file.name, segment_int16, sample_rate, subtype='PCM_16') return temp_file.name, None def initialize_model(model_path: str, device: str = "cuda", attn_implementation: str = "flash_attention_2"): """Initialize the ASR model.""" global asr_model try: dtype = torch.bfloat16 if device != "cpu" else torch.float32 asr_model = VibeVoiceASRInference( model_path=model_path, device=device, dtype=dtype, attn_implementation=attn_implementation ) return f"✅ Model loaded successfully from {model_path}" except Exception as e: import traceback traceback.print_exc() return f"❌ Error loading model: {str(e)}" def transcribe_audio( audio_input, audio_path_input: str, start_time_input: str, end_time_input: str, max_new_tokens: int, temperature: float, top_p: float, do_sample: bool, repetition_penalty: float = 1.0, context_info: str = "" ) -> Generator[Tuple[str, str], None, None]: """ Transcribe audio and return results with audio segments (streaming version). Args: audio_input: Audio file path or tuple (sample_rate, audio_data) max_new_tokens: Maximum tokens to generate temperature: Temperature for sampling (0 for greedy) top_p: Top-p for nucleus sampling do_sample: Whether to use sampling context_info: Optional context information (e.g., hotwords, speaker names, topics) Yields: Tuple of (raw_text, audio_segments_html) """ if asr_model is None: yield "❌ Please load a model first!", "" return if not audio_path_input and audio_input is None: yield "❌ Please provide audio input!", "" return try: print("[INFO] Transcription requested") start_sec = parse_time_to_seconds(start_time_input) end_sec = parse_time_to_seconds(end_time_input) print(f"[INFO] Parsed time range: start={start_sec}, end={end_sec}") if (start_time_input and start_sec is None) or (end_time_input and end_sec is None): yield "❌ Invalid time format. Use seconds or hh:mm:ss.", "" return audio_path = None audio_array = None sample_rate = None if audio_path_input: candidate = Path(audio_path_input.strip()) if not candidate.exists(): yield f"❌ Provided path does not exist: {candidate}", "" return audio_path = str(candidate) print(f"[INFO] Using provided audio path: {audio_path}") # Get audio file path (Gradio Audio component returns tuple (sample_rate, audio_data) or file path) elif isinstance(audio_input, str): audio_path = audio_input print(f"[INFO] Using uploaded audio path: {audio_path}") elif isinstance(audio_input, tuple): # Audio from microphone: (sample_rate, audio_data) sample_rate, audio_array = audio_input print(f"[INFO] Received microphone audio with sample_rate={sample_rate}") elif audio_path is None: yield "❌ Invalid audio input format!", "" return # If slicing is requested, load and slice audio if start_sec is not None or end_sec is not None: print("[INFO] Slicing audio per requested time range") if audio_array is None or sample_rate is None: try: audio_array, sample_rate = load_audio_use_ffmpeg(audio_path, resample=False) print("[INFO] Loaded audio for slicing via ffmpeg") except Exception as exc: yield f"❌ Failed to load audio for slicing: {exc}", "" return sliced_path, err = slice_audio_to_temp(audio_array, sample_rate, start_sec, end_sec) if err: yield f"❌ {err}", "" return audio_path = sliced_path print(f"[INFO] Sliced audio written to temp file: {audio_path}") elif audio_array is not None and sample_rate is not None: # no slicing but microphone input: write to temp file temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") audio_path = temp_file.name temp_file.close() audio_data_int16 = (audio_array * 32768.0).astype(np.int16) sf.write(audio_path, audio_data_int16, sample_rate, subtype='PCM_16') print(f"[INFO] Microphone audio saved to temp file: {audio_path}") # Create streamer for real-time output streamer = TextIteratorStreamer( asr_model.processor.tokenizer, skip_prompt=True, skip_special_tokens=True ) # Store result in a mutable container for the thread result_container = {"result": None, "error": None} def run_transcription(): try: result_container["result"] = asr_model.transcribe( audio_path=audio_path, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=do_sample, repetition_penalty=repetition_penalty, context_info=context_info if context_info and context_info.strip() else None, streamer=streamer ) except Exception as e: result_container["error"] = str(e) traceback.print_exc() # Start transcription in background thread print("[INFO] Starting model transcription (streaming mode)") start_time = time.time() transcription_thread = threading.Thread(target=run_transcription) transcription_thread.start() # Yield streaming output generated_text = "" token_count = 0 for new_text in streamer: generated_text += new_text token_count += 1 elapsed = time.time() - start_time # Show streaming output with live stats, format for readability formatted_text = generated_text.replace('},', '},\n') streaming_output = f"--- 🔴 LIVE Streaming Output (tokens: {token_count}, time: {elapsed:.1f}s) ---\n{formatted_text}" yield streaming_output, "
🎵 Click the play button to listen to each segment directly!
" for i, (label, audio_src, error_msg) in enumerate(audio_segments): seg = segments[i] if i < len(segments) else {} start_time = seg.get('start_time', 'N/A') end_time = seg.get('end_time', 'N/A') speaker_id = seg.get('speaker_id', 'N/A') content = seg.get('text', '') # Format times nicely start_str = f"{start_time:.2f}" if isinstance(start_time, (int, float)) else str(start_time) end_str = f"{end_time:.2f}" if isinstance(end_time, (int, float)) else str(end_time) audio_segments_html += f"""❌ No audio segments available.
This could happen if the model output doesn't contain valid time stamps.