Skip to content

Commit

Permalink
[Feature] Add the support of arf op for ascend device (#2789)
Browse files Browse the repository at this point in the history
  • Loading branch information
dflhw committed May 11, 2023
1 parent 883d339 commit e197eff
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 5 deletions.
2 changes: 1 addition & 1 deletion docs/en/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ At the same time, MMCV released [2.x](https://github.com/open-mmlab/mmcv/tree/2.
- `mmcv.fileio` module, removed in PR [#2179](https://github.com/open-mmlab/mmcv/pull/2179). FileIO module from mmengine will be used wherever required.
- `mmcv.runner`, `mmcv.parallel`, `mmcv. engine` and `mmcv.device`, removed in PR [#2216](https://github.com/open-mmlab/mmcv/pull/2216).
- All classes in `mmcv.utils` (eg `Config` and `Registry`) and many functions, removed in PR [#2217](https://github.com/open-mmlab/mmcv/pull/2217). Only a few functions related to mmcv are reserved.
- `mmcv.onnex`, `mmcv.tensorrt` modules and related functions, removed in PR [#2225](https://github.com/open-mmlab/mmcv/pull/2225).
- `mmcv.onnx`, `mmcv.tensorrt` modules and related functions, removed in PR [#2225](https://github.com/open-mmlab/mmcv/pull/2225).

(2) It added the [`mmcv.transforms`](https://github.com/open-mmlab/mmcv/tree/2.x/mmcv/transforms) data transformation module.

Expand Down
2 changes: 1 addition & 1 deletion docs/en/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ We implement common ops used in detection, segmentation, etc.

| Device | CPU | CUDA | MLU | MPS | Ascend |
| ---------------------------- | --- | ---- | --- | --- | ------ |
| ActiveRotatedFilter ||| | | |
| ActiveRotatedFilter ||| | | |
| AssignScoreWithK | || | | |
| BallQuery | ||| | |
| BBoxOverlaps | |||||
Expand Down
2 changes: 1 addition & 1 deletion docs/zh_cn/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ OpenMMLab 团队于 2022 年 9 月 1 日在世界人工智能大会发布了新
- `mmcv.fileio` 模块,删除于 PR [#2179](https://github.com/open-mmlab/mmcv/pull/2179)。在需要使用 FileIO 的地方使用 mmengine 中的 FileIO 模块
- `mmcv.runner``mmcv.parallel``mmcv.engine``mmcv.device`,删除于 PR [#2216](https://github.com/open-mmlab/mmcv/pull/2216)
- `mmcv.utils` 的所有类(例如 `Config``Registry`)和大部分函数,删除于 PR [#2217](https://github.com/open-mmlab/mmcv/pull/2217),只保留少数和 mmcv 相关的函数
- `mmcv.onnex``mmcv.tensorrt` 模块以及相关的函数,删除于 PR [#2225](https://github.com/open-mmlab/mmcv/pull/2225)
- `mmcv.onnx``mmcv.tensorrt` 模块以及相关的函数,删除于 PR [#2225](https://github.com/open-mmlab/mmcv/pull/2225)

(2)新增了 [`mmcv.transforms`](https://github.com/open-mmlab/mmcv/tree/2.x/mmcv/transforms) 数据变换模块

Expand Down
2 changes: 1 addition & 1 deletion docs/zh_cn/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ MMCV 提供了检测、分割等任务中常用的算子

| Device | CPU | CUDA | MLU | MPS | Ascend |
| ---------------------------- | --- | ---- | --- | --- | ------ |
| ActiveRotatedFilter ||| | | |
| ActiveRotatedFilter ||| | | |
| AssignScoreWithK | || | | |
| BallQuery | ||| | |
| BBoxOverlaps | |||||
Expand Down
36 changes: 36 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/active_rotated_filter_npu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#include "pytorch_npu_helper.hpp"

using namespace NPU_NAME_SPACE;
using namespace std;

void active_rotated_filter_forward_impl(const Tensor input,
const Tensor indices, Tensor output);

void active_rotated_filter_backward_impl(const Tensor grad_out,
const Tensor indices, Tensor grad_in);

void active_rotated_filter_forward_npu(const Tensor input, const Tensor indices,
Tensor output) {
OpCommand cmd;
cmd.Name("ActiveRotatedFilter")
.Input(input)
.Input(indices)
.Output(output)
.Run();
}

void active_rotated_filter_backward_npu(const Tensor grad_out,
const Tensor indices, Tensor grad_in) {
OpCommand cmd;
cmd.Name("ActiveRotatedFilterGrad")
.Input(grad_out)
.Input(indices)
.Output(grad_in)
.Run();
}

REGISTER_NPU_IMPL(active_rotated_filter_forward_impl,
active_rotated_filter_forward_npu);

REGISTER_NPU_IMPL(active_rotated_filter_backward_impl,
active_rotated_filter_backward_npu);
7 changes: 6 additions & 1 deletion tests/test_ops/test_active_rotated_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch

from mmcv.ops import active_rotated_filter
from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE

np_feature = np.array([[[[[-1.4934e-01, 1.1341e+00, -1.6241e-01],
[-1.0986e+00, -1.1463e+00, -1.3176e+00],
Expand Down Expand Up @@ -245,7 +246,11 @@
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')),
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
])
def test_active_rotated_filter(device):
feature = torch.tensor(
Expand Down

0 comments on commit e197eff

Please sign in to comment.