note-articles/static/scripts/vibevoice-asr/vibevoice_asr_gradio_demo_patched.py
koide 1fb76254e9
All checks were successful
Deploy Docusaurus Site / deploy (push) Successful in 27s
Add: VibeVoice ASR セットアップスクリプト一式
2026-02-24 01:21:33 +00:00

1230 lines
47 KiB
Python

#!/usr/bin/env python
"""
VibeVoice ASR Gradio Demo
(Patched for MKV and video file support)
"""
import os
import sys
import torch
import numpy as np
import soundfile as sf
from pathlib import Path
import argparse
import time
import json
import gradio as gr
from typing import List, Dict, Tuple, Optional, Generator
import tempfile
import base64
import io
import traceback
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
# Import TextIteratorStreamer for streaming generation
from transformers import TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList
try:
from liger_kernel.transformers import apply_liger_kernel_to_qwen2
# Only apply RoPE, RMSNorm, SwiGLU patches (these affect the underlying Qwen2 layers)
apply_liger_kernel_to_qwen2(
rope=True,
rms_norm=True,
swiglu=True,
cross_entropy=False,
)
print("✅ Liger Kernel applied to Qwen2 components (RoPE, RMSNorm, SwiGLU)")
except Exception as e:
print(f"⚠️ Failed to apply Liger Kernel: {e}, you can install it with: pip install liger-kernel")
# Try to import pydub for MP3 conversion
try:
from pydub import AudioSegment
HAS_PYDUB = True
except ImportError:
HAS_PYDUB = False
print("⚠️ Warning: pydub not available, falling back to WAV format")
from vibevoice.modular.modeling_vibevoice_asr import VibeVoiceASRForConditionalGeneration
from vibevoice.processor.vibevoice_asr_processor import VibeVoiceASRProcessor
from vibevoice.processor.audio_utils import load_audio_use_ffmpeg, COMMON_AUDIO_EXTS
# Add MKV and additional video format support
MKV_EXTS = ['.mkv', '.MKV', '.Mkv', '.avi', '.AVI']
for ext in MKV_EXTS:
if ext not in COMMON_AUDIO_EXTS:
COMMON_AUDIO_EXTS.append(ext)
print(f"✅ MKV/Video support enabled. Total supported formats: {len(COMMON_AUDIO_EXTS)}")
class VibeVoiceASRInference:
"""Simple inference wrapper for VibeVoice ASR model."""
def __init__(self, model_path: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16, attn_implementation: str = "flash_attention_2"):
"""
Initialize the ASR inference pipeline.
Args:
model_path: Path to the pretrained model (HuggingFace format directory or model name)
device: Device to run inference on
dtype: Data type for model weights
attn_implementation: Attention implementation to use ('flash_attention_2', 'sdpa', 'eager')
"""
print(f"Loading VibeVoice ASR model from {model_path}")
# Load processor
self.processor = VibeVoiceASRProcessor.from_pretrained(model_path)
# Load model
print(f"Using attention implementation: {attn_implementation}")
self.model = VibeVoiceASRForConditionalGeneration.from_pretrained(
model_path,
dtype=dtype,
device_map=device if device == "auto" else None,
attn_implementation=attn_implementation,
trust_remote_code=True
)
if device != "auto":
self.model = self.model.to(device)
self.device = device if device != "auto" else next(self.model.parameters()).device
self.model.eval()
# Print model info
total_params = sum(p.numel() for p in self.model.parameters())
print(f"✅ Model loaded successfully on {self.device}")
print(f"📊 Total parameters: {total_params:,} ({total_params/1e9:.2f}B)")
def transcribe(
self,
audio_path: str = None,
audio_array: np.ndarray = None,
sample_rate: int = None,
max_new_tokens: int = 512,
temperature: float = 0.0,
top_p: float = 1.0,
do_sample: bool = False,
num_beams: int = 1,
repetition_penalty: float = 1.0,
context_info: str = None,
streamer: Optional[TextIteratorStreamer] = None,
) -> dict:
"""
Transcribe audio to text.
Args:
audio_path: Path to audio file
audio_array: Audio array (if not loading from file)
sample_rate: Sample rate of audio array
max_new_tokens: Maximum tokens to generate
temperature: Temperature for sampling (0 for greedy)
top_p: Top-p for nucleus sampling (1.0 for no filtering)
do_sample: Whether to use sampling
num_beams: Number of beams for beam search (1 for greedy)
repetition_penalty: Repetition penalty (1.0 for no penalty)
context_info: Optional context information (e.g., hotwords, speaker names, topics) to help transcription
streamer: Optional TextIteratorStreamer for streaming output
Returns:
Dictionary with transcription results
"""
# Process audio
inputs = self.processor(
audio=audio_path,
sampling_rate=sample_rate,
return_tensors="pt",
add_generation_prompt=True,
context_info=context_info
)
# Move to device
inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
for k, v in inputs.items()}
# Generate
generation_config = {
"max_new_tokens": max_new_tokens,
"temperature": temperature if temperature > 0 else None,
"top_p": top_p if do_sample else None,
"do_sample": do_sample,
"num_beams": num_beams,
"repetition_penalty": repetition_penalty,
"pad_token_id": self.processor.pad_id,
"eos_token_id": self.processor.tokenizer.eos_token_id,
}
# Add streamer if provided
if streamer is not None:
generation_config["streamer"] = streamer
# Add stopping criteria for stop button support
generation_config["stopping_criteria"] = StoppingCriteriaList([StopOnFlag()])
# Remove None values
generation_config = {k: v for k, v in generation_config.items() if v is not None}
start_time = time.time()
# Calculate input token statistics before generation
input_ids = inputs['input_ids'][0] # Shape: [seq_len]
total_input_tokens = input_ids.shape[0]
# Count padding tokens (tokens equal to pad_id)
pad_id = self.processor.pad_id
padding_mask = (input_ids == pad_id)
num_padding_tokens = padding_mask.sum().item()
# Count speech tokens (tokens between speech_start_id and speech_end_id)
speech_start_id = self.processor.speech_start_id
speech_end_id = self.processor.speech_end_id
# Find speech regions
input_ids_list = input_ids.tolist()
num_speech_tokens = 0
in_speech = False
for token_id in input_ids_list:
if token_id == speech_start_id:
in_speech = True
num_speech_tokens += 1 # Count speech_start token
elif token_id == speech_end_id:
in_speech = False
num_speech_tokens += 1 # Count speech_end token
elif in_speech:
num_speech_tokens += 1
# Text tokens = total - speech - padding
num_text_tokens = total_input_tokens - num_speech_tokens - num_padding_tokens
with torch.no_grad():
output_ids = self.model.generate(
**inputs,
**generation_config
)
generation_time = time.time() - start_time
# Decode output
generated_ids = output_ids[0, inputs['input_ids'].shape[1]:]
generated_text = self.processor.decode(generated_ids, skip_special_tokens=True)
# Parse structured output
try:
transcription_segments = self.processor.post_process_transcription(generated_text)
except Exception as e:
print(f"Warning: Failed to parse structured output: {e}")
transcription_segments = []
return {
"raw_text": generated_text,
"segments": transcription_segments,
"generation_time": generation_time,
"input_tokens": {
"total": total_input_tokens,
"speech": num_speech_tokens,
"text": num_text_tokens,
"padding": num_padding_tokens,
},
}
def clip_and_encode_audio(
audio_data: np.ndarray,
sr: int,
start_time: float,
end_time: float,
segment_idx: int,
use_mp3: bool = True,
target_sr: int = 16000, # Downsample to 16kHz for smaller size
mp3_bitrate: str = "32k" # Use low bitrate for minimal transfer
) -> Tuple[int, Optional[str], Optional[str]]:
"""
Clip audio segment and encode to base64.
Args:
audio_data: Full audio array
sr: Sample rate
start_time: Start time in seconds
end_time: End time in seconds
segment_idx: Segment index for identification
use_mp3: Whether to use MP3 format (smaller size)
target_sr: Target sample rate for downsampling (lower = smaller)
mp3_bitrate: MP3 bitrate (lower = smaller, e.g., "24k", "32k", "48k")
Returns:
Tuple of (segment_idx, base64_string, error_message)
"""
try:
# Convert time to sample indices
start_sample = int(start_time * sr)
end_sample = int(end_time * sr)
# Ensure indices are within bounds
start_sample = max(0, start_sample)
end_sample = min(len(audio_data), end_sample)
if start_sample >= end_sample:
return segment_idx, None, f"Invalid time range: [{start_time:.2f}s - {end_time:.2f}s]"
# Extract segment
segment_data = audio_data[start_sample:end_sample]
# Downsample if needed (reduces data size significantly)
if sr != target_sr and target_sr < sr:
# Simple downsampling using linear interpolation
duration = len(segment_data) / sr
new_length = int(duration * target_sr)
indices = np.linspace(0, len(segment_data) - 1, new_length)
segment_data = np.interp(indices, np.arange(len(segment_data)), segment_data)
sr = target_sr
# Convert float32 audio to int16 for encoding
segment_data_int16 = (segment_data * 32768.0).astype(np.int16)
# Convert to MP3 if pydub is available and use_mp3 is True
if use_mp3 and HAS_PYDUB:
try:
# Write to WAV in memory
wav_buffer = io.BytesIO()
sf.write(wav_buffer, segment_data_int16, sr, format='WAV', subtype='PCM_16')
wav_buffer.seek(0)
# Convert to MP3 with low bitrate
audio_segment = AudioSegment.from_wav(wav_buffer)
# Convert to mono if stereo (halves the size)
if audio_segment.channels > 1:
audio_segment = audio_segment.set_channels(1)
mp3_buffer = io.BytesIO()
audio_segment.export(mp3_buffer, format='mp3', bitrate=mp3_bitrate)
mp3_buffer.seek(0)
# Encode to base64
audio_bytes = mp3_buffer.read()
audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
audio_src = f"data:audio/mp3;base64,{audio_base64}"
return segment_idx, audio_src, None
except Exception as e:
# Fall back to WAV on error
print(f"MP3 conversion failed for segment {segment_idx}, using WAV: {e}")
# Fall back to WAV format (no temp file, use in-memory buffer)
wav_buffer = io.BytesIO()
sf.write(wav_buffer, segment_data_int16, sr, format='WAV', subtype='PCM_16')
wav_buffer.seek(0)
audio_bytes = wav_buffer.read()
audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
audio_src = f"data:audio/wav;base64,{audio_base64}"
return segment_idx, audio_src, None
except Exception as e:
error_msg = f"Error clipping segment {segment_idx}: {str(e)}"
print(error_msg)
return segment_idx, None, error_msg
def extract_audio_segments(audio_path: str, segments: List[Dict]) -> List[Tuple[str, str, Optional[str]]]:
"""
Extract multiple segments from audio file efficiently with parallel processing.
Args:
audio_path: Path to original audio file
segments: List of segment dictionaries with start_time, end_time, etc.
Returns:
List of tuples (segment_label, audio_base64_src, error_msg)
"""
try:
# Read audio file once using ffmpeg for better format support
print(f"📂 Loading audio file: {audio_path}")
audio_data, sr = load_audio_use_ffmpeg(audio_path, resample=False)
print(f"✅ Audio loaded: {len(audio_data)} samples, {sr} Hz")
# Prepare tasks
tasks = []
use_mp3 = HAS_PYDUB # Use MP3 if available
for i, seg in enumerate(segments):
start_time = seg.get('start_time')
end_time = seg.get('end_time')
# Skip if times are not available or invalid
if (not isinstance(start_time, (int, float)) or
not isinstance(end_time, (int, float)) or
start_time >= end_time):
tasks.append((i, None, None, None, None, None)) # Will be filtered later
continue
tasks.append((audio_data, sr, start_time, end_time, i, use_mp3))
# Process in parallel using ThreadPoolExecutor
results = []
total_segments = len(tasks)
completed_count = 0
# Use CPU count for max workers
max_workers = os.cpu_count() or 4
print(f"🚀 Starting parallel processing with {max_workers} threads...")
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = {}
for task in tasks:
if task[0] is None: # Skip invalid tasks
continue
future = executor.submit(clip_and_encode_audio, *task)
futures[future] = task[4] # segment_idx
for future in as_completed(futures):
try:
result = future.result()
results.append(result)
completed_count += 1
# Log progress every 100 segments or at completion
if completed_count % 100 == 0 or completed_count == len(futures):
print(f"Progress: {completed_count}/{len(futures)} segments processed ({completed_count*100//len(futures)}%)")
except Exception as e:
idx = futures[future]
results.append((idx, None, f"Processing error: {str(e)}"))
completed_count += 1
print(f"Error on segment {idx}: {e}")
print(f"✅ Completed processing all {len(futures)} valid segments")
# Sort by segment index to maintain order
results.sort(key=lambda x: x[0])
# Build output list with labels
audio_segments = []
for i, (idx, audio_src, error_msg) in enumerate(results):
seg = segments[idx] if idx < len(segments) else {}
start_time = seg.get('start_time', 'N/A')
end_time = seg.get('end_time', 'N/A')
speaker_id = seg.get('speaker_id', 'N/A')
segment_label = f"Segment {idx+1}: [{start_time:.2f}s - {end_time:.2f}s] Speaker {speaker_id}"
audio_segments.append((segment_label, audio_src, error_msg))
return audio_segments
except Exception as e:
print(f"Error loading audio file: {e}")
return []
# Global variable to store the ASR model
asr_model = None
# Global stop flag for generation
stop_generation_flag = False
class StopOnFlag(StoppingCriteria):
"""Custom stopping criteria that checks a global flag."""
def __call__(self, input_ids, scores, **kwargs):
global stop_generation_flag
return stop_generation_flag
def parse_time_to_seconds(val: Optional[str]) -> Optional[float]:
"""Parse seconds or hh:mm:ss to float seconds."""
if val is None:
return None
val = val.strip()
if not val:
return None
try:
return float(val)
except ValueError:
pass
if ":" in val:
parts = val.split(":")
if not all(p.strip().replace(".", "", 1).isdigit() for p in parts):
return None
parts = [float(p) for p in parts]
if len(parts) == 3:
h, m, s = parts
elif len(parts) == 2:
h = 0
m, s = parts
else:
return None
return h * 3600 + m * 60 + s
return None
def slice_audio_to_temp(
audio_data: np.ndarray,
sample_rate: int,
start_sec: Optional[float],
end_sec: Optional[float]
) -> Tuple[Optional[str], Optional[str]]:
"""Slice audio_data to [start_sec, end_sec) and write to a temp WAV file."""
n_samples = len(audio_data)
full_duration = n_samples / float(sample_rate)
start = 0.0 if start_sec is None else max(0.0, start_sec)
end = full_duration if end_sec is None else min(full_duration, end_sec)
if end <= start:
return None, f"Invalid time range: start={start:.2f}s, end={end:.2f}s"
start_idx = int(start * sample_rate)
end_idx = int(end * sample_rate)
segment = audio_data[start_idx:end_idx]
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
temp_file.close()
segment_int16 = (segment * 32768.0).astype(np.int16)
sf.write(temp_file.name, segment_int16, sample_rate, subtype='PCM_16')
return temp_file.name, None
def initialize_model(model_path: str, device: str = "cuda", attn_implementation: str = "flash_attention_2"):
"""Initialize the ASR model."""
global asr_model
try:
dtype = torch.bfloat16 if device != "cpu" else torch.float32
asr_model = VibeVoiceASRInference(
model_path=model_path,
device=device,
dtype=dtype,
attn_implementation=attn_implementation
)
return f"✅ Model loaded successfully from {model_path}"
except Exception as e:
import traceback
traceback.print_exc()
return f"❌ Error loading model: {str(e)}"
def transcribe_audio(
audio_input,
audio_path_input: str,
start_time_input: str,
end_time_input: str,
max_new_tokens: int,
temperature: float,
top_p: float,
do_sample: bool,
repetition_penalty: float = 1.0,
context_info: str = ""
) -> Generator[Tuple[str, str], None, None]:
"""
Transcribe audio and return results with audio segments (streaming version).
Args:
audio_input: Audio file path or tuple (sample_rate, audio_data)
max_new_tokens: Maximum tokens to generate
temperature: Temperature for sampling (0 for greedy)
top_p: Top-p for nucleus sampling
do_sample: Whether to use sampling
context_info: Optional context information (e.g., hotwords, speaker names, topics)
Yields:
Tuple of (raw_text, audio_segments_html)
"""
if asr_model is None:
yield "❌ Please load a model first!", ""
return
if not audio_path_input and audio_input is None:
yield "❌ Please provide audio input!", ""
return
try:
print("[INFO] Transcription requested")
start_sec = parse_time_to_seconds(start_time_input)
end_sec = parse_time_to_seconds(end_time_input)
print(f"[INFO] Parsed time range: start={start_sec}, end={end_sec}")
if (start_time_input and start_sec is None) or (end_time_input and end_sec is None):
yield "❌ Invalid time format. Use seconds or hh:mm:ss.", ""
return
audio_path = None
audio_array = None
sample_rate = None
if audio_path_input:
candidate = Path(audio_path_input.strip())
if not candidate.exists():
yield f"❌ Provided path does not exist: {candidate}", ""
return
audio_path = str(candidate)
print(f"[INFO] Using provided audio path: {audio_path}")
# Get audio file path (Gradio Audio component returns tuple (sample_rate, audio_data) or file path)
elif isinstance(audio_input, str):
audio_path = audio_input
print(f"[INFO] Using uploaded audio path: {audio_path}")
elif isinstance(audio_input, tuple):
# Audio from microphone: (sample_rate, audio_data)
sample_rate, audio_array = audio_input
print(f"[INFO] Received microphone audio with sample_rate={sample_rate}")
elif audio_path is None:
yield "❌ Invalid audio input format!", ""
return
# If slicing is requested, load and slice audio
if start_sec is not None or end_sec is not None:
print("[INFO] Slicing audio per requested time range")
if audio_array is None or sample_rate is None:
try:
audio_array, sample_rate = load_audio_use_ffmpeg(audio_path, resample=False)
print("[INFO] Loaded audio for slicing via ffmpeg")
except Exception as exc:
yield f"❌ Failed to load audio for slicing: {exc}", ""
return
sliced_path, err = slice_audio_to_temp(audio_array, sample_rate, start_sec, end_sec)
if err:
yield f"{err}", ""
return
audio_path = sliced_path
print(f"[INFO] Sliced audio written to temp file: {audio_path}")
elif audio_array is not None and sample_rate is not None:
# no slicing but microphone input: write to temp file
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
audio_path = temp_file.name
temp_file.close()
audio_data_int16 = (audio_array * 32768.0).astype(np.int16)
sf.write(audio_path, audio_data_int16, sample_rate, subtype='PCM_16')
print(f"[INFO] Microphone audio saved to temp file: {audio_path}")
# Create streamer for real-time output
streamer = TextIteratorStreamer(
asr_model.processor.tokenizer,
skip_prompt=True,
skip_special_tokens=True
)
# Store result in a mutable container for the thread
result_container = {"result": None, "error": None}
def run_transcription():
try:
result_container["result"] = asr_model.transcribe(
audio_path=audio_path,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=do_sample,
repetition_penalty=repetition_penalty,
context_info=context_info if context_info and context_info.strip() else None,
streamer=streamer
)
except Exception as e:
result_container["error"] = str(e)
traceback.print_exc()
# Start transcription in background thread
print("[INFO] Starting model transcription (streaming mode)")
start_time = time.time()
transcription_thread = threading.Thread(target=run_transcription)
transcription_thread.start()
# Yield streaming output
generated_text = ""
token_count = 0
for new_text in streamer:
generated_text += new_text
token_count += 1
elapsed = time.time() - start_time
# Show streaming output with live stats, format for readability
formatted_text = generated_text.replace('},', '},\n')
streaming_output = f"--- 🔴 LIVE Streaming Output (tokens: {token_count}, time: {elapsed:.1f}s) ---\n{formatted_text}"
yield streaming_output, "<div style='padding: 20px; text-align: center; color: #6c757d;'>⏳ Generating transcription... Audio segments will appear after completion.</div>"
# Wait for thread to complete
transcription_thread.join()
if result_container["error"]:
yield f"❌ Error during transcription: {result_container['error']}", ""
return
result = result_container["result"]
generation_time = time.time() - start_time
# Get input token statistics
input_tokens = result.get('input_tokens', {})
speech_tokens = input_tokens.get('speech', 0)
text_tokens = input_tokens.get('text', 0)
padding_tokens = input_tokens.get('padding', 0)
total_input = input_tokens.get('total', 0)
# Format final raw output with input/output token stats
raw_output = f"--- ✅ Raw Output ---\n"
raw_output += f"📥 Input: {total_input} tokens (🎤 speech: {speech_tokens}, 📝 text: {text_tokens}, ⬜ pad: {padding_tokens})\n"
raw_output += f"📤 Output: {token_count} tokens | ⏱️ Time: {generation_time:.2f}s\n"
raw_output += f"---\n"
# Format raw text for better readability: add newline after each dict (},)
formatted_raw_text = result['raw_text'].replace('},', '},\n')
raw_output += formatted_raw_text
# Debug: print raw output to console
print(f"[DEBUG] Raw model output:")
print(f"[DEBUG] {result['raw_text']}")
print(f"[DEBUG] Found {len(result['segments'])} segments")
# Create audio segments with server-side encoding (low quality for minimal transfer)
# Using: 16kHz mono MP3 @ 32kbps = ~4KB per second of audio
audio_segments_html = ""
segments = result['segments']
if segments:
num_segments = len(segments)
print(f"[INFO] Creating per-segment audio clips ({num_segments} segments, 16kHz mono MP3 @ 32kbps)")
# Extract all audio segments efficiently (load audio only once)
audio_segments = extract_audio_segments(audio_path, segments)
print("[INFO] Completed creating audio clips")
# Calculate approximate total size
total_duration = sum(
(seg.get('end_time', 0) - seg.get('start_time', 0))
for seg in segments
if isinstance(seg.get('start_time'), (int, float)) and isinstance(seg.get('end_time'), (int, float))
)
approx_size_kb = total_duration * 4 # ~4KB per second at 32kbps
# Add CSS for theme-aware styling
theme_css = """
<style>
:root {
--segment-bg: #f8f9fa;
--segment-border: #e1e5e9;
--segment-text: #495057;
--segment-meta: #6c757d;
--content-bg: white;
--content-border: #007bff;
--warning-bg: #fff3cd;
--warning-border: #ffc107;
--warning-text: #856404;
}
@media (prefers-color-scheme: dark) {
:root {
--segment-bg: #2d3748;
--segment-border: #4a5568;
--segment-text: #e2e8f0;
--segment-meta: #a0aec0;
--content-bg: #1a202c;
--content-border: #4299e1;
--warning-bg: #744210;
--warning-border: #d69e2e;
--warning-text: #faf089;
}
}
.audio-segments-container {
max-height: 600px;
overflow-y: auto;
padding: 10px;
}
.audio-segment {
margin-bottom: 15px;
padding: 15px;
border: 2px solid var(--segment-border);
border-radius: 8px;
background-color: var(--segment-bg);
transition: all 0.3s ease;
}
.audio-segment:hover {
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
}
.segment-header {
margin-bottom: 10px;
}
.segment-title {
margin: 0;
color: var(--segment-text);
font-size: 16px;
font-weight: 600;
}
.segment-meta {
margin-top: 5px;
font-size: 14px;
color: var(--segment-meta);
}
.segment-content {
margin-bottom: 10px;
padding: 12px;
background-color: var(--content-bg);
border-radius: 6px;
border-left: 4px solid var(--content-border);
color: var(--segment-text);
line-height: 1.5;
}
.segment-audio {
width: 100%;
margin-top: 10px;
border-radius: 4px;
}
.segment-warning {
margin-top: 10px;
padding: 10px;
background-color: var(--warning-bg);
border-radius: 4px;
border-left: 4px solid var(--warning-border);
color: var(--warning-text);
font-size: 13px;
}
.segments-title {
color: var(--segment-text);
margin-bottom: 10px;
}
.segments-description {
color: var(--segment-meta);
margin-bottom: 20px;
}
.size-badge {
display: inline-block;
background: linear-gradient(135deg, #6c757d, #495057);
color: white;
padding: 4px 10px;
border-radius: 12px;
font-size: 12px;
margin-left: 10px;
}
</style>
"""
audio_segments_html = theme_css
audio_segments_html += f"<div class='audio-segments-container'>"
# Add format info
format_info = "MP3 32kbps 16kHz mono" if HAS_PYDUB else "WAV 16kHz"
audio_segments_html += f"<h3 class='segments-title'>🔊 Audio Segments ({num_segments} segments)"
audio_segments_html += f"<span class='size-badge'>📦 ~{approx_size_kb:.0f}KB ({format_info})</span></h3>"
audio_segments_html += "<p class='segments-description'>🎵 Click the play button to listen to each segment directly!</p>"
for i, (label, audio_src, error_msg) in enumerate(audio_segments):
seg = segments[i] if i < len(segments) else {}
start_time = seg.get('start_time', 'N/A')
end_time = seg.get('end_time', 'N/A')
speaker_id = seg.get('speaker_id', 'N/A')
content = seg.get('text', '')
# Format times nicely
start_str = f"{start_time:.2f}" if isinstance(start_time, (int, float)) else str(start_time)
end_str = f"{end_time:.2f}" if isinstance(end_time, (int, float)) else str(end_time)
audio_segments_html += f"""
<div class='audio-segment'>
<div class='segment-header'>
<h4 class='segment-title'>Segment {i+1}</h4>
<div class='segment-meta'>
<strong>Time:</strong> [{start_str}s - {end_str}s] |
<strong>Speaker:</strong> {speaker_id}
</div>
</div>
<div class='segment-content'>
{content}
</div>
"""
if audio_src:
# Detect format from data URI
audio_type = 'audio/mp3' if 'audio/mp3' in audio_src else 'audio/wav'
audio_segments_html += f"""
<audio controls class='segment-audio' preload='none'>
<source src='{audio_src}' type='{audio_type}'>
Your browser does not support the audio element.
</audio>
"""
elif error_msg:
audio_segments_html += f"""
<div class='segment-warning'>
<small>❌ {error_msg}</small>
</div>
"""
else:
audio_segments_html += """
<div class='segment-warning'>
<small>Audio playback unavailable for this segment</small>
</div>
"""
audio_segments_html += "</div>"
audio_segments_html += "</div>"
else:
audio_segments_html = """
<style>
:root {
--no-segments-text: #6c757d;
}
@media (prefers-color-scheme: dark) {
:root {
--no-segments-text: #a0aec0;
}
}
.no-segments-container {
padding: 20px;
text-align: center;
color: var(--no-segments-text);
line-height: 1.6;
}
</style>
<div class='no-segments-container'>
<p>❌ No audio segments available.</p>
<p>This could happen if the model output doesn't contain valid time stamps.</p>
</div>
"""
# Final yield with complete results
yield raw_output, audio_segments_html
except Exception as e:
print(f"Error during transcription: {e}")
print(traceback.format_exc())
yield f"❌ Error during transcription: {str(e)}", ""
def create_gradio_interface(model_path: str, default_max_tokens: int = 8192, attn_implementation: str = "flash_attention_2"):
"""Create and launch Gradio interface.
Args:
model_path: Path to the model (HuggingFace format directory or model name)
default_max_tokens: Default value for max_new_tokens slider
attn_implementation: Attention implementation to use ('flash_attention_2', 'sdpa', 'eager')
"""
# Initialize model at startup
device = "cuda" if torch.cuda.is_available() else "cpu"
model_status = initialize_model(model_path, device, attn_implementation)
print(model_status)
# Exit if model loading failed
if model_status.startswith(""):
print("\n" + "="*80)
print("💥 FATAL ERROR: Model loading failed!")
print("="*80)
print("Cannot start demo without a valid model. Please check:")
print(" 1. Model path is correct")
print(" 2. Model files are not corrupted")
print(" 3. You have enough GPU memory")
print(" 4. CUDA is properly installed (if using GPU)")
print("="*80)
sys.exit(1)
# Custom CSS for Stop button styling
custom_css = """
#stop-btn {
background: linear-gradient(135deg, #ef4444 0%, #dc2626 100%) !important;
border: none !important;
color: white !important;
}
#stop-btn:hover {
background: linear-gradient(135deg, #dc2626 0%, #b91c1c 100%) !important;
}
"""
# Gradio 6.0+ moved theme/css to launch()
with gr.Blocks(title="VibeVoice ASR Demo") as demo:
gr.Markdown("# 🎙️ VibeVoice ASR Demo")
gr.Markdown("Upload audio files or record from microphone to get speech-to-text transcription with speaker diarization.")
gr.Markdown(f"**Model loaded from:** `{model_path}`")
with gr.Row():
with gr.Column(scale=1):
# Generation parameters
gr.Markdown("## ⚙️ Generation Parameters")
max_tokens_slider = gr.Slider(
minimum=4096,
maximum=65536,
value=default_max_tokens,
step=4096,
label="Max New Tokens"
)
# Sampling parameters
gr.Markdown("### 🎲 Sampling")
do_sample_checkbox = gr.Checkbox(
value=False,
label="Enable Sampling",
info="Enable random sampling instead of deterministic decoding"
)
with gr.Column(visible=False) as sampling_params:
temperature_slider = gr.Slider(
minimum=0.0,
maximum=2.0,
value=0.0,
step=0.1,
label="Temperature",
info="0 = greedy, higher = more random"
)
top_p_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
value=1.0,
step=0.05,
label="Top-p (Nucleus Sampling)",
info="1.0 = no filtering"
)
# Repetition penalty (works with both greedy and sampling)
repetition_penalty_slider = gr.Slider(
minimum=1.0,
maximum=1.2,
value=1.0,
step=0.01,
label="Repetition Penalty",
info="1.0 = no penalty, higher = less repetition (works with greedy & sampling)"
)
# Context information section
gr.Markdown("## 📋 Context Info (Optional)")
context_info_input = gr.Textbox(
label="Context Information",
placeholder="Enter hotwords, speaker names, topics, or other context to help transcription...\nExample:\nJohn Smith\nMachine Learning\nOpenAI",
lines=4,
max_lines=8,
interactive=True,
info="Provide context like proper nouns, technical terms, or speaker names to improve accuracy"
)
with gr.Column(scale=2):
# Audio input section
gr.Markdown("## 🎵 Audio Input")
audio_input = gr.Audio(
label="Upload Audio File or Record from Microphone",
sources=["upload", "microphone"],
type="filepath",
interactive=True,
buttons=["download"]
)
# Video file upload (MKV, MP4, AVI, etc.) - audio will be extracted
video_file_input = gr.File(
label="📹 Or Upload Video File (MKV, MP4, AVI, etc.)",
file_types=[".mkv", ".mp4", ".avi", ".mov", ".webm", ".flv", ".m4v", ".mpeg", ".3gp"],
type="filepath"
)
with gr.Accordion("📂 Advanced: Remote Path & Time Slicing", open=False):
audio_path_input = gr.Textbox(
label="Audio path (optional)",
placeholder="Enter remote audio file path",
lines=1
)
with gr.Row():
start_time_input = gr.Textbox(
label="Start time",
placeholder="e.g., 0 or 00:00:00",
lines=1,
info="Leave empty to start from the beginning"
)
end_time_input = gr.Textbox(
label="End time",
placeholder="e.g., 30.5 or 00:00:30.5",
lines=1,
info="Leave empty to use full length"
)
with gr.Row():
transcribe_button = gr.Button("🎯 Transcribe", variant="primary", size="lg", scale=3)
stop_button = gr.Button("⏹️ Stop", variant="secondary", size="lg", scale=1, elem_id="stop-btn")
# Results section
gr.Markdown("## 📝 Results")
with gr.Tabs():
with gr.TabItem("Raw Output"):
raw_output = gr.Textbox(
label="Raw Transcription Output",
lines=8,
max_lines=20,
interactive=False
)
with gr.TabItem("Audio Segments"):
audio_segments_output = gr.HTML(
label="Play individual segments to verify accuracy"
)
# Event handlers
do_sample_checkbox.change(
fn=lambda x: gr.update(visible=x),
inputs=[do_sample_checkbox],
outputs=[sampling_params]
)
def reset_stop_flag():
"""Reset stop flag before starting transcription."""
global stop_generation_flag
stop_generation_flag = False
def transcribe_with_video_support(
audio_input,
video_file_input,
audio_path_input,
start_time_input,
end_time_input,
max_new_tokens,
temperature,
top_p,
do_sample,
repetition_penalty,
context_info
):
"""Wrapper that handles both audio and video file inputs."""
# If video file is provided, use it as the audio path
effective_audio_path = audio_path_input
if video_file_input is not None and not audio_path_input:
effective_audio_path = video_file_input
print(f"[INFO] Using video file: {effective_audio_path}")
# Call original transcribe function
yield from transcribe_audio(
audio_input,
effective_audio_path,
start_time_input,
end_time_input,
max_new_tokens,
temperature,
top_p,
do_sample,
repetition_penalty,
context_info
)
def set_stop_flag():
"""Set stop flag to interrupt generation."""
global stop_generation_flag
stop_generation_flag = True
return "⏹️ Stop requested..."
transcribe_button.click(
fn=reset_stop_flag,
inputs=[],
outputs=[],
queue=False
).then(
fn=transcribe_with_video_support,
inputs=[
audio_input,
video_file_input,
audio_path_input,
start_time_input,
end_time_input,
max_tokens_slider,
temperature_slider,
top_p_slider,
do_sample_checkbox,
repetition_penalty_slider,
context_info_input
],
outputs=[raw_output, audio_segments_output]
)
stop_button.click(
fn=set_stop_flag,
inputs=[],
outputs=[raw_output],
queue=False
)
# Add examples
gr.Markdown("## 📋 Instructions")
gr.Markdown(f"""
1. **Upload Audio**: Use the audio component to upload a file or record from microphone
- **Supported formats**: {', '.join(sorted(set([ext.lower() for ext in COMMON_AUDIO_EXTS])))}
- Optionally set **Start/End time** (seconds or hh:mm:ss) to clip before transcription
2. **Context Info (Optional)**: Provide context to improve transcription accuracy
- Add hotwords, proper nouns, speaker names, or technical terms
- One item per line or comma-separated
- Examples: "John Smith", "OpenAI", "machine learning"
3. **Adjust Parameters**: Configure generation parameters as needed
4. **Transcribe**: Click "Transcribe" to get results
5. **Review Results**:
- **Raw Output**: View the model's original output
- **Audio Segments**: Play individual segments directly to verify accuracy
**Audio Segments**: Each segment shows the time range, speaker ID, transcribed content, and an embedded audio player for immediate verification.
""")
return demo, custom_css
def main():
parser = argparse.ArgumentParser(description="VibeVoice ASR Gradio Demo")
parser.add_argument(
"--model_path",
type=str,
default="microsoft/VibeVoice-ASR",
help="Path to the model (HuggingFace format directory or model name)"
)
parser.add_argument(
"--attn_implementation",
type=str,
default="flash_attention_2",
help="Attention implementation to use (default: flash_attention_2)"
)
parser.add_argument(
"--max_new_tokens",
type=int,
default=32768,
help="Default max new tokens for generation (default: 32768)"
)
parser.add_argument(
"--host",
type=str,
default="0.0.0.0",
help="Host to bind the server to"
)
parser.add_argument(
"--port",
type=int,
default=7860,
help="Port to bind the server to"
)
parser.add_argument(
"--share",
action="store_true",
help="Create a public link"
)
args = parser.parse_args()
# Create and launch interface
demo, custom_css = create_gradio_interface(
model_path=args.model_path,
default_max_tokens=args.max_new_tokens,
attn_implementation=args.attn_implementation
)
print(f"🚀 Starting VibeVoice ASR Demo...")
print(f"📍 Server will be available at: http://{args.host}:{args.port}")
# Gradio 6.0+ moved theme/css to launch()
launch_kwargs = {
"server_name": args.host,
"server_port": args.port,
"share": args.share,
"show_error": True,
"theme": gr.themes.Soft(),
"css": custom_css,
}
# Enable queue for concurrent request handling
demo.queue(default_concurrency_limit=3)
demo.launch(**launch_kwargs)
if __name__ == "__main__":
main()