Files
headmic/sound_id.py
Alex 5e3c16659f Add YAMNet sound classification to headmic
New sound_id.py module with SoundClassifier class that runs YAMNet
(521 audio event categories) on CPU TFLite. Classifies audio every
0.5s from a ring buffer fed by the existing audio stream.

Categories: speech, alert, music, animal, household, environment, silence.
Smoothing via 20-sample history window for stable dominant category.

New endpoints: GET /sounds, GET /sounds/history
Updated: /health (sound_classification_enabled), /status (audio_scene)
Graceful degradation if model files not present.

Model download (not tracked in git):
  curl -sL 'https://tfhub.dev/google/lite-model/yamnet/classification/tflite/1?lite-format=tflite' -o models/yamnet.tflite

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-01 20:41:44 -06:00

183 lines
6.9 KiB
Python

"""
Sound Identification Module for HeadMic
YAMNet audio event classifier — CPU TFLite (Edge TPU ready)
"""
import collections
import csv
import logging
import time
import numpy as np
logger = logging.getLogger("sound_id")
logger.setLevel(logging.INFO)
# YAMNet expects 0.975s of 16kHz audio = 15600 samples
YAMNET_SAMPLE_RATE = 16000
YAMNET_INPUT_SAMPLES = 15600
# Category mapping: group YAMNet's 521 classes into useful buckets
CATEGORY_GROUPS = {
"speech": [
"Speech", "Child speech, kid speaking", "Conversation",
"Narration, monologue", "Babbling", "Speech synthesizer",
"Shout", "Yell", "Whispering", "Laughter", "Chatter",
],
"alert": [
"Doorbell", "Ding-dong", "Knock", "Tap", "Alarm",
"Alarm clock", "Smoke detector, smoke alarm", "Fire alarm",
"Siren", "Buzzer", "Telephone", "Telephone bell ringing",
"Ringtone", "Telephone dialing, DTMF", "Bell", "Church bell",
],
"music": [
"Music", "Musical instrument", "Singing", "Song", "Guitar",
"Electric guitar", "Acoustic guitar", "Bass guitar", "Piano",
"Keyboard (musical)", "Drum", "Drum kit", "Snare drum",
"Bass drum", "Cymbal", "Hi-hat", "Violin, fiddle", "Flute",
"Trumpet", "Saxophone", "Harmonica", "Organ", "Synthesizer",
"Plucked string instrument", "Strum", "Hip hop music",
"Pop music", "Rock music", "Heavy metal", "Punk rock",
"Grunge", "Progressive rock", "Rock and roll", "Jazz",
"Blues", "Soul music", "Reggae", "Country", "Swing music",
"Bluegrass", "Funk", "Folk music", "Middle Eastern music",
"Electronic music", "House music", "Techno", "Dubstep",
"Drum and bass", "Electronica", "Electronic dance music",
"Ambient music", "Trance music", "Music of Latin America",
"Salsa music", "Flamenco", "Gospel music", "Christian music",
"Music of Africa", "Afrobeat", "Music of Asia",
],
"animal": [
"Dog", "Bark", "Howl", "Bow-wow", "Growling", "Whimper",
"Cat", "Purr", "Meow", "Hiss", "Caterwaul",
"Bird", "Bird vocalization, bird call, bird song", "Chirp, tweet",
"Squawk", "Crow", "Caw", "Owl", "Pigeon, dove",
"Insect", "Cricket", "Mosquito", "Fly, housefly", "Bee, wasp, etc.",
"Frog", "Rooster", "Chicken, hen", "Duck", "Goose",
],
"household": [
"Door", "Doorbell", "Sliding door", "Slam",
"Drawer open or close", "Cupboard open or close",
"Cutlery, silverware", "Dishes, pots, and pans",
"Glass", "Chink, clink", "Water tap, faucet",
"Sink (filling or washing)", "Bathtub (filling or washing)",
"Toilet flush", "Vacuum cleaner", "Blender",
"Microwave oven", "Typing", "Computer keyboard",
"Writing", "Keys jangling", "Coin (dropping)",
"Footsteps", "Walk, footsteps", "Run",
"Clapping", "Finger snapping",
],
"environment": [
"Wind", "Rustling leaves", "Wind noise (microphone)",
"Rain", "Raindrop", "Rain on surface", "Thunder",
"Thunderstorm", "Stream", "Waterfall", "Ocean",
"Waves, surf", "Traffic noise, roadway noise",
"Car", "Engine", "Idling", "Truck", "Bus", "Motorcycle",
"Aircraft", "Aircraft engine", "Helicopter",
"Train", "Railroad car, train wagon",
],
"silence": [
"Silence", "White noise", "Static",
],
}
class SoundClassifier:
def __init__(self, model_path, class_map_path, use_edgetpu=False):
# Load class names
self._class_names = []
with open(class_map_path) as f:
reader = csv.reader(f)
next(reader) # skip header
for row in reader:
self._class_names.append(row[2])
logger.info("Loaded %d YAMNet class names", len(self._class_names))
# Build reverse lookup: class_name -> category
self._class_to_category = {}
for category, names in CATEGORY_GROUPS.items():
for name in names:
self._class_to_category[name] = category
# Load TFLite model
import ai_edge_litert.interpreter as tfl
if use_edgetpu:
delegate = tfl.load_delegate("libedgetpu.so.1")
self._interp = tfl.Interpreter(
model_path=str(model_path),
experimental_delegates=[delegate],
)
logger.info("YAMNet loaded on Edge TPU")
else:
self._interp = tfl.Interpreter(model_path=str(model_path))
logger.info("YAMNet loaded on CPU")
self._interp.allocate_tensors()
self._input = self._interp.get_input_details()[0]
self._output = self._interp.get_output_details()[0]
logger.info(
"YAMNet ready: input %s %s, output %s",
self._input["shape"], self._input["dtype"], self._output["shape"],
)
# Classification history for smoothing
self._history = collections.deque(maxlen=20)
def classify(self, audio_int16):
"""Classify an audio buffer.
Args:
audio_int16: numpy int16 array of PCM samples at 16kHz
Returns:
dict with category, top_classes, dominant_category, timestamp
"""
# Convert int16 PCM to float32 [-1.0, 1.0]
audio_f32 = audio_int16.astype(np.float32) / 32768.0
# Pad or trim to exact input size
if len(audio_f32) < YAMNET_INPUT_SAMPLES:
audio_f32 = np.pad(audio_f32, (0, YAMNET_INPUT_SAMPLES - len(audio_f32)))
elif len(audio_f32) > YAMNET_INPUT_SAMPLES:
audio_f32 = audio_f32[-YAMNET_INPUT_SAMPLES:]
self._interp.set_tensor(self._input["index"], audio_f32)
self._interp.invoke()
scores = self._interp.get_tensor(self._output["index"])[0]
# Top 5 classes
top_indices = np.argsort(scores)[-5:][::-1]
top_classes = []
for idx in top_indices:
name = self._class_names[idx] if idx < len(self._class_names) else "Unknown"
top_classes.append({"name": name, "score": round(float(scores[idx]), 3)})
# Map top class to category
top_name = top_classes[0]["name"]
category = self._class_to_category.get(top_name, "other")
# Update history
now = time.time()
self._history.append({"category": category, "timestamp": now})
# Dominant category from recent history
if self._history:
cat_counts = collections.Counter(h["category"] for h in self._history)
dominant = cat_counts.most_common(1)[0][0]
else:
dominant = category
return {
"category": category,
"top_classes": top_classes,
"dominant_category": dominant,
"timestamp": now,
}
def get_history(self, seconds=30):
"""Return recent classification history."""
cutoff = time.time() - seconds
return [h for h in self._history if h["timestamp"] >= cutoff]