Skip to content

Commit

Permalink
fix paddle.summary's bug when outputs contains non-tensor (#34160)
Browse files Browse the repository at this point in the history
* fix paddle.summary's bug when output contains non-tensor
  • Loading branch information
HydrogenSulfate committed Jul 29, 2021
1 parent 02cc3c5 commit b7fac0f
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
4 changes: 3 additions & 1 deletion python/paddle/hapi/model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,10 @@ def _get_shape_from_tensor(x):
def _get_output_shape(output):
if isinstance(output, (list, tuple)):
output_shape = [_get_output_shape(o) for o in output]
else:
elif hasattr(output, 'shape'):
output_shape = list(output.shape)
else:
output_shape = []
return output_shape

def register_hook(layer):
Expand Down
25 changes: 25 additions & 0 deletions python/paddle/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,28 @@ def forward(self, inputs):
return x


class ModelInner(paddle.nn.Layer):
def __init__(self):
super(ModelInner, self).__init__()
self.fc = paddle.nn.Linear(3, 4)

def forward(self, x):
y = self.fc(x)
return y, 0


class ModelOutter(paddle.nn.Layer):
def __init__(self):
super(ModelOutter, self).__init__()
self.module1 = ModelInner()
self.module2 = paddle.nn.Linear(4, 5)

def forward(self, x):
y, dummpy = self.module1(x)
y = self.module2(y)
return y, 3


class LeNetListInput(LeNetDygraph):
def forward(self, inputs):
x = inputs[0]
Expand Down Expand Up @@ -607,6 +629,9 @@ def _get_param_from_state_dict(state_dict):
model.summary(input_size=[(20)])
model.summary(input_size=(20), dtype='float32')

def test_summary_non_tensor(self):
paddle.summary(ModelOutter(), input_size=(-1, 3))

def test_summary_nlp(self):
def _get_param_from_state_dict(state_dict):
params = 0
Expand Down

0 comments on commit b7fac0f

Please sign in to comment.