Skip to content

Commit

Permalink
[Enhancement] Enable timeout in dist training (#877)
Browse files Browse the repository at this point in the history
* Adding missing pre-commit requirement to tests.txt

* Added support for setting a timeout for distributed learning

* Adding documentation about how to change the runtime timeout into the distributed manual.

* Fixed type in documentation to correctly specify an integer

* Removing type-cast after checking the correct type already before

* Update mmengine/dist/utils.py

Adding an explicit `is not None` to the check

Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>

* Removing explicit type check and replacing it with more pythonic way of assuming it is the right type and handling the exception if the type doesn't match.

* Removing pre-commit from test requirements again

* Simplified the code according to suggestions from PR

* Update distributed.md

---------

Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>
Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
  • Loading branch information
3 people committed Feb 3, 2023
1 parent d1d4609 commit 1aa14b4
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 1 deletion.
11 changes: 11 additions & 0 deletions docs/en/advanced_tutorials/distributed.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,17 @@ We will detail on these APIs in the following chapters.

- [init_dist](mmengine.dist.init_dist): Launch function of distributed training. Currently it supports 3 launchers including pytorch, slurm and MPI. It also setup the given communication backends, defaults to NCCL.

If you need to change the runtime timeout (default=30 minutes) for distributed operations that take very long, you can specify a different timeout in your `env_cfg` configuration passing in [Runner](mmengine.runner.Runner) like this:

```python
env_cfg = dict(
cudnn_benchmark=True,
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl', timeout=10800), # Sets the timeout to 3h (10800 seconds)
)
runner = Runner(xxx, env_cfg=env_cfg)
```

## Query and control

The query and control functions are all argument free.
Expand Down
14 changes: 14 additions & 0 deletions mmengine/dist/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import datetime
import functools
import os
import subprocess
Expand Down Expand Up @@ -50,6 +51,19 @@ def init_dist(launcher, backend='nccl', **kwargs) -> None:
'gloo' and 'mpi'. Defaults to 'nccl'.
**kwargs: keyword arguments are passed to ``init_process_group``.
"""
timeout = kwargs.get('timeout', None)
if timeout is not None:
# If a timeout (in seconds) is specified, it must be converted
# to a timedelta object before forwarding the call to
# the respective backend, because they expect a timedelta object.
try:
kwargs['timeout'] = datetime.timedelta(seconds=timeout)
except TypeError as exception:
raise TypeError(
f'Timeout for distributed training must be provided as '
f"timeout in seconds, but we've received the type "
f'{type(timeout)}. Please specify the timeout like this: '
f"dist_cfg=dict(backend='nccl', timeout=1800)") from exception
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
if launcher == 'pytorch':
Expand Down
2 changes: 1 addition & 1 deletion mmengine/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ def setup_env(self, env_cfg: Dict) -> None:
mp_start_method='fork',
opencv_num_threads=0
),
dist_cfg=dict(backend='nccl'),
dist_cfg=dict(backend='nccl', timeout=1800),
resource_limit=4096
)
Expand Down

0 comments on commit 1aa14b4

Please sign in to comment.