Files
vixy-vision/server/detector.py
Alex 1bcf32889f Add label whitelist to filter detection types
DETECTION_LABELS env var accepts comma-separated list (e.g. "person,cat,dog").
Only matching detections are reported; others are ignored. Empty = report all.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-08 19:08:31 -06:00

222 lines
7.3 KiB
Python

#!/usr/bin/env python3
"""
Object Detection Module
Lightweight object detection using TensorFlow Lite with MobileNet V2 SSD.
Designed to run on Raspberry Pi 4/5 with minimal overhead.
The model is lazy-loaded on first detect() call to avoid startup delay.
"""
import cv2
import logging
import numpy as np
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
logger = logging.getLogger(__name__)
@dataclass
class Detection:
"""A single detected object"""
label: str
confidence: float
bbox: tuple # (x_min, y_min, x_max, y_max) normalized 0-1
class_id: int
class ObjectDetector:
"""
Object detection using TFLite MobileNet V2 SSD.
Lazy-loads the model on first detect() call. Designed to be called
from the motion detection thread after motion is confirmed.
"""
def __init__(
self,
model_path: str,
labels_path: str,
confidence_threshold: float = 0.5,
label_whitelist: Optional[set[str]] = None,
):
self.model_path = Path(model_path)
self.labels_path = Path(labels_path)
self.confidence_threshold = confidence_threshold
self.label_whitelist = label_whitelist
self._interpreter = None
self._input_details = None
self._output_details = None
self._labels: list[str] = []
self._input_height = 0
self._input_width = 0
def _load_model(self):
"""Load TFLite model and label map"""
# Try ai-edge-litert (modern), then tflite-runtime (legacy)
try:
from ai_edge_litert import interpreter as tflite
except ImportError:
try:
import tflite_runtime.interpreter as tflite
except ImportError:
raise ImportError(
"No TFLite runtime found. Install one of:\n"
" pip install ai-edge-litert (Python 3.12+)\n"
" pip install tflite-runtime (Python 3.9-3.11)"
)
if not self.model_path.exists():
raise FileNotFoundError(
f"Model file not found: {self.model_path}\n"
f"Run download_model.sh to download the model."
)
# Load labels
if self.labels_path.exists():
self._labels = self.labels_path.read_text().strip().splitlines()
else:
logger.warning(f"Labels file not found: {self.labels_path}")
self._labels = []
# Try XNNPACK delegate for ARM acceleration
delegates = []
try:
delegates = [tflite.load_delegate('libXNNPACK.so')]
logger.info("XNNPACK delegate loaded")
except (ValueError, OSError):
logger.info("XNNPACK delegate not available, using default CPU")
# Load model
self._interpreter = tflite.Interpreter(
model_path=str(self.model_path),
experimental_delegates=delegates if delegates else None,
)
self._interpreter.allocate_tensors()
self._input_details = self._interpreter.get_input_details()
self._output_details = self._interpreter.get_output_details()
# Get expected input size
input_shape = self._input_details[0]['shape']
self._input_height = input_shape[1]
self._input_width = input_shape[2]
logger.info(
f"Object detection model loaded: {self.model_path.name} "
f"(input: {self._input_width}x{self._input_height}, "
f"{len(self._labels)} classes)"
)
def detect(self, frame: np.ndarray) -> list[Detection]:
"""
Run object detection on a frame.
Args:
frame: BGR numpy array from OpenCV
Returns:
List of Detection objects above confidence threshold
"""
# Lazy load
if self._interpreter is None:
self._load_model()
# Preprocess: resize and convert BGR to RGB
input_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
input_frame = cv2.resize(input_frame, (self._input_width, self._input_height))
input_data = np.expand_dims(input_frame, axis=0)
# Ensure correct dtype
input_dtype = self._input_details[0]['dtype']
if input_dtype == np.uint8:
input_data = input_data.astype(np.uint8)
elif input_dtype == np.float32:
input_data = (input_data / 255.0).astype(np.float32)
# Run inference
self._interpreter.set_tensor(self._input_details[0]['index'], input_data)
self._interpreter.invoke()
# Parse outputs (SSD MobileNet post-processed format):
# [0] bounding boxes: [1, N, 4] (y_min, x_min, y_max, x_max) normalized
# [1] class IDs: [1, N]
# [2] scores: [1, N]
# [3] number of detections: [1]
boxes = self._interpreter.get_tensor(self._output_details[0]['index'])[0]
class_ids = self._interpreter.get_tensor(self._output_details[1]['index'])[0]
scores = self._interpreter.get_tensor(self._output_details[2]['index'])[0]
num_detections = int(self._interpreter.get_tensor(self._output_details[3]['index'])[0])
# Filter by confidence
detections = []
for i in range(num_detections):
score = float(scores[i])
if score < self.confidence_threshold:
continue
class_id = int(class_ids[i])
label = self._labels[class_id] if class_id < len(self._labels) else f"class_{class_id}"
# Skip labels not in whitelist (if set)
if self.label_whitelist and label not in self.label_whitelist:
continue
# Convert from (y_min, x_min, y_max, x_max) to (x_min, y_min, x_max, y_max)
y_min, x_min, y_max, x_max = boxes[i]
bbox = (
float(np.clip(x_min, 0, 1)),
float(np.clip(y_min, 0, 1)),
float(np.clip(x_max, 0, 1)),
float(np.clip(y_max, 0, 1)),
)
detections.append(Detection(
label=label,
confidence=score,
bbox=bbox,
class_id=class_id,
))
return detections
def annotate_frame(frame: np.ndarray, detections: list[Detection]) -> np.ndarray:
"""
Draw bounding boxes and labels on a frame.
Args:
frame: BGR numpy array (will be copied, not modified in place)
detections: List of Detection objects
Returns:
Annotated copy of the frame
"""
annotated = frame.copy()
h, w = annotated.shape[:2]
for det in detections:
x1 = int(det.bbox[0] * w)
y1 = int(det.bbox[1] * h)
x2 = int(det.bbox[2] * w)
y2 = int(det.bbox[3] * h)
# Green box with label
color = (0, 255, 0)
cv2.rectangle(annotated, (x1, y1), (x2, y2), color, 2)
label_text = f"{det.label} {det.confidence:.0%}"
font_scale = 0.6
thickness = 1
(tw, th), _ = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)
# Background for text
cv2.rectangle(annotated, (x1, y1 - th - 8), (x1 + tw + 4, y1), color, -1)
cv2.putText(annotated, label_text, (x1 + 2, y1 - 4),
cv2.FONT_HERSHEY_SIMPLEX, font_scale, (0, 0, 0), thickness)
return annotated