Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move psroi_pool OP to phi #40353

Merged
merged 2 commits into from
Mar 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 13 additions & 94 deletions paddle/fluid/operators/psroi_pool_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/psroi_pool_op.h"
#include <memory>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/multiary.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;

class PSROIPoolOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
Expand Down Expand Up @@ -82,75 +82,6 @@ class PSROIPoolOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(X) of PSROIPoolOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("ROIs"), true,
platform::errors::InvalidArgument(
"Input(ROIs) of PSROIPoolOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of PSROIPoolOp should not be null."));
auto input_dims = ctx->GetInputDim("X");
auto rois_dims = ctx->GetInputDim("ROIs");

PADDLE_ENFORCE_EQ(input_dims.size(), 4,
platform::errors::InvalidArgument(
"The format of input tensor is NCHW"));
PADDLE_ENFORCE_EQ(
rois_dims.size(), 2,
platform::errors::InvalidArgument(
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4) "
"given as [(x1, y1, x2, y2), ...]"));
PADDLE_ENFORCE_EQ(
rois_dims[1], 4,
platform::errors::InvalidArgument(
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4) "
"given as [(x1, y1, x2, y2), ...]"));
if (ctx->HasInput("RoisNum")) {
auto rois_num_dims = ctx->GetInputDim("RoisNum");
PADDLE_ENFORCE_EQ(rois_num_dims.size(), 1,
platform::errors::InvalidArgument(
"The second dimension of RoisNum should "
"be 1, but received dimension is %d",
rois_num_dims.size()));
}
int pooled_height = ctx->Attrs().Get<int>("pooled_height");
int pooled_width = ctx->Attrs().Get<int>("pooled_width");
int output_channels = ctx->Attrs().Get<int>("output_channels");
float spatial_scale = ctx->Attrs().Get<float>("spatial_scale");

PADDLE_ENFORCE_EQ(
input_dims[1], output_channels * pooled_height * pooled_width,
platform::errors::InvalidArgument(
"the channel of X(%d) "
"should be equal to the product of "
"output_channels(%d), pooled_height(%d) and pooled_width(%d)",
input_dims[1], output_channels, pooled_height, pooled_width));

PADDLE_ENFORCE_GT(pooled_height, 0,
platform::errors::InvalidArgument(
"The pooled output height must be greater than 0"));
PADDLE_ENFORCE_GT(pooled_width, 0,
platform::errors::InvalidArgument(
"The pooled output width must be greater than 0"));
PADDLE_ENFORCE_GT(output_channels, 1,
platform::errors::InvalidArgument(
"The pooled output channels must greater than 1"));
PADDLE_ENFORCE_GT(spatial_scale, 0.0f,
platform::errors::InvalidArgument(
"The spatial scale must greater than 0."));

auto out_dims = input_dims;
out_dims[0] = rois_dims[0];
out_dims[1] =
output_channels; // input_dims[1] / (pooled_height * pooled_width);
out_dims[2] = pooled_height;
out_dims[3] = pooled_width;
ctx->SetOutputDim("Out", out_dims);
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
Expand All @@ -164,16 +95,6 @@ class PSROIPoolGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
platform::errors::InvalidArgument(
"The gradient of Out should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true,
platform::errors::InvalidArgument(
"The gradient of X should not be null."));
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
Expand Down Expand Up @@ -204,15 +125,13 @@ class PSROIPoolGradMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle

namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(psroi_pool, PsroiPoolInferShapeFunctor,
PD_INFER_META(phi::PsroiPoolInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(psroi_pool_grad, PsroiPoolGradInferShapeFunctor,
PD_INFER_META(phi::PsroiPoolGradInferMeta));
REGISTER_OPERATOR(psroi_pool, ops::PSROIPoolOp, ops::PSROIPoolOpMaker,
ops::PSROIPoolGradMaker<paddle::framework::OpDesc>,
ops::PSROIPoolGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(psroi_pool_grad, ops::PSROIPoolGradOp);
REGISTER_OP_CPU_KERNEL(
psroi_pool,
ops::CPUPSROIPoolOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::CPUPSROIPoolOpKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
psroi_pool_grad,
ops::CPUPSROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::CPUPSROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, double>);
ops::PSROIPoolGradMaker<paddle::imperative::OpBase>,
PsroiPoolInferShapeFunctor);
REGISTER_OPERATOR(psroi_pool_grad, ops::PSROIPoolGradOp,
PsroiPoolGradInferShapeFunctor);
Loading