token limit and chunking

This commit is contained in:
Alex
2026-02-06 10:07:05 -06:00
parent 75a5fc0a95
commit 14af1d0600
3 changed files with 79 additions and 42 deletions

View File

@@ -42,7 +42,7 @@ ENV OUTPUT_DIR=/app/output
ENV VOICES_DIR=/app/voices
ENV ORPHEUS_MODEL=canopylabs/orpheus-tts-0.1-finetune-prod
ENV DEFAULT_VOICE=tara
ENV MAX_MODEL_LEN=2048
ENV MAX_MODEL_LEN=8192
# Health check (longer start period - model loading takes time)
HEALTHCHECK --interval=30s --timeout=10s --start-period=180s --retries=3 \

View File

@@ -36,7 +36,7 @@ services:
environment:
- ORPHEUS_MODEL=canopylabs/orpheus-tts-0.1-finetune-prod
- DEFAULT_VOICE=tara
- MAX_MODEL_LEN=1024
- MAX_MODEL_LEN=8192
- CACHE_ENABLED=true
- RETENTION_DAYS=10
- HF_TOKEN=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

117
main.py
View File

@@ -22,6 +22,7 @@ Endpoints:
"""
import os
import re
import json
import hashlib
import asyncio
@@ -47,7 +48,8 @@ VOICES_DIR = Path(os.getenv("VOICES_DIR", "voices")) # For cloned voice referen
RETENTION_DAYS = int(os.getenv("RETENTION_DAYS", "10"))
CLEANUP_INTERVAL_HOURS = int(os.getenv("CLEANUP_INTERVAL_HOURS", "1"))
DEFAULT_VOICE = os.getenv("DEFAULT_VOICE", "tara") # Orpheus default voice
MAX_MODEL_LEN = int(os.getenv("MAX_MODEL_LEN", "2048"))
MAX_MODEL_LEN = int(os.getenv("MAX_MODEL_LEN", "8192"))
MAX_TOKENS = int(os.getenv("MAX_TOKENS", "8000"))
SAMPLE_RATE = 24000
# Ensure directories exist
@@ -164,6 +166,37 @@ def get_custom_voices() -> List[str]:
return voices
def chunk_text(text: str, max_chars: int = 800) -> List[str]:
"""
Split text into chunks at sentence boundaries for sequential TTS generation.
Ensures no chunk exceeds max_chars (unless a single sentence is longer).
Preserves emotion tags within sentences.
"""
# Split on sentence-ending punctuation followed by whitespace
sentences = re.split(r'(?<=[.!?])\s+', text.strip())
chunks = []
current_chunk = []
current_len = 0
for sentence in sentences:
sentence_len = len(sentence)
# If adding this sentence would exceed the limit, finalize current chunk
if current_chunk and current_len + 1 + sentence_len > max_chars:
chunks.append(' '.join(current_chunk))
current_chunk = []
current_len = 0
current_chunk.append(sentence)
current_len += sentence_len + (1 if current_len > 0 else 0)
# Don't forget the last chunk
if current_chunk:
chunks.append(' '.join(current_chunk))
return chunks if chunks else [text]
def generate_speech_sync(text: str, voice: str) -> bytes:
"""
Generate speech using Orpheus model (synchronous).
@@ -188,41 +221,43 @@ def generate_speech_sync(text: str, voice: str) -> bytes:
print(f"Unknown voice '{voice}', using default '{DEFAULT_VOICE}'")
voice = DEFAULT_VOICE
print(f"Generating: {text}")
audio_chunks = []
# Call model directly - it returns a generator
syn_tokens = model.generate_speech(
prompt=text,
voice=voice,
max_tokens=4000, # Increased from default 1200 for longer texts
)
print(f"Got generator: {type(syn_tokens)}")
# Iterate over generator
for i, audio_chunk in enumerate(syn_tokens):
print(f"Chunk {i}: {type(audio_chunk)}, shape: {audio_chunk.shape if hasattr(audio_chunk, 'shape') else 'N/A'}")
audio_chunks.append(audio_chunk)
print(f"Total chunks: {len(audio_chunks)}")
# Chunks are raw int16 bytes from SNAC decoder - just concatenate
if len(audio_chunks) == 0:
# Split long text into chunks for sequential generation
text_chunks = chunk_text(text)
print(f"Generating: {text} ({len(text_chunks)} chunk(s))")
all_pcm = []
for chunk_idx, chunk in enumerate(text_chunks):
print(f"Chunk {chunk_idx + 1}/{len(text_chunks)}: {chunk[:80]}...")
audio_chunks = []
syn_tokens = model.generate_speech(
prompt=chunk,
voice=voice,
max_tokens=MAX_TOKENS,
)
for i, audio_chunk in enumerate(syn_tokens):
audio_chunks.append(audio_chunk)
print(f" -> {len(audio_chunks)} audio frames")
if audio_chunks:
all_pcm.append(b''.join(audio_chunks))
if not all_pcm:
raise ValueError("No audio chunks generated")
# Concatenate bytes directly
audio_bytes_raw = b''.join(audio_chunks)
# Convert to WAV bytes
# Concatenate raw PCM from all chunks, wrap in single WAV
audio_bytes_raw = b''.join(all_pcm)
buffer = io.BytesIO()
with wave.open(buffer, 'wb') as wf:
wf.setnchannels(1)
wf.setsampwidth(2) # 16-bit
wf.setframerate(SAMPLE_RATE)
wf.writeframes(audio_bytes_raw)
print(f"Generated WAV: {len(buffer.getvalue())} bytes")
return buffer.getvalue()
@@ -328,6 +363,7 @@ async def startup():
print("OrpheusTail - Orpheus TTS Service Starting")
print(f"Model: {ORPHEUS_MODEL}")
print(f"Max Model Len: {MAX_MODEL_LEN}")
print(f"Max Tokens: {MAX_TOKENS}")
print(f"Cache: {'Enabled' if CACHE_ENABLED else 'Disabled'}")
print(f"Default Voice: {DEFAULT_VOICE}")
print("=" * 60)
@@ -354,7 +390,7 @@ async def startup():
OrpheusModel._setup_engine = patched_setup_engine
# Also patch generate_tokens_sync to work with sync LLM
def patched_generate_tokens_sync(self, prompt, voice=None, request_id="req-001", temperature=0.6, top_p=0.8, max_tokens=4000, stop_token_ids=[49158], repetition_penalty=1.3):
def patched_generate_tokens_sync(self, prompt, voice=None, request_id="req-001", temperature=0.6, top_p=0.8, max_tokens=MAX_TOKENS, stop_token_ids=[49158], repetition_penalty=1.3):
from vllm import SamplingParams
import re
prompt_string = self._format_prompt(prompt, voice)
@@ -565,17 +601,18 @@ async def stream_tts(request: TTSStreamRequest):
voice = DEFAULT_VOICE
def sync_audio_generator():
"""Generate audio chunks (sync generator)"""
"""Generate audio chunks (sync generator), chunking long text."""
try:
syn_tokens = model.generate_speech(
prompt=request.text,
voice=voice,
max_tokens=4000, # Increased from default 1200 for longer texts
)
for audio_chunk in syn_tokens:
yield audio_chunk
text_chunks = chunk_text(request.text)
for chunk_idx, chunk in enumerate(text_chunks):
print(f"Stream chunk {chunk_idx + 1}/{len(text_chunks)}: {chunk[:80]}...")
syn_tokens = model.generate_speech(
prompt=chunk,
voice=voice,
max_tokens=MAX_TOKENS,
)
for audio_chunk in syn_tokens:
yield audio_chunk
except Exception as e:
print(f"Stream error: {e}")
raise