From 76e23d23d393791b0452d2cec3610be725c5baaf Mon Sep 17 00:00:00 2001 From: Jun0922 <89365607+Jun0922@users.noreply.github.com> Date: Thu, 1 Feb 2024 11:49:22 +0900 Subject: [PATCH 1/4] [Bug Fix] video_demo bug fixed for issue#11353 (#11451) --- demo/video_demo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/demo/video_demo.py b/demo/video_demo.py index 6fc36319ae8..f72d617695c 100644 --- a/demo/video_demo.py +++ b/demo/video_demo.py @@ -58,7 +58,7 @@ def main(): args.out, fourcc, video_reader.fps, (video_reader.width, video_reader.height)) - for frame in track_iter_progress(video_reader): + for frame in track_iter_progress((video_reader, len(video_reader))): result = inference_detector(model, frame, test_pipeline=test_pipeline) visualizer.add_datasample( name='video', From 498295aad4e7081a1b8540c299c3ff9ff9a01956 Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Thu, 1 Feb 2024 10:51:51 +0800 Subject: [PATCH 2/4] use `lerp` in ema (#11442) --- mmdet/models/layers/ema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmdet/models/layers/ema.py b/mmdet/models/layers/ema.py index bce503c4641..73a0ca67c28 100644 --- a/mmdet/models/layers/ema.py +++ b/mmdet/models/layers/ema.py @@ -63,4 +63,4 @@ def avg_func(self, averaged_param: Tensor, source_param: Tensor, """ momentum = (1 - self.momentum) * math.exp( -float(1 + steps) / self.gamma) + self.momentum - averaged_param.mul_(1 - momentum).add_(source_param, alpha=momentum) + averaged_param.lerp_(source_param, momentum) From 892e8ecbbe5f9b771a09832dab980b4727e2b31d Mon Sep 17 00:00:00 2001 From: Chen Qibo <69334887+Baboom-l@users.noreply.github.com> Date: Thu, 1 Feb 2024 10:52:22 +0800 Subject: [PATCH 3/4] fix:grounding_pretrain_text_trans_negative_label (#11404) --- mmdet/datasets/transforms/text_transformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmdet/datasets/transforms/text_transformers.py b/mmdet/datasets/transforms/text_transformers.py index 25304d5fe45..12a0e57db3d 100644 --- a/mmdet/datasets/transforms/text_transformers.py +++ b/mmdet/datasets/transforms/text_transformers.py @@ -199,7 +199,7 @@ def od_aug(self, results): for i in np.random.choice( valid_negative_indexes, size=num_negatives, replace=False): - if i not in positive_label_list: + if int(i) not in positive_label_list: negative_label_list.add(i) random.shuffle(positive_label_list) From 2390ebc32384512477c6c1dd51a452a71f45e908 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Haian=20Huang=28=E6=B7=B1=E5=BA=A6=E7=9C=B8=29?= <1286304229@qq.com> Date: Mon, 5 Feb 2024 21:11:41 +0800 Subject: [PATCH 4/4] Release MM-GroundingDINO SwinB and SwinL Weights (#11458) --- README.md | 2 + README_zh-CN.md | 2 + configs/mm_grounding_dino/README.md | 44 ++- .../grounding_dino_swin-b_pretrain_all.py | 335 ++++++++++++++++++ ...dino_swin-b_pretrain_obj365_goldg_v3det.py | 143 ++++++++ .../grounding_dino_swin-l_pretrain_all.py | 18 +- ...nding_dino_swin-l_pretrain_obj365_goldg.py | 227 ++++++++++++ configs/mm_grounding_dino/metafile.yml | 36 ++ mmdet/models/detectors/grounding_dino.py | 15 +- 9 files changed, 799 insertions(+), 23 deletions(-) create mode 100644 configs/mm_grounding_dino/grounding_dino_swin-b_pretrain_all.py create mode 100644 configs/mm_grounding_dino/grounding_dino_swin-b_pretrain_obj365_goldg_v3det.py create mode 100644 configs/mm_grounding_dino/grounding_dino_swin-l_pretrain_obj365_goldg.py diff --git a/README.md b/README.md index 15f71dad5fb..34f7f0b8f90 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,8 @@ Apart from MMDetection, we also released [MMEngine](https://github.com/open-mmla ## What's New +💎 **We have released the pre-trained weights for MM-Grounding-DINO Swin-B and Swin-L, welcome to try and give feedback.** + ### Highlight **v3.3.0** was released in 5/1/2024: diff --git a/README_zh-CN.md b/README_zh-CN.md index 885d1f22617..8d7c060f6f9 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -100,6 +100,8 @@ MMDetection 是一个基于 PyTorch 的目标检测开源工具箱。它是 [Ope ## 最新进展 +💎 **我们已经发布了 MM-Grounding-DINO Swin-B 和 Swin-L 预训练权重,欢迎试用和反馈.** + ### 亮点 **v3.3.0** 版本已经在 2024.1.5 发布: diff --git a/configs/mm_grounding_dino/README.md b/configs/mm_grounding_dino/README.md index bcc913446dc..c88cb1c9026 100644 --- a/configs/mm_grounding_dino/README.md +++ b/configs/mm_grounding_dino/README.md @@ -20,22 +20,33 @@ Grounding-DINO is a state-of-the-art open-set detection model that tackles multi Please refer to [dataset_prepare.md](dataset_prepare.md) or [中文版数据准备](dataset_prepare_zh-CN.md) +## ✨ What's New + +💎 **We have released the pre-trained weights for Swin-B and Swin-L, welcome to try and give feedback.** + ## Usage Please refer to [usage.md](usage.md) or [中文版用法说明](usage_zh-CN.md) ## Zero-Shot COCO Results and Models -| Model | Backbone | Style | COCO mAP | Pre-Train Data | Config | Download | -| :--------: | :------: | :-------: | :--------: | :-------------------: | :------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | -| GDINO-T | Swin-T | Zero-shot | 46.7 | O365 | | | -| GDINO-T | Swin-T | Zero-shot | 48.1 | O365,GoldG | | | -| GDINO-T | Swin-T | Zero-shot | 48.4 | O365,GoldG,Cap4M | [config](../grounding_dino/grounding_dino_swin-t_pretrain_obj365_goldg_cap4m.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/grounding_dino/groundingdino_swint_ogc_mmdet-822d7e9d.pth) | -| MM-GDINO-T | Swin-T | Zero-shot | 48.5(+1.8) | O365 | [config](grounding_dino_swin-t_pretrain_obj365.py) | | -| MM-GDINO-T | Swin-T | Zero-shot | 50.4(+2.3) | O365,GoldG | [config](grounding_dino_swin-t_pretrain_obj365_goldg.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/mm_grounding_dino/grounding_dino_swin-t_pretrain_obj365_goldg/grounding_dino_swin-t_pretrain_obj365_goldg_20231122_132602-4ea751ce.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/mm_grounding_dino/grounding_dino_swin-t_pretrain_obj365_goldg/grounding_dino_swin-t_pretrain_obj365_goldg_20231122_132602.log.json) | -| MM-GDINO-T | Swin-T | Zero-shot | 50.5(+2.1) | O365,GoldG,GRIT | [config](grounding_dino_swin-t_pretrain_obj365_goldg_grit9m.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/mm_grounding_dino/grounding_dino_swin-t_pretrain_obj365_goldg_grit9m/grounding_dino_swin-t_pretrain_obj365_goldg_grit9m_20231128_200818-169cc352.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/mm_grounding_dino/grounding_dino_swin-t_pretrain_obj365_goldg_grit9m/grounding_dino_swin-t_pretrain_obj365_goldg_grit9m_20231128_200818.log.json) | -| MM-GDINO-T | Swin-T | Zero-shot | 50.6(+2.2) | O365,GoldG,V3Det | [config](grounding_dino_swin-t_pretrain_obj365_goldg_v3det.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/mm_grounding_dino/grounding_dino_swin-t_pretrain_obj365_goldg_v3det/grounding_dino_swin-t_pretrain_obj365_goldg_v3det_20231218_095741-e316e297.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/mm_grounding_dino/grounding_dino_swin-t_pretrain_obj365_goldg_v3det/grounding_dino_swin-t_pretrain_obj365_goldg_v3det_20231218_095741.log.json) | -| MM-GDINO-T | Swin-T | Zero-shot | 50.4(+2.0) | O365,GoldG,GRIT,V3Det | [config](grounding_dino_swin-t_pretrain_obj365_goldg_grit9m_v3det.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/mm_grounding_dino/grounding_dino_swin-t_pretrain_obj365_goldg_grit9m_v3det/grounding_dino_swin-t_pretrain_obj365_goldg_grit9m_v3det_20231204_095047-b448804b.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/mm_grounding_dino/grounding_dino_swin-t_pretrain_obj365_goldg_grit9m_v3det/grounding_dino_swin-t_pretrain_obj365_goldg_grit9m_v3det_20231204_095047.log.json) | +| Model | Backbone | Style | COCO mAP | Pre-Train Data | Config | Download | +| :----------: | :------: | :-------: | :--------: | :----------------------: | :------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| GDINO-T | Swin-T | Zero-shot | 46.7 | O365 | | | +| GDINO-T | Swin-T | Zero-shot | 48.1 | O365,GoldG | | | +| GDINO-T | Swin-T | Zero-shot | 48.4 | O365,GoldG,Cap4M | [config](../grounding_dino/grounding_dino_swin-t_pretrain_obj365_goldg_cap4m.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/grounding_dino/groundingdino_swint_ogc_mmdet-822d7e9d.pth) | +| MM-GDINO-T | Swin-T | Zero-shot | 48.5(+1.8) | O365 | [config](grounding_dino_swin-t_pretrain_obj365.py) | | +| MM-GDINO-T | Swin-T | Zero-shot | 50.4(+2.3) | O365,GoldG | [config](grounding_dino_swin-t_pretrain_obj365_goldg.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/mm_grounding_dino/grounding_dino_swin-t_pretrain_obj365_goldg/grounding_dino_swin-t_pretrain_obj365_goldg_20231122_132602-4ea751ce.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/mm_grounding_dino/grounding_dino_swin-t_pretrain_obj365_goldg/grounding_dino_swin-t_pretrain_obj365_goldg_20231122_132602.log.json) | +| MM-GDINO-T | Swin-T | Zero-shot | 50.5(+2.1) | O365,GoldG,GRIT | [config](grounding_dino_swin-t_pretrain_obj365_goldg_grit9m.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/mm_grounding_dino/grounding_dino_swin-t_pretrain_obj365_goldg_grit9m/grounding_dino_swin-t_pretrain_obj365_goldg_grit9m_20231128_200818-169cc352.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/mm_grounding_dino/grounding_dino_swin-t_pretrain_obj365_goldg_grit9m/grounding_dino_swin-t_pretrain_obj365_goldg_grit9m_20231128_200818.log.json) | +| MM-GDINO-T | Swin-T | Zero-shot | 50.6(+2.2) | O365,GoldG,V3Det | [config](grounding_dino_swin-t_pretrain_obj365_goldg_v3det.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/mm_grounding_dino/grounding_dino_swin-t_pretrain_obj365_goldg_v3det/grounding_dino_swin-t_pretrain_obj365_goldg_v3det_20231218_095741-e316e297.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/mm_grounding_dino/grounding_dino_swin-t_pretrain_obj365_goldg_v3det/grounding_dino_swin-t_pretrain_obj365_goldg_v3det_20231218_095741.log.json) | +| MM-GDINO-T | Swin-T | Zero-shot | 50.4(+2.0) | O365,GoldG,GRIT,V3Det | [config](grounding_dino_swin-t_pretrain_obj365_goldg_grit9m_v3det.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/mm_grounding_dino/grounding_dino_swin-t_pretrain_obj365_goldg_grit9m_v3det/grounding_dino_swin-t_pretrain_obj365_goldg_grit9m_v3det_20231204_095047-b448804b.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/mm_grounding_dino/grounding_dino_swin-t_pretrain_obj365_goldg_grit9m_v3det/grounding_dino_swin-t_pretrain_obj365_goldg_grit9m_v3det_20231204_095047.log.json) | +| MM-GDINO-B | Swin-B | Zero-shot | 52.5 | O365,GoldG,V3Det | [config](grounding_dino_swin-b_pretrain_obj365_goldg_v3det.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/mm_grounding_dino/grounding_dino_swin-b_pretrain_obj365_goldg_v3det/grounding_dino_swin-b_pretrain_obj365_goldg_v3de-f83eef00.pth) \| [log](<>) | +| MM-GDINO-B\* | Swin-B | - | 59.5 | O365,ALL | [config](grounding_dino_swin-b_pretrain_all.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/mm_grounding_dino/grounding_dino_swin-b_pretrain_all/grounding_dino_swin-b_pretrain_all-f9818a7c.pth) \| [log](<>) | +| MM-GDINO-L | Swin-L | Zero-shot | 53.0 | O365V2,OpenImageV6,GoldG | [config](grounding_dino_swin-l_pretrain_obj365_goldg.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/mm_grounding_dino/grounding_dino_swin-l_pretrain_obj365_goldg/grounding_dino_swin-l_pretrain_obj365_goldg-34dcdc53.pth) \| [log](<>) | +| MM-GDINO-L\* | Swin-L | - | 60.3 | O365V2,OpenImageV6,ALL | [config](grounding_dino_swin-l_pretrain_all.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/mm_grounding_dino/grounding_dino_swin-l_pretrain_all/grounding_dino_swin-l_pretrain_all-56d69e78.pth) \| [log](<>) | + +- This * indicates that the model has not been fully trained yet. We will release the final weights in the future. +- ALL: GoldG,V3det,COCO2017,LVISV1,COCO2014,GRIT,RefCOCO,RefCOCO+,RefCOCOg,gRefCOCO. ## Zero-Shot LVIS Results @@ -361,3 +372,16 @@ Note: | MM-GDINO | Swin-T | 5e | 45.1 | 64.7 | 42.5 | 65.5 | 40.3 | 63.2 | - The MM-GDINO-T config file is [here](refcoco/grounding_dino_swin-t_finetune_8xb4_5e_grefcoco.py) + +## Citation + +If you find this project useful in your research, please consider citing: + +```latex +@article{zhao2024open, + title={An Open and Comprehensive Pipeline for Unified Object Grounding and Detection}, + author={Zhao, Xiangyu and Chen, Yicheng and Xu, Shilin and Li, Xiangtai and Wang, Xinjiang and Li, Yining and Huang, Haian}, + journal={arXiv preprint arXiv:2401.02361}, + year={2024} +} +``` diff --git a/configs/mm_grounding_dino/grounding_dino_swin-b_pretrain_all.py b/configs/mm_grounding_dino/grounding_dino_swin-b_pretrain_all.py new file mode 100644 index 00000000000..eff58bba6b1 --- /dev/null +++ b/configs/mm_grounding_dino/grounding_dino_swin-b_pretrain_all.py @@ -0,0 +1,335 @@ +_base_ = 'grounding_dino_swin-t_pretrain_obj365.py' + +load_from = 'https://download.openmmlab.com/mmdetection/v3.0/mm_grounding_dino/grounding_dino_swin-b_pretrain_obj365_goldg_v3det/grounding_dino_swin-b_pretrain_obj365_goldg_v3de-f83eef00.pth' # noqa + +model = dict( + use_autocast=True, + backbone=dict( + _delete_=True, + type='SwinTransformer', + pretrain_img_size=384, + embed_dims=128, + depths=[2, 2, 18, 2], + num_heads=[4, 8, 16, 32], + window_size=12, + mlp_ratio=4, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.3, + patch_norm=True, + out_indices=(1, 2, 3), + with_cp=True, + convert_weights=True, + frozen_stages=-1, + init_cfg=None), + neck=dict(in_channels=[256, 512, 1024]), +) + +o365v1_od_dataset = dict( + type='ODVGDataset', + data_root='data/objects365v1/', + ann_file='o365v1_train_odvg.json', + label_map_file='o365v1_label_map.json', + data_prefix=dict(img='train/'), + filter_cfg=dict(filter_empty_gt=False), + pipeline=_base_.train_pipeline, + return_classes=True, + backend_args=None, +) + +flickr30k_dataset = dict( + type='ODVGDataset', + data_root='data/flickr30k_entities/', + ann_file='final_flickr_separateGT_train_vg.json', + label_map_file=None, + data_prefix=dict(img='flickr30k_images/'), + filter_cfg=dict(filter_empty_gt=False), + pipeline=_base_.train_pipeline, + return_classes=True, + backend_args=None) + +gqa_dataset = dict( + type='ODVGDataset', + data_root='data/gqa/', + ann_file='final_mixed_train_no_coco_vg.json', + label_map_file=None, + data_prefix=dict(img='images/'), + filter_cfg=dict(filter_empty_gt=False), + pipeline=_base_.train_pipeline, + return_classes=True, + backend_args=None) + +v3d_train_pipeline = [ + dict(type='LoadImageFromFile', backend_args=_base_.backend_args), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='RandomFlip', prob=0.5), + dict( + type='RandomChoice', + transforms=[ + [ + dict( + type='RandomChoiceResize', + scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + keep_ratio=True) + ], + [ + dict( + type='RandomChoiceResize', + # The radio of all image in train dataset < 7 + # follow the original implement + scales=[(400, 4200), (500, 4200), (600, 4200)], + keep_ratio=True), + dict( + type='RandomCrop', + crop_type='absolute_range', + crop_size=(384, 600), + allow_negative_crop=True), + dict( + type='RandomChoiceResize', + scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + keep_ratio=True) + ] + ]), + dict(type='FilterAnnotations', min_gt_bbox_wh=(1e-2, 1e-2)), + dict( + type='RandomSamplingNegPos', + tokenizer_name=_base_.lang_model_name, + num_sample_negative=85, + # change this + label_map_file='data/V3Det/annotations/v3det_2023_v1_label_map.json', + max_tokens=256), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'flip', 'flip_direction', 'text', + 'custom_entities', 'tokens_positive', 'dataset_mode')) +] +v3det_dataset = dict( + type='ODVGDataset', + data_root='data/V3Det/', + ann_file='annotations/v3det_2023_v1_train_od.json', + label_map_file='annotations/v3det_2023_v1_label_map.json', + data_prefix=dict(img=''), + filter_cfg=dict(filter_empty_gt=False), + need_text=False, # change this + pipeline=v3d_train_pipeline, + return_classes=True, + backend_args=None) + +grit_dataset = dict( + type='ODVGDataset', + data_root='grit_processed/', + ann_file='grit20m_vg.json', + label_map_file=None, + data_prefix=dict(img=''), + filter_cfg=dict(filter_empty_gt=False), + pipeline=_base_.train_pipeline, + return_classes=True, + backend_args=None) + +# --------------------------- lvis od dataset--------------------------- +lvis_train_pipeline = [ + dict(type='LoadImageFromFile', backend_args=_base_.backend_args), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='RandomFlip', prob=0.5), + dict( + type='RandomChoice', + transforms=[ + [ + dict( + type='RandomChoiceResize', + scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + keep_ratio=True) + ], + [ + dict( + type='RandomChoiceResize', + # The radio of all image in train dataset < 7 + # follow the original implement + scales=[(400, 4200), (500, 4200), (600, 4200)], + keep_ratio=True), + dict( + type='RandomCrop', + crop_type='absolute_range', + crop_size=(384, 600), + allow_negative_crop=True), + dict( + type='RandomChoiceResize', + scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + keep_ratio=True) + ] + ]), + dict(type='FilterAnnotations', min_gt_bbox_wh=(1e-2, 1e-2)), + dict( + type='RandomSamplingNegPos', + tokenizer_name=_base_.lang_model_name, + num_sample_negative=85, + # change this + label_map_file='data/coco/annotations/lvis_v1_label_map.json', + max_tokens=256), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'flip', 'flip_direction', 'text', + 'custom_entities', 'tokens_positive', 'dataset_mode')) +] +lvis_dataset = dict( + type='ClassBalancedDataset', + oversample_thr=1e-3, + dataset=dict( + type='ODVGDataset', + data_root='data/coco/', + ann_file='annotations/lvis_v1_train_od.json', + label_map_file='annotations/lvis_v1_label_map.json', + data_prefix=dict(img=''), + filter_cfg=dict(filter_empty_gt=False), + need_text=False, # change this + pipeline=lvis_train_pipeline, + return_classes=True, + backend_args=None)) + +# --------------------------- coco2017 od dataset--------------------------- +coco2017_train_dataset = dict( + type='RepeatDataset', + times=2, + dataset=dict( + type='ODVGDataset', + data_root='data/coco/', + ann_file='annotations/instance_train2017_norefval_od.json', + label_map_file='annotations/coco2017_label_map.json', + data_prefix=dict(img='train2017'), + filter_cfg=dict(filter_empty_gt=False), + pipeline=_base_.train_pipeline, + return_classes=True, + backend_args=None)) + +# --------------------------- coco2014 vg dataset--------------------------- +coco2014_vg_dataset = dict( + type='ODVGDataset', + data_root='data/coco/', + ann_file='mdetr_annotations/final_mixed_train_only_coco_vg.json', + label_map_file=None, + data_prefix=dict(img='train2014/'), + filter_cfg=dict(filter_empty_gt=False), + pipeline=_base_.train_pipeline, + return_classes=True, + backend_args=None) + +# --------------------------- refcoco vg dataset--------------------------- +refcoco_dataset = dict( + type='RepeatDataset', + times=2, + dataset=dict( + type='ODVGDataset', + data_root='data/coco/', + ann_file='mdetr_annotations/finetune_refcoco_train_vg.json', + label_map_file=None, + data_prefix=dict(img='train2014'), + filter_cfg=dict(filter_empty_gt=False), + pipeline=_base_.train_pipeline, + return_classes=True, + backend_args=None)) + +# --------------------------- refcoco+ vg dataset--------------------------- +refcoco_plus_dataset = dict( + type='RepeatDataset', + times=2, + dataset=dict( + type='ODVGDataset', + data_root='data/coco/', + ann_file='mdetr_annotations/finetune_refcoco+_train_vg.json', + label_map_file=None, + data_prefix=dict(img='train2014'), + filter_cfg=dict(filter_empty_gt=False), + pipeline=_base_.train_pipeline, + return_classes=True, + backend_args=None)) + +# --------------------------- refcocog vg dataset--------------------------- +refcocog_dataset = dict( + type='RepeatDataset', + times=3, + dataset=dict( + type='ODVGDataset', + data_root='data/coco/', + ann_file='mdetr_annotations/finetune_refcocog_train_vg.json', + label_map_file=None, + data_prefix=dict(img='train2014'), + filter_cfg=dict(filter_empty_gt=False), + pipeline=_base_.train_pipeline, + return_classes=True, + backend_args=None)) + +# --------------------------- grefcoco vg dataset--------------------------- +grefcoco_dataset = dict( + type='RepeatDataset', + times=2, + dataset=dict( + type='ODVGDataset', + data_root='data/coco/', + ann_file='mdetr_annotations/finetune_grefcoco_train_vg.json', + label_map_file=None, + data_prefix=dict(img='train2014'), + filter_cfg=dict(filter_empty_gt=False), + pipeline=_base_.train_pipeline, + return_classes=True, + backend_args=None)) + +# --------------------------- dataloader--------------------------- +train_dataloader = dict( + batch_size=4, + num_workers=4, + sampler=dict( + _delete_=True, + type='CustomSampleSizeSampler', + ratio_mode=True, + dataset_size=[-1, -1, 0.07, -1, -1, -1, -1, -1, -1, -1, -1, -1]), + dataset=dict(datasets=[ + o365v1_od_dataset, # 1.74M + v3det_dataset, # + grit_dataset, + lvis_dataset, + coco2017_train_dataset, # 0.12M + flickr30k_dataset, # 0.15M + gqa_dataset, # 0.62M + coco2014_vg_dataset, # 0.49M + refcoco_dataset, # 0.12M + refcoco_plus_dataset, # 0.12M + refcocog_dataset, # 0.08M + grefcoco_dataset, # 0.19M + ])) + +optim_wrapper = dict(optimizer=dict(lr=0.0001)) + +# learning policy +max_iter = 304680 +train_cfg = dict( + _delete_=True, + type='IterBasedTrainLoop', + max_iters=max_iter, + val_interval=10000) + +param_scheduler = [ + dict(type='LinearLR', start_factor=0.1, by_epoch=False, begin=0, end=1000), + dict( + type='MultiStepLR', + begin=0, + end=max_iter, + by_epoch=False, + milestones=[228510], + gamma=0.1) +] + +default_hooks = dict( + checkpoint=dict(by_epoch=False, interval=10000, max_keep_ckpts=20)) +log_processor = dict(by_epoch=False) diff --git a/configs/mm_grounding_dino/grounding_dino_swin-b_pretrain_obj365_goldg_v3det.py b/configs/mm_grounding_dino/grounding_dino_swin-b_pretrain_obj365_goldg_v3det.py new file mode 100644 index 00000000000..743d02cffbe --- /dev/null +++ b/configs/mm_grounding_dino/grounding_dino_swin-b_pretrain_obj365_goldg_v3det.py @@ -0,0 +1,143 @@ +_base_ = 'grounding_dino_swin-t_pretrain_obj365.py' + +pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth' # noqa +model = dict( + use_autocast=True, + backbone=dict( + _delete_=True, + type='SwinTransformer', + pretrain_img_size=384, + embed_dims=128, + depths=[2, 2, 18, 2], + num_heads=[4, 8, 16, 32], + window_size=12, + mlp_ratio=4, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.3, + patch_norm=True, + out_indices=(1, 2, 3), + with_cp=True, + convert_weights=True, + frozen_stages=-1, + init_cfg=dict(type='Pretrained', checkpoint=pretrained)), + neck=dict(in_channels=[256, 512, 1024]), +) + +o365v1_od_dataset = dict( + type='ODVGDataset', + data_root='data/objects365v1/', + ann_file='o365v1_train_odvg.json', + label_map_file='o365v1_label_map.json', + data_prefix=dict(img='train/'), + filter_cfg=dict(filter_empty_gt=False), + pipeline=_base_.train_pipeline, + return_classes=True, + backend_args=None, +) + +flickr30k_dataset = dict( + type='ODVGDataset', + data_root='data/flickr30k_entities/', + ann_file='final_flickr_separateGT_train_vg.json', + label_map_file=None, + data_prefix=dict(img='flickr30k_images/'), + filter_cfg=dict(filter_empty_gt=False), + pipeline=_base_.train_pipeline, + return_classes=True, + backend_args=None) + +gqa_dataset = dict( + type='ODVGDataset', + data_root='data/gqa/', + ann_file='final_mixed_train_no_coco_vg.json', + label_map_file=None, + data_prefix=dict(img='images/'), + filter_cfg=dict(filter_empty_gt=False), + pipeline=_base_.train_pipeline, + return_classes=True, + backend_args=None) + +v3d_train_pipeline = [ + dict(type='LoadImageFromFile', backend_args=_base_.backend_args), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='RandomFlip', prob=0.5), + dict( + type='RandomChoice', + transforms=[ + [ + dict( + type='RandomChoiceResize', + scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + keep_ratio=True) + ], + [ + dict( + type='RandomChoiceResize', + # The radio of all image in train dataset < 7 + # follow the original implement + scales=[(400, 4200), (500, 4200), (600, 4200)], + keep_ratio=True), + dict( + type='RandomCrop', + crop_type='absolute_range', + crop_size=(384, 600), + allow_negative_crop=True), + dict( + type='RandomChoiceResize', + scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + keep_ratio=True) + ] + ]), + dict(type='FilterAnnotations', min_gt_bbox_wh=(1e-2, 1e-2)), + dict( + type='RandomSamplingNegPos', + tokenizer_name=_base_.lang_model_name, + num_sample_negative=85, + # change this + label_map_file='data/V3Det/annotations/v3det_2023_v1_label_map.json', + max_tokens=256), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'flip', 'flip_direction', 'text', + 'custom_entities', 'tokens_positive', 'dataset_mode')) +] +v3det_dataset = dict( + type='ODVGDataset', + data_root='data/V3Det/', + ann_file='annotations/v3det_2023_v1_train_od.json', + label_map_file='annotations/v3det_2023_v1_label_map.json', + data_prefix=dict(img=''), + filter_cfg=dict(filter_empty_gt=False), + need_text=False, # change this + pipeline=v3d_train_pipeline, + return_classes=True, + backend_args=None) + +train_dataloader = dict( + dataset=dict(datasets=[ + o365v1_od_dataset, flickr30k_dataset, gqa_dataset, v3det_dataset + ])) + +# learning policy +max_epochs = 18 +param_scheduler = [ + dict(type='LinearLR', start_factor=0.1, by_epoch=False, begin=0, end=1000), + dict( + type='MultiStepLR', + begin=0, + end=max_epochs, + by_epoch=True, + milestones=[13, 16], + gamma=0.1) +] + +train_cfg = dict( + type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=1) diff --git a/configs/mm_grounding_dino/grounding_dino_swin-l_pretrain_all.py b/configs/mm_grounding_dino/grounding_dino_swin-l_pretrain_all.py index 46241e2e03b..a17f2344e14 100644 --- a/configs/mm_grounding_dino/grounding_dino_swin-l_pretrain_all.py +++ b/configs/mm_grounding_dino/grounding_dino_swin-l_pretrain_all.py @@ -1,8 +1,10 @@ _base_ = 'grounding_dino_swin-t_pretrain_obj365.py' -pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth' # noqa +load_from = 'https://download.openmmlab.com/mmdetection/v3.0/mm_grounding_dino/grounding_dino_swin-l_pretrain_obj365_goldg/grounding_dino_swin-l_pretrain_obj365_goldg-34dcdc53.pth' # noqa + num_levels = 5 model = dict( + use_autocast=True, num_feature_levels=num_levels, backbone=dict( _delete_=True, @@ -25,7 +27,7 @@ with_cp=True, convert_weights=True, frozen_stages=-1, - init_cfg=dict(type='Pretrained', checkpoint=pretrained)), + init_cfg=None), neck=dict(in_channels=[192, 384, 768, 1536], num_outs=num_levels), encoder=dict(layer_cfg=dict(self_attn_cfg=dict(num_levels=num_levels))), decoder=dict(layer_cfg=dict(cross_attn_cfg=dict(num_levels=num_levels)))) @@ -512,14 +514,10 @@ grit_dataset # 9M ])) -# bs=256 -optim_wrapper = dict(optimizer=dict(lr=0.0008)) +# 4NODES * 8GPU +optim_wrapper = dict(optimizer=dict(lr=0.0001)) -# one epoch = (3.2+3.3)M/256 = 25390 iter -# 24e=609360 iter -# 16e=406240 iter -# 20e=507800 iter -max_iter = 609360 +max_iter = 250000 train_cfg = dict( _delete_=True, type='IterBasedTrainLoop', @@ -533,7 +531,7 @@ begin=0, end=max_iter, by_epoch=False, - milestones=[406240, 507800], + milestones=[210000], gamma=0.1) ] diff --git a/configs/mm_grounding_dino/grounding_dino_swin-l_pretrain_obj365_goldg.py b/configs/mm_grounding_dino/grounding_dino_swin-l_pretrain_obj365_goldg.py new file mode 100644 index 00000000000..85d43f96b3b --- /dev/null +++ b/configs/mm_grounding_dino/grounding_dino_swin-l_pretrain_obj365_goldg.py @@ -0,0 +1,227 @@ +_base_ = 'grounding_dino_swin-t_pretrain_obj365.py' + +pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth' # noqa +num_levels = 5 +model = dict( + use_autocast=True, + num_feature_levels=num_levels, + backbone=dict( + _delete_=True, + type='SwinTransformer', + pretrain_img_size=384, + embed_dims=192, + depths=[2, 2, 18, 2], + num_heads=[6, 12, 24, 48], + window_size=12, + mlp_ratio=4, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + patch_norm=True, + out_indices=(0, 1, 2, 3), + # Please only add indices that would be used + # in FPN, otherwise some parameter will not be used + with_cp=True, + convert_weights=True, + frozen_stages=-1, + init_cfg=dict(type='Pretrained', checkpoint=pretrained)), + neck=dict(in_channels=[192, 384, 768, 1536], num_outs=num_levels), + encoder=dict(layer_cfg=dict(self_attn_cfg=dict(num_levels=num_levels))), + decoder=dict(layer_cfg=dict(cross_attn_cfg=dict(num_levels=num_levels)))) + +# --------------------------- object365v2 od dataset--------------------------- +# objv2_backend_args = dict( +# backend='petrel', +# path_mapping=dict({ +# './data/objects365v2/': 'yudong:s3://wangyudong/obj365_v2/', +# 'data/objects365v2/': 'yudong:s3://wangyudong/obj365_v2/' +# })) +objv2_backend_args = None + +objv2_train_pipeline = [ + dict(type='LoadImageFromFile', backend_args=objv2_backend_args), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='RandomFlip', prob=0.5), + dict( + type='RandomChoice', + transforms=[ + [ + dict( + type='RandomChoiceResize', + scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + keep_ratio=True) + ], + [ + dict( + type='RandomChoiceResize', + # The radio of all image in train dataset < 7 + # follow the original implement + scales=[(400, 4200), (500, 4200), (600, 4200)], + keep_ratio=True), + dict( + type='RandomCrop', + crop_type='absolute_range', + crop_size=(384, 600), + allow_negative_crop=True), + dict( + type='RandomChoiceResize', + scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + keep_ratio=True) + ] + ]), + dict(type='FilterAnnotations', min_gt_bbox_wh=(1e-2, 1e-2)), + dict( + type='RandomSamplingNegPos', + tokenizer_name=_base_.lang_model_name, + num_sample_negative=85, + # change this + label_map_file='data/objects365v2/annotations/o365v2_label_map.json', + max_tokens=256), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'flip', 'flip_direction', 'text', + 'custom_entities', 'tokens_positive', 'dataset_mode')) +] + +o365v2_dataset = dict( + type='ODVGDataset', + data_root='data/objects365v2/', + ann_file='annotations/zhiyuan_objv2_train_od.json', + label_map_file='annotations/o365v2_label_map.json', + data_prefix=dict(img='train/'), + filter_cfg=dict(filter_empty_gt=False), + pipeline=objv2_train_pipeline, + return_classes=True, + need_text=False, + backend_args=None, +) + +# --------------------------- openimagev6 od dataset--------------------------- +# oi_backend_args = dict( +# backend='petrel', +# path_mapping=dict({ +# './data/': 's3://openmmlab/datasets/detection/', +# 'data/': 's3://openmmlab/datasets/detection/' +# })) +oi_backend_args = None + +oi_train_pipeline = [ + dict(type='LoadImageFromFile', backend_args=oi_backend_args), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='RandomFlip', prob=0.5), + dict( + type='RandomChoice', + transforms=[ + [ + dict( + type='RandomChoiceResize', + scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + keep_ratio=True) + ], + [ + dict( + type='RandomChoiceResize', + # The radio of all image in train dataset < 7 + # follow the original implement + scales=[(400, 4200), (500, 4200), (600, 4200)], + keep_ratio=True), + dict( + type='RandomCrop', + crop_type='absolute_range', + crop_size=(384, 600), + allow_negative_crop=True), + dict( + type='RandomChoiceResize', + scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + keep_ratio=True) + ] + ]), + dict(type='FilterAnnotations', min_gt_bbox_wh=(1e-2, 1e-2)), + dict( + type='RandomSamplingNegPos', + tokenizer_name=_base_.lang_model_name, + num_sample_negative=85, + # change this + label_map_file='data/OpenImages/annotations/openimages_label_map.json', + max_tokens=256), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'flip', 'flip_direction', 'text', + 'custom_entities', 'tokens_positive', 'dataset_mode')) +] + +oiv6_dataset = dict( + type='ODVGDataset', + data_root='data/OpenImages/', + ann_file='annotations/oidv6-train-annotations_od.json', + label_map_file='annotations/openimages_label_map.json', + data_prefix=dict(img='OpenImages/train/'), + filter_cfg=dict(filter_empty_gt=False), + need_text=False, + pipeline=oi_train_pipeline, + return_classes=True, + backend_args=None) + +flickr30k_dataset = dict( + type='ODVGDataset', + data_root='data/flickr30k_entities/', + ann_file='final_flickr_separateGT_train_vg.json', + label_map_file=None, + data_prefix=dict(img='flickr30k_images/'), + filter_cfg=dict(filter_empty_gt=False), + pipeline=_base_.train_pipeline, + return_classes=True, + backend_args=None) + +gqa_dataset = dict( + type='ODVGDataset', + data_root='data/gqa/', + ann_file='final_mixed_train_no_coco_vg.json', + label_map_file=None, + data_prefix=dict(img='images/'), + filter_cfg=dict(filter_empty_gt=False), + pipeline=_base_.train_pipeline, + return_classes=True, + backend_args=None) + +train_dataloader = dict( + dataset=dict(datasets=[ + o365v2_dataset, oiv6_dataset, flickr30k_dataset, gqa_dataset + ])) + +# 4Nodex8GPU +optim_wrapper = dict(optimizer=dict(lr=0.0002)) + +max_iter = 200000 +train_cfg = dict( + _delete_=True, + type='IterBasedTrainLoop', + max_iters=max_iter, + val_interval=13000) + +param_scheduler = [ + dict(type='LinearLR', start_factor=0.1, by_epoch=False, begin=0, end=1000), + dict( + type='MultiStepLR', + begin=0, + end=max_iter, + by_epoch=False, + milestones=[156100], + gamma=0.5) +] + +default_hooks = dict( + checkpoint=dict(by_epoch=False, interval=13000, max_keep_ckpts=30)) +log_processor = dict(by_epoch=False) diff --git a/configs/mm_grounding_dino/metafile.yml b/configs/mm_grounding_dino/metafile.yml index 3071686e7ac..c104ac05136 100644 --- a/configs/mm_grounding_dino/metafile.yml +++ b/configs/mm_grounding_dino/metafile.yml @@ -52,3 +52,39 @@ Models: Metrics: box AP: 50.4 Weights: https://download.openmmlab.com/mmdetection/v3.0/mm_grounding_dino/grounding_dino_swin-t_pretrain_obj365_goldg_grit9m_v3det/grounding_dino_swin-t_pretrain_obj365_goldg_grit9m_v3det_20231204_095047-b448804b.pth + - Name: grounding_dino_swin-b_pretrain_obj365_goldg_v3det + In Collection: MM Grounding DINO + Config: configs/mm_grounding_dino/grounding_dino_swin-b_pretrain_obj365_goldg_v3det.py + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 52.5 + Weights: https://download.openmmlab.com/mmdetection/v3.0/mm_grounding_dino/grounding_dino_swin-b_pretrain_obj365_goldg_v3det/grounding_dino_swin-b_pretrain_obj365_goldg_v3de-f83eef00.pth + - Name: grounding_dino_swin-b_pretrain_all + In Collection: MM Grounding DINO + Config: configs/mm_grounding_dino/grounding_dino_swin-b_pretrain_all.py + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 59.5 + Weights: https://download.openmmlab.com/mmdetection/v3.0/mm_grounding_dino/grounding_dino_swin-b_pretrain_all/grounding_dino_swin-b_pretrain_all-f9818a7c.pth + - Name: grounding_dino_swin-l_pretrain_obj365_goldg + In Collection: MM Grounding DINO + Config: configs/mm_grounding_dino/grounding_dino_swin-l_pretrain_obj365_goldg.py + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 53.0 + Weights: https://download.openmmlab.com/mmdetection/v3.0/mm_grounding_dino/grounding_dino_swin-l_pretrain_obj365_goldg/grounding_dino_swin-l_pretrain_obj365_goldg-34dcdc53.pth + - Name: grounding_dino_swin-l_pretrain_all + In Collection: MM Grounding DINO + Config: configs/mm_grounding_dino/grounding_dino_swin-l_pretrain_all.py + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 60.3 + Weights: https://download.openmmlab.com/mmdetection/v3.0/mm_grounding_dino/grounding_dino_swin-l_pretrain_all/grounding_dino_swin-l_pretrain_all-56d69e78.pth diff --git a/mmdet/models/detectors/grounding_dino.py b/mmdet/models/detectors/grounding_dino.py index 4ec9d14e634..b1ab7c2da16 100644 --- a/mmdet/models/detectors/grounding_dino.py +++ b/mmdet/models/detectors/grounding_dino.py @@ -6,6 +6,7 @@ import torch import torch.nn as nn +from mmengine.runner.amp import autocast from torch import Tensor from mmdet.registry import MODELS @@ -51,10 +52,15 @@ class GroundingDINO(DINO): `_. """ - def __init__(self, language_model, *args, **kwargs) -> None: + def __init__(self, + language_model, + *args, + use_autocast=False, + **kwargs) -> None: self.language_model_cfg = language_model self._special_tokens = '. ' + self.use_autocast = use_autocast super().__init__(*args, **kwargs) def _init_layers(self) -> None: @@ -483,8 +489,11 @@ def loss(self, batch_inputs: Tensor, data_samples.gt_instances.text_token_mask = \ text_token_mask.unsqueeze(0).repeat( len(positive_map), 1) - - visual_features = self.extract_feat(batch_inputs) + if self.use_autocast: + with autocast(enabled=True): + visual_features = self.extract_feat(batch_inputs) + else: + visual_features = self.extract_feat(batch_inputs) head_inputs_dict = self.forward_transformer(visual_features, text_dict, batch_data_samples)