All checks were successful
Deploy Docusaurus Site / deploy (push) Successful in 27s
174 lines
4.5 KiB
Python
174 lines
4.5 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
VibeVoice-ASR Test Script for DGX Spark
|
|
Tests basic functionality and GPU availability
|
|
"""
|
|
|
|
import sys
|
|
import subprocess
|
|
|
|
|
|
def test_imports():
|
|
"""Test that VibeVoice can be imported"""
|
|
print("=" * 60)
|
|
print("Testing VibeVoice imports...")
|
|
print("=" * 60)
|
|
|
|
try:
|
|
import vibevoice
|
|
print("[OK] vibevoice imported successfully")
|
|
return True
|
|
except ImportError as e:
|
|
print(f"[FAIL] Failed to import vibevoice: {e}")
|
|
return False
|
|
|
|
|
|
def test_torch_cuda():
|
|
"""Test PyTorch CUDA availability"""
|
|
print("\n" + "=" * 60)
|
|
print("Testing PyTorch CUDA...")
|
|
print("=" * 60)
|
|
|
|
try:
|
|
import torch
|
|
print(f"[INFO] PyTorch version: {torch.__version__}")
|
|
print(f"[INFO] CUDA available: {torch.cuda.is_available()}")
|
|
|
|
if torch.cuda.is_available():
|
|
print(f"[INFO] CUDA version: {torch.version.cuda}")
|
|
print(f"[INFO] GPU count: {torch.cuda.device_count()}")
|
|
|
|
for i in range(torch.cuda.device_count()):
|
|
props = torch.cuda.get_device_properties(i)
|
|
print(f"[INFO] GPU {i}: {props.name}")
|
|
print(f" Compute capability: {props.major}.{props.minor}")
|
|
print(f" Total memory: {props.total_memory / 1024**3:.1f} GB")
|
|
|
|
# Quick CUDA test
|
|
x = torch.randn(100, 100, device='cuda')
|
|
y = torch.matmul(x, x)
|
|
print(f"[OK] CUDA tensor operations working")
|
|
return True
|
|
else:
|
|
print("[WARN] CUDA not available")
|
|
return False
|
|
|
|
except Exception as e:
|
|
print(f"[FAIL] PyTorch CUDA test failed: {e}")
|
|
return False
|
|
|
|
|
|
def test_flash_attention():
|
|
"""Test flash attention availability"""
|
|
print("\n" + "=" * 60)
|
|
print("Testing Flash Attention...")
|
|
print("=" * 60)
|
|
|
|
try:
|
|
import flash_attn
|
|
print(f"[OK] flash_attn version: {flash_attn.__version__}")
|
|
return True
|
|
except ImportError:
|
|
print("[WARN] flash_attn not installed (optional)")
|
|
return True # Not required
|
|
|
|
|
|
def test_ffmpeg():
|
|
"""Test FFmpeg availability"""
|
|
print("\n" + "=" * 60)
|
|
print("Testing FFmpeg...")
|
|
print("=" * 60)
|
|
|
|
try:
|
|
result = subprocess.run(
|
|
["ffmpeg", "-version"],
|
|
capture_output=True,
|
|
text=True
|
|
)
|
|
if result.returncode == 0:
|
|
version_line = result.stdout.split('\n')[0]
|
|
print(f"[OK] {version_line}")
|
|
return True
|
|
else:
|
|
print("[FAIL] FFmpeg returned error")
|
|
return False
|
|
except FileNotFoundError:
|
|
print("[FAIL] FFmpeg not found")
|
|
return False
|
|
|
|
|
|
def test_asr_model():
|
|
"""Test loading ASR model (if GPU available)"""
|
|
print("\n" + "=" * 60)
|
|
print("Testing ASR Model Loading...")
|
|
print("=" * 60)
|
|
|
|
try:
|
|
import torch
|
|
if not torch.cuda.is_available():
|
|
print("[SKIP] Skipping model test - no GPU available")
|
|
return True
|
|
|
|
# Try to load the ASR pipeline
|
|
from vibevoice import ASRPipeline
|
|
print("[INFO] Loading ASR pipeline...")
|
|
|
|
# Use smaller model for testing
|
|
pipeline = ASRPipeline()
|
|
print("[OK] ASR pipeline loaded successfully")
|
|
|
|
# Clean up
|
|
del pipeline
|
|
torch.cuda.empty_cache()
|
|
|
|
return True
|
|
|
|
except ImportError as e:
|
|
print(f"[WARN] ASRPipeline not available: {e}")
|
|
print("[INFO] This may be normal depending on VibeVoice version")
|
|
return True
|
|
except Exception as e:
|
|
print(f"[WARN] ASR model test: {e}")
|
|
return True
|
|
|
|
|
|
def main():
|
|
"""Run all tests"""
|
|
print("\n")
|
|
print("*" * 60)
|
|
print(" VibeVoice-ASR Test Suite for DGX Spark")
|
|
print("*" * 60)
|
|
|
|
results = {
|
|
"imports": test_imports(),
|
|
"torch_cuda": test_torch_cuda(),
|
|
"flash_attention": test_flash_attention(),
|
|
"ffmpeg": test_ffmpeg(),
|
|
"asr_model": test_asr_model(),
|
|
}
|
|
|
|
print("\n")
|
|
print("=" * 60)
|
|
print("Test Summary")
|
|
print("=" * 60)
|
|
|
|
all_passed = True
|
|
for name, passed in results.items():
|
|
status = "[OK]" if passed else "[FAIL]"
|
|
print(f" {status} {name}")
|
|
if not passed:
|
|
all_passed = False
|
|
|
|
print("=" * 60)
|
|
|
|
if all_passed:
|
|
print("\nAll tests passed!")
|
|
return 0
|
|
else:
|
|
print("\nSome tests failed.")
|
|
return 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|