Skip to content

Commit

Permalink
Correct 2.0 API usage in hapi.model.load (#26829) (#26927)
Browse files Browse the repository at this point in the history
* replace fluid.optimizer.set_dict with optimizer.set_state_dict

* replace fluid.optimizer.set_dict with optimizer.set_state_dict

* add coverage rate

* Increase coverage rate, fix code style

* Increase coverage rate, fix code style

* add fit to generate optimizer.state_dict() to save .pdopt to increase coverage rate

* delete http.log
  • Loading branch information
LiuChiachi authored Sep 3, 2020
1 parent d29bda3 commit 64a118f
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
12 changes: 9 additions & 3 deletions python/paddle/hapi/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,8 +731,8 @@ def load(self, param_state_pairs, optim_state):
if not self.model._optimizer or not optim_state:
return

# If optimizer performs set_dict when state vars haven't been created,
# which would happen when set_dict before minimize, the state would be
# If optimizer performs set_state_dict when state vars haven't been created,
# which would happen when set_state_dict before minimize, the state would be
# stored in optimizer._accumulators_holder and loaded lazily.
# To contrive this when loading from static-graph saved states, extend
# state dict to include keys named accoring to dygraph naming rules.
Expand Down Expand Up @@ -776,7 +776,13 @@ def load(self, param_state_pairs, optim_state):
accum_name + "_0")
converted_state[dy_state_name] = state_var

self.model._optimizer.set_dict(converted_state)
if not hasattr(self.model._optimizer, 'set_state_dict'):
warnings.warn(
"paddle.fluid.optimizer is deprecated in API 2.0, please use paddle.optimizer instead"
)
self.model._optimizer.set_dict(converted_state)
else:
self.model._optimizer.set_state_dict(converted_state)


class Model(object):
Expand Down
23 changes: 23 additions & 0 deletions python/paddle/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,29 @@ def test_save_load(self):
shutil.rmtree(path)
fluid.disable_dygraph() if dynamic else None

def test_dynamic_load(self):
mnist_data = MnistDataset(mode='train')
for new_optimizer in [True, False]:
path = tempfile.mkdtemp()
paddle.disable_static()
net = LeNet()
inputs = [InputSpec([None, 1, 28, 28], 'float32', 'x')]
labels = [InputSpec([None, 1], 'int64', 'label')]
if new_optimizer:
optim = paddle.optimizer.Adam(
learning_rate=0.001, parameters=net.parameters())
else:
optim = fluid.optimizer.Adam(
learning_rate=0.001, parameter_list=net.parameters())
model = Model(net, inputs, labels)
model.prepare(
optimizer=optim, loss=CrossEntropyLoss(reduction="sum"))
model.fit(mnist_data, batch_size=64, verbose=0)
model.save(path + '/test')
model.load(path + '/test')
shutil.rmtree(path)
paddle.enable_static()

def test_dynamic_save_static_load(self):
path = tempfile.mkdtemp()
# dynamic saving
Expand Down

0 comments on commit 64a118f

Please sign in to comment.