diff --git a/server/detector.py b/server/detector.py index bd4b291..9d52685 100644 --- a/server/detector.py +++ b/server/detector.py @@ -40,10 +40,12 @@ class ObjectDetector: model_path: str, labels_path: str, confidence_threshold: float = 0.5, + label_whitelist: Optional[set[str]] = None, ): self.model_path = Path(model_path) self.labels_path = Path(labels_path) self.confidence_threshold = confidence_threshold + self.label_whitelist = label_whitelist self._interpreter = None self._input_details = None @@ -159,6 +161,10 @@ class ObjectDetector: class_id = int(class_ids[i]) label = self._labels[class_id] if class_id < len(self._labels) else f"class_{class_id}" + # Skip labels not in whitelist (if set) + if self.label_whitelist and label not in self.label_whitelist: + continue + # Convert from (y_min, x_min, y_max, x_max) to (x_min, y_min, x_max, y_max) y_min, x_min, y_max, x_max = boxes[i] bbox = ( diff --git a/server/env.example b/server/env.example index 27ed369..9e4d06e 100644 --- a/server/env.example +++ b/server/env.example @@ -60,6 +60,9 @@ DETECTION_CONFIDENCE=0.5 # Set to false to keep reporting all motion events DETECTION_SUPPRESS_EMPTY=true +# Only report these object types (comma-separated, empty = all) +DETECTION_LABELS=person,cat,dog + # ============ Event Collector ============ # URL to POST motion events to (collector on Mac mini) diff --git a/server/main.py b/server/main.py index 0a83639..94e338c 100644 --- a/server/main.py +++ b/server/main.py @@ -47,6 +47,7 @@ DETECTION_MODEL_PATH = os.getenv("DETECTION_MODEL_PATH", "models/ssd_mobilenet_v DETECTION_LABELS_PATH = os.getenv("DETECTION_LABELS_PATH", "models/coco_labels.txt") DETECTION_CONFIDENCE = float(os.getenv("DETECTION_CONFIDENCE", "0.5")) DETECTION_SUPPRESS_EMPTY = os.getenv("DETECTION_SUPPRESS_EMPTY", "true").lower() == "true" +DETECTION_LABELS = os.getenv("DETECTION_LABELS", "") # Comma-separated whitelist (empty = all) if not API_KEY: raise ValueError("API_KEY not set in .env file") @@ -153,6 +154,7 @@ if MOTION_ENABLED: detection_labels_path=DETECTION_LABELS_PATH, detection_confidence=DETECTION_CONFIDENCE, detection_suppress_empty=DETECTION_SUPPRESS_EMPTY, + detection_labels=DETECTION_LABELS if DETECTION_LABELS else None, ) @@ -241,6 +243,7 @@ def enable_motion(api_key: str = Security(verify_api_key)): detection_labels_path=DETECTION_LABELS_PATH, detection_confidence=DETECTION_CONFIDENCE, detection_suppress_empty=DETECTION_SUPPRESS_EMPTY, + detection_labels=DETECTION_LABELS if DETECTION_LABELS else None, ) motion_detector.start(camera_manager.get_raw_frame) diff --git a/server/motion.py b/server/motion.py index e97f770..2f9f687 100644 --- a/server/motion.py +++ b/server/motion.py @@ -61,6 +61,7 @@ class MotionDetector: detection_labels_path: Optional[str] = None, detection_confidence: float = 0.5, detection_suppress_empty: bool = True, + detection_labels: Optional[str] = None, ): self.camera_id = camera_id self.collector_url = collector_url @@ -82,10 +83,14 @@ class MotionDetector: if detection_enabled and detection_model_path: try: from detector import ObjectDetector + label_whitelist = None + if detection_labels: + label_whitelist = set(l.strip() for l in detection_labels.split(",")) self._detector = ObjectDetector( model_path=detection_model_path, labels_path=detection_labels_path or "", confidence_threshold=detection_confidence, + label_whitelist=label_whitelist, ) logger.info(f"Object detection enabled (model: {detection_model_path})") except ImportError as e: