diff --git a/paddleseg/core/train.py b/paddleseg/core/train.py index da14534459..b075164c40 100644 --- a/paddleseg/core/train.py +++ b/paddleseg/core/train.py @@ -99,8 +99,8 @@ def train(model, keep_checkpoint_max (int, optional): Maximum number of checkpoints to save. Default: 5. test_config(dict, optional): Evaluation config. precision (str, optional): Use AMP if precision='fp16'. If precision='fp32', the training is normal. - amp_level (str, optional): Auto mixed precision level. Accepted values are “O1” and “O2”: O1 represent mixed precision, - the input data type of each operator will be casted by white_list and black_list; O2 represent Pure fp16, all operators + amp_level (str, optional): Auto mixed precision level. Accepted values are “O1” and “O2”: O1 represent mixed precision, + the input data type of each operator will be casted by white_list and black_list; O2 represent Pure fp16, all operators parameters and input data will be casted to fp16, except operators in black_list, don’t support fp16 kernel and batchnorm. Default is O1(amp) profiler_options (str, optional): The option of train profiler. to_static_training (bool, optional): Whether to use @to_static for training. @@ -220,9 +220,10 @@ def train(model, scaled = scaler.scale(loss) # scale the loss scaled.backward() # do backward if isinstance(optimizer, paddle.distributed.fleet.Fleet): - scaler.minimize(optimizer.user_defined_optimizer, scaled) + scaler.step(optimizer.user_defined_optimizer) else: - scaler.minimize(optimizer, scaled) # update parameters + scaler.step(optimizer) # update parameters + scaler.update() # update parameters else: logits_list = ddp_model(images) if nranks > 1 else model(images)