Skip to content

Commit

Permalink
Complement type hint of get_model_complexity_info()
Browse files Browse the repository at this point in the history
The type of `inputs` should be one of `torch.Tensor`,
`tuple[torch.Tensor, ...]` and `None`.

Signed-off-by: Shengjiang QUAN <qsj287068067@126.com>
  • Loading branch information
sjiang95 committed Apr 10, 2023
1 parent f76218a commit ec64498
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions mmengine/analysis/print_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/print_model_statistics.py

from collections import defaultdict
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union

import torch
from rich import box
Expand Down Expand Up @@ -676,7 +676,7 @@ def complexity_stats_table(
def get_model_complexity_info(
model: nn.Module,
input_shape: tuple = None,
inputs: Optional[torch.Tensor] = None,
inputs: Union[torch.Tensor, Tuple[torch.Tensor, ...], None] = None,
show_table: bool = True,
show_arch: bool = True,
):
Expand All @@ -685,9 +685,9 @@ def get_model_complexity_info(
Args:
model (nn.Module): The model to analyze.
input_shape (tuple): The input shape of the model.
inputs (torch.Tensor, optional): The input tensor of the model.
If not given the input tensor will be generated automatically
with the given input_shape.
inputs (Union[torch.Tensor, tuple[torch.Tensor, ...], None]):\
The input tensor(s) of the model. If not given the input tensor
will be generated automatically with the given input_shape.
show_table (bool): Whether to show the complexity table.
Defaults to True.
show_arch (bool): Whether to show the complexity arch.
Expand Down

0 comments on commit ec64498

Please sign in to comment.