Use sync vllm.LLM instead of AsyncLLMEngine to avoid event loop conflicts
This commit is contained in:
29
main.py
29
main.py
@@ -339,20 +339,43 @@ async def startup():
|
|||||||
from vllm import AsyncLLMEngine
|
from vllm import AsyncLLMEngine
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
|
|
||||||
# Monkey-patch OrpheusModel to support max_model_len (PyPI version doesn't)
|
# Monkey-patch OrpheusModel to use sync LLM for proper sync context
|
||||||
original_setup_engine = OrpheusModel._setup_engine
|
original_setup_engine = OrpheusModel._setup_engine
|
||||||
def patched_setup_engine(self):
|
def patched_setup_engine(self):
|
||||||
# Get the mapped model name (handles "medium-3b" -> full path)
|
# Get the mapped model name (handles "medium-3b" -> full path)
|
||||||
model_name = self._map_model_params(self.model_name)
|
model_name = self._map_model_params(self.model_name)
|
||||||
engine_args = AsyncEngineArgs(
|
# Use LLM (sync) instead of AsyncLLMEngine to avoid event loop conflicts
|
||||||
|
from vllm import LLM
|
||||||
|
return LLM(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
max_model_len=MAX_MODEL_LEN, # Our custom limit!
|
max_model_len=MAX_MODEL_LEN, # Our custom limit!
|
||||||
gpu_memory_utilization=0.85, # Leave some headroom
|
gpu_memory_utilization=0.85, # Leave some headroom
|
||||||
enforce_eager=False,
|
enforce_eager=False,
|
||||||
)
|
)
|
||||||
return AsyncLLMEngine.from_engine_args(engine_args)
|
|
||||||
OrpheusModel._setup_engine = patched_setup_engine
|
OrpheusModel._setup_engine = patched_setup_engine
|
||||||
|
|
||||||
|
# Also patch generate_tokens_sync to work with sync LLM
|
||||||
|
def patched_generate_tokens_sync(self, prompt, voice=None, request_id="req-001", temperature=0.6, top_p=0.8, max_tokens=1200, stop_token_ids=[49158], repetition_penalty=1.3):
|
||||||
|
from vllm import SamplingParams
|
||||||
|
prompt_string = self._format_prompt(prompt, voice)
|
||||||
|
print(prompt)
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
stop_token_ids=stop_token_ids,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
)
|
||||||
|
# Use sync generate - yields full output
|
||||||
|
outputs = self.engine.generate([prompt_string], sampling_params)
|
||||||
|
# Yield individual tokens from the output text
|
||||||
|
for output in outputs:
|
||||||
|
text = output.outputs[0].text
|
||||||
|
# Tokens are space-separated custom_token_XXX
|
||||||
|
for token in text.split():
|
||||||
|
yield token
|
||||||
|
OrpheusModel.generate_tokens_sync = patched_generate_tokens_sync
|
||||||
|
|
||||||
model = OrpheusModel(model_name=ORPHEUS_MODEL)
|
model = OrpheusModel(model_name=ORPHEUS_MODEL)
|
||||||
|
|
||||||
print("✓ Orpheus model loaded successfully")
|
print("✓ Orpheus model loaded successfully")
|
||||||
|
|||||||
Reference in New Issue
Block a user