Monkey-patch OrpheusModel to support max_model_len on Jetson
This commit is contained in:
@@ -21,9 +21,9 @@ COPY requirements.txt /app/
|
|||||||
# Install Python dependencies (FastAPI, etc - but NOT torch/vllm)
|
# Install Python dependencies (FastAPI, etc - but NOT torch/vllm)
|
||||||
RUN pip3 install --no-cache-dir -r requirements.txt
|
RUN pip3 install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
# Install orpheus-speech from GitHub repo (supports max_model_len) WITHOUT dependencies
|
# Install orpheus-speech from regular PyPI WITHOUT dependencies
|
||||||
# to avoid overwriting vllm/torch. Then install snac audio codec.
|
# to avoid overwriting vllm/torch. Then install snac audio codec.
|
||||||
RUN pip3 install --no-cache-dir --no-deps git+https://github.com/canopyai/Orpheus-TTS.git#subdirectory=orpheus_tts_pypi && \
|
RUN pip3 install --no-cache-dir --no-deps --index-url https://pypi.org/simple/ orpheus-speech && \
|
||||||
pip3 install --no-cache-dir --index-url https://pypi.org/simple/ snac
|
pip3 install --no-cache-dir --index-url https://pypi.org/simple/ snac
|
||||||
|
|
||||||
# Copy application code
|
# Copy application code
|
||||||
|
|||||||
19
main.py
19
main.py
@@ -325,12 +325,23 @@ async def startup():
|
|||||||
# Import and load Orpheus model
|
# Import and load Orpheus model
|
||||||
print("Loading Orpheus model (this may take a moment)...")
|
print("Loading Orpheus model (this may take a moment)...")
|
||||||
from orpheus_tts import OrpheusModel
|
from orpheus_tts import OrpheusModel
|
||||||
|
from vllm import AsyncLLMEngine
|
||||||
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
|
|
||||||
# GitHub version supports max_model_len for memory control
|
# Monkey-patch OrpheusModel to support max_model_len (PyPI version doesn't)
|
||||||
model = OrpheusModel(
|
original_setup_engine = OrpheusModel._setup_engine
|
||||||
model_name=ORPHEUS_MODEL,
|
def patched_setup_engine(self):
|
||||||
max_model_len=MAX_MODEL_LEN
|
model_name = self._map_model_params()
|
||||||
|
engine_args = AsyncEngineArgs(
|
||||||
|
model=model_name,
|
||||||
|
max_model_len=MAX_MODEL_LEN, # Our custom limit!
|
||||||
|
gpu_memory_utilization=0.85, # Leave some headroom
|
||||||
|
enforce_eager=False,
|
||||||
)
|
)
|
||||||
|
return AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
|
OrpheusModel._setup_engine = patched_setup_engine
|
||||||
|
|
||||||
|
model = OrpheusModel(model_name=ORPHEUS_MODEL)
|
||||||
|
|
||||||
print("✓ Orpheus model loaded successfully")
|
print("✓ Orpheus model loaded successfully")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user