Use sync vllm.LLM instead of AsyncLLMEngine to avoid event loop conflicts
This commit is contained in:
29
main.py
29
main.py
@@ -339,20 +339,43 @@ async def startup():
|
||||
from vllm import AsyncLLMEngine
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
|
||||
# Monkey-patch OrpheusModel to support max_model_len (PyPI version doesn't)
|
||||
# Monkey-patch OrpheusModel to use sync LLM for proper sync context
|
||||
original_setup_engine = OrpheusModel._setup_engine
|
||||
def patched_setup_engine(self):
|
||||
# Get the mapped model name (handles "medium-3b" -> full path)
|
||||
model_name = self._map_model_params(self.model_name)
|
||||
engine_args = AsyncEngineArgs(
|
||||
# Use LLM (sync) instead of AsyncLLMEngine to avoid event loop conflicts
|
||||
from vllm import LLM
|
||||
return LLM(
|
||||
model=model_name,
|
||||
max_model_len=MAX_MODEL_LEN, # Our custom limit!
|
||||
gpu_memory_utilization=0.85, # Leave some headroom
|
||||
enforce_eager=False,
|
||||
)
|
||||
return AsyncLLMEngine.from_engine_args(engine_args)
|
||||
OrpheusModel._setup_engine = patched_setup_engine
|
||||
|
||||
# Also patch generate_tokens_sync to work with sync LLM
|
||||
def patched_generate_tokens_sync(self, prompt, voice=None, request_id="req-001", temperature=0.6, top_p=0.8, max_tokens=1200, stop_token_ids=[49158], repetition_penalty=1.3):
|
||||
from vllm import SamplingParams
|
||||
prompt_string = self._format_prompt(prompt, voice)
|
||||
print(prompt)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_tokens=max_tokens,
|
||||
stop_token_ids=stop_token_ids,
|
||||
repetition_penalty=repetition_penalty,
|
||||
)
|
||||
# Use sync generate - yields full output
|
||||
outputs = self.engine.generate([prompt_string], sampling_params)
|
||||
# Yield individual tokens from the output text
|
||||
for output in outputs:
|
||||
text = output.outputs[0].text
|
||||
# Tokens are space-separated custom_token_XXX
|
||||
for token in text.split():
|
||||
yield token
|
||||
OrpheusModel.generate_tokens_sync = patched_generate_tokens_sync
|
||||
|
||||
model = OrpheusModel(model_name=ORPHEUS_MODEL)
|
||||
|
||||
print("✓ Orpheus model loaded successfully")
|
||||
|
||||
Reference in New Issue
Block a user