True streaming TTS: AsyncLLMEngine + incremental token decoding
Replaced sync vLLM LLM with AsyncLLMEngine for real streaming. Tokens now flow incrementally: vLLM → async_generate_tokens → orpheus_tts tokens_decoder → audio chunks → StreamingResponse. First audio chunk arrives after ~28 tokens (SNAC codec warmup) instead of waiting for all ~2000+ tokens to complete. Expected: first-byte latency drops from ~15s to ~1-2s. Background jobs (submit/async) still work via sync wrapper that collects all tokens from the async engine. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
117
main.py
117
main.py
@@ -374,46 +374,84 @@ async def startup():
|
|||||||
from vllm import AsyncLLMEngine
|
from vllm import AsyncLLMEngine
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
|
|
||||||
# Monkey-patch OrpheusModel to use sync LLM for proper sync context
|
# Monkey-patch OrpheusModel to use AsyncLLMEngine for true streaming
|
||||||
original_setup_engine = OrpheusModel._setup_engine
|
original_setup_engine = OrpheusModel._setup_engine
|
||||||
def patched_setup_engine(self):
|
def patched_setup_engine(self):
|
||||||
# Get the mapped model name (handles "medium-3b" -> full path)
|
|
||||||
model_name = self._map_model_params(self.model_name)
|
model_name = self._map_model_params(self.model_name)
|
||||||
# Use LLM (sync) instead of AsyncLLMEngine to avoid event loop conflicts
|
engine_args = AsyncEngineArgs(
|
||||||
from vllm import LLM
|
|
||||||
return LLM(
|
|
||||||
model=model_name,
|
model=model_name,
|
||||||
max_model_len=MAX_MODEL_LEN, # Our custom limit!
|
max_model_len=MAX_MODEL_LEN,
|
||||||
gpu_memory_utilization=0.85, # Leave some headroom
|
gpu_memory_utilization=0.85,
|
||||||
enforce_eager=False,
|
enforce_eager=False,
|
||||||
)
|
)
|
||||||
|
return AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
OrpheusModel._setup_engine = patched_setup_engine
|
OrpheusModel._setup_engine = patched_setup_engine
|
||||||
|
|
||||||
# Also patch generate_tokens_sync to work with sync LLM
|
# Sync token generation (for background jobs)
|
||||||
def patched_generate_tokens_sync(self, prompt, voice=None, request_id="req-001", temperature=0.6, top_p=0.8, max_tokens=MAX_TOKENS, stop_token_ids=[49158], repetition_penalty=1.3):
|
# Uses the async engine but collects all results synchronously
|
||||||
|
def patched_generate_tokens_sync(self, prompt, voice=None, request_id="req-001",
|
||||||
|
temperature=0.6, top_p=0.8, max_tokens=MAX_TOKENS,
|
||||||
|
stop_token_ids=[49158], repetition_penalty=1.3):
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
import re
|
import re
|
||||||
prompt_string = self._format_prompt(prompt, voice)
|
prompt_string = self._format_prompt(prompt, voice)
|
||||||
print(prompt)
|
print(prompt)
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
temperature=temperature,
|
temperature=temperature, top_p=top_p, max_tokens=max_tokens,
|
||||||
top_p=top_p,
|
stop_token_ids=stop_token_ids, repetition_penalty=repetition_penalty,
|
||||||
max_tokens=max_tokens,
|
|
||||||
stop_token_ids=stop_token_ids,
|
|
||||||
repetition_penalty=repetition_penalty,
|
|
||||||
)
|
)
|
||||||
# Use sync generate - yields full output
|
req_id = f"sync-{uuid.uuid4().hex[:8]}"
|
||||||
outputs = self.engine.generate([prompt_string], sampling_params)
|
|
||||||
# Yield individual tokens from the output text
|
# Collect from async engine — use the running loop if available, else create one
|
||||||
for output in outputs:
|
async def _collect_all():
|
||||||
text = output.outputs[0].text
|
final = None
|
||||||
|
async for output in self.engine.generate(prompt_string, sampling_params, req_id):
|
||||||
|
final = output
|
||||||
|
return final
|
||||||
|
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
# We're in an async context (background task) — use asyncio.ensure_future
|
||||||
|
import concurrent.futures
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||||
|
final_output = pool.submit(lambda: asyncio.run(_collect_all())).result()
|
||||||
|
except RuntimeError:
|
||||||
|
# No running loop — safe to use asyncio.run
|
||||||
|
final_output = asyncio.run(_collect_all())
|
||||||
|
|
||||||
|
if final_output:
|
||||||
|
text = final_output.outputs[0].text
|
||||||
print(f"Raw output (first 500 chars): {text[:500]}")
|
print(f"Raw output (first 500 chars): {text[:500]}")
|
||||||
# Extract all <custom_token_XXXX> patterns
|
|
||||||
tokens = re.findall(r'<custom_token_\d+>', text)
|
tokens = re.findall(r'<custom_token_\d+>', text)
|
||||||
print(f"Found {len(tokens)} tokens")
|
print(f"Found {len(tokens)} tokens")
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
yield token
|
yield token
|
||||||
OrpheusModel.generate_tokens_sync = patched_generate_tokens_sync
|
OrpheusModel.generate_tokens_sync = patched_generate_tokens_sync
|
||||||
|
|
||||||
|
# Async streaming token generation (for /tts/stream — yields tokens as produced)
|
||||||
|
async def async_generate_tokens(model_instance, prompt, voice=None,
|
||||||
|
temperature=0.6, top_p=0.8, max_tokens=MAX_TOKENS,
|
||||||
|
stop_token_ids=[49158], repetition_penalty=1.3):
|
||||||
|
"""Async generator: yields tokens incrementally as vLLM produces them."""
|
||||||
|
from vllm import SamplingParams
|
||||||
|
prompt_string = model_instance._format_prompt(prompt, voice)
|
||||||
|
print(f"[streaming] {prompt}")
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=temperature, top_p=top_p, max_tokens=max_tokens,
|
||||||
|
stop_token_ids=stop_token_ids, repetition_penalty=repetition_penalty,
|
||||||
|
)
|
||||||
|
request_id = f"stream-{uuid.uuid4().hex[:8]}"
|
||||||
|
prev_text_len = 0
|
||||||
|
|
||||||
|
async for output in model_instance.engine.generate(prompt_string, sampling_params, request_id):
|
||||||
|
# output.outputs[0].text grows incrementally
|
||||||
|
text = output.outputs[0].text
|
||||||
|
new_text = text[prev_text_len:]
|
||||||
|
prev_text_len = len(text)
|
||||||
|
# Extract new tokens from the incremental text
|
||||||
|
new_tokens = re.findall(r'<custom_token_\d+>', new_text)
|
||||||
|
for token in new_tokens:
|
||||||
|
yield token
|
||||||
|
|
||||||
model = OrpheusModel(model_name=ORPHEUS_MODEL)
|
model = OrpheusModel(model_name=ORPHEUS_MODEL)
|
||||||
|
|
||||||
@@ -586,40 +624,41 @@ async def get_audio(job_id: str):
|
|||||||
@app.post("/tts/stream")
|
@app.post("/tts/stream")
|
||||||
async def stream_tts(request: TTSStreamRequest):
|
async def stream_tts(request: TTSStreamRequest):
|
||||||
"""
|
"""
|
||||||
Stream TTS audio in real-time.
|
Stream TTS audio in real-time with true token-level streaming.
|
||||||
|
|
||||||
For head-vixy to stream directly without waiting for full generation.
|
Audio starts playing within ~1-2s instead of waiting for full generation.
|
||||||
Returns audio chunks as they're generated.
|
Uses vLLM AsyncLLMEngine to yield tokens incrementally, decoded to audio
|
||||||
|
via orpheus_tts.decoder.tokens_decoder every 7 tokens.
|
||||||
"""
|
"""
|
||||||
global model
|
global model
|
||||||
|
|
||||||
if model is None:
|
if model is None:
|
||||||
raise HTTPException(status_code=503, detail="Model not loaded")
|
raise HTTPException(status_code=503, detail="Model not loaded")
|
||||||
|
|
||||||
voice = request.voice
|
voice = request.voice
|
||||||
if voice not in BUILTIN_VOICES:
|
if voice not in BUILTIN_VOICES:
|
||||||
voice = DEFAULT_VOICE
|
voice = DEFAULT_VOICE
|
||||||
|
|
||||||
def sync_audio_generator():
|
async def streaming_audio_generator():
|
||||||
"""Generate audio chunks (sync generator), chunking long text."""
|
"""Async generator: vLLM tokens → SNAC decoder → PCM audio chunks."""
|
||||||
|
from orpheus_tts.decoder import tokens_decoder
|
||||||
try:
|
try:
|
||||||
text_chunks = chunk_text(request.text)
|
text_chunks = chunk_text(request.text)
|
||||||
for chunk_idx, chunk in enumerate(text_chunks):
|
for chunk_idx, chunk in enumerate(text_chunks):
|
||||||
print(f"Stream chunk {chunk_idx + 1}/{len(text_chunks)}: {chunk[:80]}...")
|
print(f"Stream chunk {chunk_idx + 1}/{len(text_chunks)}: {chunk[:80]}...")
|
||||||
syn_tokens = model.generate_speech(
|
|
||||||
prompt=chunk,
|
# async_generate_tokens yields tokens as vLLM produces them
|
||||||
voice=voice,
|
# tokens_decoder converts them to audio every 7 tokens
|
||||||
max_tokens=MAX_TOKENS,
|
token_gen = async_generate_tokens(model, chunk, voice=voice)
|
||||||
)
|
async for audio_chunk in tokens_decoder(token_gen):
|
||||||
for audio_chunk in syn_tokens:
|
|
||||||
yield audio_chunk
|
yield audio_chunk
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Stream error: {e}")
|
print(f"Stream error: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
sync_audio_generator(),
|
streaming_audio_generator(),
|
||||||
media_type="audio/wav"
|
media_type="audio/pcm"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user