Revert to sync LLM + sentence-level streaming
AsyncLLMEngine hangs on Jetson during model loading. Reverted to sync LLM but added fine-grained text chunking (chunk_text_fine, ~200 chars) for the stream endpoint. Each sentence/clause generates independently, so first audio plays after ~2-4s instead of waiting for the full text. Not true token-level streaming, but a significant latency reduction for multi-sentence utterances without AsyncLLMEngine dependency. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
122
main.py
122
main.py
@@ -197,6 +197,38 @@ def chunk_text(text: str, max_chars: int = 800) -> List[str]:
|
||||
return chunks if chunks else [text]
|
||||
|
||||
|
||||
def chunk_text_fine(text: str, max_chars: int = 200) -> List[str]:
|
||||
"""
|
||||
Split text into fine-grained chunks for streaming — every sentence or clause.
|
||||
Smaller chunks = faster first-audio, slight quality tradeoff at boundaries.
|
||||
"""
|
||||
# Split on sentence boundaries AND commas/semicolons with reasonable length
|
||||
parts = re.split(r'(?<=[.!?;])\s+', text.strip())
|
||||
|
||||
# Further split long parts on commas
|
||||
chunks = []
|
||||
for part in parts:
|
||||
if len(part) <= max_chars:
|
||||
chunks.append(part)
|
||||
else:
|
||||
# Split on commas
|
||||
sub = re.split(r',\s+', part)
|
||||
current = []
|
||||
current_len = 0
|
||||
for s in sub:
|
||||
if current and current_len + len(s) > max_chars:
|
||||
chunks.append(', '.join(current))
|
||||
current = []
|
||||
current_len = 0
|
||||
current.append(s)
|
||||
current_len += len(s)
|
||||
if current:
|
||||
chunks.append(', '.join(current))
|
||||
|
||||
# Filter empty chunks
|
||||
return [c.strip() for c in chunks if c.strip()]
|
||||
|
||||
|
||||
def generate_speech_sync(text: str, voice: str) -> bytes:
|
||||
"""
|
||||
Generate speech using Orpheus model (synchronous).
|
||||
@@ -374,21 +406,20 @@ async def startup():
|
||||
from vllm import AsyncLLMEngine
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
|
||||
# Monkey-patch OrpheusModel to use AsyncLLMEngine for true streaming
|
||||
# Monkey-patch OrpheusModel to use sync LLM (AsyncLLMEngine hangs on Jetson)
|
||||
original_setup_engine = OrpheusModel._setup_engine
|
||||
def patched_setup_engine(self):
|
||||
model_name = self._map_model_params(self.model_name)
|
||||
engine_args = AsyncEngineArgs(
|
||||
from vllm import LLM
|
||||
return LLM(
|
||||
model=model_name,
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
gpu_memory_utilization=0.85,
|
||||
enforce_eager=False,
|
||||
)
|
||||
return AsyncLLMEngine.from_engine_args(engine_args)
|
||||
OrpheusModel._setup_engine = patched_setup_engine
|
||||
|
||||
# Sync token generation (for background jobs)
|
||||
# Uses the async engine but collects all results synchronously
|
||||
# Sync token generation
|
||||
def patched_generate_tokens_sync(self, prompt, voice=None, request_id="req-001",
|
||||
temperature=0.6, top_p=0.8, max_tokens=MAX_TOKENS,
|
||||
stop_token_ids=[49158], repetition_penalty=1.3):
|
||||
@@ -400,27 +431,9 @@ async def startup():
|
||||
temperature=temperature, top_p=top_p, max_tokens=max_tokens,
|
||||
stop_token_ids=stop_token_ids, repetition_penalty=repetition_penalty,
|
||||
)
|
||||
req_id = f"sync-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# Collect from async engine — use the running loop if available, else create one
|
||||
async def _collect_all():
|
||||
final = None
|
||||
async for output in self.engine.generate(prompt_string, sampling_params, req_id):
|
||||
final = output
|
||||
return final
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
# We're in an async context (background task) — use asyncio.ensure_future
|
||||
import concurrent.futures
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
final_output = pool.submit(lambda: asyncio.run(_collect_all())).result()
|
||||
except RuntimeError:
|
||||
# No running loop — safe to use asyncio.run
|
||||
final_output = asyncio.run(_collect_all())
|
||||
|
||||
if final_output:
|
||||
text = final_output.outputs[0].text
|
||||
outputs = self.engine.generate([prompt_string], sampling_params)
|
||||
for output in outputs:
|
||||
text = output.outputs[0].text
|
||||
print(f"Raw output (first 500 chars): {text[:500]}")
|
||||
tokens = re.findall(r'<custom_token_\d+>', text)
|
||||
print(f"Found {len(tokens)} tokens")
|
||||
@@ -428,31 +441,6 @@ async def startup():
|
||||
yield token
|
||||
OrpheusModel.generate_tokens_sync = patched_generate_tokens_sync
|
||||
|
||||
# Async streaming token generation (for /tts/stream — yields tokens as produced)
|
||||
async def async_generate_tokens(model_instance, prompt, voice=None,
|
||||
temperature=0.6, top_p=0.8, max_tokens=MAX_TOKENS,
|
||||
stop_token_ids=[49158], repetition_penalty=1.3):
|
||||
"""Async generator: yields tokens incrementally as vLLM produces them."""
|
||||
from vllm import SamplingParams
|
||||
prompt_string = model_instance._format_prompt(prompt, voice)
|
||||
print(f"[streaming] {prompt}")
|
||||
sampling_params = SamplingParams(
|
||||
temperature=temperature, top_p=top_p, max_tokens=max_tokens,
|
||||
stop_token_ids=stop_token_ids, repetition_penalty=repetition_penalty,
|
||||
)
|
||||
request_id = f"stream-{uuid.uuid4().hex[:8]}"
|
||||
prev_text_len = 0
|
||||
|
||||
async for output in model_instance.engine.generate(prompt_string, sampling_params, request_id):
|
||||
# output.outputs[0].text grows incrementally
|
||||
text = output.outputs[0].text
|
||||
new_text = text[prev_text_len:]
|
||||
prev_text_len = len(text)
|
||||
# Extract new tokens from the incremental text
|
||||
new_tokens = re.findall(r'<custom_token_\d+>', new_text)
|
||||
for token in new_tokens:
|
||||
yield token
|
||||
|
||||
model = OrpheusModel(model_name=ORPHEUS_MODEL)
|
||||
|
||||
print("✓ Orpheus model loaded successfully")
|
||||
@@ -624,11 +612,11 @@ async def get_audio(job_id: str):
|
||||
@app.post("/tts/stream")
|
||||
async def stream_tts(request: TTSStreamRequest):
|
||||
"""
|
||||
Stream TTS audio in real-time with true token-level streaming.
|
||||
Stream TTS audio with sentence-level chunking.
|
||||
|
||||
Audio starts playing within ~1-2s instead of waiting for full generation.
|
||||
Uses vLLM AsyncLLMEngine to yield tokens incrementally, decoded to audio
|
||||
via orpheus_tts.decoder.tokens_decoder every 7 tokens.
|
||||
Splits text into small chunks (sentences/clauses) and generates each
|
||||
independently. First chunk's audio starts playing while later chunks
|
||||
are still generating. Reduces perceived latency significantly.
|
||||
"""
|
||||
global model
|
||||
|
||||
@@ -639,25 +627,27 @@ async def stream_tts(request: TTSStreamRequest):
|
||||
if voice not in BUILTIN_VOICES:
|
||||
voice = DEFAULT_VOICE
|
||||
|
||||
async def streaming_audio_generator():
|
||||
"""Async generator: vLLM tokens → SNAC decoder → PCM audio chunks."""
|
||||
from orpheus_tts.decoder import tokens_decoder
|
||||
def sync_audio_generator():
|
||||
"""Generate audio per-sentence, yielding as each finishes."""
|
||||
try:
|
||||
text_chunks = chunk_text(request.text)
|
||||
# Split into fine-grained chunks for faster first-audio
|
||||
text_chunks = chunk_text_fine(request.text)
|
||||
print(f"[stream] {len(text_chunks)} chunk(s): {[c[:40] for c in text_chunks]}")
|
||||
for chunk_idx, chunk in enumerate(text_chunks):
|
||||
print(f"Stream chunk {chunk_idx + 1}/{len(text_chunks)}: {chunk[:80]}...")
|
||||
|
||||
# async_generate_tokens yields tokens as vLLM produces them
|
||||
# tokens_decoder converts them to audio every 7 tokens
|
||||
token_gen = async_generate_tokens(model, chunk, voice=voice)
|
||||
async for audio_chunk in tokens_decoder(token_gen):
|
||||
print(f" Generating chunk {chunk_idx + 1}/{len(text_chunks)}: {chunk[:60]}...")
|
||||
syn_tokens = model.generate_speech(
|
||||
prompt=chunk,
|
||||
voice=voice,
|
||||
max_tokens=MAX_TOKENS,
|
||||
)
|
||||
for audio_chunk in syn_tokens:
|
||||
yield audio_chunk
|
||||
except Exception as e:
|
||||
print(f"Stream error: {e}")
|
||||
raise
|
||||
|
||||
return StreamingResponse(
|
||||
streaming_audio_generator(),
|
||||
sync_audio_generator(),
|
||||
media_type="audio/pcm"
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user