From 36d865e7021ad6289b01760e79b74bedf7787393 Mon Sep 17 00:00:00 2001 From: Xuan Ju <89566272+juxuan27@users.noreply.github.com> Date: Tue, 26 Sep 2023 13:58:17 +0800 Subject: [PATCH] [Feature] Add detectors trained on humanart (#2724) --- projects/rtmpose/README.md | 15 ++ projects/rtmpose/README_CN.md | 15 ++ .../rtmdet/person/humanart_detection.py | 95 +++++++ .../person/rtmdet_l_8xb32-300e_humanart.py | 180 +++++++++++++ .../person/rtmdet_m_8xb32-300e_humanart.py | 6 + .../person/rtmdet_s_8xb32-300e_humanart.py | 62 +++++ .../person/rtmdet_tiny_8xb32-300e_humanart.py | 43 +++ .../person/rtmdet_x_8xb32-300e_humanart.py | 7 + .../humanart/yolox_l_8xb8-300e_humanart.py | 8 + .../humanart/yolox_m_8xb8-300e_humanart.py | 8 + .../humanart/yolox_nano_8xb8-300e_humanart.py | 11 + .../humanart/yolox_s_8xb8-300e_humanart.py | 250 ++++++++++++++++++ .../humanart/yolox_tiny_8xb8-300e_humanart.py | 54 ++++ .../humanart/yolox_x_8xb8-300e_humanart.py | 8 + 14 files changed, 762 insertions(+) create mode 100644 projects/rtmpose/rtmdet/person/humanart_detection.py create mode 100644 projects/rtmpose/rtmdet/person/rtmdet_l_8xb32-300e_humanart.py create mode 100644 projects/rtmpose/rtmdet/person/rtmdet_m_8xb32-300e_humanart.py create mode 100644 projects/rtmpose/rtmdet/person/rtmdet_s_8xb32-300e_humanart.py create mode 100644 projects/rtmpose/rtmdet/person/rtmdet_tiny_8xb32-300e_humanart.py create mode 100644 projects/rtmpose/rtmdet/person/rtmdet_x_8xb32-300e_humanart.py create mode 100644 projects/rtmpose/yolox/humanart/yolox_l_8xb8-300e_humanart.py create mode 100644 projects/rtmpose/yolox/humanart/yolox_m_8xb8-300e_humanart.py create mode 100644 projects/rtmpose/yolox/humanart/yolox_nano_8xb8-300e_humanart.py create mode 100644 projects/rtmpose/yolox/humanart/yolox_s_8xb8-300e_humanart.py create mode 100644 projects/rtmpose/yolox/humanart/yolox_tiny_8xb8-300e_humanart.py create mode 100644 projects/rtmpose/yolox/humanart/yolox_x_8xb8-300e_humanart.py diff --git a/projects/rtmpose/README.md b/projects/rtmpose/README.md index 59a7d2297d..27a8a90144 100644 --- a/projects/rtmpose/README.md +++ b/projects/rtmpose/README.md @@ -219,6 +219,8 @@ Feel free to join our community group for more help: - RTMPose for Human-Centric Artificial Scenes is supported by [Human-Art](https://github.com/IDEA-Research/HumanArt) - +Pose Estimators: + | Config | Input Size | AP
(Human-Art GT) | Params
(M) | FLOPS
(G) | ORT-Latency
(ms)
(i7-11700) | TRT-FP16-Latency
(ms)
(GTX 1660Ti) | ncnn-FP16-Latency
(ms)
(Snapdragon 865) | Download | | :-----------------------------------------------------------------------------: | :--------: | :-----------------------: | :----------------: | :---------------: | :-----------------------------------------: | :------------------------------------------------: | :-----------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | | [RTMPose-t\*](./rtmpose/body_2d_keypoint/rtmpose-t_8xb256-420e_coco-256x192.py) | 256x192 | 65.5 | 3.34 | 0.36 | 3.20 | 1.06 | 9.02 | [pth](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-t_8xb256-420e_humanart-256x192-60b68c98_20230612.pth)
[onnx](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/onnx_sdk/rtmpose-t_8xb256-420e_humanart-256x192-60b68c98_20230612.zip) | @@ -226,6 +228,19 @@ Feel free to join our community group for more help: | [RTMPose-m\*](./rtmpose/body_2d_keypoint/rtmpose-m_8xb256-420e_coco-256x192.py) | 256x192 | 72.8 | 13.59 | 1.93 | 11.06 | 2.29 | 26.44 | [pth](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-m_8xb256-420e_humanart-256x192-8430627b_20230611.pth)
[onnx](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/onnx_sdk/rtmpose-m_8xb256-420e_humanart-256x192-8430627b_20230611.zip) | | [RTMPose-l\*](./rtmpose/body_2d_keypoint/rtmpose-l_8xb256-420e_coco-256x192.py) | 256x192 | 75.3 | 27.66 | 4.16 | 18.85 | 3.46 | 45.37 | [pth](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-l_8xb256-420e_humanart-256x192-389f2cb0_20230611.pth)
[onnx](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/onnx_sdk/rtmpose-l_8xb256-420e_humanart-256x192-389f2cb0_20230611.zip) | +Detetors: + +| Detection Config | Input Size | Model AP
(OneHand10K) | Flops
(G) | ORT-Latency
(ms)
(i7-11700) | TRT-FP16-Latency
(ms)
(GTX 1660Ti) | Download | +| :---------------------------: | :--------: | :---------------------------: | :---------------: | :-----------------------------------------: | :------------------------------------------------: | :--------------------: | +| [RTMDet-tiny](./rtmdet/person/rtmdet_tiny_8xb32-300e_humanart.py) | 640x640 | 46.6 | - | - | - | [Det Model](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmdet_tiny_8xb32-300e_humanart-7da5554e.pth) | +| [RTMDet-s](./rtmdet/person/rtmdet_s_8xb32-300e_humanart.py) | 640x640 | 50.6 | - | - | - | [Det Model](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmdet_s_8xb32-300e_humanart-af5bd52d.pth) | +| [YOLOX-nano](./yolox/humanart/yolox_nano_8xb8-300e_humanart.py) | 640x640 | 38.9 | - | - | - | [Det Model](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/yolox_nano_8xb8-300e_humanart-40f6f0d0.pth) | +| [YOLOX-tiny](./yolox/humanart/yolox_tiny_8xb8-300e_humanart.py) | 640x640 | 47.7 | - | - | - | [Det Model](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/yolox_tiny_8xb8-300e_humanart-6f3252f9.pth) | +| [YOLOX-s](./yolox/humanart/yolox_s_8xb8-300e_humanart.py) | 640x640 | 54.6 | - | - | - | [Det Model](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/yolox_s_8xb8-300e_humanart-3ef259a7.pth) | +| [YOLOX-m](./yolox/humanart/yolox_m_8xb8-300e_humanart.py) | 640x640 | 59.1 | - | - | - | [Det Model](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/yolox_m_8xb8-300e_humanart-c2c7a14a.pth) | +| [YOLOX-l](./yolox/humanart/yolox_l_8xb8-300e_humanart.py) | 640x640 | 60.2 | - | - | - | [Det Model](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/yolox_l_8xb8-300e_humanart-ce1d7a62.pth) | +| [YOLOX-x](./yolox/humanart/yolox_x_8xb8-300e_humanart.py) | 640x640 | 61.3 | - | - | - | [Det Model](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/yolox_x_8xb8-300e_humanart-a39d44ed.pth) | + #### 26 Keypoints diff --git a/projects/rtmpose/README_CN.md b/projects/rtmpose/README_CN.md index 05092f2e46..a08c74283f 100644 --- a/projects/rtmpose/README_CN.md +++ b/projects/rtmpose/README_CN.md @@ -210,6 +210,8 @@ RTMPose 是一个长期优化迭代的项目,致力于业务场景下的高性 - 面向艺术图片的人体姿态估计 RTMPose 模型由 [Human-Art](https://github.com/IDEA-Research/HumanArt) 提供。 - +人体姿态估计模型: + | Config | Input Size | AP
(Human-Art GT) | Params
(M) | FLOPS
(G) | ORT-Latency
(ms)
(i7-11700) | TRT-FP16-Latency
(ms)
(GTX 1660Ti) | ncnn-FP16-Latency
(ms)
(Snapdragon 865) | Download | | :-----------------------------------------------------------------------------: | :--------: | :-----------------------: | :----------------: | :---------------: | :-----------------------------------------: | :------------------------------------------------: | :-----------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | | [RTMPose-t\*](./rtmpose/body_2d_keypoint/rtmpose-t_8xb256-420e_coco-256x192.py) | 256x192 | 65.5 | 3.34 | 0.36 | 3.20 | 1.06 | 9.02 | [pth](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-t_8xb256-420e_humanart-256x192-60b68c98_20230612.pth)
[onnx](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/onnx_sdk/rtmpose-t_8xb256-420e_humanart-256x192-60b68c98_20230612.zip) | @@ -217,6 +219,19 @@ RTMPose 是一个长期优化迭代的项目,致力于业务场景下的高性 | [RTMPose-m\*](./rtmpose/body_2d_keypoint/rtmpose-m_8xb256-420e_coco-256x192.py) | 256x192 | 72.8 | 13.59 | 1.93 | 11.06 | 2.29 | 26.44 | [pth](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-m_8xb256-420e_humanart-256x192-8430627b_20230611.pth)
[onnx](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/onnx_sdk/rtmpose-m_8xb256-420e_humanart-256x192-8430627b_20230611.zip) | | [RTMPose-l\*](./rtmpose/body_2d_keypoint/rtmpose-l_8xb256-420e_coco-256x192.py) | 256x192 | 75.3 | 27.66 | 4.16 | 18.85 | 3.46 | 45.37 | [pth](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-l_8xb256-420e_humanart-256x192-389f2cb0_20230611.pth)
[onnx](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/onnx_sdk/rtmpose-l_8xb256-420e_humanart-256x192-389f2cb0_20230611.zip) | +人体检测模型: + +| Detection Config | Input Size | Model AP
(OneHand10K) | Flops
(G) | ORT-Latency
(ms)
(i7-11700) | TRT-FP16-Latency
(ms)
(GTX 1660Ti) | Download | +| :---------------------------: | :--------: | :---------------------------: | :---------------: | :-----------------------------------------: | :------------------------------------------------: | :--------------------: | +| [RTMDet-tiny](./rtmdet/person/rtmdet_tiny_8xb32-300e_humanart.py) | 640x640 | 46.6 | - | - | - | [Det Model](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmdet_tiny_8xb32-300e_humanart-7da5554e.pth) | +| [RTMDet-s](./rtmdet/person/rtmdet_s_8xb32-300e_humanart.py) | 640x640 | 50.6 | - | - | - | [Det Model](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmdet_s_8xb32-300e_humanart-af5bd52d.pth) | +| [YOLOX-nano](./yolox/humanart/yolox_nano_8xb8-300e_humanart.py) | 640x640 | 38.9 | - | - | - | [Det Model](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/yolox_nano_8xb8-300e_humanart-40f6f0d0.pth) | +| [YOLOX-tiny](./yolox/humanart/yolox_tiny_8xb8-300e_humanart.py) | 640x640 | 47.7 | - | - | - | [Det Model](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/yolox_tiny_8xb8-300e_humanart-6f3252f9.pth) | +| [YOLOX-s](./yolox/humanart/yolox_s_8xb8-300e_humanart.py) | 640x640 | 54.6 | - | - | - | [Det Model](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/yolox_s_8xb8-300e_humanart-3ef259a7.pth) | +| [YOLOX-m](./yolox/humanart/yolox_m_8xb8-300e_humanart.py) | 640x640 | 59.1 | - | - | - | [Det Model](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/yolox_m_8xb8-300e_humanart-c2c7a14a.pth) | +| [YOLOX-l](./yolox/humanart/yolox_l_8xb8-300e_humanart.py) | 640x640 | 60.2 | - | - | - | [Det Model](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/yolox_l_8xb8-300e_humanart-ce1d7a62.pth) | +| [YOLOX-x](./yolox/humanart/yolox_x_8xb8-300e_humanart.py) | 640x640 | 61.3 | - | - | - | [Det Model](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/yolox_x_8xb8-300e_humanart-a39d44ed.pth) | + #### 26 Keypoints diff --git a/projects/rtmpose/rtmdet/person/humanart_detection.py b/projects/rtmpose/rtmdet/person/humanart_detection.py new file mode 100644 index 0000000000..a07a2499ce --- /dev/null +++ b/projects/rtmpose/rtmdet/person/humanart_detection.py @@ -0,0 +1,95 @@ +# dataset settings +dataset_type = 'CocoDataset' +data_root = 'data/' + +# Example to use different file client +# Method 1: simply set the data root and let the file I/O module +# automatically infer from prefix (not support LMDB and Memcache yet) + +# data_root = 's3://openmmlab/datasets/detection/coco/' + +# Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 +# backend_args = dict( +# backend='petrel', +# path_mapping=dict({ +# './data/': 's3://openmmlab/datasets/detection/', +# 'data/': 's3://openmmlab/datasets/detection/' +# })) +backend_args = None + +train_pipeline = [ + dict(type='LoadImageFromFile', backend_args=backend_args), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='Resize', scale=(1333, 800), keep_ratio=True), + dict(type='RandomFlip', prob=0.5), + dict(type='PackDetInputs') +] +test_pipeline = [ + dict(type='LoadImageFromFile', backend_args=backend_args), + dict(type='Resize', scale=(1333, 800), keep_ratio=True), + # If you don't have a gt annotation, delete the pipeline + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] +train_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + batch_sampler=dict(type='AspectRatioBatchSampler'), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='HumanArt/annotations/training_humanart_coco.json', + data_prefix=dict(img=''), + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=train_pipeline, + backend_args=backend_args)) +val_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='HumanArt/annotations/validation_humanart_coco.json', + data_prefix=dict(img=''), + test_mode=True, + pipeline=test_pipeline, + backend_args=backend_args)) +test_dataloader = val_dataloader + +val_evaluator = dict( + type='CocoMetric', + ann_file=data_root + 'HumanArt/annotations/validation_humanart_coco.json', + metric='bbox', + format_only=False, + backend_args=backend_args) +test_evaluator = val_evaluator + +# inference on test dataset and +# format the output results for submission. +# test_dataloader = dict( +# batch_size=1, +# num_workers=2, +# persistent_workers=True, +# drop_last=False, +# sampler=dict(type='DefaultSampler', shuffle=False), +# dataset=dict( +# type=dataset_type, +# data_root=data_root, +# ann_file=data_root + 'annotations/image_info_test-dev2017.json', +# data_prefix=dict(img='test2017/'), +# test_mode=True, +# pipeline=test_pipeline)) +# test_evaluator = dict( +# type='CocoMetric', +# metric='bbox', +# format_only=True, +# ann_file=data_root + 'annotations/image_info_test-dev2017.json', +# outfile_prefix='./work_dirs/coco_detection/test') diff --git a/projects/rtmpose/rtmdet/person/rtmdet_l_8xb32-300e_humanart.py b/projects/rtmpose/rtmdet/person/rtmdet_l_8xb32-300e_humanart.py new file mode 100644 index 0000000000..7b009072c6 --- /dev/null +++ b/projects/rtmpose/rtmdet/person/rtmdet_l_8xb32-300e_humanart.py @@ -0,0 +1,180 @@ +_base_ = [ + 'mmdet::_base_/default_runtime.py', + 'mmdet::_base_/schedules/schedule_1x.py', './humanart_detection.py', + 'mmdet::rtmdet_tta.py' +] +model = dict( + type='RTMDet', + data_preprocessor=dict( + type='DetDataPreprocessor', + mean=[103.53, 116.28, 123.675], + std=[57.375, 57.12, 58.395], + bgr_to_rgb=False, + batch_augments=None), + backbone=dict( + type='CSPNeXt', + arch='P5', + expand_ratio=0.5, + deepen_factor=1, + widen_factor=1, + channel_attention=True, + norm_cfg=dict(type='SyncBN'), + act_cfg=dict(type='SiLU', inplace=True)), + neck=dict( + type='CSPNeXtPAFPN', + in_channels=[256, 512, 1024], + out_channels=256, + num_csp_blocks=3, + expand_ratio=0.5, + norm_cfg=dict(type='SyncBN'), + act_cfg=dict(type='SiLU', inplace=True)), + bbox_head=dict( + type='RTMDetSepBNHead', + num_classes=80, + in_channels=256, + stacked_convs=2, + feat_channels=256, + anchor_generator=dict( + type='MlvlPointGenerator', offset=0, strides=[8, 16, 32]), + bbox_coder=dict(type='DistancePointBBoxCoder'), + loss_cls=dict( + type='QualityFocalLoss', + use_sigmoid=True, + beta=2.0, + loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=2.0), + with_objectness=False, + exp_on_reg=True, + share_conv=True, + pred_kernel_size=1, + norm_cfg=dict(type='SyncBN'), + act_cfg=dict(type='SiLU', inplace=True)), + train_cfg=dict( + assigner=dict(type='DynamicSoftLabelAssigner', topk=13), + allowed_border=-1, + pos_weight=-1, + debug=False), + test_cfg=dict( + nms_pre=30000, + min_bbox_size=0, + score_thr=0.001, + nms=dict(type='nms', iou_threshold=0.65), + max_per_img=300), +) + +train_pipeline = [ + dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='CachedMosaic', img_scale=(640, 640), pad_val=114.0), + dict( + type='RandomResize', + scale=(1280, 1280), + ratio_range=(0.1, 2.0), + keep_ratio=True), + dict(type='RandomCrop', crop_size=(640, 640)), + dict(type='YOLOXHSVRandomAug'), + dict(type='RandomFlip', prob=0.5), + dict(type='Pad', size=(640, 640), pad_val=dict(img=(114, 114, 114))), + dict( + type='CachedMixUp', + img_scale=(640, 640), + ratio_range=(1.0, 1.0), + max_cached_images=20, + pad_val=(114, 114, 114)), + dict(type='PackDetInputs') +] + +train_pipeline_stage2 = [ + dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}), + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='RandomResize', + scale=(640, 640), + ratio_range=(0.1, 2.0), + keep_ratio=True), + dict(type='RandomCrop', crop_size=(640, 640)), + dict(type='YOLOXHSVRandomAug'), + dict(type='RandomFlip', prob=0.5), + dict(type='Pad', size=(640, 640), pad_val=dict(img=(114, 114, 114))), + dict(type='PackDetInputs') +] + +test_pipeline = [ + dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}), + dict(type='Resize', scale=(640, 640), keep_ratio=True), + dict(type='Pad', size=(640, 640), pad_val=dict(img=(114, 114, 114))), + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] + +train_dataloader = dict( + batch_size=32, + num_workers=10, + batch_sampler=None, + pin_memory=True, + dataset=dict(pipeline=train_pipeline)) +val_dataloader = dict( + batch_size=5, num_workers=10, dataset=dict(pipeline=test_pipeline)) +test_dataloader = val_dataloader + +max_epochs = 300 +stage2_num_epochs = 20 +base_lr = 0.0005 +interval = 10 + +train_cfg = dict( + max_epochs=max_epochs, + val_interval=interval, + dynamic_intervals=[(max_epochs - stage2_num_epochs, 1)]) + +val_evaluator = dict(proposal_nums=(100, 1, 10)) +test_evaluator = val_evaluator + +# optimizer +optim_wrapper = dict( + _delete_=True, + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05), + paramwise_cfg=dict( + norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True)) + +# learning rate +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1.0e-5, + by_epoch=False, + begin=0, + end=1000), + dict( + # use cosine lr from 150 to 300 epoch + type='CosineAnnealingLR', + eta_min=base_lr * 0.05, + begin=max_epochs // 2, + end=max_epochs, + T_max=max_epochs // 2, + by_epoch=True, + convert_to_iter_based=True), +] + +# hooks +default_hooks = dict( + checkpoint=dict( + interval=interval, + max_keep_ckpts=3 # only keep latest 3 checkpoints + )) +custom_hooks = [ + dict( + type='EMAHook', + ema_type='ExpMomentumEMA', + momentum=0.0002, + update_buffers=True, + priority=49), + dict( + type='PipelineSwitchHook', + switch_epoch=max_epochs - stage2_num_epochs, + switch_pipeline=train_pipeline_stage2) +] diff --git a/projects/rtmpose/rtmdet/person/rtmdet_m_8xb32-300e_humanart.py b/projects/rtmpose/rtmdet/person/rtmdet_m_8xb32-300e_humanart.py new file mode 100644 index 0000000000..263ec89347 --- /dev/null +++ b/projects/rtmpose/rtmdet/person/rtmdet_m_8xb32-300e_humanart.py @@ -0,0 +1,6 @@ +_base_ = './rtmdet_l_8xb32-300e_humanart.py' + +model = dict( + backbone=dict(deepen_factor=0.67, widen_factor=0.75), + neck=dict(in_channels=[192, 384, 768], out_channels=192, num_csp_blocks=2), + bbox_head=dict(in_channels=192, feat_channels=192)) diff --git a/projects/rtmpose/rtmdet/person/rtmdet_s_8xb32-300e_humanart.py b/projects/rtmpose/rtmdet/person/rtmdet_s_8xb32-300e_humanart.py new file mode 100644 index 0000000000..927cbf7555 --- /dev/null +++ b/projects/rtmpose/rtmdet/person/rtmdet_s_8xb32-300e_humanart.py @@ -0,0 +1,62 @@ +_base_ = './rtmdet_l_8xb32-300e_humanart.py' +checkpoint = 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-s_imagenet_600e.pth' # noqa +model = dict( + backbone=dict( + deepen_factor=0.33, + widen_factor=0.5, + init_cfg=dict( + type='Pretrained', prefix='backbone.', checkpoint=checkpoint)), + neck=dict(in_channels=[128, 256, 512], out_channels=128, num_csp_blocks=1), + bbox_head=dict(in_channels=128, feat_channels=128, exp_on_reg=False)) + +train_pipeline = [ + dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='CachedMosaic', img_scale=(640, 640), pad_val=114.0), + dict( + type='RandomResize', + scale=(1280, 1280), + ratio_range=(0.5, 2.0), + keep_ratio=True), + dict(type='RandomCrop', crop_size=(640, 640)), + dict(type='YOLOXHSVRandomAug'), + dict(type='RandomFlip', prob=0.5), + dict(type='Pad', size=(640, 640), pad_val=dict(img=(114, 114, 114))), + dict( + type='CachedMixUp', + img_scale=(640, 640), + ratio_range=(1.0, 1.0), + max_cached_images=20, + pad_val=(114, 114, 114)), + dict(type='PackDetInputs') +] + +train_pipeline_stage2 = [ + dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}), + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='RandomResize', + scale=(640, 640), + ratio_range=(0.5, 2.0), + keep_ratio=True), + dict(type='RandomCrop', crop_size=(640, 640)), + dict(type='YOLOXHSVRandomAug'), + dict(type='RandomFlip', prob=0.5), + dict(type='Pad', size=(640, 640), pad_val=dict(img=(114, 114, 114))), + dict(type='PackDetInputs') +] + +train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) + +custom_hooks = [ + dict( + type='EMAHook', + ema_type='ExpMomentumEMA', + momentum=0.0002, + update_buffers=True, + priority=49), + dict( + type='PipelineSwitchHook', + switch_epoch=280, + switch_pipeline=train_pipeline_stage2) +] diff --git a/projects/rtmpose/rtmdet/person/rtmdet_tiny_8xb32-300e_humanart.py b/projects/rtmpose/rtmdet/person/rtmdet_tiny_8xb32-300e_humanart.py new file mode 100644 index 0000000000..c92442fa8d --- /dev/null +++ b/projects/rtmpose/rtmdet/person/rtmdet_tiny_8xb32-300e_humanart.py @@ -0,0 +1,43 @@ +_base_ = './rtmdet_s_8xb32-300e_humanart.py' + +checkpoint = 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-tiny_imagenet_600e.pth' # noqa + +model = dict( + backbone=dict( + deepen_factor=0.167, + widen_factor=0.375, + init_cfg=dict( + type='Pretrained', prefix='backbone.', checkpoint=checkpoint)), + neck=dict(in_channels=[96, 192, 384], out_channels=96, num_csp_blocks=1), + bbox_head=dict(in_channels=96, feat_channels=96, exp_on_reg=False)) + +train_pipeline = [ + dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}), + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='CachedMosaic', + img_scale=(640, 640), + pad_val=114.0, + max_cached_images=20, + random_pop=False), + dict( + type='RandomResize', + scale=(1280, 1280), + ratio_range=(0.5, 2.0), + keep_ratio=True), + dict(type='RandomCrop', crop_size=(640, 640)), + dict(type='YOLOXHSVRandomAug'), + dict(type='RandomFlip', prob=0.5), + dict(type='Pad', size=(640, 640), pad_val=dict(img=(114, 114, 114))), + dict( + type='CachedMixUp', + img_scale=(640, 640), + ratio_range=(1.0, 1.0), + max_cached_images=10, + random_pop=False, + pad_val=(114, 114, 114), + prob=0.5), + dict(type='PackDetInputs') +] + +train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) diff --git a/projects/rtmpose/rtmdet/person/rtmdet_x_8xb32-300e_humanart.py b/projects/rtmpose/rtmdet/person/rtmdet_x_8xb32-300e_humanart.py new file mode 100644 index 0000000000..60fd09c866 --- /dev/null +++ b/projects/rtmpose/rtmdet/person/rtmdet_x_8xb32-300e_humanart.py @@ -0,0 +1,7 @@ +_base_ = './rtmdet_l_8xb32-300e_humanart.py' + +model = dict( + backbone=dict(deepen_factor=1.33, widen_factor=1.25), + neck=dict( + in_channels=[320, 640, 1280], out_channels=320, num_csp_blocks=4), + bbox_head=dict(in_channels=320, feat_channels=320)) diff --git a/projects/rtmpose/yolox/humanart/yolox_l_8xb8-300e_humanart.py b/projects/rtmpose/yolox/humanart/yolox_l_8xb8-300e_humanart.py new file mode 100644 index 0000000000..6fd4354cec --- /dev/null +++ b/projects/rtmpose/yolox/humanart/yolox_l_8xb8-300e_humanart.py @@ -0,0 +1,8 @@ +_base_ = './yolox_s_8xb8-300e_humanart.py' + +# model settings +model = dict( + backbone=dict(deepen_factor=1.0, widen_factor=1.0), + neck=dict( + in_channels=[256, 512, 1024], out_channels=256, num_csp_blocks=3), + bbox_head=dict(in_channels=256, feat_channels=256)) diff --git a/projects/rtmpose/yolox/humanart/yolox_m_8xb8-300e_humanart.py b/projects/rtmpose/yolox/humanart/yolox_m_8xb8-300e_humanart.py new file mode 100644 index 0000000000..e74e2bb99c --- /dev/null +++ b/projects/rtmpose/yolox/humanart/yolox_m_8xb8-300e_humanart.py @@ -0,0 +1,8 @@ +_base_ = './yolox_s_8xb8-300e_humanart.py' + +# model settings +model = dict( + backbone=dict(deepen_factor=0.67, widen_factor=0.75), + neck=dict(in_channels=[192, 384, 768], out_channels=192, num_csp_blocks=2), + bbox_head=dict(in_channels=192, feat_channels=192), +) diff --git a/projects/rtmpose/yolox/humanart/yolox_nano_8xb8-300e_humanart.py b/projects/rtmpose/yolox/humanart/yolox_nano_8xb8-300e_humanart.py new file mode 100644 index 0000000000..96a363abec --- /dev/null +++ b/projects/rtmpose/yolox/humanart/yolox_nano_8xb8-300e_humanart.py @@ -0,0 +1,11 @@ +_base_ = './yolox_tiny_8xb8-300e_humanart.py' + +# model settings +model = dict( + backbone=dict(deepen_factor=0.33, widen_factor=0.25, use_depthwise=True), + neck=dict( + in_channels=[64, 128, 256], + out_channels=64, + num_csp_blocks=1, + use_depthwise=True), + bbox_head=dict(in_channels=64, feat_channels=64, use_depthwise=True)) diff --git a/projects/rtmpose/yolox/humanart/yolox_s_8xb8-300e_humanart.py b/projects/rtmpose/yolox/humanart/yolox_s_8xb8-300e_humanart.py new file mode 100644 index 0000000000..a7992b076d --- /dev/null +++ b/projects/rtmpose/yolox/humanart/yolox_s_8xb8-300e_humanart.py @@ -0,0 +1,250 @@ +_base_ = [ + 'mmdet::_base_/schedules/schedule_1x.py', + 'mmdet::_base_/default_runtime.py', 'mmdet::yolox/yolox_tta.py' +] + +img_scale = (640, 640) # width, height + +# model settings +model = dict( + type='YOLOX', + data_preprocessor=dict( + type='DetDataPreprocessor', + pad_size_divisor=32, + batch_augments=[ + dict( + type='BatchSyncRandomResize', + random_size_range=(480, 800), + size_divisor=32, + interval=10) + ]), + backbone=dict( + type='CSPDarknet', + deepen_factor=0.33, + widen_factor=0.5, + out_indices=(2, 3, 4), + use_depthwise=False, + spp_kernal_sizes=(5, 9, 13), + norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), + act_cfg=dict(type='Swish'), + ), + neck=dict( + type='YOLOXPAFPN', + in_channels=[128, 256, 512], + out_channels=128, + num_csp_blocks=1, + use_depthwise=False, + upsample_cfg=dict(scale_factor=2, mode='nearest'), + norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), + act_cfg=dict(type='Swish')), + bbox_head=dict( + type='YOLOXHead', + num_classes=80, + in_channels=128, + feat_channels=128, + stacked_convs=2, + strides=(8, 16, 32), + use_depthwise=False, + norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), + act_cfg=dict(type='Swish'), + loss_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + reduction='sum', + loss_weight=1.0), + loss_bbox=dict( + type='IoULoss', + mode='square', + eps=1e-16, + reduction='sum', + loss_weight=5.0), + loss_obj=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + reduction='sum', + loss_weight=1.0), + loss_l1=dict(type='L1Loss', reduction='sum', loss_weight=1.0)), + train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)), + # In order to align the source code, the threshold of the val phase is + # 0.01, and the threshold of the test phase is 0.001. + test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65))) + +# dataset settings +data_root = 'data/' +dataset_type = 'CocoDataset' + +# Example to use different file client +# Method 1: simply set the data root and let the file I/O module +# automatically infer from prefix (not support LMDB and Memcache yet) + +# data_root = 's3://openmmlab/datasets/detection/coco/' + +# Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 +# backend_args = dict( +# backend='petrel', +# path_mapping=dict({ +# './data/': 's3://openmmlab/datasets/detection/', +# 'data/': 's3://openmmlab/datasets/detection/' +# })) +backend_args = None + +train_pipeline = [ + dict(type='Mosaic', img_scale=img_scale, pad_val=114.0), + dict( + type='RandomAffine', + scaling_ratio_range=(0.1, 2), + # img_scale is (width, height) + border=(-img_scale[0] // 2, -img_scale[1] // 2)), + dict( + type='MixUp', + img_scale=img_scale, + ratio_range=(0.8, 1.6), + pad_val=114.0), + dict(type='YOLOXHSVRandomAug'), + dict(type='RandomFlip', prob=0.5), + # According to the official implementation, multi-scale + # training is not considered here but in the + # 'mmdet/models/detectors/yolox.py'. + # Resize and Pad are for the last 15 epochs when Mosaic, + # RandomAffine, and MixUp are closed by YOLOXModeSwitchHook. + dict(type='Resize', scale=img_scale, keep_ratio=True), + dict( + type='Pad', + pad_to_square=True, + # If the image is three-channel, the pad value needs + # to be set separately for each channel. + pad_val=dict(img=(114.0, 114.0, 114.0))), + dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False), + dict(type='PackDetInputs') +] + +train_dataset = dict( + # use MultiImageMixDataset wrapper to support mosaic and mixup + type='MultiImageMixDataset', + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='HumanArt/annotations/training_humanart_coco.json', + data_prefix=dict(img=''), + pipeline=[ + dict(type='LoadImageFromFile', backend_args=backend_args), + dict(type='LoadAnnotations', with_bbox=True) + ], + filter_cfg=dict(filter_empty_gt=False, min_size=32), + backend_args=backend_args), + pipeline=train_pipeline) + +test_pipeline = [ + dict(type='LoadImageFromFile', backend_args=backend_args), + dict(type='Resize', scale=img_scale, keep_ratio=True), + dict( + type='Pad', + pad_to_square=True, + pad_val=dict(img=(114.0, 114.0, 114.0))), + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] + +train_dataloader = dict( + batch_size=8, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=train_dataset) +val_dataloader = dict( + batch_size=8, + num_workers=4, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='HumanArt/annotations/validation_humanart_coco.json', + data_prefix=dict(img=''), + test_mode=True, + pipeline=test_pipeline, + backend_args=backend_args)) +test_dataloader = val_dataloader + +val_evaluator = dict( + type='CocoMetric', + ann_file=data_root + 'HumanArt/annotations/validation_humanart_coco.json', + metric='bbox', + backend_args=backend_args) +test_evaluator = val_evaluator + +# training settings +max_epochs = 300 +num_last_epochs = 15 +interval = 10 + +train_cfg = dict(max_epochs=max_epochs, val_interval=interval) + +# optimizer +# default 8 gpu +base_lr = 0.01 +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict( + type='SGD', lr=base_lr, momentum=0.9, weight_decay=5e-4, + nesterov=True), + paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.)) + +# learning rate +param_scheduler = [ + dict( + # use quadratic formula to warm up 5 epochs + # and lr is updated by iteration + # TODO: fix default scope in get function + type='mmdet.QuadraticWarmupLR', + by_epoch=True, + begin=0, + end=5, + convert_to_iter_based=True), + dict( + # use cosine lr from 5 to 285 epoch + type='CosineAnnealingLR', + eta_min=base_lr * 0.05, + begin=5, + T_max=max_epochs - num_last_epochs, + end=max_epochs - num_last_epochs, + by_epoch=True, + convert_to_iter_based=True), + dict( + # use fixed lr during last 15 epochs + type='ConstantLR', + by_epoch=True, + factor=1, + begin=max_epochs - num_last_epochs, + end=max_epochs, + ) +] + +default_hooks = dict( + checkpoint=dict( + interval=interval, + max_keep_ckpts=3 # only keep latest 3 checkpoints + )) + +custom_hooks = [ + dict( + type='YOLOXModeSwitchHook', + num_last_epochs=num_last_epochs, + priority=48), + dict(type='SyncNormHook', priority=48), + dict( + type='EMAHook', + ema_type='ExpMomentumEMA', + momentum=0.0001, + update_buffers=True, + priority=49) +] + +# NOTE: `auto_scale_lr` is for automatically scaling LR, +# USER SHOULD NOT CHANGE ITS VALUES. +# base_batch_size = (8 GPUs) x (8 samples per GPU) +auto_scale_lr = dict(base_batch_size=64) diff --git a/projects/rtmpose/yolox/humanart/yolox_tiny_8xb8-300e_humanart.py b/projects/rtmpose/yolox/humanart/yolox_tiny_8xb8-300e_humanart.py new file mode 100644 index 0000000000..71971e2ddc --- /dev/null +++ b/projects/rtmpose/yolox/humanart/yolox_tiny_8xb8-300e_humanart.py @@ -0,0 +1,54 @@ +_base_ = './yolox_s_8xb8-300e_humanart.py' + +# model settings +model = dict( + data_preprocessor=dict(batch_augments=[ + dict( + type='BatchSyncRandomResize', + random_size_range=(320, 640), + size_divisor=32, + interval=10) + ]), + backbone=dict(deepen_factor=0.33, widen_factor=0.375), + neck=dict(in_channels=[96, 192, 384], out_channels=96), + bbox_head=dict(in_channels=96, feat_channels=96)) + +img_scale = (640, 640) # width, height + +train_pipeline = [ + dict(type='Mosaic', img_scale=img_scale, pad_val=114.0), + dict( + type='RandomAffine', + scaling_ratio_range=(0.5, 1.5), + # img_scale is (width, height) + border=(-img_scale[0] // 2, -img_scale[1] // 2)), + dict(type='YOLOXHSVRandomAug'), + dict(type='RandomFlip', prob=0.5), + # Resize and Pad are for the last 15 epochs when Mosaic and + # RandomAffine are closed by YOLOXModeSwitchHook. + dict(type='Resize', scale=img_scale, keep_ratio=True), + dict( + type='Pad', + pad_to_square=True, + pad_val=dict(img=(114.0, 114.0, 114.0))), + dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False), + dict(type='PackDetInputs') +] + +test_pipeline = [ + dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}), + dict(type='Resize', scale=(416, 416), keep_ratio=True), + dict( + type='Pad', + pad_to_square=True, + pad_val=dict(img=(114.0, 114.0, 114.0))), + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] + +train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) +val_dataloader = dict(dataset=dict(pipeline=test_pipeline)) +test_dataloader = val_dataloader diff --git a/projects/rtmpose/yolox/humanart/yolox_x_8xb8-300e_humanart.py b/projects/rtmpose/yolox/humanart/yolox_x_8xb8-300e_humanart.py new file mode 100644 index 0000000000..6e03ffefb6 --- /dev/null +++ b/projects/rtmpose/yolox/humanart/yolox_x_8xb8-300e_humanart.py @@ -0,0 +1,8 @@ +_base_ = './yolox_s_8xb8-300e_humanart.py' + +# model settings +model = dict( + backbone=dict(deepen_factor=1.33, widen_factor=1.25), + neck=dict( + in_channels=[320, 640, 1280], out_channels=320, num_csp_blocks=4), + bbox_head=dict(in_channels=320, feat_channels=320))