diff --git a/oak_service_spatial.py b/oak_service_spatial.py index 2150f76..f485838 100644 --- a/oak_service_spatial.py +++ b/oak_service_spatial.py @@ -21,6 +21,7 @@ import cv2 import numpy as np from face_recognition import FaceRecognizer +from pose_estimator import PoseEstimator logger = logging.getLogger("oak-service") logging.basicConfig(level=logging.INFO) @@ -42,6 +43,10 @@ FACE_DETECT_MODEL = MODELS_DIR / "ssd_mobilenet_v2_face_quant_postprocess_edgetp FACE_EMBED_MODEL = MODELS_DIR / "facenet.tflite" FACE_DB_PATH = Path(__file__).parent / "faces.db" +# Pose estimation +POSE_MODEL_PATH = MODELS_DIR / "movenet_single_pose_lightning_ptq_edgetpu.tflite" +POSE_CORAL_DEVICE = 1 # Second Coral (device 0 is headmic/YAMNet) + # ============== Global State ============== pipeline_ctx = None detection_queue = None @@ -51,6 +56,7 @@ detection_thread = None running = False labels = [] face_recognizer = None +pose_estimator = None presence_state = { "present": False, @@ -69,6 +75,16 @@ presence_state = { "recognition_confidence": None, } +pose_state = { + "active": False, + "keypoints": [], + "posture": {}, + "num_valid": 0, + "mean_confidence": 0.0, + "inference_ms": 0.0, + "last_update": None, +} + def init_face_recognition(): """Initialize Coral face detection + FaceNet embedding.""" @@ -147,6 +163,10 @@ def init_oak(): pipeline_ctx = pipeline print("✅ OAK-D initialized with SPATIAL person detection!") + + # Initialize pose estimator on Coral 2 + _init_pose_estimator() + return True except Exception as e: @@ -174,6 +194,48 @@ def cleanup_oak(): pipeline_ctx = None +def _init_pose_estimator(): + """Initialize MoveNet Lightning on the second Coral Edge TPU.""" + global pose_estimator + + if not POSE_MODEL_PATH.exists(): + print(f"⚠️ Pose model not found: {POSE_MODEL_PATH}") + return + + try: + pose_estimator = PoseEstimator( + model_path=str(POSE_MODEL_PATH), + device_index=POSE_CORAL_DEVICE, + ) + print("✅ Pose estimator initialized on Coral 2!") + except Exception as e: + print(f"⚠️ Pose estimator failed to initialize: {e}") + pose_estimator = None + + +def _run_pose_estimation(rgb_frame): + """Run pose estimation on an RGB frame via Coral 2.""" + global pose_state + + if pose_estimator is None: + return + + try: + result = pose_estimator.estimate(rgb_frame) + posture = pose_estimator.derive_posture(result["keypoints"]) + + pose_state["active"] = True + pose_state["keypoints"] = result["keypoints"] + pose_state["posture"] = posture + pose_state["num_valid"] = result["num_valid"] + pose_state["mean_confidence"] = result["mean_confidence"] + pose_state["inference_ms"] = result["inference_ms"] + pose_state["last_update"] = result["timestamp"] + + except Exception as e: + print(f"Pose estimation error: {e}") + + def detection_loop(): """Background thread for SPATIAL presence detection.""" global running, presence_state, detection_queue @@ -212,18 +274,26 @@ def detection_loop(): presence_state["spatial_z"] = best.spatialCoordinates.z presence_state["distance_mm"] = best.spatialCoordinates.z - # Face recognition + # Grab RGB frame for face recognition + pose estimation face_results = [] - if face_recognizer and rgb_queue: + rgb_frame = None + if rgb_queue: rgb_data = rgb_queue.tryGet() if rgb_data is not None: rgb_frame = rgb_data.getCvFrame() - try: - face_results = face_recognizer.process_frame( - rgb_frame, persons - ) - except Exception as e: - logger.warning("Face recognition error: %s", e) + + # Face recognition + if face_recognizer and rgb_frame is not None: + try: + face_results = face_recognizer.process_frame( + rgb_frame, persons + ) + except Exception as e: + logger.warning("Face recognition error: %s", e) + + # Pose estimation (runs on Coral 2, parallel-safe) + if rgb_frame is not None: + _run_pose_estimation(rgb_frame) det_list = [] best_recognized = None @@ -265,6 +335,14 @@ def detection_loop(): presence_state["recognized_name"] = None presence_state["recognition_confidence"] = None + # Clear pose when no person + if pose_state["active"]: + pose_state["active"] = False + pose_state["keypoints"] = [] + pose_state["posture"] = {} + pose_state["num_valid"] = 0 + pose_state["mean_confidence"] = 0.0 + # Check timeout if presence_state["last_seen"]: if now - presence_state["last_seen"] > PRESENCE_TIMEOUT: @@ -304,8 +382,8 @@ async def lifespan(app: FastAPI): app = FastAPI( title="OAK-D SPATIAL Vision Service", - description="Vixy's eyes with SPATIAL presence detection + face recognition! 🦊👀📏", - version="0.5.0", + description="Vixy's eyes with SPATIAL presence detection + face recognition + pose estimation! 🦊👀📏", + version="0.6.0", lifespan=lifespan ) @@ -316,11 +394,12 @@ async def health(): return { "status": "healthy", "service": "oak-service", - "version": "0.5.0", + "version": "0.6.0", "oak_connected": pipeline_ctx is not None, "detection_model": DETECTION_MODEL, "spatial_enabled": True, "face_recognition_enabled": face_recognizer is not None, + "pose_model_loaded": pose_estimator is not None, "timestamp": time.time() } @@ -415,6 +494,43 @@ async def depth_frame(): raise HTTPException(status_code=500, detail=str(e)) +# ============== Pose Estimation API ============== + + +@app.get("/pose") +async def pose(): + """Get current pose keypoints.""" + if pose_estimator is None: + raise HTTPException(status_code=503, detail="Pose estimator not available") + + return { + "active": pose_state["active"], + "keypoints": pose_state["keypoints"], + "num_valid": pose_state["num_valid"], + "mean_confidence": pose_state["mean_confidence"], + "inference_ms": pose_state["inference_ms"], + "last_update": pose_state["last_update"], + "timestamp": time.time(), + } + + +@app.get("/pose/summary") +async def pose_summary(): + """Get derived posture summary.""" + if pose_estimator is None: + raise HTTPException(status_code=503, detail="Pose estimator not available") + + return { + "active": pose_state["active"], + "posture": pose_state["posture"].get("posture", "unknown"), + "facing_camera": pose_state["posture"].get("facing_camera", False), + "arms_raised": pose_state["posture"].get("arms_raised", False), + "mean_confidence": pose_state["mean_confidence"], + "num_valid": pose_state["num_valid"], + "timestamp": time.time(), + } + + # ============== Face Enrollment API ==============