Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] Enable full precision training on Ascend NPU. #1109

Merged
merged 1 commit into from
May 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions mmengine/device/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .utils import (get_device, get_max_cuda_memory, is_cuda_available,
is_mlu_available, is_mps_available, is_npu_available)
is_mlu_available, is_mps_available, is_npu_available,
is_npu_support_full_precision)

__all__ = [
'get_max_cuda_memory', 'get_device', 'is_cuda_available',
'is_mlu_available', 'is_mps_available', 'is_npu_available'
'is_mlu_available', 'is_mps_available', 'is_npu_available',
'is_npu_support_full_precision'
]
8 changes: 8 additions & 0 deletions mmengine/device/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

try:
import torch_npu # noqa: F401
import torch_npu.npu.utils as npu_utils

# Enable operator support for dynamic shape and
# binary operator support on the NPU.
Expand Down Expand Up @@ -62,6 +63,13 @@ def is_mps_available() -> bool:
return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()


def is_npu_support_full_precision() -> bool:
"""Returns True if npu devices support full precision training."""
version_of_support_full_precision = 220
return IS_NPU_AVAILABLE and npu_utils.get_soc_version(
) >= version_of_support_full_precision


DEVICE = 'cpu'
if is_npu_available():
DEVICE = 'npu'
Expand Down
8 changes: 4 additions & 4 deletions mmengine/optim/optimizer/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch.nn as nn

from mmengine.config import Config, ConfigDict
from mmengine.device import is_npu_available
from mmengine.device import is_npu_available, is_npu_support_full_precision
from mmengine.registry import OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS
from .optimizer_wrapper import OptimWrapper

Expand Down Expand Up @@ -128,9 +128,9 @@ def build_optim_wrapper(model: nn.Module,
paramwise_cfg = optim_wrapper_cfg.pop('paramwise_cfg', None)

# Since the current generation of NPU(Ascend 910) only supports
# mixed precision training, here we turn on mixed precision by default
# on the NPU to make the training normal
if is_npu_available():
# mixed precision training, here we turn on mixed precision
# to make the training normal
if is_npu_available() and not is_npu_support_full_precision():
optim_wrapper_cfg['type'] = 'AmpOptimWrapper'

optim_wrapper_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build(
Expand Down