#!/usr/bin/env python3 """ OrpheusTail - Orpheus TTS Service (v2 — transformers streaming) FastAPI server for Orpheus text-to-speech generation on Jetson AGX Orin. True token-level streaming: first audio in ~1-2s instead of 10-15s. Key Features: - Emotion tags: , , , , , , , - True token-level streaming via HuggingFace transformers + TextIteratorStreamer - SNAC codec decoding with streaming audio output - Built-in voices: tara, leah, jess, leo, dan, mia, zac, zoe - Voice reference storage (for future LoRA fine-tuning) Engine: HuggingFace transformers (replaced vLLM — simpler, streams natively) """ import os import re import json import hashlib import asyncio import uuid import wave import io import threading import time import numpy as np from datetime import datetime, timedelta from pathlib import Path from typing import Dict, List, Optional, Generator from dataclasses import dataclass, asdict from enum import Enum import torch from fastapi import FastAPI, BackgroundTasks, HTTPException, UploadFile, File from fastapi.responses import FileResponse, StreamingResponse from pydantic import BaseModel # Configuration from environment ORPHEUS_MODEL = os.getenv("ORPHEUS_MODEL", "canopylabs/orpheus-tts-0.1-finetune-prod") CACHE_ENABLED = os.getenv("CACHE_ENABLED", "true").lower() == "true" CACHE_DIR = Path(os.getenv("CACHE_DIR", "cache")) OUTPUT_DIR = Path(os.getenv("OUTPUT_DIR", "output")) VOICES_DIR = Path(os.getenv("VOICES_DIR", "voices")) RETENTION_DAYS = int(os.getenv("RETENTION_DAYS", "10")) CLEANUP_INTERVAL_HOURS = int(os.getenv("CLEANUP_INTERVAL_HOURS", "1")) DEFAULT_VOICE = os.getenv("DEFAULT_VOICE", "tara") MAX_MODEL_LEN = int(os.getenv("MAX_MODEL_LEN", "8192")) MAX_TOKENS = int(os.getenv("MAX_TOKENS", "8000")) SAMPLE_RATE = 24000 # Ensure directories exist CACHE_DIR.mkdir(exist_ok=True) OUTPUT_DIR.mkdir(exist_ok=True) VOICES_DIR.mkdir(exist_ok=True) # Jobs persistence JOBS_FILE = OUTPUT_DIR / "jobs.json" # Built-in Orpheus voices BUILTIN_VOICES = ["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe"] # Supported emotion tags EMOTION_TAGS = ["", "", "", "", "", "", "", ""] # Orpheus special tokens SPECIAL_TOKEN_START = 128259 SPECIAL_TOKENS_END = [128009, 128260, 128261, 128257] EOS_TOKEN_ID = 128258 CODE_TOKEN_OFFSET = 128266 # audio code tokens start here CODE_REMOVE_TOKEN_ID = 128258 # this token signals end, not audio # SNAC streaming parameters SNAC_CHUNK_SIZE = 7 # tokens per SNAC group SNAC_INITIAL_GROUPS = 6 # wait for N groups before first decode (higher = smoother start) # Initialize FastAPI app = FastAPI( title="OrpheusTail - Orpheus TTS Service", description="Text-to-speech with emotion control and streaming for Vixy", version="2.0.0" ) # Global model references llm_model = None llm_tokenizer = None snac_model = None # ============================================================================ # Job System (unchanged from v1) # ============================================================================ class JobStatus(str, Enum): PENDING = "PENDING" PROCESSING = "PROCESSING" SUCCESS = "SUCCESS" FAILURE = "FAILURE" @dataclass class JobInfo: job_id: str text: str voice: str status: JobStatus progress: int = 0 audio_path: Optional[str] = None error: Optional[str] = None cached: bool = False created_at: str = "" completed_at: Optional[str] = None jobs: Dict[str, JobInfo] = {} def load_jobs_from_disk(): global jobs if JOBS_FILE.exists(): try: with open(JOBS_FILE, 'r') as f: data = json.load(f) for job_id, job_dict in data.items(): jobs[job_id] = JobInfo(**job_dict) print(f"Loaded {len(jobs)} jobs from disk") except Exception as e: print(f"Error loading jobs: {e}") def save_jobs_to_disk(): try: data = {job_id: asdict(job) for job_id, job in jobs.items()} with open(JOBS_FILE, 'w') as f: json.dump(data, f, indent=2) except Exception as e: print(f"Error saving jobs: {e}") # ============================================================================ # Cache System (unchanged from v1) # ============================================================================ def hash_text_voice(text: str, voice: str) -> str: content = f"{text}|{voice}" return hashlib.sha256(content.encode()).hexdigest() def get_from_cache(cache_key: str) -> Optional[str]: if not CACHE_ENABLED: return None cache_path = CACHE_DIR / f"{cache_key}.wav" if cache_path.exists(): print(f"Cache hit: {cache_key}") return str(cache_path) return None def save_to_cache(cache_key: str, audio_path: str): if not CACHE_ENABLED: return try: import shutil cache_path = CACHE_DIR / f"{cache_key}.wav" shutil.copy(audio_path, cache_path) print(f"Saved to cache: {cache_key}") except Exception as e: print(f"Error saving to cache: {e}") def get_custom_voices() -> List[str]: return [f.stem for f in VOICES_DIR.glob("*.wav")] # ============================================================================ # Text Chunking # ============================================================================ def chunk_text(text: str, max_chars: int = 800) -> List[str]: """Split text into chunks at sentence boundaries.""" sentences = re.split(r'(?<=[.!?])\s+', text.strip()) chunks = [] current_chunk = [] current_len = 0 for sentence in sentences: sentence_len = len(sentence) if current_chunk and current_len + 1 + sentence_len > max_chars: chunks.append(' '.join(current_chunk)) current_chunk = [] current_len = 0 current_chunk.append(sentence) current_len += sentence_len + (1 if current_len > 0 else 0) if current_chunk: chunks.append(' '.join(current_chunk)) return chunks if chunks else [text] # ============================================================================ # Inference Engine (transformers — replaces vLLM) # ============================================================================ def format_prompt(text: str, voice: str) -> torch.Tensor: """Format prompt for Orpheus finetuned model. Returns input_ids tensor.""" full_text = f"{voice}: {text}" input_ids = llm_tokenizer(full_text, return_tensors="pt").input_ids # Wrap with Orpheus special tokens start = torch.tensor([[SPECIAL_TOKEN_START]], dtype=torch.long) end = torch.tensor([SPECIAL_TOKENS_END], dtype=torch.long) return torch.cat([start, input_ids, end], dim=1).to(llm_model.device) def generate_tokens_streaming(text: str, voice: str) -> Generator[str, None, None]: """Stream token strings as the model generates them.""" from transformers import TextIteratorStreamer input_ids = format_prompt(text, voice) streamer = TextIteratorStreamer(llm_tokenizer, skip_special_tokens=False) gen_kwargs = dict( input_ids=input_ids, max_new_tokens=MAX_TOKENS, temperature=0.6, top_p=0.8, repetition_penalty=1.3, do_sample=True, eos_token_id=EOS_TOKEN_ID, streamer=streamer, ) thread = threading.Thread(target=llm_model.generate, kwargs=gen_kwargs, daemon=True) thread.start() for text_chunk in streamer: yield text_chunk thread.join() # ============================================================================ # Token → Audio Pipeline # ============================================================================ def extract_audio_codes(text_stream: Generator[str, None, None]) -> Generator[int, None, None]: """Extract SNAC audio codes from streamed text. Handles partial tokens across chunks. Yields raw token IDs (before layer offset subtraction) — redistribution handles offsets.""" buffer = "" count = 0 for chunk in text_stream: buffer += chunk while True: match = re.search(r'', buffer) if not match: break token_id = int(match.group(1)) buffer = buffer[match.end():] if token_id >= CODE_TOKEN_OFFSET: # Subtract base offset + per-position layer offset (4096 per layer) # Position in group of 7: determines which SNAC layer pos_in_group = count % 7 code = token_id - CODE_TOKEN_OFFSET - (pos_in_group * 4096) if 0 <= code < 4096: count += 1 yield code else: # Out of range — skip but still count position count += 1 def redistribute_codes(codes: list) -> list: """Redistribute flat code list into SNAC's 3 hierarchical layers. Each group of 7: [L1, L2a, L2b, L3a, L3b, L3c, L3d] Codes are already offset-corrected to 0-4095 range per layer.""" layer1, layer2, layer3 = [], [], [] for i in range(0, len(codes), SNAC_CHUNK_SIZE): group = codes[i:i + SNAC_CHUNK_SIZE] if len(group) < SNAC_CHUNK_SIZE: break layer1.append(group[0]) layer2.extend(group[1:3]) layer3.extend(group[3:7]) return [layer1, layer2, layer3] def snac_decode(codes: list) -> Optional[bytes]: """Decode SNAC codes to PCM audio bytes.""" layers = redistribute_codes(codes) if not layers[0]: return None with torch.no_grad(): codes_tensor = [ torch.tensor(layer, device=snac_model.device, dtype=torch.long).unsqueeze(0) for layer in layers ] audio_hat = snac_model.decode(codes_tensor) audio_np = audio_hat.squeeze().cpu().numpy() audio_int16 = (audio_np * 32767).clip(-32768, 32767).astype(np.int16) return audio_int16.tobytes() def decode_audio_stream(code_stream: Generator[int, None, None]) -> Generator[bytes, None, None]: """Convert streaming SNAC codes to PCM audio chunks.""" buffer = [] group_count = 0 total_codes = 0 for code in code_stream: buffer.append(code) total_codes += 1 if total_codes % SNAC_CHUNK_SIZE == 0: group_count += 1 if group_count >= SNAC_INITIAL_GROUPS and group_count % 1 == 0: # Decode the last N groups decode_size = SNAC_INITIAL_GROUPS * SNAC_CHUNK_SIZE codes_to_decode = buffer[-decode_size:] audio = snac_decode(codes_to_decode) if audio: # Yield only the NEW audio (avoid overlap) # Each group of 7 codes produces ~2048 samples new_samples = SNAC_CHUNK_SIZE * 293 # ~293 samples per code at 24kHz yield audio[-new_samples * 2:] # *2 for 16-bit (2 bytes per sample) # ============================================================================ # High-level generation functions # ============================================================================ def generate_stream(text: str, voice: str) -> Generator[bytes, None, None]: """Full streaming pipeline: text → tokens → audio chunks.""" voice = voice if voice in BUILTIN_VOICES else DEFAULT_VOICE for chunk in chunk_text(text): print(f"[stream] Generating: {chunk[:80]}...") t0 = time.time() first_audio = True text_stream = generate_tokens_streaming(chunk, voice) code_stream = extract_audio_codes(text_stream) for pcm_chunk in decode_audio_stream(code_stream): if first_audio: print(f"[stream] First audio in {time.time() - t0:.2f}s") first_audio = False yield pcm_chunk print("[stream] Done") def generate_speech_sync(text: str, voice: str) -> bytes: """Generate complete audio (for job system). Returns WAV bytes.""" voice = voice if voice in BUILTIN_VOICES else DEFAULT_VOICE all_pcm = [] for chunk in chunk_text(text): print(f"Generating: {chunk[:80]}...") text_stream = generate_tokens_streaming(chunk, voice) code_stream = extract_audio_codes(text_stream) for pcm_chunk in decode_audio_stream(code_stream): all_pcm.append(pcm_chunk) if not all_pcm: raise ValueError("No audio generated") audio_bytes_raw = b''.join(all_pcm) buffer = io.BytesIO() with wave.open(buffer, 'wb') as wf: wf.setnchannels(1) wf.setsampwidth(2) wf.setframerate(SAMPLE_RATE) wf.writeframes(audio_bytes_raw) return buffer.getvalue() def save_audio_to_file(job_id: str, audio_bytes: bytes) -> str: output_path = OUTPUT_DIR / f"{job_id}.wav" with open(output_path, 'wb') as f: f.write(audio_bytes) return str(output_path) async def generate_speech_background(job_id: str, text: str, voice: str): """Background task for speech generation.""" try: jobs[job_id].status = JobStatus.PROCESSING jobs[job_id].progress = 25 save_jobs_to_disk() cache_key = hash_text_voice(text, voice) cached_path = get_from_cache(cache_key) if cached_path: jobs[job_id].audio_path = cached_path jobs[job_id].status = JobStatus.SUCCESS jobs[job_id].progress = 100 jobs[job_id].cached = True jobs[job_id].completed_at = datetime.now().isoformat() save_jobs_to_disk() return jobs[job_id].progress = 50 save_jobs_to_disk() audio_bytes = await asyncio.to_thread(generate_speech_sync, text, voice) jobs[job_id].progress = 75 save_jobs_to_disk() output_path = save_audio_to_file(job_id, audio_bytes) save_to_cache(cache_key, output_path) jobs[job_id].audio_path = output_path jobs[job_id].status = JobStatus.SUCCESS jobs[job_id].progress = 100 jobs[job_id].completed_at = datetime.now().isoformat() save_jobs_to_disk() except Exception as e: print(f"Job {job_id} failed: {e}") import traceback traceback.print_exc() jobs[job_id].status = JobStatus.FAILURE jobs[job_id].error = str(e) save_jobs_to_disk() async def cleanup_old_jobs(): while True: try: await asyncio.sleep(CLEANUP_INTERVAL_HOURS * 3600) cutoff = datetime.now() - timedelta(days=RETENTION_DAYS) to_delete = [] for job_id, job in jobs.items(): try: created = datetime.fromisoformat(job.created_at) if created < cutoff: if job.audio_path and Path(job.audio_path).exists(): Path(job.audio_path).unlink() to_delete.append(job_id) except: pass for job_id in to_delete: del jobs[job_id] if to_delete: save_jobs_to_disk() print(f"Cleanup: deleted {len(to_delete)} old jobs") except Exception as e: print(f"Error in cleanup task: {e}") # ============================================================================ # Startup # ============================================================================ @app.on_event("startup") async def startup(): global llm_model, llm_tokenizer, snac_model print("=" * 60) print("OrpheusTail v2 — transformers streaming engine") print(f"Model: {ORPHEUS_MODEL}") print(f"Max Tokens: {MAX_TOKENS}") print(f"Cache: {'Enabled' if CACHE_ENABLED else 'Disabled'}") print(f"Default Voice: {DEFAULT_VOICE}") print("=" * 60) # Load LLM print("Loading Orpheus LLM...") from transformers import AutoModelForCausalLM, AutoTokenizer llm_tokenizer = AutoTokenizer.from_pretrained(ORPHEUS_MODEL) llm_model = AutoModelForCausalLM.from_pretrained( ORPHEUS_MODEL, torch_dtype=torch.bfloat16, device_map="auto", ) llm_model.eval() print(f"✓ LLM loaded ({ORPHEUS_MODEL})") # Load SNAC decoder print("Loading SNAC audio codec...") from snac import SNAC snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval() if torch.cuda.is_available(): snac_model = snac_model.to("cuda") print("✓ SNAC loaded") # Load jobs from disk load_jobs_from_disk() # Start cleanup task asyncio.create_task(cleanup_old_jobs()) print("✓ OrpheusTail v2 ready") # ============================================================================ # Pydantic Models # ============================================================================ class TTSRequest(BaseModel): text: str voice: str = DEFAULT_VOICE class TTSStreamRequest(BaseModel): text: str voice: str = DEFAULT_VOICE class JobResponse(BaseModel): job_id: str status: str class StatusResponse(BaseModel): job_id: str status: str progress: int cached: bool = False audio_url: Optional[str] = None error: Optional[str] = None class VoicesResponse(BaseModel): builtin: List[str] custom: List[str] default: str emotion_tags: List[str] # ============================================================================ # Endpoints # ============================================================================ @app.get("/") def root(): return { "service": "OrpheusTail - Orpheus TTS Service", "version": "2.0.0", "engine": "transformers (streaming)", "model": ORPHEUS_MODEL, "default_voice": DEFAULT_VOICE, "emotion_tags": EMOTION_TAGS, "endpoints": { "/tts/submit": "POST - Submit TTS job", "/tts/status/{job_id}": "GET - Check job status", "/tts/audio/{job_id}": "GET - Download audio", "/tts/stream": "POST - Stream audio (true token-level)", "/voice/clone": "POST - Upload voice reference", "/voices": "GET - List available voices", "/health": "GET - Health check" } } @app.get("/health") def health(): return { "status": "healthy", "model_loaded": llm_model is not None and snac_model is not None, "engine": "transformers", "cache_enabled": CACHE_ENABLED, "voices_available": len(BUILTIN_VOICES) + len(get_custom_voices()) } @app.get("/voices", response_model=VoicesResponse) def list_voices(): return VoicesResponse( builtin=BUILTIN_VOICES, custom=get_custom_voices(), default=DEFAULT_VOICE, emotion_tags=EMOTION_TAGS ) # --- TTS endpoints --- @app.post("/tts/submit", response_model=JobResponse) async def submit_tts_job(request: TTSRequest): job_id = str(uuid.uuid4()) job = JobInfo( job_id=job_id, text=request.text, voice=request.voice, status=JobStatus.PENDING, progress=0, created_at=datetime.now().isoformat() ) jobs[job_id] = job save_jobs_to_disk() asyncio.create_task(generate_speech_background(job_id, request.text, request.voice)) return JobResponse(job_id=job_id, status=JobStatus.PENDING) @app.get("/tts/status/{job_id}", response_model=StatusResponse) async def get_job_status(job_id: str): if job_id not in jobs: raise HTTPException(status_code=404, detail="Job not found") job = jobs[job_id] response = StatusResponse( job_id=job_id, status=job.status, progress=job.progress, cached=job.cached) if job.status == JobStatus.SUCCESS: response.audio_url = f"/tts/audio/{job_id}" elif job.status == JobStatus.FAILURE: response.error = job.error return response @app.get("/tts/audio/{job_id}") async def get_audio(job_id: str): if job_id not in jobs: raise HTTPException(status_code=404, detail="Job not found") job = jobs[job_id] if job.status != JobStatus.SUCCESS: raise HTTPException(status_code=400, detail=f"Audio not ready. Status: {job.status}") if not job.audio_path or not Path(job.audio_path).exists(): raise HTTPException(status_code=404, detail="Audio file not found") return FileResponse(job.audio_path, media_type="audio/wav", filename=f"{job_id}.wav") @app.post("/tts/stream") async def stream_tts(request: TTSStreamRequest): """ Stream TTS audio with true token-level streaming. First audio arrives after ~28 tokens (~1-2 seconds). """ if llm_model is None or snac_model is None: raise HTTPException(status_code=503, detail="Model not loaded") def audio_generator(): for pcm_chunk in generate_stream(request.text, request.voice): yield pcm_chunk return StreamingResponse(audio_generator(), media_type="audio/pcm") # --- Voice endpoints --- @app.post("/voice/clone") async def upload_voice_reference(name: str, audio: UploadFile = File(...)): """Upload reference audio for voice cloning (saved for future LoRA fine-tuning).""" if not name.isalnum(): raise HTTPException(status_code=400, detail="Voice name must be alphanumeric") if name in BUILTIN_VOICES: raise HTTPException(status_code=400, detail="Cannot overwrite built-in voice") voice_path = VOICES_DIR / f"{name}.wav" try: content = await audio.read() with open(voice_path, 'wb') as f: f.write(content) return { "status": "saved", "voice_name": name, "message": f"Voice '{name}' saved. Note: the finetuned model uses 8 built-in voices. " f"Custom voice cloning requires LoRA fine-tuning (coming soon).", "builtin_voices": BUILTIN_VOICES, } except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to save voice: {e}") @app.delete("/voice/{name}") async def delete_voice(name: str): if name in BUILTIN_VOICES: raise HTTPException(status_code=400, detail="Cannot delete built-in voice") voice_path = VOICES_DIR / f"{name}.wav" if not voice_path.exists(): raise HTTPException(status_code=404, detail="Voice not found") voice_path.unlink() return {"status": "success", "message": f"Voice '{name}' deleted"} @app.delete("/tts/job/{job_id}") async def delete_job(job_id: str): if job_id not in jobs: raise HTTPException(status_code=404, detail="Job not found") job = jobs[job_id] if job.audio_path and Path(job.audio_path).exists(): try: Path(job.audio_path).unlink() except: pass del jobs[job_id] save_jobs_to_disk() return {"message": f"Job {job_id} deleted"} if __name__ == "__main__": import uvicorn uvicorn.run("main:app", host="0.0.0.0", port=8766, reload=False)