Skip to content

Commit

Permalink
phi_multiclass_nms3 (PaddlePaddle#44613)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiboniu committed Jul 29, 2022
1 parent e439d73 commit a991990
Show file tree
Hide file tree
Showing 10 changed files with 910 additions and 82 deletions.
18 changes: 8 additions & 10 deletions paddle/fluid/operators/detection/multiclass_nms_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ limitations under the License. */

#include <glog/logging.h>

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detection/nms_util.h"
#include "paddle/phi/infermeta/ternary.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -609,12 +611,6 @@ class MultiClassNMS3Op : public MultiClassNMS2Op {
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: MultiClassNMS2Op(type, inputs, outputs, attrs) {}

void InferShape(framework::InferShapeContext* ctx) const override {
MultiClassNMS2Op::InferShape(ctx);

ctx->SetOutputDim("NmsRoisNum", {-1});
}
};

class MultiClassNMS3OpMaker : public MultiClassNMS2OpMaker {
Expand All @@ -633,6 +629,10 @@ class MultiClassNMS3OpMaker : public MultiClassNMS2OpMaker {
} // namespace operators
} // namespace paddle

DECLARE_INFER_SHAPE_FUNCTOR(multiclass_nms3,
MultiClassNMSShapeFunctor,
PD_INFER_META(phi::MultiClassNMSInferMeta));

namespace ops = paddle::operators;
REGISTER_OPERATOR(
multiclass_nms,
Expand All @@ -658,7 +658,5 @@ REGISTER_OPERATOR(
ops::MultiClassNMS3Op,
ops::MultiClassNMS3OpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(multiclass_nms3,
ops::MultiClassNMSKernel<float>,
ops::MultiClassNMSKernel<double>);
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
MultiClassNMSShapeFunctor);
9 changes: 9 additions & 0 deletions paddle/phi/api/yaml/legacy_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1615,6 +1615,15 @@
func : multi_dot
backward : multi_dot_grad

- api : multiclass_nms3
args : (Tensor bboxes, Tensor scores, Tensor rois_num, float score_threshold, int nms_top_k, int keep_top_k, float nms_threshold=0.3, bool normalized=true, float nms_eta=1.0, int background_label=0)
output : Tensor(out), Tensor(index), Tensor(nms_rois_num)
infer_meta :
func : MultiClassNMSInferMeta
kernel :
func : multiclass_nms3
optional : rois_num

# multinomial
- api : multinomial
args : (Tensor x, int num_samples, bool replacement)
Expand Down
93 changes: 93 additions & 0 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,99 @@ void LinspaceInferMeta(const MetaTensor& start,
LinspaceRawInferMeta(start, stop, number, out);
}

void MultiClassNMSInferMeta(const MetaTensor& bboxes,
const MetaTensor& scores,
const MetaTensor& rois_num,
float score_threshold,
int nms_top_k,
int keep_top_k,
float nms_threshold,
bool normalized,
float nms_eta,
int background_label,
MetaTensor* out,
MetaTensor* index,
MetaTensor* nms_rois_num,
MetaConfig config) {
auto box_dims = bboxes.dims();
auto score_dims = scores.dims();
auto score_size = score_dims.size();

if (config.is_runtime) {
PADDLE_ENFORCE_EQ(
score_size == 2 || score_size == 3,
true,
errors::InvalidArgument("The rank of Input(Scores) must be 2 or 3"
". But received rank = %d",
score_size));
PADDLE_ENFORCE_EQ(
box_dims.size(),
3,
errors::InvalidArgument("The rank of Input(BBoxes) must be 3"
". But received rank = %d",
box_dims.size()));
if (score_size == 3) {
PADDLE_ENFORCE_EQ(box_dims[2] == 4 || box_dims[2] == 8 ||
box_dims[2] == 16 || box_dims[2] == 24 ||
box_dims[2] == 32,
true,
errors::InvalidArgument(
"The last dimension of Input"
"(BBoxes) must be 4 or 8, "
"represents the layout of coordinate "
"[xmin, ymin, xmax, ymax] or "
"4 points: [x1, y1, x2, y2, x3, y3, x4, y4] or "
"8 points: [xi, yi] i= 1,2,...,8 or "
"12 points: [xi, yi] i= 1,2,...,12 or "
"16 points: [xi, yi] i= 1,2,...,16"));
PADDLE_ENFORCE_EQ(
box_dims[1],
score_dims[2],
errors::InvalidArgument(
"The 2nd dimension of Input(BBoxes) must be equal to "
"last dimension of Input(Scores), which represents the "
"predicted bboxes."
"But received box_dims[1](%s) != socre_dims[2](%s)",
box_dims[1],
score_dims[2]));
} else {
PADDLE_ENFORCE_EQ(box_dims[2],
4,
errors::InvalidArgument(
"The last dimension of Input"
"(BBoxes) must be 4. But received dimension = %d",
box_dims[2]));
PADDLE_ENFORCE_EQ(
box_dims[1],
score_dims[1],
errors::InvalidArgument(
"The 2nd dimension of Input"
"(BBoxes) must be equal to the 2nd dimension of Input(Scores). "
"But received box dimension = %d, score dimension = %d",
box_dims[1],
score_dims[1]));
}
}
PADDLE_ENFORCE_NE(out,
nullptr,
errors::InvalidArgument(
"The out in MultiClassNMSInferMeta can't be nullptr."));
PADDLE_ENFORCE_NE(
index,
nullptr,
errors::InvalidArgument(
"The index in MultiClassNMSInferMeta can't be nullptr."));
// Here the box_dims[0] is not the real dimension of output.
// It will be rewritten in the computing kernel.

out->set_dims(phi::make_ddim({-1, box_dims[2] + 2}));
out->set_dtype(bboxes.dtype());
index->set_dims(phi::make_ddim({-1, box_dims[2] + 2}));
index->set_dtype(DataType::INT32);
nms_rois_num->set_dims(phi::make_ddim({-1}));
nms_rois_num->set_dtype(DataType::INT32);
}

void NllLossRawInferMeta(const MetaTensor& input,
const MetaTensor& label,
const MetaTensor& weight,
Expand Down
15 changes: 15 additions & 0 deletions paddle/phi/infermeta/ternary.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,21 @@ void LinspaceInferMeta(const MetaTensor& start,
DataType dtype,
MetaTensor* out);

void MultiClassNMSInferMeta(const MetaTensor& bboxes,
const MetaTensor& scores,
const MetaTensor& rois_num,
float score_threshold,
int nms_top_k,
int keep_top_k,
float nms_threshold,
bool normalized,
float nms_eta,
int background_label,
MetaTensor* out,
MetaTensor* index,
MetaTensor* nms_rois_num,
MetaConfig config = MetaConfig());

void NllLossRawInferMeta(const MetaTensor& input,
const MetaTensor& label,
const MetaTensor& weight,
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ set(COMMON_KERNEL_DEPS
lod_utils
custom_kernel
string_infermeta
utf8proc)
utf8proc
gpc)

copy_if_different(${kernel_declare_file} ${kernel_declare_file_final})

Expand Down
Loading

0 comments on commit a991990

Please sign in to comment.