OrpheusTail v2: transformers streaming engine (replaces vLLM)
Major rework: replaced vLLM sync LLM with HuggingFace transformers + TextIteratorStreamer for true token-level streaming. Pipeline: text → format_prompt → model.generate(streamer) → extract_audio_codes (regex on streaming text) → SNAC decode → PCM Expected first-audio latency: ~1-2s (was 10-14s with vLLM). No more monkey-patching, no more AsyncLLMEngine hangs on Jetson. SNAC model loaded separately (snac_24khz) for audio decoding. All endpoints preserved, API compatible with v1. Voice cloning endpoint now honest about LoRA requirement. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
579
main.py
579
main.py
@@ -1,24 +1,18 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
OrpheusTail - Orpheus TTS Service
|
||||
OrpheusTail - Orpheus TTS Service (v2 — transformers streaming)
|
||||
|
||||
FastAPI server for Orpheus text-to-speech generation on Jetson AGX Orin.
|
||||
Replaces VoiceTail (Bark) with better control, voice cloning, and emotion tags.
|
||||
True token-level streaming: first audio in ~1-2s instead of 10-15s.
|
||||
|
||||
Key Features:
|
||||
- Emotion tags: <laugh>, <chuckle>, <sigh>, <cough>, <sniffle>, <groan>, <yawn>, <gasp>
|
||||
- Zero-shot voice cloning from reference audio
|
||||
- Streaming support for real-time head playback
|
||||
- True token-level streaming via HuggingFace transformers + TextIteratorStreamer
|
||||
- SNAC codec decoding with streaming audio output
|
||||
- Built-in voices: tara, leah, jess, leo, dan, mia, zac, zoe
|
||||
- Voice reference storage (for future LoRA fine-tuning)
|
||||
|
||||
Endpoints:
|
||||
- POST /tts/submit - Submit TTS job (returns job_id)
|
||||
- GET /tts/status/{job_id} - Check job status
|
||||
- GET /tts/audio/{job_id} - Download generated audio
|
||||
- POST /tts/stream - Stream audio in real-time (for head)
|
||||
- POST /voice/clone - Upload reference audio for voice cloning
|
||||
- GET /voices - List available voices
|
||||
- GET /health - Health check
|
||||
Engine: HuggingFace transformers (replaced vLLM — simpler, streams natively)
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -29,12 +23,16 @@ import asyncio
|
||||
import uuid
|
||||
import wave
|
||||
import io
|
||||
import threading
|
||||
import time
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Generator
|
||||
from dataclasses import dataclass, asdict
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
from fastapi import FastAPI, BackgroundTasks, HTTPException, UploadFile, File
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
@@ -44,10 +42,10 @@ ORPHEUS_MODEL = os.getenv("ORPHEUS_MODEL", "canopylabs/orpheus-tts-0.1-finetune-
|
||||
CACHE_ENABLED = os.getenv("CACHE_ENABLED", "true").lower() == "true"
|
||||
CACHE_DIR = Path(os.getenv("CACHE_DIR", "cache"))
|
||||
OUTPUT_DIR = Path(os.getenv("OUTPUT_DIR", "output"))
|
||||
VOICES_DIR = Path(os.getenv("VOICES_DIR", "voices")) # For cloned voice references
|
||||
VOICES_DIR = Path(os.getenv("VOICES_DIR", "voices"))
|
||||
RETENTION_DAYS = int(os.getenv("RETENTION_DAYS", "10"))
|
||||
CLEANUP_INTERVAL_HOURS = int(os.getenv("CLEANUP_INTERVAL_HOURS", "1"))
|
||||
DEFAULT_VOICE = os.getenv("DEFAULT_VOICE", "tara") # Orpheus default voice
|
||||
DEFAULT_VOICE = os.getenv("DEFAULT_VOICE", "tara")
|
||||
MAX_MODEL_LEN = int(os.getenv("MAX_MODEL_LEN", "8192"))
|
||||
MAX_TOKENS = int(os.getenv("MAX_TOKENS", "8000"))
|
||||
SAMPLE_RATE = 24000
|
||||
@@ -60,34 +58,48 @@ VOICES_DIR.mkdir(exist_ok=True)
|
||||
# Jobs persistence
|
||||
JOBS_FILE = OUTPUT_DIR / "jobs.json"
|
||||
|
||||
# Built-in Orpheus voices (in order of conversational realism per docs)
|
||||
# Built-in Orpheus voices
|
||||
BUILTIN_VOICES = ["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe"]
|
||||
|
||||
# Supported emotion tags
|
||||
EMOTION_TAGS = ["<laugh>", "<chuckle>", "<sigh>", "<cough>", "<sniffle>", "<groan>", "<yawn>", "<gasp>"]
|
||||
|
||||
# Orpheus special tokens
|
||||
SPECIAL_TOKEN_START = 128259
|
||||
SPECIAL_TOKENS_END = [128009, 128260, 128261, 128257]
|
||||
EOS_TOKEN_ID = 128258
|
||||
CODE_TOKEN_OFFSET = 128266 # audio code tokens start here
|
||||
CODE_REMOVE_TOKEN_ID = 128258 # this token signals end, not audio
|
||||
|
||||
# SNAC streaming parameters
|
||||
SNAC_CHUNK_SIZE = 7 # tokens per SNAC group
|
||||
SNAC_INITIAL_GROUPS = 6 # wait for N groups before first decode (higher = smoother start)
|
||||
|
||||
# Initialize FastAPI
|
||||
app = FastAPI(
|
||||
title="OrpheusTail - Orpheus TTS Service",
|
||||
description="Text-to-speech with emotion control and voice cloning for Vixy",
|
||||
version="1.0.0"
|
||||
description="Text-to-speech with emotion control and streaming for Vixy",
|
||||
version="2.0.0"
|
||||
)
|
||||
|
||||
# Global model (loaded at startup)
|
||||
model = None
|
||||
# Global model references
|
||||
llm_model = None
|
||||
llm_tokenizer = None
|
||||
snac_model = None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Job System (unchanged from v1)
|
||||
# ============================================================================
|
||||
|
||||
class JobStatus(str, Enum):
|
||||
"""Job status enum"""
|
||||
PENDING = "PENDING"
|
||||
PROCESSING = "PROCESSING"
|
||||
SUCCESS = "SUCCESS"
|
||||
FAILURE = "FAILURE"
|
||||
|
||||
|
||||
@dataclass
|
||||
class JobInfo:
|
||||
"""Job information"""
|
||||
job_id: str
|
||||
text: str
|
||||
voice: str
|
||||
@@ -99,13 +111,9 @@ class JobInfo:
|
||||
created_at: str = ""
|
||||
completed_at: Optional[str] = None
|
||||
|
||||
|
||||
# In-memory job storage
|
||||
jobs: Dict[str, JobInfo] = {}
|
||||
|
||||
|
||||
def load_jobs_from_disk():
|
||||
"""Load jobs from disk on startup"""
|
||||
global jobs
|
||||
if JOBS_FILE.exists():
|
||||
try:
|
||||
@@ -117,9 +125,7 @@ def load_jobs_from_disk():
|
||||
except Exception as e:
|
||||
print(f"Error loading jobs: {e}")
|
||||
|
||||
|
||||
def save_jobs_to_disk():
|
||||
"""Save jobs to disk"""
|
||||
try:
|
||||
data = {job_id: asdict(job) for job_id, job in jobs.items()}
|
||||
with open(JOBS_FILE, 'w') as f:
|
||||
@@ -128,14 +134,15 @@ def save_jobs_to_disk():
|
||||
print(f"Error saving jobs: {e}")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Cache System (unchanged from v1)
|
||||
# ============================================================================
|
||||
|
||||
def hash_text_voice(text: str, voice: str) -> str:
|
||||
"""Generate cache key from text + voice"""
|
||||
content = f"{text}|{voice}"
|
||||
return hashlib.sha256(content.encode()).hexdigest()
|
||||
|
||||
|
||||
def get_from_cache(cache_key: str) -> Optional[str]:
|
||||
"""Check if audio exists in cache"""
|
||||
if not CACHE_ENABLED:
|
||||
return None
|
||||
cache_path = CACHE_DIR / f"{cache_key}.wav"
|
||||
@@ -144,9 +151,7 @@ def get_from_cache(cache_key: str) -> Optional[str]:
|
||||
return str(cache_path)
|
||||
return None
|
||||
|
||||
|
||||
def save_to_cache(cache_key: str, audio_path: str):
|
||||
"""Save generated audio to cache"""
|
||||
if not CACHE_ENABLED:
|
||||
return
|
||||
try:
|
||||
@@ -157,145 +162,200 @@ def save_to_cache(cache_key: str, audio_path: str):
|
||||
except Exception as e:
|
||||
print(f"Error saving to cache: {e}")
|
||||
|
||||
|
||||
def get_custom_voices() -> List[str]:
|
||||
"""Get list of custom cloned voices"""
|
||||
voices = []
|
||||
for voice_file in VOICES_DIR.glob("*.wav"):
|
||||
voices.append(voice_file.stem)
|
||||
return voices
|
||||
return [f.stem for f in VOICES_DIR.glob("*.wav")]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Text Chunking
|
||||
# ============================================================================
|
||||
|
||||
def chunk_text(text: str, max_chars: int = 800) -> List[str]:
|
||||
"""
|
||||
Split text into chunks at sentence boundaries for sequential TTS generation.
|
||||
|
||||
Ensures no chunk exceeds max_chars (unless a single sentence is longer).
|
||||
Preserves emotion tags within sentences.
|
||||
"""
|
||||
# Split on sentence-ending punctuation followed by whitespace
|
||||
"""Split text into chunks at sentence boundaries."""
|
||||
sentences = re.split(r'(?<=[.!?])\s+', text.strip())
|
||||
|
||||
chunks = []
|
||||
current_chunk = []
|
||||
current_len = 0
|
||||
|
||||
for sentence in sentences:
|
||||
sentence_len = len(sentence)
|
||||
# If adding this sentence would exceed the limit, finalize current chunk
|
||||
if current_chunk and current_len + 1 + sentence_len > max_chars:
|
||||
chunks.append(' '.join(current_chunk))
|
||||
current_chunk = []
|
||||
current_len = 0
|
||||
current_chunk.append(sentence)
|
||||
current_len += sentence_len + (1 if current_len > 0 else 0)
|
||||
|
||||
# Don't forget the last chunk
|
||||
if current_chunk:
|
||||
chunks.append(' '.join(current_chunk))
|
||||
|
||||
return chunks if chunks else [text]
|
||||
|
||||
|
||||
def chunk_text_fine(text: str, max_chars: int = 200) -> List[str]:
|
||||
"""
|
||||
Split text into fine-grained chunks for streaming — every sentence or clause.
|
||||
Smaller chunks = faster first-audio, slight quality tradeoff at boundaries.
|
||||
"""
|
||||
# Split on sentence boundaries AND commas/semicolons with reasonable length
|
||||
parts = re.split(r'(?<=[.!?;])\s+', text.strip())
|
||||
# ============================================================================
|
||||
# Inference Engine (transformers — replaces vLLM)
|
||||
# ============================================================================
|
||||
|
||||
# Further split long parts on commas
|
||||
chunks = []
|
||||
for part in parts:
|
||||
if len(part) <= max_chars:
|
||||
chunks.append(part)
|
||||
else:
|
||||
# Split on commas
|
||||
sub = re.split(r',\s+', part)
|
||||
current = []
|
||||
current_len = 0
|
||||
for s in sub:
|
||||
if current and current_len + len(s) > max_chars:
|
||||
chunks.append(', '.join(current))
|
||||
current = []
|
||||
current_len = 0
|
||||
current.append(s)
|
||||
current_len += len(s)
|
||||
if current:
|
||||
chunks.append(', '.join(current))
|
||||
def format_prompt(text: str, voice: str) -> torch.Tensor:
|
||||
"""Format prompt for Orpheus finetuned model. Returns input_ids tensor."""
|
||||
full_text = f"{voice}: {text}"
|
||||
input_ids = llm_tokenizer(full_text, return_tensors="pt").input_ids
|
||||
|
||||
# Filter empty chunks
|
||||
return [c.strip() for c in chunks if c.strip()]
|
||||
# Wrap with Orpheus special tokens
|
||||
start = torch.tensor([[SPECIAL_TOKEN_START]], dtype=torch.long)
|
||||
end = torch.tensor([SPECIAL_TOKENS_END], dtype=torch.long)
|
||||
return torch.cat([start, input_ids, end], dim=1).to(llm_model.device)
|
||||
|
||||
|
||||
def generate_tokens_streaming(text: str, voice: str) -> Generator[str, None, None]:
|
||||
"""Stream token strings as the model generates them."""
|
||||
from transformers import TextIteratorStreamer
|
||||
|
||||
input_ids = format_prompt(text, voice)
|
||||
|
||||
streamer = TextIteratorStreamer(llm_tokenizer, skip_special_tokens=False)
|
||||
|
||||
gen_kwargs = dict(
|
||||
input_ids=input_ids,
|
||||
max_new_tokens=MAX_TOKENS,
|
||||
temperature=0.6,
|
||||
top_p=0.8,
|
||||
repetition_penalty=1.3,
|
||||
do_sample=True,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
streamer=streamer,
|
||||
)
|
||||
|
||||
thread = threading.Thread(target=llm_model.generate, kwargs=gen_kwargs, daemon=True)
|
||||
thread.start()
|
||||
|
||||
for text_chunk in streamer:
|
||||
yield text_chunk
|
||||
|
||||
thread.join()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token → Audio Pipeline
|
||||
# ============================================================================
|
||||
|
||||
def extract_audio_codes(text_stream: Generator[str, None, None]) -> Generator[int, None, None]:
|
||||
"""Extract SNAC audio codes from streamed text. Handles partial tokens across chunks."""
|
||||
buffer = ""
|
||||
for chunk in text_stream:
|
||||
buffer += chunk
|
||||
# Extract all complete <custom_token_XXXX> patterns
|
||||
while True:
|
||||
match = re.search(r'<custom_token_(\d+)>', buffer)
|
||||
if not match:
|
||||
break
|
||||
token_id = int(match.group(1))
|
||||
buffer = buffer[match.end():]
|
||||
# Only yield actual audio codes (>= CODE_TOKEN_OFFSET)
|
||||
if token_id >= CODE_TOKEN_OFFSET:
|
||||
yield token_id - CODE_TOKEN_OFFSET
|
||||
|
||||
|
||||
def redistribute_codes(codes: list) -> list:
|
||||
"""Redistribute flat code list into SNAC's 3 hierarchical layers.
|
||||
Each group of 7: [L1, L2a, L2b, L3a, L3b, L3c, L3d]"""
|
||||
layer1, layer2, layer3 = [], [], []
|
||||
for i in range(0, len(codes), SNAC_CHUNK_SIZE):
|
||||
group = codes[i:i + SNAC_CHUNK_SIZE]
|
||||
if len(group) < SNAC_CHUNK_SIZE:
|
||||
break
|
||||
layer1.append(group[0])
|
||||
layer2.extend(group[1:3])
|
||||
layer3.extend(group[3:7])
|
||||
return [layer1, layer2, layer3]
|
||||
|
||||
|
||||
def snac_decode(codes: list) -> Optional[bytes]:
|
||||
"""Decode SNAC codes to PCM audio bytes."""
|
||||
layers = redistribute_codes(codes)
|
||||
if not layers[0]:
|
||||
return None
|
||||
|
||||
with torch.no_grad():
|
||||
codes_tensor = [
|
||||
torch.tensor(layer, device=snac_model.device, dtype=torch.long).unsqueeze(0)
|
||||
for layer in layers
|
||||
]
|
||||
audio_hat = snac_model.decode(codes_tensor)
|
||||
|
||||
audio_np = audio_hat.squeeze().cpu().numpy()
|
||||
audio_int16 = (audio_np * 32767).clip(-32768, 32767).astype(np.int16)
|
||||
return audio_int16.tobytes()
|
||||
|
||||
|
||||
def decode_audio_stream(code_stream: Generator[int, None, None]) -> Generator[bytes, None, None]:
|
||||
"""Convert streaming SNAC codes to PCM audio chunks."""
|
||||
buffer = []
|
||||
group_count = 0
|
||||
total_codes = 0
|
||||
|
||||
for code in code_stream:
|
||||
buffer.append(code)
|
||||
total_codes += 1
|
||||
|
||||
if total_codes % SNAC_CHUNK_SIZE == 0:
|
||||
group_count += 1
|
||||
|
||||
if group_count >= SNAC_INITIAL_GROUPS and group_count % 1 == 0:
|
||||
# Decode the last N groups
|
||||
decode_size = SNAC_INITIAL_GROUPS * SNAC_CHUNK_SIZE
|
||||
codes_to_decode = buffer[-decode_size:]
|
||||
audio = snac_decode(codes_to_decode)
|
||||
if audio:
|
||||
# Yield only the NEW audio (avoid overlap)
|
||||
# Each group of 7 codes produces ~2048 samples
|
||||
new_samples = SNAC_CHUNK_SIZE * 293 # ~293 samples per code at 24kHz
|
||||
yield audio[-new_samples * 2:] # *2 for 16-bit (2 bytes per sample)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# High-level generation functions
|
||||
# ============================================================================
|
||||
|
||||
def generate_stream(text: str, voice: str) -> Generator[bytes, None, None]:
|
||||
"""Full streaming pipeline: text → tokens → audio chunks."""
|
||||
voice = voice if voice in BUILTIN_VOICES else DEFAULT_VOICE
|
||||
for chunk in chunk_text(text):
|
||||
print(f"[stream] Generating: {chunk[:80]}...")
|
||||
t0 = time.time()
|
||||
first_audio = True
|
||||
text_stream = generate_tokens_streaming(chunk, voice)
|
||||
code_stream = extract_audio_codes(text_stream)
|
||||
for pcm_chunk in decode_audio_stream(code_stream):
|
||||
if first_audio:
|
||||
print(f"[stream] First audio in {time.time() - t0:.2f}s")
|
||||
first_audio = False
|
||||
yield pcm_chunk
|
||||
print("[stream] Done")
|
||||
|
||||
|
||||
def generate_speech_sync(text: str, voice: str) -> bytes:
|
||||
"""
|
||||
Generate speech using Orpheus model (synchronous).
|
||||
|
||||
Args:
|
||||
text: Text to convert (may include emotion tags)
|
||||
voice: Voice name (built-in or custom)
|
||||
|
||||
Returns:
|
||||
WAV audio bytes
|
||||
"""
|
||||
global model
|
||||
import numpy as np
|
||||
|
||||
# Check if it's a custom voice (needs reference audio)
|
||||
custom_voice_path = VOICES_DIR / f"{voice}.wav"
|
||||
|
||||
if custom_voice_path.exists():
|
||||
print(f"Custom voice '{voice}' - voice cloning to be implemented")
|
||||
voice = DEFAULT_VOICE
|
||||
elif voice not in BUILTIN_VOICES:
|
||||
print(f"Unknown voice '{voice}', using default '{DEFAULT_VOICE}'")
|
||||
voice = DEFAULT_VOICE
|
||||
|
||||
# Split long text into chunks for sequential generation
|
||||
text_chunks = chunk_text(text)
|
||||
print(f"Generating: {text} ({len(text_chunks)} chunk(s))")
|
||||
|
||||
"""Generate complete audio (for job system). Returns WAV bytes."""
|
||||
voice = voice if voice in BUILTIN_VOICES else DEFAULT_VOICE
|
||||
all_pcm = []
|
||||
|
||||
for chunk_idx, chunk in enumerate(text_chunks):
|
||||
print(f"Chunk {chunk_idx + 1}/{len(text_chunks)}: {chunk[:80]}...")
|
||||
audio_chunks = []
|
||||
|
||||
syn_tokens = model.generate_speech(
|
||||
prompt=chunk,
|
||||
voice=voice,
|
||||
max_tokens=MAX_TOKENS,
|
||||
)
|
||||
|
||||
for i, audio_chunk in enumerate(syn_tokens):
|
||||
audio_chunks.append(audio_chunk)
|
||||
|
||||
print(f" -> {len(audio_chunks)} audio frames")
|
||||
|
||||
if audio_chunks:
|
||||
all_pcm.append(b''.join(audio_chunks))
|
||||
for chunk in chunk_text(text):
|
||||
print(f"Generating: {chunk[:80]}...")
|
||||
text_stream = generate_tokens_streaming(chunk, voice)
|
||||
code_stream = extract_audio_codes(text_stream)
|
||||
for pcm_chunk in decode_audio_stream(code_stream):
|
||||
all_pcm.append(pcm_chunk)
|
||||
|
||||
if not all_pcm:
|
||||
raise ValueError("No audio chunks generated")
|
||||
raise ValueError("No audio generated")
|
||||
|
||||
# Concatenate raw PCM from all chunks, wrap in single WAV
|
||||
audio_bytes_raw = b''.join(all_pcm)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
with wave.open(buffer, 'wb') as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2) # 16-bit
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(SAMPLE_RATE)
|
||||
wf.writeframes(audio_bytes_raw)
|
||||
|
||||
print(f"Generated WAV: {len(buffer.getvalue())} bytes")
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
def save_audio_to_file(job_id: str, audio_bytes: bytes) -> str:
|
||||
"""Save audio bytes to WAV file."""
|
||||
output_path = OUTPUT_DIR / f"{job_id}.wav"
|
||||
with open(output_path, 'wb') as f:
|
||||
f.write(audio_bytes)
|
||||
@@ -303,16 +363,14 @@ def save_audio_to_file(job_id: str, audio_bytes: bytes) -> str:
|
||||
|
||||
|
||||
async def generate_speech_background(job_id: str, text: str, voice: str):
|
||||
"""Background task for speech generation (async)."""
|
||||
"""Background task for speech generation."""
|
||||
try:
|
||||
jobs[job_id].status = JobStatus.PROCESSING
|
||||
jobs[job_id].progress = 25
|
||||
save_jobs_to_disk()
|
||||
|
||||
# Check cache first
|
||||
cache_key = hash_text_voice(text, voice)
|
||||
cached_path = get_from_cache(cache_key)
|
||||
|
||||
if cached_path:
|
||||
jobs[job_id].audio_path = cached_path
|
||||
jobs[job_id].status = JobStatus.SUCCESS
|
||||
@@ -320,34 +378,25 @@ async def generate_speech_background(job_id: str, text: str, voice: str):
|
||||
jobs[job_id].cached = True
|
||||
jobs[job_id].completed_at = datetime.now().isoformat()
|
||||
save_jobs_to_disk()
|
||||
print(f"Job {job_id} completed from cache")
|
||||
return
|
||||
|
||||
# Generate audio - call sync function directly (blocks but let's test if it works)
|
||||
jobs[job_id].progress = 50
|
||||
save_jobs_to_disk()
|
||||
|
||||
print(f"Generating audio for job {job_id}...")
|
||||
audio_bytes = generate_speech_sync(text, voice)
|
||||
audio_bytes = await asyncio.to_thread(generate_speech_sync, text, voice)
|
||||
|
||||
# Save to file
|
||||
jobs[job_id].progress = 75
|
||||
save_jobs_to_disk()
|
||||
|
||||
output_path = save_audio_to_file(job_id, audio_bytes)
|
||||
|
||||
# Save to cache
|
||||
save_to_cache(cache_key, output_path)
|
||||
|
||||
# Complete
|
||||
jobs[job_id].audio_path = output_path
|
||||
jobs[job_id].status = JobStatus.SUCCESS
|
||||
jobs[job_id].progress = 100
|
||||
jobs[job_id].completed_at = datetime.now().isoformat()
|
||||
save_jobs_to_disk()
|
||||
|
||||
print(f"Job {job_id} completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Job {job_id} failed: {e}")
|
||||
import traceback
|
||||
@@ -358,12 +407,10 @@ async def generate_speech_background(job_id: str, text: str, voice: str):
|
||||
|
||||
|
||||
async def cleanup_old_jobs():
|
||||
"""Background task to cleanup old jobs and files."""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(CLEANUP_INTERVAL_HOURS * 3600)
|
||||
cutoff = datetime.now() - timedelta(days=RETENTION_DAYS)
|
||||
|
||||
to_delete = []
|
||||
for job_id, job in jobs.items():
|
||||
try:
|
||||
@@ -374,76 +421,50 @@ async def cleanup_old_jobs():
|
||||
to_delete.append(job_id)
|
||||
except:
|
||||
pass
|
||||
|
||||
for job_id in to_delete:
|
||||
del jobs[job_id]
|
||||
|
||||
if to_delete:
|
||||
save_jobs_to_disk()
|
||||
print(f"Cleanup: deleted {len(to_delete)} old jobs")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in cleanup task: {e}")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Startup
|
||||
# ============================================================================
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup():
|
||||
"""Load model and jobs on startup"""
|
||||
global model
|
||||
global llm_model, llm_tokenizer, snac_model
|
||||
|
||||
print("=" * 60)
|
||||
print("OrpheusTail - Orpheus TTS Service Starting")
|
||||
print("OrpheusTail v2 — transformers streaming engine")
|
||||
print(f"Model: {ORPHEUS_MODEL}")
|
||||
print(f"Max Model Len: {MAX_MODEL_LEN}")
|
||||
print(f"Max Tokens: {MAX_TOKENS}")
|
||||
print(f"Cache: {'Enabled' if CACHE_ENABLED else 'Disabled'}")
|
||||
print(f"Default Voice: {DEFAULT_VOICE}")
|
||||
print("=" * 60)
|
||||
|
||||
# Import and load Orpheus model
|
||||
print("Loading Orpheus model (this may take a moment)...")
|
||||
from orpheus_tts import OrpheusModel
|
||||
from vllm import AsyncLLMEngine
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
# Load LLM
|
||||
print("Loading Orpheus LLM...")
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
llm_tokenizer = AutoTokenizer.from_pretrained(ORPHEUS_MODEL)
|
||||
llm_model = AutoModelForCausalLM.from_pretrained(
|
||||
ORPHEUS_MODEL,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
)
|
||||
llm_model.eval()
|
||||
print(f"✓ LLM loaded ({ORPHEUS_MODEL})")
|
||||
|
||||
# Monkey-patch OrpheusModel to use sync LLM (AsyncLLMEngine hangs on Jetson)
|
||||
original_setup_engine = OrpheusModel._setup_engine
|
||||
def patched_setup_engine(self):
|
||||
model_name = self._map_model_params(self.model_name)
|
||||
from vllm import LLM
|
||||
return LLM(
|
||||
model=model_name,
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
gpu_memory_utilization=0.85,
|
||||
enforce_eager=False,
|
||||
)
|
||||
OrpheusModel._setup_engine = patched_setup_engine
|
||||
|
||||
# Sync token generation
|
||||
def patched_generate_tokens_sync(self, prompt, voice=None, request_id="req-001",
|
||||
temperature=0.6, top_p=0.8, max_tokens=MAX_TOKENS,
|
||||
stop_token_ids=[49158], repetition_penalty=1.3):
|
||||
from vllm import SamplingParams
|
||||
import re
|
||||
prompt_string = self._format_prompt(prompt, voice)
|
||||
print(prompt)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=temperature, top_p=top_p, max_tokens=max_tokens,
|
||||
stop_token_ids=stop_token_ids, repetition_penalty=repetition_penalty,
|
||||
)
|
||||
outputs = self.engine.generate([prompt_string], sampling_params)
|
||||
for output in outputs:
|
||||
text = output.outputs[0].text
|
||||
print(f"Raw output (first 500 chars): {text[:500]}")
|
||||
tokens = re.findall(r'<custom_token_\d+>', text)
|
||||
print(f"Found {len(tokens)} tokens")
|
||||
for token in tokens:
|
||||
yield token
|
||||
OrpheusModel.generate_tokens_sync = patched_generate_tokens_sync
|
||||
|
||||
model = OrpheusModel(model_name=ORPHEUS_MODEL)
|
||||
|
||||
print("✓ Orpheus model loaded successfully")
|
||||
# Load SNAC decoder
|
||||
print("Loading SNAC audio codec...")
|
||||
from snac import SNAC
|
||||
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval()
|
||||
if torch.cuda.is_available():
|
||||
snac_model = snac_model.to("cuda")
|
||||
print("✓ SNAC loaded")
|
||||
|
||||
# Load jobs from disk
|
||||
load_jobs_from_disk()
|
||||
@@ -451,29 +472,26 @@ async def startup():
|
||||
# Start cleanup task
|
||||
asyncio.create_task(cleanup_old_jobs())
|
||||
|
||||
print("✓ OrpheusTail v2 ready")
|
||||
|
||||
# === Pydantic Models ===
|
||||
|
||||
# ============================================================================
|
||||
# Pydantic Models
|
||||
# ============================================================================
|
||||
|
||||
class TTSRequest(BaseModel):
|
||||
"""TTS job submission request"""
|
||||
text: str
|
||||
voice: str = DEFAULT_VOICE
|
||||
|
||||
|
||||
class TTSStreamRequest(BaseModel):
|
||||
"""TTS streaming request (for head playback)"""
|
||||
text: str
|
||||
voice: str = DEFAULT_VOICE
|
||||
|
||||
|
||||
class JobResponse(BaseModel):
|
||||
"""Job submission response"""
|
||||
job_id: str
|
||||
status: str
|
||||
|
||||
|
||||
class StatusResponse(BaseModel):
|
||||
"""Job status response"""
|
||||
job_id: str
|
||||
status: str
|
||||
progress: int
|
||||
@@ -481,23 +499,23 @@ class StatusResponse(BaseModel):
|
||||
audio_url: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class VoicesResponse(BaseModel):
|
||||
"""Available voices response"""
|
||||
builtin: List[str]
|
||||
custom: List[str]
|
||||
default: str
|
||||
emotion_tags: List[str]
|
||||
|
||||
|
||||
# === Endpoints ===
|
||||
# ============================================================================
|
||||
# Endpoints
|
||||
# ============================================================================
|
||||
|
||||
@app.get("/")
|
||||
def root():
|
||||
"""Root endpoint"""
|
||||
return {
|
||||
"service": "OrpheusTail - Orpheus TTS Service",
|
||||
"version": "1.0.0",
|
||||
"version": "2.0.0",
|
||||
"engine": "transformers (streaming)",
|
||||
"model": ORPHEUS_MODEL,
|
||||
"default_voice": DEFAULT_VOICE,
|
||||
"emotion_tags": EMOTION_TAGS,
|
||||
@@ -505,28 +523,25 @@ def root():
|
||||
"/tts/submit": "POST - Submit TTS job",
|
||||
"/tts/status/{job_id}": "GET - Check job status",
|
||||
"/tts/audio/{job_id}": "GET - Download audio",
|
||||
"/tts/stream": "POST - Stream audio (for head)",
|
||||
"/tts/stream": "POST - Stream audio (true token-level)",
|
||||
"/voice/clone": "POST - Upload voice reference",
|
||||
"/voices": "GET - List available voices",
|
||||
"/health": "GET - Health check"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
"""Health check"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"model_loaded": model is not None,
|
||||
"model_loaded": llm_model is not None and snac_model is not None,
|
||||
"engine": "transformers",
|
||||
"cache_enabled": CACHE_ENABLED,
|
||||
"voices_available": len(BUILTIN_VOICES) + len(get_custom_voices())
|
||||
}
|
||||
|
||||
|
||||
@app.get("/voices", response_model=VoicesResponse)
|
||||
def list_voices():
|
||||
"""List all available voices"""
|
||||
return VoicesResponse(
|
||||
builtin=BUILTIN_VOICES,
|
||||
custom=get_custom_voices(),
|
||||
@@ -535,197 +550,113 @@ def list_voices():
|
||||
)
|
||||
|
||||
|
||||
# --- TTS endpoints ---
|
||||
|
||||
@app.post("/tts/submit", response_model=JobResponse)
|
||||
async def submit_tts_job(request: TTSRequest):
|
||||
"""Submit a TTS job for processing."""
|
||||
job_id = str(uuid.uuid4())
|
||||
|
||||
job = JobInfo(
|
||||
job_id=job_id,
|
||||
text=request.text,
|
||||
voice=request.voice,
|
||||
status=JobStatus.PENDING,
|
||||
progress=0,
|
||||
job_id=job_id, text=request.text, voice=request.voice,
|
||||
status=JobStatus.PENDING, progress=0,
|
||||
created_at=datetime.now().isoformat()
|
||||
)
|
||||
|
||||
jobs[job_id] = job
|
||||
save_jobs_to_disk()
|
||||
|
||||
# Use asyncio.create_task for proper async execution
|
||||
asyncio.create_task(
|
||||
generate_speech_background(job_id, request.text, request.voice)
|
||||
)
|
||||
|
||||
print(f"Job {job_id} submitted: '{request.text[:50]}...' with voice '{request.voice}'")
|
||||
|
||||
asyncio.create_task(generate_speech_background(job_id, request.text, request.voice))
|
||||
return JobResponse(job_id=job_id, status=JobStatus.PENDING)
|
||||
|
||||
|
||||
@app.get("/tts/status/{job_id}", response_model=StatusResponse)
|
||||
async def get_job_status(job_id: str):
|
||||
"""Get status of a TTS job."""
|
||||
if job_id not in jobs:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
job = jobs[job_id]
|
||||
|
||||
response = StatusResponse(
|
||||
job_id=job_id,
|
||||
status=job.status,
|
||||
progress=job.progress,
|
||||
cached=job.cached
|
||||
)
|
||||
|
||||
job_id=job_id, status=job.status, progress=job.progress, cached=job.cached)
|
||||
if job.status == JobStatus.SUCCESS:
|
||||
response.audio_url = f"/tts/audio/{job_id}"
|
||||
elif job.status == JobStatus.FAILURE:
|
||||
response.error = job.error
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@app.get("/tts/audio/{job_id}")
|
||||
async def get_audio(job_id: str):
|
||||
"""Retrieve generated audio file."""
|
||||
if job_id not in jobs:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
job = jobs[job_id]
|
||||
|
||||
if job.status != JobStatus.SUCCESS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Audio not ready. Job status: {job.status}"
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=400, detail=f"Audio not ready. Status: {job.status}")
|
||||
if not job.audio_path or not Path(job.audio_path).exists():
|
||||
raise HTTPException(status_code=404, detail="Audio file not found")
|
||||
|
||||
return FileResponse(
|
||||
job.audio_path,
|
||||
media_type="audio/wav",
|
||||
filename=f"{job_id}.wav"
|
||||
)
|
||||
return FileResponse(job.audio_path, media_type="audio/wav", filename=f"{job_id}.wav")
|
||||
|
||||
|
||||
@app.post("/tts/stream")
|
||||
async def stream_tts(request: TTSStreamRequest):
|
||||
"""
|
||||
Stream TTS audio with sentence-level chunking.
|
||||
|
||||
Splits text into small chunks (sentences/clauses) and generates each
|
||||
independently. First chunk's audio starts playing while later chunks
|
||||
are still generating. Reduces perceived latency significantly.
|
||||
Stream TTS audio with true token-level streaming.
|
||||
First audio arrives after ~28 tokens (~1-2 seconds).
|
||||
"""
|
||||
global model
|
||||
|
||||
if model is None:
|
||||
if llm_model is None or snac_model is None:
|
||||
raise HTTPException(status_code=503, detail="Model not loaded")
|
||||
|
||||
voice = request.voice
|
||||
if voice not in BUILTIN_VOICES:
|
||||
voice = DEFAULT_VOICE
|
||||
def audio_generator():
|
||||
for pcm_chunk in generate_stream(request.text, request.voice):
|
||||
yield pcm_chunk
|
||||
|
||||
def sync_audio_generator():
|
||||
"""Generate audio per-sentence, yielding as each finishes."""
|
||||
try:
|
||||
# Split into fine-grained chunks for faster first-audio
|
||||
text_chunks = chunk_text_fine(request.text)
|
||||
print(f"[stream] {len(text_chunks)} chunk(s): {[c[:40] for c in text_chunks]}")
|
||||
for chunk_idx, chunk in enumerate(text_chunks):
|
||||
print(f" Generating chunk {chunk_idx + 1}/{len(text_chunks)}: {chunk[:60]}...")
|
||||
syn_tokens = model.generate_speech(
|
||||
prompt=chunk,
|
||||
voice=voice,
|
||||
max_tokens=MAX_TOKENS,
|
||||
)
|
||||
for audio_chunk in syn_tokens:
|
||||
yield audio_chunk
|
||||
except Exception as e:
|
||||
print(f"Stream error: {e}")
|
||||
raise
|
||||
return StreamingResponse(audio_generator(), media_type="audio/pcm")
|
||||
|
||||
return StreamingResponse(
|
||||
sync_audio_generator(),
|
||||
media_type="audio/pcm"
|
||||
)
|
||||
|
||||
# --- Voice endpoints ---
|
||||
|
||||
@app.post("/voice/clone")
|
||||
async def upload_voice_reference(
|
||||
name: str,
|
||||
audio: UploadFile = File(...),
|
||||
):
|
||||
"""
|
||||
Upload a reference audio file for voice cloning.
|
||||
|
||||
Args:
|
||||
name: Name for this custom voice
|
||||
audio: WAV audio file (5-30 seconds recommended)
|
||||
"""
|
||||
async def upload_voice_reference(name: str, audio: UploadFile = File(...)):
|
||||
"""Upload reference audio for voice cloning (saved for future LoRA fine-tuning)."""
|
||||
if not name.isalnum():
|
||||
raise HTTPException(status_code=400, detail="Voice name must be alphanumeric")
|
||||
|
||||
if name in BUILTIN_VOICES:
|
||||
raise HTTPException(status_code=400, detail="Cannot overwrite built-in voice")
|
||||
|
||||
# Save the reference audio
|
||||
voice_path = VOICES_DIR / f"{name}.wav"
|
||||
|
||||
try:
|
||||
content = await audio.read()
|
||||
with open(voice_path, 'wb') as f:
|
||||
f.write(content)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"status": "saved",
|
||||
"voice_name": name,
|
||||
"message": f"Voice '{name}' saved. Use voice='{name}' in TTS requests."
|
||||
"message": f"Voice '{name}' saved. Note: the finetuned model uses 8 built-in voices. "
|
||||
f"Custom voice cloning requires LoRA fine-tuning (coming soon).",
|
||||
"builtin_voices": BUILTIN_VOICES,
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to save voice: {e}")
|
||||
|
||||
|
||||
@app.delete("/voice/{name}")
|
||||
async def delete_voice(name: str):
|
||||
"""Delete a custom voice."""
|
||||
if name in BUILTIN_VOICES:
|
||||
raise HTTPException(status_code=400, detail="Cannot delete built-in voice")
|
||||
|
||||
voice_path = VOICES_DIR / f"{name}.wav"
|
||||
if not voice_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Voice not found")
|
||||
|
||||
voice_path.unlink()
|
||||
return {"status": "success", "message": f"Voice '{name}' deleted"}
|
||||
|
||||
|
||||
@app.delete("/tts/job/{job_id}")
|
||||
async def delete_job(job_id: str):
|
||||
"""Delete a job and its audio file."""
|
||||
if job_id not in jobs:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
job = jobs[job_id]
|
||||
|
||||
if job.audio_path and Path(job.audio_path).exists():
|
||||
try:
|
||||
Path(job.audio_path).unlink()
|
||||
except:
|
||||
pass
|
||||
|
||||
del jobs[job_id]
|
||||
save_jobs_to_disk()
|
||||
|
||||
return {"message": f"Job {job_id} deleted"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(
|
||||
"main:app",
|
||||
host="0.0.0.0",
|
||||
port=8766, # Same port as VoiceTail for drop-in replacement
|
||||
reload=False
|
||||
)
|
||||
uvicorn.run("main:app", host="0.0.0.0", port=8766, reload=False)
|
||||
|
||||
Reference in New Issue
Block a user