Skip to content

Commit

Permalink
[Feature] Add AvoidOOM to avoid OOM (open-mmlab#7434)
Browse files Browse the repository at this point in the history
* [Feature] Add AvoidOOM to avoid OOM

* support multiple outputs

* add docs in faq

* add docs in faq

* fix logic

* minor fix

* minor fix

* minor fix

* minor fix

* add the tutorials of using avoidoom as a decorator

* minor fix

* add convert tensor type test unit

* minor fix

* minor fix
  • Loading branch information
BIGWangYuDong authored and ZwwWayne committed Jul 19, 2022
1 parent 97dc8e2 commit 9761de5
Show file tree
Hide file tree
Showing 5 changed files with 356 additions and 1 deletion.
21 changes: 21 additions & 0 deletions docs/en/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,30 @@ We list some common troubles faced by many users and their corresponding solutio
- "GPU out of memory"

1. There are some scenarios when there are large amount of ground truth boxes, which may cause OOM during target assignment. You can set `gpu_assign_thr=N` in the config of assigner thus the assigner will calculate box overlaps through CPU when there are more than N GT boxes.

2. Set `with_cp=True` in the backbone. This uses the sublinear strategy in PyTorch to reduce GPU memory cost in the backbone.

3. Try mixed precision training using following the examples in `config/fp16`. The `loss_scale` might need further tuning for different models.

4. Try to use `AvoidCUDAOOM` to avoid GPU out of memory. It will first retry after calling `torch.cuda.empty_cache()`. If it still fails, it will then retry by converting the type of inputs to FP16 format. If it still fails, it will try to copy inputs from GPUs to CPUs to continue computing. Try AvoidOOM in you code to make the code continue to run when GPU memory runs out:
```python
from mmdet.utils import AvoidCUDAOOM

output = AvoidCUDAOOM.retry_if_cuda_oom(some_function)(input1, input2)
```
You can also try `AvoidCUDAOOM` as a decorator to make the code continue to run when GPU memory runs out:
```python
from mmdet.utils import AvoidCUDAOOM

@AvoidCUDAOOM.retry_if_cuda_oom
def function(*args, **kwargs):
...
return xxx
```
- "RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one"
1. This error indicates that your module has parameters that were not used in producing loss. This phenomenon may be caused by running different branches in your code in DDP mode.
Expand Down
21 changes: 21 additions & 0 deletions docs/zh_cn/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,30 @@
- "GPU out of memory"

1. 存在大量 ground truth boxes 或者大量 anchor 的场景,可能在 assigner 会 OOM。 您可以在 assigner 的配置中设置 `gpu_assign_thr=N`,这样当超过 N 个 GT boxes 时,assigner 会通过 CPU 计算 IOU。

2. 在 backbone 中设置 `with_cp=True`。 这使用 PyTorch 中的 `sublinear strategy` 来降低 backbone 占用的 GPU 显存。

3. 使用 `config/fp16` 中的示例尝试混合精度训练。`loss_scale` 可能需要针对不同模型进行调整。

4. 你也可以尝试使用 `AvoidCUDAOOM` 来避免该问题。首先它将尝试调用 `torch.cuda.empty_cache()`。如果失败,将会尝试把输入类型转换到 FP16。如果仍然失败,将会把输入从 GPUs 转换到 CPUs 进行计算。这里提供了两个使用的例子:
```python
from mmdet.utils import AvoidCUDAOOM

output = AvoidCUDAOOM.retry_if_cuda_oom(some_function)(input1, input2)
```

你也可也使用 `AvoidCUDAOOM` 作为装饰器让代码遇到 OOM 的时候继续运行:

```python
from mmdet.utils import AvoidCUDAOOM
@AvoidCUDAOOM.retry_if_cuda_oom
def function(*args, **kwargs):
...
return xxx
```

- "RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one"

1. 这个错误出现在存在参数没有在 forward 中使用,容易在 DDP 中运行不同分支时发生。
Expand Down
3 changes: 2 additions & 1 deletion mmdet/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .collect_env import collect_env
from .compat_config import compat_cfg
from .logger import get_caller_name, get_root_logger, log_img_scale
from .memory import AvoidCUDAOOM, AvoidOOM
from .misc import find_latest_checkpoint, update_data_root
from .replace_cfg_vals import replace_cfg_vals
from .setup_env import setup_multi_processes
Expand All @@ -12,5 +13,5 @@
'get_root_logger', 'collect_env', 'find_latest_checkpoint',
'update_data_root', 'setup_multi_processes', 'get_caller_name',
'log_img_scale', 'compat_cfg', 'split_batch', 'build_ddp', 'build_dp',
'get_device', 'replace_cfg_vals'
'get_device', 'replace_cfg_vals', 'AvoidOOM', 'AvoidCUDAOOM'
]
214 changes: 214 additions & 0 deletions mmdet/utils/memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from collections import abc
from contextlib import contextmanager
from functools import wraps

import torch

from mmdet.utils import get_root_logger


def cast_tensor_type(inputs, src_type=None, dst_type=None):
"""Recursively convert Tensor in inputs from ``src_type`` to ``dst_type``.
Args:
inputs: Inputs that to be casted.
src_type (torch.dtype | torch.device): Source type.
src_type (torch.dtype | torch.device): Destination type.
Returns:
The same type with inputs, but all contained Tensors have been cast.
"""
assert dst_type is not None
if isinstance(inputs, torch.Tensor):
if isinstance(dst_type, torch.device):
# convert Tensor to dst_device
if hasattr(inputs, 'to') and \
hasattr(inputs, 'device') and \
(inputs.device == src_type or src_type is None):
return inputs.to(dst_type)
else:
return inputs
else:
# convert Tensor to dst_dtype
if hasattr(inputs, 'to') and \
hasattr(inputs, 'dtype') and \
(inputs.dtype == src_type or src_type is None):
return inputs.to(dst_type)
else:
return inputs
# we need to ensure that the type of inputs to be casted are the same
# as the argument `src_type`.
elif isinstance(inputs, abc.Mapping):
return type(inputs)({
k: cast_tensor_type(v, src_type=src_type, dst_type=dst_type)
for k, v in inputs.items()
})
elif isinstance(inputs, abc.Iterable):
return type(inputs)(
cast_tensor_type(item, src_type=src_type, dst_type=dst_type)
for item in inputs)
# TODO: Currently not supported
# elif isinstance(inputs, InstanceData):
# for key, value in inputs.items():
# inputs[key] = cast_tensor_type(
# value, src_type=src_type, dst_type=dst_type)
# return inputs
else:
return inputs


@contextmanager
def _ignore_torch_cuda_oom():
"""A context which ignores CUDA OOM exception from pytorch.
Code is modified from
<https://github.com/facebookresearch/detectron2/blob/main/detectron2/utils/memory.py> # noqa: E501
"""
try:
yield
except RuntimeError as e:
# NOTE: the string may change?
if 'CUDA out of memory. ' in str(e):
pass
else:
raise


class AvoidOOM:
"""Try to convert inputs to FP16 and CPU if got a PyTorch's CUDA Out of
Memory error. It will do the following steps:
1. First retry after calling `torch.cuda.empty_cache()`.
2. If that still fails, it will then retry by converting inputs
to FP16.
3. If that still fails trying to convert inputs to CPUs.
In this case, it expects the function to dispatch to
CPU implementation.
Args:
to_cpu (bool): Whether to convert outputs to CPU if get an OOM
error. This will slow down the code significantly.
Defaults to True.
test (bool): Skip `_ignore_torch_cuda_oom` operate that can use
lightweight data in unit test, only used in
test unit. Defaults to False.
Examples:
>>> from mmdet.utils.memory import AvoidOOM
>>> AvoidCUDAOOM = AvoidOOM()
>>> output = AvoidOOM.retry_if_cuda_oom(
>>> some_torch_function)(input1, input2)
>>> # To use as a decorator
>>> # from mmdet.utils import AvoidCUDAOOM
>>> @AvoidCUDAOOM.retry_if_cuda_oom
>>> def function(*args, **kwargs):
>>> return None
```
Note:
1. The output may be on CPU even if inputs are on GPU. Processing
on CPU will slow down the code significantly.
2. When converting inputs to CPU, it will only look at each argument
and check if it has `.device` and `.to` for conversion. Nested
structures of tensors are not supported.
3. Since the function might be called more than once, it has to be
stateless.
"""

def __init__(self, to_cpu=True, test=False):
self.logger = get_root_logger()
self.to_cpu = to_cpu
self.test = test

def retry_if_cuda_oom(self, func):
"""Makes a function retry itself after encountering pytorch's CUDA OOM
error.
The implementation logic is referred to
https://github.com/facebookresearch/detectron2/blob/main/detectron2/utils/memory.py
Args:
func: a stateless callable that takes tensor-like objects
as arguments.
Returns:
func: a callable which retries `func` if OOM is encountered.
""" # noqa: W605

@wraps(func)
def wrapped(*args, **kwargs):

# raw function
if not self.test:
with _ignore_torch_cuda_oom():
return func(*args, **kwargs)

# Clear cache and retry
torch.cuda.empty_cache()
with _ignore_torch_cuda_oom():
return func(*args, **kwargs)

# get the type and device of first tensor
dtype, device = None, None
values = args + tuple(kwargs.values())
for value in values:
if isinstance(value, torch.Tensor):
dtype = value.dtype
device = value.device
break
if dtype is None or device is None:
raise ValueError('There is no tensor in the inputs, '
'cannot get dtype and device.')

# Convert to FP16
fp16_args = cast_tensor_type(args, dst_type=torch.half)
fp16_kwargs = cast_tensor_type(kwargs, dst_type=torch.half)
self.logger.info(f'Attempting to copy inputs of {str(func)} '
f'to FP16 due to CUDA OOM')

# get input tensor type, the output type will same as
# the first parameter type.
with _ignore_torch_cuda_oom():
output = func(*fp16_args, **fp16_kwargs)
output = cast_tensor_type(
output, src_type=torch.half, dst_type=dtype)
if not self.test:
return output
self.logger.info('Using FP16 still meet CUDA OOM')

# Try on CPU. This will slow down the code significantly,
# therefore print a notice.
if self.to_cpu:
self.logger.info(f'Attempting to copy inputs of {str(func)} '
f'to CPU due to CUDA OOM')
cpu_device = torch.empty(0).device
cpu_args = cast_tensor_type(args, dst_type=cpu_device)
cpu_kwargs = cast_tensor_type(kwargs, dst_type=cpu_device)

# convert outputs to GPU
with _ignore_torch_cuda_oom():
self.logger.info(f'Convert outputs to GPU '
f'(device={device})')
output = func(*cpu_args, **cpu_kwargs)
output = cast_tensor_type(
output, src_type=cpu_device, dst_type=device)
return output

warnings.warn('Cannot convert output to GPU due to CUDA OOM, '
'the output is now on CPU, which might cause '
'errors if the output need to interact with GPU '
'data in subsequent operations')
self.logger.info('Cannot convert output to GPU due to '
'CUDA OOM, the output is on CPU now.')

return func(*cpu_args, **cpu_kwargs)
else:
# may still get CUDA OOM error
return func(*args, **kwargs)

return wrapped


# To use AvoidOOM as a decorator
AvoidCUDAOOM = AvoidOOM()
98 changes: 98 additions & 0 deletions tests/test_utils/test_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import numpy as np
import pytest
import torch

from mmdet.utils import AvoidOOM
from mmdet.utils.memory import cast_tensor_type


def test_avoidoom():
tensor = torch.from_numpy(np.random.random((20, 20)))
if torch.cuda.is_available():
tensor = tensor.cuda()
# get default result
default_result = torch.mm(tensor, tensor.transpose(1, 0))

# when not occurred OOM error
AvoidCudaOOM = AvoidOOM()
result = AvoidCudaOOM.retry_if_cuda_oom(torch.mm)(tensor,
tensor.transpose(
1, 0))
assert default_result.device == result.device and \
default_result.dtype == result.dtype and \
torch.equal(default_result, result)

# calculate with fp16 and convert back to source type
AvoidCudaOOM = AvoidOOM(test=True)
result = AvoidCudaOOM.retry_if_cuda_oom(torch.mm)(tensor,
tensor.transpose(
1, 0))
assert default_result.device == result.device and \
default_result.dtype == result.dtype and \
torch.allclose(default_result, result, 1e-3)

# calculate on cpu and convert back to source device
AvoidCudaOOM = AvoidOOM(test=True)
result = AvoidCudaOOM.retry_if_cuda_oom(torch.mm)(tensor,
tensor.transpose(
1, 0))
assert result.dtype == default_result.dtype and \
result.device == default_result.device and \
torch.allclose(default_result, result)

# do not calculate on cpu and the outputs will be same as input
AvoidCudaOOM = AvoidOOM(test=True, to_cpu=False)
result = AvoidCudaOOM.retry_if_cuda_oom(torch.mm)(tensor,
tensor.transpose(
1, 0))
assert result.dtype == default_result.dtype and \
result.device == default_result.device

else:
default_result = torch.mm(tensor, tensor.transpose(1, 0))
AvoidCudaOOM = AvoidOOM()
result = AvoidCudaOOM.retry_if_cuda_oom(torch.mm)(tensor,
tensor.transpose(
1, 0))
assert default_result.device == result.device and \
default_result.dtype == result.dtype and \
torch.equal(default_result, result)


def test_cast_tensor_type():
inputs = torch.rand(10)
if torch.cuda.is_available():
inputs = inputs.cuda()
with pytest.raises(AssertionError):
cast_tensor_type(inputs, src_type=None, dst_type=None)
# input is a float
out = cast_tensor_type(10., dst_type=torch.half)
assert out == 10. and isinstance(out, float)
# convert Tensor to fp16 and re-convert to fp32
fp16_out = cast_tensor_type(inputs, dst_type=torch.half)
assert fp16_out.dtype == torch.half
fp32_out = cast_tensor_type(fp16_out, dst_type=torch.float32)
assert fp32_out.dtype == torch.float32

# input is a list
list_input = [inputs, inputs]
list_outs = cast_tensor_type(list_input, dst_type=torch.half)
assert len(list_outs) == len(list_input) and \
isinstance(list_outs, list)
for out in list_outs:
assert out.dtype == torch.half
# input is a dict
dict_input = {'test1': inputs, 'test2': inputs}
dict_outs = cast_tensor_type(dict_input, dst_type=torch.half)
assert len(dict_outs) == len(dict_input) and \
isinstance(dict_outs, dict)

# convert the input tensor to CPU and re-convert to GPU
if torch.cuda.is_available():
cpu_device = torch.empty(0).device
gpu_device = inputs.device
cpu_out = cast_tensor_type(inputs, dst_type=cpu_device)
assert cpu_out.device == cpu_device

gpu_out = cast_tensor_type(inputs, dst_type=gpu_device)
assert gpu_out.device == gpu_device

0 comments on commit 9761de5

Please sign in to comment.