Add YAMNet sound classification to headmic
New sound_id.py module with SoundClassifier class that runs YAMNet (521 audio event categories) on CPU TFLite. Classifies audio every 0.5s from a ring buffer fed by the existing audio stream. Categories: speech, alert, music, animal, household, environment, silence. Smoothing via 20-sample history window for stable dominant category. New endpoints: GET /sounds, GET /sounds/history Updated: /health (sound_classification_enabled), /status (audio_scene) Graceful degradation if model files not present. Model download (not tracked in git): curl -sL 'https://tfhub.dev/google/lite-model/yamnet/classification/tflite/1?lite-format=tflite' -o models/yamnet.tflite Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
83
headmic.py
83
headmic.py
@@ -34,6 +34,7 @@ import wave
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
|
||||
import numpy as np
|
||||
import httpx
|
||||
import pvporcupine
|
||||
import webrtcvad
|
||||
@@ -117,9 +118,15 @@ class ServiceState:
|
||||
self.last_wake_time: Optional[float] = None
|
||||
self.wake_count = 0
|
||||
self.error: Optional[str] = None
|
||||
self.audio_scene: Optional[dict] = None
|
||||
self.sound_classification_enabled: bool = False
|
||||
|
||||
state = ServiceState()
|
||||
|
||||
# Sound classifier globals
|
||||
sound_classifier = None
|
||||
sound_ring_buffer = None # collections.deque, filled by listener_loop
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Audio Stream using ALSA directly (arecord)
|
||||
@@ -254,7 +261,11 @@ def listener_loop():
|
||||
|
||||
# Convert bytes to int16 array for Porcupine
|
||||
pcm = struct.unpack_from("h" * 512, frame_data)
|
||||
|
||||
|
||||
# Feed sound classifier ring buffer
|
||||
if sound_ring_buffer is not None:
|
||||
sound_ring_buffer.append(frame_data)
|
||||
|
||||
# Check for wake word
|
||||
keyword_index = porcupine.process(pcm)
|
||||
|
||||
@@ -327,6 +338,31 @@ def listener_loop():
|
||||
logger.info("Listener stopped")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Sound Classification Thread
|
||||
# ============================================================================
|
||||
|
||||
def sound_classifier_loop():
|
||||
"""Background thread for continuous sound classification."""
|
||||
global state
|
||||
logger.info("Sound classifier thread started")
|
||||
while state.running:
|
||||
if sound_ring_buffer is None or len(sound_ring_buffer) < 30:
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
try:
|
||||
frames = list(sound_ring_buffer)
|
||||
audio = np.frombuffer(b"".join(frames), dtype=np.int16)
|
||||
result = sound_classifier.classify(audio)
|
||||
state.audio_scene = result
|
||||
except Exception as e:
|
||||
logger.warning("Sound classification error: %s", e)
|
||||
|
||||
time.sleep(0.5)
|
||||
logger.info("Sound classifier thread stopped")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# FastAPI
|
||||
# ============================================================================
|
||||
@@ -336,7 +372,30 @@ app = FastAPI(title="HeadMic", description="Vixy's Ears 🦊👂")
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup():
|
||||
global sound_classifier, sound_ring_buffer
|
||||
|
||||
state.running = True
|
||||
|
||||
# Init sound classifier (optional — graceful if model missing)
|
||||
model_dir = Path(__file__).parent / "models"
|
||||
model_path = model_dir / "yamnet.tflite"
|
||||
class_map_path = model_dir / "yamnet_class_map.csv"
|
||||
if model_path.exists() and class_map_path.exists():
|
||||
try:
|
||||
from sound_id import SoundClassifier
|
||||
sound_classifier = SoundClassifier(str(model_path), str(class_map_path))
|
||||
# 31 frames of 512 samples = ~0.99s at 16kHz
|
||||
sound_ring_buffer = collections.deque(maxlen=31)
|
||||
state.sound_classification_enabled = True
|
||||
logger.info("Sound classification enabled (YAMNet)")
|
||||
|
||||
sc_thread = threading.Thread(target=sound_classifier_loop, daemon=True)
|
||||
sc_thread.start()
|
||||
except Exception as e:
|
||||
logger.warning("Sound classification unavailable: %s", e)
|
||||
else:
|
||||
logger.info("Sound classification models not found, skipping")
|
||||
|
||||
thread = threading.Thread(target=listener_loop, daemon=True)
|
||||
thread.start()
|
||||
logger.info("HeadMic started")
|
||||
@@ -365,6 +424,7 @@ async def health():
|
||||
"recording": state.recording,
|
||||
"processing": state.processing,
|
||||
"wake_count": state.wake_count,
|
||||
"sound_classification_enabled": state.sound_classification_enabled,
|
||||
"error": state.error
|
||||
}
|
||||
|
||||
@@ -378,6 +438,7 @@ async def status():
|
||||
"last_transcription": state.last_transcription,
|
||||
"last_wake_time": state.last_wake_time,
|
||||
"wake_count": state.wake_count,
|
||||
"audio_scene": state.audio_scene["dominant_category"] if state.audio_scene else None,
|
||||
"error": state.error
|
||||
}
|
||||
|
||||
@@ -390,6 +451,26 @@ async def last():
|
||||
}
|
||||
|
||||
|
||||
@app.get("/sounds")
|
||||
async def sounds():
|
||||
"""Current audio scene classification."""
|
||||
if not state.sound_classification_enabled:
|
||||
raise HTTPException(status_code=503, detail="Sound classification not available")
|
||||
if state.audio_scene is None:
|
||||
return {"category": None, "top_classes": [], "dominant_category": None, "timestamp": None}
|
||||
return state.audio_scene
|
||||
|
||||
|
||||
@app.get("/sounds/history")
|
||||
async def sounds_history(seconds: int = 30):
|
||||
"""Recent sound classification history."""
|
||||
if not state.sound_classification_enabled:
|
||||
raise HTTPException(status_code=503, detail="Sound classification not available")
|
||||
if sound_classifier is None:
|
||||
return {"history": []}
|
||||
return {"history": sound_classifier.get_history(seconds)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8446)
|
||||
|
||||
Reference in New Issue
Block a user