Skip to content

Commit

Permalink
Fix metric for small distances (#32)
Browse files Browse the repository at this point in the history
There was a bug in the metrics that caused the heatmap generation and hence the metric to fail for d=1 if the keypoints were annotated as floats.
This is now fixed by casting keypoints to the top-left (zero-index as in COCO) keypoint before training.

make sure keypoints are cast to zero-indexed integers before training
make sure heatmaps are centered on that pixel
make sure the AP for d=0 (i.e pixel-perfect keypoints) goes up to 1 for a dummy dataset (integration testing)

**These changes will break reproducibility**


commits:

* scrutinize the metric calculations:

- all keypoints are in ints (and represent topleft corner of their pixel, aka zero-indexed)
- heatmaps are really centered on the int pixel
- metric max distances are ints and represent <= L2 distances (i.e. pixel-perfect == dmax = 0)
- integration test: dummy dataset results in good scores for all mAP threholds down to 0

* remove deprecated test that is now broken

* update readme

* extend readme

* extend readme
  • Loading branch information
tlpss authored Aug 30, 2023
1 parent e88c5f7 commit 9f39a8e
Show file tree
Hide file tree
Showing 9 changed files with 231 additions and 56 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
.vscode/**

datasets/
scripts/dummy_dataset/
**wandb/
lightning_logs**
**.ckpt
Expand Down
55 changes: 46 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<h1 align="center">Pytorch Keypoint Detection</h1>

This repo contains a Python package for 2D keypoint detection using [Pytorch Lightning](https://pytorch-lightning.readthedocs.io/en/latest/) and [wandb](https://docs.wandb.ai/). Keypoints are trained using Gaussian Heatmaps, as in [Jakab et Al.](https://proceedings.neurips.cc/paper/2018/hash/1f36c15d6a3d18d52e8d493bc8187cb9-Abstract.html) or [Centernet](https://github.com/xingyizhou/CenterNet).
A Framework for keypoint detection using [Pytorch Lightning](https://pytorch-lightning.readthedocs.io/en/latest/) and [wandb](https://docs.wandb.ai/). Keypoints are trained with Gaussian Heatmaps, as in [Jakab et Al.](https://proceedings.neurips.cc/paper/2018/hash/1f36c15d6a3d18d52e8d493bc8187cb9-Abstract.html) or [Centernet](https://github.com/xingyizhou/CenterNet).

This package is been used for research at the [AI and Robotics](https://airo.ugent.be/projects/computervision/) research group at Ghent University. You can see some applications below: The first image shows how this package is used to detect corners of cardboard boxes, in order to close the box with a robot. The second example shows how it is used to detect a varying number of flowers.
<div align="center">
Expand All @@ -10,15 +10,16 @@ This package is been used for research at the [AI and Robotics](https://airo.uge


## Main Features
- The detector can deal with an **arbitrary number of keypoint channels**, that can contain **a varying amount of keypoints**. You can easily configure which keypoint types from the COCO dataset should be mapped onto the different channels of the keypoint detector. This flexibility allows to e.g. combine different semantic locations that have symmetries onto the same channel to overcome this ambiguity.
- We use the standard **COCO dataset format**.

- This package contains **different backbones** (Unet-like, dilated CNN, Unet-like with pretrained ConvNeXt encoder). Furthermore you can easily add new backbones or loss functions. The head of the keypoint detector is a single CNN layer.
- The package uses the often-used **COCO dataset format**.
- The detector can deal with an **arbitrary number of keypoint channels**, that can contain **a varying amount of keypoints**. You can easily configure which keypoint types from the COCO dataset should be mapped onto the different channels of the keypoint detector.
- The package contains an implementation of the Average Precision metric for keypoint detection.
- Extensive **logging to wandb is provided**: The loss for each channel is logged, together with the AP metrics for all specified treshold distances. Furthermore, the raw heatmaps, detected keypoints and ground truth heatmaps are logged at every epoch for the first batch to provide insight in the training dynamics and to verify all data processing is as desired.
- **different backbones** can be used (Unet-like, dilated CNN, Unet-like with pretrained encoders). Furthermore you can easily add new backbones or loss functions. The head of the keypoint detector is a single CNN layer.

- The package contains an implementation of the Average Precision metric for keypoint detection. The threshold distance for classification of detections as FP or TP is based on L2 distance between the keypoints and ground truth keypoints.
- Extensive **logging to wandb is provided**: The train/val loss for each channel is logged, together with the AP metrics for all specified treshold distances and all channels. Furthermore, the raw heatmaps, detected keypoints and ground truth heatmaps are logged to provide insight in the training dynamics and to verify all data processing is as desired.
- All **hyperparameters are configurable** using a python argumentparser or wandb sweeps.

note: this is the second version of the package, for the older version that used a custom dataset format, see the github releases.
note: this package is still under development and we make no commitment on backwards compatibility nor reproducibility on the main branch. If you need this, it is best to pin a single commit.


TODO: add integration example.
Expand All @@ -43,7 +44,9 @@ For an example, see the `test_dataset` at `test/test_dataset`.


### Labeling
If you want to label data, we provide integration with the [CVAT](https://github.com/opencv/cvat) labeling tool: You can annotate your data and export it in their custom format, which can then be converted to COCO format. Take a look [here](labeling/Readme.md) for more information on this workflow and an example. To visualize a given dataset, you can use the `keypoint_detection/utils/visualization.py` script.
If you want to label data, we use[CVAT](https://github.com/opencv/cvat) labeling tool. The flow and the code to create COCO keypoints datasets is all available in the [airo-dataset-tools](https://github.com/airo-ugent/airo-mono/tree/main) package.

It is best to label your data with floats that represent the subpixel location of the keypoints. This allows for more precise resizing of the images later on. The keypoint detector cast them to ints before training to obtain the pixel they belong to (it does not support sub-pixel detections).

## Training

Expand All @@ -57,6 +60,27 @@ A minimal sweep example is given in `test/configuration.py`. The same content s

To create your own configuration: run `python train.py -h` to see all parameter options and their documentation.

## Metrics

TO calculate AP, precision or recall, the detections need to be classified into False Positives and False negatives as for object detection or instance segmentation.

This package simply uses a number of euclidian pixel distance thresholds. You can set the euclidian distances for which you want to calculate the metrics in the hyperparameters.

Pixel perfect keypoints have a pixel distance of 0, so if you want a metric for pixel-perfect keypoints you should add a threshold distance of 0.

Usually it is best to calculate the real-world deviations (in cm) that are acceptable and then determine the threshold(s) (in pixels) you are interested in.

In general a lower threshold will result in a lower metric. The size of this gap is determined by the 'ambiguity' of your dataset and/or the accuracy of your labels.

#TODO: add a figure to illustrate this.


We do not use OKS as in COCO for 2 reasons:
1. it requires bbox annotations, which are not always required for keypoint detection itself and represent additional label effort.
2. More importantly, in robotics the size of an object does not always correlate with the required precision. If a large and a small mug stand on a table, they require the same precise localisation of keypoints for a robot to grasp them even though their apparent size is different.
3. (you need to estimate label variance, though you could simply set k=1 and skip this part)


## Using a trained model (Inference)
During training Pytorch Lightning will have saved checkpoints. See `scripts/checkpoint_inference.py` for a simple example to run inference with a checkpoint.
For benchmarking the inference (or training), see `scripts/benchmark.py`.
Expand All @@ -67,7 +91,20 @@ For benchmarking the inference (or training), see `scripts/benchmark.py`.


## Note on performance
- Keep in mind that the Average Precision is a very expensive operation, it can easily take as long to calculate the AP of a .1 data split as it takes to train on the remaining 90% of the data. Therefore it makes sense to use the metric sparsely. The AP will always be calculated at the final epoch, so for optimal train performance (w/o intermediate feedback), you can e.g. set the `ap_epoch_start` parameter to your max number of epochs + 1.
- Keep in mind that calculating the Average Precision is expensive operation, it can easily take as long to calculate the AP of a .1 data split as it takes to train on the remaining 90% of the data. Therefore it makes sense to use the metric sparsely, for which hyperparameters are available. The AP will always be calculated at the final epoch.

## Note on top-down vs. bottom-up keypoint detection.
There are 2 ways to do keypoint detection when multiple instances are present in an image:
1. first do instance detection and then detect keypoints on a crop of the bbox for each instance
2. detect keypoints on the full image.

Option 1 suffers from compounding errors (if the instance is not detected, no keypoints will be detected) and/or requires you to train (and hence label) an object detector.
Option 2 can have lower performance for the keypoints (more 'noise' in the image that can distract the detector) and if you have multiple keypoints / instance as well as multiple instances per image, you need to do keypoint association.

This repo is somewhat agnostic to that choice.
For 1: crop your dataset upfront and train the detector on those crops, at inference: chain the object detector and the keypoint detector.
for 2: If you can do the association manually, simply do it after inference. However this repo does not offer learning the associations as in the [Part Affinity Fields]() paper.


## Rationale:
TODO
Expand Down
13 changes: 13 additions & 0 deletions keypoint_detection/data/coco_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import json
import math
import typing
from collections import defaultdict
from pathlib import Path
Expand Down Expand Up @@ -97,6 +98,18 @@ def __getitem__(self, index) -> Tuple[torch.Tensor, IMG_KEYPOINTS_TYPE]:
image = image[..., :3]

keypoints = self.dataset[index][1]

# convert all keypoints to integers values.
# COCO keypoints can be floats if they specify the exact location of the keypoint (e.g. from CVAT)
# even though COCO format specifies zero-indexed integers (i.e. every keypoint in the [0,1]x [0.1] pixel box becomes (0,0)
# we convert them to ints here, as the heatmap generation will add a 0.5 offset to the keypoint location to center it in the pixel
# the distance metrics also operate on integer values.

# so basically from here on every keypoint is an int that represents the pixel-box in which the keypoint is located.
keypoints = [
[[math.floor(keypoint[0]), math.floor(keypoint[1])] for keypoint in channel_keypoints]
for channel_keypoints in keypoints
]
if self.transform:
transformed = self.transform(image=image, keypoints=keypoints)
image, keypoints = transformed["image"], transformed["keypoints"]
Expand Down
2 changes: 2 additions & 0 deletions keypoint_detection/data/coco_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ class CocoKeypointAnnotation(BaseModel):
image_id: ImageID

num_keypoints: Optional[int]
# COCO keypoints can be floats if they specify the exact location of the keypoint (e.g. from CVAT)
# even though COCO format specifies zero-indexed integers (i.e. every keypoint in the [0,1]x [0.1] pixel box becomes (0,0)
keypoints: List[float]

# TODO: add checks.
Expand Down
4 changes: 2 additions & 2 deletions keypoint_detection/models/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def __init__(
# parse the gt pixel distances
if isinstance(maximal_gt_keypoint_pixel_distances, str):
maximal_gt_keypoint_pixel_distances = [
float(val) for val in maximal_gt_keypoint_pixel_distances.strip().split(" ")
int(val) for val in maximal_gt_keypoint_pixel_distances.strip().split(" ")
]
self.maximal_gt_keypoint_pixel_distances = maximal_gt_keypoint_pixel_distances

Expand Down Expand Up @@ -395,7 +395,7 @@ def compute_and_log_metrics_for_channel(
ap_metrics = metrics.compute()
print(f"{ap_metrics=}")
for maximal_distance, ap in ap_metrics.items():
self.log(f"{training_mode}/{channel}_ap/d={maximal_distance}", ap)
self.log(f"{training_mode}/{channel}_ap/d={float(maximal_distance):.1f}", ap)

mean_ap = sum(ap_metrics.values()) / len(ap_metrics.values())

Expand Down
11 changes: 6 additions & 5 deletions keypoint_detection/models/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,17 @@ class ClassifiedKeypoint(DetectedKeypoint):
unsafe_hash -> dirty fix to allow for hash w/o explictly telling python the object is immutable.
"""

threshold_distance: float
threshold_distance: int
true_positive: bool


def keypoint_classification(
detected_keypoints: List[DetectedKeypoint],
ground_truth_keypoints: List[Keypoint],
threshold_distance: float,
threshold_distance: int,
) -> List[ClassifiedKeypoint]:
"""Classifies keypoints of a **single** frame in True Positives or False Positives by searching for unused gt keypoints in prediction probability order
that are within distance d of the detected keypoint.
that are within distance d of the detected keypoint (greedy matching).
Args:
detected_keypoints (List[DetectedKeypoint]): The detected keypoints in the frame
Expand All @@ -73,7 +73,8 @@ def keypoint_classification(
matched = False
for gt_keypoint in ground_truth_keypoints:
distance = detected_keypoint.l2_distance(gt_keypoint)
if distance < threshold_distance:
# add small epsilon to avoid numerical errors
if distance <= threshold_distance + 1e-5:
classified_keypoint = ClassifiedKeypoint(
detected_keypoint.u,
detected_keypoint.v,
Expand Down Expand Up @@ -209,7 +210,7 @@ class KeypointAPMetrics(Metric):

full_state_update = False

def __init__(self, keypoint_threshold_distances: List[float], dist_sync_on_step=False):
def __init__(self, keypoint_threshold_distances: List[int], dist_sync_on_step=False):
super().__init__(dist_sync_on_step=dist_sync_on_step)

self.ap_metrics = [KeypointAPMetric(dst, dist_sync_on_step) for dst in keypoint_threshold_distances]
Expand Down
4 changes: 0 additions & 4 deletions keypoint_detection/utils/heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,6 @@ def generate_channel_heatmap(
Torch.tensor: A Tensor with the combined heatmaps of all keypoints.
"""

# cast keypoints (center) to ints to make grid align with pixel raster.
# Otherwise, the AP metric for d = 1 will not result in 1
# if the gt_heatmaps are used as input.

assert isinstance(keypoints, torch.Tensor)

if keypoints.numel() == 0:
Expand Down
161 changes: 161 additions & 0 deletions scripts/generate_dataset.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Generate a COCO keypoints dataset of black images with circles on it for integration testing of the keypoint detector. \n"
]
},
{
"cell_type": "code",
"execution_count": 192,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: distinctipy in /fast_storage_2/symlinked_homes/tlips/conda/.conda/envs/keypoint-detection/lib/python3.9/site-packages (1.2.2)\n",
"Requirement already satisfied: numpy in /home/tlips/.local/lib/python3.9/site-packages (from distinctipy) (1.25.2)\n"
]
}
],
"source": [
"import cv2\n",
"import numpy as np \n",
"from airo_dataset_tools.data_parsers.coco import CocoKeypointAnnotation, CocoImage, CocoKeypointCategory, CocoKeypointsDataset\n",
"import pathlib\n",
"!pip install distinctipy\n",
"import distinctipy"
]
},
{
"cell_type": "code",
"execution_count": 193,
"metadata": {},
"outputs": [],
"source": [
"n_images = 500\n",
"n_categories = 2\n",
"max_category_instances_per_image = 2\n",
"\n",
"image_resolution = (128, 128)\n",
"circle_radius = 1"
]
},
{
"cell_type": "code",
"execution_count": 194,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"DATA_DIR = pathlib.Path(\"./dummy_dataset\")\n",
"DATA_DIR.mkdir(exist_ok=True)\n",
"IMAGE_DIR = DATA_DIR / \"images\"\n",
"IMAGE_DIR.mkdir(exist_ok=True)"
]
},
{
"cell_type": "code",
"execution_count": 195,
"metadata": {},
"outputs": [],
"source": [
"categories = []\n",
"for category_idx in range(n_categories):\n",
" coco_category = CocoKeypointCategory(\n",
" id=category_idx,\n",
" name=f\"dummy{category_idx}\",\n",
" supercategory=f\"dummy{category_idx}\",\n",
" keypoints=[f\"dummy{category_idx}\"]\n",
" )\n",
" categories.append(coco_category)"
]
},
{
"cell_type": "code",
"execution_count": 196,
"metadata": {},
"outputs": [],
"source": [
"category_colors = distinctipy.get_colors(n_categories)\n",
"category_colors = [tuple([int(c * 255) for c in color]) for color in category_colors]"
]
},
{
"cell_type": "code",
"execution_count": 197,
"metadata": {},
"outputs": [],
"source": [
"coco_images = []\n",
"cococ_annotations = []\n",
"\n",
"coco_instances_coutner = 0\n",
"for image_idx in range(n_images):\n",
" img = np.zeros((image_resolution[1],image_resolution[0],3), dtype=np.uint8)\n",
" coco_images.append(CocoImage(id=image_idx, file_name=f\"images/img_{image_idx}.png\", height=image_resolution[1], width=image_resolution[0]))\n",
" for category_idx in range(n_categories):\n",
" n_instances = np.random.randint(0, max_category_instances_per_image+1)\n",
" for instance_idx in range(n_instances):\n",
" x = np.random.randint(2, image_resolution[0])\n",
" y = np.random.randint(2, image_resolution[1])\n",
" img = cv2.circle(img, (x, y), circle_radius, category_colors[category_idx], -1)\n",
" cococ_annotations.append(CocoKeypointAnnotation(\n",
" id=coco_instances_coutner,\n",
" image_id=image_idx,\n",
" category_id=category_idx,\n",
" # as in coco datasets: zero-index, INT keypoints.\n",
" # but add some random noise (simulating dataset with the exact pixel location instead of the zero-index int location)\n",
" # to test if the detector can deal with this\n",
" keypoints=[x + np.random.rand(1).item(), y + np.random.rand(1).item(), 1],\n",
" num_keypoints=1,\n",
" ))\n",
" coco_instances_coutner += 1\n",
"\n",
" cv2.imwrite(str(DATA_DIR / \"images\"/f\"img_{image_idx}.png\"), img)\n",
"\n",
"coco_dataset = CocoKeypointsDataset(\n",
" images=coco_images,\n",
" annotations=cococ_annotations,\n",
" categories=categories,\n",
")\n",
"\n",
"with open(DATA_DIR / \"dummy_dataset.json\", \"w\") as f:\n",
" f.write(coco_dataset.json())\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "keypoint-detection",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit 9f39a8e

Please sign in to comment.