From 9d2e0ecb1e6c7fabbc6ac743b79f49baf78c2ba3 Mon Sep 17 00:00:00 2001 From: Thomas Young <35565423+HexToString@users.noreply.github.com> Date: Mon, 1 Aug 2022 15:11:09 +0800 Subject: [PATCH] [operator migration] Migrate unstack_op and nms_op (#44424) * update unstack_op * update unstack_op * update unstack_op * fix unstack test * update unstack * update with remote * fix unstack_test.py * temp_save_change_nms_op * add nms test * update nms fix * update unstack_op * temp save change * finish fix nms_op * pass nms test * fix CI * fix ops test * save change * fix code style * fix code style * fix ci and codestyle * fix ci Co-authored-by: ShiningZhang --- .../fluid/operators/detection/CMakeLists.txt | 2 +- paddle/fluid/operators/detection/nms_op.cc | 81 ++---------- paddle/fluid/operators/detection/nms_op.cu | 118 ------------------ paddle/fluid/operators/unstack_op.cc | 55 ++------ paddle/phi/api/yaml/legacy_api.yaml | 25 +++- paddle/phi/api/yaml/legacy_backward.yaml | 10 ++ paddle/phi/infermeta/backward.cc | 43 ++++++- paddle/phi/infermeta/backward.h | 5 +- paddle/phi/infermeta/unary.cc | 13 ++ paddle/phi/infermeta/unary.h | 2 + paddle/phi/kernels/cpu/nms_kernel.cc | 73 +++++++++++ paddle/phi/kernels/gpu/nms_kernel.cu | 102 +++++++++++++++ .../nms_op.h => phi/kernels/nms_kernel.h} | 42 ++++--- python/paddle/fluid/layers/nn.py | 1 + .../fluid/tests/unittests/test_nms_op.py | 6 +- .../fluid/tests/unittests/test_unstack_op.py | 9 +- python/paddle/tensor/manipulation.py | 7 ++ python/paddle/vision/ops.py | 3 + 18 files changed, 334 insertions(+), 263 deletions(-) delete mode 100644 paddle/fluid/operators/detection/nms_op.cu mode change 100644 => 100755 paddle/phi/api/yaml/legacy_api.yaml mode change 100644 => 100755 paddle/phi/api/yaml/legacy_backward.yaml mode change 100644 => 100755 paddle/phi/infermeta/backward.h create mode 100644 paddle/phi/kernels/cpu/nms_kernel.cc create mode 100644 paddle/phi/kernels/gpu/nms_kernel.cu rename paddle/{fluid/operators/detection/nms_op.h => phi/kernels/nms_kernel.h} (56%) mode change 100644 => 100755 python/paddle/fluid/tests/unittests/test_nms_op.py mode change 100644 => 100755 python/paddle/fluid/tests/unittests/test_unstack_op.py mode change 100644 => 100755 python/paddle/vision/ops.py diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index 000a1a4a520fd..c2dcdf58f6e2e 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -81,7 +81,7 @@ detection_library(sigmoid_focal_loss_op SRCS sigmoid_focal_loss_op.cc sigmoid_focal_loss_op.cu) detection_library(retinanet_detection_output_op SRCS retinanet_detection_output_op.cc) -detection_library(nms_op SRCS nms_op.cc nms_op.cu) +detection_library(nms_op SRCS nms_op.cc) if(WITH_GPU OR WITH_ROCM) set(TMPDEPS memory) diff --git a/paddle/fluid/operators/detection/nms_op.cc b/paddle/fluid/operators/detection/nms_op.cc index 3c5feaa656a32..03680538f778e 100644 --- a/paddle/fluid/operators/detection/nms_op.cc +++ b/paddle/fluid/operators/detection/nms_op.cc @@ -12,10 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/detection/nms_op.h" - #include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/for_range.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" + namespace paddle { namespace operators { @@ -65,23 +69,6 @@ class NMSOpMaker : public framework::OpProtoAndCheckerMaker { class NMSOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Boxes"), "Input", "Boxes", "NMS"); - OP_INOUT_CHECK( - ctx->HasOutput("KeepBoxesIdxs"), "Output", "KeepBoxesIdxs", "NMS"); - - auto boxes_dim = ctx->GetInputDim("Boxes"); - PADDLE_ENFORCE_EQ(boxes_dim.size(), - 2, - platform::errors::InvalidArgument( - "The Input Boxes must be 2-dimention " - "whose shape must be [N, 4] " - "N is the number of boxes " - "in last dimension in format [x1, x2, y1, y2]. ")); - auto num_boxes = boxes_dim[0]; - - ctx->SetOutputDim("KeepBoxesIdxs", {num_boxes}); - } protected: framework::OpKernelType GetExpectedKernelType( @@ -92,64 +79,20 @@ class NMSOp : public framework::OperatorWithKernel { }; template -static void NMS(const T* boxes_data, - int64_t* output_data, - float threshold, - int64_t num_boxes) { - auto num_masks = CeilDivide(num_boxes, 64); - std::vector masks(num_masks, 0); - - for (int64_t i = 0; i < num_boxes; ++i) { - if (masks[i / 64] & 1ULL << (i % 64)) continue; - T box_1[4]; - for (int k = 0; k < 4; ++k) { - box_1[k] = boxes_data[i * 4 + k]; - } - for (int64_t j = i + 1; j < num_boxes; ++j) { - if (masks[j / 64] & 1ULL << (j % 64)) continue; - T box_2[4]; - for (int k = 0; k < 4; ++k) { - box_2[k] = boxes_data[j * 4 + k]; - } - bool is_overlap = CalculateIoU(box_1, box_2, threshold); - if (is_overlap) { - masks[j / 64] |= 1ULL << (j % 64); - } - } - } - - int64_t output_data_idx = 0; - for (int64_t i = 0; i < num_boxes; ++i) { - if (masks[i / 64] & 1ULL << (i % 64)) continue; - output_data[output_data_idx++] = i; - } - - for (; output_data_idx < num_boxes; ++output_data_idx) { - output_data[output_data_idx] = 0; - } -} - -template -class NMSKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* boxes = context.Input("Boxes"); - Tensor* output = context.Output("KeepBoxesIdxs"); - int64_t* output_data = output->mutable_data(context.GetPlace()); - auto threshold = context.template Attr("iou_threshold"); - NMS(boxes->data(), output_data, threshold, boxes->dims()[0]); - } -}; +class NMSKernel : public framework::OpKernel {}; } // namespace operators } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(nms, + NMSInferMetaFunctor, + PD_INFER_META(phi::NMSInferMeta)); REGISTER_OPERATOR( nms, ops::NMSOp, ops::NMSOpMaker, paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL(nms, ops::NMSKernel, ops::NMSKernel); + paddle::framework::EmptyGradOpMaker, + NMSInferMetaFunctor); diff --git a/paddle/fluid/operators/detection/nms_op.cu b/paddle/fluid/operators/detection/nms_op.cu deleted file mode 100644 index 935d13cfd4a47..0000000000000 --- a/paddle/fluid/operators/detection/nms_op.cu +++ /dev/null @@ -1,118 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include - -#include "paddle/fluid/operators/detection/nms_op.h" -#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" - -static const int64_t threadsPerBlock = sizeof(int64_t) * 8; - -namespace paddle { -namespace operators { - -using framework::Tensor; - -template -static __global__ void NMS(const T* boxes_data, - float threshold, - int64_t num_boxes, - uint64_t* masks) { - auto raw_start = blockIdx.y; - auto col_start = blockIdx.x; - if (raw_start > col_start) return; - - const int raw_last_storage = - min(num_boxes - raw_start * threadsPerBlock, threadsPerBlock); - const int col_last_storage = - min(num_boxes - col_start * threadsPerBlock, threadsPerBlock); - - if (threadIdx.x < raw_last_storage) { - uint64_t mask = 0; - auto current_box_idx = raw_start * threadsPerBlock + threadIdx.x; - const T* current_box = boxes_data + current_box_idx * 4; - for (int i = 0; i < col_last_storage; ++i) { - const T* target_box = boxes_data + (col_start * threadsPerBlock + i) * 4; - if (CalculateIoU(current_box, target_box, threshold)) { - mask |= 1ULL << i; - } - } - const int blocks_per_line = CeilDivide(num_boxes, threadsPerBlock); - masks[current_box_idx * blocks_per_line + col_start] = mask; - } -} - -template -class NMSCudaKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* boxes = context.Input("Boxes"); - Tensor* output = context.Output("KeepBoxesIdxs"); - auto* output_data = output->mutable_data(context.GetPlace()); - - auto threshold = context.template Attr("iou_threshold"); - const int64_t num_boxes = boxes->dims()[0]; - const auto blocks_per_line = CeilDivide(num_boxes, threadsPerBlock); - - dim3 block(threadsPerBlock); - dim3 grid(blocks_per_line, blocks_per_line); - - auto mask_data = - memory::Alloc(context.cuda_device_context(), - num_boxes * blocks_per_line * sizeof(uint64_t)); - uint64_t* mask_dev = reinterpret_cast(mask_data->ptr()); - NMS<<>>( - boxes->data(), threshold, num_boxes, mask_dev); - - std::vector mask_host(num_boxes * blocks_per_line); - memory::Copy(platform::CPUPlace(), - mask_host.data(), - context.GetPlace(), - mask_dev, - num_boxes * blocks_per_line * sizeof(uint64_t), - context.cuda_device_context().stream()); - - std::vector remv(blocks_per_line); - - std::vector keep_boxes_idxs(num_boxes); - int64_t* output_host = keep_boxes_idxs.data(); - - int64_t last_box_num = 0; - for (int64_t i = 0; i < num_boxes; ++i) { - auto remv_element_id = i / threadsPerBlock; - auto remv_bit_id = i % threadsPerBlock; - if (!(remv[remv_element_id] & 1ULL << remv_bit_id)) { - output_host[last_box_num++] = i; - uint64_t* current_mask = mask_host.data() + i * blocks_per_line; - for (auto j = remv_element_id; j < blocks_per_line; ++j) { - remv[j] |= current_mask[j]; - } - } - } - memory::Copy(context.GetPlace(), - output_data, - platform::CPUPlace(), - output_host, - sizeof(int64_t) * num_boxes, - context.cuda_device_context().stream()); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(nms, - ops::NMSCudaKernel, - ops::NMSCudaKernel); diff --git a/paddle/fluid/operators/unstack_op.cc b/paddle/fluid/operators/unstack_op.cc index 76fe2ac77d9d8..d1cfbd2b90260 100644 --- a/paddle/fluid/operators/unstack_op.cc +++ b/paddle/fluid/operators/unstack_op.cc @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/for_range.h" #include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/backward.h" #include "paddle/phi/infermeta/unary.h" namespace paddle { @@ -63,51 +64,6 @@ class UnStackGradOpMaker : public framework::SingleGradOpMaker { class UnStackGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE_GT(ctx->Inputs(framework::GradVarName("Y")).size(), - 0, - platform::errors::InvalidArgument( - "The Inputs(Y@Grad) of unstack operator are empty.")); - OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), - "Output", - "X", - "UnStackGrad"); - auto input_dims = ctx->GetInputsDim(framework::GradVarName("Y")); - for (size_t i = 1; i < input_dims.size(); ++i) { - PADDLE_ENFORCE_EQ( - input_dims[i], - input_dims[0], - platform::errors::InvalidArgument( - "The dimensions of all Inputs(Y@Grad) must be the same," - "but received Inputs(Y@Grad)'s %d-th dimension is %d, " - "Inputs(Y@Grad)'s 0-th to %d-th dimension is %d.", - i, - input_dims[i], - i - 1, - input_dims[0])); - } - - int axis = ctx->Attrs().Get("axis"); - int rank = input_dims[0].size(); - PADDLE_ENFORCE_GE(axis, - -(rank + 1), - platform::errors::InvalidArgument( - "The attribute axis is out of range, it must be " - "inside [-(rank+1), rank+1), where rank = %d", - rank)); - PADDLE_ENFORCE_LT(axis, - rank + 1, - platform::errors::InvalidArgument( - "The attribute axis is out of range, it must be " - "inside [-(rank+1), rank+1), where rank = %d", - rank)); - if (axis < 0) axis += (rank + 1); - - auto vec = phi::vectorize(input_dims[0]); - vec.insert(vec.begin() + axis, input_dims.size()); - ctx->SetOutputDim(framework::GradVarName("X"), phi::make_ddim(vec)); - } }; } // namespace operators @@ -119,12 +75,15 @@ namespace ops = paddle::operators; DECLARE_INFER_SHAPE_FUNCTOR(unstack, UnStackInferMetaFunctor, PD_INFER_META(phi::UnStackInferMeta)); - +DECLARE_INFER_SHAPE_FUNCTOR(unstack_grad, + UnStackGradInferMetaFunctor, + PD_INFER_META(phi::UnStackGradInferMeta)); REGISTER_OPERATOR(unstack, ops::UnStackOp, ops::UnStackOpMaker, ops::UnStackGradOpMaker, ops::UnStackGradOpMaker, UnStackInferMetaFunctor); - -REGISTER_OPERATOR(unstack_grad, ops::UnStackGradOp); +REGISTER_OPERATOR(unstack_grad, + ops::UnStackGradOp, + UnStackGradInferMetaFunctor); diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml old mode 100644 new mode 100755 index bb232f6212c39..7d1342f807cd7 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -889,7 +889,7 @@ func : FrameInferMeta kernel : func : frame - backward : frame_grad + backward : frame_grad - api : frobenius_norm args : (Tensor x, int64_t[] axis, bool keep_dim, bool reduce_all) @@ -1700,6 +1700,15 @@ optional : weight backward : nll_loss_grad +- api : nms + args : (Tensor x, float threshold) + output : Tensor(out) + infer_meta : + func : NMSInferMeta + kernel : + func : nms + data_type : x + - api : norm args : (Tensor x, int axis, float epsilon, bool is_test) output : Tensor(out), Tensor(norm) @@ -2258,7 +2267,7 @@ kernel : func : spectralnorm data_type : weight - backward : spectral_norm_grad + backward : spectral_norm_grad - api : split args : (Tensor x, IntArray num_or_sections, Scalar(int) axis) @@ -2566,6 +2575,16 @@ intermediate : xshape backward : unsqueeze_grad +# unstack +- api : unstack + args : (Tensor x, int axis, int num) + output : Tensor[]{num} + infer_meta : + func : UnStackInferMeta + kernel : + func : unstack + backward : unstack_grad + # viterbi_decode - api : viterbi_decode args : (Tensor input, Tensor transition, Tensor length, bool include_bos_eos_tag) @@ -2629,7 +2648,7 @@ kernel: func: broadcast_tensors backward: broadcast_tensors_grad - + # dirichlet - api: dirichlet args: (Tensor alpha) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml old mode 100644 new mode 100755 index b44417050783e..6bdd73dd00452 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -2499,6 +2499,16 @@ inplace : (out_grad -> x_grad) backward : unsqueeze_double_grad +- backward_api : unstack_grad + forward : unstack (Tensor x, int axis, int num) -> Tensor[](out) + args : (Tensor[] out_grad, int axis) + output : Tensor(x_grad) + infer_meta : + func : UnStackGradInferMeta + param : [out_grad, axis] + kernel : + func : unstack_grad + - backward_api : warpctc_grad forward : warpctc (Tensor logits, Tensor label, Tensor logits_length, Tensor labels_length, int blank, bool norm_by_times) -> Tensor(loss), Tensor(warpctcgrad) args : (Tensor logits, Tensor logits_length, Tensor warpctcgrad, Tensor loss_grad, int blank, bool norm_by_times) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 2c2484da35d0d..bea572ca741a6 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/phi/infermeta/backward.h" - #include "paddle/phi/common/type_traits.h" #include "paddle/phi/kernels/funcs/axis_utils.h" @@ -787,4 +786,46 @@ void StackGradInferMeta(const MetaTensor& out_grad, } } +void UnStackGradInferMeta(const std::vector& out_grad, + int axis, + MetaTensor* x_grad) { + std::vector input_dims(out_grad.size()); + for (size_t i = 0; i < out_grad.size(); ++i) { + input_dims[i] = out_grad[i]->dims(); + } + for (size_t i = 1; i < input_dims.size(); ++i) { + PADDLE_ENFORCE_EQ( + input_dims[i], + input_dims[0], + phi::errors::InvalidArgument( + "The dimensions of all Inputs(Y@Grad) must be the same," + "but received Inputs(Y@Grad)'s %d-th dimension is %d, " + "Inputs(Y@Grad)'s 0-th to %d-th dimension is %d.", + i, + input_dims[i], + i - 1, + input_dims[0])); + } + + int rank = input_dims[0].size(); + PADDLE_ENFORCE_GE(axis, + -(rank + 1), + phi::errors::InvalidArgument( + "The attribute axis is out of range, it must be " + "inside [-(rank+1), rank+1), where rank = %d", + rank)); + PADDLE_ENFORCE_LT(axis, + rank + 1, + phi::errors::InvalidArgument( + "The attribute axis is out of range, it must be " + "inside [-(rank+1), rank+1), where rank = %d", + rank)); + if (axis < 0) axis += (rank + 1); + + auto vec = phi::vectorize(input_dims[0]); + vec.insert(vec.begin() + axis, input_dims.size()); + x_grad->set_dims(phi::make_ddim(vec)); + x_grad->set_dtype(out_grad[0]->dtype()); +} + } // namespace phi diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h old mode 100644 new mode 100755 index add2f8945dd9f..696e5466158aa --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -15,7 +15,6 @@ limitations under the License. */ #pragma once #include - #include "paddle/phi/core/meta_tensor.h" #include "paddle/phi/infermeta/binary.h" #include "paddle/phi/infermeta/multiary.h" @@ -325,4 +324,8 @@ void StackGradInferMeta(const MetaTensor& out_grad, int axis, std::vector x_grad); +void UnStackGradInferMeta(const std::vector& out_grad, + int axis, + MetaTensor* x_grad); + } // namespace phi diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 9bb8c156d9016..4ba496f4499cd 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1707,6 +1707,19 @@ void NanmedianInferMeta(const MetaTensor& x, out->set_dims(make_ddim(out_dim)); } +void NMSInferMeta(const MetaTensor& x, float threshold, MetaTensor* out) { + auto boxes_dim = x.dims(); + PADDLE_ENFORCE_EQ(boxes_dim.size(), + 2, + phi::errors::InvalidArgument( + "The Input Boxes must be 2-dimention " + "whose shape must be [N, 4] " + "N is the number of boxes " + "in last dimension in format [x1, x2, y1, y2]. ")); + auto num_boxes = boxes_dim[0]; + out->set_dims(phi::make_ddim({num_boxes})); +} + void NormInferMeta(const MetaTensor& x, int axis, float epsilon, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index f17ab48f0fae6..21c052580a829 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -228,6 +228,8 @@ void NanmedianInferMeta(const MetaTensor& x, MetaTensor* out, MetaTensor* median_index); +void NMSInferMeta(const MetaTensor& x, float threshold, MetaTensor* out); + void NormInferMeta(const MetaTensor& x, int axis, float epsilon, diff --git a/paddle/phi/kernels/cpu/nms_kernel.cc b/paddle/phi/kernels/cpu/nms_kernel.cc new file mode 100644 index 0000000000000..7e656b14f1fc5 --- /dev/null +++ b/paddle/phi/kernels/cpu/nms_kernel.cc @@ -0,0 +1,73 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/nms_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/diagonal.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" + +namespace phi { + +template +static void NMS(const T* boxes_data, + int64_t* output_data, + float threshold, + int64_t num_boxes) { + auto num_masks = CeilDivide(num_boxes, 64); + std::vector masks(num_masks, 0); + + for (int64_t i = 0; i < num_boxes; ++i) { + if (masks[i / 64] & 1ULL << (i % 64)) continue; + T box_1[4]; + for (int k = 0; k < 4; ++k) { + box_1[k] = boxes_data[i * 4 + k]; + } + for (int64_t j = i + 1; j < num_boxes; ++j) { + if (masks[j / 64] & 1ULL << (j % 64)) continue; + T box_2[4]; + for (int k = 0; k < 4; ++k) { + box_2[k] = boxes_data[j * 4 + k]; + } + bool is_overlap = CalculateIoU(box_1, box_2, threshold); + if (is_overlap) { + masks[j / 64] |= 1ULL << (j % 64); + } + } + } + + int64_t output_data_idx = 0; + for (int64_t i = 0; i < num_boxes; ++i) { + if (masks[i / 64] & 1ULL << (i % 64)) continue; + output_data[output_data_idx++] = i; + } + + for (; output_data_idx < num_boxes; ++output_data_idx) { + output_data[output_data_idx] = 0; + } +} + +template +void NMSKernel(const Context& dev_ctx, + const DenseTensor& boxes, + float threshold, + DenseTensor* output) { + auto output_data = dev_ctx.template Alloc(output); + NMS(boxes.data(), output_data, threshold, boxes.dims()[0]); +} + +} // namespace phi + +PD_REGISTER_KERNEL(nms, CPU, ALL_LAYOUT, phi::NMSKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/nms_kernel.cu b/paddle/phi/kernels/gpu/nms_kernel.cu new file mode 100644 index 0000000000000..5a52cb33662fc --- /dev/null +++ b/paddle/phi/kernels/gpu/nms_kernel.cu @@ -0,0 +1,102 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/nms_kernel.h" + +#include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +static const int64_t threadsPerBlock = sizeof(int64_t) * 8; + +namespace phi { + +template +static __global__ void NMS(const T* boxes_data, + float threshold, + int64_t num_boxes, + uint64_t* masks) { + auto raw_start = blockIdx.y; + auto col_start = blockIdx.x; + if (raw_start > col_start) return; + + const int raw_last_storage = + min(num_boxes - raw_start * threadsPerBlock, threadsPerBlock); + const int col_last_storage = + min(num_boxes - col_start * threadsPerBlock, threadsPerBlock); + + if (threadIdx.x < raw_last_storage) { + uint64_t mask = 0; + auto current_box_idx = raw_start * threadsPerBlock + threadIdx.x; + const T* current_box = boxes_data + current_box_idx * 4; + for (int i = 0; i < col_last_storage; ++i) { + const T* target_box = boxes_data + (col_start * threadsPerBlock + i) * 4; + if (CalculateIoU(current_box, target_box, threshold)) { + mask |= 1ULL << i; + } + } + const int blocks_per_line = CeilDivide(num_boxes, threadsPerBlock); + masks[current_box_idx * blocks_per_line + col_start] = mask; + } +} + +template +void NMSKernel(const Context& dev_ctx, + const DenseTensor& boxes, + float threshold, + DenseTensor* output) { + auto* output_data = dev_ctx.template Alloc(output); + const int64_t num_boxes = boxes.dims()[0]; + const auto blocks_per_line = CeilDivide(num_boxes, threadsPerBlock); + dim3 block(threadsPerBlock); + dim3 grid(blocks_per_line, blocks_per_line); + auto mask_data = paddle::memory::Alloc( + dev_ctx, num_boxes * blocks_per_line * sizeof(uint64_t)); + uint64_t* mask_dev = reinterpret_cast(mask_data->ptr()); + NMS<<>>( + boxes.data(), threshold, num_boxes, mask_dev); + std::vector mask_host(num_boxes * blocks_per_line); + paddle::memory::Copy(phi::CPUPlace(), + mask_host.data(), + dev_ctx.GetPlace(), + mask_dev, + num_boxes * blocks_per_line * sizeof(uint64_t), + dev_ctx.stream()); + std::vector remv(blocks_per_line); + std::vector keep_boxes_idxs(num_boxes); + int64_t* output_host = keep_boxes_idxs.data(); + int64_t last_box_num = 0; + for (int64_t i = 0; i < num_boxes; ++i) { + auto remv_element_id = i / threadsPerBlock; + auto remv_bit_id = i % threadsPerBlock; + if (!(remv[remv_element_id] & 1ULL << remv_bit_id)) { + output_host[last_box_num++] = i; + uint64_t* current_mask = mask_host.data() + i * blocks_per_line; + for (auto j = remv_element_id; j < blocks_per_line; ++j) { + remv[j] |= current_mask[j]; + } + } + } + paddle::memory::Copy(dev_ctx.GetPlace(), + output_data, + phi::CPUPlace(), + output_host, + sizeof(int64_t) * num_boxes, + dev_ctx.stream()); +} +} // namespace phi +PD_REGISTER_KERNEL(nms, GPU, ALL_LAYOUT, phi::NMSKernel, float, double) {} diff --git a/paddle/fluid/operators/detection/nms_op.h b/paddle/phi/kernels/nms_kernel.h similarity index 56% rename from paddle/fluid/operators/detection/nms_op.h rename to paddle/phi/kernels/nms_kernel.h index f5cd1c9203784..e8511f4c4a49f 100644 --- a/paddle/fluid/operators/detection/nms_op.h +++ b/paddle/phi/kernels/nms_kernel.h @@ -1,24 +1,23 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/hostdevice.h" -namespace paddle { -namespace operators { +namespace phi { HOSTDEVICE static inline int64_t CeilDivide(int64_t n, int64_t m) { return (n + m - 1) / m; @@ -48,5 +47,10 @@ HOSTDEVICE inline bool CalculateIoU(const T* const box_1, return inter_area / union_area > threshold; } -} // namespace operators -} // namespace paddle +template +void NMSKernel(const Context& dev_ctx, + const DenseTensor& boxes, + float threshold, + DenseTensor* output); + +} // namespace phi diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index e68b70107c109..03d21035bd510 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -10685,6 +10685,7 @@ def unstack(x, axis=0, num=None): y = paddle.unstack(x, axis=1) # unstack with second axis, which results 3 tensors with shape=[2, 5] """ + if _non_static_mode(): if num == None: num = x.shape[axis] diff --git a/python/paddle/fluid/tests/unittests/test_nms_op.py b/python/paddle/fluid/tests/unittests/test_nms_op.py old mode 100644 new mode 100755 index f3c253d45c0de..a81a46e1140e8 --- a/python/paddle/fluid/tests/unittests/test_nms_op.py +++ b/python/paddle/fluid/tests/unittests/test_nms_op.py @@ -15,6 +15,7 @@ import unittest import numpy as np from op_test import OpTest +import paddle def iou(box_a, box_b): @@ -71,22 +72,25 @@ class TestNMSOp(OpTest): def setUp(self): self.op_type = 'nms' + self.python_api = paddle.vision.ops.nms self.dtype = np.float64 self.init_dtype_type() boxes = np.random.rand(32, 4).astype(self.dtype) boxes[:, 2] = boxes[:, 0] + boxes[:, 2] boxes[:, 3] = boxes[:, 1] + boxes[:, 3] + paddle.disable_static() self.inputs = {'Boxes': boxes} self.attrs = {'iou_threshold': 0.5} out_py = nms(boxes, self.attrs['iou_threshold']) self.outputs = {'KeepBoxesIdxs': out_py} + paddle.enable_static() def init_dtype_type(self): pass def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_unstack_op.py b/python/paddle/fluid/tests/unittests/test_unstack_op.py old mode 100644 new mode 100755 index 730a74dc54c5a..bb28bdeba79d3 --- a/python/paddle/fluid/tests/unittests/test_unstack_op.py +++ b/python/paddle/fluid/tests/unittests/test_unstack_op.py @@ -15,6 +15,7 @@ from op_test import OpTest import numpy as np import unittest +import paddle class TestUnStackOpBase(OpTest): @@ -37,6 +38,7 @@ def setUp(self): self.initDefaultParameters() self.initParameters() self.op_type = 'unstack' + self.python_api = paddle.unstack self.x = np.random.random(size=self.input_dim).astype(self.dtype) outs = np.split(self.x, self.input_dim[self.axis], self.axis) @@ -44,18 +46,21 @@ def setUp(self): del new_shape[self.axis] y_names = self.get_y_names() tmp = [] + tmp_names = [] for i in range(self.input_dim[self.axis]): tmp.append((y_names[i], np.reshape(outs[i], new_shape))) + tmp_names.append(y_names[i]) + self.python_out_sig = tmp_names self.inputs = {'X': self.x} self.outputs = {'Y': tmp} self.attrs = {'axis': self.axis, 'num': self.input_dim[self.axis]} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], self.get_y_names()) + self.check_grad(['X'], self.get_y_names(), check_eager=True) class TestStackOp3(TestUnStackOpBase): diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index c170a6fb04e88..4053cc43b30ee 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -454,6 +454,13 @@ def unstack(x, axis=0, num=None): y = paddle.unstack(x, axis=1) # unstack with second axis, which results 3 tensors with shape=[2, 5] """ + if in_dygraph_mode(): + if num == None: + num = x.shape[axis] + if num == 0: + return [] + return _C_ops.final_state_unstack(x, axis, num) + if _non_static_mode(): if num == None: num = x.shape[axis] diff --git a/python/paddle/vision/ops.py b/python/paddle/vision/ops.py old mode 100644 new mode 100755 index aef90bb140d2b..6ab9e7567cc4b --- a/python/paddle/vision/ops.py +++ b/python/paddle/vision/ops.py @@ -1579,6 +1579,9 @@ def nms(boxes, """ def _nms(boxes, iou_threshold): + if in_dygraph_mode(): + return _C_ops.final_state_nms(boxes, iou_threshold) + if _non_static_mode(): return _C_ops.nms(boxes, 'iou_threshold', iou_threshold)