Fix SNAC code offset: subtract per-layer offset (position*4096)

SNAC has 3 codebook layers, each 4096 entries. Token position within
the group of 7 determines which layer: pos 0 = L1 (offset 0),
pos 1-2 = L2 (offset 4096), pos 3-6 = L3 (offset 8192).
Without this, codes exceeded 4096 and caused index-out-of-range in SNAC.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Alex
2026-04-13 16:04:54 -05:00
parent d650fd06b9
commit 16aa526656

20
main.py
View File

@@ -237,25 +237,35 @@ def generate_tokens_streaming(text: str, voice: str) -> Generator[str, None, Non
# ============================================================================ # ============================================================================
def extract_audio_codes(text_stream: Generator[str, None, None]) -> Generator[int, None, None]: def extract_audio_codes(text_stream: Generator[str, None, None]) -> Generator[int, None, None]:
"""Extract SNAC audio codes from streamed text. Handles partial tokens across chunks.""" """Extract SNAC audio codes from streamed text. Handles partial tokens across chunks.
Yields raw token IDs (before layer offset subtraction) — redistribution handles offsets."""
buffer = "" buffer = ""
count = 0
for chunk in text_stream: for chunk in text_stream:
buffer += chunk buffer += chunk
# Extract all complete <custom_token_XXXX> patterns
while True: while True:
match = re.search(r'<custom_token_(\d+)>', buffer) match = re.search(r'<custom_token_(\d+)>', buffer)
if not match: if not match:
break break
token_id = int(match.group(1)) token_id = int(match.group(1))
buffer = buffer[match.end():] buffer = buffer[match.end():]
# Only yield actual audio codes (>= CODE_TOKEN_OFFSET)
if token_id >= CODE_TOKEN_OFFSET: if token_id >= CODE_TOKEN_OFFSET:
yield token_id - CODE_TOKEN_OFFSET # Subtract base offset + per-position layer offset (4096 per layer)
# Position in group of 7: determines which SNAC layer
pos_in_group = count % 7
code = token_id - CODE_TOKEN_OFFSET - (pos_in_group * 4096)
if 0 <= code < 4096:
count += 1
yield code
else:
# Out of range — skip but still count position
count += 1
def redistribute_codes(codes: list) -> list: def redistribute_codes(codes: list) -> list:
"""Redistribute flat code list into SNAC's 3 hierarchical layers. """Redistribute flat code list into SNAC's 3 hierarchical layers.
Each group of 7: [L1, L2a, L2b, L3a, L3b, L3c, L3d]""" Each group of 7: [L1, L2a, L2b, L3a, L3b, L3c, L3d]
Codes are already offset-corrected to 0-4095 range per layer."""
layer1, layer2, layer3 = [], [], [] layer1, layer2, layer3 = [], [], []
for i in range(0, len(codes), SNAC_CHUNK_SIZE): for i in range(0, len(codes), SNAC_CHUNK_SIZE):
group = codes[i:i + SNAC_CHUNK_SIZE] group = codes[i:i + SNAC_CHUNK_SIZE]