Fix token extraction - use regex to find custom_token patterns
This commit is contained in:
8
main.py
8
main.py
@@ -357,6 +357,7 @@ async def startup():
|
|||||||
# Also patch generate_tokens_sync to work with sync LLM
|
# Also patch generate_tokens_sync to work with sync LLM
|
||||||
def patched_generate_tokens_sync(self, prompt, voice=None, request_id="req-001", temperature=0.6, top_p=0.8, max_tokens=1200, stop_token_ids=[49158], repetition_penalty=1.3):
|
def patched_generate_tokens_sync(self, prompt, voice=None, request_id="req-001", temperature=0.6, top_p=0.8, max_tokens=1200, stop_token_ids=[49158], repetition_penalty=1.3):
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
|
import re
|
||||||
prompt_string = self._format_prompt(prompt, voice)
|
prompt_string = self._format_prompt(prompt, voice)
|
||||||
print(prompt)
|
print(prompt)
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
@@ -371,8 +372,11 @@ async def startup():
|
|||||||
# Yield individual tokens from the output text
|
# Yield individual tokens from the output text
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
text = output.outputs[0].text
|
text = output.outputs[0].text
|
||||||
# Tokens are space-separated custom_token_XXX
|
print(f"Raw output (first 500 chars): {text[:500]}")
|
||||||
for token in text.split():
|
# Extract all <custom_token_XXXX> patterns
|
||||||
|
tokens = re.findall(r'<custom_token_\d+>', text)
|
||||||
|
print(f"Found {len(tokens)} tokens")
|
||||||
|
for token in tokens:
|
||||||
yield token
|
yield token
|
||||||
OrpheusModel.generate_tokens_sync = patched_generate_tokens_sync
|
OrpheusModel.generate_tokens_sync = patched_generate_tokens_sync
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user