This project aims at providing the necessary building blocks for easily creating detection and segmentation models using PyTorch 1.0.
- PyTorch 1.0: RPN, Faster R-CNN and Mask R-CNN implementations that matches or exceeds Detectron accuracies
- Very fast: up to 2x faster than Detectron and 30% faster than mmdetection during training. See MODEL_ZOO.md for more details.
- Memory efficient: uses roughly 500MB less GPU memory than mmdetection during training
- Multi-GPU training and inference
- Mixed precision training: trains faster with less GPU memory on NVIDIA tensor cores.
- Batched inference: can perform inference using multiple images per batch per GPU
- CPU support for inference: runs on CPU in inference time. See our webcam demo for an example
- Provides pre-trained models for almost all reference Mask R-CNN and Faster R-CNN configurations with 1x schedule.
We provide a simple webcam demo that illustrates how you can use maskrcnn_benchmark
for inference:
cd demo
# by default, it runs on the GPU
# for best results, use min-image-size 800
python webcam.py --min-image-size 800
# can also run it on the CPU
python webcam.py --min-image-size 300 MODEL.DEVICE cpu
# or change the model that you want to use
python webcam.py --config-file ../configs/caffe2/e2e_mask_rcnn_R_101_FPN_1x_caffe2.yaml --min-image-size 300 MODEL.DEVICE cpu
# in order to see the probability heatmaps, pass --show-mask-heatmaps
python webcam.py --min-image-size 300 --show-mask-heatmaps MODEL.DEVICE cpu
# for the keypoint demo
python webcam.py --config-file ../configs/caffe2/e2e_keypoint_rcnn_R_50_FPN_1x_caffe2.yaml --min-image-size 300 MODEL.DEVICE cpu
A notebook with the demo can be found in demo/Mask_R-CNN_demo.ipynb.
Check INSTALL.md for installation instructions.
Pre-trained models, baselines and comparison with Detectron and mmdetection can be found in MODEL_ZOO.md
We provide a helper class to simplify writing inference pipelines using pre-trained models.
Here is how we would do it. Run this from the demo
folder:
from maskrcnn_benchmark.config import cfg
from predictor import COCODemo
config_file = "../configs/caffe2/e2e_mask_rcnn_R_50_FPN_1x_caffe2.yaml"
# update the config options with the config file
cfg.merge_from_file(config_file)
# manual override some options
cfg.merge_from_list(["MODEL.DEVICE", "cpu"])
coco_demo = COCODemo(
cfg,
min_image_size=800,
confidence_threshold=0.7,
)
# load image and then run prediction
image = ...
predictions = coco_demo.run_on_opencv_image(image)
For the following examples to work, you need to first install maskrcnn_benchmark
.
You will also need to download the COCO dataset.
We recommend to symlink the path to the coco dataset to datasets/
as follows
We use minival
and valminusminival
sets from Detectron
# symlink the coco dataset
cd ~/github/maskrcnn-benchmark
mkdir -p datasets/coco
ln -s /path_to_coco_dataset/annotations datasets/coco/annotations
ln -s /path_to_coco_dataset/train2014 datasets/coco/train2014
ln -s /path_to_coco_dataset/test2014 datasets/coco/test2014
ln -s /path_to_coco_dataset/val2014 datasets/coco/val2014
# or use COCO 2017 version
ln -s /path_to_coco_dataset/annotations datasets/coco/annotations
ln -s /path_to_coco_dataset/train2017 datasets/coco/train2017
ln -s /path_to_coco_dataset/test2017 datasets/coco/test2017
ln -s /path_to_coco_dataset/val2017 datasets/coco/val2017
# for pascal voc dataset:
ln -s /path_to_VOCdevkit_dir datasets/voc
P.S. COCO_2017_train
= COCO_2014_train
+ valminusminival
, COCO_2017_val
= minival
You can also configure your own paths to the datasets.
For that, all you need to do is to modify maskrcnn_benchmark/config/paths_catalog.py
to
point to the location where your dataset is stored.
You can also create a new paths_catalog.py
file which implements the same two classes,
and pass it as a config argument PATHS_CATALOG
during training.
Most of the configuration files that we provide assume that we are running on 8 GPUs. In order to be able to run it on fewer GPUs, there are a few possibilities:
1. Run the following without modifications
python /path_to_maskrcnn_benchmark/tools/train_net.py --config-file "/path/to/config/file.yaml"
This should work out of the box and is very similar to what we should do for multi-GPU training. But the drawback is that it will use much more GPU memory. The reason is that we set in the configuration files a global batch size that is divided over the number of GPUs. So if we only have a single GPU, this means that the batch size for that GPU will be 8x larger, which might lead to out-of-memory errors.
If you have a lot of memory available, this is the easiest solution.
2. Modify the cfg parameters
If you experience out-of-memory errors, you can reduce the global batch size. But this means that you'll also need to change the learning rate, the number of iterations and the learning rate schedule.
Here is an example for Mask R-CNN R-50 FPN with the 1x schedule:
python tools/train_net.py --config-file "configs/e2e_mask_rcnn_R_50_FPN_1x.yaml" SOLVER.IMS_PER_BATCH 2 SOLVER.BASE_LR 0.0025 SOLVER.MAX_ITER 720000 SOLVER.STEPS "(480000, 640000)" TEST.IMS_PER_BATCH 1 MODEL.RPN.FPN_POST_NMS_TOP_N_TRAIN 2000
This follows the scheduling rules from Detectron. Note that we have multiplied the number of iterations by 8x (as well as the learning rate schedules), and we have divided the learning rate by 8x.
We also changed the batch size during testing, but that is generally not necessary because testing requires much less memory than training.
Furthermore, we set MODEL.RPN.FPN_POST_NMS_TOP_N_TRAIN 2000
as the proposals are selected for per the batch rather than per image in the default training. The value is calculated by 1000 x images-per-gpu. Here we have 2 images per GPU, therefore we set the number as 1000 x 2 = 2000. If we have 8 images per GPU, the value should be set as 8000. Note that this does not apply if MODEL.RPN.FPN_POST_NMS_PER_BATCH
is set to False
during training. See #672 for more details.
We use internally torch.distributed.launch
in order to launch
multi-gpu training. This utility function from PyTorch spawns as many
Python processes as the number of GPUs we want to use, and each Python
process will only use a single GPU.
export NGPUS=8
python -m torch.distributed.launch --nproc_per_node=$NGPUS /path_to_maskrcnn_benchmark/tools/train_net.py --config-file "path/to/config/file.yaml" MODEL.RPN.FPN_POST_NMS_TOP_N_TRAIN images_per_gpu x 1000
Note we should set MODEL.RPN.FPN_POST_NMS_TOP_N_TRAIN
follow the rule in Single-GPU training.
We currently use APEX to add Automatic Mixed Precision support. To enable, just do Single-GPU or Multi-GPU training and set DTYPE "float16"
.
export NGPUS=8
python -m torch.distributed.launch --nproc_per_node=$NGPUS /path_to_maskrcnn_benchmark/tools/train_net.py --config-file "path/to/config/file.yaml" MODEL.RPN.FPN_POST_NMS_TOP_N_TRAIN images_per_gpu x 1000 DTYPE "float16"
If you want more verbose logging, set AMP_VERBOSE True
. See Mixed Precision Training guide for more details.
You can test your model directly on single or multiple gpus. Here is an example for Mask R-CNN R-50 FPN with the 1x schedule on 8 GPUS:
export NGPUS=8
python -m torch.distributed.launch --nproc_per_node=$NGPUS /path_to_maskrcnn_benchmark/tools/test_net.py --config-file "configs/e2e_mask_rcnn_R_50_FPN_1x.yaml" TEST.IMS_PER_BATCH 16
To calculate mAP for each class, you can simply modify a few lines in coco_eval.py. See #524 for more details.
For more information on some of the main abstractions in our implementation, see ABSTRACTIONS.md.
This implementation adds support for COCO-style datasets. But adding support for training on a new dataset can be done as follows:
from maskrcnn_benchmark.structures.bounding_box import BoxList
class MyDataset(object):
def __init__(self, ...):
# as you would do normally
def __getitem__(self, idx):
# load the image as a PIL Image
image = ...
# load the bounding boxes as a list of list of boxes
# in this case, for illustrative purposes, we use
# x1, y1, x2, y2 order.
boxes = [[0, 0, 10, 10], [10, 20, 50, 50]]
# and labels
labels = torch.tensor([10, 20])
# create a BoxList from the boxes
boxlist = BoxList(boxes, image.size, mode="xyxy")
# add the labels to the boxlist
boxlist.add_field("labels", labels)
if self.transforms:
image, boxlist = self.transforms(image, boxlist)
# return the image, the boxlist and the idx in your dataset
return image, boxlist, idx
def get_img_info(self, idx):
# get img_height and img_width. This is used if
# we want to split the batches according to the aspect ratio
# of the image, as it can be more efficient than loading the
# image from disk
return {"height": img_height, "width": img_width}
That's it. You can also add extra fields to the boxlist, such as segmentation masks
(using structures.segmentation_mask.SegmentationMask
), or even your own instance type.
For a full example of how the COCODataset
is implemented, check maskrcnn_benchmark/data/datasets/coco.py
.
Once you have created your dataset, it needs to be added in a couple of places:
maskrcnn_benchmark/data/datasets/__init__.py
: add it to__all__
maskrcnn_benchmark/config/paths_catalog.py
:DatasetCatalog.DATASETS
and correspondingif
clause inDatasetCatalog.get()
While the aforementioned example should work for training, we leverage the cocoApi for computing the accuracies during testing. Thus, test datasets should currently follow the cocoApi for now.
To enable your dataset for testing, add a corresponding if statement in maskrcnn_benchmark/data/datasets/evaluation/__init__.py
:
if isinstance(dataset, datasets.MyDataset):
return coco_evaluation(**args)
Create a script tools/trim_detectron_model.py
like here.
You can decide which keys to be removed and which keys to be kept by modifying the script.
Then you can simply point the converted model path in the config file by changing MODEL.WEIGHT
.
For further information, please refer to #15.
If you have issues running or compiling this code, we have compiled a list of common issues in TROUBLESHOOTING.md. If your issue is not present there, please feel free to open a new issue.
Please consider citing this project in your publications if it helps your research. The following is a BibTeX reference. The BibTeX entry requires the url
LaTeX package.
@misc{massa2018mrcnn,
author = {Massa, Francisco and Girshick, Ross},
title = {{maskrcnn-benchmark: Fast, modular reference implementation of Instance Segmentation and Object Detection algorithms in PyTorch}},
year = {2018},
howpublished = {\url{https://github.com/facebookresearch/maskrcnn-benchmark}},
note = {Accessed: [Insert date here]}
}
- RetinaMask: Learning to predict masks improves state-of-the-art single-shot detection for free. Cheng-Yang Fu, Mykhailo Shvets, and Alexander C. Berg. Tech report, arXiv,1901.03353.
- FCOS: Fully Convolutional One-Stage Object Detection. Zhi Tian, Chunhua Shen, Hao Chen and Tong He. Tech report, arXiv,1904.01355. [code]
maskrcnn-benchmark is released under the MIT license. See LICENSE for additional details.