diff --git a/python/paddle/hapi/model_summary.py b/python/paddle/hapi/model_summary.py index d78196d94451e..93f1a5a37a67f 100644 --- a/python/paddle/hapi/model_summary.py +++ b/python/paddle/hapi/model_summary.py @@ -80,6 +80,23 @@ def forward(self, inputs): params_info = paddle.summary(lenet, (1, 1, 28, 28)) print(params_info) + # multi input demo + class LeNetMultiInput(LeNet): + + def forward(self, inputs, y): + x = self.features(inputs) + + if self.num_classes > 0: + x = paddle.flatten(x, 1) + x = self.fc(x + y) + return x + + lenet_multi_input = LeNetMultiInput() + + params_info = paddle.summary(lenet_multi_input, [(1, 1, 28, 28), (1, 400)], + ['float32', 'float32']) + print(params_info) + """ if isinstance(input_size, InputSpec): _input_size = tuple(input_size.shape)