Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add test time augmentation base model. #538

Merged
merged 28 commits into from
Oct 10, 2022
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/en/api/model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Model
BaseModel
BaseDataPreprocessor
ImgDataPreprocessor
BaseTTAModel

EMA
----------------
Expand Down
1 change: 1 addition & 0 deletions docs/zh_cn/api/model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Model
BaseModel
BaseDataPreprocessor
ImgDataPreprocessor
BaseTTAModel

EMA
----------------
Expand Down
4 changes: 2 additions & 2 deletions mmengine/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
constant_init, initialize, kaiming_init, normal_init,
trunc_normal_init, uniform_init, update_init_info,
xavier_init)
from .wrappers import (MMDistributedDataParallel,
from .wrappers import (BaseTTAModel, MMDistributedDataParallel,
MMSeparateDistributedDataParallel, is_model_wrapper)

__all__ = [
Expand All @@ -29,7 +29,7 @@
'bias_init_with_prob', 'BaseInit', 'ConstantInit', 'XavierInit',
'NormalInit', 'TruncNormalInit', 'UniformInit', 'KaimingInit',
'Caffe2XavierInit', 'PretrainedInit', 'initialize',
'convert_sync_batchnorm'
'convert_sync_batchnorm', 'BaseTTAModel'
]

if digit_version(TORCH_VERSION) >= digit_version('1.11.0'):
Expand Down
3 changes: 2 additions & 1 deletion mmengine/model/wrappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from mmengine.utils.version_utils import digit_version
from .distributed import MMDistributedDataParallel
from .seperate_distributed import MMSeparateDistributedDataParallel
from .test_time_aug import BaseTTAModel
from .utils import is_model_wrapper

__all__ = [
'MMDistributedDataParallel', 'is_model_wrapper',
'MMSeparateDistributedDataParallel'
'MMSeparateDistributedDataParallel', 'BaseTTAModel'
]

if digit_version(TORCH_VERSION) >= digit_version('1.11.0'):
Expand Down
6 changes: 2 additions & 4 deletions mmengine/model/wrappers/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,7 @@ def val_step(self, data: Union[dict, tuple, list]) -> list:
Returns:
list: The predictions of given data.
"""
data = self.module.data_preprocessor(data, training=False)
return self._run_forward(data, mode='predict')
return self.module.val_step(data)

def test_step(self, data: Union[dict, tuple, list]) -> list:
"""Gets the predictions of module during testing process.
Expand All @@ -146,8 +145,7 @@ def test_step(self, data: Union[dict, tuple, list]) -> list:
Returns:
list: The predictions of given data.
"""
data = self.module.data_preprocessor(data, training=False)
return self._run_forward(data, mode='predict')
return self.module.test_step(data)

def _run_forward(self, data: Union[dict, tuple, list], mode: str) -> Any:
"""Unpacks data for :meth:`forward`
Expand Down
126 changes: 126 additions & 0 deletions mmengine/model/wrappers/test_time_aug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import abstractmethod
from typing import Dict, List, Union

import torch
import torch.nn as nn

from mmengine import MODELS
from mmengine.structures import BaseDataElement

# multi-batch inputs processed by different augmentations from the same batch.
EnhancedBatchInputs = List[Union[torch.Tensor, List[torch.Tensor]]]
# multi-batch data samples processed by different augmentations from the same
# batch. The inner list stands for different augmentations and the outer list
# stands for batch.
EnhancedBatchDataSamples = List[List[BaseDataElement]]
DATA_BATCH = Union[Dict[str, Union[EnhancedBatchInputs,
EnhancedBatchDataSamples]], tuple, dict]
MergedDataSamples = List[BaseDataElement]


@MODELS.register_module()
class BaseTTAModel:
"""Base model for inference with test-time augmentation.

``BaseTTAModel`` is a wrapper for inference given multi-batch data.
It implements the :meth:`test_step` for multi-batch data inference.
``multi-batch`` data means data processed by different augmentation
from the same batch.

During test time augmentation, the data processed by
:obj:`mmcv.transforms.TestTimeAug`, and then collated by
``pseudo_collate`` will have the following format:

.. code-block::

result = dict(
inputs=[
[image1_aug1, image2_aug1],
[image1_aug2, image2_aug2]
],
data_samples=[
[data_sample1_aug1, data_sample2_aug1],
[data_sample1_aug2, data_sample2_aug2],
]
)

``image{i}_aug{j}`` means the i-th image of the batch, which is
augmented by the j-th augmentation.

``BaseTTAModel`` will collate the data to:

.. code-block::

data1 = dict(
inputs=[image1_aug1, image2_aug1],
data_samples=[data_sample1_aug1, data_sample2_aug1]
)

data2 = dict(
inputs=[image1_aug2, image2_aug2],
data_samples=[data_sample1_aug2, data_sample2_aug2]
)

``data1`` and ``data2`` will be passed to model, and the results will be
merged by :meth:`merge_preds`.

Note:
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
:meth:`merge_results` is an abstract method, all subclasses should
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
implement it.

Args:
module (dict or nn.Module): Tested model.
"""

def __init__(self, module: Union[dict, nn.Module]):
if isinstance(module, nn.Module):
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
self.module = module
elif isinstance(module, dict):
self.module = MODELS.build(module)
else:
raise TypeError('The type of module should be a `nn.Module` '
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
f'instance or a dict, but got {module}')
assert hasattr(self.module, 'test_step'), (
'Model wrapped by BaseTTAModel must implement `test_step`!')

@abstractmethod
def merge_preds(self, data_samples_list: EnhancedBatchDataSamples) \
-> MergedDataSamples:
"""Merge predictions of enhanced data to one prediction.

Args:
data_samples_list (EnhancedBatchDataSamples): List of predictions
of all enhanced data.

Returns:
List[BaseDataElement]: Merged prediction.
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
"""

def test_step(self, data: DATA_BATCH) -> MergedDataSamples:
"""Get predictions of each enhanced data, a multiple predictions.

Args:
data (DataBatch): Enhanced data batch sampled from dataloader.

Returns:
MergedDataSamples: Merged prediction.
"""
data_list: Union[List[dict], List[list]]
if isinstance(data, dict):
num_augs = len(data[next(iter(data))])
data_list = [{key: value[idx]
for key, value in data.items()}
for idx in range(num_augs)]
elif isinstance(data, (tuple, list)):
num_augs = len(data[0])
data_list = [[_data[idx] for _data in data]
for idx in range(num_augs)]
else:
raise TypeError('data given by dataLoader should be a dict, '
f'tuple or a list, but got {type(data)}')

predictions = []
for data in data_list: # type: ignore
predictions.append(self.module.test_step(data))
return self.merge_preds(list(zip(*predictions))) # type: ignore
4 changes: 2 additions & 2 deletions mmengine/runner/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from mmengine.fileio import FileClient, get_file_backend
from mmengine.fileio import load as load_file
from mmengine.logging import print_log
from mmengine.model import is_model_wrapper
from mmengine.model import BaseTTAModel, is_model_wrapper
from mmengine.utils import mkdir_or_exist
from mmengine.utils.dl_utils import load_url

Expand Down Expand Up @@ -73,7 +73,7 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
def load(module, prefix=''):
# recursively check parallel module in case that the model has a
# complicated structure, e.g., nn.Module(nn.Module(DDP))
if is_model_wrapper(module):
if is_model_wrapper(module) or isinstance(module, BaseTTAModel):
module = module.module
local_metadata = {} if metadata is None else metadata.get(
prefix[:-1], {})
Expand Down
40 changes: 29 additions & 11 deletions mmengine/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,8 @@
from mmengine.fileio import FileClient, join_path
from mmengine.hooks import Hook
from mmengine.logging import MessageHub, MMLogger, print_log
from mmengine.model import (BaseModel, MMDistributedDataParallel,
convert_sync_batchnorm, is_model_wrapper,
revert_sync_batchnorm)
from mmengine.model import (MMDistributedDataParallel, convert_sync_batchnorm,
is_model_wrapper, revert_sync_batchnorm)
from mmengine.optim import (OptimWrapper, OptimWrapperDict, _ParamScheduler,
build_optim_wrapper)
from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS,
Expand Down Expand Up @@ -772,7 +771,7 @@ def build_visualizer(
'visualizer should be Visualizer object, a dict or None, '
f'but got {visualizer}')

def build_model(self, model: Union[BaseModel, Dict]) -> BaseModel:
def build_model(self, model: Union[nn.Module, Dict]) -> nn.Module:
"""Build model.

If ``model`` is a dict, it will be used to build a nn.Module object.
Expand All @@ -783,14 +782,20 @@ def build_model(self, model: Union[BaseModel, Dict]) -> BaseModel:
model = dict(type='ResNet')

Args:
model (BaseModel or dict): A nn.Module object or a dict to build
nn.Module object. If ``model`` is a nn.Module object, just
returns itself.
model (nn.Module or dict): A ``nn.Module`` object or a dict to
build nn.Module object. If ``model`` is a nn.Module object,
just returns itself.

Note:
The returned model must implement ``train_step``, ``test_step``
if ``runner.train`` or ``runner.test`` will be called. If
``runner.val`` will be called or ``val_cfg`` is configured,
model must implement `val_step`.

Returns:
nn.Module: Model build from ``model``.
"""
if isinstance(model, BaseModel):
if isinstance(model, nn.Module):
return model
elif isinstance(model, dict):
model = MODELS.build(model)
Expand All @@ -801,7 +806,7 @@ def build_model(self, model: Union[BaseModel, Dict]) -> BaseModel:

def wrap_model(
self, model_wrapper_cfg: Optional[Dict],
model: BaseModel) -> Union[DistributedDataParallel, BaseModel]:
model: nn.Module) -> Union[DistributedDataParallel, nn.Module]:
"""Wrap the model to :obj:``MMDistributedDataParallel`` or other custom
distributed data-parallel module wrappers.

Expand All @@ -816,10 +821,10 @@ def wrap_model(
model_wrapper_cfg (dict, optional): Config to wrap model. If not
specified, ``DistributedDataParallel`` will be used in
distributed environment. Defaults to None.
model (BaseModel): Model to be wrapped.
model (nn.Module): Model to be wrapped.

Returns:
BaseModel or DistributedDataParallel: BaseModel or subclass of
nn.Module or DistributedDataParallel: nn.Module or subclass of
``DistributedDataParallel``.
"""
if is_model_wrapper(model):
Expand Down Expand Up @@ -1601,6 +1606,19 @@ def train(self) -> nn.Module:
Returns:
nn.Module: The model after training.
"""
if is_model_wrapper(self.model):
ori_model = self.model.module
else:
ori_model = self.model
assert hasattr(ori_model, 'train_step'), (
'If you want to train your model, please make sure your model '
'has implemented `train_step`.')

if self._val_loop is not None:
assert hasattr(ori_model, 'val_step'), (
'If you want to validate your model, please make sure your '
'model has implemented `val_step`.')

zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
if self._train_loop is None:
raise RuntimeError(
'`self._train_loop` should not be None when calling train '
Expand Down
14 changes: 7 additions & 7 deletions tests/test_hooks/test_ema_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from mmengine.testing import assert_allclose


class ToyModel(nn.Module):
class ToyModel(BaseModel):

def __init__(self):
super().__init__()
Expand All @@ -39,33 +39,33 @@ def forward(self, inputs, data_sample, mode='tensor'):
return outputs


class ToyModel1(BaseModel, ToyModel):
class ToyModel1(ToyModel):

def __init__(self):
super().__init__()

def forward(self, *args, **kwargs):
return super(BaseModel, self).forward(*args, **kwargs)
return super().forward(*args, **kwargs)


class ToyModel2(BaseModel, ToyModel):
class ToyModel2(ToyModel):

def __init__(self):
super().__init__()
self.linear1 = nn.Linear(2, 1)

def forward(self, *args, **kwargs):
return super(BaseModel, self).forward(*args, **kwargs)
return super().forward(*args, **kwargs)


class ToyModel3(BaseModel, ToyModel):
class ToyModel3(ToyModel):

def __init__(self):
super().__init__()
self.linear = nn.Linear(2, 2)

def forward(self, *args, **kwargs):
return super(BaseModel, self).forward(*args, **kwargs)
return super().forward(*args, **kwargs)


@DATASETS.register_module()
Expand Down
47 changes: 47 additions & 0 deletions tests/test_model/test_wrappers/test_test_aug_time.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase

from torch.utils.data import DataLoader

from mmengine.model import BaseModel, BaseTTAModel


class ToyTestTimeAugModel(BaseTTAModel):

def merge_preds(self, data_samples_list):
result = [sum(x) for x in data_samples_list]
return result


class ToyModel(BaseModel):

def forward(self, inputs, data_samples, mode='tensor'):
return data_samples


class TestBaseTTAModel(TestCase):

def setUp(self) -> None:
dict_dataset = [dict(inputs=[1, 2], data_samples=[3, 4])] * 10
tuple_dataset = [([1, 2], [3, 4])] * 10
self.model = ToyModel()
self.dict_dataloader = DataLoader(dict_dataset, batch_size=2)
self.tuple_dataloader = DataLoader(tuple_dataset, batch_size=2)

def test_test_step(self):
tta_model = ToyTestTimeAugModel(self.model)

# Test dict dataset

for data in self.dict_dataloader:
# Test step will call forward.
result = tta_model.test_step(data)
self.assertEqual(result, [7, 7])

for data in self.tuple_dataloader:
result = tta_model.test_step(data)
self.assertEqual(result, [7, 7])

def test_init(self):
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
tta_model = ToyTestTimeAugModel(self.model)
self.assertIs(tta_model.module, self.model)
Loading