Fix SNAC decoding: correct token offset + device attribute

- CODE_TOKEN_OFFSET is 10 in decoded text (not 128266 in token ID space)
  because tokenizer.decode() maps 128266 → <custom_token_10>
- Fixed 'SNAC object has no attribute device' — use explicit SNAC_DEVICE
- Added debug logging for pipeline visibility
- Audio now generates correctly: 442KB for "Hello world"

True streaming pipeline verified: text → TextIteratorStreamer →
regex extraction → SNAC decode → PCM bytes. The bottleneck is
Jetson inference speed (~12s for first 42 tokens on a 3B model),
not the streaming infrastructure.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Alex
2026-04-13 16:41:14 -05:00
parent 16aa526656
commit 57a2e24101

83
main.py
View File

@@ -68,7 +68,9 @@ EMOTION_TAGS = ["<laugh>", "<chuckle>", "<sigh>", "<cough>", "<sniffle>", "<groa
SPECIAL_TOKEN_START = 128259 SPECIAL_TOKEN_START = 128259
SPECIAL_TOKENS_END = [128009, 128260, 128261, 128257] SPECIAL_TOKENS_END = [128009, 128260, 128261, 128257]
EOS_TOKEN_ID = 128258 EOS_TOKEN_ID = 128258
CODE_TOKEN_OFFSET = 128266 # audio code tokens start here # When decoded by tokenizer, audio codes appear as <custom_token_N> where N = token_id - 128256
# Audio codes start at token_id 128266, which decodes as <custom_token_10>
CODE_TOKEN_OFFSET = 10 # in decoded text space (token_id 128266 → custom_token_10)
CODE_REMOVE_TOKEN_ID = 128258 # this token signals end, not audio CODE_REMOVE_TOKEN_ID = 128258 # this token signals end, not audio
# SNAC streaming parameters # SNAC streaming parameters
@@ -241,6 +243,7 @@ def extract_audio_codes(text_stream: Generator[str, None, None]) -> Generator[in
Yields raw token IDs (before layer offset subtraction) — redistribution handles offsets.""" Yields raw token IDs (before layer offset subtraction) — redistribution handles offsets."""
buffer = "" buffer = ""
count = 0 count = 0
skipped = 0
for chunk in text_stream: for chunk in text_stream:
buffer += chunk buffer += chunk
while True: while True:
@@ -250,16 +253,19 @@ def extract_audio_codes(text_stream: Generator[str, None, None]) -> Generator[in
token_id = int(match.group(1)) token_id = int(match.group(1))
buffer = buffer[match.end():] buffer = buffer[match.end():]
if token_id >= CODE_TOKEN_OFFSET: if token_id >= CODE_TOKEN_OFFSET:
# Subtract base offset + per-position layer offset (4096 per layer)
# Position in group of 7: determines which SNAC layer
pos_in_group = count % 7 pos_in_group = count % 7
# token_id is already decoded (e.g., 2061 for custom_token_2061)
# Subtract the base offset (10) and per-layer offset (pos * 4096)
code = token_id - CODE_TOKEN_OFFSET - (pos_in_group * 4096) code = token_id - CODE_TOKEN_OFFSET - (pos_in_group * 4096)
if count < 14:
print(f"[codes] custom_token_{token_id} pos={pos_in_group} code={code}")
if 0 <= code < 4096: if 0 <= code < 4096:
count += 1 count += 1
yield code yield code
else: else:
# Out of range — skip but still count position skipped += 1
count += 1 count += 1
print(f"[codes] Total: {count} extracted, {skipped} skipped")
def redistribute_codes(codes: list) -> list: def redistribute_codes(codes: list) -> list:
@@ -277,47 +283,57 @@ def redistribute_codes(codes: list) -> list:
return [layer1, layer2, layer3] return [layer1, layer2, layer3]
SNAC_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def snac_decode(codes: list) -> Optional[bytes]: def snac_decode(codes: list) -> Optional[bytes]:
"""Decode SNAC codes to PCM audio bytes.""" """Decode SNAC codes to PCM audio bytes."""
layers = redistribute_codes(codes) layers = redistribute_codes(codes)
if not layers[0]: if not layers[0]:
return None return None
with torch.no_grad(): try:
codes_tensor = [ with torch.no_grad():
torch.tensor(layer, device=snac_model.device, dtype=torch.long).unsqueeze(0) codes_tensor = [
for layer in layers torch.tensor(layer, device=SNAC_DEVICE, dtype=torch.long).unsqueeze(0)
] for layer in layers
audio_hat = snac_model.decode(codes_tensor) ]
audio_hat = snac_model.decode(codes_tensor)
audio_np = audio_hat.squeeze().cpu().numpy() audio_np = audio_hat.squeeze().cpu().numpy()
audio_int16 = (audio_np * 32767).clip(-32768, 32767).astype(np.int16) audio_int16 = (audio_np * 32767).clip(-32768, 32767).astype(np.int16)
return audio_int16.tobytes() return audio_int16.tobytes()
except Exception as e:
print(f"[snac] Decode error: {e}")
return None
def decode_audio_stream(code_stream: Generator[int, None, None]) -> Generator[bytes, None, None]: def decode_audio_stream(code_stream: Generator[int, None, None]) -> Generator[bytes, None, None]:
"""Convert streaming SNAC codes to PCM audio chunks.""" """Convert streaming SNAC codes to PCM audio chunks."""
buffer = [] buffer = []
group_count = 0
total_codes = 0 total_codes = 0
for code in code_stream: for code in code_stream:
buffer.append(code) buffer.append(code)
total_codes += 1 total_codes += 1
if total_codes % SNAC_CHUNK_SIZE == 0: if total_codes == 1:
group_count += 1 print(f"[snac] First code received: {code}")
if group_count >= SNAC_INITIAL_GROUPS and group_count % 1 == 0: # Decode every SNAC_INITIAL_GROUPS groups (sliding window)
# Decode the last N groups if total_codes % SNAC_CHUNK_SIZE == 0:
groups = total_codes // SNAC_CHUNK_SIZE
if groups >= SNAC_INITIAL_GROUPS:
decode_size = SNAC_INITIAL_GROUPS * SNAC_CHUNK_SIZE decode_size = SNAC_INITIAL_GROUPS * SNAC_CHUNK_SIZE
codes_to_decode = buffer[-decode_size:] codes_to_decode = buffer[-decode_size:]
if groups == SNAC_INITIAL_GROUPS:
print(f"[snac] First decode at {total_codes} codes, values: {codes_to_decode[:7]}")
audio = snac_decode(codes_to_decode) audio = snac_decode(codes_to_decode)
if audio: if audio:
# Yield only the NEW audio (avoid overlap) yield audio
# Each group of 7 codes produces ~2048 samples elif groups == SNAC_INITIAL_GROUPS:
new_samples = SNAC_CHUNK_SIZE * 293 # ~293 samples per code at 24kHz print(f"[snac] WARNING: decode returned None")
yield audio[-new_samples * 2:] # *2 for 16-bit (2 bytes per sample)
print(f"[snac] Total codes: {total_codes}, groups: {total_codes // SNAC_CHUNK_SIZE}")
# ============================================================================ # ============================================================================
@@ -327,18 +343,25 @@ def decode_audio_stream(code_stream: Generator[int, None, None]) -> Generator[by
def generate_stream(text: str, voice: str) -> Generator[bytes, None, None]: def generate_stream(text: str, voice: str) -> Generator[bytes, None, None]:
"""Full streaming pipeline: text → tokens → audio chunks.""" """Full streaming pipeline: text → tokens → audio chunks."""
voice = voice if voice in BUILTIN_VOICES else DEFAULT_VOICE voice = voice if voice in BUILTIN_VOICES else DEFAULT_VOICE
total_bytes = 0
for chunk in chunk_text(text): for chunk in chunk_text(text):
print(f"[stream] Generating: {chunk[:80]}...") print(f"[stream] Generating: {chunk[:80]}...")
t0 = time.time() t0 = time.time()
first_audio = True first_audio = True
text_stream = generate_tokens_streaming(chunk, voice) try:
code_stream = extract_audio_codes(text_stream) text_stream = generate_tokens_streaming(chunk, voice)
for pcm_chunk in decode_audio_stream(code_stream): code_stream = extract_audio_codes(text_stream)
if first_audio: for pcm_chunk in decode_audio_stream(code_stream):
print(f"[stream] First audio in {time.time() - t0:.2f}s") if first_audio:
first_audio = False print(f"[stream] First audio in {time.time() - t0:.2f}s ({len(pcm_chunk)} bytes)")
yield pcm_chunk first_audio = False
print("[stream] Done") total_bytes += len(pcm_chunk)
yield pcm_chunk
except Exception as e:
print(f"[stream] ERROR in pipeline: {e}")
import traceback
traceback.print_exc()
print(f"[stream] Done — {total_bytes} bytes total")
def generate_speech_sync(text: str, voice: str) -> bytes: def generate_speech_sync(text: str, voice: str) -> bytes: