Fix audio quality: use original orpheus_tts convert_to_audio decoder

Our custom SNAC redistribution had wrong layer mapping (positions
1,2 vs 1,4 for layer 2) and incorrect audio slicing. Switched to
importing convert_to_audio directly from orpheus_tts.decoder which
handles the sliding window, layer redistribution, and 2048:4096
audio slice correctly.

Audio now sounds clean with only a subtle boundary artifact on the
first token group (inherent to SNAC streaming, not our code).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Alex
2026-04-14 01:07:29 -05:00
parent 57a2e24101
commit 4989e0a7e8

65
main.py
View File

@@ -270,45 +270,35 @@ def extract_audio_codes(text_stream: Generator[str, None, None]) -> Generator[in
def redistribute_codes(codes: list) -> list: def redistribute_codes(codes: list) -> list:
"""Redistribute flat code list into SNAC's 3 hierarchical layers. """Redistribute flat code list into SNAC's 3 hierarchical layers.
Each group of 7: [L1, L2a, L2b, L3a, L3b, L3c, L3d] Each group of 7 maps as: [0]=L1, [1]=L2, [2]=L3, [3]=L3, [4]=L2, [5]=L3, [6]=L3
Codes are already offset-corrected to 0-4095 range per layer.""" (from orpheus_tts.decoder.convert_to_audio)"""
layer1, layer2, layer3 = [], [], [] layer1, layer2, layer3 = [], [], []
for i in range(0, len(codes), SNAC_CHUNK_SIZE): for i in range(0, len(codes), SNAC_CHUNK_SIZE):
group = codes[i:i + SNAC_CHUNK_SIZE] group = codes[i:i + SNAC_CHUNK_SIZE]
if len(group) < SNAC_CHUNK_SIZE: if len(group) < SNAC_CHUNK_SIZE:
break break
layer1.append(group[0]) layer1.append(group[0])
layer2.extend(group[1:3]) layer2.append(group[1])
layer3.extend(group[3:7]) layer2.append(group[4])
layer3.append(group[2])
layer3.append(group[3])
layer3.append(group[5])
layer3.append(group[6])
return [layer1, layer2, layer3] return [layer1, layer2, layer3]
SNAC_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" SNAC_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def snac_decode(codes: list) -> Optional[bytes]:
"""Decode SNAC codes to PCM audio bytes."""
layers = redistribute_codes(codes)
if not layers[0]:
return None
try:
with torch.no_grad():
codes_tensor = [
torch.tensor(layer, device=SNAC_DEVICE, dtype=torch.long).unsqueeze(0)
for layer in layers
]
audio_hat = snac_model.decode(codes_tensor)
audio_np = audio_hat.squeeze().cpu().numpy()
audio_int16 = (audio_np * 32767).clip(-32768, 32767).astype(np.int16)
return audio_int16.tobytes()
except Exception as e:
print(f"[snac] Decode error: {e}")
return None
def decode_audio_stream(code_stream: Generator[int, None, None]) -> Generator[bytes, None, None]: def decode_audio_stream(code_stream: Generator[int, None, None]) -> Generator[bytes, None, None]:
"""Convert streaming SNAC codes to PCM audio chunks.""" """Convert streaming SNAC codes to PCM audio chunks.
Uses the exact same decode logic as orpheus_tts.decoder.convert_to_audio:
accumulate all codes, decode the last 28 every 7 new codes,
slice audio_hat[:,:,2048:4096] for the non-overlapping portion.
"""
from orpheus_tts.decoder import convert_to_audio as _original_convert
buffer = [] buffer = []
total_codes = 0 total_codes = 0
@@ -319,21 +309,16 @@ def decode_audio_stream(code_stream: Generator[int, None, None]) -> Generator[by
if total_codes == 1: if total_codes == 1:
print(f"[snac] First code received: {code}") print(f"[snac] First code received: {code}")
# Decode every SNAC_INITIAL_GROUPS groups (sliding window) # The original decoder triggers every 7 codes after 28 minimum
if total_codes % SNAC_CHUNK_SIZE == 0: if total_codes % SNAC_CHUNK_SIZE == 0 and total_codes > 27:
groups = total_codes // SNAC_CHUNK_SIZE # Pass the last 28 codes, matching the original exactly
if groups >= SNAC_INITIAL_GROUPS: audio_bytes = _original_convert(buffer[-28:], total_codes)
decode_size = SNAC_INITIAL_GROUPS * SNAC_CHUNK_SIZE if audio_bytes is not None:
codes_to_decode = buffer[-decode_size:] if total_codes == 28:
if groups == SNAC_INITIAL_GROUPS: print(f"[snac] First audio at {total_codes} codes")
print(f"[snac] First decode at {total_codes} codes, values: {codes_to_decode[:7]}") yield audio_bytes
audio = snac_decode(codes_to_decode)
if audio:
yield audio
elif groups == SNAC_INITIAL_GROUPS:
print(f"[snac] WARNING: decode returned None")
print(f"[snac] Total codes: {total_codes}, groups: {total_codes // SNAC_CHUNK_SIZE}") print(f"[snac] Total codes: {total_codes}")
# ============================================================================ # ============================================================================