#!/usr/bin/env python """ VibeVoice ASR Gradio Demo (Patched for MKV and video file support) """ import os import sys import torch import numpy as np import soundfile as sf from pathlib import Path import argparse import time import json import gradio as gr from typing import List, Dict, Tuple, Optional, Generator import tempfile import base64 import io import traceback import threading from concurrent.futures import ThreadPoolExecutor, as_completed # Import TextIteratorStreamer for streaming generation from transformers import TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList try: from liger_kernel.transformers import apply_liger_kernel_to_qwen2 # Only apply RoPE, RMSNorm, SwiGLU patches (these affect the underlying Qwen2 layers) apply_liger_kernel_to_qwen2( rope=True, rms_norm=True, swiglu=True, cross_entropy=False, ) print("✅ Liger Kernel applied to Qwen2 components (RoPE, RMSNorm, SwiGLU)") except Exception as e: print(f"⚠️ Failed to apply Liger Kernel: {e}, you can install it with: pip install liger-kernel") # Try to import pydub for MP3 conversion try: from pydub import AudioSegment HAS_PYDUB = True except ImportError: HAS_PYDUB = False print("⚠️ Warning: pydub not available, falling back to WAV format") from vibevoice.modular.modeling_vibevoice_asr import VibeVoiceASRForConditionalGeneration from vibevoice.processor.vibevoice_asr_processor import VibeVoiceASRProcessor from vibevoice.processor.audio_utils import load_audio_use_ffmpeg, COMMON_AUDIO_EXTS # Add MKV and additional video format support MKV_EXTS = ['.mkv', '.MKV', '.Mkv', '.avi', '.AVI'] for ext in MKV_EXTS: if ext not in COMMON_AUDIO_EXTS: COMMON_AUDIO_EXTS.append(ext) print(f"✅ MKV/Video support enabled. Total supported formats: {len(COMMON_AUDIO_EXTS)}") class VibeVoiceASRInference: """Simple inference wrapper for VibeVoice ASR model.""" def __init__(self, model_path: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16, attn_implementation: str = "flash_attention_2"): """ Initialize the ASR inference pipeline. Args: model_path: Path to the pretrained model (HuggingFace format directory or model name) device: Device to run inference on dtype: Data type for model weights attn_implementation: Attention implementation to use ('flash_attention_2', 'sdpa', 'eager') """ print(f"Loading VibeVoice ASR model from {model_path}") # Load processor self.processor = VibeVoiceASRProcessor.from_pretrained(model_path) # Load model print(f"Using attention implementation: {attn_implementation}") self.model = VibeVoiceASRForConditionalGeneration.from_pretrained( model_path, dtype=dtype, device_map=device if device == "auto" else None, attn_implementation=attn_implementation, trust_remote_code=True ) if device != "auto": self.model = self.model.to(device) self.device = device if device != "auto" else next(self.model.parameters()).device self.model.eval() # Print model info total_params = sum(p.numel() for p in self.model.parameters()) print(f"✅ Model loaded successfully on {self.device}") print(f"📊 Total parameters: {total_params:,} ({total_params/1e9:.2f}B)") def transcribe( self, audio_path: str = None, audio_array: np.ndarray = None, sample_rate: int = None, max_new_tokens: int = 512, temperature: float = 0.0, top_p: float = 1.0, do_sample: bool = False, num_beams: int = 1, repetition_penalty: float = 1.0, context_info: str = None, streamer: Optional[TextIteratorStreamer] = None, ) -> dict: """ Transcribe audio to text. Args: audio_path: Path to audio file audio_array: Audio array (if not loading from file) sample_rate: Sample rate of audio array max_new_tokens: Maximum tokens to generate temperature: Temperature for sampling (0 for greedy) top_p: Top-p for nucleus sampling (1.0 for no filtering) do_sample: Whether to use sampling num_beams: Number of beams for beam search (1 for greedy) repetition_penalty: Repetition penalty (1.0 for no penalty) context_info: Optional context information (e.g., hotwords, speaker names, topics) to help transcription streamer: Optional TextIteratorStreamer for streaming output Returns: Dictionary with transcription results """ # Process audio inputs = self.processor( audio=audio_path, sampling_rate=sample_rate, return_tensors="pt", add_generation_prompt=True, context_info=context_info ) # Move to device inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} # Generate generation_config = { "max_new_tokens": max_new_tokens, "temperature": temperature if temperature > 0 else None, "top_p": top_p if do_sample else None, "do_sample": do_sample, "num_beams": num_beams, "repetition_penalty": repetition_penalty, "pad_token_id": self.processor.pad_id, "eos_token_id": self.processor.tokenizer.eos_token_id, } # Add streamer if provided if streamer is not None: generation_config["streamer"] = streamer # Add stopping criteria for stop button support generation_config["stopping_criteria"] = StoppingCriteriaList([StopOnFlag()]) # Remove None values generation_config = {k: v for k, v in generation_config.items() if v is not None} start_time = time.time() # Calculate input token statistics before generation input_ids = inputs['input_ids'][0] # Shape: [seq_len] total_input_tokens = input_ids.shape[0] # Count padding tokens (tokens equal to pad_id) pad_id = self.processor.pad_id padding_mask = (input_ids == pad_id) num_padding_tokens = padding_mask.sum().item() # Count speech tokens (tokens between speech_start_id and speech_end_id) speech_start_id = self.processor.speech_start_id speech_end_id = self.processor.speech_end_id # Find speech regions input_ids_list = input_ids.tolist() num_speech_tokens = 0 in_speech = False for token_id in input_ids_list: if token_id == speech_start_id: in_speech = True num_speech_tokens += 1 # Count speech_start token elif token_id == speech_end_id: in_speech = False num_speech_tokens += 1 # Count speech_end token elif in_speech: num_speech_tokens += 1 # Text tokens = total - speech - padding num_text_tokens = total_input_tokens - num_speech_tokens - num_padding_tokens with torch.no_grad(): output_ids = self.model.generate( **inputs, **generation_config ) generation_time = time.time() - start_time # Decode output generated_ids = output_ids[0, inputs['input_ids'].shape[1]:] generated_text = self.processor.decode(generated_ids, skip_special_tokens=True) # Parse structured output try: transcription_segments = self.processor.post_process_transcription(generated_text) except Exception as e: print(f"Warning: Failed to parse structured output: {e}") transcription_segments = [] return { "raw_text": generated_text, "segments": transcription_segments, "generation_time": generation_time, "input_tokens": { "total": total_input_tokens, "speech": num_speech_tokens, "text": num_text_tokens, "padding": num_padding_tokens, }, } def clip_and_encode_audio( audio_data: np.ndarray, sr: int, start_time: float, end_time: float, segment_idx: int, use_mp3: bool = True, target_sr: int = 16000, # Downsample to 16kHz for smaller size mp3_bitrate: str = "32k" # Use low bitrate for minimal transfer ) -> Tuple[int, Optional[str], Optional[str]]: """ Clip audio segment and encode to base64. Args: audio_data: Full audio array sr: Sample rate start_time: Start time in seconds end_time: End time in seconds segment_idx: Segment index for identification use_mp3: Whether to use MP3 format (smaller size) target_sr: Target sample rate for downsampling (lower = smaller) mp3_bitrate: MP3 bitrate (lower = smaller, e.g., "24k", "32k", "48k") Returns: Tuple of (segment_idx, base64_string, error_message) """ try: # Convert time to sample indices start_sample = int(start_time * sr) end_sample = int(end_time * sr) # Ensure indices are within bounds start_sample = max(0, start_sample) end_sample = min(len(audio_data), end_sample) if start_sample >= end_sample: return segment_idx, None, f"Invalid time range: [{start_time:.2f}s - {end_time:.2f}s]" # Extract segment segment_data = audio_data[start_sample:end_sample] # Downsample if needed (reduces data size significantly) if sr != target_sr and target_sr < sr: # Simple downsampling using linear interpolation duration = len(segment_data) / sr new_length = int(duration * target_sr) indices = np.linspace(0, len(segment_data) - 1, new_length) segment_data = np.interp(indices, np.arange(len(segment_data)), segment_data) sr = target_sr # Convert float32 audio to int16 for encoding segment_data_int16 = (segment_data * 32768.0).astype(np.int16) # Convert to MP3 if pydub is available and use_mp3 is True if use_mp3 and HAS_PYDUB: try: # Write to WAV in memory wav_buffer = io.BytesIO() sf.write(wav_buffer, segment_data_int16, sr, format='WAV', subtype='PCM_16') wav_buffer.seek(0) # Convert to MP3 with low bitrate audio_segment = AudioSegment.from_wav(wav_buffer) # Convert to mono if stereo (halves the size) if audio_segment.channels > 1: audio_segment = audio_segment.set_channels(1) mp3_buffer = io.BytesIO() audio_segment.export(mp3_buffer, format='mp3', bitrate=mp3_bitrate) mp3_buffer.seek(0) # Encode to base64 audio_bytes = mp3_buffer.read() audio_base64 = base64.b64encode(audio_bytes).decode('utf-8') audio_src = f"data:audio/mp3;base64,{audio_base64}" return segment_idx, audio_src, None except Exception as e: # Fall back to WAV on error print(f"MP3 conversion failed for segment {segment_idx}, using WAV: {e}") # Fall back to WAV format (no temp file, use in-memory buffer) wav_buffer = io.BytesIO() sf.write(wav_buffer, segment_data_int16, sr, format='WAV', subtype='PCM_16') wav_buffer.seek(0) audio_bytes = wav_buffer.read() audio_base64 = base64.b64encode(audio_bytes).decode('utf-8') audio_src = f"data:audio/wav;base64,{audio_base64}" return segment_idx, audio_src, None except Exception as e: error_msg = f"Error clipping segment {segment_idx}: {str(e)}" print(error_msg) return segment_idx, None, error_msg def extract_audio_segments(audio_path: str, segments: List[Dict]) -> List[Tuple[str, str, Optional[str]]]: """ Extract multiple segments from audio file efficiently with parallel processing. Args: audio_path: Path to original audio file segments: List of segment dictionaries with start_time, end_time, etc. Returns: List of tuples (segment_label, audio_base64_src, error_msg) """ try: # Read audio file once using ffmpeg for better format support print(f"📂 Loading audio file: {audio_path}") audio_data, sr = load_audio_use_ffmpeg(audio_path, resample=False) print(f"✅ Audio loaded: {len(audio_data)} samples, {sr} Hz") # Prepare tasks tasks = [] use_mp3 = HAS_PYDUB # Use MP3 if available for i, seg in enumerate(segments): start_time = seg.get('start_time') end_time = seg.get('end_time') # Skip if times are not available or invalid if (not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)) or start_time >= end_time): tasks.append((i, None, None, None, None, None)) # Will be filtered later continue tasks.append((audio_data, sr, start_time, end_time, i, use_mp3)) # Process in parallel using ThreadPoolExecutor results = [] total_segments = len(tasks) completed_count = 0 # Use CPU count for max workers max_workers = os.cpu_count() or 4 print(f"🚀 Starting parallel processing with {max_workers} threads...") with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = {} for task in tasks: if task[0] is None: # Skip invalid tasks continue future = executor.submit(clip_and_encode_audio, *task) futures[future] = task[4] # segment_idx for future in as_completed(futures): try: result = future.result() results.append(result) completed_count += 1 # Log progress every 100 segments or at completion if completed_count % 100 == 0 or completed_count == len(futures): print(f"Progress: {completed_count}/{len(futures)} segments processed ({completed_count*100//len(futures)}%)") except Exception as e: idx = futures[future] results.append((idx, None, f"Processing error: {str(e)}")) completed_count += 1 print(f"Error on segment {idx}: {e}") print(f"✅ Completed processing all {len(futures)} valid segments") # Sort by segment index to maintain order results.sort(key=lambda x: x[0]) # Build output list with labels audio_segments = [] for i, (idx, audio_src, error_msg) in enumerate(results): seg = segments[idx] if idx < len(segments) else {} start_time = seg.get('start_time', 'N/A') end_time = seg.get('end_time', 'N/A') speaker_id = seg.get('speaker_id', 'N/A') segment_label = f"Segment {idx+1}: [{start_time:.2f}s - {end_time:.2f}s] Speaker {speaker_id}" audio_segments.append((segment_label, audio_src, error_msg)) return audio_segments except Exception as e: print(f"Error loading audio file: {e}") return [] # Global variable to store the ASR model asr_model = None # Global stop flag for generation stop_generation_flag = False class StopOnFlag(StoppingCriteria): """Custom stopping criteria that checks a global flag.""" def __call__(self, input_ids, scores, **kwargs): global stop_generation_flag return stop_generation_flag def parse_time_to_seconds(val: Optional[str]) -> Optional[float]: """Parse seconds or hh:mm:ss to float seconds.""" if val is None: return None val = val.strip() if not val: return None try: return float(val) except ValueError: pass if ":" in val: parts = val.split(":") if not all(p.strip().replace(".", "", 1).isdigit() for p in parts): return None parts = [float(p) for p in parts] if len(parts) == 3: h, m, s = parts elif len(parts) == 2: h = 0 m, s = parts else: return None return h * 3600 + m * 60 + s return None def slice_audio_to_temp( audio_data: np.ndarray, sample_rate: int, start_sec: Optional[float], end_sec: Optional[float] ) -> Tuple[Optional[str], Optional[str]]: """Slice audio_data to [start_sec, end_sec) and write to a temp WAV file.""" n_samples = len(audio_data) full_duration = n_samples / float(sample_rate) start = 0.0 if start_sec is None else max(0.0, start_sec) end = full_duration if end_sec is None else min(full_duration, end_sec) if end <= start: return None, f"Invalid time range: start={start:.2f}s, end={end:.2f}s" start_idx = int(start * sample_rate) end_idx = int(end * sample_rate) segment = audio_data[start_idx:end_idx] temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") temp_file.close() segment_int16 = (segment * 32768.0).astype(np.int16) sf.write(temp_file.name, segment_int16, sample_rate, subtype='PCM_16') return temp_file.name, None def initialize_model(model_path: str, device: str = "cuda", attn_implementation: str = "flash_attention_2"): """Initialize the ASR model.""" global asr_model try: dtype = torch.bfloat16 if device != "cpu" else torch.float32 asr_model = VibeVoiceASRInference( model_path=model_path, device=device, dtype=dtype, attn_implementation=attn_implementation ) return f"✅ Model loaded successfully from {model_path}" except Exception as e: import traceback traceback.print_exc() return f"❌ Error loading model: {str(e)}" def transcribe_audio( audio_input, audio_path_input: str, start_time_input: str, end_time_input: str, max_new_tokens: int, temperature: float, top_p: float, do_sample: bool, repetition_penalty: float = 1.0, context_info: str = "" ) -> Generator[Tuple[str, str], None, None]: """ Transcribe audio and return results with audio segments (streaming version). Args: audio_input: Audio file path or tuple (sample_rate, audio_data) max_new_tokens: Maximum tokens to generate temperature: Temperature for sampling (0 for greedy) top_p: Top-p for nucleus sampling do_sample: Whether to use sampling context_info: Optional context information (e.g., hotwords, speaker names, topics) Yields: Tuple of (raw_text, audio_segments_html) """ if asr_model is None: yield "❌ Please load a model first!", "" return if not audio_path_input and audio_input is None: yield "❌ Please provide audio input!", "" return try: print("[INFO] Transcription requested") start_sec = parse_time_to_seconds(start_time_input) end_sec = parse_time_to_seconds(end_time_input) print(f"[INFO] Parsed time range: start={start_sec}, end={end_sec}") if (start_time_input and start_sec is None) or (end_time_input and end_sec is None): yield "❌ Invalid time format. Use seconds or hh:mm:ss.", "" return audio_path = None audio_array = None sample_rate = None if audio_path_input: candidate = Path(audio_path_input.strip()) if not candidate.exists(): yield f"❌ Provided path does not exist: {candidate}", "" return audio_path = str(candidate) print(f"[INFO] Using provided audio path: {audio_path}") # Get audio file path (Gradio Audio component returns tuple (sample_rate, audio_data) or file path) elif isinstance(audio_input, str): audio_path = audio_input print(f"[INFO] Using uploaded audio path: {audio_path}") elif isinstance(audio_input, tuple): # Audio from microphone: (sample_rate, audio_data) sample_rate, audio_array = audio_input print(f"[INFO] Received microphone audio with sample_rate={sample_rate}") elif audio_path is None: yield "❌ Invalid audio input format!", "" return # If slicing is requested, load and slice audio if start_sec is not None or end_sec is not None: print("[INFO] Slicing audio per requested time range") if audio_array is None or sample_rate is None: try: audio_array, sample_rate = load_audio_use_ffmpeg(audio_path, resample=False) print("[INFO] Loaded audio for slicing via ffmpeg") except Exception as exc: yield f"❌ Failed to load audio for slicing: {exc}", "" return sliced_path, err = slice_audio_to_temp(audio_array, sample_rate, start_sec, end_sec) if err: yield f"❌ {err}", "" return audio_path = sliced_path print(f"[INFO] Sliced audio written to temp file: {audio_path}") elif audio_array is not None and sample_rate is not None: # no slicing but microphone input: write to temp file temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") audio_path = temp_file.name temp_file.close() audio_data_int16 = (audio_array * 32768.0).astype(np.int16) sf.write(audio_path, audio_data_int16, sample_rate, subtype='PCM_16') print(f"[INFO] Microphone audio saved to temp file: {audio_path}") # Create streamer for real-time output streamer = TextIteratorStreamer( asr_model.processor.tokenizer, skip_prompt=True, skip_special_tokens=True ) # Store result in a mutable container for the thread result_container = {"result": None, "error": None} def run_transcription(): try: result_container["result"] = asr_model.transcribe( audio_path=audio_path, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=do_sample, repetition_penalty=repetition_penalty, context_info=context_info if context_info and context_info.strip() else None, streamer=streamer ) except Exception as e: result_container["error"] = str(e) traceback.print_exc() # Start transcription in background thread print("[INFO] Starting model transcription (streaming mode)") start_time = time.time() transcription_thread = threading.Thread(target=run_transcription) transcription_thread.start() # Yield streaming output generated_text = "" token_count = 0 for new_text in streamer: generated_text += new_text token_count += 1 elapsed = time.time() - start_time # Show streaming output with live stats, format for readability formatted_text = generated_text.replace('},', '},\n') streaming_output = f"--- 🔴 LIVE Streaming Output (tokens: {token_count}, time: {elapsed:.1f}s) ---\n{formatted_text}" yield streaming_output, "
⏳ Generating transcription... Audio segments will appear after completion.
" # Wait for thread to complete transcription_thread.join() if result_container["error"]: yield f"❌ Error during transcription: {result_container['error']}", "" return result = result_container["result"] generation_time = time.time() - start_time # Get input token statistics input_tokens = result.get('input_tokens', {}) speech_tokens = input_tokens.get('speech', 0) text_tokens = input_tokens.get('text', 0) padding_tokens = input_tokens.get('padding', 0) total_input = input_tokens.get('total', 0) # Format final raw output with input/output token stats raw_output = f"--- ✅ Raw Output ---\n" raw_output += f"📥 Input: {total_input} tokens (🎤 speech: {speech_tokens}, 📝 text: {text_tokens}, ⬜ pad: {padding_tokens})\n" raw_output += f"📤 Output: {token_count} tokens | ⏱️ Time: {generation_time:.2f}s\n" raw_output += f"---\n" # Format raw text for better readability: add newline after each dict (},) formatted_raw_text = result['raw_text'].replace('},', '},\n') raw_output += formatted_raw_text # Debug: print raw output to console print(f"[DEBUG] Raw model output:") print(f"[DEBUG] {result['raw_text']}") print(f"[DEBUG] Found {len(result['segments'])} segments") # Create audio segments with server-side encoding (low quality for minimal transfer) # Using: 16kHz mono MP3 @ 32kbps = ~4KB per second of audio audio_segments_html = "" segments = result['segments'] if segments: num_segments = len(segments) print(f"[INFO] Creating per-segment audio clips ({num_segments} segments, 16kHz mono MP3 @ 32kbps)") # Extract all audio segments efficiently (load audio only once) audio_segments = extract_audio_segments(audio_path, segments) print("[INFO] Completed creating audio clips") # Calculate approximate total size total_duration = sum( (seg.get('end_time', 0) - seg.get('start_time', 0)) for seg in segments if isinstance(seg.get('start_time'), (int, float)) and isinstance(seg.get('end_time'), (int, float)) ) approx_size_kb = total_duration * 4 # ~4KB per second at 32kbps # Add CSS for theme-aware styling theme_css = """ """ audio_segments_html = theme_css audio_segments_html += f"
" # Add format info format_info = "MP3 32kbps 16kHz mono" if HAS_PYDUB else "WAV 16kHz" audio_segments_html += f"

🔊 Audio Segments ({num_segments} segments)" audio_segments_html += f"📦 ~{approx_size_kb:.0f}KB ({format_info})

" audio_segments_html += "

🎵 Click the play button to listen to each segment directly!

" for i, (label, audio_src, error_msg) in enumerate(audio_segments): seg = segments[i] if i < len(segments) else {} start_time = seg.get('start_time', 'N/A') end_time = seg.get('end_time', 'N/A') speaker_id = seg.get('speaker_id', 'N/A') content = seg.get('text', '') # Format times nicely start_str = f"{start_time:.2f}" if isinstance(start_time, (int, float)) else str(start_time) end_str = f"{end_time:.2f}" if isinstance(end_time, (int, float)) else str(end_time) audio_segments_html += f"""

Segment {i+1}

Time: [{start_str}s - {end_str}s] | Speaker: {speaker_id}
{content}
""" if audio_src: # Detect format from data URI audio_type = 'audio/mp3' if 'audio/mp3' in audio_src else 'audio/wav' audio_segments_html += f""" """ elif error_msg: audio_segments_html += f"""
❌ {error_msg}
""" else: audio_segments_html += """
Audio playback unavailable for this segment
""" audio_segments_html += "
" audio_segments_html += "
" else: audio_segments_html = """

❌ No audio segments available.

This could happen if the model output doesn't contain valid time stamps.

""" # Final yield with complete results yield raw_output, audio_segments_html except Exception as e: print(f"Error during transcription: {e}") print(traceback.format_exc()) yield f"❌ Error during transcription: {str(e)}", "" def create_gradio_interface(model_path: str, default_max_tokens: int = 8192, attn_implementation: str = "flash_attention_2"): """Create and launch Gradio interface. Args: model_path: Path to the model (HuggingFace format directory or model name) default_max_tokens: Default value for max_new_tokens slider attn_implementation: Attention implementation to use ('flash_attention_2', 'sdpa', 'eager') """ # Initialize model at startup device = "cuda" if torch.cuda.is_available() else "cpu" model_status = initialize_model(model_path, device, attn_implementation) print(model_status) # Exit if model loading failed if model_status.startswith("❌"): print("\n" + "="*80) print("💥 FATAL ERROR: Model loading failed!") print("="*80) print("Cannot start demo without a valid model. Please check:") print(" 1. Model path is correct") print(" 2. Model files are not corrupted") print(" 3. You have enough GPU memory") print(" 4. CUDA is properly installed (if using GPU)") print("="*80) sys.exit(1) # Custom CSS for Stop button styling custom_css = """ #stop-btn { background: linear-gradient(135deg, #ef4444 0%, #dc2626 100%) !important; border: none !important; color: white !important; } #stop-btn:hover { background: linear-gradient(135deg, #dc2626 0%, #b91c1c 100%) !important; } """ # Gradio 6.0+ moved theme/css to launch() with gr.Blocks(title="VibeVoice ASR Demo") as demo: gr.Markdown("# 🎙️ VibeVoice ASR Demo") gr.Markdown("Upload audio files or record from microphone to get speech-to-text transcription with speaker diarization.") gr.Markdown(f"**Model loaded from:** `{model_path}`") with gr.Row(): with gr.Column(scale=1): # Generation parameters gr.Markdown("## ⚙️ Generation Parameters") max_tokens_slider = gr.Slider( minimum=4096, maximum=65536, value=default_max_tokens, step=4096, label="Max New Tokens" ) # Sampling parameters gr.Markdown("### 🎲 Sampling") do_sample_checkbox = gr.Checkbox( value=False, label="Enable Sampling", info="Enable random sampling instead of deterministic decoding" ) with gr.Column(visible=False) as sampling_params: temperature_slider = gr.Slider( minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="Temperature", info="0 = greedy, higher = more random" ) top_p_slider = gr.Slider( minimum=0.0, maximum=1.0, value=1.0, step=0.05, label="Top-p (Nucleus Sampling)", info="1.0 = no filtering" ) # Repetition penalty (works with both greedy and sampling) repetition_penalty_slider = gr.Slider( minimum=1.0, maximum=1.2, value=1.0, step=0.01, label="Repetition Penalty", info="1.0 = no penalty, higher = less repetition (works with greedy & sampling)" ) # Context information section gr.Markdown("## 📋 Context Info (Optional)") context_info_input = gr.Textbox( label="Context Information", placeholder="Enter hotwords, speaker names, topics, or other context to help transcription...\nExample:\nJohn Smith\nMachine Learning\nOpenAI", lines=4, max_lines=8, interactive=True, info="Provide context like proper nouns, technical terms, or speaker names to improve accuracy" ) with gr.Column(scale=2): # Audio input section gr.Markdown("## 🎵 Audio Input") audio_input = gr.Audio( label="Upload Audio File or Record from Microphone", sources=["upload", "microphone"], type="filepath", interactive=True, buttons=["download"] ) # Video file upload (MKV, MP4, AVI, etc.) - audio will be extracted video_file_input = gr.File( label="📹 Or Upload Video File (MKV, MP4, AVI, etc.)", file_types=[".mkv", ".mp4", ".avi", ".mov", ".webm", ".flv", ".m4v", ".mpeg", ".3gp"], type="filepath" ) with gr.Accordion("📂 Advanced: Remote Path & Time Slicing", open=False): audio_path_input = gr.Textbox( label="Audio path (optional)", placeholder="Enter remote audio file path", lines=1 ) with gr.Row(): start_time_input = gr.Textbox( label="Start time", placeholder="e.g., 0 or 00:00:00", lines=1, info="Leave empty to start from the beginning" ) end_time_input = gr.Textbox( label="End time", placeholder="e.g., 30.5 or 00:00:30.5", lines=1, info="Leave empty to use full length" ) with gr.Row(): transcribe_button = gr.Button("🎯 Transcribe", variant="primary", size="lg", scale=3) stop_button = gr.Button("⏹️ Stop", variant="secondary", size="lg", scale=1, elem_id="stop-btn") # Results section gr.Markdown("## 📝 Results") with gr.Tabs(): with gr.TabItem("Raw Output"): raw_output = gr.Textbox( label="Raw Transcription Output", lines=8, max_lines=20, interactive=False ) with gr.TabItem("Audio Segments"): audio_segments_output = gr.HTML( label="Play individual segments to verify accuracy" ) # Event handlers do_sample_checkbox.change( fn=lambda x: gr.update(visible=x), inputs=[do_sample_checkbox], outputs=[sampling_params] ) def reset_stop_flag(): """Reset stop flag before starting transcription.""" global stop_generation_flag stop_generation_flag = False def transcribe_with_video_support( audio_input, video_file_input, audio_path_input, start_time_input, end_time_input, max_new_tokens, temperature, top_p, do_sample, repetition_penalty, context_info ): """Wrapper that handles both audio and video file inputs.""" # If video file is provided, use it as the audio path effective_audio_path = audio_path_input if video_file_input is not None and not audio_path_input: effective_audio_path = video_file_input print(f"[INFO] Using video file: {effective_audio_path}") # Call original transcribe function yield from transcribe_audio( audio_input, effective_audio_path, start_time_input, end_time_input, max_new_tokens, temperature, top_p, do_sample, repetition_penalty, context_info ) def set_stop_flag(): """Set stop flag to interrupt generation.""" global stop_generation_flag stop_generation_flag = True return "⏹️ Stop requested..." transcribe_button.click( fn=reset_stop_flag, inputs=[], outputs=[], queue=False ).then( fn=transcribe_with_video_support, inputs=[ audio_input, video_file_input, audio_path_input, start_time_input, end_time_input, max_tokens_slider, temperature_slider, top_p_slider, do_sample_checkbox, repetition_penalty_slider, context_info_input ], outputs=[raw_output, audio_segments_output] ) stop_button.click( fn=set_stop_flag, inputs=[], outputs=[raw_output], queue=False ) # Add examples gr.Markdown("## 📋 Instructions") gr.Markdown(f""" 1. **Upload Audio**: Use the audio component to upload a file or record from microphone - **Supported formats**: {', '.join(sorted(set([ext.lower() for ext in COMMON_AUDIO_EXTS])))} - Optionally set **Start/End time** (seconds or hh:mm:ss) to clip before transcription 2. **Context Info (Optional)**: Provide context to improve transcription accuracy - Add hotwords, proper nouns, speaker names, or technical terms - One item per line or comma-separated - Examples: "John Smith", "OpenAI", "machine learning" 3. **Adjust Parameters**: Configure generation parameters as needed 4. **Transcribe**: Click "Transcribe" to get results 5. **Review Results**: - **Raw Output**: View the model's original output - **Audio Segments**: Play individual segments directly to verify accuracy **Audio Segments**: Each segment shows the time range, speaker ID, transcribed content, and an embedded audio player for immediate verification. """) return demo, custom_css def main(): parser = argparse.ArgumentParser(description="VibeVoice ASR Gradio Demo") parser.add_argument( "--model_path", type=str, default="microsoft/VibeVoice-ASR", help="Path to the model (HuggingFace format directory or model name)" ) parser.add_argument( "--attn_implementation", type=str, default="flash_attention_2", help="Attention implementation to use (default: flash_attention_2)" ) parser.add_argument( "--max_new_tokens", type=int, default=32768, help="Default max new tokens for generation (default: 32768)" ) parser.add_argument( "--host", type=str, default="0.0.0.0", help="Host to bind the server to" ) parser.add_argument( "--port", type=int, default=7860, help="Port to bind the server to" ) parser.add_argument( "--share", action="store_true", help="Create a public link" ) args = parser.parse_args() # Create and launch interface demo, custom_css = create_gradio_interface( model_path=args.model_path, default_max_tokens=args.max_new_tokens, attn_implementation=args.attn_implementation ) print(f"🚀 Starting VibeVoice ASR Demo...") print(f"📍 Server will be available at: http://{args.host}:{args.port}") # Gradio 6.0+ moved theme/css to launch() launch_kwargs = { "server_name": args.host, "server_port": args.port, "share": args.share, "show_error": True, "theme": gr.themes.Soft(), "css": custom_css, } # Enable queue for concurrent request handling demo.queue(default_concurrency_limit=3) demo.launch(**launch_kwargs) if __name__ == "__main__": main()