Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Phi generate_proposals_v2 #44436

Merged
merged 9 commits into from
Aug 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/operators/detection/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ if(WITH_GPU OR WITH_ROCM)
detection_library(generate_proposals_op SRCS generate_proposals_op.cc
generate_proposals_op.cu DEPS ${TMPDEPS})
detection_library(generate_proposals_v2_op SRCS generate_proposals_v2_op.cc
generate_proposals_v2_op.cu DEPS ${TMPDEPS})
DEPS ${TMPDEPS})
detection_library(
distribute_fpn_proposals_op SRCS distribute_fpn_proposals_op.cc
distribute_fpn_proposals_op.cu DEPS ${TMPDEPS})
Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/operators/detection/generate_proposals_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/detection/bbox_util.h"
#include "paddle/fluid/operators/detection/nms_util.h"
#include "paddle/phi/kernels/funcs/detection/nms_util.h"
#include "paddle/phi/kernels/funcs/gather.h"
#include "paddle/phi/kernels/funcs/math_function.h"

Expand Down Expand Up @@ -251,7 +251,8 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
return std::make_pair(bbox_sel, scores_filter);
}

Tensor keep_nms = NMS<T>(ctx, &bbox_sel, &scores_filter, nms_thresh, eta);
Tensor keep_nms =
phi::funcs::NMS<T>(ctx, &bbox_sel, &scores_filter, nms_thresh, eta);

if (post_nms_top_n > 0 && post_nms_top_n < keep_nms.numel()) {
keep_nms.Resize({post_nms_top_n});
Expand Down
245 changes: 10 additions & 235 deletions paddle/fluid/operators/detection/generate_proposals_v2_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ limitations under the License. */
#include <string>
#include <vector>

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/detection/bbox_util.h"
#include "paddle/fluid/operators/detection/nms_util.h"
#include "paddle/phi/infermeta/multiary.h"
#include "paddle/phi/kernels/funcs/detection/nms_util.h"
#include "paddle/phi/kernels/funcs/gather.h"
#include "paddle/phi/kernels/funcs/math_function.h"

Expand All @@ -34,36 +36,6 @@ class GenerateProposalsV2Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("Scores"),
true,
platform::errors::NotFound("Input(Scores) shouldn't be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("BboxDeltas"),
true,
platform::errors::NotFound("Input(BboxDeltas) shouldn't be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("ImShape"),
true,
platform::errors::NotFound("Input(ImShape) shouldn't be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Anchors"),
true,
platform::errors::NotFound("Input(Anchors) shouldn't be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Variances"),
true,
platform::errors::NotFound("Input(Variances) shouldn't be null."));

ctx->SetOutputDim("RpnRois", {-1, 4});
ctx->SetOutputDim("RpnRoiProbs", {-1, 1});
if (!ctx->IsRuntime()) {
ctx->SetLoDLevel("RpnRois", std::max(ctx->GetLoDLevel("Scores"), 1));
ctx->SetLoDLevel("RpnRoiProbs", std::max(ctx->GetLoDLevel("Scores"), 1));
}
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
Expand All @@ -73,206 +45,6 @@ class GenerateProposalsV2Op : public framework::OperatorWithKernel {
}
};

template <typename T>
class GenerateProposalsV2Kernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *scores = context.Input<Tensor>("Scores");
auto *bbox_deltas = context.Input<Tensor>("BboxDeltas");
auto *im_shape = context.Input<Tensor>("ImShape");
auto anchors = GET_DATA_SAFELY(context.Input<Tensor>("Anchors"),
"Input",
"Anchors",
"GenerateProposals");
auto variances = GET_DATA_SAFELY(context.Input<Tensor>("Variances"),
"Input",
"Variances",
"GenerateProposals");

auto *rpn_rois = context.Output<LoDTensor>("RpnRois");
auto *rpn_roi_probs = context.Output<LoDTensor>("RpnRoiProbs");

int pre_nms_top_n = context.Attr<int>("pre_nms_topN");
int post_nms_top_n = context.Attr<int>("post_nms_topN");
float nms_thresh = context.Attr<float>("nms_thresh");
float min_size = context.Attr<float>("min_size");
float eta = context.Attr<float>("eta");
bool pixel_offset = context.Attr<bool>("pixel_offset");

auto &dev_ctx = context.template device_context<phi::CPUContext>();

auto &scores_dim = scores->dims();
int64_t num = scores_dim[0];
int64_t c_score = scores_dim[1];
int64_t h_score = scores_dim[2];
int64_t w_score = scores_dim[3];

auto &bbox_dim = bbox_deltas->dims();
int64_t c_bbox = bbox_dim[1];
int64_t h_bbox = bbox_dim[2];
int64_t w_bbox = bbox_dim[3];

rpn_rois->mutable_data<T>({bbox_deltas->numel() / 4, 4},
context.GetPlace());
rpn_roi_probs->mutable_data<T>({scores->numel(), 1}, context.GetPlace());

Tensor bbox_deltas_swap, scores_swap;
bbox_deltas_swap.mutable_data<T>({num, h_bbox, w_bbox, c_bbox},
dev_ctx.GetPlace());
scores_swap.mutable_data<T>({num, h_score, w_score, c_score},
dev_ctx.GetPlace());

phi::funcs::Transpose<phi::CPUContext, T, 4> trans;
std::vector<int> axis = {0, 2, 3, 1};
trans(dev_ctx, *bbox_deltas, &bbox_deltas_swap, axis);
trans(dev_ctx, *scores, &scores_swap, axis);

framework::LoD lod;
lod.resize(1);
auto &lod0 = lod[0];
lod0.push_back(0);
anchors.Resize({anchors.numel() / 4, 4});
variances.Resize({variances.numel() / 4, 4});
std::vector<int> tmp_num;

int64_t num_proposals = 0;
for (int64_t i = 0; i < num; ++i) {
Tensor im_shape_slice = im_shape->Slice(i, i + 1);
Tensor bbox_deltas_slice = bbox_deltas_swap.Slice(i, i + 1);
Tensor scores_slice = scores_swap.Slice(i, i + 1);

bbox_deltas_slice.Resize({h_bbox * w_bbox * c_bbox / 4, 4});
scores_slice.Resize({h_score * w_score * c_score, 1});

std::pair<Tensor, Tensor> tensor_pair =
ProposalForOneImage(dev_ctx,
im_shape_slice,
anchors,
variances,
bbox_deltas_slice,
scores_slice,
pre_nms_top_n,
post_nms_top_n,
nms_thresh,
min_size,
eta,
pixel_offset);
Tensor &proposals = tensor_pair.first;
Tensor &scores = tensor_pair.second;

AppendProposals(rpn_rois, 4 * num_proposals, proposals);
AppendProposals(rpn_roi_probs, num_proposals, scores);
num_proposals += proposals.dims()[0];
lod0.push_back(num_proposals);
tmp_num.push_back(proposals.dims()[0]);
}
if (context.HasOutput("RpnRoisNum")) {
auto *rpn_rois_num = context.Output<Tensor>("RpnRoisNum");
rpn_rois_num->mutable_data<int>({num}, context.GetPlace());
int *num_data = rpn_rois_num->data<int>();
for (int i = 0; i < num; i++) {
num_data[i] = tmp_num[i];
}
rpn_rois_num->Resize({num});
}
rpn_rois->set_lod(lod);
rpn_roi_probs->set_lod(lod);
rpn_rois->Resize({num_proposals, 4});
rpn_roi_probs->Resize({num_proposals, 1});
}

std::pair<Tensor, Tensor> ProposalForOneImage(
const phi::CPUContext &ctx,
const Tensor &im_shape_slice,
const Tensor &anchors,
const Tensor &variances,
const Tensor &bbox_deltas_slice, // [M, 4]
const Tensor &scores_slice, // [N, 1]
int pre_nms_top_n,
int post_nms_top_n,
float nms_thresh,
float min_size,
float eta,
bool pixel_offset = true) const {
auto *scores_data = scores_slice.data<T>();

// Sort index
Tensor index_t;
index_t.Resize({scores_slice.numel()});
int *index = index_t.mutable_data<int>(ctx.GetPlace());
for (int i = 0; i < scores_slice.numel(); ++i) {
index[i] = i;
}
auto compare = [scores_data](const int64_t &i, const int64_t &j) {
return scores_data[i] > scores_data[j];
};

if (pre_nms_top_n <= 0 || pre_nms_top_n >= scores_slice.numel()) {
std::sort(index, index + scores_slice.numel(), compare);
} else {
std::nth_element(
index, index + pre_nms_top_n, index + scores_slice.numel(), compare);
index_t.Resize({pre_nms_top_n});
}

Tensor scores_sel, bbox_sel, anchor_sel, var_sel;
scores_sel.mutable_data<T>({index_t.numel(), 1}, ctx.GetPlace());
bbox_sel.mutable_data<T>({index_t.numel(), 4}, ctx.GetPlace());
anchor_sel.mutable_data<T>({index_t.numel(), 4}, ctx.GetPlace());
var_sel.mutable_data<T>({index_t.numel(), 4}, ctx.GetPlace());

phi::funcs::CPUGather<T>(ctx, scores_slice, index_t, &scores_sel);
phi::funcs::CPUGather<T>(ctx, bbox_deltas_slice, index_t, &bbox_sel);
phi::funcs::CPUGather<T>(ctx, anchors, index_t, &anchor_sel);
phi::funcs::CPUGather<T>(ctx, variances, index_t, &var_sel);

Tensor proposals;
proposals.mutable_data<T>({index_t.numel(), 4}, ctx.GetPlace());
BoxCoder<T>(
ctx, &anchor_sel, &bbox_sel, &var_sel, &proposals, pixel_offset);

ClipTiledBoxes<T>(
ctx, im_shape_slice, proposals, &proposals, false, pixel_offset);

Tensor keep;
FilterBoxes<T>(
ctx, &proposals, min_size, im_shape_slice, false, &keep, pixel_offset);
// Handle the case when there is no keep index left
if (keep.numel() == 0) {
phi::funcs::SetConstant<phi::CPUContext, T> set_zero;
bbox_sel.mutable_data<T>({1, 4}, ctx.GetPlace());
set_zero(ctx, &bbox_sel, static_cast<T>(0));
Tensor scores_filter;
scores_filter.mutable_data<T>({1, 1}, ctx.GetPlace());
set_zero(ctx, &scores_filter, static_cast<T>(0));
return std::make_pair(bbox_sel, scores_filter);
}

Tensor scores_filter;
bbox_sel.mutable_data<T>({keep.numel(), 4}, ctx.GetPlace());
scores_filter.mutable_data<T>({keep.numel(), 1}, ctx.GetPlace());
phi::funcs::CPUGather<T>(ctx, proposals, keep, &bbox_sel);
phi::funcs::CPUGather<T>(ctx, scores_sel, keep, &scores_filter);
if (nms_thresh <= 0) {
return std::make_pair(bbox_sel, scores_filter);
}

Tensor keep_nms =
NMS<T>(ctx, &bbox_sel, &scores_filter, nms_thresh, eta, pixel_offset);

if (post_nms_top_n > 0 && post_nms_top_n < keep_nms.numel()) {
keep_nms.Resize({post_nms_top_n});
}

proposals.mutable_data<T>({keep_nms.numel(), 4}, ctx.GetPlace());
scores_sel.mutable_data<T>({keep_nms.numel(), 1}, ctx.GetPlace());
phi::funcs::CPUGather<T>(ctx, bbox_sel, keep_nms, &proposals);
phi::funcs::CPUGather<T>(ctx, scores_filter, keep_nms, &scores_sel);

return std::make_pair(proposals, scores_sel);
}
};

class GenerateProposalsV2OpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
Expand Down Expand Up @@ -336,16 +108,19 @@ to before and will not effect the result.
} // namespace operators
} // namespace paddle

DECLARE_INFER_SHAPE_FUNCTOR(generate_proposals_v2,
GenerateProposalsV2InferShapeFunctor,
PD_INFER_META(phi::GenerateProposalsV2InferMeta));

namespace ops = paddle::operators;
REGISTER_OPERATOR(
generate_proposals_v2,
ops::GenerateProposalsV2Op,
ops::GenerateProposalsV2OpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(generate_proposals_v2,
ops::GenerateProposalsV2Kernel<float>,
ops::GenerateProposalsV2Kernel<double>);
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
GenerateProposalsV2InferShapeFunctor);

REGISTER_OP_VERSION(generate_proposals_v2)
.AddCheckpoint(
R"ROC(Registe generate_proposals_v2 for adding the attribute of pixel_offset)ROC",
Expand Down
Loading