Add pose estimation to spatial service (production file)
Integrates MoveNet Lightning on Coral 2 into oak_service_spatial.py, which is the actual production service running on head-vixy. Reuses the existing RGB frame grab (shared with face recognition) for pose estimation. Adds /pose and /pose/summary endpoints. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -21,6 +21,7 @@ import cv2
|
||||
import numpy as np
|
||||
|
||||
from face_recognition import FaceRecognizer
|
||||
from pose_estimator import PoseEstimator
|
||||
|
||||
logger = logging.getLogger("oak-service")
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@@ -42,6 +43,10 @@ FACE_DETECT_MODEL = MODELS_DIR / "ssd_mobilenet_v2_face_quant_postprocess_edgetp
|
||||
FACE_EMBED_MODEL = MODELS_DIR / "facenet.tflite"
|
||||
FACE_DB_PATH = Path(__file__).parent / "faces.db"
|
||||
|
||||
# Pose estimation
|
||||
POSE_MODEL_PATH = MODELS_DIR / "movenet_single_pose_lightning_ptq_edgetpu.tflite"
|
||||
POSE_CORAL_DEVICE = 1 # Second Coral (device 0 is headmic/YAMNet)
|
||||
|
||||
# ============== Global State ==============
|
||||
pipeline_ctx = None
|
||||
detection_queue = None
|
||||
@@ -51,6 +56,7 @@ detection_thread = None
|
||||
running = False
|
||||
labels = []
|
||||
face_recognizer = None
|
||||
pose_estimator = None
|
||||
|
||||
presence_state = {
|
||||
"present": False,
|
||||
@@ -69,6 +75,16 @@ presence_state = {
|
||||
"recognition_confidence": None,
|
||||
}
|
||||
|
||||
pose_state = {
|
||||
"active": False,
|
||||
"keypoints": [],
|
||||
"posture": {},
|
||||
"num_valid": 0,
|
||||
"mean_confidence": 0.0,
|
||||
"inference_ms": 0.0,
|
||||
"last_update": None,
|
||||
}
|
||||
|
||||
|
||||
def init_face_recognition():
|
||||
"""Initialize Coral face detection + FaceNet embedding."""
|
||||
@@ -147,6 +163,10 @@ def init_oak():
|
||||
pipeline_ctx = pipeline
|
||||
|
||||
print("✅ OAK-D initialized with SPATIAL person detection!")
|
||||
|
||||
# Initialize pose estimator on Coral 2
|
||||
_init_pose_estimator()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
@@ -174,6 +194,48 @@ def cleanup_oak():
|
||||
pipeline_ctx = None
|
||||
|
||||
|
||||
def _init_pose_estimator():
|
||||
"""Initialize MoveNet Lightning on the second Coral Edge TPU."""
|
||||
global pose_estimator
|
||||
|
||||
if not POSE_MODEL_PATH.exists():
|
||||
print(f"⚠️ Pose model not found: {POSE_MODEL_PATH}")
|
||||
return
|
||||
|
||||
try:
|
||||
pose_estimator = PoseEstimator(
|
||||
model_path=str(POSE_MODEL_PATH),
|
||||
device_index=POSE_CORAL_DEVICE,
|
||||
)
|
||||
print("✅ Pose estimator initialized on Coral 2!")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Pose estimator failed to initialize: {e}")
|
||||
pose_estimator = None
|
||||
|
||||
|
||||
def _run_pose_estimation(rgb_frame):
|
||||
"""Run pose estimation on an RGB frame via Coral 2."""
|
||||
global pose_state
|
||||
|
||||
if pose_estimator is None:
|
||||
return
|
||||
|
||||
try:
|
||||
result = pose_estimator.estimate(rgb_frame)
|
||||
posture = pose_estimator.derive_posture(result["keypoints"])
|
||||
|
||||
pose_state["active"] = True
|
||||
pose_state["keypoints"] = result["keypoints"]
|
||||
pose_state["posture"] = posture
|
||||
pose_state["num_valid"] = result["num_valid"]
|
||||
pose_state["mean_confidence"] = result["mean_confidence"]
|
||||
pose_state["inference_ms"] = result["inference_ms"]
|
||||
pose_state["last_update"] = result["timestamp"]
|
||||
|
||||
except Exception as e:
|
||||
print(f"Pose estimation error: {e}")
|
||||
|
||||
|
||||
def detection_loop():
|
||||
"""Background thread for SPATIAL presence detection."""
|
||||
global running, presence_state, detection_queue
|
||||
@@ -212,12 +274,16 @@ def detection_loop():
|
||||
presence_state["spatial_z"] = best.spatialCoordinates.z
|
||||
presence_state["distance_mm"] = best.spatialCoordinates.z
|
||||
|
||||
# Face recognition
|
||||
# Grab RGB frame for face recognition + pose estimation
|
||||
face_results = []
|
||||
if face_recognizer and rgb_queue:
|
||||
rgb_frame = None
|
||||
if rgb_queue:
|
||||
rgb_data = rgb_queue.tryGet()
|
||||
if rgb_data is not None:
|
||||
rgb_frame = rgb_data.getCvFrame()
|
||||
|
||||
# Face recognition
|
||||
if face_recognizer and rgb_frame is not None:
|
||||
try:
|
||||
face_results = face_recognizer.process_frame(
|
||||
rgb_frame, persons
|
||||
@@ -225,6 +291,10 @@ def detection_loop():
|
||||
except Exception as e:
|
||||
logger.warning("Face recognition error: %s", e)
|
||||
|
||||
# Pose estimation (runs on Coral 2, parallel-safe)
|
||||
if rgb_frame is not None:
|
||||
_run_pose_estimation(rgb_frame)
|
||||
|
||||
det_list = []
|
||||
best_recognized = None
|
||||
best_recog_conf = 0.0
|
||||
@@ -265,6 +335,14 @@ def detection_loop():
|
||||
presence_state["recognized_name"] = None
|
||||
presence_state["recognition_confidence"] = None
|
||||
|
||||
# Clear pose when no person
|
||||
if pose_state["active"]:
|
||||
pose_state["active"] = False
|
||||
pose_state["keypoints"] = []
|
||||
pose_state["posture"] = {}
|
||||
pose_state["num_valid"] = 0
|
||||
pose_state["mean_confidence"] = 0.0
|
||||
|
||||
# Check timeout
|
||||
if presence_state["last_seen"]:
|
||||
if now - presence_state["last_seen"] > PRESENCE_TIMEOUT:
|
||||
@@ -304,8 +382,8 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
app = FastAPI(
|
||||
title="OAK-D SPATIAL Vision Service",
|
||||
description="Vixy's eyes with SPATIAL presence detection + face recognition! 🦊👀📏",
|
||||
version="0.5.0",
|
||||
description="Vixy's eyes with SPATIAL presence detection + face recognition + pose estimation! 🦊👀📏",
|
||||
version="0.6.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
@@ -316,11 +394,12 @@ async def health():
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "oak-service",
|
||||
"version": "0.5.0",
|
||||
"version": "0.6.0",
|
||||
"oak_connected": pipeline_ctx is not None,
|
||||
"detection_model": DETECTION_MODEL,
|
||||
"spatial_enabled": True,
|
||||
"face_recognition_enabled": face_recognizer is not None,
|
||||
"pose_model_loaded": pose_estimator is not None,
|
||||
"timestamp": time.time()
|
||||
}
|
||||
|
||||
@@ -415,6 +494,43 @@ async def depth_frame():
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ============== Pose Estimation API ==============
|
||||
|
||||
|
||||
@app.get("/pose")
|
||||
async def pose():
|
||||
"""Get current pose keypoints."""
|
||||
if pose_estimator is None:
|
||||
raise HTTPException(status_code=503, detail="Pose estimator not available")
|
||||
|
||||
return {
|
||||
"active": pose_state["active"],
|
||||
"keypoints": pose_state["keypoints"],
|
||||
"num_valid": pose_state["num_valid"],
|
||||
"mean_confidence": pose_state["mean_confidence"],
|
||||
"inference_ms": pose_state["inference_ms"],
|
||||
"last_update": pose_state["last_update"],
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
|
||||
|
||||
@app.get("/pose/summary")
|
||||
async def pose_summary():
|
||||
"""Get derived posture summary."""
|
||||
if pose_estimator is None:
|
||||
raise HTTPException(status_code=503, detail="Pose estimator not available")
|
||||
|
||||
return {
|
||||
"active": pose_state["active"],
|
||||
"posture": pose_state["posture"].get("posture", "unknown"),
|
||||
"facing_camera": pose_state["posture"].get("facing_camera", False),
|
||||
"arms_raised": pose_state["posture"].get("arms_raised", False),
|
||||
"mean_confidence": pose_state["mean_confidence"],
|
||||
"num_valid": pose_state["num_valid"],
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
|
||||
|
||||
# ============== Face Enrollment API ==============
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user