diff --git a/paddle/fluid/operators/mv_op.cc b/paddle/fluid/operators/mv_op.cc index ab9f10070fc60..d34a1ebf82c2f 100644 --- a/paddle/fluid/operators/mv_op.cc +++ b/paddle/fluid/operators/mv_op.cc @@ -16,8 +16,11 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/binary.h" namespace paddle { namespace operators { @@ -42,33 +45,6 @@ class MVOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContext *context) const override { - OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "mv"); - OP_INOUT_CHECK(context->HasInput("Vec"), "Input", "Vec", "mv"); - OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "mv"); - - auto dim_x = context->GetInputDim("X"); - auto dim_vec = context->GetInputDim("Vec"); - PADDLE_ENFORCE_EQ( - dim_x.size(), 2, - platform::errors::InvalidArgument( - "The rank of input X should be 2, but is %d", dim_x.size())); - PADDLE_ENFORCE_EQ( - dim_vec.size(), 1, - platform::errors::InvalidArgument( - "The rank of input Vec should be 1, but is %d", dim_vec.size())); - PADDLE_ENFORCE_EQ(dim_x[1], dim_vec[0], - platform::errors::InvalidArgument( - "X's second dimension is expected to be equal to " - "Vec's first dimension" - "but recieved X'shape = [%s], Vec's shape = [%s]", - dim_x, dim_vec)); - - framework::DDim dim_out = phi::make_ddim({dim_x[0]}); - - context->SetOutputDim("Out", dim_out); - context->ShareLoD("X", /*->*/ "Out"); - } }; template @@ -118,7 +94,11 @@ class MVOpGrad : public framework::OperatorWithKernel { namespace ops = paddle::operators; namespace plat = paddle::platform; +DELCARE_INFER_SHAPE_FUNCTOR(mv, MvInferShapeFunctor, + PT_INFER_META(phi::MvInferMeta)); + REGISTER_OPERATOR(mv, ops::MVOp, ops::MVOpMaker, ops::MVOpGradMaker, - ops::MVOpGradMaker); + ops::MVOpGradMaker, + MvInferShapeFunctor); REGISTER_OPERATOR(mv_grad, ops::MVOpGrad); diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 745ddffabbe33..03128e96a838f 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -443,4 +443,34 @@ void GatherTreeMeta(const MetaTensor& ids, out->set_dims(ids_dims); } +void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out) { + auto dim_x = x.dims(); + auto dim_vec = vec.dims(); + PADDLE_ENFORCE_EQ( + dim_x.size(), + 2, + phi::errors::InvalidArgument("The rank of input X should be 2, but is %d", + dim_x.size())); + PADDLE_ENFORCE_EQ( + dim_vec.size(), + 1, + phi::errors::InvalidArgument( + "The rank of input Vec should be 1, but is %d", dim_vec.size())); + PADDLE_ENFORCE_EQ(dim_x[1], + dim_vec[0], + phi::errors::InvalidArgument( + "X's second dimension is expected to be equal to " + "Vec's first dimension" + "but recieved X'shape = [%s], Vec's shape = [%s]", + dim_x, + dim_vec)); + + auto dim_out = phi::make_ddim({dim_x[0]}); + + out->set_dims(dim_out); + out->set_dtype(x.dtype()); + out->set_layout(x.layout()); + out->share_lod(x); +} + } // namespace phi diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 2ec744636988f..f397c0def8a0b 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -85,4 +85,7 @@ void GatherNdInferMeta(const MetaTensor& x, void GatherTreeMeta(const MetaTensor& ids, const MetaTensor& parents, MetaTensor* out); + +void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out); + } // namespace phi diff --git a/paddle/phi/ops/compat/mv_sig.cc b/paddle/phi/ops/compat/mv_sig.cc index ab0d31ee31dab..0012f8e1ccb41 100644 --- a/paddle/phi/ops/compat/mv_sig.cc +++ b/paddle/phi/ops/compat/mv_sig.cc @@ -16,10 +16,6 @@ namespace phi { -KernelSignature MvOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("mv", {"X", "Vec"}, {}, {"Out"}); -} - KernelSignature MvGradOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature("mv_grad", {"X", "Vec", GradVarName("Out")}, @@ -29,5 +25,4 @@ KernelSignature MvGradOpArgumentMapping(const ArgumentMappingContext& ctx) { } // namespace phi -PD_REGISTER_ARG_MAPPING_FN(mv, phi::MvOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(mv_grad, phi::MvGradOpArgumentMapping);