Add speaker identification with Resemblyzer
Adds voice-based speaker ID triggered by YAMNet speech detection.
New speaker_id.py module with SQLite-backed voice enrollment and
cosine similarity matching. Endpoints: POST /speakers/enroll,
POST /speakers/enroll-from-mic, GET /speakers, DELETE /speakers/{name}.
Orange LED animation during enrollment.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -5,6 +5,9 @@ Hey-Vivi_*/
|
|||||||
# ML models (downloaded separately)
|
# ML models (downloaded separately)
|
||||||
models/*.tflite
|
models/*.tflite
|
||||||
|
|
||||||
|
# Speaker voice database
|
||||||
|
voices.db
|
||||||
|
|
||||||
# Python
|
# Python
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
|
|||||||
148
headmic.py
148
headmic.py
@@ -38,7 +38,7 @@ import numpy as np
|
|||||||
import httpx
|
import httpx
|
||||||
import pvporcupine
|
import pvporcupine
|
||||||
import webrtcvad
|
import webrtcvad
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
@@ -97,6 +97,14 @@ def leds_processing():
|
|||||||
except: pass
|
except: pass
|
||||||
|
|
||||||
|
|
||||||
|
def leds_enrolling():
|
||||||
|
if LEDS_AVAILABLE:
|
||||||
|
try:
|
||||||
|
pixel_ring.set_color_palette(0xFF8C00, 0x000000)
|
||||||
|
pixel_ring.think()
|
||||||
|
except: pass
|
||||||
|
|
||||||
|
|
||||||
def leds_off():
|
def leds_off():
|
||||||
if LEDS_AVAILABLE:
|
if LEDS_AVAILABLE:
|
||||||
try:
|
try:
|
||||||
@@ -120,6 +128,10 @@ class ServiceState:
|
|||||||
self.error: Optional[str] = None
|
self.error: Optional[str] = None
|
||||||
self.audio_scene: Optional[dict] = None
|
self.audio_scene: Optional[dict] = None
|
||||||
self.sound_classification_enabled: bool = False
|
self.sound_classification_enabled: bool = False
|
||||||
|
self.recognized_speaker: Optional[str] = None
|
||||||
|
self.speaker_confidence: float = 0.0
|
||||||
|
self.speaker_recognition_enabled: bool = False
|
||||||
|
self.enrolling: bool = False
|
||||||
|
|
||||||
state = ServiceState()
|
state = ServiceState()
|
||||||
|
|
||||||
@@ -127,6 +139,11 @@ state = ServiceState()
|
|||||||
sound_classifier = None
|
sound_classifier = None
|
||||||
sound_ring_buffer = None # collections.deque, filled by listener_loop
|
sound_ring_buffer = None # collections.deque, filled by listener_loop
|
||||||
|
|
||||||
|
# Speaker recognizer globals
|
||||||
|
speaker_recognizer = None
|
||||||
|
enrollment_buffer = None # list of frame bytes, set during enrollment
|
||||||
|
enrollment_name = None
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# Audio Stream using ALSA directly (arecord)
|
# Audio Stream using ALSA directly (arecord)
|
||||||
@@ -266,6 +283,10 @@ def listener_loop():
|
|||||||
if sound_ring_buffer is not None:
|
if sound_ring_buffer is not None:
|
||||||
sound_ring_buffer.append(frame_data)
|
sound_ring_buffer.append(frame_data)
|
||||||
|
|
||||||
|
# Feed enrollment buffer if active
|
||||||
|
if enrollment_buffer is not None:
|
||||||
|
enrollment_buffer.append(frame_data)
|
||||||
|
|
||||||
# Check for wake word
|
# Check for wake word
|
||||||
keyword_index = porcupine.process(pcm)
|
keyword_index = porcupine.process(pcm)
|
||||||
|
|
||||||
@@ -355,7 +376,19 @@ def sound_classifier_loop():
|
|||||||
frames = list(sound_ring_buffer)
|
frames = list(sound_ring_buffer)
|
||||||
audio = np.frombuffer(b"".join(frames), dtype=np.int16)
|
audio = np.frombuffer(b"".join(frames), dtype=np.int16)
|
||||||
result = sound_classifier.classify(audio)
|
result = sound_classifier.classify(audio)
|
||||||
|
|
||||||
|
# Strip audio_float32 before storing in state (not JSON-serializable)
|
||||||
|
audio_f32 = result.pop("audio_float32", None)
|
||||||
state.audio_scene = result
|
state.audio_scene = result
|
||||||
|
|
||||||
|
# Speaker identification: run when speech detected
|
||||||
|
if speaker_recognizer and result["category"] == "speech" and audio_f32 is not None:
|
||||||
|
try:
|
||||||
|
name, confidence = speaker_recognizer.identify(audio_f32)
|
||||||
|
state.recognized_speaker = name
|
||||||
|
state.speaker_confidence = confidence
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Speaker identification error: %s", e)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Sound classification error: %s", e)
|
logger.warning("Sound classification error: %s", e)
|
||||||
|
|
||||||
@@ -372,7 +405,7 @@ app = FastAPI(title="HeadMic", description="Vixy's Ears 🦊👂")
|
|||||||
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def startup():
|
async def startup():
|
||||||
global sound_classifier, sound_ring_buffer
|
global sound_classifier, sound_ring_buffer, speaker_recognizer
|
||||||
|
|
||||||
state.running = True
|
state.running = True
|
||||||
|
|
||||||
@@ -396,6 +429,16 @@ async def startup():
|
|||||||
else:
|
else:
|
||||||
logger.info("Sound classification models not found, skipping")
|
logger.info("Sound classification models not found, skipping")
|
||||||
|
|
||||||
|
# Init speaker recognizer (optional — graceful if resemblyzer not installed)
|
||||||
|
try:
|
||||||
|
from speaker_id import SpeakerRecognizer
|
||||||
|
db_path = Path(__file__).parent / "voices.db"
|
||||||
|
speaker_recognizer = SpeakerRecognizer(db_path=str(db_path))
|
||||||
|
state.speaker_recognition_enabled = True
|
||||||
|
logger.info("Speaker recognition enabled (Resemblyzer)")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Speaker recognition unavailable: %s", e)
|
||||||
|
|
||||||
thread = threading.Thread(target=listener_loop, daemon=True)
|
thread = threading.Thread(target=listener_loop, daemon=True)
|
||||||
thread.start()
|
thread.start()
|
||||||
logger.info("HeadMic started")
|
logger.info("HeadMic started")
|
||||||
@@ -425,6 +468,7 @@ async def health():
|
|||||||
"processing": state.processing,
|
"processing": state.processing,
|
||||||
"wake_count": state.wake_count,
|
"wake_count": state.wake_count,
|
||||||
"sound_classification_enabled": state.sound_classification_enabled,
|
"sound_classification_enabled": state.sound_classification_enabled,
|
||||||
|
"speaker_recognition_enabled": state.speaker_recognition_enabled,
|
||||||
"error": state.error
|
"error": state.error
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -439,6 +483,7 @@ async def status():
|
|||||||
"last_wake_time": state.last_wake_time,
|
"last_wake_time": state.last_wake_time,
|
||||||
"wake_count": state.wake_count,
|
"wake_count": state.wake_count,
|
||||||
"audio_scene": state.audio_scene["dominant_category"] if state.audio_scene else None,
|
"audio_scene": state.audio_scene["dominant_category"] if state.audio_scene else None,
|
||||||
|
"recognized_speaker": state.recognized_speaker,
|
||||||
"error": state.error
|
"error": state.error
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -457,8 +502,13 @@ async def sounds():
|
|||||||
if not state.sound_classification_enabled:
|
if not state.sound_classification_enabled:
|
||||||
raise HTTPException(status_code=503, detail="Sound classification not available")
|
raise HTTPException(status_code=503, detail="Sound classification not available")
|
||||||
if state.audio_scene is None:
|
if state.audio_scene is None:
|
||||||
return {"category": None, "top_classes": [], "dominant_category": None, "timestamp": None}
|
return {"category": None, "top_classes": [], "dominant_category": None, "timestamp": None,
|
||||||
return state.audio_scene
|
"recognized_speaker": None, "speaker_confidence": 0.0}
|
||||||
|
return {
|
||||||
|
**state.audio_scene,
|
||||||
|
"recognized_speaker": state.recognized_speaker,
|
||||||
|
"speaker_confidence": state.speaker_confidence,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/sounds/history")
|
@app.get("/sounds/history")
|
||||||
@@ -471,6 +521,96 @@ async def sounds_history(seconds: int = 30):
|
|||||||
return {"history": sound_classifier.get_history(seconds)}
|
return {"history": sound_classifier.get_history(seconds)}
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Speaker Endpoints
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
@app.post("/speakers/enroll")
|
||||||
|
async def enroll_speaker(name: str = Form(...), audio: UploadFile = File(...)):
|
||||||
|
"""Enroll a speaker from uploaded audio file."""
|
||||||
|
if speaker_recognizer is None:
|
||||||
|
raise HTTPException(status_code=503, detail="Speaker recognition not available")
|
||||||
|
|
||||||
|
audio_bytes = await audio.read()
|
||||||
|
# Convert to float32: try raw int16 first, fall back to wav
|
||||||
|
try:
|
||||||
|
import wave as _wave
|
||||||
|
wav_io = io.BytesIO(audio_bytes)
|
||||||
|
with _wave.open(wav_io, 'rb') as wf:
|
||||||
|
raw = wf.readframes(wf.getnframes())
|
||||||
|
audio_f32 = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768.0
|
||||||
|
except Exception:
|
||||||
|
# Assume raw int16 PCM at 16kHz
|
||||||
|
audio_f32 = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32768.0
|
||||||
|
|
||||||
|
try:
|
||||||
|
speaker_recognizer.enroll(name, audio_f32, source="upload")
|
||||||
|
return {"enrolled": name, "speakers": speaker_recognizer.list_speakers()}
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/speakers/enroll-from-mic")
|
||||||
|
async def enroll_from_mic(name: str):
|
||||||
|
"""Record from live mic for 5 seconds and enroll speaker."""
|
||||||
|
global enrollment_buffer, enrollment_name, enrollment_event
|
||||||
|
|
||||||
|
if speaker_recognizer is None:
|
||||||
|
raise HTTPException(status_code=503, detail="Speaker recognition not available")
|
||||||
|
if state.enrolling:
|
||||||
|
raise HTTPException(status_code=409, detail="Enrollment already in progress")
|
||||||
|
|
||||||
|
state.enrolling = True
|
||||||
|
enrollment_buffer = []
|
||||||
|
enrollment_name = name
|
||||||
|
|
||||||
|
leds_enrolling()
|
||||||
|
logger.info("Enrollment started for '%s' — recording 5 seconds", name)
|
||||||
|
|
||||||
|
# Wait 5 seconds for audio, non-blocking to the event loop
|
||||||
|
await asyncio.sleep(5.0)
|
||||||
|
|
||||||
|
# Collect what we have
|
||||||
|
frames = enrollment_buffer
|
||||||
|
enrollment_buffer = None
|
||||||
|
enrollment_name = None
|
||||||
|
state.enrolling = False
|
||||||
|
leds_off()
|
||||||
|
|
||||||
|
if not frames:
|
||||||
|
raise HTTPException(status_code=500, detail="No audio captured")
|
||||||
|
|
||||||
|
audio_int16 = np.frombuffer(b"".join(frames), dtype=np.int16)
|
||||||
|
audio_f32 = audio_int16.astype(np.float32) / 32768.0
|
||||||
|
logger.info("Enrollment audio: %.1f seconds", len(audio_f32) / SAMPLE_RATE)
|
||||||
|
|
||||||
|
try:
|
||||||
|
speaker_recognizer.enroll(name, audio_f32, source="mic")
|
||||||
|
return {"enrolled": name, "seconds": round(len(audio_f32) / SAMPLE_RATE, 1),
|
||||||
|
"speakers": speaker_recognizer.list_speakers()}
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/speakers")
|
||||||
|
async def list_speakers():
|
||||||
|
"""List enrolled speakers."""
|
||||||
|
if speaker_recognizer is None:
|
||||||
|
raise HTTPException(status_code=503, detail="Speaker recognition not available")
|
||||||
|
return {"speakers": speaker_recognizer.list_speakers()}
|
||||||
|
|
||||||
|
|
||||||
|
@app.delete("/speakers/{name}")
|
||||||
|
async def delete_speaker(name: str):
|
||||||
|
"""Remove a speaker."""
|
||||||
|
if speaker_recognizer is None:
|
||||||
|
raise HTTPException(status_code=503, detail="Speaker recognition not available")
|
||||||
|
removed = speaker_recognizer.delete_speaker(name)
|
||||||
|
if removed == 0:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Speaker '{name}' not found")
|
||||||
|
return {"deleted": name, "samples_removed": removed}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8446)
|
uvicorn.run(app, host="0.0.0.0", port=8446)
|
||||||
|
|||||||
@@ -176,6 +176,7 @@ class SoundClassifier:
|
|||||||
"top_classes": top_classes,
|
"top_classes": top_classes,
|
||||||
"dominant_category": dominant,
|
"dominant_category": dominant,
|
||||||
"timestamp": now,
|
"timestamp": now,
|
||||||
|
"audio_float32": audio_f32,
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_history(self, seconds=30):
|
def get_history(self, seconds=30):
|
||||||
|
|||||||
133
speaker_id.py
Normal file
133
speaker_id.py
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
"""
|
||||||
|
Speaker Identification Module for HeadMic
|
||||||
|
Resemblyzer GE2E speaker encoder — 256-dim embeddings, cosine similarity matching.
|
||||||
|
Triggered when YAMNet detects speech.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import sqlite3
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
logger = logging.getLogger("speaker_id")
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
SIMILARITY_THRESHOLD = 0.75
|
||||||
|
|
||||||
|
|
||||||
|
class SpeakerRecognizer:
|
||||||
|
def __init__(self, db_path="voices.db"):
|
||||||
|
from resemblyzer import VoiceEncoder
|
||||||
|
|
||||||
|
self._encoder = VoiceEncoder("cpu")
|
||||||
|
logger.info("Resemblyzer voice encoder loaded")
|
||||||
|
|
||||||
|
self._db_path = str(db_path)
|
||||||
|
self._init_db()
|
||||||
|
self._cache = self._load_embeddings()
|
||||||
|
logger.info(
|
||||||
|
"Speaker DB ready: %d embeddings for %d speakers",
|
||||||
|
sum(len(v) for v in self._cache.values()),
|
||||||
|
len(self._cache),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_db(self):
|
||||||
|
with sqlite3.connect(self._db_path) as conn:
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS voices (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
embedding BLOB NOT NULL,
|
||||||
|
enrolled_at REAL NOT NULL,
|
||||||
|
source TEXT
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
conn.execute(
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_voices_name ON voices(name)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _load_embeddings(self):
|
||||||
|
"""Load all embeddings from DB into memory, grouped by name."""
|
||||||
|
cache = {}
|
||||||
|
with sqlite3.connect(self._db_path) as conn:
|
||||||
|
rows = conn.execute("SELECT name, embedding FROM voices").fetchall()
|
||||||
|
for name, blob in rows:
|
||||||
|
emb = np.frombuffer(blob, dtype=np.float32).copy()
|
||||||
|
cache.setdefault(name, []).append(emb)
|
||||||
|
return cache
|
||||||
|
|
||||||
|
def identify(self, audio_float32):
|
||||||
|
"""Identify speaker from float32 audio at 16kHz.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(name, confidence) or (None, 0.0) if no match above threshold.
|
||||||
|
"""
|
||||||
|
if not self._cache:
|
||||||
|
return None, 0.0
|
||||||
|
|
||||||
|
try:
|
||||||
|
from resemblyzer import preprocess_wav
|
||||||
|
wav = preprocess_wav(audio_float32, source_sr=16000)
|
||||||
|
if len(wav) < 1600: # too short
|
||||||
|
return None, 0.0
|
||||||
|
embedding = self._encoder.embed_utterance(wav)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Embedding computation failed: %s", e)
|
||||||
|
return None, 0.0
|
||||||
|
|
||||||
|
best_name = None
|
||||||
|
best_score = 0.0
|
||||||
|
|
||||||
|
for name, embeddings in self._cache.items():
|
||||||
|
# Best score across all enrolled samples for this speaker
|
||||||
|
scores = [np.dot(embedding, emb) for emb in embeddings]
|
||||||
|
top = max(scores)
|
||||||
|
if top > best_score:
|
||||||
|
best_score = top
|
||||||
|
best_name = name
|
||||||
|
|
||||||
|
if best_score >= SIMILARITY_THRESHOLD:
|
||||||
|
return best_name, round(float(best_score), 3)
|
||||||
|
return None, 0.0
|
||||||
|
|
||||||
|
def enroll(self, name, audio_float32, source="api"):
|
||||||
|
"""Enroll a speaker from float32 audio at 16kHz.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The computed embedding (256-dim).
|
||||||
|
"""
|
||||||
|
from resemblyzer import preprocess_wav
|
||||||
|
|
||||||
|
wav = preprocess_wav(audio_float32, source_sr=16000)
|
||||||
|
if len(wav) < 1600:
|
||||||
|
raise ValueError("Audio too short for enrollment")
|
||||||
|
|
||||||
|
embedding = self._encoder.embed_utterance(wav)
|
||||||
|
blob = embedding.astype(np.float32).tobytes()
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
|
with sqlite3.connect(self._db_path) as conn:
|
||||||
|
conn.execute(
|
||||||
|
"INSERT INTO voices (name, embedding, enrolled_at, source) VALUES (?, ?, ?, ?)",
|
||||||
|
(name, blob, now, source),
|
||||||
|
)
|
||||||
|
|
||||||
|
self._cache.setdefault(name, []).append(embedding)
|
||||||
|
logger.info("Enrolled speaker '%s' (source=%s, total=%d samples)", name, source, len(self._cache[name]))
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
def list_speakers(self):
|
||||||
|
"""Return enrolled speaker names with sample counts."""
|
||||||
|
return {name: len(embs) for name, embs in self._cache.items()}
|
||||||
|
|
||||||
|
def delete_speaker(self, name):
|
||||||
|
"""Remove all embeddings for a speaker."""
|
||||||
|
with sqlite3.connect(self._db_path) as conn:
|
||||||
|
conn.execute("DELETE FROM voices WHERE name = ?", (name,))
|
||||||
|
removed = self._cache.pop(name, None)
|
||||||
|
if removed:
|
||||||
|
logger.info("Deleted speaker '%s' (%d samples)", name, len(removed))
|
||||||
|
return len(removed)
|
||||||
|
return 0
|
||||||
Reference in New Issue
Block a user