diff --git a/Dockerfile b/Dockerfile index bda7673..4a6bda6 100644 --- a/Dockerfile +++ b/Dockerfile @@ -42,7 +42,7 @@ ENV OUTPUT_DIR=/app/output ENV VOICES_DIR=/app/voices ENV ORPHEUS_MODEL=canopylabs/orpheus-tts-0.1-finetune-prod ENV DEFAULT_VOICE=tara -ENV MAX_MODEL_LEN=2048 +ENV MAX_MODEL_LEN=8192 # Health check (longer start period - model loading takes time) HEALTHCHECK --interval=30s --timeout=10s --start-period=180s --retries=3 \ diff --git a/docker-compose.yml b/docker-compose.yml index 2c933d1..53d52ea 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -36,7 +36,7 @@ services: environment: - ORPHEUS_MODEL=canopylabs/orpheus-tts-0.1-finetune-prod - DEFAULT_VOICE=tara - - MAX_MODEL_LEN=1024 + - MAX_MODEL_LEN=8192 - CACHE_ENABLED=true - RETENTION_DAYS=10 - HF_TOKEN=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 diff --git a/main.py b/main.py index 6996fca..367b2e1 100644 --- a/main.py +++ b/main.py @@ -22,6 +22,7 @@ Endpoints: """ import os +import re import json import hashlib import asyncio @@ -47,7 +48,8 @@ VOICES_DIR = Path(os.getenv("VOICES_DIR", "voices")) # For cloned voice referen RETENTION_DAYS = int(os.getenv("RETENTION_DAYS", "10")) CLEANUP_INTERVAL_HOURS = int(os.getenv("CLEANUP_INTERVAL_HOURS", "1")) DEFAULT_VOICE = os.getenv("DEFAULT_VOICE", "tara") # Orpheus default voice -MAX_MODEL_LEN = int(os.getenv("MAX_MODEL_LEN", "2048")) +MAX_MODEL_LEN = int(os.getenv("MAX_MODEL_LEN", "8192")) +MAX_TOKENS = int(os.getenv("MAX_TOKENS", "8000")) SAMPLE_RATE = 24000 # Ensure directories exist @@ -164,6 +166,37 @@ def get_custom_voices() -> List[str]: return voices +def chunk_text(text: str, max_chars: int = 800) -> List[str]: + """ + Split text into chunks at sentence boundaries for sequential TTS generation. + + Ensures no chunk exceeds max_chars (unless a single sentence is longer). + Preserves emotion tags within sentences. + """ + # Split on sentence-ending punctuation followed by whitespace + sentences = re.split(r'(?<=[.!?])\s+', text.strip()) + + chunks = [] + current_chunk = [] + current_len = 0 + + for sentence in sentences: + sentence_len = len(sentence) + # If adding this sentence would exceed the limit, finalize current chunk + if current_chunk and current_len + 1 + sentence_len > max_chars: + chunks.append(' '.join(current_chunk)) + current_chunk = [] + current_len = 0 + current_chunk.append(sentence) + current_len += sentence_len + (1 if current_len > 0 else 0) + + # Don't forget the last chunk + if current_chunk: + chunks.append(' '.join(current_chunk)) + + return chunks if chunks else [text] + + def generate_speech_sync(text: str, voice: str) -> bytes: """ Generate speech using Orpheus model (synchronous). @@ -188,41 +221,43 @@ def generate_speech_sync(text: str, voice: str) -> bytes: print(f"Unknown voice '{voice}', using default '{DEFAULT_VOICE}'") voice = DEFAULT_VOICE - print(f"Generating: {text}") - - audio_chunks = [] - - # Call model directly - it returns a generator - syn_tokens = model.generate_speech( - prompt=text, - voice=voice, - max_tokens=4000, # Increased from default 1200 for longer texts - ) - - print(f"Got generator: {type(syn_tokens)}") - - # Iterate over generator - for i, audio_chunk in enumerate(syn_tokens): - print(f"Chunk {i}: {type(audio_chunk)}, shape: {audio_chunk.shape if hasattr(audio_chunk, 'shape') else 'N/A'}") - audio_chunks.append(audio_chunk) - - print(f"Total chunks: {len(audio_chunks)}") - - # Chunks are raw int16 bytes from SNAC decoder - just concatenate - if len(audio_chunks) == 0: + # Split long text into chunks for sequential generation + text_chunks = chunk_text(text) + print(f"Generating: {text} ({len(text_chunks)} chunk(s))") + + all_pcm = [] + + for chunk_idx, chunk in enumerate(text_chunks): + print(f"Chunk {chunk_idx + 1}/{len(text_chunks)}: {chunk[:80]}...") + audio_chunks = [] + + syn_tokens = model.generate_speech( + prompt=chunk, + voice=voice, + max_tokens=MAX_TOKENS, + ) + + for i, audio_chunk in enumerate(syn_tokens): + audio_chunks.append(audio_chunk) + + print(f" -> {len(audio_chunks)} audio frames") + + if audio_chunks: + all_pcm.append(b''.join(audio_chunks)) + + if not all_pcm: raise ValueError("No audio chunks generated") - - # Concatenate bytes directly - audio_bytes_raw = b''.join(audio_chunks) - - # Convert to WAV bytes + + # Concatenate raw PCM from all chunks, wrap in single WAV + audio_bytes_raw = b''.join(all_pcm) + buffer = io.BytesIO() with wave.open(buffer, 'wb') as wf: wf.setnchannels(1) wf.setsampwidth(2) # 16-bit wf.setframerate(SAMPLE_RATE) wf.writeframes(audio_bytes_raw) - + print(f"Generated WAV: {len(buffer.getvalue())} bytes") return buffer.getvalue() @@ -328,6 +363,7 @@ async def startup(): print("OrpheusTail - Orpheus TTS Service Starting") print(f"Model: {ORPHEUS_MODEL}") print(f"Max Model Len: {MAX_MODEL_LEN}") + print(f"Max Tokens: {MAX_TOKENS}") print(f"Cache: {'Enabled' if CACHE_ENABLED else 'Disabled'}") print(f"Default Voice: {DEFAULT_VOICE}") print("=" * 60) @@ -354,7 +390,7 @@ async def startup(): OrpheusModel._setup_engine = patched_setup_engine # Also patch generate_tokens_sync to work with sync LLM - def patched_generate_tokens_sync(self, prompt, voice=None, request_id="req-001", temperature=0.6, top_p=0.8, max_tokens=4000, stop_token_ids=[49158], repetition_penalty=1.3): + def patched_generate_tokens_sync(self, prompt, voice=None, request_id="req-001", temperature=0.6, top_p=0.8, max_tokens=MAX_TOKENS, stop_token_ids=[49158], repetition_penalty=1.3): from vllm import SamplingParams import re prompt_string = self._format_prompt(prompt, voice) @@ -565,17 +601,18 @@ async def stream_tts(request: TTSStreamRequest): voice = DEFAULT_VOICE def sync_audio_generator(): - """Generate audio chunks (sync generator)""" + """Generate audio chunks (sync generator), chunking long text.""" try: - syn_tokens = model.generate_speech( - prompt=request.text, - voice=voice, - max_tokens=4000, # Increased from default 1200 for longer texts - ) - - for audio_chunk in syn_tokens: - yield audio_chunk - + text_chunks = chunk_text(request.text) + for chunk_idx, chunk in enumerate(text_chunks): + print(f"Stream chunk {chunk_idx + 1}/{len(text_chunks)}: {chunk[:80]}...") + syn_tokens = model.generate_speech( + prompt=chunk, + voice=voice, + max_tokens=MAX_TOKENS, + ) + for audio_chunk in syn_tokens: + yield audio_chunk except Exception as e: print(f"Stream error: {e}") raise