Skip to content

Commit

Permalink
add SQR for Deformable DETR (#8579)
Browse files Browse the repository at this point in the history
* add sqr
  • Loading branch information
flytocc authored Sep 24, 2023
1 parent bb0a42e commit efcb6ad
Show file tree
Hide file tree
Showing 8 changed files with 278 additions and 6 deletions.
27 changes: 27 additions & 0 deletions configs/sqr/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Enhanced Training of Query-Based Object Detection via Selective Query Recollection


## Introduction
This paper investigates a phenomenon where query-based object detectors mispredict at the last decoding stage while predicting correctly at an intermediate stage. It design and present Selective Query Recollection (SQR), a simple and effective training strategy for query-based object detectors. It cumulatively collects intermediate queries as decoding stages go deeper and selectively forwards the queries to the downstream stages aside from the sequential structure.


## Model Zoo

| Backbone | Model | Images/GPU | GPUs | Epochs | Box AP | Config | Download |
|:--------:|:-------------------:|:----------:|:----:|:------:|:------:|:------------------------------------------------:|:---------:|
| R-50 | Deformable DETR SQR | 1 | 4 | 12 | 32.9 | [config](./deformable_detr_sqr_r50_12e_coco.yml) |[model](https://bj.bcebos.com/v1/paddledet/models/deformable_detr_sqr_r50_12e_coco.pdparams) |

> We did not find the config for the 12 epochs experiment in the paper, which we wrote ourselves with reference to the standard 12 epochs config in mmdetection. The same accuracy was obtained in the official project and in this project with this [config](./deformable_detr_sqr_r50_12e_coco.yml). <br> We haven't finished validating the 50 epochs experiment yet, if you need the config, please refer to [here](https://pan.baidu.com/s/1eWavnAiRoFXm3mMlpn9WPw?pwd=3z6m).

## Citations
```
@InProceedings{Chen_2023_CVPR,
author = {Chen, Fangyi and Zhang, Han and Hu, Kai and Huang, Yu-Kai and Zhu, Chenchen and Savvides, Marios},
title = {Enhanced Training of Query-Based Object Detection via Selective Query Recollection},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2023},
pages = {23756-23765}
}
```
50 changes: 50 additions & 0 deletions configs/sqr/_base_/deformable_detr_sqr_r50.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
architecture: DETR
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_vb_normal_pretrained.pdparams
hidden_dim: 256
use_focal_loss: True


DETR:
backbone: ResNet
transformer: QRDeformableTransformer
detr_head: DeformableDETRHead
post_process: DETRPostProcess


ResNet:
# index 0 stands for res2
depth: 50
norm_type: bn
freeze_at: 0
return_idx: [1, 2, 3]
lr_mult_list: [0.0, 0.1, 0.1, 0.1]
num_stages: 4


QRDeformableTransformer:
num_queries: 300
position_embed_type: sine
nhead: 8
num_encoder_layers: 6
num_decoder_layers: 6
dim_feedforward: 1024
dropout: 0.1
activation: relu
num_feature_levels: 4
num_encoder_points: 4
num_decoder_points: 4
start_q: [0, 0, 1, 2, 4, 7, 12]
end_q: [1, 2, 4, 7, 12, 20, 33]


DeformableDETRHead:
num_mlp_layers: 3


DETRLoss:
loss_coeff: {class: 2, bbox: 5, giou: 2}
aux_loss: True


HungarianMatcher:
matcher_coeff: {class: 2, bbox: 5, giou: 2}
44 changes: 44 additions & 0 deletions configs/sqr/_base_/deformable_detr_sqr_reader.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
worker_num: 4
TrainReader:
sample_transforms:
- Decode: {}
- RandomFlip: {prob: 0.5}
- RandomSelect: { transforms1: [ RandomShortSideResize: { short_side_sizes: [ 480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800 ], max_size: 1333 } ],
transforms2: [
RandomShortSideResize: { short_side_sizes: [ 400, 500, 600 ] },
RandomSizeCrop: { min_size: 384, max_size: 600 },
RandomShortSideResize: { short_side_sizes: [ 480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800 ], max_size: 1333 } ]
}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- NormalizeBox: {}
- BboxXYXY2XYWH: {}
- Permute: {}
batch_transforms:
- PadMaskBatch: {pad_to_stride: -1, return_pad_mask: true}
batch_size: 4
shuffle: true
drop_last: true
collate_batch: false
use_shared_memory: false


EvalReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: [800, 1333], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_size: 1
shuffle: false
drop_last: false


TestReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: [800, 1333], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_size: 1
shuffle: false
drop_last: false
16 changes: 16 additions & 0 deletions configs/sqr/_base_/deformable_sqr_optimizer_1x.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
epoch: 50

LearningRate:
base_lr: 0.0002
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [40]
use_warmup: false

OptimizerBuilder:
clip_grad_by_norm: 0.1
regularizer: false
optimizer:
type: AdamW
weight_decay: 0.0001
27 changes: 27 additions & 0 deletions configs/sqr/deformable_detr_sqr_r50_12e_coco.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/deformable_detr_sqr_r50.yml',
'_base_/deformable_detr_sqr_reader.yml',
]
weights: output/deformable_detr_sqr_r50_12e_coco/model_final
find_unused_parameters: True


# a standard 1x schedule
epoch: 12

LearningRate:
base_lr: 0.0002
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [8, 11]
use_warmup: false

OptimizerBuilder:
clip_grad_by_norm: 0.1
regularizer: false
optimizer:
type: AdamW
weight_decay: 0.0001
8 changes: 3 additions & 5 deletions ppdet/data/transform/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,13 +956,11 @@ def apply(self, sample, context=None):

resize_h = int(im_scale * float(im_shape[0]) + 0.5)
resize_w = int(im_scale * float(im_shape[1]) + 0.5)

im_scale_x = im_scale
im_scale_y = im_scale
else:
resize_h, resize_w = self.target_size
im_scale_y = resize_h / im_shape[0]
im_scale_x = resize_w / im_shape[1]

im_scale_y = resize_h / im_shape[0]
im_scale_x = resize_w / im_shape[1]

if len(im.shape) == 3:
im = self.apply_image(sample['image'], [im_scale_x, im_scale_y])
Expand Down
109 changes: 109 additions & 0 deletions ppdet/modeling/transformers/deformable_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,3 +535,112 @@ def forward(self, src_feats, src_mask=None, *args, **kwargs):
level_start_index, mask_flatten, query_embed)

return (hs, memory, reference_points)


class QRDeformableTransformerDecoder(DeformableTransformerDecoder):
def __init__(self, decoder_layer, num_layers,
start_q=None, end_q=None, return_intermediate=False):
super(QRDeformableTransformerDecoder, self).__init__(
decoder_layer, num_layers, return_intermediate=return_intermediate)
self.start_q = start_q
self.end_q = end_q

def forward(self,
tgt,
reference_points,
memory,
memory_spatial_shapes,
memory_level_start_index,
memory_mask=None,
query_pos_embed=None):

if not self.training:
return super(QRDeformableTransformerDecoder, self).forward(
tgt, reference_points,
memory, memory_spatial_shapes,
memory_level_start_index,
memory_mask=memory_mask,
query_pos_embed=query_pos_embed)

batchsize = tgt.shape[0]
query_list_reserve = [tgt]
intermediate = []
for lid, layer in enumerate(self.layers):

start_q = self.start_q[lid]
end_q = self.end_q[lid]
query_list = query_list_reserve.copy()[start_q:end_q]

# prepare for parallel process
output = paddle.concat(query_list, axis=0)
fakesetsize = int(output.shape[0] / batchsize)
reference_points_tiled = reference_points.tile([fakesetsize, 1, 1, 1])

memory_tiled = memory.tile([fakesetsize, 1, 1])
query_pos_embed_tiled = query_pos_embed.tile([fakesetsize, 1, 1])
memory_mask_tiled = memory_mask.tile([fakesetsize, 1])

output = layer(output, reference_points_tiled, memory_tiled,
memory_spatial_shapes, memory_level_start_index,
memory_mask_tiled, query_pos_embed_tiled)

for i in range(fakesetsize):
query_list_reserve.append(output[batchsize*i:batchsize*(i+1)])

if self.return_intermediate:
for i in range(fakesetsize):
intermediate.append(output[batchsize*i:batchsize*(i+1)])

if self.return_intermediate:
return paddle.stack(intermediate)

return output.unsqueeze(0)


@register
class QRDeformableTransformer(DeformableTransformer):

def __init__(self,
num_queries=300,
position_embed_type='sine',
return_intermediate_dec=True,
in_feats_channel=[512, 1024, 2048],
num_feature_levels=4,
num_encoder_points=4,
num_decoder_points=4,
hidden_dim=256,
nhead=8,
num_encoder_layers=6,
num_decoder_layers=6,
dim_feedforward=1024,
dropout=0.1,
activation="relu",
lr_mult=0.1,
pe_temperature=10000,
pe_offset=-0.5,
start_q=None,
end_q=None):
super(QRDeformableTransformer, self).__init__(
num_queries=num_queries,
position_embed_type=position_embed_type,
return_intermediate_dec=return_intermediate_dec,
in_feats_channel=in_feats_channel,
num_feature_levels=num_feature_levels,
num_encoder_points=num_encoder_points,
num_decoder_points=num_decoder_points,
hidden_dim=hidden_dim,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation=activation,
lr_mult=lr_mult,
pe_temperature=pe_temperature,
pe_offset=pe_offset)

decoder_layer = DeformableTransformerDecoderLayer(
hidden_dim, nhead, dim_feedforward, dropout, activation,
num_feature_levels, num_decoder_points)
self.decoder = QRDeformableTransformerDecoder(
decoder_layer, num_decoder_layers, start_q, end_q, return_intermediate_dec)
3 changes: 2 additions & 1 deletion ppdet/modeling/transformers/matchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,10 @@ def forward(self,
out_bbox.unsqueeze(1) - tgt_bbox.unsqueeze(0)).abs().sum(-1)

# Compute the giou cost betwen boxes
cost_giou = self.giou_loss(
giou_loss = self.giou_loss(
bbox_cxcywh_to_xyxy(out_bbox.unsqueeze(1)),
bbox_cxcywh_to_xyxy(tgt_bbox.unsqueeze(0))).squeeze(-1)
cost_giou = giou_loss - 1

# Final cost matrix
C = self.matcher_coeff['class'] * cost_class + \
Expand Down

0 comments on commit efcb6ad

Please sign in to comment.