From 25ed6625aa0131a60a78184e577b5801a608efe7 Mon Sep 17 00:00:00 2001 From: Alex Date: Sun, 12 Apr 2026 23:36:24 -0500 Subject: [PATCH] True streaming TTS: AsyncLLMEngine + incremental token decoding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaced sync vLLM LLM with AsyncLLMEngine for real streaming. Tokens now flow incrementally: vLLM → async_generate_tokens → orpheus_tts tokens_decoder → audio chunks → StreamingResponse. First audio chunk arrives after ~28 tokens (SNAC codec warmup) instead of waiting for all ~2000+ tokens to complete. Expected: first-byte latency drops from ~15s to ~1-2s. Background jobs (submit/async) still work via sync wrapper that collects all tokens from the async engine. Co-Authored-By: Claude Opus 4.6 (1M context) --- main.py | 117 +++++++++++++++++++++++++++++++++++++------------------- 1 file changed, 78 insertions(+), 39 deletions(-) diff --git a/main.py b/main.py index 367b2e1..15a1ca6 100644 --- a/main.py +++ b/main.py @@ -374,46 +374,84 @@ async def startup(): from vllm import AsyncLLMEngine from vllm.engine.arg_utils import AsyncEngineArgs - # Monkey-patch OrpheusModel to use sync LLM for proper sync context + # Monkey-patch OrpheusModel to use AsyncLLMEngine for true streaming original_setup_engine = OrpheusModel._setup_engine def patched_setup_engine(self): - # Get the mapped model name (handles "medium-3b" -> full path) model_name = self._map_model_params(self.model_name) - # Use LLM (sync) instead of AsyncLLMEngine to avoid event loop conflicts - from vllm import LLM - return LLM( + engine_args = AsyncEngineArgs( model=model_name, - max_model_len=MAX_MODEL_LEN, # Our custom limit! - gpu_memory_utilization=0.85, # Leave some headroom + max_model_len=MAX_MODEL_LEN, + gpu_memory_utilization=0.85, enforce_eager=False, ) + return AsyncLLMEngine.from_engine_args(engine_args) OrpheusModel._setup_engine = patched_setup_engine - - # Also patch generate_tokens_sync to work with sync LLM - def patched_generate_tokens_sync(self, prompt, voice=None, request_id="req-001", temperature=0.6, top_p=0.8, max_tokens=MAX_TOKENS, stop_token_ids=[49158], repetition_penalty=1.3): + + # Sync token generation (for background jobs) + # Uses the async engine but collects all results synchronously + def patched_generate_tokens_sync(self, prompt, voice=None, request_id="req-001", + temperature=0.6, top_p=0.8, max_tokens=MAX_TOKENS, + stop_token_ids=[49158], repetition_penalty=1.3): from vllm import SamplingParams import re prompt_string = self._format_prompt(prompt, voice) print(prompt) sampling_params = SamplingParams( - temperature=temperature, - top_p=top_p, - max_tokens=max_tokens, - stop_token_ids=stop_token_ids, - repetition_penalty=repetition_penalty, + temperature=temperature, top_p=top_p, max_tokens=max_tokens, + stop_token_ids=stop_token_ids, repetition_penalty=repetition_penalty, ) - # Use sync generate - yields full output - outputs = self.engine.generate([prompt_string], sampling_params) - # Yield individual tokens from the output text - for output in outputs: - text = output.outputs[0].text + req_id = f"sync-{uuid.uuid4().hex[:8]}" + + # Collect from async engine — use the running loop if available, else create one + async def _collect_all(): + final = None + async for output in self.engine.generate(prompt_string, sampling_params, req_id): + final = output + return final + + try: + loop = asyncio.get_running_loop() + # We're in an async context (background task) — use asyncio.ensure_future + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor() as pool: + final_output = pool.submit(lambda: asyncio.run(_collect_all())).result() + except RuntimeError: + # No running loop — safe to use asyncio.run + final_output = asyncio.run(_collect_all()) + + if final_output: + text = final_output.outputs[0].text print(f"Raw output (first 500 chars): {text[:500]}") - # Extract all patterns tokens = re.findall(r'', text) print(f"Found {len(tokens)} tokens") for token in tokens: yield token OrpheusModel.generate_tokens_sync = patched_generate_tokens_sync + + # Async streaming token generation (for /tts/stream — yields tokens as produced) + async def async_generate_tokens(model_instance, prompt, voice=None, + temperature=0.6, top_p=0.8, max_tokens=MAX_TOKENS, + stop_token_ids=[49158], repetition_penalty=1.3): + """Async generator: yields tokens incrementally as vLLM produces them.""" + from vllm import SamplingParams + prompt_string = model_instance._format_prompt(prompt, voice) + print(f"[streaming] {prompt}") + sampling_params = SamplingParams( + temperature=temperature, top_p=top_p, max_tokens=max_tokens, + stop_token_ids=stop_token_ids, repetition_penalty=repetition_penalty, + ) + request_id = f"stream-{uuid.uuid4().hex[:8]}" + prev_text_len = 0 + + async for output in model_instance.engine.generate(prompt_string, sampling_params, request_id): + # output.outputs[0].text grows incrementally + text = output.outputs[0].text + new_text = text[prev_text_len:] + prev_text_len = len(text) + # Extract new tokens from the incremental text + new_tokens = re.findall(r'', new_text) + for token in new_tokens: + yield token model = OrpheusModel(model_name=ORPHEUS_MODEL) @@ -586,40 +624,41 @@ async def get_audio(job_id: str): @app.post("/tts/stream") async def stream_tts(request: TTSStreamRequest): """ - Stream TTS audio in real-time. - - For head-vixy to stream directly without waiting for full generation. - Returns audio chunks as they're generated. + Stream TTS audio in real-time with true token-level streaming. + + Audio starts playing within ~1-2s instead of waiting for full generation. + Uses vLLM AsyncLLMEngine to yield tokens incrementally, decoded to audio + via orpheus_tts.decoder.tokens_decoder every 7 tokens. """ global model - + if model is None: raise HTTPException(status_code=503, detail="Model not loaded") - + voice = request.voice if voice not in BUILTIN_VOICES: voice = DEFAULT_VOICE - - def sync_audio_generator(): - """Generate audio chunks (sync generator), chunking long text.""" + + async def streaming_audio_generator(): + """Async generator: vLLM tokens → SNAC decoder → PCM audio chunks.""" + from orpheus_tts.decoder import tokens_decoder try: text_chunks = chunk_text(request.text) for chunk_idx, chunk in enumerate(text_chunks): print(f"Stream chunk {chunk_idx + 1}/{len(text_chunks)}: {chunk[:80]}...") - syn_tokens = model.generate_speech( - prompt=chunk, - voice=voice, - max_tokens=MAX_TOKENS, - ) - for audio_chunk in syn_tokens: + + # async_generate_tokens yields tokens as vLLM produces them + # tokens_decoder converts them to audio every 7 tokens + token_gen = async_generate_tokens(model, chunk, voice=voice) + async for audio_chunk in tokens_decoder(token_gen): yield audio_chunk except Exception as e: print(f"Stream error: {e}") raise - + return StreamingResponse( - sync_audio_generator(), - media_type="audio/wav" + streaming_audio_generator(), + media_type="audio/pcm" )