diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f4dd84c0b4..2c300d103d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,6 +29,7 @@ repos: rev: v2.1.0 hooks: - id: codespell + exclude: ^README.md - repo: https://github.com/executablebooks/mdformat rev: 0.7.9 hooks: diff --git a/README.md b/README.md index 1a6541a689..57247eb8b0 100644 --- a/README.md +++ b/README.md @@ -247,6 +247,39 @@ c. Install full version with custom operators for onnxruntime If you would like to build MMCV from source, please refer to the [guide](https://mmcv.readthedocs.io/en/latest/get_started/build.html). +## NPU build and Installation + +You may want to run mmcv on your npu device, then you can build and install mmcv-npu by the following steps. + +a. Install the **ascend-toolkit** + +```python + Ascend-cann-toolkit_{version}_linux-{arch}.run +``` + +- You can download the ascend-toolkit package in https://www.hiascend.com/software/cann/community. Choose the **"Ascend-cann-toolkit\_{xxx.xxx}.run"** which fits your develop environment. +- In order to install **CANN** quickly, you can refer to the documents in https://www.hiascend.com/document/detail/zh/canncommercial/51RC2/envdeployment/instg/instg_000052.html + +b. Install the **toch_npu** + +- As the dispatch mechanism is based on torch, you have to install torch-npu before running your mmcv.ops on npu device. +- you can download the torch_npu code from https://gitee.com/ascend/pytorch, and install torch-npu as the steps in README. +- torch-npu depends on ascend-toolkit. So you have to install the ascend-toolkit, and set the ascend environment. +- ```python + source /usr/local/Ascend/ascned-toolkit/set_env.sh + ``` + +c. build and install mmcv-npu + +- ```bash + MMCV_WITH_OPS=1 FORCE_NPU=1 python setup.py build_ext + MMCV_WITH_OPS=1 FORCE_NPU=1 python setup.py develop + ``` +- or +- ```bash + MMCV_WITH_OPS=1 FORCE_NPU=1 python setup.py install + ``` + ## FAQ If you face some installation issues, CUDA related issues or RuntimeErrors, diff --git a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp index bd82824689..030fa02fb6 100644 --- a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp @@ -1,10 +1,20 @@ #include "pytorch_npu_helper.hpp" using namespace NPU_NAME_SPACE; +using namespace std; void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, Tensor output, float gamma, float alpha) { - at::Tensor target_y = at::reshape(target, input.sizes()); + int64_t n_class = input.size(1); + at::Tensor target_y = at::ones_like(input); + if(n_class == 1) { + target_y = at::reshape(target, input.sizes()); + target_y = at::mul(target_y, -1.0); + target_y = at::add(target_y, 1.0); + } + else { + target_y = at_npu::native::NPUNativeFunctions::one_hot(target, n_class); + } target_y = at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt); int64_t weight_size = weight.size(0); @@ -14,6 +24,7 @@ void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, input.sizes()); } OpCommand cmd; + string reduction = "none"; cmd.Name("SigmoidFocalLoss") .Input(input) .Input(target_y) @@ -21,7 +32,7 @@ void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, .Output(output) .Attr("gamma", gamma) .Attr("alpha", alpha) - .Attr("reduction", "none") + .Attr("reduction", reduction) .Run(); } @@ -31,7 +42,16 @@ void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, Tensor grad_input, float gamma, float alpha) { - at::Tensor target_y = at::reshape(target, input.sizes()); + int64_t n_class = input.size(1); + at::Tensor target_y = at::ones_like(input); + if(n_class == 1) { + target_y = at::reshape(target, input.sizes()); + } + else { + target_y = at_npu::native::NPUNativeFunctions::one_hot(target, n_class); + target_y = at::mul(target_y, -1.0); + target_y = at::add(target_y, 1.0); + } target_y = at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt); at::Tensor grad_up = at::ones_like(input); @@ -42,6 +62,7 @@ void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, input.sizes()); } OpCommand cmd; + string reduction = "none"; cmd.Name("SigmoidFocalLossGrad") .Input(input) .Input(target_y) @@ -50,7 +71,7 @@ void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, .Output(grad_input) .Attr("gamma", gamma) .Attr("alpha", alpha) - .Attr("reduction", "none") + .Attr("reduction", reduction) .Run(); } @@ -71,16 +92,25 @@ void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, input.sizes()); } + at::Tensor op_output = at::ones_like(input); OpCommand cmd; + string reduction = "none"; cmd.Name("SoftmaxFocalLoss") .Input(input) .Input(target_y) .Input(weight_y) - .Output(output) + .Output(op_output) .Attr("gamma", gamma) .Attr("alpha", alpha) - .Attr("reduction", "none") + .Attr("reduction", reduction) .Run(); + int64_t n_batch = input.size(0); + c10::SmallVector offsets = {0,0}; + c10::SmallVector sizes = {n_batch,1}; + at::IntArrayRef offset = at::IntArrayRef(offsets); + at::IntArrayRef size = at::IntArrayRef(sizes); + at_npu::native::NPUNativeFunctions::npu_slice_out(op_output, offset, + size, output); } void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, @@ -102,8 +132,8 @@ void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, input.sizes()); } - OpCommand cmd; + string reduction = "none"; cmd.Name("SoftmaxFocalLossGrad") .Input(input) .Input(target_y) @@ -112,7 +142,7 @@ void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, .Output(grad_input) .Attr("gamma", gamma) .Attr("alpha", alpha) - .Attr("reduction", "none") + .Attr("reduction", reduction) .Run(); }