Skip to content

Commit

Permalink
[operator migration] Migrate unstack_op and nms_op (#44424)
Browse files Browse the repository at this point in the history
* update unstack_op

* update unstack_op

* update unstack_op

* fix unstack test

* update unstack

* update with remote

* fix unstack_test.py

* temp_save_change_nms_op

* add nms test

* update nms fix

* update unstack_op

* temp save change

* finish fix nms_op

* pass nms test

* fix CI

* fix ops test

* save change

* fix code style

* fix code style

* fix ci and codestyle

* fix ci

Co-authored-by: ShiningZhang <zhang_liang1991@126.com>
  • Loading branch information
HexToString and ShiningZhang committed Aug 1, 2022
1 parent 74e46a9 commit 9d2e0ec
Show file tree
Hide file tree
Showing 18 changed files with 334 additions and 263 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/operators/detection/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ detection_library(sigmoid_focal_loss_op SRCS sigmoid_focal_loss_op.cc
sigmoid_focal_loss_op.cu)
detection_library(retinanet_detection_output_op SRCS
retinanet_detection_output_op.cc)
detection_library(nms_op SRCS nms_op.cc nms_op.cu)
detection_library(nms_op SRCS nms_op.cc)

if(WITH_GPU OR WITH_ROCM)
set(TMPDEPS memory)
Expand Down
81 changes: 12 additions & 69 deletions paddle/fluid/operators/detection/nms_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/detection/nms_op.h"

#include <vector>

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"

namespace paddle {
namespace operators {

Expand Down Expand Up @@ -65,23 +69,6 @@ class NMSOpMaker : public framework::OpProtoAndCheckerMaker {
class NMSOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Boxes"), "Input", "Boxes", "NMS");
OP_INOUT_CHECK(
ctx->HasOutput("KeepBoxesIdxs"), "Output", "KeepBoxesIdxs", "NMS");

auto boxes_dim = ctx->GetInputDim("Boxes");
PADDLE_ENFORCE_EQ(boxes_dim.size(),
2,
platform::errors::InvalidArgument(
"The Input Boxes must be 2-dimention "
"whose shape must be [N, 4] "
"N is the number of boxes "
"in last dimension in format [x1, x2, y1, y2]. "));
auto num_boxes = boxes_dim[0];

ctx->SetOutputDim("KeepBoxesIdxs", {num_boxes});
}

protected:
framework::OpKernelType GetExpectedKernelType(
Expand All @@ -92,64 +79,20 @@ class NMSOp : public framework::OperatorWithKernel {
};

template <typename T>
static void NMS(const T* boxes_data,
int64_t* output_data,
float threshold,
int64_t num_boxes) {
auto num_masks = CeilDivide(num_boxes, 64);
std::vector<uint64_t> masks(num_masks, 0);

for (int64_t i = 0; i < num_boxes; ++i) {
if (masks[i / 64] & 1ULL << (i % 64)) continue;
T box_1[4];
for (int k = 0; k < 4; ++k) {
box_1[k] = boxes_data[i * 4 + k];
}
for (int64_t j = i + 1; j < num_boxes; ++j) {
if (masks[j / 64] & 1ULL << (j % 64)) continue;
T box_2[4];
for (int k = 0; k < 4; ++k) {
box_2[k] = boxes_data[j * 4 + k];
}
bool is_overlap = CalculateIoU<T>(box_1, box_2, threshold);
if (is_overlap) {
masks[j / 64] |= 1ULL << (j % 64);
}
}
}

int64_t output_data_idx = 0;
for (int64_t i = 0; i < num_boxes; ++i) {
if (masks[i / 64] & 1ULL << (i % 64)) continue;
output_data[output_data_idx++] = i;
}

for (; output_data_idx < num_boxes; ++output_data_idx) {
output_data[output_data_idx] = 0;
}
}

template <typename T>
class NMSKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* boxes = context.Input<Tensor>("Boxes");
Tensor* output = context.Output<Tensor>("KeepBoxesIdxs");
int64_t* output_data = output->mutable_data<int64_t>(context.GetPlace());
auto threshold = context.template Attr<float>("iou_threshold");
NMS<T>(boxes->data<T>(), output_data, threshold, boxes->dims()[0]);
}
};
class NMSKernel : public framework::OpKernel<T> {};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(nms,
NMSInferMetaFunctor,
PD_INFER_META(phi::NMSInferMeta));

REGISTER_OPERATOR(
nms,
ops::NMSOp,
ops::NMSOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(nms, ops::NMSKernel<float>, ops::NMSKernel<double>);
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
NMSInferMetaFunctor);
118 changes: 0 additions & 118 deletions paddle/fluid/operators/detection/nms_op.cu

This file was deleted.

55 changes: 7 additions & 48 deletions paddle/fluid/operators/unstack_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"

namespace paddle {
Expand Down Expand Up @@ -63,51 +64,6 @@ class UnStackGradOpMaker : public framework::SingleGradOpMaker<T> {
class UnStackGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_GT(ctx->Inputs(framework::GradVarName("Y")).size(),
0,
platform::errors::InvalidArgument(
"The Inputs(Y@Grad) of unstack operator are empty."));
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")),
"Output",
"X",
"UnStackGrad");
auto input_dims = ctx->GetInputsDim(framework::GradVarName("Y"));
for (size_t i = 1; i < input_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(
input_dims[i],
input_dims[0],
platform::errors::InvalidArgument(
"The dimensions of all Inputs(Y@Grad) must be the same,"
"but received Inputs(Y@Grad)'s %d-th dimension is %d, "
"Inputs(Y@Grad)'s 0-th to %d-th dimension is %d.",
i,
input_dims[i],
i - 1,
input_dims[0]));
}

int axis = ctx->Attrs().Get<int>("axis");
int rank = input_dims[0].size();
PADDLE_ENFORCE_GE(axis,
-(rank + 1),
platform::errors::InvalidArgument(
"The attribute axis is out of range, it must be "
"inside [-(rank+1), rank+1), where rank = %d",
rank));
PADDLE_ENFORCE_LT(axis,
rank + 1,
platform::errors::InvalidArgument(
"The attribute axis is out of range, it must be "
"inside [-(rank+1), rank+1), where rank = %d",
rank));
if (axis < 0) axis += (rank + 1);

auto vec = phi::vectorize<int>(input_dims[0]);
vec.insert(vec.begin() + axis, input_dims.size());
ctx->SetOutputDim(framework::GradVarName("X"), phi::make_ddim(vec));
}
};

} // namespace operators
Expand All @@ -119,12 +75,15 @@ namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(unstack,
UnStackInferMetaFunctor,
PD_INFER_META(phi::UnStackInferMeta));

DECLARE_INFER_SHAPE_FUNCTOR(unstack_grad,
UnStackGradInferMetaFunctor,
PD_INFER_META(phi::UnStackGradInferMeta));
REGISTER_OPERATOR(unstack,
ops::UnStackOp,
ops::UnStackOpMaker,
ops::UnStackGradOpMaker<paddle::framework::OpDesc>,
ops::UnStackGradOpMaker<paddle::imperative::OpBase>,
UnStackInferMetaFunctor);

REGISTER_OPERATOR(unstack_grad, ops::UnStackGradOp);
REGISTER_OPERATOR(unstack_grad,
ops::UnStackGradOp,
UnStackGradInferMetaFunctor);
25 changes: 22 additions & 3 deletions paddle/phi/api/yaml/legacy_api.yaml
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -889,7 +889,7 @@
func : FrameInferMeta
kernel :
func : frame
backward : frame_grad
backward : frame_grad

- api : frobenius_norm
args : (Tensor x, int64_t[] axis, bool keep_dim, bool reduce_all)
Expand Down Expand Up @@ -1700,6 +1700,15 @@
optional : weight
backward : nll_loss_grad

- api : nms
args : (Tensor x, float threshold)
output : Tensor(out)
infer_meta :
func : NMSInferMeta
kernel :
func : nms
data_type : x

- api : norm
args : (Tensor x, int axis, float epsilon, bool is_test)
output : Tensor(out), Tensor(norm)
Expand Down Expand Up @@ -2258,7 +2267,7 @@
kernel :
func : spectralnorm
data_type : weight
backward : spectral_norm_grad
backward : spectral_norm_grad

- api : split
args : (Tensor x, IntArray num_or_sections, Scalar(int) axis)
Expand Down Expand Up @@ -2566,6 +2575,16 @@
intermediate : xshape
backward : unsqueeze_grad

# unstack
- api : unstack
args : (Tensor x, int axis, int num)
output : Tensor[]{num}
infer_meta :
func : UnStackInferMeta
kernel :
func : unstack
backward : unstack_grad

# viterbi_decode
- api : viterbi_decode
args : (Tensor input, Tensor transition, Tensor length, bool include_bos_eos_tag)
Expand Down Expand Up @@ -2629,7 +2648,7 @@
kernel:
func: broadcast_tensors
backward: broadcast_tensors_grad

# dirichlet
- api: dirichlet
args: (Tensor alpha)
Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/api/yaml/legacy_backward.yaml
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -2499,6 +2499,16 @@
inplace : (out_grad -> x_grad)
backward : unsqueeze_double_grad

- backward_api : unstack_grad
forward : unstack (Tensor x, int axis, int num) -> Tensor[](out)
args : (Tensor[] out_grad, int axis)
output : Tensor(x_grad)
infer_meta :
func : UnStackGradInferMeta
param : [out_grad, axis]
kernel :
func : unstack_grad

- backward_api : warpctc_grad
forward : warpctc (Tensor logits, Tensor label, Tensor logits_length, Tensor labels_length, int blank, bool norm_by_times) -> Tensor(loss), Tensor(warpctcgrad)
args : (Tensor logits, Tensor logits_length, Tensor warpctcgrad, Tensor loss_grad, int blank, bool norm_by_times)
Expand Down
Loading

0 comments on commit 9d2e0ec

Please sign in to comment.