Monkey-patch OrpheusModel to support max_model_len on Jetson
This commit is contained in:
21
main.py
21
main.py
@@ -325,12 +325,23 @@ async def startup():
|
||||
# Import and load Orpheus model
|
||||
print("Loading Orpheus model (this may take a moment)...")
|
||||
from orpheus_tts import OrpheusModel
|
||||
from vllm import AsyncLLMEngine
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
|
||||
# GitHub version supports max_model_len for memory control
|
||||
model = OrpheusModel(
|
||||
model_name=ORPHEUS_MODEL,
|
||||
max_model_len=MAX_MODEL_LEN
|
||||
)
|
||||
# Monkey-patch OrpheusModel to support max_model_len (PyPI version doesn't)
|
||||
original_setup_engine = OrpheusModel._setup_engine
|
||||
def patched_setup_engine(self):
|
||||
model_name = self._map_model_params()
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=MAX_MODEL_LEN, # Our custom limit!
|
||||
gpu_memory_utilization=0.85, # Leave some headroom
|
||||
enforce_eager=False,
|
||||
)
|
||||
return AsyncLLMEngine.from_engine_args(engine_args)
|
||||
OrpheusModel._setup_engine = patched_setup_engine
|
||||
|
||||
model = OrpheusModel(model_name=ORPHEUS_MODEL)
|
||||
|
||||
print("✓ Orpheus model loaded successfully")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user