From af35dc46d55dad3a4a108b1411b6cf79a9190ff1 Mon Sep 17 00:00:00 2001 From: vixy Date: Sun, 11 Jan 2026 18:58:12 -0600 Subject: [PATCH] Use sync vllm.LLM instead of AsyncLLMEngine to avoid event loop conflicts --- main.py | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index abb3c28..8264009 100644 --- a/main.py +++ b/main.py @@ -339,20 +339,43 @@ async def startup(): from vllm import AsyncLLMEngine from vllm.engine.arg_utils import AsyncEngineArgs - # Monkey-patch OrpheusModel to support max_model_len (PyPI version doesn't) + # Monkey-patch OrpheusModel to use sync LLM for proper sync context original_setup_engine = OrpheusModel._setup_engine def patched_setup_engine(self): # Get the mapped model name (handles "medium-3b" -> full path) model_name = self._map_model_params(self.model_name) - engine_args = AsyncEngineArgs( + # Use LLM (sync) instead of AsyncLLMEngine to avoid event loop conflicts + from vllm import LLM + return LLM( model=model_name, max_model_len=MAX_MODEL_LEN, # Our custom limit! gpu_memory_utilization=0.85, # Leave some headroom enforce_eager=False, ) - return AsyncLLMEngine.from_engine_args(engine_args) OrpheusModel._setup_engine = patched_setup_engine + # Also patch generate_tokens_sync to work with sync LLM + def patched_generate_tokens_sync(self, prompt, voice=None, request_id="req-001", temperature=0.6, top_p=0.8, max_tokens=1200, stop_token_ids=[49158], repetition_penalty=1.3): + from vllm import SamplingParams + prompt_string = self._format_prompt(prompt, voice) + print(prompt) + sampling_params = SamplingParams( + temperature=temperature, + top_p=top_p, + max_tokens=max_tokens, + stop_token_ids=stop_token_ids, + repetition_penalty=repetition_penalty, + ) + # Use sync generate - yields full output + outputs = self.engine.generate([prompt_string], sampling_params) + # Yield individual tokens from the output text + for output in outputs: + text = output.outputs[0].text + # Tokens are space-separated custom_token_XXX + for token in text.split(): + yield token + OrpheusModel.generate_tokens_sync = patched_generate_tokens_sync + model = OrpheusModel(model_name=ORPHEUS_MODEL) print("✓ Orpheus model loaded successfully")