Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Phi]Move eye, lerp infershape to phi #40105

Merged
merged 1 commit into from
Mar 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 7 additions & 19 deletions paddle/fluid/operators/eye_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/nullary.h"

namespace paddle {
namespace operators {
Expand All @@ -21,24 +24,6 @@ class EyeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of EyeOP should not be null."));
auto num_rows = ctx->Attrs().Get<int64_t>("num_rows");
PADDLE_ENFORCE_EQ(
num_rows >= 0, true,
platform::errors::InvalidArgument(
"The value of Input(num_rows) should be non-negative int."));
auto num_columns = ctx->Attrs().Get<int64_t>("num_columns");
if (num_columns == -1) num_columns = num_rows;
PADDLE_ENFORCE_EQ(
num_columns >= 0, true,
platform::errors::InvalidArgument(
"The value of Input(num_columns) should be non-negative int."));
ctx->SetOutputDim("Out", {num_rows, num_columns});
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
Expand Down Expand Up @@ -82,8 +67,11 @@ Return an identity tensor whose shape is [num_rows, num_columns].
} // namespace paddle

namespace ops = paddle::operators;
DELCARE_INFER_SHAPE_FUNCTOR(eye, EyeInferShapeFunctor,
PT_INFER_META(phi::EyeInferMeta));

REGISTER_OPERATOR(
eye, ops::EyeOp, ops::EyeOpMaker, ops::EyeOpVarTypeInference,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
EyeInferShapeFunctor);
50 changes: 6 additions & 44 deletions paddle/fluid/operators/lerp_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,57 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/ternary.h"

namespace paddle {
namespace operators {

class LerpOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "lerp");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "lerp");
OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "lerp");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "lerp");

auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
auto w_dims = ctx->GetInputDim("Weight");
framework::DDim out_dims;
out_dims = GetOutputDims(x_dims, y_dims);
if (w_dims.size() > 1 || w_dims[0] != 1) {
out_dims = GetOutputDims(out_dims, w_dims);
}

ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}

private:
framework::DDim GetOutputDims(const framework::DDim& s_dims,
const framework::DDim& l_dims) const {
if (s_dims.size() > l_dims.size()) {
return GetOutputDims(l_dims, s_dims);
}
std::vector<int64_t> shapes = phi::vectorize<int64_t>(l_dims);
for (int i = s_dims.size() - 1, j = l_dims.size() - 1; i >= 0; --i, --j) {
int64_t s = s_dims[i];
int64_t l = l_dims[j];
if (s != l) {
if (l == 1) {
shapes[j] = s;
} else if (s != 1) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The shape of tensor a %s:%d must match shape of tensor b "
"%s:%d.",
s_dims.to_str(), i, l_dims.to_str(), j));
}
}
}
return phi::make_ddim(shapes);
}
};

class LerpOpMaker : public framework::OpProtoAndCheckerMaker {
Expand Down Expand Up @@ -125,10 +85,12 @@ DECLARE_INPLACE_OP_INFERER(LerpInplaceInferer, {"X", "Out"});
} // namespace operators
} // namespace paddle

DELCARE_INFER_SHAPE_FUNCTOR(lerp, LerpInferShapeFunctor,
PT_INFER_META(phi::LerpInferMeta));
REGISTER_OPERATOR(
lerp, paddle::operators::LerpOp, paddle::operators::LerpOpMaker,
paddle::operators::LerpOpGradMaker<paddle::framework::OpDesc>,
paddle::operators::LerpOpGradMaker<paddle::imperative::OpBase>,
paddle::operators::LerpInplaceInferer);
paddle::operators::LerpInplaceInferer, LerpInferShapeFunctor);

REGISTER_OPERATOR(lerp_grad, paddle::operators::LerpGradOp);
8 changes: 8 additions & 0 deletions paddle/phi/infermeta/nullary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,12 @@ void CreateInferMeta(const ScalarArray& shape,
CreateInferMetaBase(shape.GetData(), dtype, DataLayout::NCHW, out);
}

void EyeInferMeta(int64_t num_rows,
int64_t num_columns,
DataType dtype,
MetaTensor* out) {
if (num_columns == -1) num_columns = num_rows;
out->set_dims({num_rows, num_columns});
out->set_dtype(dtype);
}
} // namespace phi
5 changes: 5 additions & 0 deletions paddle/phi/infermeta/nullary.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,9 @@ void CreateInferMetaBase(const std::vector<int64_t>& shape,

void CreateInferMeta(const ScalarArray& shape, DataType dtype, MetaTensor* out);

void EyeInferMeta(int64_t num_rows,
int64_t num_columns,
DataType dtype,
MetaTensor* out);

} // namespace phi
17 changes: 17 additions & 0 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,21 @@ void AddmmInferMeta(const MetaTensor& input,
out->set_dtype(input.dtype());
}

void LerpInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& weight,
MetaTensor* out) {
auto x_dims = x.dims();
auto y_dims = y.dims();
auto w_dims = weight.dims();
DDim out_dims;
out_dims = funcs::GetOutputDims(x_dims, y_dims);
if (w_dims.size() > 1 || w_dims[0] != 1) {
out_dims = funcs::GetOutputDims(out_dims, w_dims);
}
out->set_dims(out_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
}

} // namespace phi
5 changes: 5 additions & 0 deletions paddle/phi/infermeta/ternary.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,9 @@ void AddmmInferMeta(const MetaTensor& input,
float beta,
MetaTensor* out);

void LerpInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& weight,
MetaTensor* out);

} // namespace phi
2 changes: 1 addition & 1 deletion paddle/phi/kernels/eye_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ template <typename T, typename Context>
void EyeKernel(const Context& ctx,
int64_t num_rows,
int64_t num_columns,
int dtype,
DataType dtype,
DenseTensor* out);

} // namespace phi
25 changes: 25 additions & 0 deletions paddle/phi/kernels/funcs/common_shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,5 +140,30 @@ inline bool CheckDims(const DDim &dims_x, const DDim &dims_y) {
return true;
}

inline DDim GetOutputDims(const DDim &s_dims, const DDim &l_dims) {
if (s_dims.size() > l_dims.size()) {
return GetOutputDims(l_dims, s_dims);
}
std::vector<int64_t> shapes = phi::vectorize<int64_t>(l_dims);
for (int i = s_dims.size() - 1, j = l_dims.size() - 1; i >= 0; --i, --j) {
int64_t s = s_dims[i];
int64_t l = l_dims[j];
if (s != l) {
if (l == 1) {
shapes[j] = s;
} else if (s != 1) {
PADDLE_THROW(errors::InvalidArgument(
"The shape of tensor a %s:%d must match shape of tensor b "
"%s:%d.",
s_dims.to_str(),
i,
l_dims.to_str(),
j));
}
}
}
return phi::make_ddim(shapes);
}

} // namespace funcs
} // namespace phi
2 changes: 1 addition & 1 deletion paddle/phi/kernels/impl/eye_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ template <typename T, typename Context>
void EyeKernel(const Context& ctx,
int64_t num_rows,
int64_t num_columns,
int dtype,
DataType dtype,
DenseTensor* out) {
auto num = num_columns;
if (num == -1) {
Expand Down