Skip to content

[CVPR'2022] SAM-DETR & SAM-DETR++: Official PyTorch Implementation

License

Notifications You must be signed in to change notification settings

ZhangGongjie/SAM-DETR

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

50 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

SAM-DETR (Semantic-Aligned-Matching DETR)

arXiv Survey Maintenance PR's Welcome GitHub license

This repository is an official PyTorch implementation of the CVPR 2022 paper "Accelerating DETR Convergence via Semantic-Aligned Matching".

[UPDATE on 21 Apr 2022]   We found that with a very simple modification (with no extra computational cost), SAM-DETR can achieve better performance. On MS-COCO, SAM-DETR w/ SMCA can achieve 37.0 AP within 12 epochs, and 42.7 AP within 50 epochs. We will release the updated training scripts, model weights, and logs in the future. Please stay tuned!

Introduction

TL;DR   SAM-DETR is an efficeint DETR-like object detector that can converge wihtin 12 epochs and outperform the strong Faster R-CNN (w/ FPN) baseline.

The recently developed DEtection TRansformer (DETR) has established a new object detection paradigm by eliminating a series of hand-crafted components. However, DETR suffers from extremely slow convergence, which increases the training cost significantly. We observe that the slow convergence can be largely attributed to the complication in matching object queries to encoded image features in DETR's decoder cross-attention modules.

Motivated by this observation, in our paper, we propose SAM-DETR, a Semantic-Aligned-Matching DETR that can greatly accelerates DETR's convergence without sacrificing its accuracy. SAM-DETR addresses the slow convergence issue from two perspectives. First, it projects object queries into the same embedding space as encoded image features, where the matching can be accomplished efficiently with aligned semantics. Second, it explicitly searches salient points with the most discriminative features for semantic-aligned matching, which further speeds up the convergence and boosts detection accuracy as well. Being like a plug and play, SAM-DETR complements existing convergence solutions well yet only introduces slight computational overhead. Experiments show that the proposed SAM-DETR achieves superior convergence as well as competitive detection accuracy.

At the core of SAM-DETR is a plug-and-play module named "Semantics Aligner" appended ahead of the cross-attention module in DETR's each decoder layer. It also models a learnable reference box for each object query, whose center location is used to generate corresponding position embeddings.

The figure below illustrates the architecture of the appended "Semantics Aligner", which aligns the semantics of "encoded image features" and "object queries" by re-sampling features from multiple salient points as new object queries.

Being like a plug-and-play, our approach can be easily integrated with existing convergence solutions (e.g., SMCA) in a complementary manner, boosting detection accuracy and convergence speed further.

Please check our CVPR 2022 paper for more details.

Installation

Pre-Requisites

You must have NVIDIA GPUs to run the codes.

The implementation codes are developed and tested with the following environment setups:

  • Linux
  • 8x NVIDIA V100 GPUs (32GB)
  • CUDA 10.1
  • Python == 3.8
  • PyTorch == 1.8.1+cu101, TorchVision == 0.9.1+cu101
  • GCC == 7.5.0
  • cython, pycocotools, tqdm, scipy

We recommend using the exact setups above. However, other environments (Linux, Python>=3.7, CUDA>=9.2, GCC>=5.4, PyTorch>=1.5.1, TorchVision>=0.6.1) should also work.

Code Installation

First, clone the repository locally:

git clone https://github.com/ZhangGongjie/SAM-DETR.git

We recommend you to use Anaconda to create a conda environment:

conda create -n sam_detr python=3.8 pip

Then, activate the environment:

conda activate sam_detr

Then, install PyTorch and TorchVision:

(preferably using our recommended setups; CUDA version should match your own local environment)

conda install pytorch=1.8.1 torchvision=0.9.1 cudatoolkit=10.1 -c pytorch

After that, install other requirements:

conda install cython scipy tqdm
pip install -U 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'

[Optional]   If you wish to run multi-scale version of SAM-DETR (results not reported in the CVPR paper), you need to compile Deformable Attention, which is used in DETR encoder to generate feature pyramid efficiently. If you don't need multi-scale version of SAM-DETR, you may skip this step.

# Optionally compile CUDA operators of Deformable Attention for multi-scale SAM-DETR
cd SAM-DETR
cd ./models/ops
sh ./make.sh
python test.py  # unit test (should see all checking is True)

Data Preparation

Please download COCO 2017 dataset and organize them as following:

code_root/
└── data/
    └── coco/
        ├── train2017/
        ├── val2017/
        └── annotations/
        	├── instances_train2017.json
        	└── instances_val2017.json

Usage

Reproducing Paper Results

All scripts to reproduce results reported in our CVPR 2022 paper are stored in ./scripts. We also provide scripts for slurm cluster, which are stored in ./scripts_slurm.

Taking SAM-DETR-R50 w/ SMCA (12 epochs) for example, to reproduce its results, simply run:

bash scripts/r50_smca_e12_4gpu.sh

Taking SAM-DETR-R50 multiscale w/ SMCA (50 epochs) for example, to reproduce its results on a slurm cluster, simply run:

bash scripts_slurm/r50_ms_smca_e50_8gpu.sh

Reminder: To reproduce results, please make sure the total batch size matches the implementation details described in our paper. For R50 (single-scale) experiments, we use 4 GPUs with a batch size of 4 on each GPU. For R50 (multi-scale) experiments, we use 8 GPUs with a batch size of 2 on each GPU. For R50-DC5 (single-scale) experiments, we use 8 GPUs with a batch size of 1 on each GPU.

Training

To perform training on COCO train2017, modify the arguments based on the scripts below:

python -m torch.distributed.launch \
    --nproc_per_node=4 \        # number of GPUs to perform training
    --use_env main.py \
    --batch_size 4 \            # batch_size on individual GPU (this is *NOT* total batch_size)
    --smca \                    # to integrate with SMCA, remove this line to disable SMCA
    --dilation \                # to enable DC5, remove this line to disable DC5
    --multiscale \              # to enable multi-scale, remove this line to disable multiscale
    --epochs 50 \               # total number of epochs to train
    --lr_drop 40 \              # when to drop learning rate
    --output_dir output/xxxx    # where to store outputs, remove this line for not storing outputs

More arguments and their explanations are available at main.py.

Evaluation

To evaluate a model on COCO val2017, simply add --resume and --eval arguments to your training scripts:

python -m torch.distributed.launch \
    --nproc_per_node=4 \
    --use_env main.py \
    --batch_size 4 \
    --smca \
    --dilation \                
    --multiscale \ 
    --epochs 50 \
    --lr_drop 40 \ 
    --resume <path/to/checkpoint.pth> \   # trained model weights
    --eval \                              # this means that only evaluation will be performed
    --output_dir output/xxxx   

Visualize Detection Results

We provide demo.py, which is a minimal implementation that allows users to visualize model's detection predictions. It performs detection on images inside the ./images folder, and stores detection visualizations in that folder. Taking SAM-DETR-R50 w/ SMCA (50 epochs) for example, simply run:

python demo.py \                       # do NOT use distributed mode
    --smca \
    --epochs 50 \                      # you need to set this correct. See models/fast_detr.py L50-79 for details.
    --resume <path/to/checkpoint.pth>  # trained model weights

Model Zoo

Trained model weights are stored in Google Drive.

The original DETR models trained for 500 epochs:

Method Epochs Params (M) GFLOPs AP URL
DETR-R50 500 41 86 42.0 log
DETR-R50-DC5 500 41 187 43.3 log

Our proposed SAM-DETR models (results reported in our CVPR paper):

Method Epochs Params (M) GFLOPs AP URL
SAM-DETR-R50 12 58 100 33.1 model
log
SAM-DETR-R50 w/ SMCA 12 58 100 36.0 model
log
SAM-DETR-R50-DC5 12 58 210 38.3 model
log
SAM-DETR-R50-DC5 w/ SMCA 12 58 210 40.6 model
log
SAM-DETR-R50 50 58 100 39.8 model
log
SAM-DETR-R50 w/ SMCA 50 58 100 41.8 model
log
SAM-DETR-R50-DC5 50 58 210 43.3 model
log
SAM-DETR-R50-DC5 w/ SMCA 50 58 210 45.0 model
log

Our proposed multi-scale SAM-DETR models (results to appear in a journal extension):

Method Epochs Params (M) GFLOPs AP URL
SAM-DETR-R50-MS 12 55 203 41.1 model
log
SAM-DETR-R50-MS w/ SMCA 12 55 203 42.8 model
log
SAM-DETR-R50-MS 50 55 203 46.1 model
log
SAM-DETR-R50-MS w/ SMCA 50 55 203 47.1 model
log

Note:

  1. AP is computed on COCO val2017.
  2. "DC5" means removing the stride in C5 stage of ResNet and add a dilation of 2 instead.
  3. The GFLOPs of our models are estimated using fvcore on the first 100 images in COCO val2017. GFLOPs varies as input image sizes change. There may exist slight difference from actual values.

License

The implementation codes of SAM-DETR are released under the MIT license.

Please see the LICENSE file for more information.

However, prior works' licenses also apply. It is your responsibility to ensure you comply with all license requirements.

Citation

If you find SAM-DETR useful or inspiring, please consider citing:

@inproceedings{zhang2022-SAMDETR,
  title      = {Accelerating {DETR} Convergence via Semantic-Aligned Matching},
  author     = {Zhang, Gongjie and Luo, Zhipeng and Yu, Yingchen and Cui, Kaiwen and Lu, Shijian},
  booktitle  = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
  pages      = {949-958},
  year       = {2022},
}

Acknowledgement

Our SAM-DETR is heavily inspired by many outstanding prior works, including DETR, Conditional-DETR, SMCA-DETR, and Deformable DETR. Thank the authors of above projects for open-sourcing their implementation codes!