From 4989e0a7e8b24a1f2e85a3564d2b0ff372ff0f88 Mon Sep 17 00:00:00 2001 From: Alex Date: Tue, 14 Apr 2026 01:07:29 -0500 Subject: [PATCH] Fix audio quality: use original orpheus_tts convert_to_audio decoder Our custom SNAC redistribution had wrong layer mapping (positions 1,2 vs 1,4 for layer 2) and incorrect audio slicing. Switched to importing convert_to_audio directly from orpheus_tts.decoder which handles the sliding window, layer redistribution, and 2048:4096 audio slice correctly. Audio now sounds clean with only a subtle boundary artifact on the first token group (inherent to SNAC streaming, not our code). Co-Authored-By: Claude Opus 4.6 (1M context) --- main.py | 65 ++++++++++++++++++++++----------------------------------- 1 file changed, 25 insertions(+), 40 deletions(-) diff --git a/main.py b/main.py index 524c5d5..d3219cf 100644 --- a/main.py +++ b/main.py @@ -270,45 +270,35 @@ def extract_audio_codes(text_stream: Generator[str, None, None]) -> Generator[in def redistribute_codes(codes: list) -> list: """Redistribute flat code list into SNAC's 3 hierarchical layers. - Each group of 7: [L1, L2a, L2b, L3a, L3b, L3c, L3d] - Codes are already offset-corrected to 0-4095 range per layer.""" + Each group of 7 maps as: [0]=L1, [1]=L2, [2]=L3, [3]=L3, [4]=L2, [5]=L3, [6]=L3 + (from orpheus_tts.decoder.convert_to_audio)""" layer1, layer2, layer3 = [], [], [] for i in range(0, len(codes), SNAC_CHUNK_SIZE): group = codes[i:i + SNAC_CHUNK_SIZE] if len(group) < SNAC_CHUNK_SIZE: break layer1.append(group[0]) - layer2.extend(group[1:3]) - layer3.extend(group[3:7]) + layer2.append(group[1]) + layer2.append(group[4]) + layer3.append(group[2]) + layer3.append(group[3]) + layer3.append(group[5]) + layer3.append(group[6]) return [layer1, layer2, layer3] SNAC_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" -def snac_decode(codes: list) -> Optional[bytes]: - """Decode SNAC codes to PCM audio bytes.""" - layers = redistribute_codes(codes) - if not layers[0]: - return None - - try: - with torch.no_grad(): - codes_tensor = [ - torch.tensor(layer, device=SNAC_DEVICE, dtype=torch.long).unsqueeze(0) - for layer in layers - ] - audio_hat = snac_model.decode(codes_tensor) - - audio_np = audio_hat.squeeze().cpu().numpy() - audio_int16 = (audio_np * 32767).clip(-32768, 32767).astype(np.int16) - return audio_int16.tobytes() - except Exception as e: - print(f"[snac] Decode error: {e}") - return None - def decode_audio_stream(code_stream: Generator[int, None, None]) -> Generator[bytes, None, None]: - """Convert streaming SNAC codes to PCM audio chunks.""" + """Convert streaming SNAC codes to PCM audio chunks. + + Uses the exact same decode logic as orpheus_tts.decoder.convert_to_audio: + accumulate all codes, decode the last 28 every 7 new codes, + slice audio_hat[:,:,2048:4096] for the non-overlapping portion. + """ + from orpheus_tts.decoder import convert_to_audio as _original_convert + buffer = [] total_codes = 0 @@ -319,21 +309,16 @@ def decode_audio_stream(code_stream: Generator[int, None, None]) -> Generator[by if total_codes == 1: print(f"[snac] First code received: {code}") - # Decode every SNAC_INITIAL_GROUPS groups (sliding window) - if total_codes % SNAC_CHUNK_SIZE == 0: - groups = total_codes // SNAC_CHUNK_SIZE - if groups >= SNAC_INITIAL_GROUPS: - decode_size = SNAC_INITIAL_GROUPS * SNAC_CHUNK_SIZE - codes_to_decode = buffer[-decode_size:] - if groups == SNAC_INITIAL_GROUPS: - print(f"[snac] First decode at {total_codes} codes, values: {codes_to_decode[:7]}") - audio = snac_decode(codes_to_decode) - if audio: - yield audio - elif groups == SNAC_INITIAL_GROUPS: - print(f"[snac] WARNING: decode returned None") + # The original decoder triggers every 7 codes after 28 minimum + if total_codes % SNAC_CHUNK_SIZE == 0 and total_codes > 27: + # Pass the last 28 codes, matching the original exactly + audio_bytes = _original_convert(buffer[-28:], total_codes) + if audio_bytes is not None: + if total_codes == 28: + print(f"[snac] First audio at {total_codes} codes") + yield audio_bytes - print(f"[snac] Total codes: {total_codes}, groups: {total_codes // SNAC_CHUNK_SIZE}") + print(f"[snac] Total codes: {total_codes}") # ============================================================================