Our custom SNAC redistribution had wrong layer mapping (positions 1,2 vs 1,4 for layer 2) and incorrect audio slicing. Switched to importing convert_to_audio directly from orpheus_tts.decoder which handles the sliding window, layer redistribution, and 2048:4096 audio slice correctly. Audio now sounds clean with only a subtle boundary artifact on the first token group (inherent to SNAC streaming, not our code). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
681 lines
23 KiB
Python
681 lines
23 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
OrpheusTail - Orpheus TTS Service (v2 — transformers streaming)
|
|
|
|
FastAPI server for Orpheus text-to-speech generation on Jetson AGX Orin.
|
|
True token-level streaming: first audio in ~1-2s instead of 10-15s.
|
|
|
|
Key Features:
|
|
- Emotion tags: <laugh>, <chuckle>, <sigh>, <cough>, <sniffle>, <groan>, <yawn>, <gasp>
|
|
- True token-level streaming via HuggingFace transformers + TextIteratorStreamer
|
|
- SNAC codec decoding with streaming audio output
|
|
- Built-in voices: tara, leah, jess, leo, dan, mia, zac, zoe
|
|
- Voice reference storage (for future LoRA fine-tuning)
|
|
|
|
Engine: HuggingFace transformers (replaced vLLM — simpler, streams natively)
|
|
"""
|
|
|
|
import os
|
|
import re
|
|
import json
|
|
import hashlib
|
|
import asyncio
|
|
import uuid
|
|
import wave
|
|
import io
|
|
import threading
|
|
import time
|
|
import numpy as np
|
|
from datetime import datetime, timedelta
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Generator
|
|
from dataclasses import dataclass, asdict
|
|
from enum import Enum
|
|
|
|
import torch
|
|
from fastapi import FastAPI, BackgroundTasks, HTTPException, UploadFile, File
|
|
from fastapi.responses import FileResponse, StreamingResponse
|
|
from pydantic import BaseModel
|
|
|
|
# Configuration from environment
|
|
ORPHEUS_MODEL = os.getenv("ORPHEUS_MODEL", "canopylabs/orpheus-tts-0.1-finetune-prod")
|
|
CACHE_ENABLED = os.getenv("CACHE_ENABLED", "true").lower() == "true"
|
|
CACHE_DIR = Path(os.getenv("CACHE_DIR", "cache"))
|
|
OUTPUT_DIR = Path(os.getenv("OUTPUT_DIR", "output"))
|
|
VOICES_DIR = Path(os.getenv("VOICES_DIR", "voices"))
|
|
RETENTION_DAYS = int(os.getenv("RETENTION_DAYS", "10"))
|
|
CLEANUP_INTERVAL_HOURS = int(os.getenv("CLEANUP_INTERVAL_HOURS", "1"))
|
|
DEFAULT_VOICE = os.getenv("DEFAULT_VOICE", "tara")
|
|
MAX_MODEL_LEN = int(os.getenv("MAX_MODEL_LEN", "8192"))
|
|
MAX_TOKENS = int(os.getenv("MAX_TOKENS", "8000"))
|
|
SAMPLE_RATE = 24000
|
|
|
|
# Ensure directories exist
|
|
CACHE_DIR.mkdir(exist_ok=True)
|
|
OUTPUT_DIR.mkdir(exist_ok=True)
|
|
VOICES_DIR.mkdir(exist_ok=True)
|
|
|
|
# Jobs persistence
|
|
JOBS_FILE = OUTPUT_DIR / "jobs.json"
|
|
|
|
# Built-in Orpheus voices
|
|
BUILTIN_VOICES = ["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe"]
|
|
|
|
# Supported emotion tags
|
|
EMOTION_TAGS = ["<laugh>", "<chuckle>", "<sigh>", "<cough>", "<sniffle>", "<groan>", "<yawn>", "<gasp>"]
|
|
|
|
# Orpheus special tokens
|
|
SPECIAL_TOKEN_START = 128259
|
|
SPECIAL_TOKENS_END = [128009, 128260, 128261, 128257]
|
|
EOS_TOKEN_ID = 128258
|
|
# When decoded by tokenizer, audio codes appear as <custom_token_N> where N = token_id - 128256
|
|
# Audio codes start at token_id 128266, which decodes as <custom_token_10>
|
|
CODE_TOKEN_OFFSET = 10 # in decoded text space (token_id 128266 → custom_token_10)
|
|
CODE_REMOVE_TOKEN_ID = 128258 # this token signals end, not audio
|
|
|
|
# SNAC streaming parameters
|
|
SNAC_CHUNK_SIZE = 7 # tokens per SNAC group
|
|
SNAC_INITIAL_GROUPS = 6 # wait for N groups before first decode (higher = smoother start)
|
|
|
|
# Initialize FastAPI
|
|
app = FastAPI(
|
|
title="OrpheusTail - Orpheus TTS Service",
|
|
description="Text-to-speech with emotion control and streaming for Vixy",
|
|
version="2.0.0"
|
|
)
|
|
|
|
# Global model references
|
|
llm_model = None
|
|
llm_tokenizer = None
|
|
snac_model = None
|
|
|
|
|
|
# ============================================================================
|
|
# Job System (unchanged from v1)
|
|
# ============================================================================
|
|
|
|
class JobStatus(str, Enum):
|
|
PENDING = "PENDING"
|
|
PROCESSING = "PROCESSING"
|
|
SUCCESS = "SUCCESS"
|
|
FAILURE = "FAILURE"
|
|
|
|
@dataclass
|
|
class JobInfo:
|
|
job_id: str
|
|
text: str
|
|
voice: str
|
|
status: JobStatus
|
|
progress: int = 0
|
|
audio_path: Optional[str] = None
|
|
error: Optional[str] = None
|
|
cached: bool = False
|
|
created_at: str = ""
|
|
completed_at: Optional[str] = None
|
|
|
|
jobs: Dict[str, JobInfo] = {}
|
|
|
|
def load_jobs_from_disk():
|
|
global jobs
|
|
if JOBS_FILE.exists():
|
|
try:
|
|
with open(JOBS_FILE, 'r') as f:
|
|
data = json.load(f)
|
|
for job_id, job_dict in data.items():
|
|
jobs[job_id] = JobInfo(**job_dict)
|
|
print(f"Loaded {len(jobs)} jobs from disk")
|
|
except Exception as e:
|
|
print(f"Error loading jobs: {e}")
|
|
|
|
def save_jobs_to_disk():
|
|
try:
|
|
data = {job_id: asdict(job) for job_id, job in jobs.items()}
|
|
with open(JOBS_FILE, 'w') as f:
|
|
json.dump(data, f, indent=2)
|
|
except Exception as e:
|
|
print(f"Error saving jobs: {e}")
|
|
|
|
|
|
# ============================================================================
|
|
# Cache System (unchanged from v1)
|
|
# ============================================================================
|
|
|
|
def hash_text_voice(text: str, voice: str) -> str:
|
|
content = f"{text}|{voice}"
|
|
return hashlib.sha256(content.encode()).hexdigest()
|
|
|
|
def get_from_cache(cache_key: str) -> Optional[str]:
|
|
if not CACHE_ENABLED:
|
|
return None
|
|
cache_path = CACHE_DIR / f"{cache_key}.wav"
|
|
if cache_path.exists():
|
|
print(f"Cache hit: {cache_key}")
|
|
return str(cache_path)
|
|
return None
|
|
|
|
def save_to_cache(cache_key: str, audio_path: str):
|
|
if not CACHE_ENABLED:
|
|
return
|
|
try:
|
|
import shutil
|
|
cache_path = CACHE_DIR / f"{cache_key}.wav"
|
|
shutil.copy(audio_path, cache_path)
|
|
print(f"Saved to cache: {cache_key}")
|
|
except Exception as e:
|
|
print(f"Error saving to cache: {e}")
|
|
|
|
def get_custom_voices() -> List[str]:
|
|
return [f.stem for f in VOICES_DIR.glob("*.wav")]
|
|
|
|
|
|
# ============================================================================
|
|
# Text Chunking
|
|
# ============================================================================
|
|
|
|
def chunk_text(text: str, max_chars: int = 800) -> List[str]:
|
|
"""Split text into chunks at sentence boundaries."""
|
|
sentences = re.split(r'(?<=[.!?])\s+', text.strip())
|
|
chunks = []
|
|
current_chunk = []
|
|
current_len = 0
|
|
for sentence in sentences:
|
|
sentence_len = len(sentence)
|
|
if current_chunk and current_len + 1 + sentence_len > max_chars:
|
|
chunks.append(' '.join(current_chunk))
|
|
current_chunk = []
|
|
current_len = 0
|
|
current_chunk.append(sentence)
|
|
current_len += sentence_len + (1 if current_len > 0 else 0)
|
|
if current_chunk:
|
|
chunks.append(' '.join(current_chunk))
|
|
return chunks if chunks else [text]
|
|
|
|
|
|
# ============================================================================
|
|
# Inference Engine (transformers — replaces vLLM)
|
|
# ============================================================================
|
|
|
|
def format_prompt(text: str, voice: str) -> torch.Tensor:
|
|
"""Format prompt for Orpheus finetuned model. Returns input_ids tensor."""
|
|
full_text = f"{voice}: {text}"
|
|
input_ids = llm_tokenizer(full_text, return_tensors="pt").input_ids
|
|
|
|
# Wrap with Orpheus special tokens
|
|
start = torch.tensor([[SPECIAL_TOKEN_START]], dtype=torch.long)
|
|
end = torch.tensor([SPECIAL_TOKENS_END], dtype=torch.long)
|
|
return torch.cat([start, input_ids, end], dim=1).to(llm_model.device)
|
|
|
|
|
|
def generate_tokens_streaming(text: str, voice: str) -> Generator[str, None, None]:
|
|
"""Stream token strings as the model generates them."""
|
|
from transformers import TextIteratorStreamer
|
|
|
|
input_ids = format_prompt(text, voice)
|
|
|
|
streamer = TextIteratorStreamer(llm_tokenizer, skip_special_tokens=False)
|
|
|
|
gen_kwargs = dict(
|
|
input_ids=input_ids,
|
|
max_new_tokens=MAX_TOKENS,
|
|
temperature=0.6,
|
|
top_p=0.8,
|
|
repetition_penalty=1.3,
|
|
do_sample=True,
|
|
eos_token_id=EOS_TOKEN_ID,
|
|
streamer=streamer,
|
|
)
|
|
|
|
thread = threading.Thread(target=llm_model.generate, kwargs=gen_kwargs, daemon=True)
|
|
thread.start()
|
|
|
|
for text_chunk in streamer:
|
|
yield text_chunk
|
|
|
|
thread.join()
|
|
|
|
|
|
# ============================================================================
|
|
# Token → Audio Pipeline
|
|
# ============================================================================
|
|
|
|
def extract_audio_codes(text_stream: Generator[str, None, None]) -> Generator[int, None, None]:
|
|
"""Extract SNAC audio codes from streamed text. Handles partial tokens across chunks.
|
|
Yields raw token IDs (before layer offset subtraction) — redistribution handles offsets."""
|
|
buffer = ""
|
|
count = 0
|
|
skipped = 0
|
|
for chunk in text_stream:
|
|
buffer += chunk
|
|
while True:
|
|
match = re.search(r'<custom_token_(\d+)>', buffer)
|
|
if not match:
|
|
break
|
|
token_id = int(match.group(1))
|
|
buffer = buffer[match.end():]
|
|
if token_id >= CODE_TOKEN_OFFSET:
|
|
pos_in_group = count % 7
|
|
# token_id is already decoded (e.g., 2061 for custom_token_2061)
|
|
# Subtract the base offset (10) and per-layer offset (pos * 4096)
|
|
code = token_id - CODE_TOKEN_OFFSET - (pos_in_group * 4096)
|
|
if count < 14:
|
|
print(f"[codes] custom_token_{token_id} pos={pos_in_group} code={code}")
|
|
if 0 <= code < 4096:
|
|
count += 1
|
|
yield code
|
|
else:
|
|
skipped += 1
|
|
count += 1
|
|
print(f"[codes] Total: {count} extracted, {skipped} skipped")
|
|
|
|
|
|
def redistribute_codes(codes: list) -> list:
|
|
"""Redistribute flat code list into SNAC's 3 hierarchical layers.
|
|
Each group of 7 maps as: [0]=L1, [1]=L2, [2]=L3, [3]=L3, [4]=L2, [5]=L3, [6]=L3
|
|
(from orpheus_tts.decoder.convert_to_audio)"""
|
|
layer1, layer2, layer3 = [], [], []
|
|
for i in range(0, len(codes), SNAC_CHUNK_SIZE):
|
|
group = codes[i:i + SNAC_CHUNK_SIZE]
|
|
if len(group) < SNAC_CHUNK_SIZE:
|
|
break
|
|
layer1.append(group[0])
|
|
layer2.append(group[1])
|
|
layer2.append(group[4])
|
|
layer3.append(group[2])
|
|
layer3.append(group[3])
|
|
layer3.append(group[5])
|
|
layer3.append(group[6])
|
|
return [layer1, layer2, layer3]
|
|
|
|
|
|
SNAC_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
def decode_audio_stream(code_stream: Generator[int, None, None]) -> Generator[bytes, None, None]:
|
|
"""Convert streaming SNAC codes to PCM audio chunks.
|
|
|
|
Uses the exact same decode logic as orpheus_tts.decoder.convert_to_audio:
|
|
accumulate all codes, decode the last 28 every 7 new codes,
|
|
slice audio_hat[:,:,2048:4096] for the non-overlapping portion.
|
|
"""
|
|
from orpheus_tts.decoder import convert_to_audio as _original_convert
|
|
|
|
buffer = []
|
|
total_codes = 0
|
|
|
|
for code in code_stream:
|
|
buffer.append(code)
|
|
total_codes += 1
|
|
|
|
if total_codes == 1:
|
|
print(f"[snac] First code received: {code}")
|
|
|
|
# The original decoder triggers every 7 codes after 28 minimum
|
|
if total_codes % SNAC_CHUNK_SIZE == 0 and total_codes > 27:
|
|
# Pass the last 28 codes, matching the original exactly
|
|
audio_bytes = _original_convert(buffer[-28:], total_codes)
|
|
if audio_bytes is not None:
|
|
if total_codes == 28:
|
|
print(f"[snac] First audio at {total_codes} codes")
|
|
yield audio_bytes
|
|
|
|
print(f"[snac] Total codes: {total_codes}")
|
|
|
|
|
|
# ============================================================================
|
|
# High-level generation functions
|
|
# ============================================================================
|
|
|
|
def generate_stream(text: str, voice: str) -> Generator[bytes, None, None]:
|
|
"""Full streaming pipeline: text → tokens → audio chunks."""
|
|
voice = voice if voice in BUILTIN_VOICES else DEFAULT_VOICE
|
|
total_bytes = 0
|
|
for chunk in chunk_text(text):
|
|
print(f"[stream] Generating: {chunk[:80]}...")
|
|
t0 = time.time()
|
|
first_audio = True
|
|
try:
|
|
text_stream = generate_tokens_streaming(chunk, voice)
|
|
code_stream = extract_audio_codes(text_stream)
|
|
for pcm_chunk in decode_audio_stream(code_stream):
|
|
if first_audio:
|
|
print(f"[stream] First audio in {time.time() - t0:.2f}s ({len(pcm_chunk)} bytes)")
|
|
first_audio = False
|
|
total_bytes += len(pcm_chunk)
|
|
yield pcm_chunk
|
|
except Exception as e:
|
|
print(f"[stream] ERROR in pipeline: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
print(f"[stream] Done — {total_bytes} bytes total")
|
|
|
|
|
|
def generate_speech_sync(text: str, voice: str) -> bytes:
|
|
"""Generate complete audio (for job system). Returns WAV bytes."""
|
|
voice = voice if voice in BUILTIN_VOICES else DEFAULT_VOICE
|
|
all_pcm = []
|
|
for chunk in chunk_text(text):
|
|
print(f"Generating: {chunk[:80]}...")
|
|
text_stream = generate_tokens_streaming(chunk, voice)
|
|
code_stream = extract_audio_codes(text_stream)
|
|
for pcm_chunk in decode_audio_stream(code_stream):
|
|
all_pcm.append(pcm_chunk)
|
|
|
|
if not all_pcm:
|
|
raise ValueError("No audio generated")
|
|
|
|
audio_bytes_raw = b''.join(all_pcm)
|
|
buffer = io.BytesIO()
|
|
with wave.open(buffer, 'wb') as wf:
|
|
wf.setnchannels(1)
|
|
wf.setsampwidth(2)
|
|
wf.setframerate(SAMPLE_RATE)
|
|
wf.writeframes(audio_bytes_raw)
|
|
return buffer.getvalue()
|
|
|
|
|
|
def save_audio_to_file(job_id: str, audio_bytes: bytes) -> str:
|
|
output_path = OUTPUT_DIR / f"{job_id}.wav"
|
|
with open(output_path, 'wb') as f:
|
|
f.write(audio_bytes)
|
|
return str(output_path)
|
|
|
|
|
|
async def generate_speech_background(job_id: str, text: str, voice: str):
|
|
"""Background task for speech generation."""
|
|
try:
|
|
jobs[job_id].status = JobStatus.PROCESSING
|
|
jobs[job_id].progress = 25
|
|
save_jobs_to_disk()
|
|
|
|
cache_key = hash_text_voice(text, voice)
|
|
cached_path = get_from_cache(cache_key)
|
|
if cached_path:
|
|
jobs[job_id].audio_path = cached_path
|
|
jobs[job_id].status = JobStatus.SUCCESS
|
|
jobs[job_id].progress = 100
|
|
jobs[job_id].cached = True
|
|
jobs[job_id].completed_at = datetime.now().isoformat()
|
|
save_jobs_to_disk()
|
|
return
|
|
|
|
jobs[job_id].progress = 50
|
|
save_jobs_to_disk()
|
|
|
|
audio_bytes = await asyncio.to_thread(generate_speech_sync, text, voice)
|
|
|
|
jobs[job_id].progress = 75
|
|
save_jobs_to_disk()
|
|
|
|
output_path = save_audio_to_file(job_id, audio_bytes)
|
|
save_to_cache(cache_key, output_path)
|
|
|
|
jobs[job_id].audio_path = output_path
|
|
jobs[job_id].status = JobStatus.SUCCESS
|
|
jobs[job_id].progress = 100
|
|
jobs[job_id].completed_at = datetime.now().isoformat()
|
|
save_jobs_to_disk()
|
|
|
|
except Exception as e:
|
|
print(f"Job {job_id} failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
jobs[job_id].status = JobStatus.FAILURE
|
|
jobs[job_id].error = str(e)
|
|
save_jobs_to_disk()
|
|
|
|
|
|
async def cleanup_old_jobs():
|
|
while True:
|
|
try:
|
|
await asyncio.sleep(CLEANUP_INTERVAL_HOURS * 3600)
|
|
cutoff = datetime.now() - timedelta(days=RETENTION_DAYS)
|
|
to_delete = []
|
|
for job_id, job in jobs.items():
|
|
try:
|
|
created = datetime.fromisoformat(job.created_at)
|
|
if created < cutoff:
|
|
if job.audio_path and Path(job.audio_path).exists():
|
|
Path(job.audio_path).unlink()
|
|
to_delete.append(job_id)
|
|
except:
|
|
pass
|
|
for job_id in to_delete:
|
|
del jobs[job_id]
|
|
if to_delete:
|
|
save_jobs_to_disk()
|
|
print(f"Cleanup: deleted {len(to_delete)} old jobs")
|
|
except Exception as e:
|
|
print(f"Error in cleanup task: {e}")
|
|
|
|
|
|
# ============================================================================
|
|
# Startup
|
|
# ============================================================================
|
|
|
|
@app.on_event("startup")
|
|
async def startup():
|
|
global llm_model, llm_tokenizer, snac_model
|
|
|
|
print("=" * 60)
|
|
print("OrpheusTail v2 — transformers streaming engine")
|
|
print(f"Model: {ORPHEUS_MODEL}")
|
|
print(f"Max Tokens: {MAX_TOKENS}")
|
|
print(f"Cache: {'Enabled' if CACHE_ENABLED else 'Disabled'}")
|
|
print(f"Default Voice: {DEFAULT_VOICE}")
|
|
print("=" * 60)
|
|
|
|
# Load LLM
|
|
print("Loading Orpheus LLM...")
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
llm_tokenizer = AutoTokenizer.from_pretrained(ORPHEUS_MODEL)
|
|
llm_model = AutoModelForCausalLM.from_pretrained(
|
|
ORPHEUS_MODEL,
|
|
torch_dtype=torch.bfloat16,
|
|
device_map="auto",
|
|
)
|
|
llm_model.eval()
|
|
print(f"✓ LLM loaded ({ORPHEUS_MODEL})")
|
|
|
|
# Load SNAC decoder
|
|
print("Loading SNAC audio codec...")
|
|
from snac import SNAC
|
|
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval()
|
|
if torch.cuda.is_available():
|
|
snac_model = snac_model.to("cuda")
|
|
print("✓ SNAC loaded")
|
|
|
|
# Load jobs from disk
|
|
load_jobs_from_disk()
|
|
|
|
# Start cleanup task
|
|
asyncio.create_task(cleanup_old_jobs())
|
|
|
|
print("✓ OrpheusTail v2 ready")
|
|
|
|
|
|
# ============================================================================
|
|
# Pydantic Models
|
|
# ============================================================================
|
|
|
|
class TTSRequest(BaseModel):
|
|
text: str
|
|
voice: str = DEFAULT_VOICE
|
|
|
|
class TTSStreamRequest(BaseModel):
|
|
text: str
|
|
voice: str = DEFAULT_VOICE
|
|
|
|
class JobResponse(BaseModel):
|
|
job_id: str
|
|
status: str
|
|
|
|
class StatusResponse(BaseModel):
|
|
job_id: str
|
|
status: str
|
|
progress: int
|
|
cached: bool = False
|
|
audio_url: Optional[str] = None
|
|
error: Optional[str] = None
|
|
|
|
class VoicesResponse(BaseModel):
|
|
builtin: List[str]
|
|
custom: List[str]
|
|
default: str
|
|
emotion_tags: List[str]
|
|
|
|
|
|
# ============================================================================
|
|
# Endpoints
|
|
# ============================================================================
|
|
|
|
@app.get("/")
|
|
def root():
|
|
return {
|
|
"service": "OrpheusTail - Orpheus TTS Service",
|
|
"version": "2.0.0",
|
|
"engine": "transformers (streaming)",
|
|
"model": ORPHEUS_MODEL,
|
|
"default_voice": DEFAULT_VOICE,
|
|
"emotion_tags": EMOTION_TAGS,
|
|
"endpoints": {
|
|
"/tts/submit": "POST - Submit TTS job",
|
|
"/tts/status/{job_id}": "GET - Check job status",
|
|
"/tts/audio/{job_id}": "GET - Download audio",
|
|
"/tts/stream": "POST - Stream audio (true token-level)",
|
|
"/voice/clone": "POST - Upload voice reference",
|
|
"/voices": "GET - List available voices",
|
|
"/health": "GET - Health check"
|
|
}
|
|
}
|
|
|
|
@app.get("/health")
|
|
def health():
|
|
return {
|
|
"status": "healthy",
|
|
"model_loaded": llm_model is not None and snac_model is not None,
|
|
"engine": "transformers",
|
|
"cache_enabled": CACHE_ENABLED,
|
|
"voices_available": len(BUILTIN_VOICES) + len(get_custom_voices())
|
|
}
|
|
|
|
@app.get("/voices", response_model=VoicesResponse)
|
|
def list_voices():
|
|
return VoicesResponse(
|
|
builtin=BUILTIN_VOICES,
|
|
custom=get_custom_voices(),
|
|
default=DEFAULT_VOICE,
|
|
emotion_tags=EMOTION_TAGS
|
|
)
|
|
|
|
|
|
# --- TTS endpoints ---
|
|
|
|
@app.post("/tts/submit", response_model=JobResponse)
|
|
async def submit_tts_job(request: TTSRequest):
|
|
job_id = str(uuid.uuid4())
|
|
job = JobInfo(
|
|
job_id=job_id, text=request.text, voice=request.voice,
|
|
status=JobStatus.PENDING, progress=0,
|
|
created_at=datetime.now().isoformat()
|
|
)
|
|
jobs[job_id] = job
|
|
save_jobs_to_disk()
|
|
asyncio.create_task(generate_speech_background(job_id, request.text, request.voice))
|
|
return JobResponse(job_id=job_id, status=JobStatus.PENDING)
|
|
|
|
|
|
@app.get("/tts/status/{job_id}", response_model=StatusResponse)
|
|
async def get_job_status(job_id: str):
|
|
if job_id not in jobs:
|
|
raise HTTPException(status_code=404, detail="Job not found")
|
|
job = jobs[job_id]
|
|
response = StatusResponse(
|
|
job_id=job_id, status=job.status, progress=job.progress, cached=job.cached)
|
|
if job.status == JobStatus.SUCCESS:
|
|
response.audio_url = f"/tts/audio/{job_id}"
|
|
elif job.status == JobStatus.FAILURE:
|
|
response.error = job.error
|
|
return response
|
|
|
|
|
|
@app.get("/tts/audio/{job_id}")
|
|
async def get_audio(job_id: str):
|
|
if job_id not in jobs:
|
|
raise HTTPException(status_code=404, detail="Job not found")
|
|
job = jobs[job_id]
|
|
if job.status != JobStatus.SUCCESS:
|
|
raise HTTPException(status_code=400, detail=f"Audio not ready. Status: {job.status}")
|
|
if not job.audio_path or not Path(job.audio_path).exists():
|
|
raise HTTPException(status_code=404, detail="Audio file not found")
|
|
return FileResponse(job.audio_path, media_type="audio/wav", filename=f"{job_id}.wav")
|
|
|
|
|
|
@app.post("/tts/stream")
|
|
async def stream_tts(request: TTSStreamRequest):
|
|
"""
|
|
Stream TTS audio with true token-level streaming.
|
|
First audio arrives after ~28 tokens (~1-2 seconds).
|
|
"""
|
|
if llm_model is None or snac_model is None:
|
|
raise HTTPException(status_code=503, detail="Model not loaded")
|
|
|
|
def audio_generator():
|
|
for pcm_chunk in generate_stream(request.text, request.voice):
|
|
yield pcm_chunk
|
|
|
|
return StreamingResponse(audio_generator(), media_type="audio/pcm")
|
|
|
|
|
|
# --- Voice endpoints ---
|
|
|
|
@app.post("/voice/clone")
|
|
async def upload_voice_reference(name: str, audio: UploadFile = File(...)):
|
|
"""Upload reference audio for voice cloning (saved for future LoRA fine-tuning)."""
|
|
if not name.isalnum():
|
|
raise HTTPException(status_code=400, detail="Voice name must be alphanumeric")
|
|
if name in BUILTIN_VOICES:
|
|
raise HTTPException(status_code=400, detail="Cannot overwrite built-in voice")
|
|
voice_path = VOICES_DIR / f"{name}.wav"
|
|
try:
|
|
content = await audio.read()
|
|
with open(voice_path, 'wb') as f:
|
|
f.write(content)
|
|
return {
|
|
"status": "saved",
|
|
"voice_name": name,
|
|
"message": f"Voice '{name}' saved. Note: the finetuned model uses 8 built-in voices. "
|
|
f"Custom voice cloning requires LoRA fine-tuning (coming soon).",
|
|
"builtin_voices": BUILTIN_VOICES,
|
|
}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Failed to save voice: {e}")
|
|
|
|
@app.delete("/voice/{name}")
|
|
async def delete_voice(name: str):
|
|
if name in BUILTIN_VOICES:
|
|
raise HTTPException(status_code=400, detail="Cannot delete built-in voice")
|
|
voice_path = VOICES_DIR / f"{name}.wav"
|
|
if not voice_path.exists():
|
|
raise HTTPException(status_code=404, detail="Voice not found")
|
|
voice_path.unlink()
|
|
return {"status": "success", "message": f"Voice '{name}' deleted"}
|
|
|
|
@app.delete("/tts/job/{job_id}")
|
|
async def delete_job(job_id: str):
|
|
if job_id not in jobs:
|
|
raise HTTPException(status_code=404, detail="Job not found")
|
|
job = jobs[job_id]
|
|
if job.audio_path and Path(job.audio_path).exists():
|
|
try:
|
|
Path(job.audio_path).unlink()
|
|
except:
|
|
pass
|
|
del jobs[job_id]
|
|
save_jobs_to_disk()
|
|
return {"message": f"Job {job_id} deleted"}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run("main:app", host="0.0.0.0", port=8766, reload=False)
|