diff --git a/main.py b/main.py index 49a1872..524c5d5 100644 --- a/main.py +++ b/main.py @@ -68,7 +68,9 @@ EMOTION_TAGS = ["", "", "", "", "", " where N = token_id - 128256 +# Audio codes start at token_id 128266, which decodes as +CODE_TOKEN_OFFSET = 10 # in decoded text space (token_id 128266 → custom_token_10) CODE_REMOVE_TOKEN_ID = 128258 # this token signals end, not audio # SNAC streaming parameters @@ -241,6 +243,7 @@ def extract_audio_codes(text_stream: Generator[str, None, None]) -> Generator[in Yields raw token IDs (before layer offset subtraction) — redistribution handles offsets.""" buffer = "" count = 0 + skipped = 0 for chunk in text_stream: buffer += chunk while True: @@ -250,16 +253,19 @@ def extract_audio_codes(text_stream: Generator[str, None, None]) -> Generator[in token_id = int(match.group(1)) buffer = buffer[match.end():] if token_id >= CODE_TOKEN_OFFSET: - # Subtract base offset + per-position layer offset (4096 per layer) - # Position in group of 7: determines which SNAC layer pos_in_group = count % 7 + # token_id is already decoded (e.g., 2061 for custom_token_2061) + # Subtract the base offset (10) and per-layer offset (pos * 4096) code = token_id - CODE_TOKEN_OFFSET - (pos_in_group * 4096) + if count < 14: + print(f"[codes] custom_token_{token_id} pos={pos_in_group} code={code}") if 0 <= code < 4096: count += 1 yield code else: - # Out of range — skip but still count position + skipped += 1 count += 1 + print(f"[codes] Total: {count} extracted, {skipped} skipped") def redistribute_codes(codes: list) -> list: @@ -277,47 +283,57 @@ def redistribute_codes(codes: list) -> list: return [layer1, layer2, layer3] +SNAC_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + def snac_decode(codes: list) -> Optional[bytes]: """Decode SNAC codes to PCM audio bytes.""" layers = redistribute_codes(codes) if not layers[0]: return None - with torch.no_grad(): - codes_tensor = [ - torch.tensor(layer, device=snac_model.device, dtype=torch.long).unsqueeze(0) - for layer in layers - ] - audio_hat = snac_model.decode(codes_tensor) + try: + with torch.no_grad(): + codes_tensor = [ + torch.tensor(layer, device=SNAC_DEVICE, dtype=torch.long).unsqueeze(0) + for layer in layers + ] + audio_hat = snac_model.decode(codes_tensor) - audio_np = audio_hat.squeeze().cpu().numpy() - audio_int16 = (audio_np * 32767).clip(-32768, 32767).astype(np.int16) - return audio_int16.tobytes() + audio_np = audio_hat.squeeze().cpu().numpy() + audio_int16 = (audio_np * 32767).clip(-32768, 32767).astype(np.int16) + return audio_int16.tobytes() + except Exception as e: + print(f"[snac] Decode error: {e}") + return None def decode_audio_stream(code_stream: Generator[int, None, None]) -> Generator[bytes, None, None]: """Convert streaming SNAC codes to PCM audio chunks.""" buffer = [] - group_count = 0 total_codes = 0 for code in code_stream: buffer.append(code) total_codes += 1 - if total_codes % SNAC_CHUNK_SIZE == 0: - group_count += 1 + if total_codes == 1: + print(f"[snac] First code received: {code}") - if group_count >= SNAC_INITIAL_GROUPS and group_count % 1 == 0: - # Decode the last N groups + # Decode every SNAC_INITIAL_GROUPS groups (sliding window) + if total_codes % SNAC_CHUNK_SIZE == 0: + groups = total_codes // SNAC_CHUNK_SIZE + if groups >= SNAC_INITIAL_GROUPS: decode_size = SNAC_INITIAL_GROUPS * SNAC_CHUNK_SIZE codes_to_decode = buffer[-decode_size:] + if groups == SNAC_INITIAL_GROUPS: + print(f"[snac] First decode at {total_codes} codes, values: {codes_to_decode[:7]}") audio = snac_decode(codes_to_decode) if audio: - # Yield only the NEW audio (avoid overlap) - # Each group of 7 codes produces ~2048 samples - new_samples = SNAC_CHUNK_SIZE * 293 # ~293 samples per code at 24kHz - yield audio[-new_samples * 2:] # *2 for 16-bit (2 bytes per sample) + yield audio + elif groups == SNAC_INITIAL_GROUPS: + print(f"[snac] WARNING: decode returned None") + + print(f"[snac] Total codes: {total_codes}, groups: {total_codes // SNAC_CHUNK_SIZE}") # ============================================================================ @@ -327,18 +343,25 @@ def decode_audio_stream(code_stream: Generator[int, None, None]) -> Generator[by def generate_stream(text: str, voice: str) -> Generator[bytes, None, None]: """Full streaming pipeline: text → tokens → audio chunks.""" voice = voice if voice in BUILTIN_VOICES else DEFAULT_VOICE + total_bytes = 0 for chunk in chunk_text(text): print(f"[stream] Generating: {chunk[:80]}...") t0 = time.time() first_audio = True - text_stream = generate_tokens_streaming(chunk, voice) - code_stream = extract_audio_codes(text_stream) - for pcm_chunk in decode_audio_stream(code_stream): - if first_audio: - print(f"[stream] First audio in {time.time() - t0:.2f}s") - first_audio = False - yield pcm_chunk - print("[stream] Done") + try: + text_stream = generate_tokens_streaming(chunk, voice) + code_stream = extract_audio_codes(text_stream) + for pcm_chunk in decode_audio_stream(code_stream): + if first_audio: + print(f"[stream] First audio in {time.time() - t0:.2f}s ({len(pcm_chunk)} bytes)") + first_audio = False + total_bytes += len(pcm_chunk) + yield pcm_chunk + except Exception as e: + print(f"[stream] ERROR in pipeline: {e}") + import traceback + traceback.print_exc() + print(f"[stream] Done — {total_bytes} bytes total") def generate_speech_sync(text: str, voice: str) -> bytes: