package and pre-code

This commit is contained in:
Aditya Pulipaka
2026-03-04 13:54:17 -06:00
parent 21958eaa2c
commit 7178ec89a4
12 changed files with 672 additions and 0 deletions

View File

View File

@@ -0,0 +1,213 @@
"""
ROS 2 node that subscribes to rectified stereo images, runs MMPose
keypoint detection, and displays the results in OpenCV windows.
"""
import cv2
import numpy as np
import rclpy
from rclpy.node import Node
from sensor_msgs.msg import Image
from cv_bridge import CvBridge
from mmpose.apis import MMPoseInferencer
# ── COCO-17 drawing helpers ─────────────────────────────────────────────
SKELETON = [
(0, 1), (0, 2), # nose → eyes
(1, 3), (2, 4), # eyes → ears
(5, 6), # shoulders
(5, 7), (7, 9), # left arm
(6, 8), (8, 10), # right arm
(5, 11), (6, 12), # shoulders → hips
(11, 12), # hips
(11, 13), (13, 15), # left leg
(12, 14), (14, 16), # right leg
]
KEYPOINT_NAMES = [
'nose', 'left_eye', 'right_eye', 'left_ear', 'right_ear',
'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow',
'left_wrist', 'right_wrist', 'left_hip', 'right_hip',
'left_knee', 'right_knee', 'left_ankle', 'right_ankle',
]
def draw_keypoints(frame, keypoints, keypoint_scores, threshold=0.3):
"""Draw COCO 17-keypoint skeleton on *frame* (in-place)."""
h, w = frame.shape[:2]
for person_kps, person_scores in zip(keypoints, keypoint_scores):
# Skeleton lines
for j1, j2 in SKELETON:
if person_scores[j1] > threshold and person_scores[j2] > threshold:
pt1 = (int(person_kps[j1][0]), int(person_kps[j1][1]))
pt2 = (int(person_kps[j2][0]), int(person_kps[j2][1]))
if (0 <= pt1[0] < w and 0 <= pt1[1] < h and
0 <= pt2[0] < w and 0 <= pt2[1] < h):
cv2.line(frame, pt1, pt2, (0, 255, 0), 2)
# Joint circles
for idx, (kp, score) in enumerate(zip(person_kps, person_scores)):
if score > threshold:
x, y = int(kp[0]), int(kp[1])
if 0 <= x < w and 0 <= y < h:
color = (0, 255, 0) if score > 0.7 else (0, 200, 200)
cv2.circle(frame, (x, y), 4, color, -1)
cv2.circle(frame, (x, y), 5, (255, 255, 255), 1)
return frame
# ── ROS 2 Node ───────────────────────────────────────────────────────────
class KeypointNode(Node):
"""Subscribe to stereo rectified images, run MMPose, and display."""
def __init__(self):
super().__init__('keypoint_node')
# ── Parameters ──────────────────────────────────────────────────
self.declare_parameter('threshold', 0.3)
self.declare_parameter('device', 'cuda:0')
threshold = self.get_parameter('threshold').value
device = self.get_parameter('device').value
self._threshold = threshold
# ── MMPose (single shared instance) ─────────────────────────────
self.get_logger().info(
f'Loading MMPose model on {device} (this may take a moment)…'
)
self._inferencer = MMPoseInferencer(
pose2d='human',
device=device,
)
self.get_logger().info('MMPose model loaded.')
# ── cv_bridge ───────────────────────────────────────────────────
self._bridge = CvBridge()
# Latest frames (written by subscribers, read by display timer)
self._left_frame: np.ndarray | None = None
self._right_frame: np.ndarray | None = None
# ── Subscribers (independent — no time-sync needed) ─────────────
self.create_subscription(
Image,
'/stereo/left/image_rect',
self._left_cb,
10,
)
self.create_subscription(
Image,
'/stereo/right/image_rect',
self._right_cb,
10,
)
# ── Timer for cv2.waitKey (keeps GUI responsive) ────────────────
# ~30 Hz is plenty for display refresh
self.create_timer(1.0 / 30.0, self._display_timer_cb)
self.get_logger().info(
'Subscribed to /stereo/left/image_rect and '
'/stereo/right/image_rect — waiting for images…'
)
# ── Subscriber callbacks ────────────────────────────────────────────
def _process_frame(self, msg: Image, label: str) -> np.ndarray:
"""Convert ROS Image -> OpenCV, run MMPose, draw keypoints."""
# Log image metadata once to help diagnose blank-frame issues
if not hasattr(self, f'_logged_{label}'):
self.get_logger().info(
f'[{label}] First frame: encoding={msg.encoding} '
f'size={msg.width}x{msg.height} step={msg.step}'
)
setattr(self, f'_logged_{label}', True)
# Use passthrough first so we can inspect the raw data
frame = self._bridge.imgmsg_to_cv2(msg, desired_encoding='passthrough')
if not hasattr(self, f'_logged_px_{label}'):
self.get_logger().info(
f'[{label}] Frame dtype={frame.dtype} shape={frame.shape} '
f'min={frame.min()} max={frame.max()}'
)
setattr(self, f'_logged_px_{label}', True)
# Convert grayscale to BGR so MMPose and drawing work correctly
if len(frame.shape) == 2 or frame.shape[2] == 1:
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR)
elif frame.shape[2] == 4:
frame = cv2.cvtColor(frame, cv2.COLOR_BGRA2BGR)
result = next(self._inferencer(
frame,
show=False,
return_datasamples=False,
))
predictions = result.get('predictions', [[]])[0]
all_kps, all_scores = [], []
for pred in predictions:
kps = pred.get('keypoints', [])
scores = pred.get('keypoint_scores', [])
if len(kps) > 0:
all_kps.append(np.array(kps))
all_scores.append(np.array(scores))
if all_kps:
draw_keypoints(frame, all_kps, all_scores, self._threshold)
cv2.putText(
frame,
f'Subjects: {len(all_kps)}',
(10, 30),
cv2.FONT_HERSHEY_SIMPLEX,
0.8,
(255, 255, 255),
2,
)
return frame
def _left_cb(self, msg: Image) -> None:
self._left_frame = self._process_frame(msg, 'left')
def _right_cb(self, msg: Image) -> None:
self._right_frame = self._process_frame(msg, 'right')
# ── Display timer ───────────────────────────────────────────────────
def _display_timer_cb(self) -> None:
if self._left_frame is not None:
cv2.imshow('Left - MMPose', self._left_frame)
if self._right_frame is not None:
cv2.imshow('Right - MMPose', self._right_frame)
key = cv2.waitKey(1) & 0xFF
if key == ord('q'):
self.get_logger().info('Quit requested — shutting down.')
self.destroy_node()
rclpy.shutdown()
# ── Entry point ──────────────────────────────────────────────────────────
def main(args=None):
rclpy.init(args=args)
node = KeypointNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
cv2.destroyAllWindows()
node.destroy_node()
rclpy.try_shutdown()
if __name__ == '__main__':
main()

22
keypoint_pose/package.xml Normal file
View File

@@ -0,0 +1,22 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>keypoint_pose</name>
<version>0.0.0</version>
<description>MMPose keypoint detection on stereo rectified images</description>
<maintainer email="pulipakaa24@outlook.com">sentry</maintainer>
<license>TODO: License declaration</license>
<depend>rclpy</depend>
<depend>sensor_msgs</depend>
<depend>cv_bridge</depend>
<test_depend>ament_copyright</test_depend>
<test_depend>ament_flake8</test_depend>
<test_depend>ament_pep257</test_depend>
<test_depend>python3-pytest</test_depend>
<export>
<build_type>ament_python</build_type>
</export>
</package>

View File

4
keypoint_pose/setup.cfg Normal file
View File

@@ -0,0 +1,4 @@
[develop]
script_dir=$base/lib/keypoint_pose
[install]
install_scripts=$base/lib/keypoint_pose

30
keypoint_pose/setup.py Normal file
View File

@@ -0,0 +1,30 @@
from setuptools import find_packages, setup
package_name = 'keypoint_pose'
setup(
name=package_name,
version='0.0.0',
packages=find_packages(exclude=['test']),
data_files=[
('share/ament_index/resource_index/packages',
['resource/' + package_name]),
('share/' + package_name, ['package.xml']),
],
install_requires=['setuptools'],
zip_safe=True,
maintainer='sentry',
maintainer_email='pulipakaa24@outlook.com',
description='MMPose keypoint detection on stereo rectified images',
license='TODO: License declaration',
extras_require={
'test': [
'pytest',
],
},
entry_points={
'console_scripts': [
'keypoint_node = keypoint_pose.keypoint_node:main',
],
},
)

View File

@@ -0,0 +1,25 @@
# Copyright 2015 Open Source Robotics Foundation, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ament_copyright.main import main
import pytest
# Remove the `skip` decorator once the source file(s) have a copyright header
@pytest.mark.skip(reason='No copyright header has been placed in the generated source file.')
@pytest.mark.copyright
@pytest.mark.linter
def test_copyright():
rc = main(argv=['.', 'test'])
assert rc == 0, 'Found errors'

View File

@@ -0,0 +1,25 @@
# Copyright 2017 Open Source Robotics Foundation, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ament_flake8.main import main_with_errors
import pytest
@pytest.mark.flake8
@pytest.mark.linter
def test_flake8():
rc, errors = main_with_errors(argv=[])
assert rc == 0, \
'Found %d code style errors / warnings:\n' % len(errors) + \
'\n'.join(errors)

View File

@@ -0,0 +1,23 @@
# Copyright 2015 Open Source Robotics Foundation, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ament_pep257.main import main
import pytest
@pytest.mark.linter
@pytest.mark.pep257
def test_pep257():
rc = main(argv=['.', 'test'])
assert rc == 0, 'Found code style errors / warnings'