Skip to content

Commit

Permalink
fix: minimize() dont support parameter_list of type dict
Browse files Browse the repository at this point in the history
there are diffs that step()+update() and minimize().
this will be fixed in PaddlePaddle/Paddle#53773.
  • Loading branch information
root committed May 15, 2023
1 parent 7fd2857 commit 61adad6
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions paddleseg/core/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def train(model,
keep_checkpoint_max (int, optional): Maximum number of checkpoints to save. Default: 5.
test_config(dict, optional): Evaluation config.
precision (str, optional): Use AMP if precision='fp16'. If precision='fp32', the training is normal.
amp_level (str, optional): Auto mixed precision level. Accepted values are “O1” and “O2”: O1 represent mixed precision,
the input data type of each operator will be casted by white_list and black_list; O2 represent Pure fp16, all operators
amp_level (str, optional): Auto mixed precision level. Accepted values are “O1” and “O2”: O1 represent mixed precision,
the input data type of each operator will be casted by white_list and black_list; O2 represent Pure fp16, all operators
parameters and input data will be casted to fp16, except operators in black_list, don’t support fp16 kernel and batchnorm. Default is O1(amp)
profiler_options (str, optional): The option of train profiler.
to_static_training (bool, optional): Whether to use @to_static for training.
Expand Down Expand Up @@ -220,9 +220,10 @@ def train(model,
scaled = scaler.scale(loss) # scale the loss
scaled.backward() # do backward
if isinstance(optimizer, paddle.distributed.fleet.Fleet):
scaler.minimize(optimizer.user_defined_optimizer, scaled)
scaler.step(optimizer.user_defined_optimizer)
else:
scaler.minimize(optimizer, scaled) # update parameters
scaler.step(optimizer) # update parameters
scaler.update() # update parameters
else:
logits_list = ddp_model(images) if nranks > 1 else model(images)

Expand Down

0 comments on commit 61adad6

Please sign in to comment.