Files
orpheus-tts/main.py
Alex 57a2e24101 Fix SNAC decoding: correct token offset + device attribute
- CODE_TOKEN_OFFSET is 10 in decoded text (not 128266 in token ID space)
  because tokenizer.decode() maps 128266 → <custom_token_10>
- Fixed 'SNAC object has no attribute device' — use explicit SNAC_DEVICE
- Added debug logging for pipeline visibility
- Audio now generates correctly: 442KB for "Hello world"

True streaming pipeline verified: text → TextIteratorStreamer →
regex extraction → SNAC decode → PCM bytes. The bottleneck is
Jetson inference speed (~12s for first 42 tokens on a 3B model),
not the streaming infrastructure.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-13 16:41:14 -05:00

696 lines
24 KiB
Python

#!/usr/bin/env python3
"""
OrpheusTail - Orpheus TTS Service (v2 — transformers streaming)
FastAPI server for Orpheus text-to-speech generation on Jetson AGX Orin.
True token-level streaming: first audio in ~1-2s instead of 10-15s.
Key Features:
- Emotion tags: <laugh>, <chuckle>, <sigh>, <cough>, <sniffle>, <groan>, <yawn>, <gasp>
- True token-level streaming via HuggingFace transformers + TextIteratorStreamer
- SNAC codec decoding with streaming audio output
- Built-in voices: tara, leah, jess, leo, dan, mia, zac, zoe
- Voice reference storage (for future LoRA fine-tuning)
Engine: HuggingFace transformers (replaced vLLM — simpler, streams natively)
"""
import os
import re
import json
import hashlib
import asyncio
import uuid
import wave
import io
import threading
import time
import numpy as np
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Optional, Generator
from dataclasses import dataclass, asdict
from enum import Enum
import torch
from fastapi import FastAPI, BackgroundTasks, HTTPException, UploadFile, File
from fastapi.responses import FileResponse, StreamingResponse
from pydantic import BaseModel
# Configuration from environment
ORPHEUS_MODEL = os.getenv("ORPHEUS_MODEL", "canopylabs/orpheus-tts-0.1-finetune-prod")
CACHE_ENABLED = os.getenv("CACHE_ENABLED", "true").lower() == "true"
CACHE_DIR = Path(os.getenv("CACHE_DIR", "cache"))
OUTPUT_DIR = Path(os.getenv("OUTPUT_DIR", "output"))
VOICES_DIR = Path(os.getenv("VOICES_DIR", "voices"))
RETENTION_DAYS = int(os.getenv("RETENTION_DAYS", "10"))
CLEANUP_INTERVAL_HOURS = int(os.getenv("CLEANUP_INTERVAL_HOURS", "1"))
DEFAULT_VOICE = os.getenv("DEFAULT_VOICE", "tara")
MAX_MODEL_LEN = int(os.getenv("MAX_MODEL_LEN", "8192"))
MAX_TOKENS = int(os.getenv("MAX_TOKENS", "8000"))
SAMPLE_RATE = 24000
# Ensure directories exist
CACHE_DIR.mkdir(exist_ok=True)
OUTPUT_DIR.mkdir(exist_ok=True)
VOICES_DIR.mkdir(exist_ok=True)
# Jobs persistence
JOBS_FILE = OUTPUT_DIR / "jobs.json"
# Built-in Orpheus voices
BUILTIN_VOICES = ["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe"]
# Supported emotion tags
EMOTION_TAGS = ["<laugh>", "<chuckle>", "<sigh>", "<cough>", "<sniffle>", "<groan>", "<yawn>", "<gasp>"]
# Orpheus special tokens
SPECIAL_TOKEN_START = 128259
SPECIAL_TOKENS_END = [128009, 128260, 128261, 128257]
EOS_TOKEN_ID = 128258
# When decoded by tokenizer, audio codes appear as <custom_token_N> where N = token_id - 128256
# Audio codes start at token_id 128266, which decodes as <custom_token_10>
CODE_TOKEN_OFFSET = 10 # in decoded text space (token_id 128266 → custom_token_10)
CODE_REMOVE_TOKEN_ID = 128258 # this token signals end, not audio
# SNAC streaming parameters
SNAC_CHUNK_SIZE = 7 # tokens per SNAC group
SNAC_INITIAL_GROUPS = 6 # wait for N groups before first decode (higher = smoother start)
# Initialize FastAPI
app = FastAPI(
title="OrpheusTail - Orpheus TTS Service",
description="Text-to-speech with emotion control and streaming for Vixy",
version="2.0.0"
)
# Global model references
llm_model = None
llm_tokenizer = None
snac_model = None
# ============================================================================
# Job System (unchanged from v1)
# ============================================================================
class JobStatus(str, Enum):
PENDING = "PENDING"
PROCESSING = "PROCESSING"
SUCCESS = "SUCCESS"
FAILURE = "FAILURE"
@dataclass
class JobInfo:
job_id: str
text: str
voice: str
status: JobStatus
progress: int = 0
audio_path: Optional[str] = None
error: Optional[str] = None
cached: bool = False
created_at: str = ""
completed_at: Optional[str] = None
jobs: Dict[str, JobInfo] = {}
def load_jobs_from_disk():
global jobs
if JOBS_FILE.exists():
try:
with open(JOBS_FILE, 'r') as f:
data = json.load(f)
for job_id, job_dict in data.items():
jobs[job_id] = JobInfo(**job_dict)
print(f"Loaded {len(jobs)} jobs from disk")
except Exception as e:
print(f"Error loading jobs: {e}")
def save_jobs_to_disk():
try:
data = {job_id: asdict(job) for job_id, job in jobs.items()}
with open(JOBS_FILE, 'w') as f:
json.dump(data, f, indent=2)
except Exception as e:
print(f"Error saving jobs: {e}")
# ============================================================================
# Cache System (unchanged from v1)
# ============================================================================
def hash_text_voice(text: str, voice: str) -> str:
content = f"{text}|{voice}"
return hashlib.sha256(content.encode()).hexdigest()
def get_from_cache(cache_key: str) -> Optional[str]:
if not CACHE_ENABLED:
return None
cache_path = CACHE_DIR / f"{cache_key}.wav"
if cache_path.exists():
print(f"Cache hit: {cache_key}")
return str(cache_path)
return None
def save_to_cache(cache_key: str, audio_path: str):
if not CACHE_ENABLED:
return
try:
import shutil
cache_path = CACHE_DIR / f"{cache_key}.wav"
shutil.copy(audio_path, cache_path)
print(f"Saved to cache: {cache_key}")
except Exception as e:
print(f"Error saving to cache: {e}")
def get_custom_voices() -> List[str]:
return [f.stem for f in VOICES_DIR.glob("*.wav")]
# ============================================================================
# Text Chunking
# ============================================================================
def chunk_text(text: str, max_chars: int = 800) -> List[str]:
"""Split text into chunks at sentence boundaries."""
sentences = re.split(r'(?<=[.!?])\s+', text.strip())
chunks = []
current_chunk = []
current_len = 0
for sentence in sentences:
sentence_len = len(sentence)
if current_chunk and current_len + 1 + sentence_len > max_chars:
chunks.append(' '.join(current_chunk))
current_chunk = []
current_len = 0
current_chunk.append(sentence)
current_len += sentence_len + (1 if current_len > 0 else 0)
if current_chunk:
chunks.append(' '.join(current_chunk))
return chunks if chunks else [text]
# ============================================================================
# Inference Engine (transformers — replaces vLLM)
# ============================================================================
def format_prompt(text: str, voice: str) -> torch.Tensor:
"""Format prompt for Orpheus finetuned model. Returns input_ids tensor."""
full_text = f"{voice}: {text}"
input_ids = llm_tokenizer(full_text, return_tensors="pt").input_ids
# Wrap with Orpheus special tokens
start = torch.tensor([[SPECIAL_TOKEN_START]], dtype=torch.long)
end = torch.tensor([SPECIAL_TOKENS_END], dtype=torch.long)
return torch.cat([start, input_ids, end], dim=1).to(llm_model.device)
def generate_tokens_streaming(text: str, voice: str) -> Generator[str, None, None]:
"""Stream token strings as the model generates them."""
from transformers import TextIteratorStreamer
input_ids = format_prompt(text, voice)
streamer = TextIteratorStreamer(llm_tokenizer, skip_special_tokens=False)
gen_kwargs = dict(
input_ids=input_ids,
max_new_tokens=MAX_TOKENS,
temperature=0.6,
top_p=0.8,
repetition_penalty=1.3,
do_sample=True,
eos_token_id=EOS_TOKEN_ID,
streamer=streamer,
)
thread = threading.Thread(target=llm_model.generate, kwargs=gen_kwargs, daemon=True)
thread.start()
for text_chunk in streamer:
yield text_chunk
thread.join()
# ============================================================================
# Token → Audio Pipeline
# ============================================================================
def extract_audio_codes(text_stream: Generator[str, None, None]) -> Generator[int, None, None]:
"""Extract SNAC audio codes from streamed text. Handles partial tokens across chunks.
Yields raw token IDs (before layer offset subtraction) — redistribution handles offsets."""
buffer = ""
count = 0
skipped = 0
for chunk in text_stream:
buffer += chunk
while True:
match = re.search(r'<custom_token_(\d+)>', buffer)
if not match:
break
token_id = int(match.group(1))
buffer = buffer[match.end():]
if token_id >= CODE_TOKEN_OFFSET:
pos_in_group = count % 7
# token_id is already decoded (e.g., 2061 for custom_token_2061)
# Subtract the base offset (10) and per-layer offset (pos * 4096)
code = token_id - CODE_TOKEN_OFFSET - (pos_in_group * 4096)
if count < 14:
print(f"[codes] custom_token_{token_id} pos={pos_in_group} code={code}")
if 0 <= code < 4096:
count += 1
yield code
else:
skipped += 1
count += 1
print(f"[codes] Total: {count} extracted, {skipped} skipped")
def redistribute_codes(codes: list) -> list:
"""Redistribute flat code list into SNAC's 3 hierarchical layers.
Each group of 7: [L1, L2a, L2b, L3a, L3b, L3c, L3d]
Codes are already offset-corrected to 0-4095 range per layer."""
layer1, layer2, layer3 = [], [], []
for i in range(0, len(codes), SNAC_CHUNK_SIZE):
group = codes[i:i + SNAC_CHUNK_SIZE]
if len(group) < SNAC_CHUNK_SIZE:
break
layer1.append(group[0])
layer2.extend(group[1:3])
layer3.extend(group[3:7])
return [layer1, layer2, layer3]
SNAC_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def snac_decode(codes: list) -> Optional[bytes]:
"""Decode SNAC codes to PCM audio bytes."""
layers = redistribute_codes(codes)
if not layers[0]:
return None
try:
with torch.no_grad():
codes_tensor = [
torch.tensor(layer, device=SNAC_DEVICE, dtype=torch.long).unsqueeze(0)
for layer in layers
]
audio_hat = snac_model.decode(codes_tensor)
audio_np = audio_hat.squeeze().cpu().numpy()
audio_int16 = (audio_np * 32767).clip(-32768, 32767).astype(np.int16)
return audio_int16.tobytes()
except Exception as e:
print(f"[snac] Decode error: {e}")
return None
def decode_audio_stream(code_stream: Generator[int, None, None]) -> Generator[bytes, None, None]:
"""Convert streaming SNAC codes to PCM audio chunks."""
buffer = []
total_codes = 0
for code in code_stream:
buffer.append(code)
total_codes += 1
if total_codes == 1:
print(f"[snac] First code received: {code}")
# Decode every SNAC_INITIAL_GROUPS groups (sliding window)
if total_codes % SNAC_CHUNK_SIZE == 0:
groups = total_codes // SNAC_CHUNK_SIZE
if groups >= SNAC_INITIAL_GROUPS:
decode_size = SNAC_INITIAL_GROUPS * SNAC_CHUNK_SIZE
codes_to_decode = buffer[-decode_size:]
if groups == SNAC_INITIAL_GROUPS:
print(f"[snac] First decode at {total_codes} codes, values: {codes_to_decode[:7]}")
audio = snac_decode(codes_to_decode)
if audio:
yield audio
elif groups == SNAC_INITIAL_GROUPS:
print(f"[snac] WARNING: decode returned None")
print(f"[snac] Total codes: {total_codes}, groups: {total_codes // SNAC_CHUNK_SIZE}")
# ============================================================================
# High-level generation functions
# ============================================================================
def generate_stream(text: str, voice: str) -> Generator[bytes, None, None]:
"""Full streaming pipeline: text → tokens → audio chunks."""
voice = voice if voice in BUILTIN_VOICES else DEFAULT_VOICE
total_bytes = 0
for chunk in chunk_text(text):
print(f"[stream] Generating: {chunk[:80]}...")
t0 = time.time()
first_audio = True
try:
text_stream = generate_tokens_streaming(chunk, voice)
code_stream = extract_audio_codes(text_stream)
for pcm_chunk in decode_audio_stream(code_stream):
if first_audio:
print(f"[stream] First audio in {time.time() - t0:.2f}s ({len(pcm_chunk)} bytes)")
first_audio = False
total_bytes += len(pcm_chunk)
yield pcm_chunk
except Exception as e:
print(f"[stream] ERROR in pipeline: {e}")
import traceback
traceback.print_exc()
print(f"[stream] Done — {total_bytes} bytes total")
def generate_speech_sync(text: str, voice: str) -> bytes:
"""Generate complete audio (for job system). Returns WAV bytes."""
voice = voice if voice in BUILTIN_VOICES else DEFAULT_VOICE
all_pcm = []
for chunk in chunk_text(text):
print(f"Generating: {chunk[:80]}...")
text_stream = generate_tokens_streaming(chunk, voice)
code_stream = extract_audio_codes(text_stream)
for pcm_chunk in decode_audio_stream(code_stream):
all_pcm.append(pcm_chunk)
if not all_pcm:
raise ValueError("No audio generated")
audio_bytes_raw = b''.join(all_pcm)
buffer = io.BytesIO()
with wave.open(buffer, 'wb') as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(SAMPLE_RATE)
wf.writeframes(audio_bytes_raw)
return buffer.getvalue()
def save_audio_to_file(job_id: str, audio_bytes: bytes) -> str:
output_path = OUTPUT_DIR / f"{job_id}.wav"
with open(output_path, 'wb') as f:
f.write(audio_bytes)
return str(output_path)
async def generate_speech_background(job_id: str, text: str, voice: str):
"""Background task for speech generation."""
try:
jobs[job_id].status = JobStatus.PROCESSING
jobs[job_id].progress = 25
save_jobs_to_disk()
cache_key = hash_text_voice(text, voice)
cached_path = get_from_cache(cache_key)
if cached_path:
jobs[job_id].audio_path = cached_path
jobs[job_id].status = JobStatus.SUCCESS
jobs[job_id].progress = 100
jobs[job_id].cached = True
jobs[job_id].completed_at = datetime.now().isoformat()
save_jobs_to_disk()
return
jobs[job_id].progress = 50
save_jobs_to_disk()
audio_bytes = await asyncio.to_thread(generate_speech_sync, text, voice)
jobs[job_id].progress = 75
save_jobs_to_disk()
output_path = save_audio_to_file(job_id, audio_bytes)
save_to_cache(cache_key, output_path)
jobs[job_id].audio_path = output_path
jobs[job_id].status = JobStatus.SUCCESS
jobs[job_id].progress = 100
jobs[job_id].completed_at = datetime.now().isoformat()
save_jobs_to_disk()
except Exception as e:
print(f"Job {job_id} failed: {e}")
import traceback
traceback.print_exc()
jobs[job_id].status = JobStatus.FAILURE
jobs[job_id].error = str(e)
save_jobs_to_disk()
async def cleanup_old_jobs():
while True:
try:
await asyncio.sleep(CLEANUP_INTERVAL_HOURS * 3600)
cutoff = datetime.now() - timedelta(days=RETENTION_DAYS)
to_delete = []
for job_id, job in jobs.items():
try:
created = datetime.fromisoformat(job.created_at)
if created < cutoff:
if job.audio_path and Path(job.audio_path).exists():
Path(job.audio_path).unlink()
to_delete.append(job_id)
except:
pass
for job_id in to_delete:
del jobs[job_id]
if to_delete:
save_jobs_to_disk()
print(f"Cleanup: deleted {len(to_delete)} old jobs")
except Exception as e:
print(f"Error in cleanup task: {e}")
# ============================================================================
# Startup
# ============================================================================
@app.on_event("startup")
async def startup():
global llm_model, llm_tokenizer, snac_model
print("=" * 60)
print("OrpheusTail v2 — transformers streaming engine")
print(f"Model: {ORPHEUS_MODEL}")
print(f"Max Tokens: {MAX_TOKENS}")
print(f"Cache: {'Enabled' if CACHE_ENABLED else 'Disabled'}")
print(f"Default Voice: {DEFAULT_VOICE}")
print("=" * 60)
# Load LLM
print("Loading Orpheus LLM...")
from transformers import AutoModelForCausalLM, AutoTokenizer
llm_tokenizer = AutoTokenizer.from_pretrained(ORPHEUS_MODEL)
llm_model = AutoModelForCausalLM.from_pretrained(
ORPHEUS_MODEL,
torch_dtype=torch.bfloat16,
device_map="auto",
)
llm_model.eval()
print(f"✓ LLM loaded ({ORPHEUS_MODEL})")
# Load SNAC decoder
print("Loading SNAC audio codec...")
from snac import SNAC
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval()
if torch.cuda.is_available():
snac_model = snac_model.to("cuda")
print("✓ SNAC loaded")
# Load jobs from disk
load_jobs_from_disk()
# Start cleanup task
asyncio.create_task(cleanup_old_jobs())
print("✓ OrpheusTail v2 ready")
# ============================================================================
# Pydantic Models
# ============================================================================
class TTSRequest(BaseModel):
text: str
voice: str = DEFAULT_VOICE
class TTSStreamRequest(BaseModel):
text: str
voice: str = DEFAULT_VOICE
class JobResponse(BaseModel):
job_id: str
status: str
class StatusResponse(BaseModel):
job_id: str
status: str
progress: int
cached: bool = False
audio_url: Optional[str] = None
error: Optional[str] = None
class VoicesResponse(BaseModel):
builtin: List[str]
custom: List[str]
default: str
emotion_tags: List[str]
# ============================================================================
# Endpoints
# ============================================================================
@app.get("/")
def root():
return {
"service": "OrpheusTail - Orpheus TTS Service",
"version": "2.0.0",
"engine": "transformers (streaming)",
"model": ORPHEUS_MODEL,
"default_voice": DEFAULT_VOICE,
"emotion_tags": EMOTION_TAGS,
"endpoints": {
"/tts/submit": "POST - Submit TTS job",
"/tts/status/{job_id}": "GET - Check job status",
"/tts/audio/{job_id}": "GET - Download audio",
"/tts/stream": "POST - Stream audio (true token-level)",
"/voice/clone": "POST - Upload voice reference",
"/voices": "GET - List available voices",
"/health": "GET - Health check"
}
}
@app.get("/health")
def health():
return {
"status": "healthy",
"model_loaded": llm_model is not None and snac_model is not None,
"engine": "transformers",
"cache_enabled": CACHE_ENABLED,
"voices_available": len(BUILTIN_VOICES) + len(get_custom_voices())
}
@app.get("/voices", response_model=VoicesResponse)
def list_voices():
return VoicesResponse(
builtin=BUILTIN_VOICES,
custom=get_custom_voices(),
default=DEFAULT_VOICE,
emotion_tags=EMOTION_TAGS
)
# --- TTS endpoints ---
@app.post("/tts/submit", response_model=JobResponse)
async def submit_tts_job(request: TTSRequest):
job_id = str(uuid.uuid4())
job = JobInfo(
job_id=job_id, text=request.text, voice=request.voice,
status=JobStatus.PENDING, progress=0,
created_at=datetime.now().isoformat()
)
jobs[job_id] = job
save_jobs_to_disk()
asyncio.create_task(generate_speech_background(job_id, request.text, request.voice))
return JobResponse(job_id=job_id, status=JobStatus.PENDING)
@app.get("/tts/status/{job_id}", response_model=StatusResponse)
async def get_job_status(job_id: str):
if job_id not in jobs:
raise HTTPException(status_code=404, detail="Job not found")
job = jobs[job_id]
response = StatusResponse(
job_id=job_id, status=job.status, progress=job.progress, cached=job.cached)
if job.status == JobStatus.SUCCESS:
response.audio_url = f"/tts/audio/{job_id}"
elif job.status == JobStatus.FAILURE:
response.error = job.error
return response
@app.get("/tts/audio/{job_id}")
async def get_audio(job_id: str):
if job_id not in jobs:
raise HTTPException(status_code=404, detail="Job not found")
job = jobs[job_id]
if job.status != JobStatus.SUCCESS:
raise HTTPException(status_code=400, detail=f"Audio not ready. Status: {job.status}")
if not job.audio_path or not Path(job.audio_path).exists():
raise HTTPException(status_code=404, detail="Audio file not found")
return FileResponse(job.audio_path, media_type="audio/wav", filename=f"{job_id}.wav")
@app.post("/tts/stream")
async def stream_tts(request: TTSStreamRequest):
"""
Stream TTS audio with true token-level streaming.
First audio arrives after ~28 tokens (~1-2 seconds).
"""
if llm_model is None or snac_model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
def audio_generator():
for pcm_chunk in generate_stream(request.text, request.voice):
yield pcm_chunk
return StreamingResponse(audio_generator(), media_type="audio/pcm")
# --- Voice endpoints ---
@app.post("/voice/clone")
async def upload_voice_reference(name: str, audio: UploadFile = File(...)):
"""Upload reference audio for voice cloning (saved for future LoRA fine-tuning)."""
if not name.isalnum():
raise HTTPException(status_code=400, detail="Voice name must be alphanumeric")
if name in BUILTIN_VOICES:
raise HTTPException(status_code=400, detail="Cannot overwrite built-in voice")
voice_path = VOICES_DIR / f"{name}.wav"
try:
content = await audio.read()
with open(voice_path, 'wb') as f:
f.write(content)
return {
"status": "saved",
"voice_name": name,
"message": f"Voice '{name}' saved. Note: the finetuned model uses 8 built-in voices. "
f"Custom voice cloning requires LoRA fine-tuning (coming soon).",
"builtin_voices": BUILTIN_VOICES,
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to save voice: {e}")
@app.delete("/voice/{name}")
async def delete_voice(name: str):
if name in BUILTIN_VOICES:
raise HTTPException(status_code=400, detail="Cannot delete built-in voice")
voice_path = VOICES_DIR / f"{name}.wav"
if not voice_path.exists():
raise HTTPException(status_code=404, detail="Voice not found")
voice_path.unlink()
return {"status": "success", "message": f"Voice '{name}' deleted"}
@app.delete("/tts/job/{job_id}")
async def delete_job(job_id: str):
if job_id not in jobs:
raise HTTPException(status_code=404, detail="Job not found")
job = jobs[job_id]
if job.audio_path and Path(job.audio_path).exists():
try:
Path(job.audio_path).unlink()
except:
pass
del jobs[job_id]
save_jobs_to_disk()
return {"message": f"Job {job_id} deleted"}
if __name__ == "__main__":
import uvicorn
uvicorn.run("main:app", host="0.0.0.0", port=8766, reload=False)