reid_stuff

This commit is contained in:
Aditya Pulipaka
2026-03-15 19:58:52 -05:00
parent 5c7b26c94a
commit 53999d6023
9 changed files with 884 additions and 0 deletions

View File

@@ -0,0 +1,51 @@
"""Reid pipeline in headless mode (no cv2 display windows).
To view output:
ros2 run rqt_image_view rqt_image_view → select /reid/annotated
rviz2 → add MarkerArray on /reid/track_markers and /keypoint_markers
"""
import os
from launch import LaunchDescription
from launch.actions import ExecuteProcess
def generate_launch_description():
python_exe = os.path.expanduser('~/miniconda3/envs/mmpose/bin/python3')
keyreID_path = os.path.expanduser('~/KeyRe-ID')
return LaunchDescription([
ExecuteProcess(
cmd=[
python_exe, '-m', 'tracking_re_id.single_person_loc_node',
'--ros-args',
'-p', 'threshold:=0.3',
'-p', 'device:=cuda:0',
'-p', 'max_residual:=0.10',
'-p', 'headless:=true',
],
output='screen',
env={**os.environ},
),
ExecuteProcess(
cmd=[
python_exe, '-m', 'tracking_re_id.reid_node',
'--ros-args',
'-p', f'keyreID_path:={keyreID_path}',
'-p', 'num_classes:=150',
'-p', 'camera_num:=2',
'-p', 'device:=cuda:0',
'-p', 'seq_len:=4',
'-p', 'kp_threshold:=0.3',
'-p', 'match_threshold:=0.65',
'-p', 'track_dist_px:=120.0',
'-p', 'track_timeout:=3.0',
'-p', 'headless:=true',
],
output='screen',
env={**os.environ},
),
])

View File

@@ -0,0 +1,70 @@
"""Launch the KeyRe-ID re-identification pipeline alongside the existing
stereo triangulation pipeline.
Nodes started
─────────────
1. single_person_loc_node (unchanged stereo 3-D triangulation)
publishes: /keypoint_markers (MarkerArray)
/keypoints_3d (PointCloud2)
2. reid_node (self-contained left-camera MMPose + KeyRe-ID)
publishes: /reid/annotated (Image)
/reid/track_markers (MarkerArray)
The two nodes are independent: reid_node runs its own MMPose instance on
the left camera only and does not depend on single_person_loc_node output.
Run them together to get both 3-D triangulation and persistent person IDs,
or launch reid_node on its own if only re-identification is needed.
Viewing the output
──────────────────
ros2 run rqt_image_view rqt_image_view → /reid/annotated
rviz2 → add MarkerArray /reid/track_markers and /keypoint_markers
"""
import os
from launch import LaunchDescription
from launch.actions import ExecuteProcess
def generate_launch_description():
python_exe = os.path.expanduser('~/miniconda3/envs/mmpose/bin/python3')
keyreID_path = os.path.expanduser('~/KeyRe-ID')
return LaunchDescription([
# # ── 1. Stereo keypoint triangulator (3-D, unchanged) ─────────────────
# ExecuteProcess(
# cmd=[
# python_exe, '-m', 'tracking_re_id.single_person_loc_node',
# '--ros-args',
# '-p', 'threshold:=0.3',
# '-p', 'device:=cuda:0',
# '-p', 'max_residual:=0.10',
# '-p', 'headless:=true',
# ],
# output='screen',
# env={**os.environ},
# ),
# ── 2. KeyRe-ID re-identification (self-contained) ───────────────────
ExecuteProcess(
cmd=[
python_exe, '-m', 'tracking_re_id.reid_node',
'--ros-args',
'-p', f'keyreID_path:={keyreID_path}',
'-p', 'num_classes:=150',
'-p', 'camera_num:=2',
'-p', 'device:=cuda:0',
'-p', 'seq_len:=4',
'-p', 'kp_threshold:=0.3',
'-p', 'match_threshold:=0.65',
'-p', 'track_dist_px:=120.0',
'-p', 'track_timeout:=3.0',
'-p', 'headless:=false',
],
output='screen',
env={**os.environ},
),
])

View File

@@ -13,6 +13,7 @@ setup(
['resource/' + package_name]),
('share/' + package_name, ['package.xml']),
(os.path.join('share', package_name, 'launch'), glob('launch/*.py')),
(os.path.join('share', package_name, 'weights'), glob('weights/*.pth')),
],
install_requires=['setuptools'],
zip_safe=True,
@@ -30,6 +31,7 @@ setup(
'single_person_loc_node = tracking_re_id.single_person_loc_node:main',
'ground_plane_node = tracking_re_id.ground_plane_node:main',
'overlay_node = tracking_re_id.overlay_node:main',
'reid_node = tracking_re_id.reid_node:main',
],
},
)

View File

@@ -0,0 +1,431 @@
"""
reid_node.py
Self-contained ROS 2 node: MMPose 2-D pose estimation on the left stereo
camera + KeyRe-ID person re-identification.
Pipeline (per frame)
────────────────────
/stereo/left/image_raw
MMPoseInferencer (pose2d='human')
│ per-person keypoints (17, 2) + scores (17,)
bbox extraction → person crop (256×128)
keypoint → 6-channel body-part heatmap (matches training pipeline)
Hungarian matching → per-track Tracklet buffer (deque, maxlen=seq_len)
▼ (once buffer full and person_id is None)
KeyRe-ID inference → feature embedding
Gallery cosine match → assign / register persistent person_id
/reid/annotated (Image) left frame annotated with IDs
/reid/track_markers (MarkerArray) labelled text markers for RViz
Parameters
──────────
weights_path str path to iLIDSVIDbest_CMC.pth (required)
keyreID_path str path to KeyRe-ID source directory
num_classes int training split size (150 for iLIDS-VID split-0)
camera_num int cameras in training set (2 for iLIDS-VID)
device str 'cuda:0' or 'cpu'
seq_len int frames per tracklet clip (default 4)
kp_threshold float min keypoint confidence
match_threshold float cosine-similarity threshold for gallery match
track_dist_px float max centroid distance (px) to keep a track alive
track_timeout float seconds before an unseen track is dropped
headless bool suppress cv2 display window
"""
import os
import sys
import time
import colorsys
from ament_index_python.packages import get_package_share_directory
import cv2
import numpy as np
import rclpy
from rclpy.node import Node
from sensor_msgs.msg import Image
from visualization_msgs.msg import Marker, MarkerArray
from cv_bridge import CvBridge
import torch
from .reid_utils import (
keypoints_to_heatmap,
keypoints_to_bbox,
clamp_bbox,
transform_keypoints_to_crop,
preprocess_crop,
preprocess_heatmap,
Tracklet,
Gallery,
)
# ── Hungarian matching ────────────────────────────────────────────────────────
def _hungarian_match(tracks: dict, detections: list, max_dist: float):
"""
Associate existing tracks to new per-frame detections by 2-D centroid
distance using the Hungarian algorithm.
Returns:
matched : list of (track_id, det_index)
unmatched_tracks : track_ids with no detection this frame
unmatched_dets : detection indices with no existing track
"""
from scipy.optimize import linear_sum_assignment
track_ids = list(tracks.keys())
n_t, n_d = len(track_ids), len(detections)
if n_t == 0:
return [], [], list(range(n_d))
if n_d == 0:
return [], track_ids, []
cost = np.full((n_t, n_d), max_dist + 1.0)
for i, tid in enumerate(track_ids):
tc = tracks[tid].centroid
if tc is None:
continue
for j, det in enumerate(detections):
dc = det['centroid']
if dc is not None:
cost[i, j] = float(np.hypot(tc[0] - dc[0], tc[1] - dc[1]))
row_ind, col_ind = linear_sum_assignment(cost)
matched, used_t, used_d = [], set(), set()
for ri, ci in zip(row_ind, col_ind):
if cost[ri, ci] <= max_dist:
matched.append((track_ids[ri], ci))
used_t.add(track_ids[ri])
used_d.add(ci)
return (
matched,
[tid for tid in track_ids if tid not in used_t],
[j for j in range(n_d) if j not in used_d],
)
# ── Stable per-person colour ──────────────────────────────────────────────────
def _id_colour(person_id: int) -> tuple:
r, g, b = colorsys.hsv_to_rgb(((person_id * 0.37) % 1.0), 0.9, 1.0)
return (int(b * 255), int(g * 255), int(r * 255))
# ═══════════════════════════════════════════════════════════════════════════════
class ReIDNode(Node):
"""KeyRe-ID re-identification node with integrated MMPose 2-D detection."""
def __init__(self):
super().__init__('reid_node')
# ── Parameters ──────────────────────────────────────────────────────
self.declare_parameter('weights_path',
os.path.join(
get_package_share_directory('tracking_re_id'),
'weights', 'iLIDSVIDbest_CMC.pth'))
self.declare_parameter('keyreID_path',
os.path.expanduser('~/KeyRe-ID'))
self.declare_parameter('num_classes', 150)
self.declare_parameter('camera_num', 2)
self.declare_parameter('device', 'cuda:0')
self.declare_parameter('seq_len', 4)
self.declare_parameter('kp_threshold', 0.3)
self.declare_parameter('match_threshold', 0.65)
self.declare_parameter('track_dist_px', 120.0)
self.declare_parameter('track_timeout', 3.0)
self.declare_parameter('headless', False)
weights_path = self.get_parameter('weights_path').value
keyreID_path = self.get_parameter('keyreID_path').value
num_classes = self.get_parameter('num_classes').value
camera_num = self.get_parameter('camera_num').value
device_str = self.get_parameter('device').value
self._seq_len = self.get_parameter('seq_len').value
self._kp_thresh = self.get_parameter('kp_threshold').value
self._match_thresh = self.get_parameter('match_threshold').value
self._track_dist = self.get_parameter('track_dist_px').value
self._track_timeout = self.get_parameter('track_timeout').value
self._headless = self.get_parameter('headless').value
self._device = torch.device(
device_str if torch.cuda.is_available() else 'cpu')
self.get_logger().info(f'Using device: {self._device}')
# ── MMPose ───────────────────────────────────────────────────────────
from mmpose.apis import MMPoseInferencer # noqa: PLC0415
self.get_logger().info(f'Loading MMPose on {device_str}')
self._inferencer = MMPoseInferencer(pose2d='human', device=device_str)
self.get_logger().info('MMPose loaded.')
# ── KeyRe-ID ─────────────────────────────────────────────────────────
if keyreID_path not in sys.path:
sys.path.insert(0, keyreID_path)
try:
from KeyRe_ID_model import KeyRe_ID # noqa: PLC0415
except ImportError as exc:
self.get_logger().fatal(
f'Cannot import KeyRe_ID_model from {keyreID_path}: {exc}')
raise
self.get_logger().info(f'Loading KeyRe-ID weights from {weights_path}')
self._model = KeyRe_ID(
num_classes=num_classes,
camera_num=camera_num,
pretrainpath=None,
)
self._model.load_param(weights_path, load=False)
self._model.to(self._device)
self._model.eval()
self.get_logger().info('KeyRe-ID model ready.')
# ── ROS infrastructure ───────────────────────────────────────────────
self._bridge = CvBridge()
self.create_subscription(
Image, '/stereo/left/image_raw', self._image_cb, 10)
self._vis_pub = self.create_publisher(Image, '/reid/annotated', 10)
self._marker_pub = self.create_publisher(MarkerArray, '/reid/track_markers', 10)
# ── State ────────────────────────────────────────────────────────────
self._tracks: dict[int, Tracklet] = {}
self._next_track_id: int = 0
self._gallery = Gallery(threshold=self._match_thresh)
self._display_frame = None
if not self._headless:
self.create_timer(1.0 / 30.0, self._display_timer_cb)
self.get_logger().info(
'reid_node ready. Waiting for /stereo/left/image_raw …')
# ── MMPose helper ─────────────────────────────────────────────────────────
def _run_mmpose(self, frame: np.ndarray) -> list:
"""Return list of dicts {keypoints: (17,2), scores: (17,)}."""
result = next(self._inferencer(frame, show=False, return_datasamples=False))
people = []
for pred in result.get('predictions', [[]])[0]:
kps = pred.get('keypoints', [])
scores = pred.get('keypoint_scores', [])
if len(kps) > 0:
people.append({
'keypoints': np.array(kps, dtype=np.float32),
'scores': np.array(scores, dtype=np.float32),
})
return people
# ── Main image callback ───────────────────────────────────────────────────
def _image_cb(self, img_msg: Image):
now = time.time()
frame = self._bridge.imgmsg_to_cv2(img_msg, desired_encoding='bgr8')
frame_h, frame_w = frame.shape[:2]
# Detect people and compute centroids / bboxes
raw_people = self._run_mmpose(frame)
detections = []
for person in raw_people:
kps, scores = person['keypoints'], person['scores']
bbox = keypoints_to_bbox(kps, scores, threshold=self._kp_thresh)
if bbox is None:
continue
visible = scores > self._kp_thresh
centroid = (float(kps[visible, 0].mean()),
float(kps[visible, 1].mean())) if np.any(visible) else None
if centroid is None:
continue
detections.append({
'keypoints': kps,
'scores': scores,
'bbox': bbox,
'centroid': centroid,
})
# Associate detections to existing tracks
matched, unmatched_tracks, unmatched_dets = _hungarian_match(
self._tracks, detections, self._track_dist)
for tid, det_idx in matched:
self._update_track(
self._tracks[tid], detections[det_idx],
frame, frame_w, frame_h, now)
for det_idx in unmatched_dets:
tid = self._next_track_id
self._next_track_id += 1
self._tracks[tid] = Tracklet(tid, seq_len=self._seq_len)
self._update_track(
self._tracks[tid], detections[det_idx],
frame, frame_w, frame_h, now)
# Re-ID: only for tracks whose buffer just became full (person_id still None)
with torch.no_grad():
for track in self._tracks.values():
if track.is_ready():
self._run_reid(track)
# Drop stale tracks
for tid in [tid for tid, t in self._tracks.items()
if now - t.last_seen > self._track_timeout]:
del self._tracks[tid]
# Publish
vis = self._build_visualisation(frame.copy())
out = self._bridge.cv2_to_imgmsg(vis, encoding='bgr8')
out.header = img_msg.header
self._vis_pub.publish(out)
self._publish_markers(img_msg.header.stamp)
if not self._headless:
self._display_frame = vis
# ── Track update ──────────────────────────────────────────────────────────
def _update_track(self, track: Tracklet, det: dict,
frame: np.ndarray, frame_w: int, frame_h: int,
timestamp: float):
x1, y1, x2, y2 = clamp_bbox(*det['bbox'], frame_w, frame_h)
if x2 - x1 < 10 or y2 - y1 < 10:
return
crop_bgr = frame[y1:y2, x1:x2]
kp_xyc, crop_w, crop_h = transform_keypoints_to_crop(
det['keypoints'], det['scores'], x1, y1, x2, y2)
heatmap_np = keypoints_to_heatmap(
kp_xyc, crop_w, crop_h, vis_thresh=self._kp_thresh)
try:
crop_t = preprocess_crop(crop_bgr)
heatmap_t = preprocess_heatmap(heatmap_np)
except Exception as exc:
self.get_logger().warn(
f'Preprocess failed for track {track.track_id}: {exc}')
return
track.add_frame(crop_t, heatmap_t,
centroid=det['centroid'],
bbox=(x1, y1, x2, y2),
timestamp=timestamp)
# ── KeyRe-ID inference ────────────────────────────────────────────────────
def _run_reid(self, track: Tracklet):
imgs, hmaps = track.get_model_inputs()
imgs = imgs.to(self._device)
hmaps = hmaps.to(self._device)
# cam_label=0: left stereo camera → iLIDS-VID cam1 index
feat = self._model(imgs, hmaps, None, cam_label=0)
feature = feat[0].cpu()
track.feature = feature
track.person_id, track.match_sim = self._gallery.match_or_register(feature)
# ── Visualisation ─────────────────────────────────────────────────────────
def _build_visualisation(self, frame: np.ndarray) -> np.ndarray:
for track in self._tracks.values():
if track.bbox is None:
continue
x1, y1, x2, y2 = track.bbox
if track.person_id is not None:
colour = _id_colour(track.person_id)
label = f'P{track.person_id} ({track.match_sim:.2f})'
else:
colour = (160, 160, 160)
label = f'T{track.track_id} ({len(track.crops)}/{self._seq_len})'
cv2.rectangle(frame, (x1, y1), (x2, y2), colour, 2)
(tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.55, 1)
cv2.rectangle(frame, (x1, y1 - th - 6), (x1 + tw + 4, y1), colour, -1)
cv2.putText(frame, label, (x1 + 2, y1 - 4),
cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 0, 0), 1, cv2.LINE_AA)
cv2.putText(frame,
f'Known: {len(self._gallery)} Tracks: {len(self._tracks)}',
(10, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.6,
(255, 255, 255), 2, cv2.LINE_AA)
return frame
# ── RViz markers ─────────────────────────────────────────────────────────
def _publish_markers(self, stamp):
ma = MarkerArray()
delete = Marker()
delete.action = Marker.DELETEALL
delete.header.frame_id = 'left'
delete.header.stamp = stamp
ma.markers.append(delete)
mid = 0
for track in self._tracks.values():
if track.centroid is None or track.person_id is None:
continue
colour = _id_colour(track.person_id)
m = Marker()
m.header.frame_id = 'left'
m.header.stamp = stamp
m.ns = 'reid_labels'
m.id = mid; mid += 1
m.type = Marker.TEXT_VIEW_FACING
m.action = Marker.ADD
m.pose.position.x = float(track.centroid[0]) / 100.0
m.pose.position.y = float(track.centroid[1]) / 100.0
m.pose.position.z = 2.0
m.pose.orientation.w = 1.0
m.scale.z = 0.15
m.color.r = colour[2] / 255.0
m.color.g = colour[1] / 255.0
m.color.b = colour[0] / 255.0
m.color.a = 1.0
m.text = f'P{track.person_id}'
m.lifetime.nanosec = 500_000_000
ma.markers.append(m)
self._marker_pub.publish(ma)
# ── Display timer ─────────────────────────────────────────────────────────
def _display_timer_cb(self):
if self._display_frame is not None:
cv2.imshow('KeyRe-ID', self._display_frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
self.get_logger().info('Quit requested.')
self.destroy_node()
rclpy.shutdown()
# ── Entry point ───────────────────────────────────────────────────────────────
def main(args=None):
rclpy.init(args=args)
node = ReIDNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
if not node._headless:
cv2.destroyAllWindows()
node.destroy_node()
rclpy.try_shutdown()
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,326 @@
"""
reid_utils.py
Shared utilities for the KeyRe-ID ROS 2 node:
- keypoints_to_heatmap() -- matches the generate_heatmaps_ilids.py training pipeline
- preprocess_crop() -- matches val_transforms from heatmap_loader.py
- preprocess_heatmap() -- matches CustomHeatmapTransform from heatmap_loader.py
- keypoints_to_bbox() -- derive a bounding box from visible keypoints
- transform_keypoints_to_crop() -- remap full-frame kps to crop-relative coords
- Tracklet -- per-track frame buffer
- Gallery -- known-person feature store with cosine re-ID
"""
from collections import deque
import cv2
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms as T
from scipy.ndimage import gaussian_filter
from torchvision.transforms import InterpolationMode
# ── COCO keypoint indices ─────────────────────────────────────────────────────
NOSE, LEYE, REYE, LEAR, REAR = 0, 1, 2, 3, 4
LS, RS, LE, RE, LW, RW = 5, 6, 7, 8, 9, 10
LH, RH, LK, RK, LA, RA = 11, 12, 13, 14, 15, 16
# ── Heatmap generation (identical to generate_heatmaps_ilids.py) ──────────────
def keypoints_to_heatmap(kp_array, img_w, img_h, vis_thresh=0.1):
"""
Convert 17 COCO keypoints to 6-channel body-part heatmaps.
This exactly reproduces the pipeline used to generate the training data so
the heatmap distribution seen at inference matches what the model trained on.
Args:
kp_array : np.ndarray (17, 3) [x, y, confidence]
img_w : int width of the source image (crop)
img_h : int height of the source image (crop)
vis_thresh : float confidence threshold
Returns:
np.ndarray (6, img_h, img_w) channels: head, torso, l-arm, r-arm, l-leg, r-leg
"""
heatmaps = np.zeros((6, img_h, img_w), dtype=np.float32)
blur_sigma = max(1.0, img_w / 16.0)
line_thickness = max(1, int(img_w / 8))
def vis(i):
return kp_array[i, 2] > vis_thresh
def pt(i):
return (int(kp_array[i, 0]), int(kp_array[i, 1]))
def in_bounds(i):
x, y = int(kp_array[i, 0]), int(kp_array[i, 1])
return 0 <= x < img_w and 0 <= y < img_h
def usable(i):
return vis(i) and in_bounds(i)
# Channel 0: Head
head_pts = [pt(i) for i in [NOSE, LEYE, REYE, LEAR, REAR] if usable(i)]
if len(head_pts) >= 3:
hull = cv2.convexHull(np.array(head_pts, dtype=np.int32))
cv2.fillConvexPoly(heatmaps[0], hull, 1.0)
elif head_pts:
for p in head_pts:
cv2.circle(heatmaps[0], p, line_thickness, 1.0, -1)
# Channel 1: Torso
torso_order = [LS, RS, RH, LH]
torso_usable = [i for i in torso_order if usable(i)]
if len(torso_usable) == 4:
pts = np.array([pt(i) for i in torso_order], dtype=np.int32)
cv2.fillPoly(heatmaps[1], [pts], 1.0)
elif len(torso_usable) >= 3:
hull = cv2.convexHull(np.array([pt(i) for i in torso_usable], dtype=np.int32))
cv2.fillConvexPoly(heatmaps[1], hull, 1.0)
elif len(torso_usable) == 2:
pts_list = [pt(i) for i in torso_usable]
cv2.line(heatmaps[1], pts_list[0], pts_list[1], 1.0, line_thickness)
elif len(torso_usable) == 1:
cv2.circle(heatmaps[1], pt(torso_usable[0]), line_thickness, 1.0, -1)
# Channels 2-5: Limbs
limbs = [
(2, LE, LW), # left arm
(3, RE, RW), # right arm
(4, LK, LA), # left leg
(5, RK, RA), # right leg
]
for ch, j1, j2 in limbs:
u1, u2 = usable(j1), usable(j2)
if u1 and u2:
cv2.line(heatmaps[ch], pt(j1), pt(j2), 1.0, line_thickness)
elif u1:
cv2.circle(heatmaps[ch], pt(j1), line_thickness, 1.0, -1)
elif u2:
cv2.circle(heatmaps[ch], pt(j2), line_thickness, 1.0, -1)
# Gaussian blur for soft edges
for i in range(6):
if heatmaps[i].max() > 0:
heatmaps[i] = gaussian_filter(heatmaps[i], sigma=blur_sigma)
# L/R sanity check based on torso orientation
if usable(LS) and usable(RS):
shoulder_gap = abs(kp_array[LS, 0] - kp_array[RS, 0])
if shoulder_gap > img_w * 0.05:
facing_camera = kp_array[LS, 0] > kp_array[RS, 0]
if usable(LE) and usable(RE):
if facing_camera != (kp_array[LE, 0] > kp_array[RE, 0]):
heatmaps[2], heatmaps[3] = heatmaps[3].copy(), heatmaps[2].copy()
if usable(LK) and usable(RK):
if facing_camera != (kp_array[LK, 0] > kp_array[RK, 0]):
heatmaps[4], heatmaps[5] = heatmaps[5].copy(), heatmaps[4].copy()
# Normalize each channel to [0, 1]
for i in range(6):
mx = heatmaps[i].max()
if mx > 0:
heatmaps[i] /= mx
return heatmaps
# ── Bounding box utilities ────────────────────────────────────────────────────
def keypoints_to_bbox(keypoints, scores, threshold=0.3, margin=0.15):
"""
Derive a bounding box (x1, y1, x2, y2) from visible COCO keypoints.
Returns None if no keypoints are visible above threshold.
"""
visible = scores > threshold
if not np.any(visible):
return None
kps = keypoints[visible]
x1, y1 = float(kps[:, 0].min()), float(kps[:, 1].min())
x2, y2 = float(kps[:, 0].max()), float(kps[:, 1].max())
w, h = x2 - x1, y2 - y1
x1 -= w * margin
y1 -= h * margin * 1.5 # extra headroom
x2 += w * margin
y2 += h * margin * 0.5
return x1, y1, x2, y2
def clamp_bbox(x1, y1, x2, y2, frame_w, frame_h):
"""Clamp bbox to image boundaries, return as ints."""
x1 = max(0, int(x1))
y1 = max(0, int(y1))
x2 = min(frame_w, int(x2))
y2 = min(frame_h, int(y2))
return x1, y1, x2, y2
def transform_keypoints_to_crop(keypoints, scores, x1, y1, x2, y2):
"""
Re-map full-frame keypoint coordinates to crop-local coordinates so that
the heatmap is generated in the same coordinate space as the training data
(where MMPose was run on individual person crops, not the full frame).
Returns:
kp_xyc : np.ndarray (17, 3) crop-relative [x, y, score]
crop_w : int
crop_h : int
"""
kp_crop = keypoints.copy().astype(np.float32)
kp_crop[:, 0] -= x1
kp_crop[:, 1] -= y1
kp_xyc = np.column_stack([kp_crop, scores.astype(np.float32)])
return kp_xyc, int(x2 - x1), int(y2 - y1)
# ── Preprocessing (must match heatmap_loader.py transforms) ──────────────────
_IMG_TRANSFORM = T.Compose([
T.Resize([256, 128], interpolation=InterpolationMode.BILINEAR),
T.ToTensor(),
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
_HM_NORMALIZE = T.Normalize(mean=[0.5] * 6, std=[0.5] * 6)
def preprocess_crop(crop_bgr):
"""
Preprocess a BGR person crop (numpy HxWx3) to a model-ready tensor.
Pipeline matches val_transforms in heatmap_loader.py:
Resize(256, 128) → ToTensor → Normalize(0.5, 0.5)
Returns:
torch.Tensor (3, 256, 128)
"""
from PIL import Image as PILImage
rgb = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB)
pil_img = PILImage.fromarray(rgb)
return _IMG_TRANSFORM(pil_img)
def preprocess_heatmap(heatmap_np):
"""
Preprocess a raw (6, H, W) numpy heatmap to a model-ready tensor.
Pipeline matches CustomHeatmapTransform([256, 128]) in heatmap_loader.py:
Resize(256, 128) bilinear → MinMaxScale per channel → Normalize(0.5, 0.5)
Returns:
torch.Tensor (6, 256, 128)
"""
t = torch.tensor(heatmap_np, dtype=torch.float32).unsqueeze(0) # (1, 6, H, W)
t = F.interpolate(t, size=(256, 128), mode='bilinear', align_corners=False)
t = t.squeeze(0) # (6, 256, 128)
# Min-max scale per channel
min_val = t.amin(dim=(1, 2), keepdim=True)
max_val = t.amax(dim=(1, 2), keepdim=True)
t = (t - min_val) / (max_val - min_val + 1e-6)
t = _HM_NORMALIZE(t)
return t
# ── Tracklet ──────────────────────────────────────────────────────────────────
class Tracklet:
"""
Stores a rolling buffer of preprocessed (crop, heatmap) tensors for one
tracked person across time, plus the most recently assigned re-ID result.
"""
def __init__(self, track_id: int, seq_len: int = 4):
self.track_id = track_id
self.seq_len = seq_len
# Rolling buffers when full a new frame displaces the oldest
self.crops: deque = deque(maxlen=seq_len) # Tensor (3, 256, 128)
self.heatmaps: deque = deque(maxlen=seq_len) # Tensor (6, 256, 128)
self.centroid = None # (cx, cy) pixel coords in latest frame
self.bbox = None # (x1, y1, x2, y2) in latest frame
self.last_seen: float = 0.0 # time.time() of latest update
# Re-ID assignment (filled after gallery matching)
self.person_id: int | None = None
self.match_sim: float = 0.0
# Latest computed feature embedding (feat_dim,)
self.feature: torch.Tensor | None = None
def add_frame(self, crop_t: torch.Tensor, heatmap_t: torch.Tensor,
centroid, bbox, timestamp: float):
self.crops.append(crop_t)
self.heatmaps.append(heatmap_t)
self.centroid = centroid
self.bbox = bbox
self.last_seen = timestamp
def is_ready(self) -> bool:
"""True once the rolling buffer has been filled at least once."""
return len(self.crops) >= self.seq_len
def get_model_inputs(self):
"""
Returns (imgs, hmaps) tensors shaped for KeyRe_ID:
imgs : (1, seq_len, 3, 256, 128)
hmaps : (1, seq_len, 6, 256, 128)
"""
imgs = torch.stack(list(self.crops)).unsqueeze(0)
hmaps = torch.stack(list(self.heatmaps)).unsqueeze(0)
return imgs, hmaps
# ── Gallery ───────────────────────────────────────────────────────────────────
class Gallery:
"""
Manages a set of known-person feature embeddings and performs cosine-
similarity-based re-identification.
New persons are registered when their best match falls below `threshold`.
Known persons' embeddings are updated with an exponential moving average.
"""
def __init__(self, threshold: float = 0.65, ema_alpha: float = 0.9):
self.threshold = threshold
self.ema_alpha = ema_alpha
self._embeddings: dict[int, torch.Tensor] = {} # person_id → (feat_dim,)
self._next_id: int = 1
def match_or_register(self, feature: torch.Tensor) -> tuple[int, float]:
"""
Compare *feature* against all gallery entries.
Returns:
(person_id, cosine_similarity)
If similarity < threshold the person is registered as new.
"""
feat_norm = F.normalize(feature.unsqueeze(0), dim=1) # (1, D)
best_pid, best_sim = None, -1.0
for pid, emb in self._embeddings.items():
sim = float((feat_norm @ F.normalize(emb.unsqueeze(0), dim=1).T).item())
if sim > best_sim:
best_sim, best_pid = sim, pid
if best_pid is not None and best_sim >= self.threshold:
# Update existing entry with EMA
self._embeddings[best_pid] = (
self.ema_alpha * self._embeddings[best_pid]
+ (1.0 - self.ema_alpha) * feature.detach()
)
return best_pid, best_sim
# Register new person
pid = self._next_id
self._next_id += 1
self._embeddings[pid] = feature.detach().clone()
return pid, best_sim if best_pid is not None else 1.0
def __len__(self) -> int:
return len(self._embeddings)

Binary file not shown.