token limit and chunking
This commit is contained in:
@@ -42,7 +42,7 @@ ENV OUTPUT_DIR=/app/output
|
|||||||
ENV VOICES_DIR=/app/voices
|
ENV VOICES_DIR=/app/voices
|
||||||
ENV ORPHEUS_MODEL=canopylabs/orpheus-tts-0.1-finetune-prod
|
ENV ORPHEUS_MODEL=canopylabs/orpheus-tts-0.1-finetune-prod
|
||||||
ENV DEFAULT_VOICE=tara
|
ENV DEFAULT_VOICE=tara
|
||||||
ENV MAX_MODEL_LEN=2048
|
ENV MAX_MODEL_LEN=8192
|
||||||
|
|
||||||
# Health check (longer start period - model loading takes time)
|
# Health check (longer start period - model loading takes time)
|
||||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=180s --retries=3 \
|
HEALTHCHECK --interval=30s --timeout=10s --start-period=180s --retries=3 \
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ services:
|
|||||||
environment:
|
environment:
|
||||||
- ORPHEUS_MODEL=canopylabs/orpheus-tts-0.1-finetune-prod
|
- ORPHEUS_MODEL=canopylabs/orpheus-tts-0.1-finetune-prod
|
||||||
- DEFAULT_VOICE=tara
|
- DEFAULT_VOICE=tara
|
||||||
- MAX_MODEL_LEN=1024
|
- MAX_MODEL_LEN=8192
|
||||||
- CACHE_ENABLED=true
|
- CACHE_ENABLED=true
|
||||||
- RETENTION_DAYS=10
|
- RETENTION_DAYS=10
|
||||||
- HF_TOKEN=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
|
- HF_TOKEN=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
|
||||||
|
|||||||
99
main.py
99
main.py
@@ -22,6 +22,7 @@ Endpoints:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import json
|
import json
|
||||||
import hashlib
|
import hashlib
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -47,7 +48,8 @@ VOICES_DIR = Path(os.getenv("VOICES_DIR", "voices")) # For cloned voice referen
|
|||||||
RETENTION_DAYS = int(os.getenv("RETENTION_DAYS", "10"))
|
RETENTION_DAYS = int(os.getenv("RETENTION_DAYS", "10"))
|
||||||
CLEANUP_INTERVAL_HOURS = int(os.getenv("CLEANUP_INTERVAL_HOURS", "1"))
|
CLEANUP_INTERVAL_HOURS = int(os.getenv("CLEANUP_INTERVAL_HOURS", "1"))
|
||||||
DEFAULT_VOICE = os.getenv("DEFAULT_VOICE", "tara") # Orpheus default voice
|
DEFAULT_VOICE = os.getenv("DEFAULT_VOICE", "tara") # Orpheus default voice
|
||||||
MAX_MODEL_LEN = int(os.getenv("MAX_MODEL_LEN", "2048"))
|
MAX_MODEL_LEN = int(os.getenv("MAX_MODEL_LEN", "8192"))
|
||||||
|
MAX_TOKENS = int(os.getenv("MAX_TOKENS", "8000"))
|
||||||
SAMPLE_RATE = 24000
|
SAMPLE_RATE = 24000
|
||||||
|
|
||||||
# Ensure directories exist
|
# Ensure directories exist
|
||||||
@@ -164,6 +166,37 @@ def get_custom_voices() -> List[str]:
|
|||||||
return voices
|
return voices
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_text(text: str, max_chars: int = 800) -> List[str]:
|
||||||
|
"""
|
||||||
|
Split text into chunks at sentence boundaries for sequential TTS generation.
|
||||||
|
|
||||||
|
Ensures no chunk exceeds max_chars (unless a single sentence is longer).
|
||||||
|
Preserves emotion tags within sentences.
|
||||||
|
"""
|
||||||
|
# Split on sentence-ending punctuation followed by whitespace
|
||||||
|
sentences = re.split(r'(?<=[.!?])\s+', text.strip())
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
current_chunk = []
|
||||||
|
current_len = 0
|
||||||
|
|
||||||
|
for sentence in sentences:
|
||||||
|
sentence_len = len(sentence)
|
||||||
|
# If adding this sentence would exceed the limit, finalize current chunk
|
||||||
|
if current_chunk and current_len + 1 + sentence_len > max_chars:
|
||||||
|
chunks.append(' '.join(current_chunk))
|
||||||
|
current_chunk = []
|
||||||
|
current_len = 0
|
||||||
|
current_chunk.append(sentence)
|
||||||
|
current_len += sentence_len + (1 if current_len > 0 else 0)
|
||||||
|
|
||||||
|
# Don't forget the last chunk
|
||||||
|
if current_chunk:
|
||||||
|
chunks.append(' '.join(current_chunk))
|
||||||
|
|
||||||
|
return chunks if chunks else [text]
|
||||||
|
|
||||||
|
|
||||||
def generate_speech_sync(text: str, voice: str) -> bytes:
|
def generate_speech_sync(text: str, voice: str) -> bytes:
|
||||||
"""
|
"""
|
||||||
Generate speech using Orpheus model (synchronous).
|
Generate speech using Orpheus model (synchronous).
|
||||||
@@ -188,34 +221,36 @@ def generate_speech_sync(text: str, voice: str) -> bytes:
|
|||||||
print(f"Unknown voice '{voice}', using default '{DEFAULT_VOICE}'")
|
print(f"Unknown voice '{voice}', using default '{DEFAULT_VOICE}'")
|
||||||
voice = DEFAULT_VOICE
|
voice = DEFAULT_VOICE
|
||||||
|
|
||||||
print(f"Generating: {text}")
|
# Split long text into chunks for sequential generation
|
||||||
|
text_chunks = chunk_text(text)
|
||||||
|
print(f"Generating: {text} ({len(text_chunks)} chunk(s))")
|
||||||
|
|
||||||
audio_chunks = []
|
all_pcm = []
|
||||||
|
|
||||||
# Call model directly - it returns a generator
|
for chunk_idx, chunk in enumerate(text_chunks):
|
||||||
syn_tokens = model.generate_speech(
|
print(f"Chunk {chunk_idx + 1}/{len(text_chunks)}: {chunk[:80]}...")
|
||||||
prompt=text,
|
audio_chunks = []
|
||||||
voice=voice,
|
|
||||||
max_tokens=4000, # Increased from default 1200 for longer texts
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Got generator: {type(syn_tokens)}")
|
syn_tokens = model.generate_speech(
|
||||||
|
prompt=chunk,
|
||||||
|
voice=voice,
|
||||||
|
max_tokens=MAX_TOKENS,
|
||||||
|
)
|
||||||
|
|
||||||
# Iterate over generator
|
for i, audio_chunk in enumerate(syn_tokens):
|
||||||
for i, audio_chunk in enumerate(syn_tokens):
|
audio_chunks.append(audio_chunk)
|
||||||
print(f"Chunk {i}: {type(audio_chunk)}, shape: {audio_chunk.shape if hasattr(audio_chunk, 'shape') else 'N/A'}")
|
|
||||||
audio_chunks.append(audio_chunk)
|
|
||||||
|
|
||||||
print(f"Total chunks: {len(audio_chunks)}")
|
print(f" -> {len(audio_chunks)} audio frames")
|
||||||
|
|
||||||
# Chunks are raw int16 bytes from SNAC decoder - just concatenate
|
if audio_chunks:
|
||||||
if len(audio_chunks) == 0:
|
all_pcm.append(b''.join(audio_chunks))
|
||||||
|
|
||||||
|
if not all_pcm:
|
||||||
raise ValueError("No audio chunks generated")
|
raise ValueError("No audio chunks generated")
|
||||||
|
|
||||||
# Concatenate bytes directly
|
# Concatenate raw PCM from all chunks, wrap in single WAV
|
||||||
audio_bytes_raw = b''.join(audio_chunks)
|
audio_bytes_raw = b''.join(all_pcm)
|
||||||
|
|
||||||
# Convert to WAV bytes
|
|
||||||
buffer = io.BytesIO()
|
buffer = io.BytesIO()
|
||||||
with wave.open(buffer, 'wb') as wf:
|
with wave.open(buffer, 'wb') as wf:
|
||||||
wf.setnchannels(1)
|
wf.setnchannels(1)
|
||||||
@@ -328,6 +363,7 @@ async def startup():
|
|||||||
print("OrpheusTail - Orpheus TTS Service Starting")
|
print("OrpheusTail - Orpheus TTS Service Starting")
|
||||||
print(f"Model: {ORPHEUS_MODEL}")
|
print(f"Model: {ORPHEUS_MODEL}")
|
||||||
print(f"Max Model Len: {MAX_MODEL_LEN}")
|
print(f"Max Model Len: {MAX_MODEL_LEN}")
|
||||||
|
print(f"Max Tokens: {MAX_TOKENS}")
|
||||||
print(f"Cache: {'Enabled' if CACHE_ENABLED else 'Disabled'}")
|
print(f"Cache: {'Enabled' if CACHE_ENABLED else 'Disabled'}")
|
||||||
print(f"Default Voice: {DEFAULT_VOICE}")
|
print(f"Default Voice: {DEFAULT_VOICE}")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
@@ -354,7 +390,7 @@ async def startup():
|
|||||||
OrpheusModel._setup_engine = patched_setup_engine
|
OrpheusModel._setup_engine = patched_setup_engine
|
||||||
|
|
||||||
# Also patch generate_tokens_sync to work with sync LLM
|
# Also patch generate_tokens_sync to work with sync LLM
|
||||||
def patched_generate_tokens_sync(self, prompt, voice=None, request_id="req-001", temperature=0.6, top_p=0.8, max_tokens=4000, stop_token_ids=[49158], repetition_penalty=1.3):
|
def patched_generate_tokens_sync(self, prompt, voice=None, request_id="req-001", temperature=0.6, top_p=0.8, max_tokens=MAX_TOKENS, stop_token_ids=[49158], repetition_penalty=1.3):
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
import re
|
import re
|
||||||
prompt_string = self._format_prompt(prompt, voice)
|
prompt_string = self._format_prompt(prompt, voice)
|
||||||
@@ -565,17 +601,18 @@ async def stream_tts(request: TTSStreamRequest):
|
|||||||
voice = DEFAULT_VOICE
|
voice = DEFAULT_VOICE
|
||||||
|
|
||||||
def sync_audio_generator():
|
def sync_audio_generator():
|
||||||
"""Generate audio chunks (sync generator)"""
|
"""Generate audio chunks (sync generator), chunking long text."""
|
||||||
try:
|
try:
|
||||||
syn_tokens = model.generate_speech(
|
text_chunks = chunk_text(request.text)
|
||||||
prompt=request.text,
|
for chunk_idx, chunk in enumerate(text_chunks):
|
||||||
voice=voice,
|
print(f"Stream chunk {chunk_idx + 1}/{len(text_chunks)}: {chunk[:80]}...")
|
||||||
max_tokens=4000, # Increased from default 1200 for longer texts
|
syn_tokens = model.generate_speech(
|
||||||
)
|
prompt=chunk,
|
||||||
|
voice=voice,
|
||||||
for audio_chunk in syn_tokens:
|
max_tokens=MAX_TOKENS,
|
||||||
yield audio_chunk
|
)
|
||||||
|
for audio_chunk in syn_tokens:
|
||||||
|
yield audio_chunk
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Stream error: {e}")
|
print(f"Stream error: {e}")
|
||||||
raise
|
raise
|
||||||
|
|||||||
Reference in New Issue
Block a user