Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make TTAModel compatible with FSDP. #611

Merged
merged 50 commits into from
Dec 27, 2022
Merged
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
0bdb83c
Add build_runner_with_tta and PrepareTTAHook
HAOCHENYE Oct 14, 2022
848d0df
rename hook file
HAOCHENYE Oct 14, 2022
922bbd8
support build tta runner with runner type
HAOCHENYE Oct 14, 2022
9586139
add unit test
HAOCHENYE Oct 18, 2022
95811c8
Add build_runner_with_tta to index.rst
HAOCHENYE Oct 18, 2022
fb29cdd
minor refine
HAOCHENYE Oct 18, 2022
a1bd441
Add runner test cast
HAOCHENYE Oct 18, 2022
05110e0
Fix unit test
HAOCHENYE Oct 18, 2022
d7d0e77
fix unit test
HAOCHENYE Oct 20, 2022
bdf728c
Merge branch 'HAOCHENYE/add_runner_testcase' into HAOCHENYE/compatibl…
HAOCHENYE Oct 25, 2022
bf8d377
tmp save
HAOCHENYE Oct 25, 2022
909cc9f
pop None if key does not exist
HAOCHENYE Oct 25, 2022
c1d8c84
Merge branch 'HAOCHENYE/add_runner_testcase' into HAOCHENYE/compatibl…
HAOCHENYE Oct 25, 2022
869efa2
Merge branch 'main' into HAOCHENYE/add_runner_testcase
HAOCHENYE Oct 25, 2022
6977ec2
Fix is_model_wrapper and force register class in test_runner
HAOCHENYE Oct 25, 2022
9f69839
Merge branch 'HAOCHENYE/add_runner_testcase' into HAOCHENYE/compatibl…
HAOCHENYE Oct 25, 2022
e0ca06d
[Fix] Fix is_model_wrapper
HAOCHENYE Oct 25, 2022
2636c19
Merge branch 'HAOCHENYE/add_runner_testcase' into HAOCHENYE/compatibl…
HAOCHENYE Oct 25, 2022
9ecd4c7
destroy group after ut
HAOCHENYE Oct 25, 2022
4f97a3b
register module in testcase
HAOCHENYE Oct 25, 2022
34b85d9
Merge branch 'HAOCHENYE/add_runner_testcase' into HAOCHENYE/compatibl…
HAOCHENYE Oct 25, 2022
c4b76e6
pass through unit test
HAOCHENYE Oct 25, 2022
b6f8e2b
fix as comment
HAOCHENYE Oct 26, 2022
141e9ed
remove breakpoint
HAOCHENYE Oct 31, 2022
7b885d2
remove mmengine/testing/runner_test_cast.py
HAOCHENYE Nov 3, 2022
4aa8e3c
minor refine
HAOCHENYE Nov 3, 2022
d2ddda5
Merge branch 'main' into HAOCHENYE/compatible_with_fsdp
HAOCHENYE Nov 3, 2022
499e1dd
minor refine
HAOCHENYE Nov 3, 2022
7bdea0d
minor refine
HAOCHENYE Nov 3, 2022
7c37ca0
set default data preprocessor for model
HAOCHENYE Nov 3, 2022
2ffadcd
minor refine
HAOCHENYE Nov 3, 2022
94b3be0
minor refine
HAOCHENYE Nov 4, 2022
89cdaaa
fix lint
HAOCHENYE Nov 20, 2022
34aa808
Merge branch 'HAOCHENYE/add_runner_testcase' into HAOCHENYE/compatibl…
HAOCHENYE Nov 20, 2022
2d9f8eb
Merge branch 'main' of github.com:open-mmlab/mmengine into HAOCHENYE/…
HAOCHENYE Nov 22, 2022
f43e996
Fix unit test
HAOCHENYE Nov 22, 2022
40efc16
replace with in ImgDataPreprocessor
HAOCHENYE Dec 6, 2022
4687d6b
Fix as comment
HAOCHENYE Dec 12, 2022
0020c77
add inference tutorial in advanced tutorial
HAOCHENYE Dec 12, 2022
20f1ee5
update index.rst
HAOCHENYE Dec 12, 2022
af6ae94
add tta example
HAOCHENYE Dec 12, 2022
37896f5
refine tta tutorial
HAOCHENYE Dec 12, 2022
a660141
Add english tutorial
HAOCHENYE Dec 12, 2022
181cc4d
add note for build_runner_with_tta
HAOCHENYE Dec 13, 2022
1157a9b
Merge branch 'main' into HAOCHENYE/compatible_with_fsdp
HAOCHENYE Dec 23, 2022
c21fd8b
Fix as comment
HAOCHENYE Dec 25, 2022
a60d8dc
add examples
HAOCHENYE Dec 26, 2022
d0ca8f5
remove chinese comment
HAOCHENYE Dec 26, 2022
7182fc5
Update docs/en/advanced_tutorials/test_time_augmentation.md
HAOCHENYE Dec 26, 2022
14382d2
Merge branch 'main' into HAOCHENYE/compatible_with_fsdp
HAOCHENYE Dec 27, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 161 additions & 0 deletions docs/en/advanced_tutorials/test_time_augmentation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Test time augmentation

Test time augmentation (TTA) is a data augmentation strategy used during the testing phase. It involves applying various augmentations, such as flipping and scaling, to the same image and then merging the predictions of each augmented image to produce a more accurate prediction. To make it easier for users to use TTA, MMEngine provides [BaseTTAModel](mmengine.model.BaseTTAModel) class, which allows users to implement different TTA strategies by simply extending the `BaseTTAModel` class according to their needs.

The core implementation of TTA is usually divided into two parts:

1. Data augmentation: This part is implemented in MMCV, see the api docs [TestTimeAug](mmcv.transform.TestTimeAug) for more information.
2. Merge the predictions: The subclasses of `BaseTTAModel.test_step` will merge the predictions of enhanced data to a more accurate prediction.
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved

## Get started

A simple example of TTA is given in [example/test_time_augmentation.py](https://github.com/open-mmlab/mmengine/blob/main/examples/test_time_augmentation.py)
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved

### Prepare test time augmentation pipeline

`BaseTTAModel` needs to be used with `TestTimeAug` implemented in MMCV:

```python
tta_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='TestTimeAug',
transforms=[
[dict(type='Resize', img_scale=(1333, 800), keep_ratio=True)],
[dict(type='RandomFlip', flip_ratio=0.),
dict(type='RandomFlip', flip_ratio=1.)],
[dict(type='PackXXXInputs', keys=['img'])],
])
]
```

The above data augmentation pipeline will first perform a scaling enhancement on the image, followed by 2 flipping enhancements (flipping and not flipping). Finally, the image is packaged into the final result using `PackXXXInputs`.

### Define the merge strategy
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved

Commonly, users only need to inherit `BaseTTAModel` and override the `BaseTTAModel.merge_preds` to merge the predictions of enhanced data. `merge_preds` accepts a list of enhanced batch data, and each element of the list means the enhanced single data of the batch.

The BaseTTAModel class requires inferencing on both flipped and unflipped images and then merges the results. The merge_preds method accepts a list where each element represents the results of applying data augmentation to a single element of the batch. For example, if batch_size is 3, and we flip each image in the batch as an augmentation, merge_preds would accept a parameter like the following:

```python
# `data_{i}_{j}` represents the result of applying the jth data augmentation to
# the ith image in the batch. So, if batch_size is 3, i can take on values of
# 0, 1, and 2. If there are 2 augmentation methods
# (such as flipping the image), then j can take on values of 0 and 1.
# For example, data_2_1 would represent the result of applying the second
# augmentation method (flipping) to the third image in the batch.

demo_results = [
[data_0_0, data_0_1],
[data_1_0, data_1_1],
[data_2_0, data_2_1],
]
```

The `merge_preds` method will merge the predictions `demo_results` into single batch results. For example, if we want to merge multiple classification results:

```python
class AverageClsScoreTTA(BaseTTAModel):
def merge_preds(
self,
data_samples_list: List[List[ClsDataSample]],
) -> List[ClsDataSample]:

merged_data_samples = []
for data_samples in data_samples_list:
merged_data_sample: ClsDataSample = data_samples[0].new()
merged_score = sum(data_sample.pred_label.score
for data_sample in data_samples) / len(data_samples)
merged_data_sample.set_pred_score(merged_score)
merged_data_samples.append(merged_data_sample)
return merged_data_samples
```

The configuration file for the above example is as follows:

```python
tta_model = dict(type='AverageClsScoreTTA')
```

### Changes to test script

```python
cfg.model = ConfigDict(**cfg.tta_model, module=cfg.model)
cfg.test_dataloader.dataset.pipeline = cfg.tta_pipeline
```

## Advanced usage

In general, users who inherit the `BaseTTAModel` class only need to implement the merge_preds method to perform result fusion. However, for more complex cases, such as fusing the results of a multi-stage detector, it may be necessary to override the test_step method. This requires an understanding of the data flow in the BaseTTAModel class and its relationship with other components.

### The relationship between BaseTTAModel and other components

The BaseTTAModel class acts as an intermediary between the DDPWrapper and Model classes. When the Runner.test() method is executed, it will first call DDPWrapper.test_step(), followed by TTAModel.test_step(), and finally model.test_step().

<div align=center><img src=https://user-images.githubusercontent.com/57566630/206969103-43ef8cb9-b649-4b38-a441-f489a41269b3.png></div>

The following diagram illustrates this sequence of method calls:

<div align=center><img src=https://user-images.githubusercontent.com/57566630/206969958-3b4d296b-9f50-4098-a6fe-756c686db86d.png></div>

### data flow

After data augmentation with TestTimeAug, the resulting data will have the following format:

```python
image1 = dict(
inputs=[data_1_1, data_1_2],
data_sample=[data_sample1_1, data_sample1_2])

image2 = dict(
inputs=[data_2_1, data_2_2],
data_sample=[data_sample2_1, data_sample2_2])

image3 = dict(
inputs=[data_3_1, data_3_2],
data_sample=[data_sample3_1, data_sample3_2])
```

where `data_{i}_{j}` means the enhanced data,and `data_sample_{i}_{j}` means the ground truth of enhanced data. Then the data will be processed by `Dataloader`, which contributes to the following format:

```python
data_batch = dict(
inputs = [
(data_1_1, data_2_1, data_3_1),
(data_1_2, data_2_2, data_3_2),
]
data_samples=[
(data_samples1_1, data_samples2_1, data_samples3_1),
(data_samples1_2, data_samples2_2, data_samples3_2)
]
)
```

To facilitate model inferencing, the `BaseTTAModel` will convert the data into the following format:

```python
data_batch_aug1 = dict(
inputs = (data_1_1, data_2_1, data_3_1),
data_samples=(data_samples1_1, data_samples2_1, data_samples3_1)
)

data_batch_aug2 = dict(
inputs = (data_1_2, data_2_2, data_3_2),
data_samples=(data_samples1_2, data_samples2_2, data_samples3_2)
)
```

At this point, each `data_batch_aug` can be passed directly to the model for inferencing. After the model has performed inferencing, the `BaseTTAModel` will reorganize the predictions as follows for the convenience of merging:

```python
preds = [
[data_samples1_1, data_samples_1_2],
[data_samples2_1, data_samples_2_2],
[data_samples3_1, data_samples_3_2],
]
```

Now that we understand the data flow in TTA, we can override the BaseTTAModel.test_step() method to implement more complex fusion strategies based on specific requirements.
1 change: 1 addition & 0 deletions docs/en/api/hooks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ mmengine.hooks
IterTimerHook
SyncBuffersHook
EmptyCacheHook
PrepareTTAHook
1 change: 1 addition & 0 deletions docs/en/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ You can switch between Chinese and English documents in the lower-left corner of
advanced_tutorials/fileio.md
advanced_tutorials/manager_mixin.md
advanced_tutorials/cross_library.md
advanced_tutorials/test_time_augmentation.md

.. toctree::
:maxdepth: 1
Expand Down
157 changes: 157 additions & 0 deletions docs/zh_cn/advanced_tutorials/test_time_augmentation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# 测试时增强(Test time augmentation)

测试时增强(Test time augmentation,后文简称 TTA)是一种测试阶段的数据增强策略,旨在测试过程中,对同一张图片做翻转、缩放等各种数据增强,将增强后每张图片预测的结果还原到原始尺寸并做融合,以获得更加准确的预测结果。为了让用户更加方便地使用 TTA,MMEngine 提供了 [BaseTTAModel](mmengine.model.BaseTTAModel) 类,用户只需按照任务需求,继承 BaseTTAModel 类,实现不同的 TTA 策略即可。

TTA 的核心实现通常分为两个部分:

1. 测试时的数据增强:测试时数据增强主要在 MMCV 中实现,可以参考 [TestTimeAug 的 API 文档](mmcv.transform.TestTimeAug),本文档不再赘述。
2. 模型推理以及结果融合:`BaseTTAModel` 的主要功能就是实现这一部分,`BaseTTAModel.test_step` 会解析测试时增强后的数据并进行推理。用户继承 `BaseTTAModel` 后只需实现相应的融合策略即可。

## 快速上手

一个简单的支持 TTA 的示例可以参考 [example/test_time_augmentation.py](https://github.com/open-mmlab/mmengine/blob/main/examples/test_time_augmentation.py)
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved

### 准备 TTA 数据增强

`BaseTTAModel` 需要配合 MMCV 中实现的 `TestTimeAug` 使用,这边简单给出一个样例配置:

```python
tta_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='TestTimeAug',
transforms=[
[dict(type='Resize', img_scale=(1333, 800), keep_ratio=True)],
[dict(type='RandomFlip', flip_ratio=0.),
dict(type='RandomFlip', flip_ratio=1.)],
[dict(type='PackXXXInputs', keys=['img'])],
])
]
```

该配置表示在测试时,每张图片缩放(Resize)后都会进行翻转增强,变成两张图片。

### 定义 TTA 模型融合策略

`BaseTTAModel` 需要对翻转前后的图片进行推理,并将结果融合。`merge_preds` 方法接受一列表,列表中每一个元素表示 batch 中的某个数据反复增强后的结果。例如 batch_size=3,我们对 batch 中的每张图片做翻转增强,`merge_preds` 接受的参数为:

```python
# data_{i}_{j} 表示对第 i 张图片做第 j 种增强后的结果,
# 例如 batch_size=3,那么 i 的 取值范围为 0,1,2,
# 增强方式有 2 种(翻转),那么 j 的取值范围为 0,1

demo_results = [
[data_0_0, data_0_1],
[data_1_0, data_1_1],
[data_2_0, data_2_1],
]
```

merge_preds 需要将 demo_results 融合成整个 batch 的推理结果。以融合分类结果为例:

```python
class AverageClsScoreTTA(BaseTTAModel):
def merge_preds(
self,
data_samples_list: List[List[ClsDataSample]],
) -> List[ClsDataSample]:

merged_data_samples = []
for data_samples in data_samples_list:
merged_data_sample: ClsDataSample = data_samples[0].new()
merged_score = sum(data_sample.pred_label.score
for data_sample in data_samples) / len(data_samples)
merged_data_sample.set_pred_score(merged_score)
merged_data_samples.append(merged_data_sample)
return merged_data_samples
```

相应的配置文件为:

```python
tta_model = dict(type='AverageClsScoreTTA')
```

### 改写测试脚本

```python
cfg.model = ConfigDict(**cfg.tta_model, module=cfg.model)
cfg.test_dataloader.dataset.pipeline = cfg.tta_pipeline
```

## 进阶使用

一般情况下,用户继承 `BaseTTAModel` 后,只需要实现 `merge_preds` 方法,即可完成结果融合。但是对于复杂情况,例如融合多阶段检测器的推理结果,则可能会需要重写 `test_step` 方法。这就要求我们去进一步了解 `BaseTTAModel` 的数据流以及它和各组件之间的关系。

### BaseTTAModel 和各组件的关系

`BaseTTAModel` 是 `DDPWrapper` 和 `Model` 的中间层。在执行 `Runner.test()` 的过程中,会先执行 `DDPWrapper.test_step()`,然后执行 `TTAModel.test_step()`,最后再执行 `model.test_step()`:

<div align=center><img src=https://user-images.githubusercontent.com/57566630/206969103-43ef8cb9-b649-4b38-a441-f489a41269b3.png></div>

运行过程中具体的调用栈如下所示:

<div align=center><img src=https://user-images.githubusercontent.com/57566630/206969958-3b4d296b-9f50-4098-a6fe-756c686db86d.png></div>

### 数据流

数据经 `TestTimeAug` 增强后,其数据格式为:

```python
image1 = dict(
inputs=[data_1_1, data_1_2],
data_sample=[data_sample1_1, data_sample1_2])

image2 = dict(
inputs=[data_2_1, data_2_2],
data_sample=[data_sample2_1, data_sample2_2])

image3 = dict(
inputs=[data_3_1, data_3_2],
data_sample=[data_sample3_1, data_sample3_2])
```

其中 `data_{i}_{j}` 为增强后的数据,`data_sample_{i}_{j}` 为增强后数据的标签信息。
数据经过 DataLoader 处理后,格式转变为:

```python
data_batch = dict(
inputs = [
(data_1_1, data_2_1, data_3_1),
(data_1_2, data_2_2, data_3_2),
]
data_samples=[
(data_samples1_1, data_samples2_1, data_samples3_1),
(data_samples1_2, data_samples2_2, data_samples3_2)
]
)
```

为了方便模型推理,BaseTTAModel 会在模型推理前将将数据转换为:

```python
data_batch_aug1 = dict(
inputs = (data_1_1, data_2_1, data_3_1),
data_samples=(data_samples1_1, data_samples2_1, data_samples3_1)
)

data_batch_aug2 = dict(
inputs = (data_1_2, data_2_2, data_3_2),
data_samples=(data_samples1_2, data_samples2_2, data_samples3_2)
)
```

此时每个 `data_batch_aug` 均可以直接传入模型进行推理。模型推理后,`BaseTTAModel` 会将推理结果整理成:

```python
preds = [
[data_samples1_1, data_samples_1_2],
[data_samples2_1, data_samples_2_2],
[data_samples3_1, data_samples_3_2],
]
```

方便用户进行结果融合。了解 TTA 的数据流后,我们就可以根据具体的需求,重载 `BaseTTAModel.test_step()`,以实现更加复杂的融合策略。
1 change: 1 addition & 0 deletions docs/zh_cn/api/hooks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ mmengine.hooks
IterTimerHook
SyncBuffersHook
EmptyCacheHook
PrepareTTAHook
1 change: 1 addition & 0 deletions docs/zh_cn/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
advanced_tutorials/fileio.md
advanced_tutorials/manager_mixin.md
advanced_tutorials/cross_library.md
advanced_tutorials/test_time_augmentation.md

.. toctree::
:maxdepth: 1
Expand Down
3 changes: 2 additions & 1 deletion mmengine/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from .runtime_info_hook import RuntimeInfoHook
from .sampler_seed_hook import DistSamplerSeedHook
from .sync_buffer_hook import SyncBuffersHook
from .test_time_aug_hook import PrepareTTAHook

__all__ = [
'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook',
'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook', 'LoggerHook',
'NaiveVisualizationHook', 'EMAHook', 'RuntimeInfoHook'
'NaiveVisualizationHook', 'EMAHook', 'RuntimeInfoHook', 'PrepareTTAHook'
]
Loading