Skip to content

Commit

Permalink
Move psroi_pool OP to phi (#40353)
Browse files Browse the repository at this point in the history
* Move psroi_pool OP to phi

* Replace platform::TensorCopy with phi::Copy
  • Loading branch information
From00 committed Mar 11, 2022
1 parent 89ed57e commit c0e2923
Show file tree
Hide file tree
Showing 14 changed files with 1,079 additions and 854 deletions.
107 changes: 13 additions & 94 deletions paddle/fluid/operators/psroi_pool_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/psroi_pool_op.h"
#include <memory>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/multiary.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;

class PSROIPoolOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
Expand Down Expand Up @@ -82,75 +82,6 @@ class PSROIPoolOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(X) of PSROIPoolOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("ROIs"), true,
platform::errors::InvalidArgument(
"Input(ROIs) of PSROIPoolOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of PSROIPoolOp should not be null."));
auto input_dims = ctx->GetInputDim("X");
auto rois_dims = ctx->GetInputDim("ROIs");

PADDLE_ENFORCE_EQ(input_dims.size(), 4,
platform::errors::InvalidArgument(
"The format of input tensor is NCHW"));
PADDLE_ENFORCE_EQ(
rois_dims.size(), 2,
platform::errors::InvalidArgument(
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4) "
"given as [(x1, y1, x2, y2), ...]"));
PADDLE_ENFORCE_EQ(
rois_dims[1], 4,
platform::errors::InvalidArgument(
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4) "
"given as [(x1, y1, x2, y2), ...]"));
if (ctx->HasInput("RoisNum")) {
auto rois_num_dims = ctx->GetInputDim("RoisNum");
PADDLE_ENFORCE_EQ(rois_num_dims.size(), 1,
platform::errors::InvalidArgument(
"The second dimension of RoisNum should "
"be 1, but received dimension is %d",
rois_num_dims.size()));
}
int pooled_height = ctx->Attrs().Get<int>("pooled_height");
int pooled_width = ctx->Attrs().Get<int>("pooled_width");
int output_channels = ctx->Attrs().Get<int>("output_channels");
float spatial_scale = ctx->Attrs().Get<float>("spatial_scale");

PADDLE_ENFORCE_EQ(
input_dims[1], output_channels * pooled_height * pooled_width,
platform::errors::InvalidArgument(
"the channel of X(%d) "
"should be equal to the product of "
"output_channels(%d), pooled_height(%d) and pooled_width(%d)",
input_dims[1], output_channels, pooled_height, pooled_width));

PADDLE_ENFORCE_GT(pooled_height, 0,
platform::errors::InvalidArgument(
"The pooled output height must be greater than 0"));
PADDLE_ENFORCE_GT(pooled_width, 0,
platform::errors::InvalidArgument(
"The pooled output width must be greater than 0"));
PADDLE_ENFORCE_GT(output_channels, 1,
platform::errors::InvalidArgument(
"The pooled output channels must greater than 1"));
PADDLE_ENFORCE_GT(spatial_scale, 0.0f,
platform::errors::InvalidArgument(
"The spatial scale must greater than 0."));

auto out_dims = input_dims;
out_dims[0] = rois_dims[0];
out_dims[1] =
output_channels; // input_dims[1] / (pooled_height * pooled_width);
out_dims[2] = pooled_height;
out_dims[3] = pooled_width;
ctx->SetOutputDim("Out", out_dims);
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
Expand All @@ -164,16 +95,6 @@ class PSROIPoolGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
platform::errors::InvalidArgument(
"The gradient of Out should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true,
platform::errors::InvalidArgument(
"The gradient of X should not be null."));
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
Expand Down Expand Up @@ -204,15 +125,13 @@ class PSROIPoolGradMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle

namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(psroi_pool, PsroiPoolInferShapeFunctor,
PD_INFER_META(phi::PsroiPoolInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(psroi_pool_grad, PsroiPoolGradInferShapeFunctor,
PD_INFER_META(phi::PsroiPoolGradInferMeta));
REGISTER_OPERATOR(psroi_pool, ops::PSROIPoolOp, ops::PSROIPoolOpMaker,
ops::PSROIPoolGradMaker<paddle::framework::OpDesc>,
ops::PSROIPoolGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(psroi_pool_grad, ops::PSROIPoolGradOp);
REGISTER_OP_CPU_KERNEL(
psroi_pool,
ops::CPUPSROIPoolOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::CPUPSROIPoolOpKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
psroi_pool_grad,
ops::CPUPSROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::CPUPSROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, double>);
ops::PSROIPoolGradMaker<paddle::imperative::OpBase>,
PsroiPoolInferShapeFunctor);
REGISTER_OPERATOR(psroi_pool_grad, ops::PSROIPoolGradOp,
PsroiPoolGradInferShapeFunctor);
Loading

0 comments on commit c0e2923

Please sign in to comment.