diff --git a/python/paddle/hapi/dynamic_flops.py b/python/paddle/hapi/dynamic_flops.py index 5c48501c07d7e..077a70c91015c 100644 --- a/python/paddle/hapi/dynamic_flops.py +++ b/python/paddle/hapi/dynamic_flops.py @@ -181,7 +181,10 @@ def count_parameters(m, x, y): def count_io_info(m, x, y): m.register_buffer('input_shape', paddle.to_tensor(x[0].shape)) - m.register_buffer('output_shape', paddle.to_tensor(y.shape)) + if isinstance(y, (list, tuple)): + m.register_buffer('output_shape', paddle.to_tensor(y[0].shape)) + else: + m.register_buffer('output_shape', paddle.to_tensor(y.shape)) register_hooks = { @@ -258,8 +261,8 @@ def add_hooks(m): for m in model.sublayers(): if len(list(m.children())) > 0: continue - if set(['total_ops', 'total_params', 'input_shape', - 'output_shape']).issubset(set(list(m._buffers.keys()))): + if {'total_ops', 'total_params', 'input_shape', + 'output_shape'}.issubset(set(list(m._buffers.keys()))): total_ops += m.total_ops total_params += m.total_params @@ -274,8 +277,8 @@ def add_hooks(m): for n, m in model.named_sublayers(): if len(list(m.children())) > 0: continue - if set(['total_ops', 'total_params', 'input_shape', - 'output_shape']).issubset(set(list(m._buffers.keys()))): + if {'total_ops', 'total_params', 'input_shape', + 'output_shape'}.issubset(set(list(m._buffers.keys()))): table.add_row([ m.full_name(), list(m.input_shape.numpy()), list(m.output_shape.numpy()), int(m.total_params), diff --git a/python/paddle/tests/test_model.py b/python/paddle/tests/test_model.py index b9c6dafbb808b..53dce286b71e9 100644 --- a/python/paddle/tests/test_model.py +++ b/python/paddle/tests/test_model.py @@ -732,6 +732,18 @@ def customize_dropout(m, x, y): custom_ops={paddle.nn.Dropout: customize_dropout}, print_detail=True) + def test_dynamic_flops_with_multiple_outputs(self): + net = paddle.nn.MaxPool2D( + kernel_size=2, stride=2, padding=0, return_mask=True) + + def customize_dropout(m, x, y): + m.total_ops += 0 + + paddle.flops( + net, [1, 2, 32, 32], + custom_ops={paddle.nn.Dropout: customize_dropout}, + print_detail=True) + def test_export_deploy_model(self): self.set_seed() np.random.seed(201)