Fix SNAC code offset: subtract per-layer offset (position*4096)
SNAC has 3 codebook layers, each 4096 entries. Token position within the group of 7 determines which layer: pos 0 = L1 (offset 0), pos 1-2 = L2 (offset 4096), pos 3-6 = L3 (offset 8192). Without this, codes exceeded 4096 and caused index-out-of-range in SNAC. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
20
main.py
20
main.py
@@ -237,25 +237,35 @@ def generate_tokens_streaming(text: str, voice: str) -> Generator[str, None, Non
|
|||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
def extract_audio_codes(text_stream: Generator[str, None, None]) -> Generator[int, None, None]:
|
def extract_audio_codes(text_stream: Generator[str, None, None]) -> Generator[int, None, None]:
|
||||||
"""Extract SNAC audio codes from streamed text. Handles partial tokens across chunks."""
|
"""Extract SNAC audio codes from streamed text. Handles partial tokens across chunks.
|
||||||
|
Yields raw token IDs (before layer offset subtraction) — redistribution handles offsets."""
|
||||||
buffer = ""
|
buffer = ""
|
||||||
|
count = 0
|
||||||
for chunk in text_stream:
|
for chunk in text_stream:
|
||||||
buffer += chunk
|
buffer += chunk
|
||||||
# Extract all complete <custom_token_XXXX> patterns
|
|
||||||
while True:
|
while True:
|
||||||
match = re.search(r'<custom_token_(\d+)>', buffer)
|
match = re.search(r'<custom_token_(\d+)>', buffer)
|
||||||
if not match:
|
if not match:
|
||||||
break
|
break
|
||||||
token_id = int(match.group(1))
|
token_id = int(match.group(1))
|
||||||
buffer = buffer[match.end():]
|
buffer = buffer[match.end():]
|
||||||
# Only yield actual audio codes (>= CODE_TOKEN_OFFSET)
|
|
||||||
if token_id >= CODE_TOKEN_OFFSET:
|
if token_id >= CODE_TOKEN_OFFSET:
|
||||||
yield token_id - CODE_TOKEN_OFFSET
|
# Subtract base offset + per-position layer offset (4096 per layer)
|
||||||
|
# Position in group of 7: determines which SNAC layer
|
||||||
|
pos_in_group = count % 7
|
||||||
|
code = token_id - CODE_TOKEN_OFFSET - (pos_in_group * 4096)
|
||||||
|
if 0 <= code < 4096:
|
||||||
|
count += 1
|
||||||
|
yield code
|
||||||
|
else:
|
||||||
|
# Out of range — skip but still count position
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
|
||||||
def redistribute_codes(codes: list) -> list:
|
def redistribute_codes(codes: list) -> list:
|
||||||
"""Redistribute flat code list into SNAC's 3 hierarchical layers.
|
"""Redistribute flat code list into SNAC's 3 hierarchical layers.
|
||||||
Each group of 7: [L1, L2a, L2b, L3a, L3b, L3c, L3d]"""
|
Each group of 7: [L1, L2a, L2b, L3a, L3b, L3c, L3d]
|
||||||
|
Codes are already offset-corrected to 0-4095 range per layer."""
|
||||||
layer1, layer2, layer3 = [], [], []
|
layer1, layer2, layer3 = [], [], []
|
||||||
for i in range(0, len(codes), SNAC_CHUNK_SIZE):
|
for i in range(0, len(codes), SNAC_CHUNK_SIZE):
|
||||||
group = codes[i:i + SNAC_CHUNK_SIZE]
|
group = codes[i:i + SNAC_CHUNK_SIZE]
|
||||||
|
|||||||
Reference in New Issue
Block a user