diff --git a/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp b/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp index 8b30fa15df..4f5c32dbec 100644 --- a/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp @@ -5,19 +5,34 @@ using namespace std; void chamfer_distance_forward_npu(Tensor XYZ1, Tensor XYZ2, Tensor dist1, Tensor dist2, Tensor idx1, Tensor idx2) { + bool is_half = XYZ1.scalar_type() == at::kHalf; at::Tensor xyz1 = at::ones_like(XYZ1); at::Tensor xyz2 = at::ones_like(XYZ2); + at::Tensor distf1 = at::ones_like(dist1); + at::Tensor distf2 = at::ones_like(dist2); xyz1 = XYZ1.transpose(1, 2).transpose(0, 1); xyz2 = XYZ2.transpose(1, 2).transpose(0, 1); + if (is_half) { + xyz1 = xyz1.to(at::kFloat); + xyz2 = xyz2.to(at::kFloat); + distf1 = dist1.to(at::kFloat); + distf2 = dist2.to(at::kFloat); + } OpCommand cmd; cmd.Name("ChamferDistance") .Input(xyz1) .Input(xyz2) - .Output(dist1) - .Output(dist2) + .Output(distf1) + .Output(distf2) .Output(idx1) .Output(idx2) .Run(); + if (is_half) { + distf1 = distf1.to(at::kHalf); + distf2 = distf2.to(at::kHalf); + } + dist1.copy_(distf1); + dist2.copy_(distf2); } void chamfer_distance_backward_npu(Tensor xyz1, Tensor xyz2, Tensor idx1, diff --git a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp index 5030fed0e7..3f3bc5a047 100644 --- a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp @@ -4,6 +4,21 @@ using namespace std; void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, Tensor output, float gamma, float alpha) { + at::Tensor input_y = input; + at::Tensor output_y = output; + bool is_half = input.scalar_type() == at::kHalf; + if (is_half) { + input_y = input.to(at::kFloat); + output_y = output.to(at::kFloat); + } + int64_t weight_size = weight.size(0); + at::Tensor weight_y = at::ones_like(input_y); + if (weight_size > 0) { + weight_y = at::broadcast_to(weight, input.sizes()); + if (is_half) { + weight_y = weight_y.to(at::kFloat); + } + } int64_t n_class = input.size(1); at::Tensor target_y = at::ones_like(input); if (n_class == 1) { @@ -12,24 +27,26 @@ void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, target_y = at::add(target_y, 1.0); } else { target_y = at::one_hot(target, n_class); + weight_y = at::mul(weight_y, target_y); + weight_y = at::sum(weight_y, 1, true); + weight_y = at::broadcast_to(weight_y, input.sizes()); } target_y = target_y.to(at::kInt); - int64_t weight_size = weight.size(0); - at::Tensor weight_y = at::ones_like(input); - if (weight_size > 0) { - weight_y = at::broadcast_to(weight, input.sizes()); - } OpCommand cmd; string reduction = "none"; cmd.Name("SigmoidFocalLoss") - .Input(input) + .Input(input_y) .Input(target_y) .Input(weight_y) - .Output(output) + .Output(output_y) .Attr("gamma", gamma) .Attr("alpha", alpha) .Attr("reduction", reduction) .Run(); + if (is_half) { + output_y = output_y.to(at::kHalf); + } + output.copy_(output_y); } void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, @@ -38,34 +55,51 @@ void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, Tensor grad_input, float gamma, float alpha) { + at::Tensor input_y = input; + at::Tensor grad_input_y = grad_input; + bool is_half = input.scalar_type() == at::kHalf; + if (is_half) { + input_y = input.to(at::kFloat); + grad_input_y = grad_input.to(at::kFloat); + } + int64_t weight_size = weight.size(0); + at::Tensor weight_y = at::ones_like(input_y); + if (weight_size > 0) { + weight_y = at::broadcast_to(weight, input.sizes()); + if (is_half) { + weight_y = weight_y.to(at::kFloat); + } + } int64_t n_class = input.size(1); at::Tensor target_y = at::ones_like(input); if (n_class == 1) { target_y = at::reshape(target, input.sizes()); } else { target_y = at::one_hot(target, n_class); + weight_y = at::mul(weight_y, target_y); + weight_y = at::sum(weight_y, 1, true); + weight_y = at::broadcast_to(weight_y, input.sizes()); target_y = at::mul(target_y, -1.0); target_y = at::add(target_y, 1.0); } target_y = target_y.to(at::kInt); at::Tensor grad_up = at::ones_like(input); - int64_t weight_size = weight.size(0); - at::Tensor weight_y = at::ones_like(input); - if (weight_size > 0) { - weight_y = at::broadcast_to(weight, input.sizes()); - } OpCommand cmd; string reduction = "none"; cmd.Name("SigmoidFocalLossGrad") - .Input(input) + .Input(input_y) .Input(target_y) .Input(grad_up) .Input(weight_y) - .Output(grad_input) + .Output(grad_input_y) .Attr("gamma", gamma) .Attr("alpha", alpha) .Attr("reduction", reduction) .Run(); + if (is_half) { + grad_input_y = grad_input_y.to(at::kHalf); + } + grad_input.copy_(grad_input_y); } void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target, @@ -74,19 +108,30 @@ void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target, void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, Tensor output, float gamma, float alpha) { + at::Tensor input_y = input; + bool is_half = input.scalar_type() == at::kHalf; + if (is_half) { + input_y = input.to(at::kFloat); + } int64_t n_class = input.size(1); at::Tensor target_y = at::one_hot(target, n_class); target_y = target_y.to(at::kInt); int64_t weight_size = weight.size(0); - at::Tensor weight_y = at::ones_like(input); + at::Tensor weight_y = at::ones_like(input_y); if (weight_size > 0) { weight_y = at::broadcast_to(weight, input.sizes()); + if (is_half) { + weight_y = weight_y.to(at::kFloat); + } + weight_y = at::mul(weight_y, target_y); + weight_y = at::sum(weight_y, 1, true); + weight_y = at::broadcast_to(weight_y, input.sizes()); } - at::Tensor op_output = at::ones_like(input); + at::Tensor op_output = at::ones_like(input_y); OpCommand cmd; string reduction = "none"; cmd.Name("SoftmaxFocalLoss") - .Input(input) + .Input(input_y) .Input(target_y) .Input(weight_y) .Output(op_output) @@ -94,6 +139,9 @@ void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, .Attr("alpha", alpha) .Attr("reduction", reduction) .Run(); + if (is_half) { + op_output = op_output.to(at::kHalf); + } int64_t n_batch = input.size(0); c10::SmallVector offsets = {0, 0}; c10::SmallVector sizes = {n_batch, 1}; @@ -124,27 +172,44 @@ void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, Tensor buff, Tensor grad_input, float gamma, float alpha) { + at::Tensor input_y = input; + at::Tensor grad_input_y = grad_input; + bool is_half = input.scalar_type() == at::kHalf; + if (is_half) { + input_y = input.to(at::kFloat); + grad_input_y = grad_input.to(at::kFloat); + } int64_t n_class = input.size(1); at::Tensor target_y = at::one_hot(target, n_class); target_y = target_y.to(at::kInt); at::Tensor grad_up = at::ones_like(input); int64_t weight_size = weight.size(0); - at::Tensor weight_y = at::ones_like(input); + at::Tensor weight_y = at::ones_like(input_y); if (weight_size > 0) { weight_y = at::broadcast_to(weight, input.sizes()); + if (is_half) { + weight_y = weight_y.to(at::kFloat); + } + weight_y = at::mul(weight_y, target_y); + weight_y = at::sum(weight_y, 1, true); + weight_y = at::broadcast_to(weight_y, input.sizes()); } OpCommand cmd; string reduction = "none"; cmd.Name("SoftmaxFocalLossGrad") - .Input(input) + .Input(input_y) .Input(target_y) .Input(grad_up) .Input(weight_y) - .Output(grad_input) + .Output(grad_input_y) .Attr("gamma", gamma) .Attr("alpha", alpha) .Attr("reduction", reduction) .Run(); + if (is_half) { + grad_input_y = grad_input_y.to(at::kHalf); + } + grad_input.copy_(grad_input_y); } void softmax_focal_loss_backward_impl(Tensor input, Tensor target, diff --git a/mmcv/ops/csrc/pytorch/npu/gather_points_npu.cpp b/mmcv/ops/csrc/pytorch/npu/gather_points_npu.cpp index 747380fb09..279f14008b 100644 --- a/mmcv/ops/csrc/pytorch/npu/gather_points_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/gather_points_npu.cpp @@ -24,6 +24,12 @@ void gather_points_forward_npu(int b, int c, int n, int npoints, void gather_points_backward_npu(int b, int c, int n, int npoints, const Tensor grad_out, const Tensor idx, Tensor grad_points) { + at::Tensor grad_out_cast = grad_out; + at::Tensor grad_points_cast = grad_points; + if (grad_out.scalar_type() == at::ScalarType::Half) { + grad_out_cast = grad_out.to(at::kFloat); + grad_points_cast = grad_points.to(at::kFloat); + } at::Tensor indices = idx; if (idx.scalar_type() != at::ScalarType::Int) { indices = idx.to(at::kInt); @@ -37,11 +43,11 @@ void gather_points_backward_npu(int b, int c, int n, int npoints, for (uint64_t i = 0; i < shape.size(); i++) { pad_size.emplace_back(shape[i]); } - at::Tensor trans_grad_points = grad_points.transpose(1, 2).contiguous(); + at::Tensor trans_grad_points = grad_points_cast.transpose(1, 2).contiguous(); at::Tensor grad_points_view = trans_grad_points.view( {trans_grad_points.sizes()[0] * trans_grad_points.sizes()[1], trans_grad_points.sizes()[2]}); - at::Tensor trans_grad_out = grad_out.transpose(1, 2).contiguous(); + at::Tensor trans_grad_out = grad_out_cast.transpose(1, 2).contiguous(); trans_grad_out = trans_grad_out.view( {trans_grad_out.sizes()[0] * trans_grad_out.sizes()[1], trans_grad_out.sizes()[2]}); @@ -63,7 +69,11 @@ void gather_points_backward_npu(int b, int c, int n, int npoints, at::Tensor grad_points_result = grad_points_view.view(trans_grad_points.sizes()); grad_points_result = grad_points_result.transpose(1, 2); - grad_points.copy_(grad_points_result); + at::Tensor grad_points_result_cast = grad_points_result; + if (grad_out.scalar_type() == at::ScalarType::Half) { + grad_points_result_cast = grad_points_result.to(at::kHalf); + } + grad_points.copy_(grad_points_result_cast); } void gather_points_forward_impl(int b, int c, int n, int npoints, diff --git a/mmcv/ops/csrc/pytorch/npu/knn_npu.cpp b/mmcv/ops/csrc/pytorch/npu/knn_npu.cpp index f25f9cf623..c4a1bcbd25 100644 --- a/mmcv/ops/csrc/pytorch/npu/knn_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/knn_npu.cpp @@ -8,11 +8,11 @@ using namespace std; void knn_forward_npu(int b, int n, int m, int nsample, const Tensor xyz, const Tensor new_xyz, Tensor idx, Tensor dist2) { // transpose known from [B, N, 3] to [B, 3, N] - at::Tensor source = xyz.transpose(1, 2).contiguous(); + at::Tensor source = xyz.transpose(2, 1).contiguous(); at::Tensor target = new_xyz.contiguous(); bool is_from_knn = true; - EXEC_NPU_CMD(aclnnKnn, source, target, nsample, is_from_knn, idx, dist2); + EXEC_NPU_CMD(aclnnKnn, source, target, is_from_knn, dist2); } void knn_forward_impl(int b, int n, int m, int nsample, const Tensor xyz, diff --git a/mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp b/mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp index c7a11e8c6d..b7015439b9 100644 --- a/mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp @@ -50,23 +50,29 @@ void roi_pool_backward_npu(Tensor grad_output, Tensor rois, Tensor argmax, int64_t pooled_height_64 = pooled_height; int64_t pooled_width_64 = pooled_width; int64_t pooled_channel = 1; + at::Tensor argmax_trans = argmax.transpose(1, 2).transpose(2, 3); + at::Tensor grad_output_trans = grad_output.transpose(1, 2).transpose(2, 3); at::Tensor roi_actual_num = at::empty_like(rois, rois.options().dtype(at::kInt)); - at::Tensor x = at::ones_like(grad_input); + at::Tensor x = at::ones_like(grad_input).transpose(1, 2).transpose(2, 3); + at::Tensor y = at::zeros_like(x); OpCommand cmd; cmd.Name("RoiPoolingGradWithArgMax") - .Input(grad_output) + .Input(grad_output_trans) .Input(x) .Input(rois) .Input(roi_actual_num) - .Input(argmax) - .Output(grad_input) + .Input(argmax_trans) + .Output(y) .Attr("pooled_h", pooled_height_64) .Attr("pooled_w", pooled_width_64) .Attr("spatial_scale_h", spatial_scale) .Attr("spatial_scale_w", spatial_scale) .Attr("pool_channel", pooled_channel) .Run(); + at::Tensor result = y.transpose(2, 3).transpose(1, 2); + at::Tensor res = result.contiguous(); + grad_input.copy_(res); } void roi_pool_forward_impl(Tensor input, Tensor rois, Tensor output, diff --git a/mmcv/ops/csrc/pytorch/npu/stack_ball_query_npu.cpp b/mmcv/ops/csrc/pytorch/npu/stack_ball_query_npu.cpp index cd8c3ad8c9..92627df6e3 100644 --- a/mmcv/ops/csrc/pytorch/npu/stack_ball_query_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/stack_ball_query_npu.cpp @@ -8,9 +8,10 @@ void stack_ball_query_forward_npu(float max_radius, int nsample, const Tensor new_xyz_batch_cnt, const Tensor xyz, const Tensor xyz_batch_cnt, Tensor idx) { - at::Tensor xyz_transpose = xyz.transpose(0, 1).contiguous(); + at::Tensor xyz_transpose = xyz.transpose(0, 1).contiguous().to(at::kFloat); + at::Tensor new_xyz_fp32 = new_xyz.to(at::kFloat); double max_radius_double = double(max_radius); - EXEC_NPU_CMD(aclnnStackBallQuery, xyz_transpose, new_xyz, xyz_batch_cnt, + EXEC_NPU_CMD(aclnnStackBallQuery, xyz_transpose, new_xyz_fp32, xyz_batch_cnt, new_xyz_batch_cnt, max_radius_double, nsample, idx); } diff --git a/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp b/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp index f908755478..42d346f7d2 100644 --- a/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp @@ -12,17 +12,21 @@ void three_interpolate_forward_npu(int b, int c, int m, int n, TORCH_CHECK((originDtype == at::kFloat || originDtype == at::kHalf), "three_interpolate_forward ascend only support fp32 and fp16."); - auto point_c_trans = points.transpose(1, 2); - + auto point_c_trans = points.transpose(1, 2).to(at::kFloat); + auto weight_cast = weight.to(at::kFloat); + auto out_cast = out.to(at::kFloat); OpCommand cmd; cmd.Name("ThreeInterpolate") .Input(point_c_trans) .Input(idx) - .Input(weight) - .Output(out) + .Input(weight_cast) + .Output(out_cast) .Run(); - auto output = out.view({b, n, c}).transpose(1, 2); + if (originDtype == at::kHalf) { + out_cast = out_cast.to(at::kHalf); + } + auto output = out_cast.view({b, n, c}).transpose(1, 2); auto res = output.contiguous(); out.copy_(res); } @@ -34,12 +38,17 @@ void three_interpolate_backward_npu(int b, int c, int n, int m, TORCH_CHECK((originDtype == at::kFloat || originDtype == at::kHalf), "three_interpolate_backward ascend only support fp32 and fp16."); - auto grad_x = at::unsqueeze(grad_out, 3); - auto grad_y = at::unsqueeze(grad_points, 3); - - EXEC_NPU_CMD(aclnnThreeInterpolateBackward, grad_x, idx, weight, m, grad_y); + auto grad_x = at::unsqueeze(grad_out, 3).to(at::kFloat); + auto grad_y = at::unsqueeze(grad_points, 3).to(at::kFloat); + auto weight_cast = weight.to(at::kFloat); + EXEC_NPU_CMD(aclnnThreeInterpolateBackward, grad_x, idx, weight_cast, m, + grad_y); - auto output = at::squeeze(grad_y, 3); + auto grad_y_cast = grad_y; + if (originDtype == at::kHalf) { + grad_y_cast = grad_y.to(at::kHalf); + } + auto output = at::squeeze(grad_y_cast, 3); auto res = output.contiguous(); grad_points.copy_(res); } diff --git a/mmcv/ops/csrc/pytorch/npu/three_nn_npu.cpp b/mmcv/ops/csrc/pytorch/npu/three_nn_npu.cpp index 9766816f6c..6740a731bc 100644 --- a/mmcv/ops/csrc/pytorch/npu/three_nn_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/three_nn_npu.cpp @@ -7,21 +7,11 @@ using namespace std; void three_nn_forward_npu(int b, int n, int m, const Tensor unknown, const Tensor known, Tensor dist2, Tensor idx) { - // transpose known [B, N, 3] -> [B, 3, N] - at::Tensor source = known.transpose(1, 2).contiguous(); + at::Tensor source = known.contiguous(); at::Tensor target = unknown.contiguous(); - auto originDtype = source.scalar_type(); - if (originDtype == at::kHalf) { - source = source.to(at::kFloat); - target = target.to(at::kFloat); - } bool is_from_knn = false; - uint32_t nsample = 3; - EXEC_NPU_CMD(aclnnKnn, source, target, nsample, is_from_knn, idx, dist2); - if (originDtype == at::kHalf) { - dist2 = dist2.to(at::kHalf); - } + EXEC_NPU_CMD(aclnnKnn, source, target, is_from_knn, dist2); } void three_nn_forward_impl(int b, int n, int m, const Tensor unknown, diff --git a/mmcv/ops/deform_conv.py b/mmcv/ops/deform_conv.py index c6cbba6779..856e3e36b6 100644 --- a/mmcv/ops/deform_conv.py +++ b/mmcv/ops/deform_conv.py @@ -57,7 +57,7 @@ def _npu_backward(ctx, grad_output): grad_input, grad_weight, grad_offset_all, grad_bias = \ torch_npu.npu_deformable_conv2dbk( input_tensor, grad_output, offset_out, weight, offset_all, - kernel_size=[weight.shape[3], weight.shape[2]], + kernel_size=[weight.shape[2], weight.shape[3]], stride=[1, 1, ctx.stride[0], ctx.stride[1]], padding=[ctx.padding[0], ctx.padding[0], ctx.padding[1], ctx.padding[1]], diff --git a/mmcv/ops/fused_bias_leakyrelu.py b/mmcv/ops/fused_bias_leakyrelu.py index e23617fb3a..fe17d2db7b 100644 --- a/mmcv/ops/fused_bias_leakyrelu.py +++ b/mmcv/ops/fused_bias_leakyrelu.py @@ -258,7 +258,7 @@ def fused_bias_leakyrelu(input: torch.Tensor, torch.Tensor: Feature map after non-linear activation. """ - if not input.is_cuda: + if not input.is_cuda and input.device.type != 'npu': return bias_leakyrelu_ref(input, bias, negative_slope, scale) return FusedBiasLeakyReLUFunction.apply(input, bias.to(input.dtype), diff --git a/mmcv/ops/knn.py b/mmcv/ops/knn.py index 47ced04c6a..1e2a68d1d2 100644 --- a/mmcv/ops/knn.py +++ b/mmcv/ops/knn.py @@ -62,6 +62,24 @@ def forward(ctx, B, npoint, _ = center_xyz.shape N = xyz.shape[1] + if xyz.device.type == 'npu': + dist = center_xyz.new_zeros((B, npoint, N)).float() + ext_module.knn_forward( + xyz, + center_xyz, + torch.Tensor([]).npu(), + dist, + b=B, + n=N, + m=npoint, + nsample=k) + dist2, idx = torch.topk(dist, k, dim=2, largest=False, sorted=True) + zeros_idx = torch.zeros( + xyz.shape[0], center_xyz.shape[1], k, dtype=torch.int32).npu() + idx.where(dist2 >= 1e10, zeros_idx) + idx = idx.transpose(2, 1).contiguous() # [B, k, npoint] + return idx.int() + idx = center_xyz.new_zeros((B, npoint, k)).int() dist2 = center_xyz.new_zeros((B, npoint, k)).float() diff --git a/mmcv/ops/modulated_deform_conv.py b/mmcv/ops/modulated_deform_conv.py index 0c169009a5..b6e8c6d40a 100644 --- a/mmcv/ops/modulated_deform_conv.py +++ b/mmcv/ops/modulated_deform_conv.py @@ -55,7 +55,7 @@ def _npu_forward(ctx, input_tensor, offset, mask, weight, bias): conv2d_bias = bias if len(bias) > 0 else None sort_index_fp, sort_index_bp = \ ModulatedDeformConv2dFunction._calculate_sort_index( - kernel_w, kernel_h, ctx.deform_groups) + kernel_h, kernel_w, ctx.deform_groups) select_offset = offset.index_select(1, sort_index_fp) offset_all = torch.cat([select_offset, mask], dim=1) import torch_npu @@ -64,7 +64,7 @@ def _npu_forward(ctx, input_tensor, offset, mask, weight, bias): weight, offset_all, conv2d_bias, - kernel_size=[kernel_w, kernel_h], + kernel_size=[kernel_h, kernel_w], stride=[1, 1, ctx.stride[0], ctx.stride[1]], padding=[ ctx.padding[0], ctx.padding[0], ctx.padding[1], ctx.padding[1] @@ -87,7 +87,7 @@ def _npu_backward(ctx, grad_output): grad_input, grad_weight, grad_offset_all, grad_bias = \ torch_npu.npu_deformable_conv2dbk( input_tensor, grad_output, offset_out, weight, offset_all, - kernel_size=[weight.shape[3], weight.shape[2]], + kernel_size=[weight.shape[2], weight.shape[3]], stride=[1, 1, ctx.stride[0], ctx.stride[1]], padding=[ctx.padding[0], ctx.padding[0], ctx.padding[1], ctx.padding[1]], diff --git a/mmcv/ops/nms.py b/mmcv/ops/nms.py index fb08ba07c6..0c6adfabc7 100644 --- a/mmcv/ops/nms.py +++ b/mmcv/ops/nms.py @@ -405,7 +405,7 @@ def nms_rotated(dets: Tensor, flip_mat[-1] = -1 dets_cw = dets * flip_mat else: - dets_cw = dets + dets_cw = dets.clone() multi_label = labels is not None if labels is None: input_labels = scores.new_empty(0, dtype=torch.int) @@ -415,6 +415,8 @@ def nms_rotated(dets: Tensor, order = scores.new_empty(0, dtype=torch.long) if dets.device.type == 'npu': coefficient = 57.29578 # 180 / PI + dets_cw = dets_cw.float() + scores = scores.float() for i in range(dets.size()[0]): dets_cw[i][4] *= coefficient # radians to angle keep_inds = ext_module.nms_rotated(dets_cw, scores, order, dets_cw, diff --git a/mmcv/ops/points_in_boxes.py b/mmcv/ops/points_in_boxes.py index 4915e6b573..23c35da4eb 100644 --- a/mmcv/ops/points_in_boxes.py +++ b/mmcv/ops/points_in_boxes.py @@ -47,8 +47,11 @@ def points_in_boxes_part(points: Tensor, boxes: Tensor) -> Tensor: points_device = points.get_device() assert points_device == boxes.get_device(), \ 'Points and boxes should be put on the same device' - if torch.cuda.current_device() != points_device: - torch.cuda.set_device(points_device) + if points.device.type != 'npu': + if torch.cuda.current_device() != points_device: + torch.cuda.set_device(points_device) + elif points.device.type == 'npu': + boxes[:, :, 2] += boxes[:, :, 5] / 2.0 ext_module.points_in_boxes_part_forward(boxes.contiguous(), points.contiguous(), diff --git a/mmcv/ops/points_in_polygons.py b/mmcv/ops/points_in_polygons.py index e54b5a896d..8d3bc8dd48 100644 --- a/mmcv/ops/points_in_polygons.py +++ b/mmcv/ops/points_in_polygons.py @@ -19,6 +19,8 @@ def points_in_polygons(points: Tensor, polygons: Tensor) -> Tensor: polygons (torch.Tensor): It has shape (M, 8), indicating (x1, y1, x2, y2, x3, y3, x4, y4). M means the number of ground truth polygons. + constraints: The number of significant digits for the input-arguments + are between -10 and 10 when running on Ascend device. Returns: torch.Tensor: Return the result with the shape of (B, M), diff --git a/mmcv/ops/scatter_points.py b/mmcv/ops/scatter_points.py index 5d881bfe63..6902343a03 100644 --- a/mmcv/ops/scatter_points.py +++ b/mmcv/ops/scatter_points.py @@ -36,10 +36,29 @@ def forward(ctx: Any, reduced from input features that share the same voxel coordinates. The second is voxel coordinates with shape [M, ndim]. """ + ctx.device = feats.device.type + if ctx.device == 'npu': + import ads_c + voxel_idx = ads_c.point_to_voxel(coors, None, None) + unique_res = ads_c.unique_voxel(voxel_idx) + num_voxels, uniqued_voxel_idx, prefix_sum, \ + argsort_coor = unique_res + voxel_coors = ads_c.voxel_to_point(uniqued_voxel_idx, None, None) + voxel_feats, \ + compare_mask = ads_c.npu_dynamic_scatter(feats, coors, + prefix_sum, + argsort_coor, + num_voxels, + reduce_type) + ctx.reduce_type = reduce_type + ctx.feats_shape = feats.shape + ctx.save_for_backward(prefix_sum, argsort_coor, compare_mask) + ctx.mark_non_differentiable(voxel_coors) + return voxel_feats, voxel_coors + results = ext_module.dynamic_point_to_voxel_forward( feats, coors, reduce_type) - (voxel_feats, voxel_coors, point2voxel_map, - voxel_points_count) = results + voxel_feats, voxel_coors, point2voxel_map, voxel_points_count = results ctx.reduce_type = reduce_type ctx.save_for_backward(feats, voxel_feats, point2voxel_map, voxel_points_count) @@ -50,6 +69,19 @@ def forward(ctx: Any, def backward(ctx: Any, grad_voxel_feats: torch.Tensor, grad_voxel_coors: Optional[torch.Tensor] = None) -> tuple: + if ctx.device == 'npu': + import ads_c + prefix_sum, argsort_coor, compare_mask = ctx.saved_tensors + grad_point_feats = torch.zeros( + ctx.feats_shape, + dtype=grad_voxel_feats.dtype, + device=grad_voxel_feats.device) + ads_c.npu_dynamic_scatter_grad(grad_point_feats, + grad_voxel_feats.contiguous(), + prefix_sum, argsort_coor, + compare_mask, ctx.reduce_type) + return grad_point_feats, None, None + (feats, voxel_feats, point2voxel_map, voxel_points_count) = ctx.saved_tensors grad_feats = torch.zeros_like(feats) diff --git a/mmcv/ops/three_nn.py b/mmcv/ops/three_nn.py index d41b9789cf..52d504609a 100644 --- a/mmcv/ops/three_nn.py +++ b/mmcv/ops/three_nn.py @@ -34,6 +34,21 @@ def forward(ctx: Any, target: torch.Tensor, B, N, _ = target.size() m = source.size(1) + if source.device.type == 'npu': + # strict to fp32 + source = source.transpose(2, 1).contiguous() + dtype_ = source.dtype + if dtype_ == torch.float16: + target = target.float() + source = source.float() + dist = target.new_empty(B, N, m) + ext_module.three_nn_forward( + target, source, dist, torch.Tensor([]).npu(), b=B, n=N, m=m) + dist2, idx = torch.topk(dist, 3, dim=2, largest=False, sorted=True) + dist2 = torch.sqrt(dist2) + if dtype_ == torch.float16: + dist2 = dist2.half() + return dist2, idx.int() dist2 = target.new_empty(B, N, 3) idx = target.new_empty(B, N, 3, dtype=torch.int32)