Fix SNAC decoding: correct token offset + device attribute
- CODE_TOKEN_OFFSET is 10 in decoded text (not 128266 in token ID space) because tokenizer.decode() maps 128266 → <custom_token_10> - Fixed 'SNAC object has no attribute device' — use explicit SNAC_DEVICE - Added debug logging for pipeline visibility - Audio now generates correctly: 442KB for "Hello world" True streaming pipeline verified: text → TextIteratorStreamer → regex extraction → SNAC decode → PCM bytes. The bottleneck is Jetson inference speed (~12s for first 42 tokens on a 3B model), not the streaming infrastructure. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
83
main.py
83
main.py
@@ -68,7 +68,9 @@ EMOTION_TAGS = ["<laugh>", "<chuckle>", "<sigh>", "<cough>", "<sniffle>", "<groa
|
|||||||
SPECIAL_TOKEN_START = 128259
|
SPECIAL_TOKEN_START = 128259
|
||||||
SPECIAL_TOKENS_END = [128009, 128260, 128261, 128257]
|
SPECIAL_TOKENS_END = [128009, 128260, 128261, 128257]
|
||||||
EOS_TOKEN_ID = 128258
|
EOS_TOKEN_ID = 128258
|
||||||
CODE_TOKEN_OFFSET = 128266 # audio code tokens start here
|
# When decoded by tokenizer, audio codes appear as <custom_token_N> where N = token_id - 128256
|
||||||
|
# Audio codes start at token_id 128266, which decodes as <custom_token_10>
|
||||||
|
CODE_TOKEN_OFFSET = 10 # in decoded text space (token_id 128266 → custom_token_10)
|
||||||
CODE_REMOVE_TOKEN_ID = 128258 # this token signals end, not audio
|
CODE_REMOVE_TOKEN_ID = 128258 # this token signals end, not audio
|
||||||
|
|
||||||
# SNAC streaming parameters
|
# SNAC streaming parameters
|
||||||
@@ -241,6 +243,7 @@ def extract_audio_codes(text_stream: Generator[str, None, None]) -> Generator[in
|
|||||||
Yields raw token IDs (before layer offset subtraction) — redistribution handles offsets."""
|
Yields raw token IDs (before layer offset subtraction) — redistribution handles offsets."""
|
||||||
buffer = ""
|
buffer = ""
|
||||||
count = 0
|
count = 0
|
||||||
|
skipped = 0
|
||||||
for chunk in text_stream:
|
for chunk in text_stream:
|
||||||
buffer += chunk
|
buffer += chunk
|
||||||
while True:
|
while True:
|
||||||
@@ -250,16 +253,19 @@ def extract_audio_codes(text_stream: Generator[str, None, None]) -> Generator[in
|
|||||||
token_id = int(match.group(1))
|
token_id = int(match.group(1))
|
||||||
buffer = buffer[match.end():]
|
buffer = buffer[match.end():]
|
||||||
if token_id >= CODE_TOKEN_OFFSET:
|
if token_id >= CODE_TOKEN_OFFSET:
|
||||||
# Subtract base offset + per-position layer offset (4096 per layer)
|
|
||||||
# Position in group of 7: determines which SNAC layer
|
|
||||||
pos_in_group = count % 7
|
pos_in_group = count % 7
|
||||||
|
# token_id is already decoded (e.g., 2061 for custom_token_2061)
|
||||||
|
# Subtract the base offset (10) and per-layer offset (pos * 4096)
|
||||||
code = token_id - CODE_TOKEN_OFFSET - (pos_in_group * 4096)
|
code = token_id - CODE_TOKEN_OFFSET - (pos_in_group * 4096)
|
||||||
|
if count < 14:
|
||||||
|
print(f"[codes] custom_token_{token_id} pos={pos_in_group} code={code}")
|
||||||
if 0 <= code < 4096:
|
if 0 <= code < 4096:
|
||||||
count += 1
|
count += 1
|
||||||
yield code
|
yield code
|
||||||
else:
|
else:
|
||||||
# Out of range — skip but still count position
|
skipped += 1
|
||||||
count += 1
|
count += 1
|
||||||
|
print(f"[codes] Total: {count} extracted, {skipped} skipped")
|
||||||
|
|
||||||
|
|
||||||
def redistribute_codes(codes: list) -> list:
|
def redistribute_codes(codes: list) -> list:
|
||||||
@@ -277,47 +283,57 @@ def redistribute_codes(codes: list) -> list:
|
|||||||
return [layer1, layer2, layer3]
|
return [layer1, layer2, layer3]
|
||||||
|
|
||||||
|
|
||||||
|
SNAC_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
def snac_decode(codes: list) -> Optional[bytes]:
|
def snac_decode(codes: list) -> Optional[bytes]:
|
||||||
"""Decode SNAC codes to PCM audio bytes."""
|
"""Decode SNAC codes to PCM audio bytes."""
|
||||||
layers = redistribute_codes(codes)
|
layers = redistribute_codes(codes)
|
||||||
if not layers[0]:
|
if not layers[0]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
with torch.no_grad():
|
try:
|
||||||
codes_tensor = [
|
with torch.no_grad():
|
||||||
torch.tensor(layer, device=snac_model.device, dtype=torch.long).unsqueeze(0)
|
codes_tensor = [
|
||||||
for layer in layers
|
torch.tensor(layer, device=SNAC_DEVICE, dtype=torch.long).unsqueeze(0)
|
||||||
]
|
for layer in layers
|
||||||
audio_hat = snac_model.decode(codes_tensor)
|
]
|
||||||
|
audio_hat = snac_model.decode(codes_tensor)
|
||||||
|
|
||||||
audio_np = audio_hat.squeeze().cpu().numpy()
|
audio_np = audio_hat.squeeze().cpu().numpy()
|
||||||
audio_int16 = (audio_np * 32767).clip(-32768, 32767).astype(np.int16)
|
audio_int16 = (audio_np * 32767).clip(-32768, 32767).astype(np.int16)
|
||||||
return audio_int16.tobytes()
|
return audio_int16.tobytes()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[snac] Decode error: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def decode_audio_stream(code_stream: Generator[int, None, None]) -> Generator[bytes, None, None]:
|
def decode_audio_stream(code_stream: Generator[int, None, None]) -> Generator[bytes, None, None]:
|
||||||
"""Convert streaming SNAC codes to PCM audio chunks."""
|
"""Convert streaming SNAC codes to PCM audio chunks."""
|
||||||
buffer = []
|
buffer = []
|
||||||
group_count = 0
|
|
||||||
total_codes = 0
|
total_codes = 0
|
||||||
|
|
||||||
for code in code_stream:
|
for code in code_stream:
|
||||||
buffer.append(code)
|
buffer.append(code)
|
||||||
total_codes += 1
|
total_codes += 1
|
||||||
|
|
||||||
if total_codes % SNAC_CHUNK_SIZE == 0:
|
if total_codes == 1:
|
||||||
group_count += 1
|
print(f"[snac] First code received: {code}")
|
||||||
|
|
||||||
if group_count >= SNAC_INITIAL_GROUPS and group_count % 1 == 0:
|
# Decode every SNAC_INITIAL_GROUPS groups (sliding window)
|
||||||
# Decode the last N groups
|
if total_codes % SNAC_CHUNK_SIZE == 0:
|
||||||
|
groups = total_codes // SNAC_CHUNK_SIZE
|
||||||
|
if groups >= SNAC_INITIAL_GROUPS:
|
||||||
decode_size = SNAC_INITIAL_GROUPS * SNAC_CHUNK_SIZE
|
decode_size = SNAC_INITIAL_GROUPS * SNAC_CHUNK_SIZE
|
||||||
codes_to_decode = buffer[-decode_size:]
|
codes_to_decode = buffer[-decode_size:]
|
||||||
|
if groups == SNAC_INITIAL_GROUPS:
|
||||||
|
print(f"[snac] First decode at {total_codes} codes, values: {codes_to_decode[:7]}")
|
||||||
audio = snac_decode(codes_to_decode)
|
audio = snac_decode(codes_to_decode)
|
||||||
if audio:
|
if audio:
|
||||||
# Yield only the NEW audio (avoid overlap)
|
yield audio
|
||||||
# Each group of 7 codes produces ~2048 samples
|
elif groups == SNAC_INITIAL_GROUPS:
|
||||||
new_samples = SNAC_CHUNK_SIZE * 293 # ~293 samples per code at 24kHz
|
print(f"[snac] WARNING: decode returned None")
|
||||||
yield audio[-new_samples * 2:] # *2 for 16-bit (2 bytes per sample)
|
|
||||||
|
print(f"[snac] Total codes: {total_codes}, groups: {total_codes // SNAC_CHUNK_SIZE}")
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
@@ -327,18 +343,25 @@ def decode_audio_stream(code_stream: Generator[int, None, None]) -> Generator[by
|
|||||||
def generate_stream(text: str, voice: str) -> Generator[bytes, None, None]:
|
def generate_stream(text: str, voice: str) -> Generator[bytes, None, None]:
|
||||||
"""Full streaming pipeline: text → tokens → audio chunks."""
|
"""Full streaming pipeline: text → tokens → audio chunks."""
|
||||||
voice = voice if voice in BUILTIN_VOICES else DEFAULT_VOICE
|
voice = voice if voice in BUILTIN_VOICES else DEFAULT_VOICE
|
||||||
|
total_bytes = 0
|
||||||
for chunk in chunk_text(text):
|
for chunk in chunk_text(text):
|
||||||
print(f"[stream] Generating: {chunk[:80]}...")
|
print(f"[stream] Generating: {chunk[:80]}...")
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
first_audio = True
|
first_audio = True
|
||||||
text_stream = generate_tokens_streaming(chunk, voice)
|
try:
|
||||||
code_stream = extract_audio_codes(text_stream)
|
text_stream = generate_tokens_streaming(chunk, voice)
|
||||||
for pcm_chunk in decode_audio_stream(code_stream):
|
code_stream = extract_audio_codes(text_stream)
|
||||||
if first_audio:
|
for pcm_chunk in decode_audio_stream(code_stream):
|
||||||
print(f"[stream] First audio in {time.time() - t0:.2f}s")
|
if first_audio:
|
||||||
first_audio = False
|
print(f"[stream] First audio in {time.time() - t0:.2f}s ({len(pcm_chunk)} bytes)")
|
||||||
yield pcm_chunk
|
first_audio = False
|
||||||
print("[stream] Done")
|
total_bytes += len(pcm_chunk)
|
||||||
|
yield pcm_chunk
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[stream] ERROR in pipeline: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
print(f"[stream] Done — {total_bytes} bytes total")
|
||||||
|
|
||||||
|
|
||||||
def generate_speech_sync(text: str, voice: str) -> bytes:
|
def generate_speech_sync(text: str, voice: str) -> bytes:
|
||||||
|
|||||||
Reference in New Issue
Block a user