#!/usr/bin/env python3 """ OrpheusTail - Orpheus TTS Service (v2 — transformers streaming) FastAPI server for Orpheus text-to-speech generation on Jetson AGX Orin. True token-level streaming: first audio in ~1-2s instead of 10-15s. Key Features: - Emotion tags: , , , , , , , - True token-level streaming via HuggingFace transformers + TextIteratorStreamer - SNAC codec decoding with streaming audio output - Built-in voices: tara, leah, jess, leo, dan, mia, zac, zoe - Voice reference storage (for future LoRA fine-tuning) Engine: HuggingFace transformers (replaced vLLM — simpler, streams natively) """ import os import re import json import hashlib import asyncio import uuid import wave import io import threading import time import numpy as np from datetime import datetime, timedelta from pathlib import Path from typing import Dict, List, Optional, Generator from dataclasses import dataclass, asdict from enum import Enum import torch from fastapi import FastAPI, BackgroundTasks, HTTPException, UploadFile, File from fastapi.responses import FileResponse, StreamingResponse from pydantic import BaseModel # Configuration from environment ORPHEUS_MODEL = os.getenv("ORPHEUS_MODEL", "canopylabs/orpheus-tts-0.1-finetune-prod") CACHE_ENABLED = os.getenv("CACHE_ENABLED", "true").lower() == "true" CACHE_DIR = Path(os.getenv("CACHE_DIR", "cache")) OUTPUT_DIR = Path(os.getenv("OUTPUT_DIR", "output")) VOICES_DIR = Path(os.getenv("VOICES_DIR", "voices")) RETENTION_DAYS = int(os.getenv("RETENTION_DAYS", "10")) CLEANUP_INTERVAL_HOURS = int(os.getenv("CLEANUP_INTERVAL_HOURS", "1")) DEFAULT_VOICE = os.getenv("DEFAULT_VOICE", "tara") MAX_MODEL_LEN = int(os.getenv("MAX_MODEL_LEN", "8192")) MAX_TOKENS = int(os.getenv("MAX_TOKENS", "8000")) SAMPLE_RATE = 24000 # Ensure directories exist CACHE_DIR.mkdir(exist_ok=True) OUTPUT_DIR.mkdir(exist_ok=True) VOICES_DIR.mkdir(exist_ok=True) # Jobs persistence JOBS_FILE = OUTPUT_DIR / "jobs.json" # Built-in Orpheus voices BUILTIN_VOICES = ["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe"] # Supported emotion tags EMOTION_TAGS = ["", "", "", "", "", "", "", ""] # Orpheus special tokens SPECIAL_TOKEN_START = 128259 SPECIAL_TOKENS_END = [128009, 128260, 128261, 128257] EOS_TOKEN_ID = 128258 # When decoded by tokenizer, audio codes appear as where N = token_id - 128256 # Audio codes start at token_id 128266, which decodes as CODE_TOKEN_OFFSET = 10 # in decoded text space (token_id 128266 → custom_token_10) CODE_REMOVE_TOKEN_ID = 128258 # this token signals end, not audio # SNAC streaming parameters SNAC_CHUNK_SIZE = 7 # tokens per SNAC group SNAC_INITIAL_GROUPS = 6 # wait for N groups before first decode (higher = smoother start) # Initialize FastAPI app = FastAPI( title="OrpheusTail - Orpheus TTS Service", description="Text-to-speech with emotion control and streaming for Vixy", version="2.0.0" ) # Global model references llm_model = None llm_tokenizer = None snac_model = None # ============================================================================ # Job System (unchanged from v1) # ============================================================================ class JobStatus(str, Enum): PENDING = "PENDING" PROCESSING = "PROCESSING" SUCCESS = "SUCCESS" FAILURE = "FAILURE" @dataclass class JobInfo: job_id: str text: str voice: str status: JobStatus progress: int = 0 audio_path: Optional[str] = None error: Optional[str] = None cached: bool = False created_at: str = "" completed_at: Optional[str] = None jobs: Dict[str, JobInfo] = {} def load_jobs_from_disk(): global jobs if JOBS_FILE.exists(): try: with open(JOBS_FILE, 'r') as f: data = json.load(f) for job_id, job_dict in data.items(): jobs[job_id] = JobInfo(**job_dict) print(f"Loaded {len(jobs)} jobs from disk") except Exception as e: print(f"Error loading jobs: {e}") def save_jobs_to_disk(): try: data = {job_id: asdict(job) for job_id, job in jobs.items()} with open(JOBS_FILE, 'w') as f: json.dump(data, f, indent=2) except Exception as e: print(f"Error saving jobs: {e}") # ============================================================================ # Cache System (unchanged from v1) # ============================================================================ def hash_text_voice(text: str, voice: str) -> str: content = f"{text}|{voice}" return hashlib.sha256(content.encode()).hexdigest() def get_from_cache(cache_key: str) -> Optional[str]: if not CACHE_ENABLED: return None cache_path = CACHE_DIR / f"{cache_key}.wav" if cache_path.exists(): print(f"Cache hit: {cache_key}") return str(cache_path) return None def save_to_cache(cache_key: str, audio_path: str): if not CACHE_ENABLED: return try: import shutil cache_path = CACHE_DIR / f"{cache_key}.wav" shutil.copy(audio_path, cache_path) print(f"Saved to cache: {cache_key}") except Exception as e: print(f"Error saving to cache: {e}") def get_custom_voices() -> List[str]: return [f.stem for f in VOICES_DIR.glob("*.wav")] # ============================================================================ # Text Chunking # ============================================================================ def chunk_text(text: str, max_chars: int = 800) -> List[str]: """Split text into chunks at sentence boundaries.""" sentences = re.split(r'(?<=[.!?])\s+', text.strip()) chunks = [] current_chunk = [] current_len = 0 for sentence in sentences: sentence_len = len(sentence) if current_chunk and current_len + 1 + sentence_len > max_chars: chunks.append(' '.join(current_chunk)) current_chunk = [] current_len = 0 current_chunk.append(sentence) current_len += sentence_len + (1 if current_len > 0 else 0) if current_chunk: chunks.append(' '.join(current_chunk)) return chunks if chunks else [text] # ============================================================================ # Inference Engine (transformers — replaces vLLM) # ============================================================================ def format_prompt(text: str, voice: str) -> torch.Tensor: """Format prompt for Orpheus finetuned model. Returns input_ids tensor.""" full_text = f"{voice}: {text}" input_ids = llm_tokenizer(full_text, return_tensors="pt").input_ids # Wrap with Orpheus special tokens start = torch.tensor([[SPECIAL_TOKEN_START]], dtype=torch.long) end = torch.tensor([SPECIAL_TOKENS_END], dtype=torch.long) return torch.cat([start, input_ids, end], dim=1).to(llm_model.device) def generate_tokens_streaming(text: str, voice: str) -> Generator[str, None, None]: """Stream token strings as the model generates them.""" from transformers import TextIteratorStreamer input_ids = format_prompt(text, voice) streamer = TextIteratorStreamer(llm_tokenizer, skip_special_tokens=False) gen_kwargs = dict( input_ids=input_ids, max_new_tokens=MAX_TOKENS, temperature=0.6, top_p=0.8, repetition_penalty=1.3, do_sample=True, eos_token_id=EOS_TOKEN_ID, streamer=streamer, ) thread = threading.Thread(target=llm_model.generate, kwargs=gen_kwargs, daemon=True) thread.start() for text_chunk in streamer: yield text_chunk thread.join() # ============================================================================ # Token → Audio Pipeline # ============================================================================ def extract_audio_codes(text_stream: Generator[str, None, None]) -> Generator[int, None, None]: """Extract SNAC audio codes from streamed text. Handles partial tokens across chunks. Yields raw token IDs (before layer offset subtraction) — redistribution handles offsets.""" buffer = "" count = 0 skipped = 0 for chunk in text_stream: buffer += chunk while True: match = re.search(r'', buffer) if not match: break token_id = int(match.group(1)) buffer = buffer[match.end():] if token_id >= CODE_TOKEN_OFFSET: pos_in_group = count % 7 # token_id is already decoded (e.g., 2061 for custom_token_2061) # Subtract the base offset (10) and per-layer offset (pos * 4096) code = token_id - CODE_TOKEN_OFFSET - (pos_in_group * 4096) if count < 14: print(f"[codes] custom_token_{token_id} pos={pos_in_group} code={code}") if 0 <= code < 4096: count += 1 yield code else: skipped += 1 count += 1 print(f"[codes] Total: {count} extracted, {skipped} skipped") def redistribute_codes(codes: list) -> list: """Redistribute flat code list into SNAC's 3 hierarchical layers. Each group of 7 maps as: [0]=L1, [1]=L2, [2]=L3, [3]=L3, [4]=L2, [5]=L3, [6]=L3 (from orpheus_tts.decoder.convert_to_audio)""" layer1, layer2, layer3 = [], [], [] for i in range(0, len(codes), SNAC_CHUNK_SIZE): group = codes[i:i + SNAC_CHUNK_SIZE] if len(group) < SNAC_CHUNK_SIZE: break layer1.append(group[0]) layer2.append(group[1]) layer2.append(group[4]) layer3.append(group[2]) layer3.append(group[3]) layer3.append(group[5]) layer3.append(group[6]) return [layer1, layer2, layer3] SNAC_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" def decode_audio_stream(code_stream: Generator[int, None, None]) -> Generator[bytes, None, None]: """Convert streaming SNAC codes to PCM audio chunks. Uses the exact same decode logic as orpheus_tts.decoder.convert_to_audio: accumulate all codes, decode the last 28 every 7 new codes, slice audio_hat[:,:,2048:4096] for the non-overlapping portion. """ from orpheus_tts.decoder import convert_to_audio as _original_convert buffer = [] total_codes = 0 for code in code_stream: buffer.append(code) total_codes += 1 if total_codes == 1: print(f"[snac] First code received: {code}") # The original decoder triggers every 7 codes after 28 minimum if total_codes % SNAC_CHUNK_SIZE == 0 and total_codes > 27: # Pass the last 28 codes, matching the original exactly audio_bytes = _original_convert(buffer[-28:], total_codes) if audio_bytes is not None: if total_codes == 28: print(f"[snac] First audio at {total_codes} codes") yield audio_bytes print(f"[snac] Total codes: {total_codes}") # ============================================================================ # High-level generation functions # ============================================================================ def generate_stream(text: str, voice: str) -> Generator[bytes, None, None]: """Full streaming pipeline: text → tokens → audio chunks.""" voice = voice if voice in BUILTIN_VOICES else DEFAULT_VOICE total_bytes = 0 for chunk in chunk_text(text): print(f"[stream] Generating: {chunk[:80]}...") t0 = time.time() first_audio = True try: text_stream = generate_tokens_streaming(chunk, voice) code_stream = extract_audio_codes(text_stream) for pcm_chunk in decode_audio_stream(code_stream): if first_audio: print(f"[stream] First audio in {time.time() - t0:.2f}s ({len(pcm_chunk)} bytes)") first_audio = False total_bytes += len(pcm_chunk) yield pcm_chunk except Exception as e: print(f"[stream] ERROR in pipeline: {e}") import traceback traceback.print_exc() print(f"[stream] Done — {total_bytes} bytes total") def generate_speech_sync(text: str, voice: str) -> bytes: """Generate complete audio (for job system). Returns WAV bytes.""" voice = voice if voice in BUILTIN_VOICES else DEFAULT_VOICE all_pcm = [] for chunk in chunk_text(text): print(f"Generating: {chunk[:80]}...") text_stream = generate_tokens_streaming(chunk, voice) code_stream = extract_audio_codes(text_stream) for pcm_chunk in decode_audio_stream(code_stream): all_pcm.append(pcm_chunk) if not all_pcm: raise ValueError("No audio generated") audio_bytes_raw = b''.join(all_pcm) buffer = io.BytesIO() with wave.open(buffer, 'wb') as wf: wf.setnchannels(1) wf.setsampwidth(2) wf.setframerate(SAMPLE_RATE) wf.writeframes(audio_bytes_raw) return buffer.getvalue() def save_audio_to_file(job_id: str, audio_bytes: bytes) -> str: output_path = OUTPUT_DIR / f"{job_id}.wav" with open(output_path, 'wb') as f: f.write(audio_bytes) return str(output_path) async def generate_speech_background(job_id: str, text: str, voice: str): """Background task for speech generation.""" try: jobs[job_id].status = JobStatus.PROCESSING jobs[job_id].progress = 25 save_jobs_to_disk() cache_key = hash_text_voice(text, voice) cached_path = get_from_cache(cache_key) if cached_path: jobs[job_id].audio_path = cached_path jobs[job_id].status = JobStatus.SUCCESS jobs[job_id].progress = 100 jobs[job_id].cached = True jobs[job_id].completed_at = datetime.now().isoformat() save_jobs_to_disk() return jobs[job_id].progress = 50 save_jobs_to_disk() audio_bytes = await asyncio.to_thread(generate_speech_sync, text, voice) jobs[job_id].progress = 75 save_jobs_to_disk() output_path = save_audio_to_file(job_id, audio_bytes) save_to_cache(cache_key, output_path) jobs[job_id].audio_path = output_path jobs[job_id].status = JobStatus.SUCCESS jobs[job_id].progress = 100 jobs[job_id].completed_at = datetime.now().isoformat() save_jobs_to_disk() except Exception as e: print(f"Job {job_id} failed: {e}") import traceback traceback.print_exc() jobs[job_id].status = JobStatus.FAILURE jobs[job_id].error = str(e) save_jobs_to_disk() async def cleanup_old_jobs(): while True: try: await asyncio.sleep(CLEANUP_INTERVAL_HOURS * 3600) cutoff = datetime.now() - timedelta(days=RETENTION_DAYS) to_delete = [] for job_id, job in jobs.items(): try: created = datetime.fromisoformat(job.created_at) if created < cutoff: if job.audio_path and Path(job.audio_path).exists(): Path(job.audio_path).unlink() to_delete.append(job_id) except: pass for job_id in to_delete: del jobs[job_id] if to_delete: save_jobs_to_disk() print(f"Cleanup: deleted {len(to_delete)} old jobs") except Exception as e: print(f"Error in cleanup task: {e}") # ============================================================================ # Startup # ============================================================================ @app.on_event("startup") async def startup(): global llm_model, llm_tokenizer, snac_model print("=" * 60) print("OrpheusTail v2 — transformers streaming engine") print(f"Model: {ORPHEUS_MODEL}") print(f"Max Tokens: {MAX_TOKENS}") print(f"Cache: {'Enabled' if CACHE_ENABLED else 'Disabled'}") print(f"Default Voice: {DEFAULT_VOICE}") print("=" * 60) # Load LLM print("Loading Orpheus LLM...") from transformers import AutoModelForCausalLM, AutoTokenizer llm_tokenizer = AutoTokenizer.from_pretrained(ORPHEUS_MODEL) llm_model = AutoModelForCausalLM.from_pretrained( ORPHEUS_MODEL, torch_dtype=torch.bfloat16, device_map="auto", ) llm_model.eval() print(f"✓ LLM loaded ({ORPHEUS_MODEL})") # Load SNAC decoder print("Loading SNAC audio codec...") from snac import SNAC snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval() if torch.cuda.is_available(): snac_model = snac_model.to("cuda") print("✓ SNAC loaded") # Load jobs from disk load_jobs_from_disk() # Start cleanup task asyncio.create_task(cleanup_old_jobs()) print("✓ OrpheusTail v2 ready") # ============================================================================ # Pydantic Models # ============================================================================ class TTSRequest(BaseModel): text: str voice: str = DEFAULT_VOICE class TTSStreamRequest(BaseModel): text: str voice: str = DEFAULT_VOICE class JobResponse(BaseModel): job_id: str status: str class StatusResponse(BaseModel): job_id: str status: str progress: int cached: bool = False audio_url: Optional[str] = None error: Optional[str] = None class VoicesResponse(BaseModel): builtin: List[str] custom: List[str] default: str emotion_tags: List[str] # ============================================================================ # Endpoints # ============================================================================ @app.get("/") def root(): return { "service": "OrpheusTail - Orpheus TTS Service", "version": "2.0.0", "engine": "transformers (streaming)", "model": ORPHEUS_MODEL, "default_voice": DEFAULT_VOICE, "emotion_tags": EMOTION_TAGS, "endpoints": { "/tts/submit": "POST - Submit TTS job", "/tts/status/{job_id}": "GET - Check job status", "/tts/audio/{job_id}": "GET - Download audio", "/tts/stream": "POST - Stream audio (true token-level)", "/voice/clone": "POST - Upload voice reference", "/voices": "GET - List available voices", "/health": "GET - Health check" } } @app.get("/health") def health(): return { "status": "healthy", "model_loaded": llm_model is not None and snac_model is not None, "engine": "transformers", "cache_enabled": CACHE_ENABLED, "voices_available": len(BUILTIN_VOICES) + len(get_custom_voices()) } @app.get("/voices", response_model=VoicesResponse) def list_voices(): return VoicesResponse( builtin=BUILTIN_VOICES, custom=get_custom_voices(), default=DEFAULT_VOICE, emotion_tags=EMOTION_TAGS ) # --- TTS endpoints --- @app.post("/tts/submit", response_model=JobResponse) async def submit_tts_job(request: TTSRequest): job_id = str(uuid.uuid4()) job = JobInfo( job_id=job_id, text=request.text, voice=request.voice, status=JobStatus.PENDING, progress=0, created_at=datetime.now().isoformat() ) jobs[job_id] = job save_jobs_to_disk() asyncio.create_task(generate_speech_background(job_id, request.text, request.voice)) return JobResponse(job_id=job_id, status=JobStatus.PENDING) @app.get("/tts/status/{job_id}", response_model=StatusResponse) async def get_job_status(job_id: str): if job_id not in jobs: raise HTTPException(status_code=404, detail="Job not found") job = jobs[job_id] response = StatusResponse( job_id=job_id, status=job.status, progress=job.progress, cached=job.cached) if job.status == JobStatus.SUCCESS: response.audio_url = f"/tts/audio/{job_id}" elif job.status == JobStatus.FAILURE: response.error = job.error return response @app.get("/tts/audio/{job_id}") async def get_audio(job_id: str): if job_id not in jobs: raise HTTPException(status_code=404, detail="Job not found") job = jobs[job_id] if job.status != JobStatus.SUCCESS: raise HTTPException(status_code=400, detail=f"Audio not ready. Status: {job.status}") if not job.audio_path or not Path(job.audio_path).exists(): raise HTTPException(status_code=404, detail="Audio file not found") return FileResponse(job.audio_path, media_type="audio/wav", filename=f"{job_id}.wav") @app.post("/tts/stream") async def stream_tts(request: TTSStreamRequest): """ Stream TTS audio with true token-level streaming. First audio arrives after ~28 tokens (~1-2 seconds). """ if llm_model is None or snac_model is None: raise HTTPException(status_code=503, detail="Model not loaded") def audio_generator(): for pcm_chunk in generate_stream(request.text, request.voice): yield pcm_chunk return StreamingResponse(audio_generator(), media_type="audio/pcm") # --- Voice endpoints --- @app.post("/voice/clone") async def upload_voice_reference(name: str, audio: UploadFile = File(...)): """Upload reference audio for voice cloning (saved for future LoRA fine-tuning).""" if not name.isalnum(): raise HTTPException(status_code=400, detail="Voice name must be alphanumeric") if name in BUILTIN_VOICES: raise HTTPException(status_code=400, detail="Cannot overwrite built-in voice") voice_path = VOICES_DIR / f"{name}.wav" try: content = await audio.read() with open(voice_path, 'wb') as f: f.write(content) return { "status": "saved", "voice_name": name, "message": f"Voice '{name}' saved. Note: the finetuned model uses 8 built-in voices. " f"Custom voice cloning requires LoRA fine-tuning (coming soon).", "builtin_voices": BUILTIN_VOICES, } except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to save voice: {e}") @app.delete("/voice/{name}") async def delete_voice(name: str): if name in BUILTIN_VOICES: raise HTTPException(status_code=400, detail="Cannot delete built-in voice") voice_path = VOICES_DIR / f"{name}.wav" if not voice_path.exists(): raise HTTPException(status_code=404, detail="Voice not found") voice_path.unlink() return {"status": "success", "message": f"Voice '{name}' deleted"} @app.delete("/tts/job/{job_id}") async def delete_job(job_id: str): if job_id not in jobs: raise HTTPException(status_code=404, detail="Job not found") job = jobs[job_id] if job.audio_path and Path(job.audio_path).exists(): try: Path(job.audio_path).unlink() except: pass del jobs[job_id] save_jobs_to_disk() return {"message": f"Job {job_id} deleted"} if __name__ == "__main__": import uvicorn uvicorn.run("main:app", host="0.0.0.0", port=8766, reload=False)