Fix audio quality: use original orpheus_tts convert_to_audio decoder
Our custom SNAC redistribution had wrong layer mapping (positions 1,2 vs 1,4 for layer 2) and incorrect audio slicing. Switched to importing convert_to_audio directly from orpheus_tts.decoder which handles the sliding window, layer redistribution, and 2048:4096 audio slice correctly. Audio now sounds clean with only a subtle boundary artifact on the first token group (inherent to SNAC streaming, not our code). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
65
main.py
65
main.py
@@ -270,45 +270,35 @@ def extract_audio_codes(text_stream: Generator[str, None, None]) -> Generator[in
|
|||||||
|
|
||||||
def redistribute_codes(codes: list) -> list:
|
def redistribute_codes(codes: list) -> list:
|
||||||
"""Redistribute flat code list into SNAC's 3 hierarchical layers.
|
"""Redistribute flat code list into SNAC's 3 hierarchical layers.
|
||||||
Each group of 7: [L1, L2a, L2b, L3a, L3b, L3c, L3d]
|
Each group of 7 maps as: [0]=L1, [1]=L2, [2]=L3, [3]=L3, [4]=L2, [5]=L3, [6]=L3
|
||||||
Codes are already offset-corrected to 0-4095 range per layer."""
|
(from orpheus_tts.decoder.convert_to_audio)"""
|
||||||
layer1, layer2, layer3 = [], [], []
|
layer1, layer2, layer3 = [], [], []
|
||||||
for i in range(0, len(codes), SNAC_CHUNK_SIZE):
|
for i in range(0, len(codes), SNAC_CHUNK_SIZE):
|
||||||
group = codes[i:i + SNAC_CHUNK_SIZE]
|
group = codes[i:i + SNAC_CHUNK_SIZE]
|
||||||
if len(group) < SNAC_CHUNK_SIZE:
|
if len(group) < SNAC_CHUNK_SIZE:
|
||||||
break
|
break
|
||||||
layer1.append(group[0])
|
layer1.append(group[0])
|
||||||
layer2.extend(group[1:3])
|
layer2.append(group[1])
|
||||||
layer3.extend(group[3:7])
|
layer2.append(group[4])
|
||||||
|
layer3.append(group[2])
|
||||||
|
layer3.append(group[3])
|
||||||
|
layer3.append(group[5])
|
||||||
|
layer3.append(group[6])
|
||||||
return [layer1, layer2, layer3]
|
return [layer1, layer2, layer3]
|
||||||
|
|
||||||
|
|
||||||
SNAC_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
SNAC_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
def snac_decode(codes: list) -> Optional[bytes]:
|
|
||||||
"""Decode SNAC codes to PCM audio bytes."""
|
|
||||||
layers = redistribute_codes(codes)
|
|
||||||
if not layers[0]:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
with torch.no_grad():
|
|
||||||
codes_tensor = [
|
|
||||||
torch.tensor(layer, device=SNAC_DEVICE, dtype=torch.long).unsqueeze(0)
|
|
||||||
for layer in layers
|
|
||||||
]
|
|
||||||
audio_hat = snac_model.decode(codes_tensor)
|
|
||||||
|
|
||||||
audio_np = audio_hat.squeeze().cpu().numpy()
|
|
||||||
audio_int16 = (audio_np * 32767).clip(-32768, 32767).astype(np.int16)
|
|
||||||
return audio_int16.tobytes()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[snac] Decode error: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def decode_audio_stream(code_stream: Generator[int, None, None]) -> Generator[bytes, None, None]:
|
def decode_audio_stream(code_stream: Generator[int, None, None]) -> Generator[bytes, None, None]:
|
||||||
"""Convert streaming SNAC codes to PCM audio chunks."""
|
"""Convert streaming SNAC codes to PCM audio chunks.
|
||||||
|
|
||||||
|
Uses the exact same decode logic as orpheus_tts.decoder.convert_to_audio:
|
||||||
|
accumulate all codes, decode the last 28 every 7 new codes,
|
||||||
|
slice audio_hat[:,:,2048:4096] for the non-overlapping portion.
|
||||||
|
"""
|
||||||
|
from orpheus_tts.decoder import convert_to_audio as _original_convert
|
||||||
|
|
||||||
buffer = []
|
buffer = []
|
||||||
total_codes = 0
|
total_codes = 0
|
||||||
|
|
||||||
@@ -319,21 +309,16 @@ def decode_audio_stream(code_stream: Generator[int, None, None]) -> Generator[by
|
|||||||
if total_codes == 1:
|
if total_codes == 1:
|
||||||
print(f"[snac] First code received: {code}")
|
print(f"[snac] First code received: {code}")
|
||||||
|
|
||||||
# Decode every SNAC_INITIAL_GROUPS groups (sliding window)
|
# The original decoder triggers every 7 codes after 28 minimum
|
||||||
if total_codes % SNAC_CHUNK_SIZE == 0:
|
if total_codes % SNAC_CHUNK_SIZE == 0 and total_codes > 27:
|
||||||
groups = total_codes // SNAC_CHUNK_SIZE
|
# Pass the last 28 codes, matching the original exactly
|
||||||
if groups >= SNAC_INITIAL_GROUPS:
|
audio_bytes = _original_convert(buffer[-28:], total_codes)
|
||||||
decode_size = SNAC_INITIAL_GROUPS * SNAC_CHUNK_SIZE
|
if audio_bytes is not None:
|
||||||
codes_to_decode = buffer[-decode_size:]
|
if total_codes == 28:
|
||||||
if groups == SNAC_INITIAL_GROUPS:
|
print(f"[snac] First audio at {total_codes} codes")
|
||||||
print(f"[snac] First decode at {total_codes} codes, values: {codes_to_decode[:7]}")
|
yield audio_bytes
|
||||||
audio = snac_decode(codes_to_decode)
|
|
||||||
if audio:
|
|
||||||
yield audio
|
|
||||||
elif groups == SNAC_INITIAL_GROUPS:
|
|
||||||
print(f"[snac] WARNING: decode returned None")
|
|
||||||
|
|
||||||
print(f"[snac] Total codes: {total_codes}, groups: {total_codes // SNAC_CHUNK_SIZE}")
|
print(f"[snac] Total codes: {total_codes}")
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|||||||
Reference in New Issue
Block a user