Probes the Edge TPU in a subprocess before loading — catches segfaults (libedgetpu ABI mismatch on Debian Trixie/Python 3.13) and falls back to CPU automatically. No more service crashes on Coral incompatibility. When the runtime is eventually fixed, Edge TPU will be used automatically with no config changes needed. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
216 lines
8.4 KiB
Python
216 lines
8.4 KiB
Python
"""
|
|
Sound Identification Module for HeadMic
|
|
YAMNet audio event classifier — CPU TFLite (Edge TPU ready)
|
|
"""
|
|
|
|
import collections
|
|
import csv
|
|
import logging
|
|
import time
|
|
|
|
import numpy as np
|
|
|
|
logger = logging.getLogger("sound_id")
|
|
logger.setLevel(logging.INFO)
|
|
|
|
# YAMNet expects 0.975s of 16kHz audio = 15600 samples
|
|
YAMNET_SAMPLE_RATE = 16000
|
|
YAMNET_INPUT_SAMPLES = 15600
|
|
|
|
# Category mapping: group YAMNet's 521 classes into useful buckets
|
|
CATEGORY_GROUPS = {
|
|
"speech": [
|
|
"Speech", "Child speech, kid speaking", "Conversation",
|
|
"Narration, monologue", "Babbling", "Speech synthesizer",
|
|
"Shout", "Yell", "Whispering", "Laughter", "Chatter",
|
|
],
|
|
"alert": [
|
|
"Doorbell", "Ding-dong", "Knock", "Tap", "Alarm",
|
|
"Alarm clock", "Smoke detector, smoke alarm", "Fire alarm",
|
|
"Siren", "Buzzer", "Telephone", "Telephone bell ringing",
|
|
"Ringtone", "Telephone dialing, DTMF", "Bell", "Church bell",
|
|
],
|
|
"music": [
|
|
"Music", "Musical instrument", "Singing", "Song", "Guitar",
|
|
"Electric guitar", "Acoustic guitar", "Bass guitar", "Piano",
|
|
"Keyboard (musical)", "Drum", "Drum kit", "Snare drum",
|
|
"Bass drum", "Cymbal", "Hi-hat", "Violin, fiddle", "Flute",
|
|
"Trumpet", "Saxophone", "Harmonica", "Organ", "Synthesizer",
|
|
"Plucked string instrument", "Strum", "Hip hop music",
|
|
"Pop music", "Rock music", "Heavy metal", "Punk rock",
|
|
"Grunge", "Progressive rock", "Rock and roll", "Jazz",
|
|
"Blues", "Soul music", "Reggae", "Country", "Swing music",
|
|
"Bluegrass", "Funk", "Folk music", "Middle Eastern music",
|
|
"Electronic music", "House music", "Techno", "Dubstep",
|
|
"Drum and bass", "Electronica", "Electronic dance music",
|
|
"Ambient music", "Trance music", "Music of Latin America",
|
|
"Salsa music", "Flamenco", "Gospel music", "Christian music",
|
|
"Music of Africa", "Afrobeat", "Music of Asia",
|
|
],
|
|
"animal": [
|
|
"Dog", "Bark", "Howl", "Bow-wow", "Growling", "Whimper",
|
|
"Cat", "Purr", "Meow", "Hiss", "Caterwaul",
|
|
"Bird", "Bird vocalization, bird call, bird song", "Chirp, tweet",
|
|
"Squawk", "Crow", "Caw", "Owl", "Pigeon, dove",
|
|
"Insect", "Cricket", "Mosquito", "Fly, housefly", "Bee, wasp, etc.",
|
|
"Frog", "Rooster", "Chicken, hen", "Duck", "Goose",
|
|
],
|
|
"household": [
|
|
"Door", "Doorbell", "Sliding door", "Slam",
|
|
"Drawer open or close", "Cupboard open or close",
|
|
"Cutlery, silverware", "Dishes, pots, and pans",
|
|
"Glass", "Chink, clink", "Water tap, faucet",
|
|
"Sink (filling or washing)", "Bathtub (filling or washing)",
|
|
"Toilet flush", "Vacuum cleaner", "Blender",
|
|
"Microwave oven", "Typing", "Computer keyboard",
|
|
"Writing", "Keys jangling", "Coin (dropping)",
|
|
"Footsteps", "Walk, footsteps", "Run",
|
|
"Clapping", "Finger snapping",
|
|
],
|
|
"environment": [
|
|
"Wind", "Rustling leaves", "Wind noise (microphone)",
|
|
"Rain", "Raindrop", "Rain on surface", "Thunder",
|
|
"Thunderstorm", "Stream", "Waterfall", "Ocean",
|
|
"Waves, surf", "Traffic noise, roadway noise",
|
|
"Car", "Engine", "Idling", "Truck", "Bus", "Motorcycle",
|
|
"Aircraft", "Aircraft engine", "Helicopter",
|
|
"Train", "Railroad car, train wagon",
|
|
"Inside, small room", "Inside, large room or hall",
|
|
"Outside, urban or manmade", "Outside, rural or natural",
|
|
],
|
|
"silence": [
|
|
"Silence", "White noise", "Static",
|
|
],
|
|
}
|
|
|
|
|
|
class SoundClassifier:
|
|
@staticmethod
|
|
def _probe_edgetpu(model_path: str) -> bool:
|
|
"""Test Edge TPU in a subprocess to catch segfaults safely."""
|
|
import subprocess, sys
|
|
try:
|
|
result = subprocess.run(
|
|
[sys.executable, "-c",
|
|
"import ai_edge_litert.interpreter as tfl; "
|
|
f"d = tfl.load_delegate('libedgetpu.so.1'); "
|
|
f"i = tfl.Interpreter(model_path='{model_path}', experimental_delegates=[d]); "
|
|
"i.allocate_tensors(); "
|
|
"print('ok')"],
|
|
capture_output=True, text=True, timeout=10
|
|
)
|
|
if result.returncode == 0 and "ok" in result.stdout:
|
|
logger.info("Edge TPU probe: OK")
|
|
return True
|
|
logger.warning("Edge TPU probe failed: %s", result.stderr.strip() or f"exit {result.returncode}")
|
|
return False
|
|
except Exception as e:
|
|
logger.warning("Edge TPU probe error: %s", e)
|
|
return False
|
|
|
|
def __init__(self, model_path, class_map_path, use_edgetpu=False):
|
|
# Load class names
|
|
self._class_names = []
|
|
with open(class_map_path) as f:
|
|
reader = csv.reader(f)
|
|
next(reader) # skip header
|
|
for row in reader:
|
|
self._class_names.append(row[2])
|
|
logger.info("Loaded %d YAMNet class names", len(self._class_names))
|
|
|
|
# Build reverse lookup: class_name -> category
|
|
self._class_to_category = {}
|
|
for category, names in CATEGORY_GROUPS.items():
|
|
for name in names:
|
|
self._class_to_category[name] = category
|
|
|
|
# Load TFLite model
|
|
import ai_edge_litert.interpreter as tfl
|
|
|
|
if use_edgetpu:
|
|
if self._probe_edgetpu(model_path):
|
|
delegate = tfl.load_delegate("libedgetpu.so.1")
|
|
self._interp = tfl.Interpreter(
|
|
model_path=str(model_path),
|
|
experimental_delegates=[delegate],
|
|
)
|
|
logger.info("YAMNet loaded on Edge TPU")
|
|
else:
|
|
logger.warning("Edge TPU probe failed (segfault or error) — falling back to CPU")
|
|
use_edgetpu = False
|
|
|
|
if not use_edgetpu:
|
|
# Use CPU model (swap edgetpu model path for CPU model if needed)
|
|
cpu_path = str(model_path).replace("_edgetpu.tflite", ".tflite")
|
|
self._interp = tfl.Interpreter(model_path=cpu_path)
|
|
logger.info("YAMNet loaded on CPU")
|
|
|
|
self._interp.allocate_tensors()
|
|
self._input = self._interp.get_input_details()[0]
|
|
self._output = self._interp.get_output_details()[0]
|
|
logger.info(
|
|
"YAMNet ready: input %s %s, output %s",
|
|
self._input["shape"], self._input["dtype"], self._output["shape"],
|
|
)
|
|
|
|
# Classification history for smoothing
|
|
self._history = collections.deque(maxlen=20)
|
|
|
|
def classify(self, audio_int16):
|
|
"""Classify an audio buffer.
|
|
|
|
Args:
|
|
audio_int16: numpy int16 array of PCM samples at 16kHz
|
|
|
|
Returns:
|
|
dict with category, top_classes, dominant_category, timestamp
|
|
"""
|
|
# Convert int16 PCM to float32 [-1.0, 1.0]
|
|
audio_f32 = audio_int16.astype(np.float32) / 32768.0
|
|
|
|
# Pad or trim to exact input size
|
|
if len(audio_f32) < YAMNET_INPUT_SAMPLES:
|
|
audio_f32 = np.pad(audio_f32, (0, YAMNET_INPUT_SAMPLES - len(audio_f32)))
|
|
elif len(audio_f32) > YAMNET_INPUT_SAMPLES:
|
|
audio_f32 = audio_f32[-YAMNET_INPUT_SAMPLES:]
|
|
|
|
self._interp.set_tensor(self._input["index"], audio_f32)
|
|
self._interp.invoke()
|
|
|
|
scores = self._interp.get_tensor(self._output["index"])[0]
|
|
|
|
# Top 5 classes
|
|
top_indices = np.argsort(scores)[-5:][::-1]
|
|
top_classes = []
|
|
for idx in top_indices:
|
|
name = self._class_names[idx] if idx < len(self._class_names) else "Unknown"
|
|
top_classes.append({"name": name, "score": round(float(scores[idx]), 3)})
|
|
|
|
# Map top class to category
|
|
top_name = top_classes[0]["name"]
|
|
category = self._class_to_category.get(top_name, "other")
|
|
|
|
# Update history
|
|
now = time.time()
|
|
self._history.append({"category": category, "timestamp": now})
|
|
|
|
# Dominant category from recent history
|
|
if self._history:
|
|
cat_counts = collections.Counter(h["category"] for h in self._history)
|
|
dominant = cat_counts.most_common(1)[0][0]
|
|
else:
|
|
dominant = category
|
|
|
|
return {
|
|
"category": category,
|
|
"top_classes": top_classes,
|
|
"dominant_category": dominant,
|
|
"timestamp": now,
|
|
"audio_float32": audio_f32,
|
|
}
|
|
|
|
def get_history(self, seconds=30):
|
|
"""Return recent classification history."""
|
|
cutoff = time.time() - seconds
|
|
return [h for h in self._history if h["timestamp"] >= cutoff]
|