Add pose estimation to spatial service (production file)

Integrates MoveNet Lightning on Coral 2 into oak_service_spatial.py,
which is the actual production service running on head-vixy. Reuses
the existing RGB frame grab (shared with face recognition) for pose
estimation. Adds /pose and /pose/summary endpoints.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Alex
2026-02-08 19:33:44 -06:00
parent cdbf7ff394
commit 0bbd54b40f

View File

@@ -21,6 +21,7 @@ import cv2
import numpy as np import numpy as np
from face_recognition import FaceRecognizer from face_recognition import FaceRecognizer
from pose_estimator import PoseEstimator
logger = logging.getLogger("oak-service") logger = logging.getLogger("oak-service")
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@@ -42,6 +43,10 @@ FACE_DETECT_MODEL = MODELS_DIR / "ssd_mobilenet_v2_face_quant_postprocess_edgetp
FACE_EMBED_MODEL = MODELS_DIR / "facenet.tflite" FACE_EMBED_MODEL = MODELS_DIR / "facenet.tflite"
FACE_DB_PATH = Path(__file__).parent / "faces.db" FACE_DB_PATH = Path(__file__).parent / "faces.db"
# Pose estimation
POSE_MODEL_PATH = MODELS_DIR / "movenet_single_pose_lightning_ptq_edgetpu.tflite"
POSE_CORAL_DEVICE = 1 # Second Coral (device 0 is headmic/YAMNet)
# ============== Global State ============== # ============== Global State ==============
pipeline_ctx = None pipeline_ctx = None
detection_queue = None detection_queue = None
@@ -51,6 +56,7 @@ detection_thread = None
running = False running = False
labels = [] labels = []
face_recognizer = None face_recognizer = None
pose_estimator = None
presence_state = { presence_state = {
"present": False, "present": False,
@@ -69,6 +75,16 @@ presence_state = {
"recognition_confidence": None, "recognition_confidence": None,
} }
pose_state = {
"active": False,
"keypoints": [],
"posture": {},
"num_valid": 0,
"mean_confidence": 0.0,
"inference_ms": 0.0,
"last_update": None,
}
def init_face_recognition(): def init_face_recognition():
"""Initialize Coral face detection + FaceNet embedding.""" """Initialize Coral face detection + FaceNet embedding."""
@@ -147,6 +163,10 @@ def init_oak():
pipeline_ctx = pipeline pipeline_ctx = pipeline
print("✅ OAK-D initialized with SPATIAL person detection!") print("✅ OAK-D initialized with SPATIAL person detection!")
# Initialize pose estimator on Coral 2
_init_pose_estimator()
return True return True
except Exception as e: except Exception as e:
@@ -174,6 +194,48 @@ def cleanup_oak():
pipeline_ctx = None pipeline_ctx = None
def _init_pose_estimator():
"""Initialize MoveNet Lightning on the second Coral Edge TPU."""
global pose_estimator
if not POSE_MODEL_PATH.exists():
print(f"⚠️ Pose model not found: {POSE_MODEL_PATH}")
return
try:
pose_estimator = PoseEstimator(
model_path=str(POSE_MODEL_PATH),
device_index=POSE_CORAL_DEVICE,
)
print("✅ Pose estimator initialized on Coral 2!")
except Exception as e:
print(f"⚠️ Pose estimator failed to initialize: {e}")
pose_estimator = None
def _run_pose_estimation(rgb_frame):
"""Run pose estimation on an RGB frame via Coral 2."""
global pose_state
if pose_estimator is None:
return
try:
result = pose_estimator.estimate(rgb_frame)
posture = pose_estimator.derive_posture(result["keypoints"])
pose_state["active"] = True
pose_state["keypoints"] = result["keypoints"]
pose_state["posture"] = posture
pose_state["num_valid"] = result["num_valid"]
pose_state["mean_confidence"] = result["mean_confidence"]
pose_state["inference_ms"] = result["inference_ms"]
pose_state["last_update"] = result["timestamp"]
except Exception as e:
print(f"Pose estimation error: {e}")
def detection_loop(): def detection_loop():
"""Background thread for SPATIAL presence detection.""" """Background thread for SPATIAL presence detection."""
global running, presence_state, detection_queue global running, presence_state, detection_queue
@@ -212,18 +274,26 @@ def detection_loop():
presence_state["spatial_z"] = best.spatialCoordinates.z presence_state["spatial_z"] = best.spatialCoordinates.z
presence_state["distance_mm"] = best.spatialCoordinates.z presence_state["distance_mm"] = best.spatialCoordinates.z
# Face recognition # Grab RGB frame for face recognition + pose estimation
face_results = [] face_results = []
if face_recognizer and rgb_queue: rgb_frame = None
if rgb_queue:
rgb_data = rgb_queue.tryGet() rgb_data = rgb_queue.tryGet()
if rgb_data is not None: if rgb_data is not None:
rgb_frame = rgb_data.getCvFrame() rgb_frame = rgb_data.getCvFrame()
try:
face_results = face_recognizer.process_frame( # Face recognition
rgb_frame, persons if face_recognizer and rgb_frame is not None:
) try:
except Exception as e: face_results = face_recognizer.process_frame(
logger.warning("Face recognition error: %s", e) rgb_frame, persons
)
except Exception as e:
logger.warning("Face recognition error: %s", e)
# Pose estimation (runs on Coral 2, parallel-safe)
if rgb_frame is not None:
_run_pose_estimation(rgb_frame)
det_list = [] det_list = []
best_recognized = None best_recognized = None
@@ -265,6 +335,14 @@ def detection_loop():
presence_state["recognized_name"] = None presence_state["recognized_name"] = None
presence_state["recognition_confidence"] = None presence_state["recognition_confidence"] = None
# Clear pose when no person
if pose_state["active"]:
pose_state["active"] = False
pose_state["keypoints"] = []
pose_state["posture"] = {}
pose_state["num_valid"] = 0
pose_state["mean_confidence"] = 0.0
# Check timeout # Check timeout
if presence_state["last_seen"]: if presence_state["last_seen"]:
if now - presence_state["last_seen"] > PRESENCE_TIMEOUT: if now - presence_state["last_seen"] > PRESENCE_TIMEOUT:
@@ -304,8 +382,8 @@ async def lifespan(app: FastAPI):
app = FastAPI( app = FastAPI(
title="OAK-D SPATIAL Vision Service", title="OAK-D SPATIAL Vision Service",
description="Vixy's eyes with SPATIAL presence detection + face recognition! 🦊👀📏", description="Vixy's eyes with SPATIAL presence detection + face recognition + pose estimation! 🦊👀📏",
version="0.5.0", version="0.6.0",
lifespan=lifespan lifespan=lifespan
) )
@@ -316,11 +394,12 @@ async def health():
return { return {
"status": "healthy", "status": "healthy",
"service": "oak-service", "service": "oak-service",
"version": "0.5.0", "version": "0.6.0",
"oak_connected": pipeline_ctx is not None, "oak_connected": pipeline_ctx is not None,
"detection_model": DETECTION_MODEL, "detection_model": DETECTION_MODEL,
"spatial_enabled": True, "spatial_enabled": True,
"face_recognition_enabled": face_recognizer is not None, "face_recognition_enabled": face_recognizer is not None,
"pose_model_loaded": pose_estimator is not None,
"timestamp": time.time() "timestamp": time.time()
} }
@@ -415,6 +494,43 @@ async def depth_frame():
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
# ============== Pose Estimation API ==============
@app.get("/pose")
async def pose():
"""Get current pose keypoints."""
if pose_estimator is None:
raise HTTPException(status_code=503, detail="Pose estimator not available")
return {
"active": pose_state["active"],
"keypoints": pose_state["keypoints"],
"num_valid": pose_state["num_valid"],
"mean_confidence": pose_state["mean_confidence"],
"inference_ms": pose_state["inference_ms"],
"last_update": pose_state["last_update"],
"timestamp": time.time(),
}
@app.get("/pose/summary")
async def pose_summary():
"""Get derived posture summary."""
if pose_estimator is None:
raise HTTPException(status_code=503, detail="Pose estimator not available")
return {
"active": pose_state["active"],
"posture": pose_state["posture"].get("posture", "unknown"),
"facing_camera": pose_state["posture"].get("facing_camera", False),
"arms_raised": pose_state["posture"].get("arms_raised", False),
"mean_confidence": pose_state["mean_confidence"],
"num_valid": pose_state["num_valid"],
"timestamp": time.time(),
}
# ============== Face Enrollment API ============== # ============== Face Enrollment API ==============