diff --git a/python/paddle/fluid/tests/unittests/test_optimizer.py b/python/paddle/fluid/tests/unittests/test_optimizer.py index 89c7be18a7dfa..e8820d5a8708e 100644 --- a/python/paddle/fluid/tests/unittests/test_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_optimizer.py @@ -25,6 +25,8 @@ from paddle.fluid.backward import append_backward from paddle.fluid.framework import Program, program_guard, convert_np_dtype_to_dtype_ import paddle +from paddle.io import Dataset +import numpy paddle.enable_static() @@ -1113,5 +1115,90 @@ def test_float32(self): self.check_with_dtype('float32') +class TestMasterWeightSaveForFP16(unittest.TestCase): + ''' + For Amp-O2, some optimizer(Momentum, Adam ...) will create master weights for parameters to to improve the accuracy. + Master weights will be saved by optimizer::state_dict. + ''' + + def check_with_opt_state_dict(self, use_save_load=True): + paddle.seed(100) + numpy.random.seed(100) + + class SimpleNet(paddle.nn.Layer): + def __init__(self, input_size, output_size): + super(SimpleNet, self).__init__() + self.linears = paddle.nn.LayerList([ + paddle.nn.Linear(input_size, output_size) for i in range(1) + ]) + + def forward(self, x): + for i, l in enumerate(self.linears): + x = self.linears[i](x) + return x + + input_size = 2 # 设为较大的值 + output_size = 2 # 设为较大的值 + batch_size = 2 # batch_size 为8的倍数 + nums_batch = 10 + + class RandomDataset(Dataset): + def __init__(self, num_samples): + self.num_samples = num_samples + + def __getitem__(self, idx): + data = numpy.random.random([input_size]).astype('float16') + label = numpy.random.random([output_size]).astype('float16') + return data, label + + def __len__(self): + return self.num_samples + + dataset = RandomDataset(nums_batch * batch_size) + loader = paddle.io.DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + drop_last=True, + num_workers=0) + + mse = paddle.nn.MSELoss() + model = SimpleNet(input_size, output_size) # 定义模型 + optimizer = paddle.optimizer.Momentum( + learning_rate=0.0001, + parameters=model.parameters(), + multi_precision=True) # 定义优化器 + scaler = paddle.amp.GradScaler(init_loss_scaling=1024) + model = paddle.amp.decorate(models=model, level='O2') + + for i, (data, label) in enumerate(loader): + with paddle.amp.auto_cast(level='O2'): + output = model(data) + loss = mse(output, label) + scaled = scaler.scale(loss) + scaled.backward() + scaler.step(optimizer) + scaler.update() + optimizer.clear_grad(set_to_zero=False) + + if use_save_load and i == 5: + paddle.save(model.state_dict(), "model.pdparams") + paddle.save(optimizer.state_dict(), "opt.pdopt") + model.set_state_dict(paddle.load("model.pdparams")) + optimizer.set_state_dict(paddle.load("opt.pdopt")) + + return loss.numpy() + + def test_with_state_dict(self): + if core.is_compiled_with_cuda(): + with fluid.dygraph.guard(): + out_use_state_dict = self.check_with_opt_state_dict( + use_save_load=True) + out_no_state_dict = self.check_with_opt_state_dict( + use_save_load=False) + self.assertTrue( + np.array_equal(out_use_state_dict, out_no_state_dict)) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/optimizer/optimizer.py b/python/paddle/optimizer/optimizer.py index d433921e826b3..989da2db7184e 100644 --- a/python/paddle/optimizer/optimizer.py +++ b/python/paddle/optimizer/optimizer.py @@ -256,6 +256,10 @@ def state_dict(self): for k, v in self._accumulators.items(): for para_name, var_tmp in v.items(): state_dict[var_tmp.name] = var_tmp + # if has master weight and then save master weight + if hasattr(self, "_master_weights"): + if len(self._master_weights) != 0: + state_dict["master_weights"] = self._master_weights # global step if use lr decay if isinstance(self._learning_rate, LRScheduler): state_dict["LR_Scheduler"] = self._learning_rate.state_dict() @@ -304,6 +308,10 @@ def set_state_dict(self, state_dict): state_dict = state_dict.copy() if "LR_Scheduler" in state_dict: state_dict.pop("LR_Scheduler") + if "master_weights" in state_dict: + if hasattr(self, "_master_weights"): + self._master_weights = state_dict["master_weights"] + state_dict.pop("master_weights") self._accumulators_holder = state_dict for k, v in self._accumulators.items(): for para_name, var_tmp in v.items():