Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add test time augmentation base model. #538

Merged
merged 28 commits into from
Oct 10, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mmengine/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
constant_init, initialize, kaiming_init, normal_init,
trunc_normal_init, uniform_init, update_init_info,
xavier_init)
from .wrappers import (MMDistributedDataParallel,
from .wrappers import (BaseTestTimeAugModel, MMDistributedDataParallel,
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
MMSeparateDistributedDataParallel, is_model_wrapper)

__all__ = [
Expand All @@ -29,7 +29,7 @@
'bias_init_with_prob', 'BaseInit', 'ConstantInit', 'XavierInit',
'NormalInit', 'TruncNormalInit', 'UniformInit', 'KaimingInit',
'Caffe2XavierInit', 'PretrainedInit', 'initialize',
'convert_sync_batchnorm'
'convert_sync_batchnorm', 'BaseTestTimeAugModel'
]

if digit_version(TORCH_VERSION) >= digit_version('1.11.0'):
Expand Down
3 changes: 2 additions & 1 deletion mmengine/model/wrappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from mmengine.utils.version_utils import digit_version
from .distributed import MMDistributedDataParallel
from .seperate_distributed import MMSeparateDistributedDataParallel
from .test_time_aug import BaseTestTimeAugModel
from .utils import is_model_wrapper

__all__ = [
'MMDistributedDataParallel', 'is_model_wrapper',
'MMSeparateDistributedDataParallel'
'MMSeparateDistributedDataParallel', 'BaseTestTimeAugModel'
]

if digit_version(TORCH_VERSION) >= digit_version('1.11.0'):
Expand Down
6 changes: 2 additions & 4 deletions mmengine/model/wrappers/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,7 @@ def val_step(self, data: Union[dict, tuple, list]) -> list:
Returns:
list: The predictions of given data.
"""
data = self.module.data_preprocessor(data, training=False)
return self._run_forward(data, mode='predict')
return self.module.val_step(data)

def test_step(self, data: Union[dict, tuple, list]) -> list:
"""Gets the predictions of module during testing process.
Expand All @@ -145,8 +144,7 @@ def test_step(self, data: Union[dict, tuple, list]) -> list:
Returns:
list: The predictions of given data.
"""
data = self.module.data_preprocessor(data, training=False)
return self._run_forward(data, mode='predict')
return self.module.test_step(data)

def _run_forward(self, data: Union[dict, tuple, list], mode: str) -> Any:
"""Unpacks data for :meth:`forward`
Expand Down
152 changes: 152 additions & 0 deletions mmengine/model/wrappers/test_time_aug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import abstractmethod
from typing import Dict, List, Optional, Union

import torch

from mmengine import MODELS
from mmengine.optim import OptimWrapper
from mmengine.structures import BaseDataElement
from ..base_model import BaseDataPreprocessor, BaseModel

EnhancedInputs = List[Union[torch.Tensor, List[torch.Tensor]]]
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
EnhancedDataSamples = List[List[BaseDataElement]]
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
MergedDataSamples = List[BaseDataElement]


@MODELS.register_module()
class BaseTestTimeAugModel(BaseModel):
"""Base model for test time augmentation.
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved

``BaseTestTimeAugModel`` is a wrapper for specific algorithm
model. It implements the :meth:`forward` for multi-batch
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
data inference. `multi-batch` data means different enhanced results for
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
the same batch.

All subclasses should implement :meth:`merge_results` for results fusion.
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved

During test time augmentation, the data processed by
:obj:`mmcv.transforms.TestTimeAug`, and then collated by
`pseudo_collate` will have the following format:
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved

.. code-block::

result = dict(
inputs=[
[image1_aug1, image2_aug1],
[image1_aug2, image2_aug2]
],
data_samples=[
[data_sample1_aug1, data_sample2_aug1],
[data_sample1_aug2, data_sample2_aug2],
]
)

``image{1}_aug{1}`` means the 1st image of the batch, which is
augmented by the 1st augmentation.
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved

``BaseTestTimeAugModel`` will collate the data to:

.. code-block::

data1 = dict(
inputs=[image1_aug1, image2_aug1],
data_samples=[data_sample1_aug1, data_sample2_aug1]
)

data2 = dict(
inputs=[image1_aug2, image2_aug2],
data_samples=[data_sample1_aug2, data_sample2_aug2]
)

``data1`` and ``data2`` will be passed to model, and the results will be
merged by :meth:`merge_results`

Note:
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
:meth:`merge_results` is an abstract method, all subclasses should
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
implement it.

Args:
module (BaseModel): Tested model.
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
data_preprocessor (BaseDataPreprocessor or dict, optional): The
pre-process config For :class:`BaseDataPreprocessor`.
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(
self,
module: BaseModel,
data_preprocessor: Optional[Union[dict,
BaseDataPreprocessor]] = None):
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved

super().__init__(module)
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(module, BaseModel):
self.module = module
elif isinstance(module, dict):
self.module = MODELS.build(module)
else:
raise TypeError('The type of module should be `BaseModel` or a '
f'dict, but got {module}')

@abstractmethod
def merge_results(self, data_samples_list: EnhancedDataSamples) \
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
-> List[BaseDataElement]:
"""Merge predictions of enhanced data to one prediction.

Args:
data_samples_list (EnhancedDataSamples): List of predictions of
all enhanced data.

Returns:
List[BaseDataElement]: Merged prediction.
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
"""

def forward(self,
inputs: torch.Tensor,
data_samples: Optional[list] = None,
mode: str = 'tensor') -> Union[Dict[str, torch.Tensor], list]:
"""``BaseTestTimeAugModel`` will directly call ``test_step`` of
corresponding algorithm, therefore its forward should not be called."""
raise NotImplementedError(
'`BaseTestTimeAugModel` will directly call '
f'{self.module.__class__.__name__}.test_step, its `forward` '
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
f'should not be called')

def test_step(self, data: Union[dict, tuple, list]) -> list:
"""Get predictions of each enhanced data, a multiple predictionsa.
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved

Args:
inputs (EnhancedInputs): List of enhanced batch data from single
batch data.
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
data_samples (EnhancedDataSamples): List of enhanced data
samples from single batch data sample.
mode (str): Current mode of model, see more information in
:meth:`mmengine.model.BaseModel.forward`.

Returns:
MergedDataSamples: Merged prediction.
"""
data_list: Union[List[dict], List[list]]
if isinstance(data, dict):
num_augs = len(data[next(iter(data))])
data_list = [{key: value[idx]
for key, value in data.items()}
for idx in range(num_augs)]
elif isinstance(data, (tuple, list)):
num_augs = len(data[0])
data_list = [[_data[idx] for _data in data]
for idx in range(num_augs)]
else:
raise TypeError

predictions = []
for data in data_list:
predictions.append(self.module.test_step(data))
return self.merge_results(predictions)

def train_step(self, data: Union[dict, tuple, list],
optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
"""``BaseTestTimeAugModel`` is only for testing or validation,
therefore ``train_step`` should not be called."""
raise NotImplementedError('train_step should not be called! '
f'{self.__class__.__name__} should only be'
f'used for testing.')
5 changes: 3 additions & 2 deletions mmengine/runner/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from mmengine.fileio import FileClient
from mmengine.fileio import load as load_file
from mmengine.logging import print_log
from mmengine.model import is_model_wrapper
from mmengine.model import BaseTestTimeAugModel, is_model_wrapper
from mmengine.utils import mkdir_or_exist
from mmengine.utils.dl_utils import load_url

Expand Down Expand Up @@ -73,7 +73,8 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
def load(module, prefix=''):
# recursively check parallel module in case that the model has a
# complicated structure, e.g., nn.Module(nn.Module(DDP))
if is_model_wrapper(module):
if is_model_wrapper(module) or isinstance(module,
BaseTestTimeAugModel):
module = module.module
local_metadata = {} if metadata is None else metadata.get(
prefix[:-1], {})
Expand Down
53 changes: 53 additions & 0 deletions tests/test_model/test_wrappers/test_test_aug_time.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase

from torch.utils.data import DataLoader

from mmengine.model import BaseModel, BaseTestTimeAugModel


class ToyTestTimeAugModel(BaseTestTimeAugModel):

def merge_results(self, data_samples_list):
result = list(map(lambda x: sum(x), zip(*data_samples_list)))
return result


class ToyModel(BaseModel):

def forward(self, inputs, data_samples, mode='tensor'):
return data_samples


class TestBaseTestTimeAugModel(TestCase):

def setUp(self) -> None:
dict_dataset = [dict(inputs=[1, 2], data_samples=[3, 4])] * 10
tuple_dataset = [([1, 2], [3, 4])] * 10
self.model = ToyModel()
self.dict_dataloader = DataLoader(dict_dataset, batch_size=2)
self.tuple_dataloader = DataLoader(tuple_dataset, batch_size=2)

def test_init(self):
tta_model = ToyTestTimeAugModel(self.model)
self.assertIs(tta_model.module, self.model)

def test_test_step(self):
tta_model = ToyTestTimeAugModel(self.model)

# Test dict dataset

for data in self.dict_dataloader:
# Test step will call forward.
result = tta_model.test_step(data)
self.assertEqual(result, [7, 7])

for data in self.tuple_dataloader:
result = tta_model.test_step(data)
self.assertEqual(result, [7, 7])

def test_train_step(self):
tta_model = ToyTestTimeAugModel(self.model)
with self.assertRaisesRegex(NotImplementedError,
'train_step should not be called'):
tta_model.train_step(None, None)