Fix token extraction - use regex to find custom_token patterns

This commit is contained in:
2026-01-11 19:33:31 -06:00
parent af35dc46d5
commit fe43eda6bd

View File

@@ -357,6 +357,7 @@ async def startup():
# Also patch generate_tokens_sync to work with sync LLM # Also patch generate_tokens_sync to work with sync LLM
def patched_generate_tokens_sync(self, prompt, voice=None, request_id="req-001", temperature=0.6, top_p=0.8, max_tokens=1200, stop_token_ids=[49158], repetition_penalty=1.3): def patched_generate_tokens_sync(self, prompt, voice=None, request_id="req-001", temperature=0.6, top_p=0.8, max_tokens=1200, stop_token_ids=[49158], repetition_penalty=1.3):
from vllm import SamplingParams from vllm import SamplingParams
import re
prompt_string = self._format_prompt(prompt, voice) prompt_string = self._format_prompt(prompt, voice)
print(prompt) print(prompt)
sampling_params = SamplingParams( sampling_params = SamplingParams(
@@ -371,8 +372,11 @@ async def startup():
# Yield individual tokens from the output text # Yield individual tokens from the output text
for output in outputs: for output in outputs:
text = output.outputs[0].text text = output.outputs[0].text
# Tokens are space-separated custom_token_XXX print(f"Raw output (first 500 chars): {text[:500]}")
for token in text.split(): # Extract all <custom_token_XXXX> patterns
tokens = re.findall(r'<custom_token_\d+>', text)
print(f"Found {len(tokens)} tokens")
for token in tokens:
yield token yield token
OrpheusModel.generate_tokens_sync = patched_generate_tokens_sync OrpheusModel.generate_tokens_sync = patched_generate_tokens_sync