From 2df5bc137ba82c6af31d6f25b07297b7db695ac1 Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Tue, 11 Oct 2022 09:49:50 +0800 Subject: [PATCH] [Fix] Fix base tta model (#593) Co-authored-by: ubuntu --- mmengine/model/wrappers/test_time_aug.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mmengine/model/wrappers/test_time_aug.py b/mmengine/model/wrappers/test_time_aug.py index 677bc3ca44..d99919df9e 100644 --- a/mmengine/model/wrappers/test_time_aug.py +++ b/mmengine/model/wrappers/test_time_aug.py @@ -20,7 +20,7 @@ @MODELS.register_module() -class BaseTTAModel: +class BaseTTAModel(nn.Module): """Base model for inference with test-time augmentation. ``BaseTTAModel`` is a wrapper for inference given multi-batch data. @@ -74,6 +74,7 @@ class BaseTTAModel: """ def __init__(self, module: Union[dict, nn.Module]): + super().__init__() if isinstance(module, nn.Module): self.module = module elif isinstance(module, dict):