Adds voice-based speaker ID triggered by YAMNet speech detection.
New speaker_id.py module with SQLite-backed voice enrollment and
cosine similarity matching. Endpoints: POST /speakers/enroll,
POST /speakers/enroll-from-mic, GET /speakers, DELETE /speakers/{name}.
Orange LED animation during enrollment.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
186 lines
7.1 KiB
Python
186 lines
7.1 KiB
Python
"""
|
|
Sound Identification Module for HeadMic
|
|
YAMNet audio event classifier — CPU TFLite (Edge TPU ready)
|
|
"""
|
|
|
|
import collections
|
|
import csv
|
|
import logging
|
|
import time
|
|
|
|
import numpy as np
|
|
|
|
logger = logging.getLogger("sound_id")
|
|
logger.setLevel(logging.INFO)
|
|
|
|
# YAMNet expects 0.975s of 16kHz audio = 15600 samples
|
|
YAMNET_SAMPLE_RATE = 16000
|
|
YAMNET_INPUT_SAMPLES = 15600
|
|
|
|
# Category mapping: group YAMNet's 521 classes into useful buckets
|
|
CATEGORY_GROUPS = {
|
|
"speech": [
|
|
"Speech", "Child speech, kid speaking", "Conversation",
|
|
"Narration, monologue", "Babbling", "Speech synthesizer",
|
|
"Shout", "Yell", "Whispering", "Laughter", "Chatter",
|
|
],
|
|
"alert": [
|
|
"Doorbell", "Ding-dong", "Knock", "Tap", "Alarm",
|
|
"Alarm clock", "Smoke detector, smoke alarm", "Fire alarm",
|
|
"Siren", "Buzzer", "Telephone", "Telephone bell ringing",
|
|
"Ringtone", "Telephone dialing, DTMF", "Bell", "Church bell",
|
|
],
|
|
"music": [
|
|
"Music", "Musical instrument", "Singing", "Song", "Guitar",
|
|
"Electric guitar", "Acoustic guitar", "Bass guitar", "Piano",
|
|
"Keyboard (musical)", "Drum", "Drum kit", "Snare drum",
|
|
"Bass drum", "Cymbal", "Hi-hat", "Violin, fiddle", "Flute",
|
|
"Trumpet", "Saxophone", "Harmonica", "Organ", "Synthesizer",
|
|
"Plucked string instrument", "Strum", "Hip hop music",
|
|
"Pop music", "Rock music", "Heavy metal", "Punk rock",
|
|
"Grunge", "Progressive rock", "Rock and roll", "Jazz",
|
|
"Blues", "Soul music", "Reggae", "Country", "Swing music",
|
|
"Bluegrass", "Funk", "Folk music", "Middle Eastern music",
|
|
"Electronic music", "House music", "Techno", "Dubstep",
|
|
"Drum and bass", "Electronica", "Electronic dance music",
|
|
"Ambient music", "Trance music", "Music of Latin America",
|
|
"Salsa music", "Flamenco", "Gospel music", "Christian music",
|
|
"Music of Africa", "Afrobeat", "Music of Asia",
|
|
],
|
|
"animal": [
|
|
"Dog", "Bark", "Howl", "Bow-wow", "Growling", "Whimper",
|
|
"Cat", "Purr", "Meow", "Hiss", "Caterwaul",
|
|
"Bird", "Bird vocalization, bird call, bird song", "Chirp, tweet",
|
|
"Squawk", "Crow", "Caw", "Owl", "Pigeon, dove",
|
|
"Insect", "Cricket", "Mosquito", "Fly, housefly", "Bee, wasp, etc.",
|
|
"Frog", "Rooster", "Chicken, hen", "Duck", "Goose",
|
|
],
|
|
"household": [
|
|
"Door", "Doorbell", "Sliding door", "Slam",
|
|
"Drawer open or close", "Cupboard open or close",
|
|
"Cutlery, silverware", "Dishes, pots, and pans",
|
|
"Glass", "Chink, clink", "Water tap, faucet",
|
|
"Sink (filling or washing)", "Bathtub (filling or washing)",
|
|
"Toilet flush", "Vacuum cleaner", "Blender",
|
|
"Microwave oven", "Typing", "Computer keyboard",
|
|
"Writing", "Keys jangling", "Coin (dropping)",
|
|
"Footsteps", "Walk, footsteps", "Run",
|
|
"Clapping", "Finger snapping",
|
|
],
|
|
"environment": [
|
|
"Wind", "Rustling leaves", "Wind noise (microphone)",
|
|
"Rain", "Raindrop", "Rain on surface", "Thunder",
|
|
"Thunderstorm", "Stream", "Waterfall", "Ocean",
|
|
"Waves, surf", "Traffic noise, roadway noise",
|
|
"Car", "Engine", "Idling", "Truck", "Bus", "Motorcycle",
|
|
"Aircraft", "Aircraft engine", "Helicopter",
|
|
"Train", "Railroad car, train wagon",
|
|
"Inside, small room", "Inside, large room or hall",
|
|
"Outside, urban or manmade", "Outside, rural or natural",
|
|
],
|
|
"silence": [
|
|
"Silence", "White noise", "Static",
|
|
],
|
|
}
|
|
|
|
|
|
class SoundClassifier:
|
|
def __init__(self, model_path, class_map_path, use_edgetpu=False):
|
|
# Load class names
|
|
self._class_names = []
|
|
with open(class_map_path) as f:
|
|
reader = csv.reader(f)
|
|
next(reader) # skip header
|
|
for row in reader:
|
|
self._class_names.append(row[2])
|
|
logger.info("Loaded %d YAMNet class names", len(self._class_names))
|
|
|
|
# Build reverse lookup: class_name -> category
|
|
self._class_to_category = {}
|
|
for category, names in CATEGORY_GROUPS.items():
|
|
for name in names:
|
|
self._class_to_category[name] = category
|
|
|
|
# Load TFLite model
|
|
import ai_edge_litert.interpreter as tfl
|
|
|
|
if use_edgetpu:
|
|
delegate = tfl.load_delegate("libedgetpu.so.1")
|
|
self._interp = tfl.Interpreter(
|
|
model_path=str(model_path),
|
|
experimental_delegates=[delegate],
|
|
)
|
|
logger.info("YAMNet loaded on Edge TPU")
|
|
else:
|
|
self._interp = tfl.Interpreter(model_path=str(model_path))
|
|
logger.info("YAMNet loaded on CPU")
|
|
|
|
self._interp.allocate_tensors()
|
|
self._input = self._interp.get_input_details()[0]
|
|
self._output = self._interp.get_output_details()[0]
|
|
logger.info(
|
|
"YAMNet ready: input %s %s, output %s",
|
|
self._input["shape"], self._input["dtype"], self._output["shape"],
|
|
)
|
|
|
|
# Classification history for smoothing
|
|
self._history = collections.deque(maxlen=20)
|
|
|
|
def classify(self, audio_int16):
|
|
"""Classify an audio buffer.
|
|
|
|
Args:
|
|
audio_int16: numpy int16 array of PCM samples at 16kHz
|
|
|
|
Returns:
|
|
dict with category, top_classes, dominant_category, timestamp
|
|
"""
|
|
# Convert int16 PCM to float32 [-1.0, 1.0]
|
|
audio_f32 = audio_int16.astype(np.float32) / 32768.0
|
|
|
|
# Pad or trim to exact input size
|
|
if len(audio_f32) < YAMNET_INPUT_SAMPLES:
|
|
audio_f32 = np.pad(audio_f32, (0, YAMNET_INPUT_SAMPLES - len(audio_f32)))
|
|
elif len(audio_f32) > YAMNET_INPUT_SAMPLES:
|
|
audio_f32 = audio_f32[-YAMNET_INPUT_SAMPLES:]
|
|
|
|
self._interp.set_tensor(self._input["index"], audio_f32)
|
|
self._interp.invoke()
|
|
|
|
scores = self._interp.get_tensor(self._output["index"])[0]
|
|
|
|
# Top 5 classes
|
|
top_indices = np.argsort(scores)[-5:][::-1]
|
|
top_classes = []
|
|
for idx in top_indices:
|
|
name = self._class_names[idx] if idx < len(self._class_names) else "Unknown"
|
|
top_classes.append({"name": name, "score": round(float(scores[idx]), 3)})
|
|
|
|
# Map top class to category
|
|
top_name = top_classes[0]["name"]
|
|
category = self._class_to_category.get(top_name, "other")
|
|
|
|
# Update history
|
|
now = time.time()
|
|
self._history.append({"category": category, "timestamp": now})
|
|
|
|
# Dominant category from recent history
|
|
if self._history:
|
|
cat_counts = collections.Counter(h["category"] for h in self._history)
|
|
dominant = cat_counts.most_common(1)[0][0]
|
|
else:
|
|
dominant = category
|
|
|
|
return {
|
|
"category": category,
|
|
"top_classes": top_classes,
|
|
"dominant_category": dominant,
|
|
"timestamp": now,
|
|
"audio_float32": audio_f32,
|
|
}
|
|
|
|
def get_history(self, seconds=30):
|
|
"""Return recent classification history."""
|
|
cutoff = time.time() - seconds
|
|
return [h for h in self._history if h["timestamp"] >= cutoff]
|