True streaming TTS: AsyncLLMEngine + incremental token decoding

Replaced sync vLLM LLM with AsyncLLMEngine for real streaming.
Tokens now flow incrementally: vLLM → async_generate_tokens →
orpheus_tts tokens_decoder → audio chunks → StreamingResponse.

First audio chunk arrives after ~28 tokens (SNAC codec warmup)
instead of waiting for all ~2000+ tokens to complete.

Expected: first-byte latency drops from ~15s to ~1-2s.

Background jobs (submit/async) still work via sync wrapper that
collects all tokens from the async engine.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Alex
2026-04-12 23:36:24 -05:00
parent 14af1d0600
commit 25ed6625aa

105
main.py
View File

@@ -374,47 +374,85 @@ async def startup():
from vllm import AsyncLLMEngine from vllm import AsyncLLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
# Monkey-patch OrpheusModel to use sync LLM for proper sync context # Monkey-patch OrpheusModel to use AsyncLLMEngine for true streaming
original_setup_engine = OrpheusModel._setup_engine original_setup_engine = OrpheusModel._setup_engine
def patched_setup_engine(self): def patched_setup_engine(self):
# Get the mapped model name (handles "medium-3b" -> full path)
model_name = self._map_model_params(self.model_name) model_name = self._map_model_params(self.model_name)
# Use LLM (sync) instead of AsyncLLMEngine to avoid event loop conflicts engine_args = AsyncEngineArgs(
from vllm import LLM
return LLM(
model=model_name, model=model_name,
max_model_len=MAX_MODEL_LEN, # Our custom limit! max_model_len=MAX_MODEL_LEN,
gpu_memory_utilization=0.85, # Leave some headroom gpu_memory_utilization=0.85,
enforce_eager=False, enforce_eager=False,
) )
return AsyncLLMEngine.from_engine_args(engine_args)
OrpheusModel._setup_engine = patched_setup_engine OrpheusModel._setup_engine = patched_setup_engine
# Also patch generate_tokens_sync to work with sync LLM # Sync token generation (for background jobs)
def patched_generate_tokens_sync(self, prompt, voice=None, request_id="req-001", temperature=0.6, top_p=0.8, max_tokens=MAX_TOKENS, stop_token_ids=[49158], repetition_penalty=1.3): # Uses the async engine but collects all results synchronously
def patched_generate_tokens_sync(self, prompt, voice=None, request_id="req-001",
temperature=0.6, top_p=0.8, max_tokens=MAX_TOKENS,
stop_token_ids=[49158], repetition_penalty=1.3):
from vllm import SamplingParams from vllm import SamplingParams
import re import re
prompt_string = self._format_prompt(prompt, voice) prompt_string = self._format_prompt(prompt, voice)
print(prompt) print(prompt)
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=temperature, temperature=temperature, top_p=top_p, max_tokens=max_tokens,
top_p=top_p, stop_token_ids=stop_token_ids, repetition_penalty=repetition_penalty,
max_tokens=max_tokens,
stop_token_ids=stop_token_ids,
repetition_penalty=repetition_penalty,
) )
# Use sync generate - yields full output req_id = f"sync-{uuid.uuid4().hex[:8]}"
outputs = self.engine.generate([prompt_string], sampling_params)
# Yield individual tokens from the output text # Collect from async engine — use the running loop if available, else create one
for output in outputs: async def _collect_all():
text = output.outputs[0].text final = None
async for output in self.engine.generate(prompt_string, sampling_params, req_id):
final = output
return final
try:
loop = asyncio.get_running_loop()
# We're in an async context (background task) — use asyncio.ensure_future
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as pool:
final_output = pool.submit(lambda: asyncio.run(_collect_all())).result()
except RuntimeError:
# No running loop — safe to use asyncio.run
final_output = asyncio.run(_collect_all())
if final_output:
text = final_output.outputs[0].text
print(f"Raw output (first 500 chars): {text[:500]}") print(f"Raw output (first 500 chars): {text[:500]}")
# Extract all <custom_token_XXXX> patterns
tokens = re.findall(r'<custom_token_\d+>', text) tokens = re.findall(r'<custom_token_\d+>', text)
print(f"Found {len(tokens)} tokens") print(f"Found {len(tokens)} tokens")
for token in tokens: for token in tokens:
yield token yield token
OrpheusModel.generate_tokens_sync = patched_generate_tokens_sync OrpheusModel.generate_tokens_sync = patched_generate_tokens_sync
# Async streaming token generation (for /tts/stream — yields tokens as produced)
async def async_generate_tokens(model_instance, prompt, voice=None,
temperature=0.6, top_p=0.8, max_tokens=MAX_TOKENS,
stop_token_ids=[49158], repetition_penalty=1.3):
"""Async generator: yields tokens incrementally as vLLM produces them."""
from vllm import SamplingParams
prompt_string = model_instance._format_prompt(prompt, voice)
print(f"[streaming] {prompt}")
sampling_params = SamplingParams(
temperature=temperature, top_p=top_p, max_tokens=max_tokens,
stop_token_ids=stop_token_ids, repetition_penalty=repetition_penalty,
)
request_id = f"stream-{uuid.uuid4().hex[:8]}"
prev_text_len = 0
async for output in model_instance.engine.generate(prompt_string, sampling_params, request_id):
# output.outputs[0].text grows incrementally
text = output.outputs[0].text
new_text = text[prev_text_len:]
prev_text_len = len(text)
# Extract new tokens from the incremental text
new_tokens = re.findall(r'<custom_token_\d+>', new_text)
for token in new_tokens:
yield token
model = OrpheusModel(model_name=ORPHEUS_MODEL) model = OrpheusModel(model_name=ORPHEUS_MODEL)
print("✓ Orpheus model loaded successfully") print("✓ Orpheus model loaded successfully")
@@ -586,10 +624,11 @@ async def get_audio(job_id: str):
@app.post("/tts/stream") @app.post("/tts/stream")
async def stream_tts(request: TTSStreamRequest): async def stream_tts(request: TTSStreamRequest):
""" """
Stream TTS audio in real-time. Stream TTS audio in real-time with true token-level streaming.
For head-vixy to stream directly without waiting for full generation. Audio starts playing within ~1-2s instead of waiting for full generation.
Returns audio chunks as they're generated. Uses vLLM AsyncLLMEngine to yield tokens incrementally, decoded to audio
via orpheus_tts.decoder.tokens_decoder every 7 tokens.
""" """
global model global model
@@ -600,26 +639,26 @@ async def stream_tts(request: TTSStreamRequest):
if voice not in BUILTIN_VOICES: if voice not in BUILTIN_VOICES:
voice = DEFAULT_VOICE voice = DEFAULT_VOICE
def sync_audio_generator(): async def streaming_audio_generator():
"""Generate audio chunks (sync generator), chunking long text.""" """Async generator: vLLM tokens → SNAC decoder → PCM audio chunks."""
from orpheus_tts.decoder import tokens_decoder
try: try:
text_chunks = chunk_text(request.text) text_chunks = chunk_text(request.text)
for chunk_idx, chunk in enumerate(text_chunks): for chunk_idx, chunk in enumerate(text_chunks):
print(f"Stream chunk {chunk_idx + 1}/{len(text_chunks)}: {chunk[:80]}...") print(f"Stream chunk {chunk_idx + 1}/{len(text_chunks)}: {chunk[:80]}...")
syn_tokens = model.generate_speech(
prompt=chunk, # async_generate_tokens yields tokens as vLLM produces them
voice=voice, # tokens_decoder converts them to audio every 7 tokens
max_tokens=MAX_TOKENS, token_gen = async_generate_tokens(model, chunk, voice=voice)
) async for audio_chunk in tokens_decoder(token_gen):
for audio_chunk in syn_tokens:
yield audio_chunk yield audio_chunk
except Exception as e: except Exception as e:
print(f"Stream error: {e}") print(f"Stream error: {e}")
raise raise
return StreamingResponse( return StreamingResponse(
sync_audio_generator(), streaming_audio_generator(),
media_type="audio/wav" media_type="audio/pcm"
) )