Skip to content

Commit

Permalink
fix model bugs, inputs can be InputSpec instance
Browse files Browse the repository at this point in the history
  • Loading branch information
LiuChiachi committed Oct 10, 2020
1 parent ad99e63 commit 7e8b2f9
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 11 deletions.
23 changes: 13 additions & 10 deletions python/paddle/hapi/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,11 @@ def _init_context():


def _update_input_shapes(inputs):
"Get input shape list by given inputs in Model initialization."
shapes = None
if isinstance(inputs, list):
if isinstance(inputs, Input):
shapes = [list(inputs.shape)]
elif isinstance(inputs, list):
shapes = [list(input.shape) for input in inputs]
elif isinstance(inputs, dict):
shapes = [list(inputs[name].shape) for name in inputs]
Expand Down Expand Up @@ -917,9 +920,7 @@ def train_batch(self, inputs, labels=None):
"""
loss = self._adapter.train_batch(inputs, labels)
if fluid.in_dygraph_mode() and self._input_shapes is None:
self._input_shapes = self._adapter._input_shapes
self._is_shape_inferred = True
self._inputs = self._verify_spec(None, self._input_shapes, True)
self._update_inputs()
return loss

def eval_batch(self, inputs, labels=None):
Expand Down Expand Up @@ -967,9 +968,7 @@ def eval_batch(self, inputs, labels=None):
"""
loss = self._adapter.eval_batch(inputs, labels)
if fluid.in_dygraph_mode() and self._input_shapes is None:
self._input_shapes = self._adapter._input_shapes
self._is_shape_inferred = True
self._inputs = self._verify_spec(None, self._input_shapes, True)
self._update_inputs()
return loss

def test_batch(self, inputs):
Expand Down Expand Up @@ -1012,9 +1011,7 @@ def test_batch(self, inputs):
"""
loss = self._adapter.test_batch(inputs)
if fluid.in_dygraph_mode() and self._input_shapes is None:
self._input_shapes = self._adapter._input_shapes
self._is_shape_inferred = True
self._inputs = self._verify_spec(None, self._input_shapes, True)
self._update_inputs()
return loss

def save(self, path, training=True):
Expand Down Expand Up @@ -1953,3 +1950,9 @@ def _len_data_loader(self, data_loader):
except Exception:
steps = None
return steps

def _update_inputs(self):
"Update self._inputs according to given inputs."
self._input_shapes = self._adapter._input_shapes
self._is_shape_inferred = True
self._inputs = self._verify_spec(None, self._input_shapes, True)
15 changes: 14 additions & 1 deletion python/paddle/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,9 +551,10 @@ def test_export_deploy_model(self):
shutil.rmtree(save_dir)
paddle.enable_static()

def test_dygraph_export_deploy_model_without_inputs(self):
def test_dygraph_export_deploy_model_about_inputs(self):
mnist_data = MnistDataset(mode='train')
paddle.disable_static()
# without inputs
for initial in ["fit", "train_batch", "eval_batch", "test_batch"]:
save_dir = tempfile.mkdtemp()
if not os.path.exists(save_dir):
Expand All @@ -579,6 +580,18 @@ def test_dygraph_export_deploy_model_without_inputs(self):

model.save(save_dir, training=False)
shutil.rmtree(save_dir)
# with inputs, and the type of inputs is InputSpec
save_dir = tempfile.mkdtemp()
if not os.path.exists(save_dir):
os.makedirs(save_dir)
net = LeNet()
inputs = InputSpec([None, 1, 28, 28], 'float32', 'x')
model = Model(net, inputs)
optim = fluid.optimizer.Adam(
learning_rate=0.001, parameter_list=model.parameters())
model.prepare(optimizer=optim, loss=CrossEntropyLoss(reduction="sum"))
model.save(save_dir, training=False)
shutil.rmtree(save_dir)


class TestRaiseError(unittest.TestCase):
Expand Down

0 comments on commit 7e8b2f9

Please sign in to comment.