diff --git a/.gitignore b/.gitignore index 4bec6a1..b144e97 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,9 @@ Hey-Vivi_*/ # ML models (downloaded separately) models/*.tflite +# Speaker voice database +voices.db + # Python __pycache__/ *.py[cod] diff --git a/headmic.py b/headmic.py index bb11c65..885f680 100644 --- a/headmic.py +++ b/headmic.py @@ -38,7 +38,7 @@ import numpy as np import httpx import pvporcupine import webrtcvad -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI, HTTPException, UploadFile, File, Form from pydantic import BaseModel # Configure logging @@ -97,6 +97,14 @@ def leds_processing(): except: pass +def leds_enrolling(): + if LEDS_AVAILABLE: + try: + pixel_ring.set_color_palette(0xFF8C00, 0x000000) + pixel_ring.think() + except: pass + + def leds_off(): if LEDS_AVAILABLE: try: @@ -120,6 +128,10 @@ class ServiceState: self.error: Optional[str] = None self.audio_scene: Optional[dict] = None self.sound_classification_enabled: bool = False + self.recognized_speaker: Optional[str] = None + self.speaker_confidence: float = 0.0 + self.speaker_recognition_enabled: bool = False + self.enrolling: bool = False state = ServiceState() @@ -127,6 +139,11 @@ state = ServiceState() sound_classifier = None sound_ring_buffer = None # collections.deque, filled by listener_loop +# Speaker recognizer globals +speaker_recognizer = None +enrollment_buffer = None # list of frame bytes, set during enrollment +enrollment_name = None + # ============================================================================ # Audio Stream using ALSA directly (arecord) @@ -266,6 +283,10 @@ def listener_loop(): if sound_ring_buffer is not None: sound_ring_buffer.append(frame_data) + # Feed enrollment buffer if active + if enrollment_buffer is not None: + enrollment_buffer.append(frame_data) + # Check for wake word keyword_index = porcupine.process(pcm) @@ -355,7 +376,19 @@ def sound_classifier_loop(): frames = list(sound_ring_buffer) audio = np.frombuffer(b"".join(frames), dtype=np.int16) result = sound_classifier.classify(audio) + + # Strip audio_float32 before storing in state (not JSON-serializable) + audio_f32 = result.pop("audio_float32", None) state.audio_scene = result + + # Speaker identification: run when speech detected + if speaker_recognizer and result["category"] == "speech" and audio_f32 is not None: + try: + name, confidence = speaker_recognizer.identify(audio_f32) + state.recognized_speaker = name + state.speaker_confidence = confidence + except Exception as e: + logger.warning("Speaker identification error: %s", e) except Exception as e: logger.warning("Sound classification error: %s", e) @@ -372,7 +405,7 @@ app = FastAPI(title="HeadMic", description="Vixy's Ears 🦊👂") @app.on_event("startup") async def startup(): - global sound_classifier, sound_ring_buffer + global sound_classifier, sound_ring_buffer, speaker_recognizer state.running = True @@ -396,6 +429,16 @@ async def startup(): else: logger.info("Sound classification models not found, skipping") + # Init speaker recognizer (optional — graceful if resemblyzer not installed) + try: + from speaker_id import SpeakerRecognizer + db_path = Path(__file__).parent / "voices.db" + speaker_recognizer = SpeakerRecognizer(db_path=str(db_path)) + state.speaker_recognition_enabled = True + logger.info("Speaker recognition enabled (Resemblyzer)") + except Exception as e: + logger.warning("Speaker recognition unavailable: %s", e) + thread = threading.Thread(target=listener_loop, daemon=True) thread.start() logger.info("HeadMic started") @@ -425,6 +468,7 @@ async def health(): "processing": state.processing, "wake_count": state.wake_count, "sound_classification_enabled": state.sound_classification_enabled, + "speaker_recognition_enabled": state.speaker_recognition_enabled, "error": state.error } @@ -439,6 +483,7 @@ async def status(): "last_wake_time": state.last_wake_time, "wake_count": state.wake_count, "audio_scene": state.audio_scene["dominant_category"] if state.audio_scene else None, + "recognized_speaker": state.recognized_speaker, "error": state.error } @@ -457,8 +502,13 @@ async def sounds(): if not state.sound_classification_enabled: raise HTTPException(status_code=503, detail="Sound classification not available") if state.audio_scene is None: - return {"category": None, "top_classes": [], "dominant_category": None, "timestamp": None} - return state.audio_scene + return {"category": None, "top_classes": [], "dominant_category": None, "timestamp": None, + "recognized_speaker": None, "speaker_confidence": 0.0} + return { + **state.audio_scene, + "recognized_speaker": state.recognized_speaker, + "speaker_confidence": state.speaker_confidence, + } @app.get("/sounds/history") @@ -471,6 +521,96 @@ async def sounds_history(seconds: int = 30): return {"history": sound_classifier.get_history(seconds)} +# ============================================================================ +# Speaker Endpoints +# ============================================================================ + +@app.post("/speakers/enroll") +async def enroll_speaker(name: str = Form(...), audio: UploadFile = File(...)): + """Enroll a speaker from uploaded audio file.""" + if speaker_recognizer is None: + raise HTTPException(status_code=503, detail="Speaker recognition not available") + + audio_bytes = await audio.read() + # Convert to float32: try raw int16 first, fall back to wav + try: + import wave as _wave + wav_io = io.BytesIO(audio_bytes) + with _wave.open(wav_io, 'rb') as wf: + raw = wf.readframes(wf.getnframes()) + audio_f32 = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768.0 + except Exception: + # Assume raw int16 PCM at 16kHz + audio_f32 = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32768.0 + + try: + speaker_recognizer.enroll(name, audio_f32, source="upload") + return {"enrolled": name, "speakers": speaker_recognizer.list_speakers()} + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@app.post("/speakers/enroll-from-mic") +async def enroll_from_mic(name: str): + """Record from live mic for 5 seconds and enroll speaker.""" + global enrollment_buffer, enrollment_name, enrollment_event + + if speaker_recognizer is None: + raise HTTPException(status_code=503, detail="Speaker recognition not available") + if state.enrolling: + raise HTTPException(status_code=409, detail="Enrollment already in progress") + + state.enrolling = True + enrollment_buffer = [] + enrollment_name = name + + leds_enrolling() + logger.info("Enrollment started for '%s' — recording 5 seconds", name) + + # Wait 5 seconds for audio, non-blocking to the event loop + await asyncio.sleep(5.0) + + # Collect what we have + frames = enrollment_buffer + enrollment_buffer = None + enrollment_name = None + state.enrolling = False + leds_off() + + if not frames: + raise HTTPException(status_code=500, detail="No audio captured") + + audio_int16 = np.frombuffer(b"".join(frames), dtype=np.int16) + audio_f32 = audio_int16.astype(np.float32) / 32768.0 + logger.info("Enrollment audio: %.1f seconds", len(audio_f32) / SAMPLE_RATE) + + try: + speaker_recognizer.enroll(name, audio_f32, source="mic") + return {"enrolled": name, "seconds": round(len(audio_f32) / SAMPLE_RATE, 1), + "speakers": speaker_recognizer.list_speakers()} + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@app.get("/speakers") +async def list_speakers(): + """List enrolled speakers.""" + if speaker_recognizer is None: + raise HTTPException(status_code=503, detail="Speaker recognition not available") + return {"speakers": speaker_recognizer.list_speakers()} + + +@app.delete("/speakers/{name}") +async def delete_speaker(name: str): + """Remove a speaker.""" + if speaker_recognizer is None: + raise HTTPException(status_code=503, detail="Speaker recognition not available") + removed = speaker_recognizer.delete_speaker(name) + if removed == 0: + raise HTTPException(status_code=404, detail=f"Speaker '{name}' not found") + return {"deleted": name, "samples_removed": removed} + + if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8446) diff --git a/sound_id.py b/sound_id.py index 92ec3a6..4b9ea40 100644 --- a/sound_id.py +++ b/sound_id.py @@ -176,6 +176,7 @@ class SoundClassifier: "top_classes": top_classes, "dominant_category": dominant, "timestamp": now, + "audio_float32": audio_f32, } def get_history(self, seconds=30): diff --git a/speaker_id.py b/speaker_id.py new file mode 100644 index 0000000..46d761c --- /dev/null +++ b/speaker_id.py @@ -0,0 +1,133 @@ +""" +Speaker Identification Module for HeadMic +Resemblyzer GE2E speaker encoder — 256-dim embeddings, cosine similarity matching. +Triggered when YAMNet detects speech. +""" + +import logging +import sqlite3 +import time +from pathlib import Path + +import numpy as np + +logger = logging.getLogger("speaker_id") +logger.setLevel(logging.INFO) + +SIMILARITY_THRESHOLD = 0.75 + + +class SpeakerRecognizer: + def __init__(self, db_path="voices.db"): + from resemblyzer import VoiceEncoder + + self._encoder = VoiceEncoder("cpu") + logger.info("Resemblyzer voice encoder loaded") + + self._db_path = str(db_path) + self._init_db() + self._cache = self._load_embeddings() + logger.info( + "Speaker DB ready: %d embeddings for %d speakers", + sum(len(v) for v in self._cache.values()), + len(self._cache), + ) + + def _init_db(self): + with sqlite3.connect(self._db_path) as conn: + conn.execute(""" + CREATE TABLE IF NOT EXISTS voices ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + embedding BLOB NOT NULL, + enrolled_at REAL NOT NULL, + source TEXT + ) + """) + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_voices_name ON voices(name)" + ) + + def _load_embeddings(self): + """Load all embeddings from DB into memory, grouped by name.""" + cache = {} + with sqlite3.connect(self._db_path) as conn: + rows = conn.execute("SELECT name, embedding FROM voices").fetchall() + for name, blob in rows: + emb = np.frombuffer(blob, dtype=np.float32).copy() + cache.setdefault(name, []).append(emb) + return cache + + def identify(self, audio_float32): + """Identify speaker from float32 audio at 16kHz. + + Returns: + (name, confidence) or (None, 0.0) if no match above threshold. + """ + if not self._cache: + return None, 0.0 + + try: + from resemblyzer import preprocess_wav + wav = preprocess_wav(audio_float32, source_sr=16000) + if len(wav) < 1600: # too short + return None, 0.0 + embedding = self._encoder.embed_utterance(wav) + except Exception as e: + logger.warning("Embedding computation failed: %s", e) + return None, 0.0 + + best_name = None + best_score = 0.0 + + for name, embeddings in self._cache.items(): + # Best score across all enrolled samples for this speaker + scores = [np.dot(embedding, emb) for emb in embeddings] + top = max(scores) + if top > best_score: + best_score = top + best_name = name + + if best_score >= SIMILARITY_THRESHOLD: + return best_name, round(float(best_score), 3) + return None, 0.0 + + def enroll(self, name, audio_float32, source="api"): + """Enroll a speaker from float32 audio at 16kHz. + + Returns: + The computed embedding (256-dim). + """ + from resemblyzer import preprocess_wav + + wav = preprocess_wav(audio_float32, source_sr=16000) + if len(wav) < 1600: + raise ValueError("Audio too short for enrollment") + + embedding = self._encoder.embed_utterance(wav) + blob = embedding.astype(np.float32).tobytes() + now = time.time() + + with sqlite3.connect(self._db_path) as conn: + conn.execute( + "INSERT INTO voices (name, embedding, enrolled_at, source) VALUES (?, ?, ?, ?)", + (name, blob, now, source), + ) + + self._cache.setdefault(name, []).append(embedding) + logger.info("Enrolled speaker '%s' (source=%s, total=%d samples)", name, source, len(self._cache[name])) + return embedding + + def list_speakers(self): + """Return enrolled speaker names with sample counts.""" + return {name: len(embs) for name, embs in self._cache.items()} + + def delete_speaker(self, name): + """Remove all embeddings for a speaker.""" + with sqlite3.connect(self._db_path) as conn: + conn.execute("DELETE FROM voices WHERE name = ?", (name,)) + removed = self._cache.pop(name, None) + if removed: + logger.info("Deleted speaker '%s' (%d samples)", name, len(removed)) + return len(removed) + return 0