Compare commits
11 Commits
53999d6023
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f02926f421 | ||
|
|
66dd16ddc7 | ||
|
|
378512b8a4 | ||
| 49435bcfa2 | |||
| 59b4ef2415 | |||
|
|
1389959e96 | ||
| b000850d68 | |||
| a97e8d54bb | |||
| 4d15d927ad | |||
|
|
5811b8acbf | ||
| 0e9d5740d1 |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -1,3 +1,5 @@
|
||||
**/__pycache__/
|
||||
|
||||
*.pyc
|
||||
|
||||
.DS_Store
|
||||
26
calibs/1_25235293_right.yaml
Normal file
26
calibs/1_25235293_right.yaml
Normal file
@@ -0,0 +1,26 @@
|
||||
image_width: 1440
|
||||
image_height: 1080
|
||||
camera_name: stereo/right
|
||||
camera_matrix:
|
||||
rows: 3
|
||||
cols: 3
|
||||
data: [1269.74474, 0. , 760.15551,
|
||||
0. , 1267.45607, 522.31581,
|
||||
0. , 0. , 1. ]
|
||||
distortion_model: plumb_bob
|
||||
distortion_coefficients:
|
||||
rows: 1
|
||||
cols: 5
|
||||
data: [-0.381294, 0.153290, 0.000474, -0.000930, 0.000000]
|
||||
rectification_matrix:
|
||||
rows: 3
|
||||
cols: 3
|
||||
data: [ 0.99938087, 0.00859203, 0.03411828,
|
||||
-0.00801615, 0.9998237 , -0.01697993,
|
||||
-0.03425815, 0.01669592, 0.99927355]
|
||||
projection_matrix:
|
||||
rows: 3
|
||||
cols: 4
|
||||
data: [1216.34701, 0. , 740.37407, -556.04457,
|
||||
0. , 1216.34701, 531.57756, 0. ,
|
||||
0. , 0. , 1. , 0. ]
|
||||
26
calibs/2_25282106_left.yaml
Normal file
26
calibs/2_25282106_left.yaml
Normal file
@@ -0,0 +1,26 @@
|
||||
image_width: 1440
|
||||
image_height: 1080
|
||||
camera_name: stereo/left
|
||||
camera_matrix:
|
||||
rows: 3
|
||||
cols: 3
|
||||
data: [1286.66791, 0. , 745.16448,
|
||||
0. , 1282.77949, 544.48338,
|
||||
0. , 0. , 1. ]
|
||||
distortion_model: plumb_bob
|
||||
distortion_coefficients:
|
||||
rows: 1
|
||||
cols: 5
|
||||
data: [-0.367441, 0.127306, -0.000022, 0.001156, 0.000000]
|
||||
rectification_matrix:
|
||||
rows: 3
|
||||
cols: 3
|
||||
data: [ 0.99996745, -0.0080634 , -0.00026715,
|
||||
0.00806676, 0.99982566, 0.01683959,
|
||||
0.00013132, -0.0168412 , 0.99985817]
|
||||
projection_matrix:
|
||||
rows: 3
|
||||
cols: 4
|
||||
data: [1216.34701, 0. , 740.37407, 0. ,
|
||||
0. , 1216.34701, 531.57756, 0. ,
|
||||
0. , 0. , 1. , 0. ]
|
||||
26
calibs/3_25462560_right.yaml
Normal file
26
calibs/3_25462560_right.yaml
Normal file
@@ -0,0 +1,26 @@
|
||||
image_width: 1440
|
||||
image_height: 1080
|
||||
camera_name: narrow_stereo/right
|
||||
camera_matrix:
|
||||
rows: 3
|
||||
cols: 3
|
||||
data: [1290.47011, 0. , 739.78524,
|
||||
0. , 1289.71737, 567.53111,
|
||||
0. , 0. , 1. ]
|
||||
distortion_model: plumb_bob
|
||||
distortion_coefficients:
|
||||
rows: 1
|
||||
cols: 5
|
||||
data: [-0.381983, 0.159777, 0.000497, -0.000491, 0.000000]
|
||||
rectification_matrix:
|
||||
rows: 3
|
||||
cols: 3
|
||||
data: [ 0.99462851, 0.01416803, -0.1025348 ,
|
||||
-0.01549098, 0.99980658, -0.01211759,
|
||||
0.10234328, 0.01364086, 0.99465561]
|
||||
projection_matrix:
|
||||
rows: 3
|
||||
cols: 4
|
||||
data: [1237.34409, 0. , 940.91473, -575.62908,
|
||||
0. , 1237.34409, 584.32687, 0. ,
|
||||
0. , 0. , 1. , 0. ]
|
||||
26
calibs/4_25502289_left.yaml
Normal file
26
calibs/4_25502289_left.yaml
Normal file
@@ -0,0 +1,26 @@
|
||||
image_width: 1440
|
||||
image_height: 1080
|
||||
camera_name: narrow_stereo/left
|
||||
camera_matrix:
|
||||
rows: 3
|
||||
cols: 3
|
||||
data: [1279.12874, 0. , 703.5194 ,
|
||||
0. , 1276.40173, 578.67346,
|
||||
0. , 0. , 1. ]
|
||||
distortion_model: plumb_bob
|
||||
distortion_coefficients:
|
||||
rows: 1
|
||||
cols: 5
|
||||
data: [-0.385976, 0.161647, -0.003145, 0.002069, 0.000000]
|
||||
rectification_matrix:
|
||||
rows: 3
|
||||
cols: 3
|
||||
data: [ 0.99248595, -0.02289772, -0.12019708,
|
||||
0.02444886, 0.99963556, 0.01144597,
|
||||
0.11989119, -0.01429864, 0.99268406]
|
||||
projection_matrix:
|
||||
rows: 3
|
||||
cols: 4
|
||||
data: [1237.34409, 0. , 940.91473, 0. ,
|
||||
0. , 1237.34409, 584.32687, 0. ,
|
||||
0. , 0. , 1. , 0. ]
|
||||
26
calibs/5_25503480_right.yaml
Normal file
26
calibs/5_25503480_right.yaml
Normal file
@@ -0,0 +1,26 @@
|
||||
image_width: 1440
|
||||
image_height: 1080
|
||||
camera_name: narrow_stereo/right
|
||||
camera_matrix:
|
||||
rows: 3
|
||||
cols: 3
|
||||
data: [1140.50437, 0. , 729.33819,
|
||||
0. , 1138.09869, 573.8711 ,
|
||||
0. , 0. , 1. ]
|
||||
distortion_model: plumb_bob
|
||||
distortion_coefficients:
|
||||
rows: 1
|
||||
cols: 5
|
||||
data: [-0.351645, 0.119879, -0.000250, -0.001973, 0.000000]
|
||||
rectification_matrix:
|
||||
rows: 3
|
||||
cols: 3
|
||||
data: [ 0.99879507, 0.01187418, -0.04761745,
|
||||
-0.0123277 , 0.9998813 , -0.00924195,
|
||||
0.04750206, 0.00981782, 0.99882289]
|
||||
projection_matrix:
|
||||
rows: 3
|
||||
cols: 4
|
||||
data: [1079.14113, 0. , 823.44205, -493.59296,
|
||||
0. , 1079.14113, 588.52183, 0. ,
|
||||
0. , 0. , 1. , 0. ]
|
||||
26
calibs/6_25503478_left.yaml
Normal file
26
calibs/6_25503478_left.yaml
Normal file
@@ -0,0 +1,26 @@
|
||||
image_width: 1440
|
||||
image_height: 1080
|
||||
camera_name: narrow_stereo/left
|
||||
camera_matrix:
|
||||
rows: 3
|
||||
cols: 3
|
||||
data: [1135.37523, 0. , 726.25468,
|
||||
0. , 1135.35182, 577.53868,
|
||||
0. , 0. , 1. ]
|
||||
distortion_model: plumb_bob
|
||||
distortion_coefficients:
|
||||
rows: 1
|
||||
cols: 5
|
||||
data: [-0.344777, 0.117188, -0.002895, 0.000324, 0.000000]
|
||||
rectification_matrix:
|
||||
rows: 3
|
||||
cols: 3
|
||||
data: [ 0.99840959, -0.02412373, -0.05095416,
|
||||
0.02460856, 0.99965746, 0.0089091 ,
|
||||
0.05072179, -0.01014884, 0.99866125]
|
||||
projection_matrix:
|
||||
rows: 3
|
||||
cols: 4
|
||||
data: [1079.14113, 0. , 823.44205, 0. ,
|
||||
0. , 1079.14113, 588.52183, 0. ,
|
||||
0. , 0. , 1. , 0. ]
|
||||
BIN
demos/top-down v1.webm
Normal file
BIN
demos/top-down v1.webm
Normal file
Binary file not shown.
BIN
demos/top-down v2.webm
Normal file
BIN
demos/top-down v2.webm
Normal file
Binary file not shown.
75
readme.md
75
readme.md
@@ -1 +1,74 @@
|
||||
Yo
|
||||
# GDC_ATRIUM
|
||||
Repository for robust, adaptive 3D person tracking and reidentification in busy areas. Please contact adipu@utexas.edu with questions. Property of AMRL - UT Austin.
|
||||
|
||||
## Environment Setup
|
||||
**Your environment MUST have MMPOSE and the tracking_re_id ROS2 package properly installed** Follow the instructions below to set it up
|
||||
|
||||
```bash
|
||||
# wherever you want your workspace root to be (anywhere upstream of tracking_re_id):
|
||||
colcon build --symlink-install
|
||||
source install/setup.bash
|
||||
|
||||
# 3.10 is the latest supported by MMPOSE
|
||||
conda create -n mmpose python=3.10 -y
|
||||
conda activate mmpose
|
||||
|
||||
# find your max. cuda version
|
||||
nvidia-smi
|
||||
|
||||
# We need torch 2.1.x for mmcv 2.1.0, so here's an example using cuda 12.1
|
||||
# this is what was used on sentry laptop
|
||||
pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu121
|
||||
|
||||
# mmpose dependencies
|
||||
pip install -U openmim
|
||||
mim install mmengine
|
||||
mim install "mmcv==2.1.0"
|
||||
mim install "mmdet==3.2.0"
|
||||
mim install "mmpose==1.2.0"
|
||||
pip install "numpy<2"
|
||||
pip install "opencv-python==4.11.0.86"
|
||||
```
|
||||
|
||||
**This installation is referenced in each launchfile by finding the location of your conda base environment, then finding an environment named "mmpose" - if you use a different name, you must change each launch file.**
|
||||
|
||||
## Camera Driver / Calibration
|
||||
Many nodes in the ```tracking_re_id``` package rely on the ```/stereo/left/camera_info``` and ```/stereo/right/camera_info``` topics publishing calibration data for the camera. This requires calibration to be performed, replacing ```left.yaml``` and ```right.yaml``` in ```tracking_re_id/calibration```.
|
||||
|
||||
Visit https://docs.ros.org/en/jazzy/p/camera_calibration/doc/tutorial_stereo.html for instructions on running the calibration. The left and right yaml files always had ```camera_name=narrow_stereo/left``` and ```camera_name=narrow_stereo/right``` which have to be changed to ```stereo/left``` and ```stereo/right```, respectively.
|
||||
|
||||
To run the camera driver after calibration has been conducted, simply execute ```ros2 launch tracking_re_id start_cameras_only.launch.py```. Running ```colcon build``` again is **NOT** required, as you should have run it with ```--symlink-install``` previously.
|
||||
|
||||
# The following must only be run after the camera driver is started and once mmpose is installed:
|
||||
|
||||
## single_person_loc_node
|
||||
This node forms the basis for 3D pose tracking methods, taking raw image feeds and camera_info from both cameras, generating joint keypoints using MMPOSE, projecting these points onto the image plane, then tracing rays from each camera's copy of the keypoint through the respective focal center to their intersection, the 3D keypoint.
|
||||
|
||||
```ros2 launch tracking_re_id single_person_demo.launch.py``` will start a visualization of the keypoints and show the average normal distance and average coordinate relative to the left camera.
|
||||
|
||||
Other launchfiles run ```single_person_loc_node``` in headless mode, drawing upon the published 3D keypoints.
|
||||
|
||||
## ground_plane_node
|
||||
Listens to keypoints from ```single_person_loc_node```, waiting for ankle keypoints that stay in a 3D location within a 5cm radius for more than 5 frames. Once at least 3 such keypoints are found with at least one off of a line connecting the other two by at least 25 cm, a plane is formed via a least-squares method. The plane is published as a large, disc-shaped marker, along with the keypoints it was based off of.
|
||||
|
||||
## overlay_node
|
||||
Provides visualization of the ground plane, transformed back into the camera feed and overlaid on the undistorted image along with the 3d keypoints
|
||||
|
||||
**```ros2 launch tracking_re_id full_pipeline.launch.py``` will launch ```single_person_loc_node```, ```ground_plane_node```, and ```overlay_node```, allowing for visualization of the ground plane while people are walking through.**
|
||||
|
||||
## reid_utils and reid_node
|
||||
Work together to deliver KeyRe-ID functionality (from ```KeyRe_ID_model.py```) based on weights stored in ```tracking_re_id/weights```. Currently, the pipeline uses the weights I finetuned from ViT-base on the iLIDSVID dataset.
|
||||
|
||||
The Re-ID system works on top of MMPOSE, utilizing generated keypoints to create attention heatmaps and use these to reidentify subjects.
|
||||
|
||||
Running ```ros2 launch reid_pipeline.launch.py``` will return a visualization of the live camera feed, along with reidentification bounding boxes (the confidence value on the first identification is always 1.0, the confidence on a new box with the same ID is the similarity between the current subject and the subject they were matched with, and the confidence on a new box with a different ID is the highest confidence KeyRe-ID had when attempting to match to an existing subject.)
|
||||
|
||||
The Re-ID implementation is based on [KeyRe-ID](https://arxiv.org/abs/2507.07393):
|
||||
```
|
||||
@article{kim2025keyreid,
|
||||
title = {KeyRe‑ID: Keypoint‑Guided Person Re‑Identification using Part‑Aware Representation in Videos},
|
||||
author = {Jinseong Kim and Jeonghoon Song and Gyeongseon Baek and Byeongjoon Noh},
|
||||
journal = {arXiv preprint arXiv:2507.07393},
|
||||
year = {2025}
|
||||
}
|
||||
```
|
||||
30
tracking_re_id/calibration/left.yaml
Normal file
30
tracking_re_id/calibration/left.yaml
Normal file
@@ -0,0 +1,30 @@
|
||||
# Left camera calibration file.
|
||||
# Replace this with the output from your stereo calibration tool (e.g. camera_calibration ROS package).
|
||||
# The file must follow the ROS camera_info YAML format consumed by camera_info_url.
|
||||
|
||||
image_width: 1440
|
||||
image_height: 1080
|
||||
camera_name: stereo/left
|
||||
camera_matrix:
|
||||
rows: 3
|
||||
cols: 3
|
||||
data: [1286.66791, 0. , 745.16448,
|
||||
0. , 1282.77949, 544.48338,
|
||||
0. , 0. , 1. ]
|
||||
distortion_model: plumb_bob
|
||||
distortion_coefficients:
|
||||
rows: 1
|
||||
cols: 5
|
||||
data: [-0.367441, 0.127306, -0.000022, 0.001156, 0.000000]
|
||||
rectification_matrix:
|
||||
rows: 3
|
||||
cols: 3
|
||||
data: [ 0.99996745, -0.0080634 , -0.00026715,
|
||||
0.00806676, 0.99982566, 0.01683959,
|
||||
0.00013132, -0.0168412 , 0.99985817]
|
||||
projection_matrix:
|
||||
rows: 3
|
||||
cols: 4
|
||||
data: [1216.34701, 0. , 740.37407, 0. ,
|
||||
0. , 1216.34701, 531.57756, 0. ,
|
||||
0. , 0. , 1. , 0. ]
|
||||
30
tracking_re_id/calibration/right.yaml
Normal file
30
tracking_re_id/calibration/right.yaml
Normal file
@@ -0,0 +1,30 @@
|
||||
# Right camera calibration file.
|
||||
# Replace this with the output from your stereo calibration tool (e.g. camera_calibration ROS package).
|
||||
# The file must follow the ROS camera_info YAML format consumed by camera_info_url.
|
||||
|
||||
image_width: 1440
|
||||
image_height: 1080
|
||||
camera_name: stereo/right
|
||||
camera_matrix:
|
||||
rows: 3
|
||||
cols: 3
|
||||
data: [1269.74474, 0. , 760.15551,
|
||||
0. , 1267.45607, 522.31581,
|
||||
0. , 0. , 1. ]
|
||||
distortion_model: plumb_bob
|
||||
distortion_coefficients:
|
||||
rows: 1
|
||||
cols: 5
|
||||
data: [-0.381294, 0.153290, 0.000474, -0.000930, 0.000000]
|
||||
rectification_matrix:
|
||||
rows: 3
|
||||
cols: 3
|
||||
data: [ 0.99938087, 0.00859203, 0.03411828,
|
||||
-0.00801615, 0.9998237 , -0.01697993,
|
||||
-0.03425815, 0.01669592, 0.99927355]
|
||||
projection_matrix:
|
||||
rows: 3
|
||||
cols: 4
|
||||
data: [1216.34701, 0. , 740.37407, -556.04457,
|
||||
0. , 1216.34701, 531.57756, 0. ,
|
||||
0. , 0. , 1. , 0. ]
|
||||
66
tracking_re_id/launch/_conda_utils.py
Normal file
66
tracking_re_id/launch/_conda_utils.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""Utilities for locating conda-environment Python interpreters at launch time."""
|
||||
|
||||
import os
|
||||
|
||||
|
||||
def find_conda_python(env_name: str) -> str:
|
||||
"""Return the path to ``python3`` inside a named conda environment.
|
||||
|
||||
Resolution order
|
||||
----------------
|
||||
1. ``<ENV_NAME_UPPER>_PYTHON`` environment variable (explicit override).
|
||||
e.g. for *mmpose*: ``MMPOSE_PYTHON=/path/to/python3``
|
||||
2. ``CONDA_EXE`` environment variable — set by ``conda init`` and points to
|
||||
the conda binary inside the base installation. The base directory is two
|
||||
levels up (``<base>/bin/conda → <base>``).
|
||||
3. A scan of common conda base-directory locations.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
env_name:
|
||||
Name of the conda environment (e.g. ``"mmpose"``).
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
Absolute path to the Python 3 interpreter.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If no interpreter is found through any of the above methods.
|
||||
"""
|
||||
# 1. Explicit override via environment variable
|
||||
override_var = f"{env_name.upper()}_PYTHON"
|
||||
if p := os.environ.get(override_var):
|
||||
return p
|
||||
|
||||
rel = os.path.join("envs", env_name, "bin", "python3")
|
||||
|
||||
# 2. Derive conda base from CONDA_EXE (most reliable when conda is initialised)
|
||||
if conda_exe := os.environ.get("CONDA_EXE"):
|
||||
base = os.path.dirname(os.path.dirname(conda_exe))
|
||||
candidate = os.path.join(base, rel)
|
||||
if os.path.isfile(candidate):
|
||||
return candidate
|
||||
|
||||
# 3. Scan common install locations
|
||||
common_bases = [
|
||||
"~/miniconda3",
|
||||
"~/anaconda3",
|
||||
"~/mambaforge",
|
||||
"~/miniforge3",
|
||||
"~/opt/miniconda3",
|
||||
"/opt/conda",
|
||||
"/opt/miniconda3",
|
||||
"/opt/anaconda3",
|
||||
]
|
||||
for base in common_bases:
|
||||
candidate = os.path.join(os.path.expanduser(base), rel)
|
||||
if os.path.isfile(candidate):
|
||||
return candidate
|
||||
|
||||
raise FileNotFoundError(
|
||||
f"Cannot locate Python for conda env '{env_name}'. "
|
||||
f"Set the {override_var} environment variable to the full path of the interpreter."
|
||||
)
|
||||
@@ -15,14 +15,17 @@ To view the annotated image:
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
from _conda_utils import find_conda_python # noqa: E402
|
||||
|
||||
from launch import LaunchDescription
|
||||
from launch.actions import ExecuteProcess
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
python_exe = os.path.expanduser(
|
||||
'~/miniconda3/envs/mmpose/bin/python3'
|
||||
)
|
||||
python_exe = find_conda_python('mmpose')
|
||||
|
||||
return LaunchDescription([
|
||||
|
||||
|
||||
@@ -6,14 +6,17 @@ Python environment.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
from _conda_utils import find_conda_python # noqa: E402
|
||||
|
||||
from launch import LaunchDescription
|
||||
from launch.actions import ExecuteProcess
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
python_exe = os.path.expanduser(
|
||||
'~/miniconda3/envs/mmpose/bin/python3'
|
||||
)
|
||||
python_exe = find_conda_python('mmpose')
|
||||
|
||||
return LaunchDescription([
|
||||
|
||||
|
||||
@@ -6,12 +6,17 @@ To view output:
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
from _conda_utils import find_conda_python # noqa: E402
|
||||
|
||||
from launch import LaunchDescription
|
||||
from launch.actions import ExecuteProcess
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
python_exe = os.path.expanduser('~/miniconda3/envs/mmpose/bin/python3')
|
||||
python_exe = find_conda_python('mmpose')
|
||||
keyreID_path = os.path.expanduser('~/KeyRe-ID')
|
||||
|
||||
return LaunchDescription([
|
||||
|
||||
@@ -23,12 +23,17 @@ Viewing the output
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
from _conda_utils import find_conda_python # noqa: E402
|
||||
|
||||
from launch import LaunchDescription
|
||||
from launch.actions import ExecuteProcess
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
python_exe = os.path.expanduser('~/miniconda3/envs/mmpose/bin/python3')
|
||||
python_exe = find_conda_python('mmpose')
|
||||
keyreID_path = os.path.expanduser('~/KeyRe-ID')
|
||||
|
||||
return LaunchDescription([
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
"""Launch single_person_loc_node using the mmpose conda environment's Python."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
from _conda_utils import find_conda_python # noqa: E402
|
||||
|
||||
from launch import LaunchDescription
|
||||
from launch.actions import ExecuteProcess
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
python_exe = os.path.expanduser(
|
||||
'~/miniconda3/envs/mmpose/bin/python3'
|
||||
)
|
||||
python_exe = find_conda_python('mmpose')
|
||||
|
||||
node_module = 'tracking_re_id.single_person_loc_node'
|
||||
|
||||
|
||||
@@ -6,14 +6,17 @@ pipeline where visualisation is handled elsewhere.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
from _conda_utils import find_conda_python # noqa: E402
|
||||
|
||||
from launch import LaunchDescription
|
||||
from launch.actions import ExecuteProcess
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
python_exe = os.path.expanduser(
|
||||
'~/miniconda3/envs/mmpose/bin/python3'
|
||||
)
|
||||
python_exe = find_conda_python('mmpose')
|
||||
|
||||
node_module = 'tracking_re_id.single_person_loc_node'
|
||||
|
||||
|
||||
@@ -11,12 +11,17 @@ RIGHT_CAMERA_NAME = 'right'
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
pkg_share = get_package_share_directory('tracking_re_id')
|
||||
|
||||
config_path = os.path.join(
|
||||
get_package_share_directory('spinnaker_camera_driver'),
|
||||
'config',
|
||||
'blackfly_s.yaml'
|
||||
)
|
||||
|
||||
left_cal = 'file://' + os.path.join(pkg_share, 'calibration', 'left.yaml')
|
||||
right_cal = 'file://' + os.path.join(pkg_share, 'calibration', 'right.yaml')
|
||||
|
||||
# ── Camera Drivers (component_container_mt) ─────────────────────────
|
||||
#
|
||||
# Both cameras share a single process because the FLIR Spinnaker SDK
|
||||
@@ -38,8 +43,8 @@ def generate_launch_description():
|
||||
namespace='stereo',
|
||||
parameters=[{
|
||||
'parameter_file': config_path,
|
||||
'serial_number': '25282106',
|
||||
'camera_info_url': 'file:///home/sentry/camera_ws/stereoCal/left.yaml',
|
||||
'serial_number': '25503478',
|
||||
'camera_info_url': left_cal,
|
||||
}],
|
||||
extra_arguments=[{'use_intra_process_comms': True}],
|
||||
),
|
||||
@@ -50,8 +55,8 @@ def generate_launch_description():
|
||||
namespace='stereo',
|
||||
parameters=[{
|
||||
'parameter_file': config_path,
|
||||
'serial_number': '25235293',
|
||||
'camera_info_url': 'file:///home/sentry/camera_ws/stereoCal/right.yaml',
|
||||
'serial_number': '25503480',
|
||||
'camera_info_url': right_cal,
|
||||
}],
|
||||
extra_arguments=[{'use_intra_process_comms': True}],
|
||||
),
|
||||
72
tracking_re_id/launch/top_down_launch.py
Normal file
72
tracking_re_id/launch/top_down_launch.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""Launch the full 3D tracking + ground-plane estimation + top-down view pipeline.
|
||||
|
||||
Nodes started:
|
||||
1. single_person_loc_node -- headless stereo keypoint triangulator
|
||||
publishes: /keypoint_markers (MarkerArray)
|
||||
/keypoints_3d (PointCloud2)
|
||||
2. ground_plane_node -- ground-plane estimator
|
||||
publishes: /ground_plane_markers (MarkerArray)
|
||||
/ground_plane_pose (PoseStamped)
|
||||
3. top_down_node -- synthetic bird's-eye view of the ground plane
|
||||
publishes: /top_down_image (Image)
|
||||
|
||||
To view the top-down image:
|
||||
ros2 run rqt_image_view rqt_image_view → select /top_down_image
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
from _conda_utils import find_conda_python # noqa: E402
|
||||
|
||||
from launch import LaunchDescription
|
||||
from launch.actions import ExecuteProcess
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
python_exe = find_conda_python('mmpose')
|
||||
|
||||
return LaunchDescription([
|
||||
|
||||
# ── 1. Keypoint triangulator (headless) ──────────────────────────────
|
||||
ExecuteProcess(
|
||||
cmd=[
|
||||
python_exe, '-m', 'tracking_re_id.single_person_loc_node',
|
||||
'--ros-args',
|
||||
'-p', 'threshold:=0.3',
|
||||
'-p', 'device:=cuda:0',
|
||||
'-p', 'max_residual:=0.10',
|
||||
'-p', 'headless:=true',
|
||||
],
|
||||
output='screen',
|
||||
env={**os.environ},
|
||||
),
|
||||
|
||||
# ── 2. Ground-plane estimator ─────────────────────────────────────────
|
||||
ExecuteProcess(
|
||||
cmd=[
|
||||
python_exe, '-m', 'tracking_re_id.ground_plane_node',
|
||||
'--ros-args',
|
||||
'-p', 'stable_frames:=5',
|
||||
'-p', 'stable_radius:=0.05',
|
||||
'-p', 'duplicate_radius:=0',
|
||||
'-p', 'collinearity_threshold:=0.25',
|
||||
'-p', 'max_ground_points:=100',
|
||||
'-p', 'min_plane_points:=5',
|
||||
],
|
||||
output='screen',
|
||||
env={**os.environ},
|
||||
),
|
||||
|
||||
# ── 3. Top-down bird's-eye visualiser ────────────────────────────────
|
||||
ExecuteProcess(
|
||||
cmd=[
|
||||
python_exe, '-m', 'tracking_re_id.top_down_node',
|
||||
'--ros-args',
|
||||
],
|
||||
output='screen',
|
||||
env={**os.environ},
|
||||
),
|
||||
|
||||
])
|
||||
@@ -14,6 +14,7 @@ setup(
|
||||
('share/' + package_name, ['package.xml']),
|
||||
(os.path.join('share', package_name, 'launch'), glob('launch/*.py')),
|
||||
(os.path.join('share', package_name, 'weights'), glob('weights/*.pth')),
|
||||
(os.path.join('share', package_name, 'calibration'), glob('calibration/*.yaml')),
|
||||
],
|
||||
install_requires=['setuptools'],
|
||||
zip_safe=True,
|
||||
@@ -31,6 +32,7 @@ setup(
|
||||
'single_person_loc_node = tracking_re_id.single_person_loc_node:main',
|
||||
'ground_plane_node = tracking_re_id.ground_plane_node:main',
|
||||
'overlay_node = tracking_re_id.overlay_node:main',
|
||||
'top_down_node = tracking_re_id.top_down_node:main',
|
||||
'reid_node = tracking_re_id.reid_node:main',
|
||||
],
|
||||
},
|
||||
|
||||
300
tracking_re_id/tracking_re_id/KeyRe_ID_model.py
Normal file
300
tracking_re_id/tracking_re_id/KeyRe_ID_model.py
Normal file
@@ -0,0 +1,300 @@
|
||||
# Source: https://github.com/jinseong0115/KeyRe-ID
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import copy
|
||||
from .vit_ID import TransReID, Block
|
||||
from functools import partial
|
||||
from torch.nn import functional as F
|
||||
from .vit_ID import resize_pos_embed
|
||||
|
||||
|
||||
def TCSS(features, shift, b,t):
|
||||
# aggregate features at patch level
|
||||
features = features.view(b, features.size(1), t*features.size(2))
|
||||
token = features[:, 0:1]
|
||||
|
||||
batchsize = features.size(0)
|
||||
dim = features.size(-1)
|
||||
|
||||
# shift the patches with amount=shift
|
||||
features= torch.cat([features[:, shift:], features[:, 1:shift]], dim=1)
|
||||
|
||||
# Patch Shuffling by 2 part
|
||||
try:
|
||||
features = features.view(batchsize, 2, -1, dim)
|
||||
except:
|
||||
features = torch.cat([features, features[:, -2:-1, :]], dim=1)
|
||||
features = features.view(batchsize, 2, -1, dim)
|
||||
|
||||
features = torch.transpose(features, 1, 2).contiguous()
|
||||
features = features.view(batchsize, -1, dim)
|
||||
|
||||
return features, token
|
||||
|
||||
def weights_init_kaiming(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Linear') != -1:
|
||||
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out')
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
elif classname.find('Conv') != -1:
|
||||
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
elif classname.find('BatchNorm') != -1:
|
||||
if m.affine:
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
|
||||
def weights_init_classifier(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Linear') != -1:
|
||||
nn.init.normal_(m.weight, std=0.001)
|
||||
if m.bias:
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
|
||||
|
||||
class KeyRe_ID(nn.Module):
|
||||
def __init__(self, num_classes, camera_num, pretrainpath):
|
||||
super(KeyRe_ID, self).__init__()
|
||||
self.in_planes = 768
|
||||
self.num_classes = num_classes
|
||||
|
||||
self.base =TransReID(
|
||||
img_size=[256, 128], patch_size=16, stride_size=[16, 16], embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,\
|
||||
camera=camera_num, drop_path_rate=0.1, drop_rate=0.0, attn_drop_rate=0.0,norm_layer=partial(nn.LayerNorm, eps=1e-6), cam_lambda=3.0)
|
||||
|
||||
# state_dict = torch.load(pretrainpath, map_location='cpu')
|
||||
# self.base.load_param(state_dict,load=True)
|
||||
if pretrainpath:
|
||||
state_dict = torch.load(pretrainpath, map_location='cpu', weights_only=False)
|
||||
self.base.load_param(state_dict, load=True)
|
||||
|
||||
#-------------------Global Branch-------------
|
||||
block= self.base.blocks[-1]
|
||||
layer_norm = self.base.norm
|
||||
self.b1 = nn.Sequential(
|
||||
copy.deepcopy(block),
|
||||
copy.deepcopy(layer_norm)
|
||||
)
|
||||
|
||||
self.bottleneck = nn.BatchNorm1d(self.in_planes)
|
||||
self.bottleneck.bias.requires_grad_(False)
|
||||
self.bottleneck.apply(weights_init_kaiming)
|
||||
self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False)
|
||||
self.classifier.apply(weights_init_classifier)
|
||||
|
||||
#-------------------Local Branch-------------
|
||||
# building local video stream
|
||||
dpr = [x.item() for x in torch.linspace(0, 0, 12)] # stochastic depth decay rule
|
||||
|
||||
self.block1 = Block(
|
||||
dim=3072, num_heads=12, mlp_ratio=4, qkv_bias=True, qk_scale=None,
|
||||
drop=0, attn_drop=0, drop_path=dpr[11], norm_layer=partial(nn.LayerNorm, eps=1e-6))
|
||||
self.b2 = nn.Sequential(
|
||||
self.block1,
|
||||
nn.LayerNorm(3072) # copy.deepcopy(layer_norm)
|
||||
)
|
||||
|
||||
self.bottleneck_1 = nn.BatchNorm1d(3072)
|
||||
self.bottleneck_1.bias.requires_grad_(False)
|
||||
self.bottleneck_1.apply(weights_init_kaiming)
|
||||
self.bottleneck_2 = nn.BatchNorm1d(3072)
|
||||
self.bottleneck_2.bias.requires_grad_(False)
|
||||
self.bottleneck_2.apply(weights_init_kaiming)
|
||||
self.bottleneck_3 = nn.BatchNorm1d(3072)
|
||||
self.bottleneck_3.bias.requires_grad_(False)
|
||||
self.bottleneck_3.apply(weights_init_kaiming)
|
||||
self.bottleneck_4 = nn.BatchNorm1d(3072)
|
||||
self.bottleneck_4.bias.requires_grad_(False)
|
||||
self.bottleneck_4.apply(weights_init_kaiming)
|
||||
self.bottleneck_5 = nn.BatchNorm1d(3072)
|
||||
self.bottleneck_5.bias.requires_grad_(False)
|
||||
self.bottleneck_5.apply(weights_init_kaiming)
|
||||
self.bottleneck_6 = nn.BatchNorm1d(3072)
|
||||
self.bottleneck_6.bias.requires_grad_(False)
|
||||
self.bottleneck_6.apply(weights_init_kaiming)
|
||||
|
||||
self.classifier_1 = nn.Linear(3072, self.num_classes, bias=False)
|
||||
self.classifier_1.apply(weights_init_classifier)
|
||||
self.classifier_2 = nn.Linear(3072, self.num_classes, bias=False)
|
||||
self.classifier_2.apply(weights_init_classifier)
|
||||
self.classifier_3 = nn.Linear(3072, self.num_classes, bias=False)
|
||||
self.classifier_3.apply(weights_init_classifier)
|
||||
self.classifier_4 = nn.Linear(3072, self.num_classes, bias=False)
|
||||
self.classifier_4.apply(weights_init_classifier)
|
||||
self.classifier_5 = nn.Linear(3072, self.num_classes, bias=False)
|
||||
self.classifier_5.apply(weights_init_classifier)
|
||||
self.classifier_6 = nn.Linear(3072, self.num_classes, bias=False)
|
||||
self.classifier_6.apply(weights_init_classifier)
|
||||
|
||||
#-------------------video attention-------------
|
||||
self.middle_dim = 256 # middle layer dimension
|
||||
self.attention_conv = nn.Conv2d(self.in_planes, self.middle_dim, [1,1]) # 7,4 cooresponds to 224, 112 input image size
|
||||
self.attention_tconv = nn.Conv1d(self.middle_dim, 1, 3, padding=1)
|
||||
self.attention_conv.apply(weights_init_kaiming)
|
||||
self.attention_tconv.apply(weights_init_kaiming)
|
||||
#------------------------------------------
|
||||
self.shift_num = 5
|
||||
self.part = 6
|
||||
self.rearrange=True
|
||||
|
||||
def forward(self, x, heatmaps, label=None, cam_label= None, view_label=None): # label is unused if self.cos_layer == 'no'
|
||||
b = x.size(0)
|
||||
t = x.size(1)
|
||||
|
||||
x = x.view(x.size(0)*x.size(1), x.size(2), x.size(3), x.size(4))
|
||||
features = self.base(x, cam_label=cam_label)
|
||||
|
||||
#-------------------Global Branch-------------
|
||||
b1_feat = self.b1(features) # [64, 129, 3072]
|
||||
global_feat = b1_feat[:, 0]
|
||||
|
||||
global_feat = global_feat.unsqueeze(dim=2).unsqueeze(dim=3)
|
||||
a = F.relu(self.attention_conv(global_feat))
|
||||
a = a.view(b, t, self.middle_dim)
|
||||
a = a.permute(0,2,1)
|
||||
a = F.relu(self.attention_tconv(a))
|
||||
a = a.view(b, t)
|
||||
a_vals = a
|
||||
|
||||
a = F.softmax(a, dim=1)
|
||||
x = global_feat.view(b, t, -1)
|
||||
a = torch.unsqueeze(a, -1)
|
||||
a = a.expand(b, t, self.in_planes)
|
||||
att_x = torch.mul(x,a)
|
||||
att_x = torch.sum(att_x, 1)
|
||||
|
||||
global_feat = att_x.view(b, self.in_planes)
|
||||
feat = self.bottleneck(global_feat)
|
||||
|
||||
#-------------------Local Branch-------------
|
||||
# Heatmap Processing
|
||||
heatmaps = heatmaps.view(b*t, 6, 256, 128) # [B*T, 6, 256, 128]
|
||||
heatmap_patches = F.unfold(heatmaps, kernel_size=16, stride=16) # [B*T, 6*16*16, 128]
|
||||
heatmap_patches = heatmap_patches.view(b*t, 6, 16*16, 128).mean(dim=2) # [B*T, 6, 128]
|
||||
heatmap_weights = heatmap_patches.transpose(1, 2) # [B*T, 128, 6]
|
||||
heatmap_weights = heatmap_weights.view(b, t, 128, 6).mean(dim=1) # [B, 128, 6]
|
||||
|
||||
# Temporal clip shift and shuffled
|
||||
x ,token = TCSS(features, self.shift_num, b, t)
|
||||
patch_feats = x
|
||||
|
||||
# Part 1: Head
|
||||
part1_weight = heatmap_weights[:, :, 0].unsqueeze(-1)
|
||||
part1 = patch_feats * part1_weight
|
||||
part1 = self.b2(torch.cat((token, part1), dim=1))
|
||||
part1_f = part1[:, 0]
|
||||
|
||||
# Part 2: Torso
|
||||
part2_weight = heatmap_weights[:, :, 1].unsqueeze(-1)
|
||||
part2 = patch_feats * part2_weight
|
||||
part2 = self.b2(torch.cat((token, part2), dim=1))
|
||||
part2_f = part2[:, 0]
|
||||
|
||||
# Part 3: Left Arm
|
||||
part3_weight = heatmap_weights[:, :, 2].unsqueeze(-1)
|
||||
part3 = patch_feats * part3_weight
|
||||
part3 = self.b2(torch.cat((token, part3), dim=1))
|
||||
part3_f = part3[:, 0]
|
||||
|
||||
# Part 4: Right Arm
|
||||
part4_weight = heatmap_weights[:, :, 3].unsqueeze(-1)
|
||||
part4 = patch_feats * part4_weight
|
||||
part4 = self.b2(torch.cat((token, part4), dim=1))
|
||||
part4_f = part4[:, 0]
|
||||
|
||||
# Part 5: Left Leg
|
||||
part5_weight = heatmap_weights[:, :, 4].unsqueeze(-1)
|
||||
part5 = patch_feats * part5_weight
|
||||
part5 = self.b2(torch.cat((token, part5), dim=1))
|
||||
part5_f = part5[:, 0]
|
||||
|
||||
# Part 6: Right Leg
|
||||
part6_weight = heatmap_weights[:, :, 5].unsqueeze(-1)
|
||||
part6 = patch_feats * part6_weight
|
||||
part6 = self.b2(torch.cat((token, part6), dim=1))
|
||||
part6_f = part6[:, 0]
|
||||
|
||||
# Apply batch normalization
|
||||
part1_bn = self.bottleneck_1(part1_f)
|
||||
part2_bn = self.bottleneck_2(part2_f)
|
||||
part3_bn = self.bottleneck_3(part3_f)
|
||||
part4_bn = self.bottleneck_4(part4_f)
|
||||
part5_bn = self.bottleneck_5(part5_f)
|
||||
part6_bn = self.bottleneck_6(part6_f)
|
||||
|
||||
if self.training:
|
||||
Global_ID = self.classifier(feat)
|
||||
Local_ID1 = self.classifier_1(part1_bn)
|
||||
Local_ID2 = self.classifier_2(part2_bn)
|
||||
Local_ID3 = self.classifier_3(part3_bn)
|
||||
Local_ID4 = self.classifier_4(part4_bn)
|
||||
Local_ID5 = self.classifier_5(part5_bn)
|
||||
Local_ID6 = self.classifier_6(part6_bn)
|
||||
|
||||
return [Global_ID, Local_ID1, Local_ID2, Local_ID3, Local_ID4, Local_ID5, Local_ID6],\
|
||||
[global_feat, part1_f, part2_f, part3_f, part4_f, part5_f, part6_f], a_vals
|
||||
else:
|
||||
return torch.cat([feat, part1_bn/self.part, part2_bn/self.part, part3_bn/self.part,
|
||||
part4_bn/self.part, part5_bn/self.part, part6_bn/self.part], dim=1)
|
||||
|
||||
def load_param(self, trained_path, load=False):
|
||||
print("Run load_param")
|
||||
if not load:
|
||||
param_dict = torch.load(trained_path, map_location='cpu', weights_only=False)
|
||||
else:
|
||||
param_dict = trained_path
|
||||
|
||||
if 'model' in param_dict:
|
||||
param_dict = param_dict['model']
|
||||
if 'state_dict' in param_dict:
|
||||
param_dict = param_dict['state_dict']
|
||||
|
||||
model_dict = self.state_dict() # Get the state_dict of the current model
|
||||
new_param_dict = {}
|
||||
|
||||
for k, v in param_dict.items():
|
||||
if 'head' in k or 'dist' in k:
|
||||
continue
|
||||
|
||||
# Patch embedding Conv-based transformation processing
|
||||
if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
|
||||
O, I, H, W = self.base.patch_embed.proj.weight.shape
|
||||
v = v.reshape(O, -1, H, W)
|
||||
# Resize Positional Embedding
|
||||
elif k == 'pos_embed' and v.shape != self.base.pos_embed.shape:
|
||||
v = resize_pos_embed(v, self.base.pos_embed, self.base.patch_embed.num_y, self.base.patch_embed.num_x)
|
||||
|
||||
# Handling `base.` prefix
|
||||
new_k = k
|
||||
if k.startswith("base.") and k[5:] in model_dict:
|
||||
new_k = k[5:] # Remove base.
|
||||
elif not k.startswith("base.") and ("base." + k) in model_dict:
|
||||
new_k = "base." + k # Add base.
|
||||
|
||||
if new_k in ['Cam', 'base.Cam'] and new_k in model_dict:
|
||||
expected_shape = model_dict[new_k].shape # Cam size that the current model expects
|
||||
print(f"[Before Resizing] {new_k}: {v.shape} -> Expected: {expected_shape}")
|
||||
|
||||
if v.shape[0] > expected_shape[0]: # Keep only the front part if the size is larger
|
||||
v = v[:expected_shape[0], :, :]
|
||||
elif v.shape[0] < expected_shape[0]: # Create a new tensor for smaller sizes
|
||||
new_v = torch.randn(expected_shape) # Random initialization (other values are possible)
|
||||
new_v[:v.shape[0], :, :] = v # Keep existing values
|
||||
v = new_v
|
||||
|
||||
print(f"[After Resizing] {new_k}: {v.shape}") # Confirm after changing the size
|
||||
new_param_dict[new_k] = v
|
||||
continue
|
||||
|
||||
# Update only if Shape fits
|
||||
if new_k in model_dict and model_dict[new_k].shape == v.shape:
|
||||
new_param_dict[new_k] = v
|
||||
|
||||
# Finally, update the state_dict
|
||||
model_dict.update(new_param_dict)
|
||||
self.load_state_dict(model_dict, strict=False)
|
||||
print("Checkpoint loaded successfully.")
|
||||
|
||||
|
||||
@@ -31,7 +31,6 @@ Pipeline (per frame)
|
||||
Parameters
|
||||
──────────
|
||||
weights_path str path to iLIDSVIDbest_CMC.pth (required)
|
||||
keyreID_path str path to KeyRe-ID source directory
|
||||
num_classes int training split size (150 for iLIDS-VID split-0)
|
||||
camera_num int cameras in training set (2 for iLIDS-VID)
|
||||
device str 'cuda:0' or 'cpu'
|
||||
@@ -44,7 +43,6 @@ Parameters
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import colorsys
|
||||
|
||||
@@ -140,8 +138,6 @@ class ReIDNode(Node):
|
||||
os.path.join(
|
||||
get_package_share_directory('tracking_re_id'),
|
||||
'weights', 'iLIDSVIDbest_CMC.pth'))
|
||||
self.declare_parameter('keyreID_path',
|
||||
os.path.expanduser('~/KeyRe-ID'))
|
||||
self.declare_parameter('num_classes', 150)
|
||||
self.declare_parameter('camera_num', 2)
|
||||
self.declare_parameter('device', 'cuda:0')
|
||||
@@ -153,7 +149,6 @@ class ReIDNode(Node):
|
||||
self.declare_parameter('headless', False)
|
||||
|
||||
weights_path = self.get_parameter('weights_path').value
|
||||
keyreID_path = self.get_parameter('keyreID_path').value
|
||||
num_classes = self.get_parameter('num_classes').value
|
||||
camera_num = self.get_parameter('camera_num').value
|
||||
device_str = self.get_parameter('device').value
|
||||
@@ -175,14 +170,7 @@ class ReIDNode(Node):
|
||||
self.get_logger().info('MMPose loaded.')
|
||||
|
||||
# ── KeyRe-ID ─────────────────────────────────────────────────────────
|
||||
if keyreID_path not in sys.path:
|
||||
sys.path.insert(0, keyreID_path)
|
||||
try:
|
||||
from KeyRe_ID_model import KeyRe_ID # noqa: PLC0415
|
||||
except ImportError as exc:
|
||||
self.get_logger().fatal(
|
||||
f'Cannot import KeyRe_ID_model from {keyreID_path}: {exc}')
|
||||
raise
|
||||
from .KeyRe_ID_model import KeyRe_ID # noqa: PLC0415
|
||||
|
||||
self.get_logger().info(f'Loading KeyRe-ID weights from {weights_path} …')
|
||||
self._model = KeyRe_ID(
|
||||
|
||||
350
tracking_re_id/tracking_re_id/top_down_node.py
Normal file
350
tracking_re_id/tracking_re_id/top_down_node.py
Normal file
@@ -0,0 +1,350 @@
|
||||
"""
|
||||
ROS 2 node: top_down_node
|
||||
|
||||
Renders a synthetic bird's-eye (top-down) 2-D view of the detected ground
|
||||
plane. Each detected person is shown as a filled circle (centred on their
|
||||
ankle midpoint) with a facing arrow. The facing direction is derived by
|
||||
projecting the 3-D vector from the shoulder midpoint to the nose onto the
|
||||
ground plane — no cross-product or disambiguation required.
|
||||
|
||||
Before the ground plane is established the window shows "Waiting for plane…".
|
||||
|
||||
Subscriptions
|
||||
─────────────
|
||||
/keypoints_3d sensor_msgs/PointCloud2 (from single_person_loc_node)
|
||||
/ground_plane_pose geometry_msgs/PoseStamped (from ground_plane_node)
|
||||
|
||||
Publications
|
||||
────────────
|
||||
/top_down_image sensor_msgs/Image (BGR8, fixed-size canvas)
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from sensor_msgs.msg import Image, PointCloud2
|
||||
from geometry_msgs.msg import PoseStamped
|
||||
from sensor_msgs_py import point_cloud2 as pc2
|
||||
from cv_bridge import CvBridge
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# TUNABLE PARAMETERS
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
# Canvas side length in pixels (square output image).
|
||||
CANVAS_SIZE: int = 800
|
||||
|
||||
# Physical half-extent of the canvas from the plane origin (metres).
|
||||
# A value of 4.0 yields an 8 m × 8 m visible area.
|
||||
CANVAS_HALF_M: float = 4.0
|
||||
|
||||
# Derived: pixels per metre.
|
||||
SCALE: float = CANVAS_SIZE / (2.0 * CANVAS_HALF_M)
|
||||
|
||||
# Canvas background colour (B, G, R).
|
||||
BG_COLOR: tuple = (30, 30, 30)
|
||||
|
||||
# Regular grid line colour (B, G, R).
|
||||
GRID_COLOR: tuple = (70, 70, 70)
|
||||
|
||||
# Grid line spacing in metres.
|
||||
GRID_SPACING: float = 0.5
|
||||
|
||||
# Axis (u=0, v=0) line colour (B, G, R).
|
||||
AXIS_COLOR: tuple = (0, 180, 0)
|
||||
|
||||
# Radius of the per-person presence circle in metres.
|
||||
PERSON_RADIUS_M: float = 0.20
|
||||
|
||||
# Fill alpha of the presence circle (0.0–1.0).
|
||||
CIRCLE_ALPHA: float = 0.45
|
||||
|
||||
# Thickness of the presence circle outline in pixels.
|
||||
CIRCLE_OUTLINE_PX: int = 2
|
||||
|
||||
# Length of the facing arrow beyond the circle edge, in metres.
|
||||
ARROW_LENGTH_M: float = 0.7
|
||||
|
||||
# Fraction of arrow length used for the arrowhead.
|
||||
ARROW_TIP_FRACTION: float = 0.30
|
||||
|
||||
# Publish rate of the top-down image in Hz.
|
||||
PUBLISH_HZ: float = 15.0
|
||||
|
||||
# Per-person colours, cycled by person_id (B, G, R).
|
||||
PERSON_COLORS: list = [
|
||||
(0, 230, 0), # green
|
||||
(0, 100, 255), # orange
|
||||
(255, 50, 200), # magenta
|
||||
(0, 200, 255), # yellow
|
||||
]
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# Helpers
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def _quat_to_rot(qx: float, qy: float, qz: float, qw: float) -> np.ndarray:
|
||||
"""Return the 3×3 rotation matrix for unit quaternion (x, y, z, w)."""
|
||||
return np.array([
|
||||
[1 - 2*(qy*qy + qz*qz), 2*(qx*qy - qz*qw), 2*(qx*qz + qy*qw)],
|
||||
[ 2*(qx*qy + qz*qw), 1 - 2*(qx*qx + qz*qz), 2*(qy*qz - qx*qw)],
|
||||
[ 2*(qx*qz - qy*qw), 2*(qy*qz + qx*qw), 1 - 2*(qx*qx + qy*qy)],
|
||||
], dtype=np.float64)
|
||||
|
||||
|
||||
def _to_plane_uv(pt3d: np.ndarray,
|
||||
origin: np.ndarray,
|
||||
u_axis: np.ndarray,
|
||||
v_axis: np.ndarray) -> np.ndarray:
|
||||
"""Project a 3-D camera-frame point onto the plane, returning (u, v)."""
|
||||
d = pt3d - origin
|
||||
return np.array([np.dot(d, u_axis), np.dot(d, v_axis)], dtype=np.float64)
|
||||
|
||||
|
||||
def _to_pixel(uv: np.ndarray) -> tuple[int, int]:
|
||||
"""
|
||||
Map plane coordinates (u, v) to integer canvas pixel (x, y).
|
||||
|
||||
u increases to the right; v is flipped so that "further from camera"
|
||||
appears upward on the canvas.
|
||||
"""
|
||||
px = int(round( uv[0] * SCALE + CANVAS_SIZE * 0.5))
|
||||
py = int(round(-uv[1] * SCALE + CANVAS_SIZE * 0.5))
|
||||
return px, py
|
||||
|
||||
|
||||
def _ankle_center(kp_dict: dict[int, np.ndarray],
|
||||
origin: np.ndarray,
|
||||
u_axis: np.ndarray,
|
||||
v_axis: np.ndarray) -> np.ndarray:
|
||||
"""Return the mean ankle plane-UV, falling back to all-keypoint mean."""
|
||||
ankles = [_to_plane_uv(kp_dict[k], origin, u_axis, v_axis)
|
||||
for k in (15, 16) if k in kp_dict]
|
||||
if ankles:
|
||||
return np.mean(ankles, axis=0)
|
||||
return np.mean(
|
||||
[_to_plane_uv(v, origin, u_axis, v_axis) for v in kp_dict.values()],
|
||||
axis=0)
|
||||
|
||||
|
||||
def _facing_dir(kp_dict: dict[int, np.ndarray],
|
||||
u_axis: np.ndarray,
|
||||
v_axis: np.ndarray) -> np.ndarray | None:
|
||||
"""
|
||||
Return the normalised facing direction as a plane (u, v) unit vector, or None.
|
||||
|
||||
Method
|
||||
------
|
||||
The 3-D vector (nose − shoulder_midpoint) points directly forward from
|
||||
the person's body. Projecting it onto the ground plane and normalising
|
||||
gives an unambiguous facing direction — no cross-product or sign-flip
|
||||
disambiguation needed.
|
||||
|
||||
Falls back to (nose − hip_midpoint) if shoulders are not detected.
|
||||
Returns None if neither fallback has enough data.
|
||||
"""
|
||||
nose = kp_dict.get(0)
|
||||
if nose is None:
|
||||
return None
|
||||
|
||||
# Prefer shoulders (5, 6), fall back to hips (11, 12)
|
||||
refs = [kp_dict[k] for k in (5, 6) if k in kp_dict]
|
||||
if not refs:
|
||||
refs = [kp_dict[k] for k in (11, 12) if k in kp_dict]
|
||||
if not refs:
|
||||
return None
|
||||
|
||||
ref_mid = np.mean(refs, axis=0) # shoulder (or hip) midpoint
|
||||
forward_3d = nose - ref_mid # points toward the face
|
||||
|
||||
facing_uv = np.array([np.dot(forward_3d, u_axis),
|
||||
np.dot(forward_3d, v_axis)], dtype=np.float64)
|
||||
n = np.linalg.norm(facing_uv)
|
||||
if n < 1e-6:
|
||||
return None
|
||||
return facing_uv / n
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# Node
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class TopDownNode(Node):
|
||||
"""
|
||||
Publishes a top-down 2-D view of the ground plane with:
|
||||
• a metric grid
|
||||
• one circle per detected person (centred on ankle midpoint)
|
||||
• a facing arrow per person (nose − shoulder direction, projected)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__('top_down_node')
|
||||
|
||||
self._bridge = CvBridge()
|
||||
|
||||
# {person_id: {kp_id: np.ndarray([x, y, z])}}
|
||||
self._kps: dict[int, dict[int, np.ndarray]] = {}
|
||||
|
||||
# (origin, R) — R cols: [u_axis | v_axis | normal]
|
||||
self._plane: tuple[np.ndarray, np.ndarray] | None = None
|
||||
|
||||
self.create_subscription(
|
||||
PointCloud2, '/keypoints_3d', self._kp_cb, 10)
|
||||
self.create_subscription(
|
||||
PoseStamped, '/ground_plane_pose', self._plane_cb, 10)
|
||||
|
||||
self._pub = self.create_publisher(Image, '/top_down_image', 10)
|
||||
self.create_timer(1.0 / PUBLISH_HZ, self._render)
|
||||
|
||||
self.get_logger().info('TopDown node started — publishing /top_down_image')
|
||||
|
||||
# ── Upstream data caches ──────────────────────────────────────────────────
|
||||
|
||||
def _kp_cb(self, msg: PointCloud2):
|
||||
kps: dict[int, dict[int, np.ndarray]] = {}
|
||||
for pt in pc2.read_points(
|
||||
msg,
|
||||
field_names=('x', 'y', 'z', 'person_id', 'kp_id'),
|
||||
skip_nans=True):
|
||||
pid = int(pt[3])
|
||||
kid = int(pt[4])
|
||||
kps.setdefault(pid, {})[kid] = np.array(
|
||||
[pt[0], pt[1], pt[2]], dtype=np.float64)
|
||||
self._kps = kps
|
||||
|
||||
def _plane_cb(self, msg: PoseStamped):
|
||||
o = msg.pose.position
|
||||
q = msg.pose.orientation
|
||||
origin = np.array([o.x, o.y, o.z], dtype=np.float64)
|
||||
R = _quat_to_rot(q.x, q.y, q.z, q.w)
|
||||
self._plane = (origin, R)
|
||||
|
||||
# ── Main render ───────────────────────────────────────────────────────────
|
||||
|
||||
def _render(self):
|
||||
canvas = np.full((CANVAS_SIZE, CANVAS_SIZE, 3), BG_COLOR, dtype=np.uint8)
|
||||
|
||||
if self._plane is None:
|
||||
_draw_waiting(canvas)
|
||||
self._pub.publish(
|
||||
self._bridge.cv2_to_imgmsg(canvas, encoding='bgr8'))
|
||||
return
|
||||
|
||||
origin, R = self._plane
|
||||
u_axis = R[:, 0]
|
||||
v_axis = R[:, 1]
|
||||
|
||||
self._draw_grid(canvas)
|
||||
|
||||
for person_id, kp_dict in self._kps.items():
|
||||
color = PERSON_COLORS[person_id % len(PERSON_COLORS)]
|
||||
|
||||
center_uv = _ankle_center(kp_dict, origin, u_axis, v_axis)
|
||||
cx, cy = _to_pixel(center_uv)
|
||||
r_px = max(1, int(round(PERSON_RADIUS_M * SCALE)))
|
||||
|
||||
# Filled circle (alpha-blended)
|
||||
overlay = canvas.copy()
|
||||
cv2.circle(overlay, (cx, cy), r_px, color, -1, cv2.LINE_AA)
|
||||
cv2.addWeighted(overlay, CIRCLE_ALPHA,
|
||||
canvas, 1.0 - CIRCLE_ALPHA, 0, canvas)
|
||||
|
||||
# Circle outline (fully opaque)
|
||||
cv2.circle(canvas, (cx, cy), r_px, color, CIRCLE_OUTLINE_PX, cv2.LINE_AA)
|
||||
|
||||
# Person label
|
||||
cv2.putText(canvas, f'P{person_id}',
|
||||
(cx + r_px + 4, cy + 5),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.55, color, 2, cv2.LINE_AA)
|
||||
|
||||
# Facing arrow: base at circle edge, tip beyond it
|
||||
fdir = _facing_dir(kp_dict, u_axis, v_axis)
|
||||
if fdir is not None:
|
||||
base_uv = center_uv + fdir * PERSON_RADIUS_M
|
||||
tip_uv = center_uv + fdir * (PERSON_RADIUS_M + ARROW_LENGTH_M)
|
||||
base_px = _to_pixel(base_uv)
|
||||
tip_px = _to_pixel(tip_uv)
|
||||
cv2.arrowedLine(canvas, base_px, tip_px,
|
||||
(255, 255, 255), 2, cv2.LINE_AA,
|
||||
tipLength=ARROW_TIP_FRACTION)
|
||||
|
||||
self._pub.publish(
|
||||
self._bridge.cv2_to_imgmsg(canvas, encoding='bgr8'))
|
||||
|
||||
# ── Grid ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def _draw_grid(self, canvas: np.ndarray):
|
||||
lim = CANVAS_HALF_M
|
||||
steps = np.arange(-lim, lim + GRID_SPACING * 0.5, GRID_SPACING)
|
||||
|
||||
for s in steps:
|
||||
is_axis = abs(s) < 1e-6
|
||||
color = AXIS_COLOR if is_axis else GRID_COLOR
|
||||
thick = 2 if is_axis else 1
|
||||
|
||||
x0, y0 = _to_pixel(np.array([s, -lim]))
|
||||
x1, y1 = _to_pixel(np.array([s, lim]))
|
||||
cv2.line(canvas, (x0, y0), (x1, y1), color, thick, cv2.LINE_AA)
|
||||
|
||||
x0, y0 = _to_pixel(np.array([-lim, s]))
|
||||
x1, y1 = _to_pixel(np.array([ lim, s]))
|
||||
cv2.line(canvas, (x0, y0), (x1, y1), color, thick, cv2.LINE_AA)
|
||||
|
||||
# Origin marker
|
||||
ox, oy = _to_pixel(np.zeros(2))
|
||||
cv2.circle(canvas, (ox, oy), 5, AXIS_COLOR, -1, cv2.LINE_AA)
|
||||
cv2.putText(canvas, 'origin', (ox + 8, oy - 8),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.38, AXIS_COLOR, 1, cv2.LINE_AA)
|
||||
|
||||
# Axis tip labels
|
||||
cv2.putText(canvas, 'u+',
|
||||
_to_pixel(np.array([lim - 0.4, 0.0])),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.45, AXIS_COLOR, 1, cv2.LINE_AA)
|
||||
cv2.putText(canvas, 'v+',
|
||||
_to_pixel(np.array([0.0, lim - 0.4])),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.45, AXIS_COLOR, 1, cv2.LINE_AA)
|
||||
|
||||
# 1 m scale bar
|
||||
bx0, by0 = _to_pixel(np.array([-lim + 0.2, -lim + 0.2]))
|
||||
bx1, by1 = _to_pixel(np.array([-lim + 1.2, -lim + 0.2]))
|
||||
cv2.line(canvas, (bx0, by0), (bx1, by1), (180, 180, 180), 2, cv2.LINE_AA)
|
||||
cv2.putText(canvas, '1 m', (bx0, by0 - 6),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.38, (180, 180, 180), 1, cv2.LINE_AA)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# Module-level helpers
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def _draw_waiting(canvas: np.ndarray):
|
||||
text = 'Waiting for plane...'
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
scale = 1.0
|
||||
thick = 2
|
||||
(tw, th), _ = cv2.getTextSize(text, font, scale, thick)
|
||||
cv2.putText(canvas, text,
|
||||
((CANVAS_SIZE - tw) // 2, (CANVAS_SIZE + th) // 2),
|
||||
font, scale, (200, 200, 200), thick, cv2.LINE_AA)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# Entry point
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = TopDownNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.try_shutdown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
352
tracking_re_id/tracking_re_id/vit_ID.py
Normal file
352
tracking_re_id/tracking_re_id/vit_ID.py
Normal file
@@ -0,0 +1,352 @@
|
||||
import math
|
||||
from itertools import repeat
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import collections.abc
|
||||
|
||||
|
||||
# From PyTorch internals
|
||||
def _ntuple(n):
|
||||
def parse(x):
|
||||
if isinstance(x, collections.abc.Iterable):
|
||||
return x
|
||||
return tuple(repeat(x, n))
|
||||
return parse
|
||||
|
||||
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
||||
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
||||
to_2tuple = _ntuple(2)
|
||||
|
||||
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
|
||||
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
||||
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
||||
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
||||
'survival rate' as the argument.
|
||||
|
||||
"""
|
||||
if drop_prob == 0. or not training:
|
||||
return x
|
||||
keep_prob = 1 - drop_prob
|
||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
||||
random_tensor.floor_() # binarize
|
||||
output = x.div(keep_prob) * random_tensor
|
||||
return output
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
||||
def __init__(self, drop_prob=None):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
||||
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention(
|
||||
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x)))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
""" Image to Patch Embedding"""
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.num_patches = num_patches
|
||||
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
# FIXME look at relaxing size constraints
|
||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
x = self.proj(x).flatten(2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
class PatchEmbed_overlap(nn.Module):
|
||||
""" Image to Patch Embedding with overlapping patches"""
|
||||
def __init__(self, img_size=224, patch_size=16, stride_size=20, in_chans=3, embed_dim=768):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
stride_size_tuple = to_2tuple(stride_size)
|
||||
self.num_x = (img_size[1] - patch_size[1]) // stride_size_tuple[1] + 1
|
||||
self.num_y = (img_size[0] - patch_size[0]) // stride_size_tuple[0] + 1
|
||||
print('using stride: {}, and patch number is num_y{} * num_x{}'.format(stride_size, self.num_y, self.num_x))
|
||||
num_patches = self.num_x * self.num_y
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.num_patches = num_patches
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride_size)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.InstanceNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
# FIXME look at relaxing size constraints
|
||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
x = self.proj(x)
|
||||
x = x.flatten(2).transpose(1, 2) # [64, 8, 768]
|
||||
|
||||
return x
|
||||
|
||||
class TransReID(nn.Module):
|
||||
""" Transformer-based Object Re-Identification"""
|
||||
def __init__(self, img_size=224, patch_size=16, stride_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
||||
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., camera=0,
|
||||
drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, cam_lambda =3.0):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||
self.cam_num = camera
|
||||
self.cam_lambda = cam_lambda
|
||||
|
||||
self.patch_embed = PatchEmbed_overlap(img_size=img_size, patch_size=patch_size, stride_size=stride_size, in_chans=in_chans,embed_dim=embed_dim)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
||||
self.Cam = nn.Parameter(torch.zeros(camera, 1, embed_dim))
|
||||
|
||||
trunc_normal_(self.Cam, std=.02)
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
Block(
|
||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
|
||||
for i in range(depth)])
|
||||
|
||||
self.norm = norm_layer(embed_dim)
|
||||
|
||||
# Classifier head
|
||||
self.fc = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {'pos_embed', 'cls_token'}
|
||||
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=''):
|
||||
self.num_classes = num_classes
|
||||
self.fc = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_features(self, x, camera_id):
|
||||
B = x.shape[0]
|
||||
x = self.patch_embed(x)
|
||||
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x = x + self.pos_embed + self.cam_lambda * self.Cam[camera_id]
|
||||
x = self.pos_drop(x)
|
||||
|
||||
for blk in self.blocks[:-1]:
|
||||
x = blk(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, cam_label=None):
|
||||
x = self.forward_features(x, cam_label)
|
||||
return x
|
||||
|
||||
def load_param(self, model_path, load=False):
|
||||
print("Run load_param")
|
||||
if not load:
|
||||
param_dict = torch.load(model_path, map_location='cpu', weights_only=False)
|
||||
else:
|
||||
param_dict = model_path
|
||||
|
||||
if 'model' in param_dict:
|
||||
param_dict = param_dict['model']
|
||||
if 'state_dict' in param_dict:
|
||||
param_dict = param_dict['state_dict']
|
||||
|
||||
model_dict = self.state_dict()
|
||||
new_param_dict = {}
|
||||
|
||||
for k, v in param_dict.items():
|
||||
if 'head' in k or 'dist' in k:
|
||||
continue
|
||||
|
||||
if k in ['Cam', 'base.Cam'] and k in model_dict:
|
||||
expected_shape = model_dict[k].shape
|
||||
if v.shape[0] > expected_shape[0]:
|
||||
print(f"⚠️ Resizing '{k}' from {v.shape} to {expected_shape}")
|
||||
v = v[:expected_shape[0], :, :]
|
||||
elif v.shape[0] < expected_shape[0]:
|
||||
print(f"⚠️ Expanding '{k}' from {v.shape} to {expected_shape}")
|
||||
new_v = torch.randn(expected_shape)
|
||||
new_v[:v.shape[0], :, :] = v
|
||||
v = new_v
|
||||
new_param_dict[k] = v
|
||||
continue
|
||||
|
||||
if k in model_dict and model_dict[k].shape == v.shape:
|
||||
new_param_dict[k] = v
|
||||
|
||||
model_dict.update(new_param_dict)
|
||||
self.load_state_dict(model_dict, strict=False)
|
||||
print("✅ Checkpoint loaded successfully.")
|
||||
|
||||
def resize_pos_embed(posemb, posemb_new, hight, width):
|
||||
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
|
||||
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
||||
ntok_new = posemb_new.shape[1]
|
||||
|
||||
posemb_token, posemb_grid = posemb[:, :1], posemb[0, 1:]
|
||||
ntok_new -= 1
|
||||
|
||||
gs_old = int(math.sqrt(len(posemb_grid)))
|
||||
print('Resized position embedding from size:{} to size: {} with height:{} width: {}'.format(posemb.shape, posemb_new.shape, hight, width))
|
||||
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
||||
posemb_grid = F.interpolate(posemb_grid, size=(hight, width), mode='bilinear')
|
||||
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, hight * width, -1)
|
||||
posemb = torch.cat([posemb_token, posemb_grid], dim=1)
|
||||
return posemb
|
||||
|
||||
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
||||
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
||||
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
||||
def norm_cdf(x):
|
||||
# Computes standard normal cumulative distribution function
|
||||
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
||||
|
||||
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
||||
print("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
||||
"The distribution of values may be incorrect.",)
|
||||
|
||||
with torch.no_grad():
|
||||
# Values are generated by using a truncated uniform distribution and
|
||||
# then using the inverse CDF for the normal distribution.
|
||||
# Get upper and lower cdf values
|
||||
l = norm_cdf((a - mean) / std)
|
||||
u = norm_cdf((b - mean) / std)
|
||||
|
||||
# Uniformly fill tensor with values from [l, u], then translate to
|
||||
# [2l-1, 2u-1].
|
||||
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
||||
|
||||
# Use inverse cdf transform for normal distribution to get truncated
|
||||
# standard normal
|
||||
tensor.erfinv_()
|
||||
|
||||
# Transform to proper mean, std
|
||||
tensor.mul_(std * math.sqrt(2.))
|
||||
tensor.add_(mean)
|
||||
|
||||
# Clamp to ensure it's in the proper range
|
||||
tensor.clamp_(min=a, max=b)
|
||||
return tensor
|
||||
|
||||
|
||||
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
||||
# type: (Tensor, float, float, float, float) -> Tensor
|
||||
r"""Fills the input Tensor with values drawn from a truncated
|
||||
normal distribution. The values are effectively drawn from the
|
||||
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
||||
with values outside :math:`[a, b]` redrawn until they are within
|
||||
the bounds. The method used for generating the random values works
|
||||
best when :math:`a \leq \text{mean} \leq b`.
|
||||
Args:
|
||||
tensor: an n-dimensional `torch.Tensor`
|
||||
mean: the mean of the normal distribution
|
||||
std: the standard deviation of the normal distribution
|
||||
a: the minimum cutoff value
|
||||
b: the maximum cutoff value
|
||||
Examples:
|
||||
>>> w = torch.empty(3, 5)
|
||||
>>> nn.init.trunc_normal_(w)
|
||||
"""
|
||||
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
||||
Reference in New Issue
Block a user