OrpheusTail v2: transformers streaming engine (replaces vLLM)
Major rework: replaced vLLM sync LLM with HuggingFace transformers + TextIteratorStreamer for true token-level streaming. Pipeline: text → format_prompt → model.generate(streamer) → extract_audio_codes (regex on streaming text) → SNAC decode → PCM Expected first-audio latency: ~1-2s (was 10-14s with vLLM). No more monkey-patching, no more AsyncLLMEngine hangs on Jetson. SNAC model loaded separately (snac_24khz) for audio decoding. All endpoints preserved, API compatible with v1. Voice cloning endpoint now honest about LoRA requirement. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
577
main.py
577
main.py
@@ -1,24 +1,18 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""
|
"""
|
||||||
OrpheusTail - Orpheus TTS Service
|
OrpheusTail - Orpheus TTS Service (v2 — transformers streaming)
|
||||||
|
|
||||||
FastAPI server for Orpheus text-to-speech generation on Jetson AGX Orin.
|
FastAPI server for Orpheus text-to-speech generation on Jetson AGX Orin.
|
||||||
Replaces VoiceTail (Bark) with better control, voice cloning, and emotion tags.
|
True token-level streaming: first audio in ~1-2s instead of 10-15s.
|
||||||
|
|
||||||
Key Features:
|
Key Features:
|
||||||
- Emotion tags: <laugh>, <chuckle>, <sigh>, <cough>, <sniffle>, <groan>, <yawn>, <gasp>
|
- Emotion tags: <laugh>, <chuckle>, <sigh>, <cough>, <sniffle>, <groan>, <yawn>, <gasp>
|
||||||
- Zero-shot voice cloning from reference audio
|
- True token-level streaming via HuggingFace transformers + TextIteratorStreamer
|
||||||
- Streaming support for real-time head playback
|
- SNAC codec decoding with streaming audio output
|
||||||
- Built-in voices: tara, leah, jess, leo, dan, mia, zac, zoe
|
- Built-in voices: tara, leah, jess, leo, dan, mia, zac, zoe
|
||||||
|
- Voice reference storage (for future LoRA fine-tuning)
|
||||||
|
|
||||||
Endpoints:
|
Engine: HuggingFace transformers (replaced vLLM — simpler, streams natively)
|
||||||
- POST /tts/submit - Submit TTS job (returns job_id)
|
|
||||||
- GET /tts/status/{job_id} - Check job status
|
|
||||||
- GET /tts/audio/{job_id} - Download generated audio
|
|
||||||
- POST /tts/stream - Stream audio in real-time (for head)
|
|
||||||
- POST /voice/clone - Upload reference audio for voice cloning
|
|
||||||
- GET /voices - List available voices
|
|
||||||
- GET /health - Health check
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@@ -29,12 +23,16 @@ import asyncio
|
|||||||
import uuid
|
import uuid
|
||||||
import wave
|
import wave
|
||||||
import io
|
import io
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional, Generator
|
||||||
from dataclasses import dataclass, asdict
|
from dataclasses import dataclass, asdict
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
import torch
|
||||||
from fastapi import FastAPI, BackgroundTasks, HTTPException, UploadFile, File
|
from fastapi import FastAPI, BackgroundTasks, HTTPException, UploadFile, File
|
||||||
from fastapi.responses import FileResponse, StreamingResponse
|
from fastapi.responses import FileResponse, StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -44,10 +42,10 @@ ORPHEUS_MODEL = os.getenv("ORPHEUS_MODEL", "canopylabs/orpheus-tts-0.1-finetune-
|
|||||||
CACHE_ENABLED = os.getenv("CACHE_ENABLED", "true").lower() == "true"
|
CACHE_ENABLED = os.getenv("CACHE_ENABLED", "true").lower() == "true"
|
||||||
CACHE_DIR = Path(os.getenv("CACHE_DIR", "cache"))
|
CACHE_DIR = Path(os.getenv("CACHE_DIR", "cache"))
|
||||||
OUTPUT_DIR = Path(os.getenv("OUTPUT_DIR", "output"))
|
OUTPUT_DIR = Path(os.getenv("OUTPUT_DIR", "output"))
|
||||||
VOICES_DIR = Path(os.getenv("VOICES_DIR", "voices")) # For cloned voice references
|
VOICES_DIR = Path(os.getenv("VOICES_DIR", "voices"))
|
||||||
RETENTION_DAYS = int(os.getenv("RETENTION_DAYS", "10"))
|
RETENTION_DAYS = int(os.getenv("RETENTION_DAYS", "10"))
|
||||||
CLEANUP_INTERVAL_HOURS = int(os.getenv("CLEANUP_INTERVAL_HOURS", "1"))
|
CLEANUP_INTERVAL_HOURS = int(os.getenv("CLEANUP_INTERVAL_HOURS", "1"))
|
||||||
DEFAULT_VOICE = os.getenv("DEFAULT_VOICE", "tara") # Orpheus default voice
|
DEFAULT_VOICE = os.getenv("DEFAULT_VOICE", "tara")
|
||||||
MAX_MODEL_LEN = int(os.getenv("MAX_MODEL_LEN", "8192"))
|
MAX_MODEL_LEN = int(os.getenv("MAX_MODEL_LEN", "8192"))
|
||||||
MAX_TOKENS = int(os.getenv("MAX_TOKENS", "8000"))
|
MAX_TOKENS = int(os.getenv("MAX_TOKENS", "8000"))
|
||||||
SAMPLE_RATE = 24000
|
SAMPLE_RATE = 24000
|
||||||
@@ -60,34 +58,48 @@ VOICES_DIR.mkdir(exist_ok=True)
|
|||||||
# Jobs persistence
|
# Jobs persistence
|
||||||
JOBS_FILE = OUTPUT_DIR / "jobs.json"
|
JOBS_FILE = OUTPUT_DIR / "jobs.json"
|
||||||
|
|
||||||
# Built-in Orpheus voices (in order of conversational realism per docs)
|
# Built-in Orpheus voices
|
||||||
BUILTIN_VOICES = ["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe"]
|
BUILTIN_VOICES = ["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe"]
|
||||||
|
|
||||||
# Supported emotion tags
|
# Supported emotion tags
|
||||||
EMOTION_TAGS = ["<laugh>", "<chuckle>", "<sigh>", "<cough>", "<sniffle>", "<groan>", "<yawn>", "<gasp>"]
|
EMOTION_TAGS = ["<laugh>", "<chuckle>", "<sigh>", "<cough>", "<sniffle>", "<groan>", "<yawn>", "<gasp>"]
|
||||||
|
|
||||||
|
# Orpheus special tokens
|
||||||
|
SPECIAL_TOKEN_START = 128259
|
||||||
|
SPECIAL_TOKENS_END = [128009, 128260, 128261, 128257]
|
||||||
|
EOS_TOKEN_ID = 128258
|
||||||
|
CODE_TOKEN_OFFSET = 128266 # audio code tokens start here
|
||||||
|
CODE_REMOVE_TOKEN_ID = 128258 # this token signals end, not audio
|
||||||
|
|
||||||
|
# SNAC streaming parameters
|
||||||
|
SNAC_CHUNK_SIZE = 7 # tokens per SNAC group
|
||||||
|
SNAC_INITIAL_GROUPS = 6 # wait for N groups before first decode (higher = smoother start)
|
||||||
|
|
||||||
# Initialize FastAPI
|
# Initialize FastAPI
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="OrpheusTail - Orpheus TTS Service",
|
title="OrpheusTail - Orpheus TTS Service",
|
||||||
description="Text-to-speech with emotion control and voice cloning for Vixy",
|
description="Text-to-speech with emotion control and streaming for Vixy",
|
||||||
version="1.0.0"
|
version="2.0.0"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Global model (loaded at startup)
|
# Global model references
|
||||||
model = None
|
llm_model = None
|
||||||
|
llm_tokenizer = None
|
||||||
|
snac_model = None
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Job System (unchanged from v1)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
class JobStatus(str, Enum):
|
class JobStatus(str, Enum):
|
||||||
"""Job status enum"""
|
|
||||||
PENDING = "PENDING"
|
PENDING = "PENDING"
|
||||||
PROCESSING = "PROCESSING"
|
PROCESSING = "PROCESSING"
|
||||||
SUCCESS = "SUCCESS"
|
SUCCESS = "SUCCESS"
|
||||||
FAILURE = "FAILURE"
|
FAILURE = "FAILURE"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class JobInfo:
|
class JobInfo:
|
||||||
"""Job information"""
|
|
||||||
job_id: str
|
job_id: str
|
||||||
text: str
|
text: str
|
||||||
voice: str
|
voice: str
|
||||||
@@ -99,13 +111,9 @@ class JobInfo:
|
|||||||
created_at: str = ""
|
created_at: str = ""
|
||||||
completed_at: Optional[str] = None
|
completed_at: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
# In-memory job storage
|
|
||||||
jobs: Dict[str, JobInfo] = {}
|
jobs: Dict[str, JobInfo] = {}
|
||||||
|
|
||||||
|
|
||||||
def load_jobs_from_disk():
|
def load_jobs_from_disk():
|
||||||
"""Load jobs from disk on startup"""
|
|
||||||
global jobs
|
global jobs
|
||||||
if JOBS_FILE.exists():
|
if JOBS_FILE.exists():
|
||||||
try:
|
try:
|
||||||
@@ -117,9 +125,7 @@ def load_jobs_from_disk():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error loading jobs: {e}")
|
print(f"Error loading jobs: {e}")
|
||||||
|
|
||||||
|
|
||||||
def save_jobs_to_disk():
|
def save_jobs_to_disk():
|
||||||
"""Save jobs to disk"""
|
|
||||||
try:
|
try:
|
||||||
data = {job_id: asdict(job) for job_id, job in jobs.items()}
|
data = {job_id: asdict(job) for job_id, job in jobs.items()}
|
||||||
with open(JOBS_FILE, 'w') as f:
|
with open(JOBS_FILE, 'w') as f:
|
||||||
@@ -128,14 +134,15 @@ def save_jobs_to_disk():
|
|||||||
print(f"Error saving jobs: {e}")
|
print(f"Error saving jobs: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Cache System (unchanged from v1)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
def hash_text_voice(text: str, voice: str) -> str:
|
def hash_text_voice(text: str, voice: str) -> str:
|
||||||
"""Generate cache key from text + voice"""
|
|
||||||
content = f"{text}|{voice}"
|
content = f"{text}|{voice}"
|
||||||
return hashlib.sha256(content.encode()).hexdigest()
|
return hashlib.sha256(content.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def get_from_cache(cache_key: str) -> Optional[str]:
|
def get_from_cache(cache_key: str) -> Optional[str]:
|
||||||
"""Check if audio exists in cache"""
|
|
||||||
if not CACHE_ENABLED:
|
if not CACHE_ENABLED:
|
||||||
return None
|
return None
|
||||||
cache_path = CACHE_DIR / f"{cache_key}.wav"
|
cache_path = CACHE_DIR / f"{cache_key}.wav"
|
||||||
@@ -144,9 +151,7 @@ def get_from_cache(cache_key: str) -> Optional[str]:
|
|||||||
return str(cache_path)
|
return str(cache_path)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def save_to_cache(cache_key: str, audio_path: str):
|
def save_to_cache(cache_key: str, audio_path: str):
|
||||||
"""Save generated audio to cache"""
|
|
||||||
if not CACHE_ENABLED:
|
if not CACHE_ENABLED:
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
@@ -157,145 +162,200 @@ def save_to_cache(cache_key: str, audio_path: str):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error saving to cache: {e}")
|
print(f"Error saving to cache: {e}")
|
||||||
|
|
||||||
|
|
||||||
def get_custom_voices() -> List[str]:
|
def get_custom_voices() -> List[str]:
|
||||||
"""Get list of custom cloned voices"""
|
return [f.stem for f in VOICES_DIR.glob("*.wav")]
|
||||||
voices = []
|
|
||||||
for voice_file in VOICES_DIR.glob("*.wav"):
|
|
||||||
voices.append(voice_file.stem)
|
|
||||||
return voices
|
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Text Chunking
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
def chunk_text(text: str, max_chars: int = 800) -> List[str]:
|
def chunk_text(text: str, max_chars: int = 800) -> List[str]:
|
||||||
"""
|
"""Split text into chunks at sentence boundaries."""
|
||||||
Split text into chunks at sentence boundaries for sequential TTS generation.
|
|
||||||
|
|
||||||
Ensures no chunk exceeds max_chars (unless a single sentence is longer).
|
|
||||||
Preserves emotion tags within sentences.
|
|
||||||
"""
|
|
||||||
# Split on sentence-ending punctuation followed by whitespace
|
|
||||||
sentences = re.split(r'(?<=[.!?])\s+', text.strip())
|
sentences = re.split(r'(?<=[.!?])\s+', text.strip())
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
current_chunk = []
|
current_chunk = []
|
||||||
current_len = 0
|
current_len = 0
|
||||||
|
|
||||||
for sentence in sentences:
|
for sentence in sentences:
|
||||||
sentence_len = len(sentence)
|
sentence_len = len(sentence)
|
||||||
# If adding this sentence would exceed the limit, finalize current chunk
|
|
||||||
if current_chunk and current_len + 1 + sentence_len > max_chars:
|
if current_chunk and current_len + 1 + sentence_len > max_chars:
|
||||||
chunks.append(' '.join(current_chunk))
|
chunks.append(' '.join(current_chunk))
|
||||||
current_chunk = []
|
current_chunk = []
|
||||||
current_len = 0
|
current_len = 0
|
||||||
current_chunk.append(sentence)
|
current_chunk.append(sentence)
|
||||||
current_len += sentence_len + (1 if current_len > 0 else 0)
|
current_len += sentence_len + (1 if current_len > 0 else 0)
|
||||||
|
|
||||||
# Don't forget the last chunk
|
|
||||||
if current_chunk:
|
if current_chunk:
|
||||||
chunks.append(' '.join(current_chunk))
|
chunks.append(' '.join(current_chunk))
|
||||||
|
|
||||||
return chunks if chunks else [text]
|
return chunks if chunks else [text]
|
||||||
|
|
||||||
|
|
||||||
def chunk_text_fine(text: str, max_chars: int = 200) -> List[str]:
|
# ============================================================================
|
||||||
"""
|
# Inference Engine (transformers — replaces vLLM)
|
||||||
Split text into fine-grained chunks for streaming — every sentence or clause.
|
# ============================================================================
|
||||||
Smaller chunks = faster first-audio, slight quality tradeoff at boundaries.
|
|
||||||
"""
|
|
||||||
# Split on sentence boundaries AND commas/semicolons with reasonable length
|
|
||||||
parts = re.split(r'(?<=[.!?;])\s+', text.strip())
|
|
||||||
|
|
||||||
# Further split long parts on commas
|
def format_prompt(text: str, voice: str) -> torch.Tensor:
|
||||||
chunks = []
|
"""Format prompt for Orpheus finetuned model. Returns input_ids tensor."""
|
||||||
for part in parts:
|
full_text = f"{voice}: {text}"
|
||||||
if len(part) <= max_chars:
|
input_ids = llm_tokenizer(full_text, return_tensors="pt").input_ids
|
||||||
chunks.append(part)
|
|
||||||
else:
|
|
||||||
# Split on commas
|
|
||||||
sub = re.split(r',\s+', part)
|
|
||||||
current = []
|
|
||||||
current_len = 0
|
|
||||||
for s in sub:
|
|
||||||
if current and current_len + len(s) > max_chars:
|
|
||||||
chunks.append(', '.join(current))
|
|
||||||
current = []
|
|
||||||
current_len = 0
|
|
||||||
current.append(s)
|
|
||||||
current_len += len(s)
|
|
||||||
if current:
|
|
||||||
chunks.append(', '.join(current))
|
|
||||||
|
|
||||||
# Filter empty chunks
|
# Wrap with Orpheus special tokens
|
||||||
return [c.strip() for c in chunks if c.strip()]
|
start = torch.tensor([[SPECIAL_TOKEN_START]], dtype=torch.long)
|
||||||
|
end = torch.tensor([SPECIAL_TOKENS_END], dtype=torch.long)
|
||||||
|
return torch.cat([start, input_ids, end], dim=1).to(llm_model.device)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_tokens_streaming(text: str, voice: str) -> Generator[str, None, None]:
|
||||||
|
"""Stream token strings as the model generates them."""
|
||||||
|
from transformers import TextIteratorStreamer
|
||||||
|
|
||||||
|
input_ids = format_prompt(text, voice)
|
||||||
|
|
||||||
|
streamer = TextIteratorStreamer(llm_tokenizer, skip_special_tokens=False)
|
||||||
|
|
||||||
|
gen_kwargs = dict(
|
||||||
|
input_ids=input_ids,
|
||||||
|
max_new_tokens=MAX_TOKENS,
|
||||||
|
temperature=0.6,
|
||||||
|
top_p=0.8,
|
||||||
|
repetition_penalty=1.3,
|
||||||
|
do_sample=True,
|
||||||
|
eos_token_id=EOS_TOKEN_ID,
|
||||||
|
streamer=streamer,
|
||||||
|
)
|
||||||
|
|
||||||
|
thread = threading.Thread(target=llm_model.generate, kwargs=gen_kwargs, daemon=True)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
for text_chunk in streamer:
|
||||||
|
yield text_chunk
|
||||||
|
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Token → Audio Pipeline
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
def extract_audio_codes(text_stream: Generator[str, None, None]) -> Generator[int, None, None]:
|
||||||
|
"""Extract SNAC audio codes from streamed text. Handles partial tokens across chunks."""
|
||||||
|
buffer = ""
|
||||||
|
for chunk in text_stream:
|
||||||
|
buffer += chunk
|
||||||
|
# Extract all complete <custom_token_XXXX> patterns
|
||||||
|
while True:
|
||||||
|
match = re.search(r'<custom_token_(\d+)>', buffer)
|
||||||
|
if not match:
|
||||||
|
break
|
||||||
|
token_id = int(match.group(1))
|
||||||
|
buffer = buffer[match.end():]
|
||||||
|
# Only yield actual audio codes (>= CODE_TOKEN_OFFSET)
|
||||||
|
if token_id >= CODE_TOKEN_OFFSET:
|
||||||
|
yield token_id - CODE_TOKEN_OFFSET
|
||||||
|
|
||||||
|
|
||||||
|
def redistribute_codes(codes: list) -> list:
|
||||||
|
"""Redistribute flat code list into SNAC's 3 hierarchical layers.
|
||||||
|
Each group of 7: [L1, L2a, L2b, L3a, L3b, L3c, L3d]"""
|
||||||
|
layer1, layer2, layer3 = [], [], []
|
||||||
|
for i in range(0, len(codes), SNAC_CHUNK_SIZE):
|
||||||
|
group = codes[i:i + SNAC_CHUNK_SIZE]
|
||||||
|
if len(group) < SNAC_CHUNK_SIZE:
|
||||||
|
break
|
||||||
|
layer1.append(group[0])
|
||||||
|
layer2.extend(group[1:3])
|
||||||
|
layer3.extend(group[3:7])
|
||||||
|
return [layer1, layer2, layer3]
|
||||||
|
|
||||||
|
|
||||||
|
def snac_decode(codes: list) -> Optional[bytes]:
|
||||||
|
"""Decode SNAC codes to PCM audio bytes."""
|
||||||
|
layers = redistribute_codes(codes)
|
||||||
|
if not layers[0]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
codes_tensor = [
|
||||||
|
torch.tensor(layer, device=snac_model.device, dtype=torch.long).unsqueeze(0)
|
||||||
|
for layer in layers
|
||||||
|
]
|
||||||
|
audio_hat = snac_model.decode(codes_tensor)
|
||||||
|
|
||||||
|
audio_np = audio_hat.squeeze().cpu().numpy()
|
||||||
|
audio_int16 = (audio_np * 32767).clip(-32768, 32767).astype(np.int16)
|
||||||
|
return audio_int16.tobytes()
|
||||||
|
|
||||||
|
|
||||||
|
def decode_audio_stream(code_stream: Generator[int, None, None]) -> Generator[bytes, None, None]:
|
||||||
|
"""Convert streaming SNAC codes to PCM audio chunks."""
|
||||||
|
buffer = []
|
||||||
|
group_count = 0
|
||||||
|
total_codes = 0
|
||||||
|
|
||||||
|
for code in code_stream:
|
||||||
|
buffer.append(code)
|
||||||
|
total_codes += 1
|
||||||
|
|
||||||
|
if total_codes % SNAC_CHUNK_SIZE == 0:
|
||||||
|
group_count += 1
|
||||||
|
|
||||||
|
if group_count >= SNAC_INITIAL_GROUPS and group_count % 1 == 0:
|
||||||
|
# Decode the last N groups
|
||||||
|
decode_size = SNAC_INITIAL_GROUPS * SNAC_CHUNK_SIZE
|
||||||
|
codes_to_decode = buffer[-decode_size:]
|
||||||
|
audio = snac_decode(codes_to_decode)
|
||||||
|
if audio:
|
||||||
|
# Yield only the NEW audio (avoid overlap)
|
||||||
|
# Each group of 7 codes produces ~2048 samples
|
||||||
|
new_samples = SNAC_CHUNK_SIZE * 293 # ~293 samples per code at 24kHz
|
||||||
|
yield audio[-new_samples * 2:] # *2 for 16-bit (2 bytes per sample)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# High-level generation functions
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
def generate_stream(text: str, voice: str) -> Generator[bytes, None, None]:
|
||||||
|
"""Full streaming pipeline: text → tokens → audio chunks."""
|
||||||
|
voice = voice if voice in BUILTIN_VOICES else DEFAULT_VOICE
|
||||||
|
for chunk in chunk_text(text):
|
||||||
|
print(f"[stream] Generating: {chunk[:80]}...")
|
||||||
|
t0 = time.time()
|
||||||
|
first_audio = True
|
||||||
|
text_stream = generate_tokens_streaming(chunk, voice)
|
||||||
|
code_stream = extract_audio_codes(text_stream)
|
||||||
|
for pcm_chunk in decode_audio_stream(code_stream):
|
||||||
|
if first_audio:
|
||||||
|
print(f"[stream] First audio in {time.time() - t0:.2f}s")
|
||||||
|
first_audio = False
|
||||||
|
yield pcm_chunk
|
||||||
|
print("[stream] Done")
|
||||||
|
|
||||||
|
|
||||||
def generate_speech_sync(text: str, voice: str) -> bytes:
|
def generate_speech_sync(text: str, voice: str) -> bytes:
|
||||||
"""
|
"""Generate complete audio (for job system). Returns WAV bytes."""
|
||||||
Generate speech using Orpheus model (synchronous).
|
voice = voice if voice in BUILTIN_VOICES else DEFAULT_VOICE
|
||||||
|
|
||||||
Args:
|
|
||||||
text: Text to convert (may include emotion tags)
|
|
||||||
voice: Voice name (built-in or custom)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
WAV audio bytes
|
|
||||||
"""
|
|
||||||
global model
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
# Check if it's a custom voice (needs reference audio)
|
|
||||||
custom_voice_path = VOICES_DIR / f"{voice}.wav"
|
|
||||||
|
|
||||||
if custom_voice_path.exists():
|
|
||||||
print(f"Custom voice '{voice}' - voice cloning to be implemented")
|
|
||||||
voice = DEFAULT_VOICE
|
|
||||||
elif voice not in BUILTIN_VOICES:
|
|
||||||
print(f"Unknown voice '{voice}', using default '{DEFAULT_VOICE}'")
|
|
||||||
voice = DEFAULT_VOICE
|
|
||||||
|
|
||||||
# Split long text into chunks for sequential generation
|
|
||||||
text_chunks = chunk_text(text)
|
|
||||||
print(f"Generating: {text} ({len(text_chunks)} chunk(s))")
|
|
||||||
|
|
||||||
all_pcm = []
|
all_pcm = []
|
||||||
|
for chunk in chunk_text(text):
|
||||||
for chunk_idx, chunk in enumerate(text_chunks):
|
print(f"Generating: {chunk[:80]}...")
|
||||||
print(f"Chunk {chunk_idx + 1}/{len(text_chunks)}: {chunk[:80]}...")
|
text_stream = generate_tokens_streaming(chunk, voice)
|
||||||
audio_chunks = []
|
code_stream = extract_audio_codes(text_stream)
|
||||||
|
for pcm_chunk in decode_audio_stream(code_stream):
|
||||||
syn_tokens = model.generate_speech(
|
all_pcm.append(pcm_chunk)
|
||||||
prompt=chunk,
|
|
||||||
voice=voice,
|
|
||||||
max_tokens=MAX_TOKENS,
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, audio_chunk in enumerate(syn_tokens):
|
|
||||||
audio_chunks.append(audio_chunk)
|
|
||||||
|
|
||||||
print(f" -> {len(audio_chunks)} audio frames")
|
|
||||||
|
|
||||||
if audio_chunks:
|
|
||||||
all_pcm.append(b''.join(audio_chunks))
|
|
||||||
|
|
||||||
if not all_pcm:
|
if not all_pcm:
|
||||||
raise ValueError("No audio chunks generated")
|
raise ValueError("No audio generated")
|
||||||
|
|
||||||
# Concatenate raw PCM from all chunks, wrap in single WAV
|
|
||||||
audio_bytes_raw = b''.join(all_pcm)
|
audio_bytes_raw = b''.join(all_pcm)
|
||||||
|
|
||||||
buffer = io.BytesIO()
|
buffer = io.BytesIO()
|
||||||
with wave.open(buffer, 'wb') as wf:
|
with wave.open(buffer, 'wb') as wf:
|
||||||
wf.setnchannels(1)
|
wf.setnchannels(1)
|
||||||
wf.setsampwidth(2) # 16-bit
|
wf.setsampwidth(2)
|
||||||
wf.setframerate(SAMPLE_RATE)
|
wf.setframerate(SAMPLE_RATE)
|
||||||
wf.writeframes(audio_bytes_raw)
|
wf.writeframes(audio_bytes_raw)
|
||||||
|
|
||||||
print(f"Generated WAV: {len(buffer.getvalue())} bytes")
|
|
||||||
return buffer.getvalue()
|
return buffer.getvalue()
|
||||||
|
|
||||||
|
|
||||||
def save_audio_to_file(job_id: str, audio_bytes: bytes) -> str:
|
def save_audio_to_file(job_id: str, audio_bytes: bytes) -> str:
|
||||||
"""Save audio bytes to WAV file."""
|
|
||||||
output_path = OUTPUT_DIR / f"{job_id}.wav"
|
output_path = OUTPUT_DIR / f"{job_id}.wav"
|
||||||
with open(output_path, 'wb') as f:
|
with open(output_path, 'wb') as f:
|
||||||
f.write(audio_bytes)
|
f.write(audio_bytes)
|
||||||
@@ -303,16 +363,14 @@ def save_audio_to_file(job_id: str, audio_bytes: bytes) -> str:
|
|||||||
|
|
||||||
|
|
||||||
async def generate_speech_background(job_id: str, text: str, voice: str):
|
async def generate_speech_background(job_id: str, text: str, voice: str):
|
||||||
"""Background task for speech generation (async)."""
|
"""Background task for speech generation."""
|
||||||
try:
|
try:
|
||||||
jobs[job_id].status = JobStatus.PROCESSING
|
jobs[job_id].status = JobStatus.PROCESSING
|
||||||
jobs[job_id].progress = 25
|
jobs[job_id].progress = 25
|
||||||
save_jobs_to_disk()
|
save_jobs_to_disk()
|
||||||
|
|
||||||
# Check cache first
|
|
||||||
cache_key = hash_text_voice(text, voice)
|
cache_key = hash_text_voice(text, voice)
|
||||||
cached_path = get_from_cache(cache_key)
|
cached_path = get_from_cache(cache_key)
|
||||||
|
|
||||||
if cached_path:
|
if cached_path:
|
||||||
jobs[job_id].audio_path = cached_path
|
jobs[job_id].audio_path = cached_path
|
||||||
jobs[job_id].status = JobStatus.SUCCESS
|
jobs[job_id].status = JobStatus.SUCCESS
|
||||||
@@ -320,34 +378,25 @@ async def generate_speech_background(job_id: str, text: str, voice: str):
|
|||||||
jobs[job_id].cached = True
|
jobs[job_id].cached = True
|
||||||
jobs[job_id].completed_at = datetime.now().isoformat()
|
jobs[job_id].completed_at = datetime.now().isoformat()
|
||||||
save_jobs_to_disk()
|
save_jobs_to_disk()
|
||||||
print(f"Job {job_id} completed from cache")
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# Generate audio - call sync function directly (blocks but let's test if it works)
|
|
||||||
jobs[job_id].progress = 50
|
jobs[job_id].progress = 50
|
||||||
save_jobs_to_disk()
|
save_jobs_to_disk()
|
||||||
|
|
||||||
print(f"Generating audio for job {job_id}...")
|
audio_bytes = await asyncio.to_thread(generate_speech_sync, text, voice)
|
||||||
audio_bytes = generate_speech_sync(text, voice)
|
|
||||||
|
|
||||||
# Save to file
|
|
||||||
jobs[job_id].progress = 75
|
jobs[job_id].progress = 75
|
||||||
save_jobs_to_disk()
|
save_jobs_to_disk()
|
||||||
|
|
||||||
output_path = save_audio_to_file(job_id, audio_bytes)
|
output_path = save_audio_to_file(job_id, audio_bytes)
|
||||||
|
|
||||||
# Save to cache
|
|
||||||
save_to_cache(cache_key, output_path)
|
save_to_cache(cache_key, output_path)
|
||||||
|
|
||||||
# Complete
|
|
||||||
jobs[job_id].audio_path = output_path
|
jobs[job_id].audio_path = output_path
|
||||||
jobs[job_id].status = JobStatus.SUCCESS
|
jobs[job_id].status = JobStatus.SUCCESS
|
||||||
jobs[job_id].progress = 100
|
jobs[job_id].progress = 100
|
||||||
jobs[job_id].completed_at = datetime.now().isoformat()
|
jobs[job_id].completed_at = datetime.now().isoformat()
|
||||||
save_jobs_to_disk()
|
save_jobs_to_disk()
|
||||||
|
|
||||||
print(f"Job {job_id} completed successfully")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Job {job_id} failed: {e}")
|
print(f"Job {job_id} failed: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
@@ -358,12 +407,10 @@ async def generate_speech_background(job_id: str, text: str, voice: str):
|
|||||||
|
|
||||||
|
|
||||||
async def cleanup_old_jobs():
|
async def cleanup_old_jobs():
|
||||||
"""Background task to cleanup old jobs and files."""
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
await asyncio.sleep(CLEANUP_INTERVAL_HOURS * 3600)
|
await asyncio.sleep(CLEANUP_INTERVAL_HOURS * 3600)
|
||||||
cutoff = datetime.now() - timedelta(days=RETENTION_DAYS)
|
cutoff = datetime.now() - timedelta(days=RETENTION_DAYS)
|
||||||
|
|
||||||
to_delete = []
|
to_delete = []
|
||||||
for job_id, job in jobs.items():
|
for job_id, job in jobs.items():
|
||||||
try:
|
try:
|
||||||
@@ -374,76 +421,50 @@ async def cleanup_old_jobs():
|
|||||||
to_delete.append(job_id)
|
to_delete.append(job_id)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
for job_id in to_delete:
|
for job_id in to_delete:
|
||||||
del jobs[job_id]
|
del jobs[job_id]
|
||||||
|
|
||||||
if to_delete:
|
if to_delete:
|
||||||
save_jobs_to_disk()
|
save_jobs_to_disk()
|
||||||
print(f"Cleanup: deleted {len(to_delete)} old jobs")
|
print(f"Cleanup: deleted {len(to_delete)} old jobs")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error in cleanup task: {e}")
|
print(f"Error in cleanup task: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Startup
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def startup():
|
async def startup():
|
||||||
"""Load model and jobs on startup"""
|
global llm_model, llm_tokenizer, snac_model
|
||||||
global model
|
|
||||||
|
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print("OrpheusTail - Orpheus TTS Service Starting")
|
print("OrpheusTail v2 — transformers streaming engine")
|
||||||
print(f"Model: {ORPHEUS_MODEL}")
|
print(f"Model: {ORPHEUS_MODEL}")
|
||||||
print(f"Max Model Len: {MAX_MODEL_LEN}")
|
|
||||||
print(f"Max Tokens: {MAX_TOKENS}")
|
print(f"Max Tokens: {MAX_TOKENS}")
|
||||||
print(f"Cache: {'Enabled' if CACHE_ENABLED else 'Disabled'}")
|
print(f"Cache: {'Enabled' if CACHE_ENABLED else 'Disabled'}")
|
||||||
print(f"Default Voice: {DEFAULT_VOICE}")
|
print(f"Default Voice: {DEFAULT_VOICE}")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
# Import and load Orpheus model
|
# Load LLM
|
||||||
print("Loading Orpheus model (this may take a moment)...")
|
print("Loading Orpheus LLM...")
|
||||||
from orpheus_tts import OrpheusModel
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
from vllm import AsyncLLMEngine
|
llm_tokenizer = AutoTokenizer.from_pretrained(ORPHEUS_MODEL)
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
llm_model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
ORPHEUS_MODEL,
|
||||||
# Monkey-patch OrpheusModel to use sync LLM (AsyncLLMEngine hangs on Jetson)
|
torch_dtype=torch.bfloat16,
|
||||||
original_setup_engine = OrpheusModel._setup_engine
|
device_map="auto",
|
||||||
def patched_setup_engine(self):
|
|
||||||
model_name = self._map_model_params(self.model_name)
|
|
||||||
from vllm import LLM
|
|
||||||
return LLM(
|
|
||||||
model=model_name,
|
|
||||||
max_model_len=MAX_MODEL_LEN,
|
|
||||||
gpu_memory_utilization=0.85,
|
|
||||||
enforce_eager=False,
|
|
||||||
)
|
)
|
||||||
OrpheusModel._setup_engine = patched_setup_engine
|
llm_model.eval()
|
||||||
|
print(f"✓ LLM loaded ({ORPHEUS_MODEL})")
|
||||||
|
|
||||||
# Sync token generation
|
# Load SNAC decoder
|
||||||
def patched_generate_tokens_sync(self, prompt, voice=None, request_id="req-001",
|
print("Loading SNAC audio codec...")
|
||||||
temperature=0.6, top_p=0.8, max_tokens=MAX_TOKENS,
|
from snac import SNAC
|
||||||
stop_token_ids=[49158], repetition_penalty=1.3):
|
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval()
|
||||||
from vllm import SamplingParams
|
if torch.cuda.is_available():
|
||||||
import re
|
snac_model = snac_model.to("cuda")
|
||||||
prompt_string = self._format_prompt(prompt, voice)
|
print("✓ SNAC loaded")
|
||||||
print(prompt)
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
temperature=temperature, top_p=top_p, max_tokens=max_tokens,
|
|
||||||
stop_token_ids=stop_token_ids, repetition_penalty=repetition_penalty,
|
|
||||||
)
|
|
||||||
outputs = self.engine.generate([prompt_string], sampling_params)
|
|
||||||
for output in outputs:
|
|
||||||
text = output.outputs[0].text
|
|
||||||
print(f"Raw output (first 500 chars): {text[:500]}")
|
|
||||||
tokens = re.findall(r'<custom_token_\d+>', text)
|
|
||||||
print(f"Found {len(tokens)} tokens")
|
|
||||||
for token in tokens:
|
|
||||||
yield token
|
|
||||||
OrpheusModel.generate_tokens_sync = patched_generate_tokens_sync
|
|
||||||
|
|
||||||
model = OrpheusModel(model_name=ORPHEUS_MODEL)
|
|
||||||
|
|
||||||
print("✓ Orpheus model loaded successfully")
|
|
||||||
|
|
||||||
# Load jobs from disk
|
# Load jobs from disk
|
||||||
load_jobs_from_disk()
|
load_jobs_from_disk()
|
||||||
@@ -451,29 +472,26 @@ async def startup():
|
|||||||
# Start cleanup task
|
# Start cleanup task
|
||||||
asyncio.create_task(cleanup_old_jobs())
|
asyncio.create_task(cleanup_old_jobs())
|
||||||
|
|
||||||
|
print("✓ OrpheusTail v2 ready")
|
||||||
|
|
||||||
# === Pydantic Models ===
|
|
||||||
|
# ============================================================================
|
||||||
|
# Pydantic Models
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
class TTSRequest(BaseModel):
|
class TTSRequest(BaseModel):
|
||||||
"""TTS job submission request"""
|
|
||||||
text: str
|
text: str
|
||||||
voice: str = DEFAULT_VOICE
|
voice: str = DEFAULT_VOICE
|
||||||
|
|
||||||
|
|
||||||
class TTSStreamRequest(BaseModel):
|
class TTSStreamRequest(BaseModel):
|
||||||
"""TTS streaming request (for head playback)"""
|
|
||||||
text: str
|
text: str
|
||||||
voice: str = DEFAULT_VOICE
|
voice: str = DEFAULT_VOICE
|
||||||
|
|
||||||
|
|
||||||
class JobResponse(BaseModel):
|
class JobResponse(BaseModel):
|
||||||
"""Job submission response"""
|
|
||||||
job_id: str
|
job_id: str
|
||||||
status: str
|
status: str
|
||||||
|
|
||||||
|
|
||||||
class StatusResponse(BaseModel):
|
class StatusResponse(BaseModel):
|
||||||
"""Job status response"""
|
|
||||||
job_id: str
|
job_id: str
|
||||||
status: str
|
status: str
|
||||||
progress: int
|
progress: int
|
||||||
@@ -481,23 +499,23 @@ class StatusResponse(BaseModel):
|
|||||||
audio_url: Optional[str] = None
|
audio_url: Optional[str] = None
|
||||||
error: Optional[str] = None
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class VoicesResponse(BaseModel):
|
class VoicesResponse(BaseModel):
|
||||||
"""Available voices response"""
|
|
||||||
builtin: List[str]
|
builtin: List[str]
|
||||||
custom: List[str]
|
custom: List[str]
|
||||||
default: str
|
default: str
|
||||||
emotion_tags: List[str]
|
emotion_tags: List[str]
|
||||||
|
|
||||||
|
|
||||||
# === Endpoints ===
|
# ============================================================================
|
||||||
|
# Endpoints
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
def root():
|
def root():
|
||||||
"""Root endpoint"""
|
|
||||||
return {
|
return {
|
||||||
"service": "OrpheusTail - Orpheus TTS Service",
|
"service": "OrpheusTail - Orpheus TTS Service",
|
||||||
"version": "1.0.0",
|
"version": "2.0.0",
|
||||||
|
"engine": "transformers (streaming)",
|
||||||
"model": ORPHEUS_MODEL,
|
"model": ORPHEUS_MODEL,
|
||||||
"default_voice": DEFAULT_VOICE,
|
"default_voice": DEFAULT_VOICE,
|
||||||
"emotion_tags": EMOTION_TAGS,
|
"emotion_tags": EMOTION_TAGS,
|
||||||
@@ -505,28 +523,25 @@ def root():
|
|||||||
"/tts/submit": "POST - Submit TTS job",
|
"/tts/submit": "POST - Submit TTS job",
|
||||||
"/tts/status/{job_id}": "GET - Check job status",
|
"/tts/status/{job_id}": "GET - Check job status",
|
||||||
"/tts/audio/{job_id}": "GET - Download audio",
|
"/tts/audio/{job_id}": "GET - Download audio",
|
||||||
"/tts/stream": "POST - Stream audio (for head)",
|
"/tts/stream": "POST - Stream audio (true token-level)",
|
||||||
"/voice/clone": "POST - Upload voice reference",
|
"/voice/clone": "POST - Upload voice reference",
|
||||||
"/voices": "GET - List available voices",
|
"/voices": "GET - List available voices",
|
||||||
"/health": "GET - Health check"
|
"/health": "GET - Health check"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
def health():
|
def health():
|
||||||
"""Health check"""
|
|
||||||
return {
|
return {
|
||||||
"status": "healthy",
|
"status": "healthy",
|
||||||
"model_loaded": model is not None,
|
"model_loaded": llm_model is not None and snac_model is not None,
|
||||||
|
"engine": "transformers",
|
||||||
"cache_enabled": CACHE_ENABLED,
|
"cache_enabled": CACHE_ENABLED,
|
||||||
"voices_available": len(BUILTIN_VOICES) + len(get_custom_voices())
|
"voices_available": len(BUILTIN_VOICES) + len(get_custom_voices())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/voices", response_model=VoicesResponse)
|
@app.get("/voices", response_model=VoicesResponse)
|
||||||
def list_voices():
|
def list_voices():
|
||||||
"""List all available voices"""
|
|
||||||
return VoicesResponse(
|
return VoicesResponse(
|
||||||
builtin=BUILTIN_VOICES,
|
builtin=BUILTIN_VOICES,
|
||||||
custom=get_custom_voices(),
|
custom=get_custom_voices(),
|
||||||
@@ -535,197 +550,113 @@ def list_voices():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# --- TTS endpoints ---
|
||||||
|
|
||||||
@app.post("/tts/submit", response_model=JobResponse)
|
@app.post("/tts/submit", response_model=JobResponse)
|
||||||
async def submit_tts_job(request: TTSRequest):
|
async def submit_tts_job(request: TTSRequest):
|
||||||
"""Submit a TTS job for processing."""
|
|
||||||
job_id = str(uuid.uuid4())
|
job_id = str(uuid.uuid4())
|
||||||
|
|
||||||
job = JobInfo(
|
job = JobInfo(
|
||||||
job_id=job_id,
|
job_id=job_id, text=request.text, voice=request.voice,
|
||||||
text=request.text,
|
status=JobStatus.PENDING, progress=0,
|
||||||
voice=request.voice,
|
|
||||||
status=JobStatus.PENDING,
|
|
||||||
progress=0,
|
|
||||||
created_at=datetime.now().isoformat()
|
created_at=datetime.now().isoformat()
|
||||||
)
|
)
|
||||||
|
|
||||||
jobs[job_id] = job
|
jobs[job_id] = job
|
||||||
save_jobs_to_disk()
|
save_jobs_to_disk()
|
||||||
|
asyncio.create_task(generate_speech_background(job_id, request.text, request.voice))
|
||||||
# Use asyncio.create_task for proper async execution
|
|
||||||
asyncio.create_task(
|
|
||||||
generate_speech_background(job_id, request.text, request.voice)
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Job {job_id} submitted: '{request.text[:50]}...' with voice '{request.voice}'")
|
|
||||||
|
|
||||||
return JobResponse(job_id=job_id, status=JobStatus.PENDING)
|
return JobResponse(job_id=job_id, status=JobStatus.PENDING)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/tts/status/{job_id}", response_model=StatusResponse)
|
@app.get("/tts/status/{job_id}", response_model=StatusResponse)
|
||||||
async def get_job_status(job_id: str):
|
async def get_job_status(job_id: str):
|
||||||
"""Get status of a TTS job."""
|
|
||||||
if job_id not in jobs:
|
if job_id not in jobs:
|
||||||
raise HTTPException(status_code=404, detail="Job not found")
|
raise HTTPException(status_code=404, detail="Job not found")
|
||||||
|
|
||||||
job = jobs[job_id]
|
job = jobs[job_id]
|
||||||
|
|
||||||
response = StatusResponse(
|
response = StatusResponse(
|
||||||
job_id=job_id,
|
job_id=job_id, status=job.status, progress=job.progress, cached=job.cached)
|
||||||
status=job.status,
|
|
||||||
progress=job.progress,
|
|
||||||
cached=job.cached
|
|
||||||
)
|
|
||||||
|
|
||||||
if job.status == JobStatus.SUCCESS:
|
if job.status == JobStatus.SUCCESS:
|
||||||
response.audio_url = f"/tts/audio/{job_id}"
|
response.audio_url = f"/tts/audio/{job_id}"
|
||||||
elif job.status == JobStatus.FAILURE:
|
elif job.status == JobStatus.FAILURE:
|
||||||
response.error = job.error
|
response.error = job.error
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
@app.get("/tts/audio/{job_id}")
|
@app.get("/tts/audio/{job_id}")
|
||||||
async def get_audio(job_id: str):
|
async def get_audio(job_id: str):
|
||||||
"""Retrieve generated audio file."""
|
|
||||||
if job_id not in jobs:
|
if job_id not in jobs:
|
||||||
raise HTTPException(status_code=404, detail="Job not found")
|
raise HTTPException(status_code=404, detail="Job not found")
|
||||||
|
|
||||||
job = jobs[job_id]
|
job = jobs[job_id]
|
||||||
|
|
||||||
if job.status != JobStatus.SUCCESS:
|
if job.status != JobStatus.SUCCESS:
|
||||||
raise HTTPException(
|
raise HTTPException(status_code=400, detail=f"Audio not ready. Status: {job.status}")
|
||||||
status_code=400,
|
|
||||||
detail=f"Audio not ready. Job status: {job.status}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not job.audio_path or not Path(job.audio_path).exists():
|
if not job.audio_path or not Path(job.audio_path).exists():
|
||||||
raise HTTPException(status_code=404, detail="Audio file not found")
|
raise HTTPException(status_code=404, detail="Audio file not found")
|
||||||
|
return FileResponse(job.audio_path, media_type="audio/wav", filename=f"{job_id}.wav")
|
||||||
return FileResponse(
|
|
||||||
job.audio_path,
|
|
||||||
media_type="audio/wav",
|
|
||||||
filename=f"{job_id}.wav"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/tts/stream")
|
@app.post("/tts/stream")
|
||||||
async def stream_tts(request: TTSStreamRequest):
|
async def stream_tts(request: TTSStreamRequest):
|
||||||
"""
|
"""
|
||||||
Stream TTS audio with sentence-level chunking.
|
Stream TTS audio with true token-level streaming.
|
||||||
|
First audio arrives after ~28 tokens (~1-2 seconds).
|
||||||
Splits text into small chunks (sentences/clauses) and generates each
|
|
||||||
independently. First chunk's audio starts playing while later chunks
|
|
||||||
are still generating. Reduces perceived latency significantly.
|
|
||||||
"""
|
"""
|
||||||
global model
|
if llm_model is None or snac_model is None:
|
||||||
|
|
||||||
if model is None:
|
|
||||||
raise HTTPException(status_code=503, detail="Model not loaded")
|
raise HTTPException(status_code=503, detail="Model not loaded")
|
||||||
|
|
||||||
voice = request.voice
|
def audio_generator():
|
||||||
if voice not in BUILTIN_VOICES:
|
for pcm_chunk in generate_stream(request.text, request.voice):
|
||||||
voice = DEFAULT_VOICE
|
yield pcm_chunk
|
||||||
|
|
||||||
def sync_audio_generator():
|
return StreamingResponse(audio_generator(), media_type="audio/pcm")
|
||||||
"""Generate audio per-sentence, yielding as each finishes."""
|
|
||||||
try:
|
|
||||||
# Split into fine-grained chunks for faster first-audio
|
|
||||||
text_chunks = chunk_text_fine(request.text)
|
|
||||||
print(f"[stream] {len(text_chunks)} chunk(s): {[c[:40] for c in text_chunks]}")
|
|
||||||
for chunk_idx, chunk in enumerate(text_chunks):
|
|
||||||
print(f" Generating chunk {chunk_idx + 1}/{len(text_chunks)}: {chunk[:60]}...")
|
|
||||||
syn_tokens = model.generate_speech(
|
|
||||||
prompt=chunk,
|
|
||||||
voice=voice,
|
|
||||||
max_tokens=MAX_TOKENS,
|
|
||||||
)
|
|
||||||
for audio_chunk in syn_tokens:
|
|
||||||
yield audio_chunk
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Stream error: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
sync_audio_generator(),
|
|
||||||
media_type="audio/pcm"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# --- Voice endpoints ---
|
||||||
|
|
||||||
@app.post("/voice/clone")
|
@app.post("/voice/clone")
|
||||||
async def upload_voice_reference(
|
async def upload_voice_reference(name: str, audio: UploadFile = File(...)):
|
||||||
name: str,
|
"""Upload reference audio for voice cloning (saved for future LoRA fine-tuning)."""
|
||||||
audio: UploadFile = File(...),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Upload a reference audio file for voice cloning.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Name for this custom voice
|
|
||||||
audio: WAV audio file (5-30 seconds recommended)
|
|
||||||
"""
|
|
||||||
if not name.isalnum():
|
if not name.isalnum():
|
||||||
raise HTTPException(status_code=400, detail="Voice name must be alphanumeric")
|
raise HTTPException(status_code=400, detail="Voice name must be alphanumeric")
|
||||||
|
|
||||||
if name in BUILTIN_VOICES:
|
if name in BUILTIN_VOICES:
|
||||||
raise HTTPException(status_code=400, detail="Cannot overwrite built-in voice")
|
raise HTTPException(status_code=400, detail="Cannot overwrite built-in voice")
|
||||||
|
|
||||||
# Save the reference audio
|
|
||||||
voice_path = VOICES_DIR / f"{name}.wav"
|
voice_path = VOICES_DIR / f"{name}.wav"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
content = await audio.read()
|
content = await audio.read()
|
||||||
with open(voice_path, 'wb') as f:
|
with open(voice_path, 'wb') as f:
|
||||||
f.write(content)
|
f.write(content)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "success",
|
"status": "saved",
|
||||||
"voice_name": name,
|
"voice_name": name,
|
||||||
"message": f"Voice '{name}' saved. Use voice='{name}' in TTS requests."
|
"message": f"Voice '{name}' saved. Note: the finetuned model uses 8 built-in voices. "
|
||||||
|
f"Custom voice cloning requires LoRA fine-tuning (coming soon).",
|
||||||
|
"builtin_voices": BUILTIN_VOICES,
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to save voice: {e}")
|
raise HTTPException(status_code=500, detail=f"Failed to save voice: {e}")
|
||||||
|
|
||||||
|
|
||||||
@app.delete("/voice/{name}")
|
@app.delete("/voice/{name}")
|
||||||
async def delete_voice(name: str):
|
async def delete_voice(name: str):
|
||||||
"""Delete a custom voice."""
|
|
||||||
if name in BUILTIN_VOICES:
|
if name in BUILTIN_VOICES:
|
||||||
raise HTTPException(status_code=400, detail="Cannot delete built-in voice")
|
raise HTTPException(status_code=400, detail="Cannot delete built-in voice")
|
||||||
|
|
||||||
voice_path = VOICES_DIR / f"{name}.wav"
|
voice_path = VOICES_DIR / f"{name}.wav"
|
||||||
if not voice_path.exists():
|
if not voice_path.exists():
|
||||||
raise HTTPException(status_code=404, detail="Voice not found")
|
raise HTTPException(status_code=404, detail="Voice not found")
|
||||||
|
|
||||||
voice_path.unlink()
|
voice_path.unlink()
|
||||||
return {"status": "success", "message": f"Voice '{name}' deleted"}
|
return {"status": "success", "message": f"Voice '{name}' deleted"}
|
||||||
|
|
||||||
|
|
||||||
@app.delete("/tts/job/{job_id}")
|
@app.delete("/tts/job/{job_id}")
|
||||||
async def delete_job(job_id: str):
|
async def delete_job(job_id: str):
|
||||||
"""Delete a job and its audio file."""
|
|
||||||
if job_id not in jobs:
|
if job_id not in jobs:
|
||||||
raise HTTPException(status_code=404, detail="Job not found")
|
raise HTTPException(status_code=404, detail="Job not found")
|
||||||
|
|
||||||
job = jobs[job_id]
|
job = jobs[job_id]
|
||||||
|
|
||||||
if job.audio_path and Path(job.audio_path).exists():
|
if job.audio_path and Path(job.audio_path).exists():
|
||||||
try:
|
try:
|
||||||
Path(job.audio_path).unlink()
|
Path(job.audio_path).unlink()
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
del jobs[job_id]
|
del jobs[job_id]
|
||||||
save_jobs_to_disk()
|
save_jobs_to_disk()
|
||||||
|
|
||||||
return {"message": f"Job {job_id} deleted"}
|
return {"message": f"Job {job_id} deleted"}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
uvicorn.run(
|
uvicorn.run("main:app", host="0.0.0.0", port=8766, reload=False)
|
||||||
"main:app",
|
|
||||||
host="0.0.0.0",
|
|
||||||
port=8766, # Same port as VoiceTail for drop-in replacement
|
|
||||||
reload=False
|
|
||||||
)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user