Fix SNAC code offset: subtract per-layer offset (position*4096)
SNAC has 3 codebook layers, each 4096 entries. Token position within the group of 7 determines which layer: pos 0 = L1 (offset 0), pos 1-2 = L2 (offset 4096), pos 3-6 = L3 (offset 8192). Without this, codes exceeded 4096 and caused index-out-of-range in SNAC. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
20
main.py
20
main.py
@@ -237,25 +237,35 @@ def generate_tokens_streaming(text: str, voice: str) -> Generator[str, None, Non
|
||||
# ============================================================================
|
||||
|
||||
def extract_audio_codes(text_stream: Generator[str, None, None]) -> Generator[int, None, None]:
|
||||
"""Extract SNAC audio codes from streamed text. Handles partial tokens across chunks."""
|
||||
"""Extract SNAC audio codes from streamed text. Handles partial tokens across chunks.
|
||||
Yields raw token IDs (before layer offset subtraction) — redistribution handles offsets."""
|
||||
buffer = ""
|
||||
count = 0
|
||||
for chunk in text_stream:
|
||||
buffer += chunk
|
||||
# Extract all complete <custom_token_XXXX> patterns
|
||||
while True:
|
||||
match = re.search(r'<custom_token_(\d+)>', buffer)
|
||||
if not match:
|
||||
break
|
||||
token_id = int(match.group(1))
|
||||
buffer = buffer[match.end():]
|
||||
# Only yield actual audio codes (>= CODE_TOKEN_OFFSET)
|
||||
if token_id >= CODE_TOKEN_OFFSET:
|
||||
yield token_id - CODE_TOKEN_OFFSET
|
||||
# Subtract base offset + per-position layer offset (4096 per layer)
|
||||
# Position in group of 7: determines which SNAC layer
|
||||
pos_in_group = count % 7
|
||||
code = token_id - CODE_TOKEN_OFFSET - (pos_in_group * 4096)
|
||||
if 0 <= code < 4096:
|
||||
count += 1
|
||||
yield code
|
||||
else:
|
||||
# Out of range — skip but still count position
|
||||
count += 1
|
||||
|
||||
|
||||
def redistribute_codes(codes: list) -> list:
|
||||
"""Redistribute flat code list into SNAC's 3 hierarchical layers.
|
||||
Each group of 7: [L1, L2a, L2b, L3a, L3b, L3c, L3d]"""
|
||||
Each group of 7: [L1, L2a, L2b, L3a, L3b, L3c, L3d]
|
||||
Codes are already offset-corrected to 0-4095 range per layer."""
|
||||
layer1, layer2, layer3 = [], [], []
|
||||
for i in range(0, len(codes), SNAC_CHUNK_SIZE):
|
||||
group = codes[i:i + SNAC_CHUNK_SIZE]
|
||||
|
||||
Reference in New Issue
Block a user