Fix audio quality: use original orpheus_tts convert_to_audio decoder
Our custom SNAC redistribution had wrong layer mapping (positions 1,2 vs 1,4 for layer 2) and incorrect audio slicing. Switched to importing convert_to_audio directly from orpheus_tts.decoder which handles the sliding window, layer redistribution, and 2048:4096 audio slice correctly. Audio now sounds clean with only a subtle boundary artifact on the first token group (inherent to SNAC streaming, not our code). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
65
main.py
65
main.py
@@ -270,45 +270,35 @@ def extract_audio_codes(text_stream: Generator[str, None, None]) -> Generator[in
|
||||
|
||||
def redistribute_codes(codes: list) -> list:
|
||||
"""Redistribute flat code list into SNAC's 3 hierarchical layers.
|
||||
Each group of 7: [L1, L2a, L2b, L3a, L3b, L3c, L3d]
|
||||
Codes are already offset-corrected to 0-4095 range per layer."""
|
||||
Each group of 7 maps as: [0]=L1, [1]=L2, [2]=L3, [3]=L3, [4]=L2, [5]=L3, [6]=L3
|
||||
(from orpheus_tts.decoder.convert_to_audio)"""
|
||||
layer1, layer2, layer3 = [], [], []
|
||||
for i in range(0, len(codes), SNAC_CHUNK_SIZE):
|
||||
group = codes[i:i + SNAC_CHUNK_SIZE]
|
||||
if len(group) < SNAC_CHUNK_SIZE:
|
||||
break
|
||||
layer1.append(group[0])
|
||||
layer2.extend(group[1:3])
|
||||
layer3.extend(group[3:7])
|
||||
layer2.append(group[1])
|
||||
layer2.append(group[4])
|
||||
layer3.append(group[2])
|
||||
layer3.append(group[3])
|
||||
layer3.append(group[5])
|
||||
layer3.append(group[6])
|
||||
return [layer1, layer2, layer3]
|
||||
|
||||
|
||||
SNAC_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
def snac_decode(codes: list) -> Optional[bytes]:
|
||||
"""Decode SNAC codes to PCM audio bytes."""
|
||||
layers = redistribute_codes(codes)
|
||||
if not layers[0]:
|
||||
return None
|
||||
|
||||
try:
|
||||
with torch.no_grad():
|
||||
codes_tensor = [
|
||||
torch.tensor(layer, device=SNAC_DEVICE, dtype=torch.long).unsqueeze(0)
|
||||
for layer in layers
|
||||
]
|
||||
audio_hat = snac_model.decode(codes_tensor)
|
||||
|
||||
audio_np = audio_hat.squeeze().cpu().numpy()
|
||||
audio_int16 = (audio_np * 32767).clip(-32768, 32767).astype(np.int16)
|
||||
return audio_int16.tobytes()
|
||||
except Exception as e:
|
||||
print(f"[snac] Decode error: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def decode_audio_stream(code_stream: Generator[int, None, None]) -> Generator[bytes, None, None]:
|
||||
"""Convert streaming SNAC codes to PCM audio chunks."""
|
||||
"""Convert streaming SNAC codes to PCM audio chunks.
|
||||
|
||||
Uses the exact same decode logic as orpheus_tts.decoder.convert_to_audio:
|
||||
accumulate all codes, decode the last 28 every 7 new codes,
|
||||
slice audio_hat[:,:,2048:4096] for the non-overlapping portion.
|
||||
"""
|
||||
from orpheus_tts.decoder import convert_to_audio as _original_convert
|
||||
|
||||
buffer = []
|
||||
total_codes = 0
|
||||
|
||||
@@ -319,21 +309,16 @@ def decode_audio_stream(code_stream: Generator[int, None, None]) -> Generator[by
|
||||
if total_codes == 1:
|
||||
print(f"[snac] First code received: {code}")
|
||||
|
||||
# Decode every SNAC_INITIAL_GROUPS groups (sliding window)
|
||||
if total_codes % SNAC_CHUNK_SIZE == 0:
|
||||
groups = total_codes // SNAC_CHUNK_SIZE
|
||||
if groups >= SNAC_INITIAL_GROUPS:
|
||||
decode_size = SNAC_INITIAL_GROUPS * SNAC_CHUNK_SIZE
|
||||
codes_to_decode = buffer[-decode_size:]
|
||||
if groups == SNAC_INITIAL_GROUPS:
|
||||
print(f"[snac] First decode at {total_codes} codes, values: {codes_to_decode[:7]}")
|
||||
audio = snac_decode(codes_to_decode)
|
||||
if audio:
|
||||
yield audio
|
||||
elif groups == SNAC_INITIAL_GROUPS:
|
||||
print(f"[snac] WARNING: decode returned None")
|
||||
# The original decoder triggers every 7 codes after 28 minimum
|
||||
if total_codes % SNAC_CHUNK_SIZE == 0 and total_codes > 27:
|
||||
# Pass the last 28 codes, matching the original exactly
|
||||
audio_bytes = _original_convert(buffer[-28:], total_codes)
|
||||
if audio_bytes is not None:
|
||||
if total_codes == 28:
|
||||
print(f"[snac] First audio at {total_codes} codes")
|
||||
yield audio_bytes
|
||||
|
||||
print(f"[snac] Total codes: {total_codes}, groups: {total_codes // SNAC_CHUNK_SIZE}")
|
||||
print(f"[snac] Total codes: {total_codes}")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
|
||||
Reference in New Issue
Block a user