From d20e52fa85b8d56e5ef099b6a57051a877a9b518 Mon Sep 17 00:00:00 2001 From: zhangkeliang Date: Sat, 26 Feb 2022 02:48:09 +0000 Subject: [PATCH 1/2] [Phi] move infershape for mv --- paddle/fluid/operators/mv_op.cc | 36 ++++++++------------------------- paddle/phi/infermeta/binary.cc | 34 +++++++++++++++++++++++++++++++ paddle/phi/infermeta/binary.h | 2 ++ 3 files changed, 44 insertions(+), 28 deletions(-) diff --git a/paddle/fluid/operators/mv_op.cc b/paddle/fluid/operators/mv_op.cc index ab9f10070fc60..d34a1ebf82c2f 100644 --- a/paddle/fluid/operators/mv_op.cc +++ b/paddle/fluid/operators/mv_op.cc @@ -16,8 +16,11 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/binary.h" namespace paddle { namespace operators { @@ -42,33 +45,6 @@ class MVOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContext *context) const override { - OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "mv"); - OP_INOUT_CHECK(context->HasInput("Vec"), "Input", "Vec", "mv"); - OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "mv"); - - auto dim_x = context->GetInputDim("X"); - auto dim_vec = context->GetInputDim("Vec"); - PADDLE_ENFORCE_EQ( - dim_x.size(), 2, - platform::errors::InvalidArgument( - "The rank of input X should be 2, but is %d", dim_x.size())); - PADDLE_ENFORCE_EQ( - dim_vec.size(), 1, - platform::errors::InvalidArgument( - "The rank of input Vec should be 1, but is %d", dim_vec.size())); - PADDLE_ENFORCE_EQ(dim_x[1], dim_vec[0], - platform::errors::InvalidArgument( - "X's second dimension is expected to be equal to " - "Vec's first dimension" - "but recieved X'shape = [%s], Vec's shape = [%s]", - dim_x, dim_vec)); - - framework::DDim dim_out = phi::make_ddim({dim_x[0]}); - - context->SetOutputDim("Out", dim_out); - context->ShareLoD("X", /*->*/ "Out"); - } }; template @@ -118,7 +94,11 @@ class MVOpGrad : public framework::OperatorWithKernel { namespace ops = paddle::operators; namespace plat = paddle::platform; +DELCARE_INFER_SHAPE_FUNCTOR(mv, MvInferShapeFunctor, + PT_INFER_META(phi::MvInferMeta)); + REGISTER_OPERATOR(mv, ops::MVOp, ops::MVOpMaker, ops::MVOpGradMaker, - ops::MVOpGradMaker); + ops::MVOpGradMaker, + MvInferShapeFunctor); REGISTER_OPERATOR(mv_grad, ops::MVOpGrad); diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 745ddffabbe33..9ac21c521e220 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -443,4 +443,38 @@ void GatherTreeMeta(const MetaTensor& ids, out->set_dims(ids_dims); } +void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out) { + // OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "mv"); + // OP_INOUT_CHECK(context->HasInput("Vec"), "Input", "Vec", "mv"); + // OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "mv"); + + auto dim_x = x.dims(); + auto dim_vec = vec.dims(); + PADDLE_ENFORCE_EQ( + dim_x.size(), + 2, + phi::errors::InvalidArgument("The rank of input X should be 2, but is %d", + dim_x.size())); + PADDLE_ENFORCE_EQ( + dim_vec.size(), + 1, + phi::errors::InvalidArgument( + "The rank of input Vec should be 1, but is %d", dim_vec.size())); + PADDLE_ENFORCE_EQ(dim_x[1], + dim_vec[0], + phi::errors::InvalidArgument( + "X's second dimension is expected to be equal to " + "Vec's first dimension" + "but recieved X'shape = [%s], Vec's shape = [%s]", + dim_x, + dim_vec)); + + auto dim_out = phi::make_ddim({dim_x[0]}); + + out->set_dims(dim_out); + out->set_dtype(x.dtype()); + out->set_layout(x.layout()); + out->share_lod(x); +} + } // namespace phi diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 2ec744636988f..e9216c94a943c 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -85,4 +85,6 @@ void GatherNdInferMeta(const MetaTensor& x, void GatherTreeMeta(const MetaTensor& ids, const MetaTensor& parents, MetaTensor* out); +void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out); + } // namespace phi From 9f5e398bc4d3c47ea105373109bfeb33e46ddcff Mon Sep 17 00:00:00 2001 From: zhangkeliang Date: Sat, 26 Feb 2022 05:13:06 +0000 Subject: [PATCH 2/2] [Phi] delete extra codes for mv --- paddle/phi/infermeta/binary.cc | 4 ---- paddle/phi/infermeta/binary.h | 1 + paddle/phi/ops/compat/mv_sig.cc | 5 ----- 3 files changed, 1 insertion(+), 9 deletions(-) diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 9ac21c521e220..03128e96a838f 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -444,10 +444,6 @@ void GatherTreeMeta(const MetaTensor& ids, } void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out) { - // OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "mv"); - // OP_INOUT_CHECK(context->HasInput("Vec"), "Input", "Vec", "mv"); - // OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "mv"); - auto dim_x = x.dims(); auto dim_vec = vec.dims(); PADDLE_ENFORCE_EQ( diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index e9216c94a943c..f397c0def8a0b 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -85,6 +85,7 @@ void GatherNdInferMeta(const MetaTensor& x, void GatherTreeMeta(const MetaTensor& ids, const MetaTensor& parents, MetaTensor* out); + void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out); } // namespace phi diff --git a/paddle/phi/ops/compat/mv_sig.cc b/paddle/phi/ops/compat/mv_sig.cc index ab0d31ee31dab..0012f8e1ccb41 100644 --- a/paddle/phi/ops/compat/mv_sig.cc +++ b/paddle/phi/ops/compat/mv_sig.cc @@ -16,10 +16,6 @@ namespace phi { -KernelSignature MvOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("mv", {"X", "Vec"}, {}, {"Out"}); -} - KernelSignature MvGradOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature("mv_grad", {"X", "Vec", GradVarName("Out")}, @@ -29,5 +25,4 @@ KernelSignature MvGradOpArgumentMapping(const ArgumentMappingContext& ctx) { } // namespace phi -PD_REGISTER_ARG_MAPPING_FN(mv, phi::MvOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(mv_grad, phi::MvGradOpArgumentMapping);