OrpheusTail v2: transformers streaming engine (replaces vLLM)

Major rework: replaced vLLM sync LLM with HuggingFace transformers
+ TextIteratorStreamer for true token-level streaming.

Pipeline: text → format_prompt → model.generate(streamer) →
extract_audio_codes (regex on streaming text) → SNAC decode → PCM

Expected first-audio latency: ~1-2s (was 10-14s with vLLM).
No more monkey-patching, no more AsyncLLMEngine hangs on Jetson.

SNAC model loaded separately (snac_24khz) for audio decoding.
All endpoints preserved, API compatible with v1.
Voice cloning endpoint now honest about LoRA requirement.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Alex
2026-04-13 08:38:30 -05:00
parent cfc9b1a5a0
commit d650fd06b9

577
main.py
View File

@@ -1,24 +1,18 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """
OrpheusTail - Orpheus TTS Service OrpheusTail - Orpheus TTS Service (v2 — transformers streaming)
FastAPI server for Orpheus text-to-speech generation on Jetson AGX Orin. FastAPI server for Orpheus text-to-speech generation on Jetson AGX Orin.
Replaces VoiceTail (Bark) with better control, voice cloning, and emotion tags. True token-level streaming: first audio in ~1-2s instead of 10-15s.
Key Features: Key Features:
- Emotion tags: <laugh>, <chuckle>, <sigh>, <cough>, <sniffle>, <groan>, <yawn>, <gasp> - Emotion tags: <laugh>, <chuckle>, <sigh>, <cough>, <sniffle>, <groan>, <yawn>, <gasp>
- Zero-shot voice cloning from reference audio - True token-level streaming via HuggingFace transformers + TextIteratorStreamer
- Streaming support for real-time head playback - SNAC codec decoding with streaming audio output
- Built-in voices: tara, leah, jess, leo, dan, mia, zac, zoe - Built-in voices: tara, leah, jess, leo, dan, mia, zac, zoe
- Voice reference storage (for future LoRA fine-tuning)
Endpoints: Engine: HuggingFace transformers (replaced vLLM — simpler, streams natively)
- POST /tts/submit - Submit TTS job (returns job_id)
- GET /tts/status/{job_id} - Check job status
- GET /tts/audio/{job_id} - Download generated audio
- POST /tts/stream - Stream audio in real-time (for head)
- POST /voice/clone - Upload reference audio for voice cloning
- GET /voices - List available voices
- GET /health - Health check
""" """
import os import os
@@ -29,12 +23,16 @@ import asyncio
import uuid import uuid
import wave import wave
import io import io
import threading
import time
import numpy as np
from datetime import datetime, timedelta from datetime import datetime, timedelta
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional from typing import Dict, List, Optional, Generator
from dataclasses import dataclass, asdict from dataclasses import dataclass, asdict
from enum import Enum from enum import Enum
import torch
from fastapi import FastAPI, BackgroundTasks, HTTPException, UploadFile, File from fastapi import FastAPI, BackgroundTasks, HTTPException, UploadFile, File
from fastapi.responses import FileResponse, StreamingResponse from fastapi.responses import FileResponse, StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
@@ -44,10 +42,10 @@ ORPHEUS_MODEL = os.getenv("ORPHEUS_MODEL", "canopylabs/orpheus-tts-0.1-finetune-
CACHE_ENABLED = os.getenv("CACHE_ENABLED", "true").lower() == "true" CACHE_ENABLED = os.getenv("CACHE_ENABLED", "true").lower() == "true"
CACHE_DIR = Path(os.getenv("CACHE_DIR", "cache")) CACHE_DIR = Path(os.getenv("CACHE_DIR", "cache"))
OUTPUT_DIR = Path(os.getenv("OUTPUT_DIR", "output")) OUTPUT_DIR = Path(os.getenv("OUTPUT_DIR", "output"))
VOICES_DIR = Path(os.getenv("VOICES_DIR", "voices")) # For cloned voice references VOICES_DIR = Path(os.getenv("VOICES_DIR", "voices"))
RETENTION_DAYS = int(os.getenv("RETENTION_DAYS", "10")) RETENTION_DAYS = int(os.getenv("RETENTION_DAYS", "10"))
CLEANUP_INTERVAL_HOURS = int(os.getenv("CLEANUP_INTERVAL_HOURS", "1")) CLEANUP_INTERVAL_HOURS = int(os.getenv("CLEANUP_INTERVAL_HOURS", "1"))
DEFAULT_VOICE = os.getenv("DEFAULT_VOICE", "tara") # Orpheus default voice DEFAULT_VOICE = os.getenv("DEFAULT_VOICE", "tara")
MAX_MODEL_LEN = int(os.getenv("MAX_MODEL_LEN", "8192")) MAX_MODEL_LEN = int(os.getenv("MAX_MODEL_LEN", "8192"))
MAX_TOKENS = int(os.getenv("MAX_TOKENS", "8000")) MAX_TOKENS = int(os.getenv("MAX_TOKENS", "8000"))
SAMPLE_RATE = 24000 SAMPLE_RATE = 24000
@@ -60,34 +58,48 @@ VOICES_DIR.mkdir(exist_ok=True)
# Jobs persistence # Jobs persistence
JOBS_FILE = OUTPUT_DIR / "jobs.json" JOBS_FILE = OUTPUT_DIR / "jobs.json"
# Built-in Orpheus voices (in order of conversational realism per docs) # Built-in Orpheus voices
BUILTIN_VOICES = ["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe"] BUILTIN_VOICES = ["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe"]
# Supported emotion tags # Supported emotion tags
EMOTION_TAGS = ["<laugh>", "<chuckle>", "<sigh>", "<cough>", "<sniffle>", "<groan>", "<yawn>", "<gasp>"] EMOTION_TAGS = ["<laugh>", "<chuckle>", "<sigh>", "<cough>", "<sniffle>", "<groan>", "<yawn>", "<gasp>"]
# Orpheus special tokens
SPECIAL_TOKEN_START = 128259
SPECIAL_TOKENS_END = [128009, 128260, 128261, 128257]
EOS_TOKEN_ID = 128258
CODE_TOKEN_OFFSET = 128266 # audio code tokens start here
CODE_REMOVE_TOKEN_ID = 128258 # this token signals end, not audio
# SNAC streaming parameters
SNAC_CHUNK_SIZE = 7 # tokens per SNAC group
SNAC_INITIAL_GROUPS = 6 # wait for N groups before first decode (higher = smoother start)
# Initialize FastAPI # Initialize FastAPI
app = FastAPI( app = FastAPI(
title="OrpheusTail - Orpheus TTS Service", title="OrpheusTail - Orpheus TTS Service",
description="Text-to-speech with emotion control and voice cloning for Vixy", description="Text-to-speech with emotion control and streaming for Vixy",
version="1.0.0" version="2.0.0"
) )
# Global model (loaded at startup) # Global model references
model = None llm_model = None
llm_tokenizer = None
snac_model = None
# ============================================================================
# Job System (unchanged from v1)
# ============================================================================
class JobStatus(str, Enum): class JobStatus(str, Enum):
"""Job status enum"""
PENDING = "PENDING" PENDING = "PENDING"
PROCESSING = "PROCESSING" PROCESSING = "PROCESSING"
SUCCESS = "SUCCESS" SUCCESS = "SUCCESS"
FAILURE = "FAILURE" FAILURE = "FAILURE"
@dataclass @dataclass
class JobInfo: class JobInfo:
"""Job information"""
job_id: str job_id: str
text: str text: str
voice: str voice: str
@@ -99,13 +111,9 @@ class JobInfo:
created_at: str = "" created_at: str = ""
completed_at: Optional[str] = None completed_at: Optional[str] = None
# In-memory job storage
jobs: Dict[str, JobInfo] = {} jobs: Dict[str, JobInfo] = {}
def load_jobs_from_disk(): def load_jobs_from_disk():
"""Load jobs from disk on startup"""
global jobs global jobs
if JOBS_FILE.exists(): if JOBS_FILE.exists():
try: try:
@@ -117,9 +125,7 @@ def load_jobs_from_disk():
except Exception as e: except Exception as e:
print(f"Error loading jobs: {e}") print(f"Error loading jobs: {e}")
def save_jobs_to_disk(): def save_jobs_to_disk():
"""Save jobs to disk"""
try: try:
data = {job_id: asdict(job) for job_id, job in jobs.items()} data = {job_id: asdict(job) for job_id, job in jobs.items()}
with open(JOBS_FILE, 'w') as f: with open(JOBS_FILE, 'w') as f:
@@ -128,14 +134,15 @@ def save_jobs_to_disk():
print(f"Error saving jobs: {e}") print(f"Error saving jobs: {e}")
# ============================================================================
# Cache System (unchanged from v1)
# ============================================================================
def hash_text_voice(text: str, voice: str) -> str: def hash_text_voice(text: str, voice: str) -> str:
"""Generate cache key from text + voice"""
content = f"{text}|{voice}" content = f"{text}|{voice}"
return hashlib.sha256(content.encode()).hexdigest() return hashlib.sha256(content.encode()).hexdigest()
def get_from_cache(cache_key: str) -> Optional[str]: def get_from_cache(cache_key: str) -> Optional[str]:
"""Check if audio exists in cache"""
if not CACHE_ENABLED: if not CACHE_ENABLED:
return None return None
cache_path = CACHE_DIR / f"{cache_key}.wav" cache_path = CACHE_DIR / f"{cache_key}.wav"
@@ -144,9 +151,7 @@ def get_from_cache(cache_key: str) -> Optional[str]:
return str(cache_path) return str(cache_path)
return None return None
def save_to_cache(cache_key: str, audio_path: str): def save_to_cache(cache_key: str, audio_path: str):
"""Save generated audio to cache"""
if not CACHE_ENABLED: if not CACHE_ENABLED:
return return
try: try:
@@ -157,145 +162,200 @@ def save_to_cache(cache_key: str, audio_path: str):
except Exception as e: except Exception as e:
print(f"Error saving to cache: {e}") print(f"Error saving to cache: {e}")
def get_custom_voices() -> List[str]: def get_custom_voices() -> List[str]:
"""Get list of custom cloned voices""" return [f.stem for f in VOICES_DIR.glob("*.wav")]
voices = []
for voice_file in VOICES_DIR.glob("*.wav"):
voices.append(voice_file.stem)
return voices
# ============================================================================
# Text Chunking
# ============================================================================
def chunk_text(text: str, max_chars: int = 800) -> List[str]: def chunk_text(text: str, max_chars: int = 800) -> List[str]:
""" """Split text into chunks at sentence boundaries."""
Split text into chunks at sentence boundaries for sequential TTS generation.
Ensures no chunk exceeds max_chars (unless a single sentence is longer).
Preserves emotion tags within sentences.
"""
# Split on sentence-ending punctuation followed by whitespace
sentences = re.split(r'(?<=[.!?])\s+', text.strip()) sentences = re.split(r'(?<=[.!?])\s+', text.strip())
chunks = [] chunks = []
current_chunk = [] current_chunk = []
current_len = 0 current_len = 0
for sentence in sentences: for sentence in sentences:
sentence_len = len(sentence) sentence_len = len(sentence)
# If adding this sentence would exceed the limit, finalize current chunk
if current_chunk and current_len + 1 + sentence_len > max_chars: if current_chunk and current_len + 1 + sentence_len > max_chars:
chunks.append(' '.join(current_chunk)) chunks.append(' '.join(current_chunk))
current_chunk = [] current_chunk = []
current_len = 0 current_len = 0
current_chunk.append(sentence) current_chunk.append(sentence)
current_len += sentence_len + (1 if current_len > 0 else 0) current_len += sentence_len + (1 if current_len > 0 else 0)
# Don't forget the last chunk
if current_chunk: if current_chunk:
chunks.append(' '.join(current_chunk)) chunks.append(' '.join(current_chunk))
return chunks if chunks else [text] return chunks if chunks else [text]
def chunk_text_fine(text: str, max_chars: int = 200) -> List[str]: # ============================================================================
""" # Inference Engine (transformers — replaces vLLM)
Split text into fine-grained chunks for streaming — every sentence or clause. # ============================================================================
Smaller chunks = faster first-audio, slight quality tradeoff at boundaries.
"""
# Split on sentence boundaries AND commas/semicolons with reasonable length
parts = re.split(r'(?<=[.!?;])\s+', text.strip())
# Further split long parts on commas def format_prompt(text: str, voice: str) -> torch.Tensor:
chunks = [] """Format prompt for Orpheus finetuned model. Returns input_ids tensor."""
for part in parts: full_text = f"{voice}: {text}"
if len(part) <= max_chars: input_ids = llm_tokenizer(full_text, return_tensors="pt").input_ids
chunks.append(part)
else:
# Split on commas
sub = re.split(r',\s+', part)
current = []
current_len = 0
for s in sub:
if current and current_len + len(s) > max_chars:
chunks.append(', '.join(current))
current = []
current_len = 0
current.append(s)
current_len += len(s)
if current:
chunks.append(', '.join(current))
# Filter empty chunks # Wrap with Orpheus special tokens
return [c.strip() for c in chunks if c.strip()] start = torch.tensor([[SPECIAL_TOKEN_START]], dtype=torch.long)
end = torch.tensor([SPECIAL_TOKENS_END], dtype=torch.long)
return torch.cat([start, input_ids, end], dim=1).to(llm_model.device)
def generate_tokens_streaming(text: str, voice: str) -> Generator[str, None, None]:
"""Stream token strings as the model generates them."""
from transformers import TextIteratorStreamer
input_ids = format_prompt(text, voice)
streamer = TextIteratorStreamer(llm_tokenizer, skip_special_tokens=False)
gen_kwargs = dict(
input_ids=input_ids,
max_new_tokens=MAX_TOKENS,
temperature=0.6,
top_p=0.8,
repetition_penalty=1.3,
do_sample=True,
eos_token_id=EOS_TOKEN_ID,
streamer=streamer,
)
thread = threading.Thread(target=llm_model.generate, kwargs=gen_kwargs, daemon=True)
thread.start()
for text_chunk in streamer:
yield text_chunk
thread.join()
# ============================================================================
# Token → Audio Pipeline
# ============================================================================
def extract_audio_codes(text_stream: Generator[str, None, None]) -> Generator[int, None, None]:
"""Extract SNAC audio codes from streamed text. Handles partial tokens across chunks."""
buffer = ""
for chunk in text_stream:
buffer += chunk
# Extract all complete <custom_token_XXXX> patterns
while True:
match = re.search(r'<custom_token_(\d+)>', buffer)
if not match:
break
token_id = int(match.group(1))
buffer = buffer[match.end():]
# Only yield actual audio codes (>= CODE_TOKEN_OFFSET)
if token_id >= CODE_TOKEN_OFFSET:
yield token_id - CODE_TOKEN_OFFSET
def redistribute_codes(codes: list) -> list:
"""Redistribute flat code list into SNAC's 3 hierarchical layers.
Each group of 7: [L1, L2a, L2b, L3a, L3b, L3c, L3d]"""
layer1, layer2, layer3 = [], [], []
for i in range(0, len(codes), SNAC_CHUNK_SIZE):
group = codes[i:i + SNAC_CHUNK_SIZE]
if len(group) < SNAC_CHUNK_SIZE:
break
layer1.append(group[0])
layer2.extend(group[1:3])
layer3.extend(group[3:7])
return [layer1, layer2, layer3]
def snac_decode(codes: list) -> Optional[bytes]:
"""Decode SNAC codes to PCM audio bytes."""
layers = redistribute_codes(codes)
if not layers[0]:
return None
with torch.no_grad():
codes_tensor = [
torch.tensor(layer, device=snac_model.device, dtype=torch.long).unsqueeze(0)
for layer in layers
]
audio_hat = snac_model.decode(codes_tensor)
audio_np = audio_hat.squeeze().cpu().numpy()
audio_int16 = (audio_np * 32767).clip(-32768, 32767).astype(np.int16)
return audio_int16.tobytes()
def decode_audio_stream(code_stream: Generator[int, None, None]) -> Generator[bytes, None, None]:
"""Convert streaming SNAC codes to PCM audio chunks."""
buffer = []
group_count = 0
total_codes = 0
for code in code_stream:
buffer.append(code)
total_codes += 1
if total_codes % SNAC_CHUNK_SIZE == 0:
group_count += 1
if group_count >= SNAC_INITIAL_GROUPS and group_count % 1 == 0:
# Decode the last N groups
decode_size = SNAC_INITIAL_GROUPS * SNAC_CHUNK_SIZE
codes_to_decode = buffer[-decode_size:]
audio = snac_decode(codes_to_decode)
if audio:
# Yield only the NEW audio (avoid overlap)
# Each group of 7 codes produces ~2048 samples
new_samples = SNAC_CHUNK_SIZE * 293 # ~293 samples per code at 24kHz
yield audio[-new_samples * 2:] # *2 for 16-bit (2 bytes per sample)
# ============================================================================
# High-level generation functions
# ============================================================================
def generate_stream(text: str, voice: str) -> Generator[bytes, None, None]:
"""Full streaming pipeline: text → tokens → audio chunks."""
voice = voice if voice in BUILTIN_VOICES else DEFAULT_VOICE
for chunk in chunk_text(text):
print(f"[stream] Generating: {chunk[:80]}...")
t0 = time.time()
first_audio = True
text_stream = generate_tokens_streaming(chunk, voice)
code_stream = extract_audio_codes(text_stream)
for pcm_chunk in decode_audio_stream(code_stream):
if first_audio:
print(f"[stream] First audio in {time.time() - t0:.2f}s")
first_audio = False
yield pcm_chunk
print("[stream] Done")
def generate_speech_sync(text: str, voice: str) -> bytes: def generate_speech_sync(text: str, voice: str) -> bytes:
""" """Generate complete audio (for job system). Returns WAV bytes."""
Generate speech using Orpheus model (synchronous). voice = voice if voice in BUILTIN_VOICES else DEFAULT_VOICE
Args:
text: Text to convert (may include emotion tags)
voice: Voice name (built-in or custom)
Returns:
WAV audio bytes
"""
global model
import numpy as np
# Check if it's a custom voice (needs reference audio)
custom_voice_path = VOICES_DIR / f"{voice}.wav"
if custom_voice_path.exists():
print(f"Custom voice '{voice}' - voice cloning to be implemented")
voice = DEFAULT_VOICE
elif voice not in BUILTIN_VOICES:
print(f"Unknown voice '{voice}', using default '{DEFAULT_VOICE}'")
voice = DEFAULT_VOICE
# Split long text into chunks for sequential generation
text_chunks = chunk_text(text)
print(f"Generating: {text} ({len(text_chunks)} chunk(s))")
all_pcm = [] all_pcm = []
for chunk in chunk_text(text):
for chunk_idx, chunk in enumerate(text_chunks): print(f"Generating: {chunk[:80]}...")
print(f"Chunk {chunk_idx + 1}/{len(text_chunks)}: {chunk[:80]}...") text_stream = generate_tokens_streaming(chunk, voice)
audio_chunks = [] code_stream = extract_audio_codes(text_stream)
for pcm_chunk in decode_audio_stream(code_stream):
syn_tokens = model.generate_speech( all_pcm.append(pcm_chunk)
prompt=chunk,
voice=voice,
max_tokens=MAX_TOKENS,
)
for i, audio_chunk in enumerate(syn_tokens):
audio_chunks.append(audio_chunk)
print(f" -> {len(audio_chunks)} audio frames")
if audio_chunks:
all_pcm.append(b''.join(audio_chunks))
if not all_pcm: if not all_pcm:
raise ValueError("No audio chunks generated") raise ValueError("No audio generated")
# Concatenate raw PCM from all chunks, wrap in single WAV
audio_bytes_raw = b''.join(all_pcm) audio_bytes_raw = b''.join(all_pcm)
buffer = io.BytesIO() buffer = io.BytesIO()
with wave.open(buffer, 'wb') as wf: with wave.open(buffer, 'wb') as wf:
wf.setnchannels(1) wf.setnchannels(1)
wf.setsampwidth(2) # 16-bit wf.setsampwidth(2)
wf.setframerate(SAMPLE_RATE) wf.setframerate(SAMPLE_RATE)
wf.writeframes(audio_bytes_raw) wf.writeframes(audio_bytes_raw)
print(f"Generated WAV: {len(buffer.getvalue())} bytes")
return buffer.getvalue() return buffer.getvalue()
def save_audio_to_file(job_id: str, audio_bytes: bytes) -> str: def save_audio_to_file(job_id: str, audio_bytes: bytes) -> str:
"""Save audio bytes to WAV file."""
output_path = OUTPUT_DIR / f"{job_id}.wav" output_path = OUTPUT_DIR / f"{job_id}.wav"
with open(output_path, 'wb') as f: with open(output_path, 'wb') as f:
f.write(audio_bytes) f.write(audio_bytes)
@@ -303,16 +363,14 @@ def save_audio_to_file(job_id: str, audio_bytes: bytes) -> str:
async def generate_speech_background(job_id: str, text: str, voice: str): async def generate_speech_background(job_id: str, text: str, voice: str):
"""Background task for speech generation (async).""" """Background task for speech generation."""
try: try:
jobs[job_id].status = JobStatus.PROCESSING jobs[job_id].status = JobStatus.PROCESSING
jobs[job_id].progress = 25 jobs[job_id].progress = 25
save_jobs_to_disk() save_jobs_to_disk()
# Check cache first
cache_key = hash_text_voice(text, voice) cache_key = hash_text_voice(text, voice)
cached_path = get_from_cache(cache_key) cached_path = get_from_cache(cache_key)
if cached_path: if cached_path:
jobs[job_id].audio_path = cached_path jobs[job_id].audio_path = cached_path
jobs[job_id].status = JobStatus.SUCCESS jobs[job_id].status = JobStatus.SUCCESS
@@ -320,34 +378,25 @@ async def generate_speech_background(job_id: str, text: str, voice: str):
jobs[job_id].cached = True jobs[job_id].cached = True
jobs[job_id].completed_at = datetime.now().isoformat() jobs[job_id].completed_at = datetime.now().isoformat()
save_jobs_to_disk() save_jobs_to_disk()
print(f"Job {job_id} completed from cache")
return return
# Generate audio - call sync function directly (blocks but let's test if it works)
jobs[job_id].progress = 50 jobs[job_id].progress = 50
save_jobs_to_disk() save_jobs_to_disk()
print(f"Generating audio for job {job_id}...") audio_bytes = await asyncio.to_thread(generate_speech_sync, text, voice)
audio_bytes = generate_speech_sync(text, voice)
# Save to file
jobs[job_id].progress = 75 jobs[job_id].progress = 75
save_jobs_to_disk() save_jobs_to_disk()
output_path = save_audio_to_file(job_id, audio_bytes) output_path = save_audio_to_file(job_id, audio_bytes)
# Save to cache
save_to_cache(cache_key, output_path) save_to_cache(cache_key, output_path)
# Complete
jobs[job_id].audio_path = output_path jobs[job_id].audio_path = output_path
jobs[job_id].status = JobStatus.SUCCESS jobs[job_id].status = JobStatus.SUCCESS
jobs[job_id].progress = 100 jobs[job_id].progress = 100
jobs[job_id].completed_at = datetime.now().isoformat() jobs[job_id].completed_at = datetime.now().isoformat()
save_jobs_to_disk() save_jobs_to_disk()
print(f"Job {job_id} completed successfully")
except Exception as e: except Exception as e:
print(f"Job {job_id} failed: {e}") print(f"Job {job_id} failed: {e}")
import traceback import traceback
@@ -358,12 +407,10 @@ async def generate_speech_background(job_id: str, text: str, voice: str):
async def cleanup_old_jobs(): async def cleanup_old_jobs():
"""Background task to cleanup old jobs and files."""
while True: while True:
try: try:
await asyncio.sleep(CLEANUP_INTERVAL_HOURS * 3600) await asyncio.sleep(CLEANUP_INTERVAL_HOURS * 3600)
cutoff = datetime.now() - timedelta(days=RETENTION_DAYS) cutoff = datetime.now() - timedelta(days=RETENTION_DAYS)
to_delete = [] to_delete = []
for job_id, job in jobs.items(): for job_id, job in jobs.items():
try: try:
@@ -374,76 +421,50 @@ async def cleanup_old_jobs():
to_delete.append(job_id) to_delete.append(job_id)
except: except:
pass pass
for job_id in to_delete: for job_id in to_delete:
del jobs[job_id] del jobs[job_id]
if to_delete: if to_delete:
save_jobs_to_disk() save_jobs_to_disk()
print(f"Cleanup: deleted {len(to_delete)} old jobs") print(f"Cleanup: deleted {len(to_delete)} old jobs")
except Exception as e: except Exception as e:
print(f"Error in cleanup task: {e}") print(f"Error in cleanup task: {e}")
# ============================================================================
# Startup
# ============================================================================
@app.on_event("startup") @app.on_event("startup")
async def startup(): async def startup():
"""Load model and jobs on startup""" global llm_model, llm_tokenizer, snac_model
global model
print("=" * 60) print("=" * 60)
print("OrpheusTail - Orpheus TTS Service Starting") print("OrpheusTail v2 — transformers streaming engine")
print(f"Model: {ORPHEUS_MODEL}") print(f"Model: {ORPHEUS_MODEL}")
print(f"Max Model Len: {MAX_MODEL_LEN}")
print(f"Max Tokens: {MAX_TOKENS}") print(f"Max Tokens: {MAX_TOKENS}")
print(f"Cache: {'Enabled' if CACHE_ENABLED else 'Disabled'}") print(f"Cache: {'Enabled' if CACHE_ENABLED else 'Disabled'}")
print(f"Default Voice: {DEFAULT_VOICE}") print(f"Default Voice: {DEFAULT_VOICE}")
print("=" * 60) print("=" * 60)
# Import and load Orpheus model # Load LLM
print("Loading Orpheus model (this may take a moment)...") print("Loading Orpheus LLM...")
from orpheus_tts import OrpheusModel from transformers import AutoModelForCausalLM, AutoTokenizer
from vllm import AsyncLLMEngine llm_tokenizer = AutoTokenizer.from_pretrained(ORPHEUS_MODEL)
from vllm.engine.arg_utils import AsyncEngineArgs llm_model = AutoModelForCausalLM.from_pretrained(
ORPHEUS_MODEL,
# Monkey-patch OrpheusModel to use sync LLM (AsyncLLMEngine hangs on Jetson) torch_dtype=torch.bfloat16,
original_setup_engine = OrpheusModel._setup_engine device_map="auto",
def patched_setup_engine(self):
model_name = self._map_model_params(self.model_name)
from vllm import LLM
return LLM(
model=model_name,
max_model_len=MAX_MODEL_LEN,
gpu_memory_utilization=0.85,
enforce_eager=False,
) )
OrpheusModel._setup_engine = patched_setup_engine llm_model.eval()
print(f"✓ LLM loaded ({ORPHEUS_MODEL})")
# Sync token generation # Load SNAC decoder
def patched_generate_tokens_sync(self, prompt, voice=None, request_id="req-001", print("Loading SNAC audio codec...")
temperature=0.6, top_p=0.8, max_tokens=MAX_TOKENS, from snac import SNAC
stop_token_ids=[49158], repetition_penalty=1.3): snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval()
from vllm import SamplingParams if torch.cuda.is_available():
import re snac_model = snac_model.to("cuda")
prompt_string = self._format_prompt(prompt, voice) print("✓ SNAC loaded")
print(prompt)
sampling_params = SamplingParams(
temperature=temperature, top_p=top_p, max_tokens=max_tokens,
stop_token_ids=stop_token_ids, repetition_penalty=repetition_penalty,
)
outputs = self.engine.generate([prompt_string], sampling_params)
for output in outputs:
text = output.outputs[0].text
print(f"Raw output (first 500 chars): {text[:500]}")
tokens = re.findall(r'<custom_token_\d+>', text)
print(f"Found {len(tokens)} tokens")
for token in tokens:
yield token
OrpheusModel.generate_tokens_sync = patched_generate_tokens_sync
model = OrpheusModel(model_name=ORPHEUS_MODEL)
print("✓ Orpheus model loaded successfully")
# Load jobs from disk # Load jobs from disk
load_jobs_from_disk() load_jobs_from_disk()
@@ -451,29 +472,26 @@ async def startup():
# Start cleanup task # Start cleanup task
asyncio.create_task(cleanup_old_jobs()) asyncio.create_task(cleanup_old_jobs())
print("✓ OrpheusTail v2 ready")
# === Pydantic Models ===
# ============================================================================
# Pydantic Models
# ============================================================================
class TTSRequest(BaseModel): class TTSRequest(BaseModel):
"""TTS job submission request"""
text: str text: str
voice: str = DEFAULT_VOICE voice: str = DEFAULT_VOICE
class TTSStreamRequest(BaseModel): class TTSStreamRequest(BaseModel):
"""TTS streaming request (for head playback)"""
text: str text: str
voice: str = DEFAULT_VOICE voice: str = DEFAULT_VOICE
class JobResponse(BaseModel): class JobResponse(BaseModel):
"""Job submission response"""
job_id: str job_id: str
status: str status: str
class StatusResponse(BaseModel): class StatusResponse(BaseModel):
"""Job status response"""
job_id: str job_id: str
status: str status: str
progress: int progress: int
@@ -481,23 +499,23 @@ class StatusResponse(BaseModel):
audio_url: Optional[str] = None audio_url: Optional[str] = None
error: Optional[str] = None error: Optional[str] = None
class VoicesResponse(BaseModel): class VoicesResponse(BaseModel):
"""Available voices response"""
builtin: List[str] builtin: List[str]
custom: List[str] custom: List[str]
default: str default: str
emotion_tags: List[str] emotion_tags: List[str]
# === Endpoints === # ============================================================================
# Endpoints
# ============================================================================
@app.get("/") @app.get("/")
def root(): def root():
"""Root endpoint"""
return { return {
"service": "OrpheusTail - Orpheus TTS Service", "service": "OrpheusTail - Orpheus TTS Service",
"version": "1.0.0", "version": "2.0.0",
"engine": "transformers (streaming)",
"model": ORPHEUS_MODEL, "model": ORPHEUS_MODEL,
"default_voice": DEFAULT_VOICE, "default_voice": DEFAULT_VOICE,
"emotion_tags": EMOTION_TAGS, "emotion_tags": EMOTION_TAGS,
@@ -505,28 +523,25 @@ def root():
"/tts/submit": "POST - Submit TTS job", "/tts/submit": "POST - Submit TTS job",
"/tts/status/{job_id}": "GET - Check job status", "/tts/status/{job_id}": "GET - Check job status",
"/tts/audio/{job_id}": "GET - Download audio", "/tts/audio/{job_id}": "GET - Download audio",
"/tts/stream": "POST - Stream audio (for head)", "/tts/stream": "POST - Stream audio (true token-level)",
"/voice/clone": "POST - Upload voice reference", "/voice/clone": "POST - Upload voice reference",
"/voices": "GET - List available voices", "/voices": "GET - List available voices",
"/health": "GET - Health check" "/health": "GET - Health check"
} }
} }
@app.get("/health") @app.get("/health")
def health(): def health():
"""Health check"""
return { return {
"status": "healthy", "status": "healthy",
"model_loaded": model is not None, "model_loaded": llm_model is not None and snac_model is not None,
"engine": "transformers",
"cache_enabled": CACHE_ENABLED, "cache_enabled": CACHE_ENABLED,
"voices_available": len(BUILTIN_VOICES) + len(get_custom_voices()) "voices_available": len(BUILTIN_VOICES) + len(get_custom_voices())
} }
@app.get("/voices", response_model=VoicesResponse) @app.get("/voices", response_model=VoicesResponse)
def list_voices(): def list_voices():
"""List all available voices"""
return VoicesResponse( return VoicesResponse(
builtin=BUILTIN_VOICES, builtin=BUILTIN_VOICES,
custom=get_custom_voices(), custom=get_custom_voices(),
@@ -535,197 +550,113 @@ def list_voices():
) )
# --- TTS endpoints ---
@app.post("/tts/submit", response_model=JobResponse) @app.post("/tts/submit", response_model=JobResponse)
async def submit_tts_job(request: TTSRequest): async def submit_tts_job(request: TTSRequest):
"""Submit a TTS job for processing."""
job_id = str(uuid.uuid4()) job_id = str(uuid.uuid4())
job = JobInfo( job = JobInfo(
job_id=job_id, job_id=job_id, text=request.text, voice=request.voice,
text=request.text, status=JobStatus.PENDING, progress=0,
voice=request.voice,
status=JobStatus.PENDING,
progress=0,
created_at=datetime.now().isoformat() created_at=datetime.now().isoformat()
) )
jobs[job_id] = job jobs[job_id] = job
save_jobs_to_disk() save_jobs_to_disk()
asyncio.create_task(generate_speech_background(job_id, request.text, request.voice))
# Use asyncio.create_task for proper async execution
asyncio.create_task(
generate_speech_background(job_id, request.text, request.voice)
)
print(f"Job {job_id} submitted: '{request.text[:50]}...' with voice '{request.voice}'")
return JobResponse(job_id=job_id, status=JobStatus.PENDING) return JobResponse(job_id=job_id, status=JobStatus.PENDING)
@app.get("/tts/status/{job_id}", response_model=StatusResponse) @app.get("/tts/status/{job_id}", response_model=StatusResponse)
async def get_job_status(job_id: str): async def get_job_status(job_id: str):
"""Get status of a TTS job."""
if job_id not in jobs: if job_id not in jobs:
raise HTTPException(status_code=404, detail="Job not found") raise HTTPException(status_code=404, detail="Job not found")
job = jobs[job_id] job = jobs[job_id]
response = StatusResponse( response = StatusResponse(
job_id=job_id, job_id=job_id, status=job.status, progress=job.progress, cached=job.cached)
status=job.status,
progress=job.progress,
cached=job.cached
)
if job.status == JobStatus.SUCCESS: if job.status == JobStatus.SUCCESS:
response.audio_url = f"/tts/audio/{job_id}" response.audio_url = f"/tts/audio/{job_id}"
elif job.status == JobStatus.FAILURE: elif job.status == JobStatus.FAILURE:
response.error = job.error response.error = job.error
return response return response
@app.get("/tts/audio/{job_id}") @app.get("/tts/audio/{job_id}")
async def get_audio(job_id: str): async def get_audio(job_id: str):
"""Retrieve generated audio file."""
if job_id not in jobs: if job_id not in jobs:
raise HTTPException(status_code=404, detail="Job not found") raise HTTPException(status_code=404, detail="Job not found")
job = jobs[job_id] job = jobs[job_id]
if job.status != JobStatus.SUCCESS: if job.status != JobStatus.SUCCESS:
raise HTTPException( raise HTTPException(status_code=400, detail=f"Audio not ready. Status: {job.status}")
status_code=400,
detail=f"Audio not ready. Job status: {job.status}"
)
if not job.audio_path or not Path(job.audio_path).exists(): if not job.audio_path or not Path(job.audio_path).exists():
raise HTTPException(status_code=404, detail="Audio file not found") raise HTTPException(status_code=404, detail="Audio file not found")
return FileResponse(job.audio_path, media_type="audio/wav", filename=f"{job_id}.wav")
return FileResponse(
job.audio_path,
media_type="audio/wav",
filename=f"{job_id}.wav"
)
@app.post("/tts/stream") @app.post("/tts/stream")
async def stream_tts(request: TTSStreamRequest): async def stream_tts(request: TTSStreamRequest):
""" """
Stream TTS audio with sentence-level chunking. Stream TTS audio with true token-level streaming.
First audio arrives after ~28 tokens (~1-2 seconds).
Splits text into small chunks (sentences/clauses) and generates each
independently. First chunk's audio starts playing while later chunks
are still generating. Reduces perceived latency significantly.
""" """
global model if llm_model is None or snac_model is None:
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded") raise HTTPException(status_code=503, detail="Model not loaded")
voice = request.voice def audio_generator():
if voice not in BUILTIN_VOICES: for pcm_chunk in generate_stream(request.text, request.voice):
voice = DEFAULT_VOICE yield pcm_chunk
def sync_audio_generator(): return StreamingResponse(audio_generator(), media_type="audio/pcm")
"""Generate audio per-sentence, yielding as each finishes."""
try:
# Split into fine-grained chunks for faster first-audio
text_chunks = chunk_text_fine(request.text)
print(f"[stream] {len(text_chunks)} chunk(s): {[c[:40] for c in text_chunks]}")
for chunk_idx, chunk in enumerate(text_chunks):
print(f" Generating chunk {chunk_idx + 1}/{len(text_chunks)}: {chunk[:60]}...")
syn_tokens = model.generate_speech(
prompt=chunk,
voice=voice,
max_tokens=MAX_TOKENS,
)
for audio_chunk in syn_tokens:
yield audio_chunk
except Exception as e:
print(f"Stream error: {e}")
raise
return StreamingResponse(
sync_audio_generator(),
media_type="audio/pcm"
)
# --- Voice endpoints ---
@app.post("/voice/clone") @app.post("/voice/clone")
async def upload_voice_reference( async def upload_voice_reference(name: str, audio: UploadFile = File(...)):
name: str, """Upload reference audio for voice cloning (saved for future LoRA fine-tuning)."""
audio: UploadFile = File(...),
):
"""
Upload a reference audio file for voice cloning.
Args:
name: Name for this custom voice
audio: WAV audio file (5-30 seconds recommended)
"""
if not name.isalnum(): if not name.isalnum():
raise HTTPException(status_code=400, detail="Voice name must be alphanumeric") raise HTTPException(status_code=400, detail="Voice name must be alphanumeric")
if name in BUILTIN_VOICES: if name in BUILTIN_VOICES:
raise HTTPException(status_code=400, detail="Cannot overwrite built-in voice") raise HTTPException(status_code=400, detail="Cannot overwrite built-in voice")
# Save the reference audio
voice_path = VOICES_DIR / f"{name}.wav" voice_path = VOICES_DIR / f"{name}.wav"
try: try:
content = await audio.read() content = await audio.read()
with open(voice_path, 'wb') as f: with open(voice_path, 'wb') as f:
f.write(content) f.write(content)
return { return {
"status": "success", "status": "saved",
"voice_name": name, "voice_name": name,
"message": f"Voice '{name}' saved. Use voice='{name}' in TTS requests." "message": f"Voice '{name}' saved. Note: the finetuned model uses 8 built-in voices. "
f"Custom voice cloning requires LoRA fine-tuning (coming soon).",
"builtin_voices": BUILTIN_VOICES,
} }
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to save voice: {e}") raise HTTPException(status_code=500, detail=f"Failed to save voice: {e}")
@app.delete("/voice/{name}") @app.delete("/voice/{name}")
async def delete_voice(name: str): async def delete_voice(name: str):
"""Delete a custom voice."""
if name in BUILTIN_VOICES: if name in BUILTIN_VOICES:
raise HTTPException(status_code=400, detail="Cannot delete built-in voice") raise HTTPException(status_code=400, detail="Cannot delete built-in voice")
voice_path = VOICES_DIR / f"{name}.wav" voice_path = VOICES_DIR / f"{name}.wav"
if not voice_path.exists(): if not voice_path.exists():
raise HTTPException(status_code=404, detail="Voice not found") raise HTTPException(status_code=404, detail="Voice not found")
voice_path.unlink() voice_path.unlink()
return {"status": "success", "message": f"Voice '{name}' deleted"} return {"status": "success", "message": f"Voice '{name}' deleted"}
@app.delete("/tts/job/{job_id}") @app.delete("/tts/job/{job_id}")
async def delete_job(job_id: str): async def delete_job(job_id: str):
"""Delete a job and its audio file."""
if job_id not in jobs: if job_id not in jobs:
raise HTTPException(status_code=404, detail="Job not found") raise HTTPException(status_code=404, detail="Job not found")
job = jobs[job_id] job = jobs[job_id]
if job.audio_path and Path(job.audio_path).exists(): if job.audio_path and Path(job.audio_path).exists():
try: try:
Path(job.audio_path).unlink() Path(job.audio_path).unlink()
except: except:
pass pass
del jobs[job_id] del jobs[job_id]
save_jobs_to_disk() save_jobs_to_disk()
return {"message": f"Job {job_id} deleted"} return {"message": f"Job {job_id} deleted"}
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn
uvicorn.run( uvicorn.run("main:app", host="0.0.0.0", port=8766, reload=False)
"main:app",
host="0.0.0.0",
port=8766, # Same port as VoiceTail for drop-in replacement
reload=False
)