token limit and chunking
This commit is contained in:
@@ -42,7 +42,7 @@ ENV OUTPUT_DIR=/app/output
|
||||
ENV VOICES_DIR=/app/voices
|
||||
ENV ORPHEUS_MODEL=canopylabs/orpheus-tts-0.1-finetune-prod
|
||||
ENV DEFAULT_VOICE=tara
|
||||
ENV MAX_MODEL_LEN=2048
|
||||
ENV MAX_MODEL_LEN=8192
|
||||
|
||||
# Health check (longer start period - model loading takes time)
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=180s --retries=3 \
|
||||
|
||||
@@ -36,7 +36,7 @@ services:
|
||||
environment:
|
||||
- ORPHEUS_MODEL=canopylabs/orpheus-tts-0.1-finetune-prod
|
||||
- DEFAULT_VOICE=tara
|
||||
- MAX_MODEL_LEN=1024
|
||||
- MAX_MODEL_LEN=8192
|
||||
- CACHE_ENABLED=true
|
||||
- RETENTION_DAYS=10
|
||||
- HF_TOKEN=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
|
||||
|
||||
79
main.py
79
main.py
@@ -22,6 +22,7 @@ Endpoints:
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import hashlib
|
||||
import asyncio
|
||||
@@ -47,7 +48,8 @@ VOICES_DIR = Path(os.getenv("VOICES_DIR", "voices")) # For cloned voice referen
|
||||
RETENTION_DAYS = int(os.getenv("RETENTION_DAYS", "10"))
|
||||
CLEANUP_INTERVAL_HOURS = int(os.getenv("CLEANUP_INTERVAL_HOURS", "1"))
|
||||
DEFAULT_VOICE = os.getenv("DEFAULT_VOICE", "tara") # Orpheus default voice
|
||||
MAX_MODEL_LEN = int(os.getenv("MAX_MODEL_LEN", "2048"))
|
||||
MAX_MODEL_LEN = int(os.getenv("MAX_MODEL_LEN", "8192"))
|
||||
MAX_TOKENS = int(os.getenv("MAX_TOKENS", "8000"))
|
||||
SAMPLE_RATE = 24000
|
||||
|
||||
# Ensure directories exist
|
||||
@@ -164,6 +166,37 @@ def get_custom_voices() -> List[str]:
|
||||
return voices
|
||||
|
||||
|
||||
def chunk_text(text: str, max_chars: int = 800) -> List[str]:
|
||||
"""
|
||||
Split text into chunks at sentence boundaries for sequential TTS generation.
|
||||
|
||||
Ensures no chunk exceeds max_chars (unless a single sentence is longer).
|
||||
Preserves emotion tags within sentences.
|
||||
"""
|
||||
# Split on sentence-ending punctuation followed by whitespace
|
||||
sentences = re.split(r'(?<=[.!?])\s+', text.strip())
|
||||
|
||||
chunks = []
|
||||
current_chunk = []
|
||||
current_len = 0
|
||||
|
||||
for sentence in sentences:
|
||||
sentence_len = len(sentence)
|
||||
# If adding this sentence would exceed the limit, finalize current chunk
|
||||
if current_chunk and current_len + 1 + sentence_len > max_chars:
|
||||
chunks.append(' '.join(current_chunk))
|
||||
current_chunk = []
|
||||
current_len = 0
|
||||
current_chunk.append(sentence)
|
||||
current_len += sentence_len + (1 if current_len > 0 else 0)
|
||||
|
||||
# Don't forget the last chunk
|
||||
if current_chunk:
|
||||
chunks.append(' '.join(current_chunk))
|
||||
|
||||
return chunks if chunks else [text]
|
||||
|
||||
|
||||
def generate_speech_sync(text: str, voice: str) -> bytes:
|
||||
"""
|
||||
Generate speech using Orpheus model (synchronous).
|
||||
@@ -188,34 +221,36 @@ def generate_speech_sync(text: str, voice: str) -> bytes:
|
||||
print(f"Unknown voice '{voice}', using default '{DEFAULT_VOICE}'")
|
||||
voice = DEFAULT_VOICE
|
||||
|
||||
print(f"Generating: {text}")
|
||||
# Split long text into chunks for sequential generation
|
||||
text_chunks = chunk_text(text)
|
||||
print(f"Generating: {text} ({len(text_chunks)} chunk(s))")
|
||||
|
||||
all_pcm = []
|
||||
|
||||
for chunk_idx, chunk in enumerate(text_chunks):
|
||||
print(f"Chunk {chunk_idx + 1}/{len(text_chunks)}: {chunk[:80]}...")
|
||||
audio_chunks = []
|
||||
|
||||
# Call model directly - it returns a generator
|
||||
syn_tokens = model.generate_speech(
|
||||
prompt=text,
|
||||
prompt=chunk,
|
||||
voice=voice,
|
||||
max_tokens=4000, # Increased from default 1200 for longer texts
|
||||
max_tokens=MAX_TOKENS,
|
||||
)
|
||||
|
||||
print(f"Got generator: {type(syn_tokens)}")
|
||||
|
||||
# Iterate over generator
|
||||
for i, audio_chunk in enumerate(syn_tokens):
|
||||
print(f"Chunk {i}: {type(audio_chunk)}, shape: {audio_chunk.shape if hasattr(audio_chunk, 'shape') else 'N/A'}")
|
||||
audio_chunks.append(audio_chunk)
|
||||
|
||||
print(f"Total chunks: {len(audio_chunks)}")
|
||||
print(f" -> {len(audio_chunks)} audio frames")
|
||||
|
||||
# Chunks are raw int16 bytes from SNAC decoder - just concatenate
|
||||
if len(audio_chunks) == 0:
|
||||
if audio_chunks:
|
||||
all_pcm.append(b''.join(audio_chunks))
|
||||
|
||||
if not all_pcm:
|
||||
raise ValueError("No audio chunks generated")
|
||||
|
||||
# Concatenate bytes directly
|
||||
audio_bytes_raw = b''.join(audio_chunks)
|
||||
# Concatenate raw PCM from all chunks, wrap in single WAV
|
||||
audio_bytes_raw = b''.join(all_pcm)
|
||||
|
||||
# Convert to WAV bytes
|
||||
buffer = io.BytesIO()
|
||||
with wave.open(buffer, 'wb') as wf:
|
||||
wf.setnchannels(1)
|
||||
@@ -328,6 +363,7 @@ async def startup():
|
||||
print("OrpheusTail - Orpheus TTS Service Starting")
|
||||
print(f"Model: {ORPHEUS_MODEL}")
|
||||
print(f"Max Model Len: {MAX_MODEL_LEN}")
|
||||
print(f"Max Tokens: {MAX_TOKENS}")
|
||||
print(f"Cache: {'Enabled' if CACHE_ENABLED else 'Disabled'}")
|
||||
print(f"Default Voice: {DEFAULT_VOICE}")
|
||||
print("=" * 60)
|
||||
@@ -354,7 +390,7 @@ async def startup():
|
||||
OrpheusModel._setup_engine = patched_setup_engine
|
||||
|
||||
# Also patch generate_tokens_sync to work with sync LLM
|
||||
def patched_generate_tokens_sync(self, prompt, voice=None, request_id="req-001", temperature=0.6, top_p=0.8, max_tokens=4000, stop_token_ids=[49158], repetition_penalty=1.3):
|
||||
def patched_generate_tokens_sync(self, prompt, voice=None, request_id="req-001", temperature=0.6, top_p=0.8, max_tokens=MAX_TOKENS, stop_token_ids=[49158], repetition_penalty=1.3):
|
||||
from vllm import SamplingParams
|
||||
import re
|
||||
prompt_string = self._format_prompt(prompt, voice)
|
||||
@@ -565,17 +601,18 @@ async def stream_tts(request: TTSStreamRequest):
|
||||
voice = DEFAULT_VOICE
|
||||
|
||||
def sync_audio_generator():
|
||||
"""Generate audio chunks (sync generator)"""
|
||||
"""Generate audio chunks (sync generator), chunking long text."""
|
||||
try:
|
||||
text_chunks = chunk_text(request.text)
|
||||
for chunk_idx, chunk in enumerate(text_chunks):
|
||||
print(f"Stream chunk {chunk_idx + 1}/{len(text_chunks)}: {chunk[:80]}...")
|
||||
syn_tokens = model.generate_speech(
|
||||
prompt=request.text,
|
||||
prompt=chunk,
|
||||
voice=voice,
|
||||
max_tokens=4000, # Increased from default 1200 for longer texts
|
||||
max_tokens=MAX_TOKENS,
|
||||
)
|
||||
|
||||
for audio_chunk in syn_tokens:
|
||||
yield audio_chunk
|
||||
|
||||
except Exception as e:
|
||||
print(f"Stream error: {e}")
|
||||
raise
|
||||
|
||||
Reference in New Issue
Block a user