Add pose estimation to spatial service (production file)

Integrates MoveNet Lightning on Coral 2 into oak_service_spatial.py,
which is the actual production service running on head-vixy. Reuses
the existing RGB frame grab (shared with face recognition) for pose
estimation. Adds /pose and /pose/summary endpoints.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Alex
2026-02-08 19:33:44 -06:00
parent cdbf7ff394
commit 0bbd54b40f

View File

@@ -21,6 +21,7 @@ import cv2
import numpy as np
from face_recognition import FaceRecognizer
from pose_estimator import PoseEstimator
logger = logging.getLogger("oak-service")
logging.basicConfig(level=logging.INFO)
@@ -42,6 +43,10 @@ FACE_DETECT_MODEL = MODELS_DIR / "ssd_mobilenet_v2_face_quant_postprocess_edgetp
FACE_EMBED_MODEL = MODELS_DIR / "facenet.tflite"
FACE_DB_PATH = Path(__file__).parent / "faces.db"
# Pose estimation
POSE_MODEL_PATH = MODELS_DIR / "movenet_single_pose_lightning_ptq_edgetpu.tflite"
POSE_CORAL_DEVICE = 1 # Second Coral (device 0 is headmic/YAMNet)
# ============== Global State ==============
pipeline_ctx = None
detection_queue = None
@@ -51,6 +56,7 @@ detection_thread = None
running = False
labels = []
face_recognizer = None
pose_estimator = None
presence_state = {
"present": False,
@@ -69,6 +75,16 @@ presence_state = {
"recognition_confidence": None,
}
pose_state = {
"active": False,
"keypoints": [],
"posture": {},
"num_valid": 0,
"mean_confidence": 0.0,
"inference_ms": 0.0,
"last_update": None,
}
def init_face_recognition():
"""Initialize Coral face detection + FaceNet embedding."""
@@ -147,6 +163,10 @@ def init_oak():
pipeline_ctx = pipeline
print("✅ OAK-D initialized with SPATIAL person detection!")
# Initialize pose estimator on Coral 2
_init_pose_estimator()
return True
except Exception as e:
@@ -174,6 +194,48 @@ def cleanup_oak():
pipeline_ctx = None
def _init_pose_estimator():
"""Initialize MoveNet Lightning on the second Coral Edge TPU."""
global pose_estimator
if not POSE_MODEL_PATH.exists():
print(f"⚠️ Pose model not found: {POSE_MODEL_PATH}")
return
try:
pose_estimator = PoseEstimator(
model_path=str(POSE_MODEL_PATH),
device_index=POSE_CORAL_DEVICE,
)
print("✅ Pose estimator initialized on Coral 2!")
except Exception as e:
print(f"⚠️ Pose estimator failed to initialize: {e}")
pose_estimator = None
def _run_pose_estimation(rgb_frame):
"""Run pose estimation on an RGB frame via Coral 2."""
global pose_state
if pose_estimator is None:
return
try:
result = pose_estimator.estimate(rgb_frame)
posture = pose_estimator.derive_posture(result["keypoints"])
pose_state["active"] = True
pose_state["keypoints"] = result["keypoints"]
pose_state["posture"] = posture
pose_state["num_valid"] = result["num_valid"]
pose_state["mean_confidence"] = result["mean_confidence"]
pose_state["inference_ms"] = result["inference_ms"]
pose_state["last_update"] = result["timestamp"]
except Exception as e:
print(f"Pose estimation error: {e}")
def detection_loop():
"""Background thread for SPATIAL presence detection."""
global running, presence_state, detection_queue
@@ -212,18 +274,26 @@ def detection_loop():
presence_state["spatial_z"] = best.spatialCoordinates.z
presence_state["distance_mm"] = best.spatialCoordinates.z
# Face recognition
# Grab RGB frame for face recognition + pose estimation
face_results = []
if face_recognizer and rgb_queue:
rgb_frame = None
if rgb_queue:
rgb_data = rgb_queue.tryGet()
if rgb_data is not None:
rgb_frame = rgb_data.getCvFrame()
try:
face_results = face_recognizer.process_frame(
rgb_frame, persons
)
except Exception as e:
logger.warning("Face recognition error: %s", e)
# Face recognition
if face_recognizer and rgb_frame is not None:
try:
face_results = face_recognizer.process_frame(
rgb_frame, persons
)
except Exception as e:
logger.warning("Face recognition error: %s", e)
# Pose estimation (runs on Coral 2, parallel-safe)
if rgb_frame is not None:
_run_pose_estimation(rgb_frame)
det_list = []
best_recognized = None
@@ -265,6 +335,14 @@ def detection_loop():
presence_state["recognized_name"] = None
presence_state["recognition_confidence"] = None
# Clear pose when no person
if pose_state["active"]:
pose_state["active"] = False
pose_state["keypoints"] = []
pose_state["posture"] = {}
pose_state["num_valid"] = 0
pose_state["mean_confidence"] = 0.0
# Check timeout
if presence_state["last_seen"]:
if now - presence_state["last_seen"] > PRESENCE_TIMEOUT:
@@ -304,8 +382,8 @@ async def lifespan(app: FastAPI):
app = FastAPI(
title="OAK-D SPATIAL Vision Service",
description="Vixy's eyes with SPATIAL presence detection + face recognition! 🦊👀📏",
version="0.5.0",
description="Vixy's eyes with SPATIAL presence detection + face recognition + pose estimation! 🦊👀📏",
version="0.6.0",
lifespan=lifespan
)
@@ -316,11 +394,12 @@ async def health():
return {
"status": "healthy",
"service": "oak-service",
"version": "0.5.0",
"version": "0.6.0",
"oak_connected": pipeline_ctx is not None,
"detection_model": DETECTION_MODEL,
"spatial_enabled": True,
"face_recognition_enabled": face_recognizer is not None,
"pose_model_loaded": pose_estimator is not None,
"timestamp": time.time()
}
@@ -415,6 +494,43 @@ async def depth_frame():
raise HTTPException(status_code=500, detail=str(e))
# ============== Pose Estimation API ==============
@app.get("/pose")
async def pose():
"""Get current pose keypoints."""
if pose_estimator is None:
raise HTTPException(status_code=503, detail="Pose estimator not available")
return {
"active": pose_state["active"],
"keypoints": pose_state["keypoints"],
"num_valid": pose_state["num_valid"],
"mean_confidence": pose_state["mean_confidence"],
"inference_ms": pose_state["inference_ms"],
"last_update": pose_state["last_update"],
"timestamp": time.time(),
}
@app.get("/pose/summary")
async def pose_summary():
"""Get derived posture summary."""
if pose_estimator is None:
raise HTTPException(status_code=503, detail="Pose estimator not available")
return {
"active": pose_state["active"],
"posture": pose_state["posture"].get("posture", "unknown"),
"facing_camera": pose_state["posture"].get("facing_camera", False),
"arms_raised": pose_state["posture"].get("arms_raised", False),
"mean_confidence": pose_state["mean_confidence"],
"num_valid": pose_state["num_valid"],
"timestamp": time.time(),
}
# ============== Face Enrollment API ==============