From da110e065cc2eebc33ed51caa48e2653362a0474 Mon Sep 17 00:00:00 2001 From: zhangting2020 Date: Wed, 31 May 2023 20:07:31 +0800 Subject: [PATCH] improve amp training and fix nan error --- configs/keypoint/tiny_pose/tinypose_128x96.yml | 3 +++ configs/ppyolo/_base_/ppyolov2_r50vd_dcn.yml | 3 +++ configs/ppyolo/ppyolo_mbv3_large_coco.yml | 3 +++ ppdet/engine/trainer.py | 18 ++++++++++++++---- ppdet/modeling/layers.py | 2 ++ ppdet/modeling/losses/yolo_loss.py | 5 +++-- ppdet/optimizer/ema.py | 10 ++++++---- 7 files changed, 34 insertions(+), 10 deletions(-) diff --git a/configs/keypoint/tiny_pose/tinypose_128x96.yml b/configs/keypoint/tiny_pose/tinypose_128x96.yml index e213c299020..42ac83eefa4 100644 --- a/configs/keypoint/tiny_pose/tinypose_128x96.yml +++ b/configs/keypoint/tiny_pose/tinypose_128x96.yml @@ -14,6 +14,9 @@ trainsize: &trainsize [*train_width, *train_height] hmsize: &hmsize [24, 32] flip_perm: &flip_perm [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]] +# AMP training +init_loss_scaling: 32752 +master_grad: true #####model architecture: TopDownHRNet diff --git a/configs/ppyolo/_base_/ppyolov2_r50vd_dcn.yml b/configs/ppyolo/_base_/ppyolov2_r50vd_dcn.yml index 6288adeed8a..8a541c29d60 100644 --- a/configs/ppyolo/_base_/ppyolov2_r50vd_dcn.yml +++ b/configs/ppyolo/_base_/ppyolov2_r50vd_dcn.yml @@ -4,6 +4,9 @@ norm_type: sync_bn use_ema: true ema_decay: 0.9998 +# AMP training +master_grad: true + YOLOv3: backbone: ResNet neck: PPYOLOPAN diff --git a/configs/ppyolo/ppyolo_mbv3_large_coco.yml b/configs/ppyolo/ppyolo_mbv3_large_coco.yml index 01558786e5f..cdce0859b26 100644 --- a/configs/ppyolo/ppyolo_mbv3_large_coco.yml +++ b/configs/ppyolo/ppyolo_mbv3_large_coco.yml @@ -9,6 +9,9 @@ _BASE_: [ snapshot_epoch: 10 weights: output/ppyolo_mbv3_large_coco/model_final +# AMP training +master_grad: true + TrainReader: inputs_def: num_max_boxes: 90 diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index bfd92fd62fb..01e00b24b52 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -72,6 +72,7 @@ def __init__(self, cfg, mode='train'): self.amp_level = self.cfg.get('amp_level', 'O1') self.custom_white_list = self.cfg.get('custom_white_list', None) self.custom_black_list = self.cfg.get('custom_black_list', None) + self.use_master_grad = self.cfg.get('master_grad', False) if 'slim' in cfg and cfg['slim_type'] == 'PTQ': self.cfg['TestDataset'] = create('TestDataset')() @@ -179,10 +180,19 @@ def __init__(self, cfg, mode='train'): self.pruner = create('UnstructuredPruner')(self.model, steps_per_epoch) if self.use_amp and self.amp_level == 'O2': - self.model, self.optimizer = paddle.amp.decorate( - models=self.model, - optimizers=self.optimizer, - level=self.amp_level) + paddle_version = paddle.__version__[:3] + # paddle version >= 2.5.0 or develop + if paddle_version in ["2.5", "0.0"]: + self.model, self.optimizer = paddle.amp.decorate( + models=self.model, + optimizers=self.optimizer, + level=self.amp_level, + master_grad=self.use_master_grad) + else: + self.model, self.optimizer = paddle.amp.decorate( + models=self.model, + optimizers=self.optimizer, + level=self.amp_level) self.use_ema = ('use_ema' in cfg and cfg['use_ema']) if self.use_ema: ema_decay = self.cfg.get('ema_decay', 0.9998) diff --git a/ppdet/modeling/layers.py b/ppdet/modeling/layers.py index 86c6d9697ff..f91b840d6e7 100644 --- a/ppdet/modeling/layers.py +++ b/ppdet/modeling/layers.py @@ -362,6 +362,8 @@ def forward(self, x): padding=self.block_size // 2, data_format=self.data_format) mask = 1. - mask_inv + mask = mask.astype('float32') + x = x.astype('float32') y = x * mask * (mask.numel() / mask.sum()) return y diff --git a/ppdet/modeling/losses/yolo_loss.py b/ppdet/modeling/losses/yolo_loss.py index 1ba05f2c8ea..fecef9ada3c 100644 --- a/ppdet/modeling/losses/yolo_loss.py +++ b/ppdet/modeling/losses/yolo_loss.py @@ -190,8 +190,9 @@ def forward(self, inputs, targets, anchors): self.distill_pairs.clear() for x, t, anchor, downsample in zip(inputs, gt_targets, anchors, self.downsample): - yolo_loss = self.yolov3_loss(x, t, gt_box, anchor, downsample, - self.scale_x_y) + yolo_loss = self.yolov3_loss( + x.astype('float32'), t, gt_box, anchor, downsample, + self.scale_x_y) for k, v in yolo_loss.items(): if k in yolo_losses: yolo_losses[k] += v diff --git a/ppdet/optimizer/ema.py b/ppdet/optimizer/ema.py index 70d006b8fe3..e81214f4705 100644 --- a/ppdet/optimizer/ema.py +++ b/ppdet/optimizer/ema.py @@ -69,9 +69,9 @@ def __init__(self, self.state_dict = dict() for k, v in model.state_dict().items(): if k in self.ema_black_list: - self.state_dict[k] = v + self.state_dict[k] = v.astype('float32') else: - self.state_dict[k] = paddle.zeros_like(v) + self.state_dict[k] = paddle.zeros_like(v, dtype='float32') self._model_state = { k: weakref.ref(p) @@ -114,7 +114,7 @@ def update(self, model=None): for k, v in self.state_dict.items(): if k not in self.ema_black_list: - v = decay * v + (1 - decay) * model_dict[k] + v = decay * v + (1 - decay) * model_dict[k].astype('float32') v.stop_gradient = True self.state_dict[k] = v self.step += 1 @@ -123,13 +123,15 @@ def apply(self): if self.step == 0: return self.state_dict state_dict = dict() + model_dict = {k: p() for k, p in self._model_state.items()} for k, v in self.state_dict.items(): if k in self.ema_black_list: v.stop_gradient = True - state_dict[k] = v + state_dict[k] = v.astype(model_dict[k].dtype) else: if self.ema_decay_type != 'exponential': v = v / (1 - self._decay**self.step) + v = v.astype(model_dict[k].dtype) v.stop_gradient = True state_dict[k] = v self.epoch += 1