Skip to content

Commit

Permalink
[Enhance] Make the parameters of get_model_complexity_info() friendly (
Browse files Browse the repository at this point in the history
…#1056)

* print_helper: optimize inputs of get_model_complexity_info

Signed-off-by: Shengjiang QUAN <qsj287068067@126.com>

* directly throw error

When "input_shape" and "inputs" are both `None` or both set,
throw ValueError.

Signed-off-by: Shengjiang QUAN <qsj287068067@126.com>

---------

Signed-off-by: Shengjiang QUAN <qsj287068067@126.com>
  • Loading branch information
sjiang95 committed Apr 10, 2023
1 parent 5e1ed7a commit f76218a
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion mmengine/analysis/print_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ def complexity_stats_table(

def get_model_complexity_info(
model: nn.Module,
input_shape: tuple,
input_shape: tuple = None,
inputs: Optional[torch.Tensor] = None,
show_table: bool = True,
show_arch: bool = True,
Expand All @@ -696,6 +696,11 @@ def get_model_complexity_info(
Returns:
dict: The complexity information of the model.
"""
if input_shape is None and inputs is None:
raise ValueError('One of "input_shape" and "inputs" should be set.')
elif input_shape is not None and inputs is not None:
raise ValueError('"input_shape" and "inputs" cannot be both set.')

if inputs is None:
inputs = (torch.randn(1, *input_shape), )

Expand Down

0 comments on commit f76218a

Please sign in to comment.