Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Optimizer] Add master weight for opt state_dict #39121

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions python/paddle/fluid/tests/unittests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from paddle.fluid.backward import append_backward
from paddle.fluid.framework import Program, program_guard, convert_np_dtype_to_dtype_
import paddle
from paddle.io import Dataset
import numpy
paddle.enable_static()


Expand Down Expand Up @@ -1113,5 +1115,90 @@ def test_float32(self):
self.check_with_dtype('float32')


class TestMasterWeightSaveForFP16(unittest.TestCase):
'''
For Amp-O2, some optimizer(Momentum, Adam ...) will create master weights for parameters to to improve the accuracy.
Master weights will be saved by optimizer::state_dict.
'''

def check_with_opt_state_dict(self, use_save_load=True):
paddle.seed(100)
numpy.random.seed(100)

class SimpleNet(paddle.nn.Layer):
def __init__(self, input_size, output_size):
super(SimpleNet, self).__init__()
self.linears = paddle.nn.LayerList([
paddle.nn.Linear(input_size, output_size) for i in range(1)
])

def forward(self, x):
for i, l in enumerate(self.linears):
x = self.linears[i](x)
return x

input_size = 2 # 设为较大的值
output_size = 2 # 设为较大的值
batch_size = 2 # batch_size 为8的倍数
nums_batch = 10

class RandomDataset(Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples

def __getitem__(self, idx):
data = numpy.random.random([input_size]).astype('float16')
label = numpy.random.random([output_size]).astype('float16')
return data, label

def __len__(self):
return self.num_samples

dataset = RandomDataset(nums_batch * batch_size)
loader = paddle.io.DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
drop_last=True,
num_workers=0)

mse = paddle.nn.MSELoss()
model = SimpleNet(input_size, output_size) # 定义模型
optimizer = paddle.optimizer.Momentum(
learning_rate=0.0001,
parameters=model.parameters(),
multi_precision=True) # 定义优化器
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
model = paddle.amp.decorate(models=model, level='O2')

for i, (data, label) in enumerate(loader):
with paddle.amp.auto_cast(level='O2'):
output = model(data)
loss = mse(output, label)
scaled = scaler.scale(loss)
scaled.backward()
scaler.step(optimizer)
scaler.update()
optimizer.clear_grad(set_to_zero=False)

if use_save_load and i == 5:
paddle.save(model.state_dict(), "model.pdparams")
paddle.save(optimizer.state_dict(), "opt.pdopt")
model.set_state_dict(paddle.load("model.pdparams"))
optimizer.set_state_dict(paddle.load("opt.pdopt"))

return loss.numpy()

def test_with_state_dict(self):
if core.is_compiled_with_cuda():
with fluid.dygraph.guard():
out_use_state_dict = self.check_with_opt_state_dict(
use_save_load=True)
out_no_state_dict = self.check_with_opt_state_dict(
use_save_load=False)
self.assertTrue(
np.array_equal(out_use_state_dict, out_no_state_dict))


if __name__ == '__main__':
unittest.main()
8 changes: 8 additions & 0 deletions python/paddle/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,10 @@ def state_dict(self):
for k, v in self._accumulators.items():
for para_name, var_tmp in v.items():
state_dict[var_tmp.name] = var_tmp
# if has master weight and then save master weight
if hasattr(self, "_master_weights"):
if len(self._master_weights) != 0:
state_dict["master_weights"] = self._master_weights
# global step if use lr decay
if isinstance(self._learning_rate, LRScheduler):
state_dict["LR_Scheduler"] = self._learning_rate.state_dict()
Expand Down Expand Up @@ -304,6 +308,10 @@ def set_state_dict(self, state_dict):
state_dict = state_dict.copy()
if "LR_Scheduler" in state_dict:
state_dict.pop("LR_Scheduler")
if "master_weights" in state_dict:
if hasattr(self, "_master_weights"):
self._master_weights = state_dict["master_weights"]
state_dict.pop("master_weights")
self._accumulators_holder = state_dict
for k, v in self._accumulators.items():
for para_name, var_tmp in v.items():
Expand Down