diff --git a/main.py b/main.py index 15a1ca6..39d2d0a 100644 --- a/main.py +++ b/main.py @@ -197,6 +197,38 @@ def chunk_text(text: str, max_chars: int = 800) -> List[str]: return chunks if chunks else [text] +def chunk_text_fine(text: str, max_chars: int = 200) -> List[str]: + """ + Split text into fine-grained chunks for streaming — every sentence or clause. + Smaller chunks = faster first-audio, slight quality tradeoff at boundaries. + """ + # Split on sentence boundaries AND commas/semicolons with reasonable length + parts = re.split(r'(?<=[.!?;])\s+', text.strip()) + + # Further split long parts on commas + chunks = [] + for part in parts: + if len(part) <= max_chars: + chunks.append(part) + else: + # Split on commas + sub = re.split(r',\s+', part) + current = [] + current_len = 0 + for s in sub: + if current and current_len + len(s) > max_chars: + chunks.append(', '.join(current)) + current = [] + current_len = 0 + current.append(s) + current_len += len(s) + if current: + chunks.append(', '.join(current)) + + # Filter empty chunks + return [c.strip() for c in chunks if c.strip()] + + def generate_speech_sync(text: str, voice: str) -> bytes: """ Generate speech using Orpheus model (synchronous). @@ -374,21 +406,20 @@ async def startup(): from vllm import AsyncLLMEngine from vllm.engine.arg_utils import AsyncEngineArgs - # Monkey-patch OrpheusModel to use AsyncLLMEngine for true streaming + # Monkey-patch OrpheusModel to use sync LLM (AsyncLLMEngine hangs on Jetson) original_setup_engine = OrpheusModel._setup_engine def patched_setup_engine(self): model_name = self._map_model_params(self.model_name) - engine_args = AsyncEngineArgs( + from vllm import LLM + return LLM( model=model_name, max_model_len=MAX_MODEL_LEN, gpu_memory_utilization=0.85, enforce_eager=False, ) - return AsyncLLMEngine.from_engine_args(engine_args) OrpheusModel._setup_engine = patched_setup_engine - # Sync token generation (for background jobs) - # Uses the async engine but collects all results synchronously + # Sync token generation def patched_generate_tokens_sync(self, prompt, voice=None, request_id="req-001", temperature=0.6, top_p=0.8, max_tokens=MAX_TOKENS, stop_token_ids=[49158], repetition_penalty=1.3): @@ -400,58 +431,15 @@ async def startup(): temperature=temperature, top_p=top_p, max_tokens=max_tokens, stop_token_ids=stop_token_ids, repetition_penalty=repetition_penalty, ) - req_id = f"sync-{uuid.uuid4().hex[:8]}" - - # Collect from async engine — use the running loop if available, else create one - async def _collect_all(): - final = None - async for output in self.engine.generate(prompt_string, sampling_params, req_id): - final = output - return final - - try: - loop = asyncio.get_running_loop() - # We're in an async context (background task) — use asyncio.ensure_future - import concurrent.futures - with concurrent.futures.ThreadPoolExecutor() as pool: - final_output = pool.submit(lambda: asyncio.run(_collect_all())).result() - except RuntimeError: - # No running loop — safe to use asyncio.run - final_output = asyncio.run(_collect_all()) - - if final_output: - text = final_output.outputs[0].text + outputs = self.engine.generate([prompt_string], sampling_params) + for output in outputs: + text = output.outputs[0].text print(f"Raw output (first 500 chars): {text[:500]}") tokens = re.findall(r'', text) print(f"Found {len(tokens)} tokens") for token in tokens: yield token OrpheusModel.generate_tokens_sync = patched_generate_tokens_sync - - # Async streaming token generation (for /tts/stream — yields tokens as produced) - async def async_generate_tokens(model_instance, prompt, voice=None, - temperature=0.6, top_p=0.8, max_tokens=MAX_TOKENS, - stop_token_ids=[49158], repetition_penalty=1.3): - """Async generator: yields tokens incrementally as vLLM produces them.""" - from vllm import SamplingParams - prompt_string = model_instance._format_prompt(prompt, voice) - print(f"[streaming] {prompt}") - sampling_params = SamplingParams( - temperature=temperature, top_p=top_p, max_tokens=max_tokens, - stop_token_ids=stop_token_ids, repetition_penalty=repetition_penalty, - ) - request_id = f"stream-{uuid.uuid4().hex[:8]}" - prev_text_len = 0 - - async for output in model_instance.engine.generate(prompt_string, sampling_params, request_id): - # output.outputs[0].text grows incrementally - text = output.outputs[0].text - new_text = text[prev_text_len:] - prev_text_len = len(text) - # Extract new tokens from the incremental text - new_tokens = re.findall(r'', new_text) - for token in new_tokens: - yield token model = OrpheusModel(model_name=ORPHEUS_MODEL) @@ -624,11 +612,11 @@ async def get_audio(job_id: str): @app.post("/tts/stream") async def stream_tts(request: TTSStreamRequest): """ - Stream TTS audio in real-time with true token-level streaming. + Stream TTS audio with sentence-level chunking. - Audio starts playing within ~1-2s instead of waiting for full generation. - Uses vLLM AsyncLLMEngine to yield tokens incrementally, decoded to audio - via orpheus_tts.decoder.tokens_decoder every 7 tokens. + Splits text into small chunks (sentences/clauses) and generates each + independently. First chunk's audio starts playing while later chunks + are still generating. Reduces perceived latency significantly. """ global model @@ -639,25 +627,27 @@ async def stream_tts(request: TTSStreamRequest): if voice not in BUILTIN_VOICES: voice = DEFAULT_VOICE - async def streaming_audio_generator(): - """Async generator: vLLM tokens → SNAC decoder → PCM audio chunks.""" - from orpheus_tts.decoder import tokens_decoder + def sync_audio_generator(): + """Generate audio per-sentence, yielding as each finishes.""" try: - text_chunks = chunk_text(request.text) + # Split into fine-grained chunks for faster first-audio + text_chunks = chunk_text_fine(request.text) + print(f"[stream] {len(text_chunks)} chunk(s): {[c[:40] for c in text_chunks]}") for chunk_idx, chunk in enumerate(text_chunks): - print(f"Stream chunk {chunk_idx + 1}/{len(text_chunks)}: {chunk[:80]}...") - - # async_generate_tokens yields tokens as vLLM produces them - # tokens_decoder converts them to audio every 7 tokens - token_gen = async_generate_tokens(model, chunk, voice=voice) - async for audio_chunk in tokens_decoder(token_gen): + print(f" Generating chunk {chunk_idx + 1}/{len(text_chunks)}: {chunk[:60]}...") + syn_tokens = model.generate_speech( + prompt=chunk, + voice=voice, + max_tokens=MAX_TOKENS, + ) + for audio_chunk in syn_tokens: yield audio_chunk except Exception as e: print(f"Stream error: {e}") raise return StreamingResponse( - streaming_audio_generator(), + sync_audio_generator(), media_type="audio/pcm" )