Skip to content

Commit

Permalink
migrate eigh to phi
Browse files Browse the repository at this point in the history
  • Loading branch information
Zjq9409 committed Mar 7, 2022
1 parent da3de72 commit 5a9bd8d
Show file tree
Hide file tree
Showing 15 changed files with 756 additions and 160 deletions.
63 changes: 10 additions & 53 deletions paddle/fluid/operators/eigh_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/eigh_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"

namespace paddle {
namespace operators {
Expand All @@ -22,42 +25,9 @@ using framework::Tensor;
class EighOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Eigh");
OP_INOUT_CHECK(ctx->HasOutput("Eigenvalues"), "Output", "Eigenvalues",
"Eigh");
OP_INOUT_CHECK(ctx->HasOutput("Eigenvectors"), "Output", "Eigenvectors",
"Eigh");

auto input_dim = ctx->GetInputDim("X");
auto rank = input_dim.size();

PADDLE_ENFORCE_GE(rank, 2,
platform::errors::InvalidArgument(
"The Input(X) should have at least 2 dimensions."
"But received a %d dimension tensor.",
rank));
PADDLE_ENFORCE_EQ(
input_dim[rank - 2], input_dim[rank - 1],
platform::errors::InvalidArgument(
"Eigh op is designed for square matrix, consequently"
"inner-most 2 dimensions of Input(X) should be symmetric."
"But received X's shape[-2] = %d and shape[-1] = %d.",
input_dim[rank - 2], input_dim[rank - 1]));

std::vector<int64_t> values_dim;

for (auto i = 0; i < rank - 1; i++) {
values_dim.emplace_back(input_dim[i]);
}

ctx->SetOutputDim("Eigenvalues", phi::make_ddim(values_dim));
ctx->SetOutputDim("Eigenvectors", input_dim);
}
};

class EignOpMaker : public framework::OpProtoAndCheckerMaker {
class EighOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
Expand Down Expand Up @@ -140,24 +110,11 @@ class EighGradOpMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle

namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(eigh, EighInferShapeFunctor,
PD_INFER_META(phi::EighInferMeta));

REGISTER_OPERATOR(eigh, ops::EighOp, ops::EignOpMaker,
REGISTER_OPERATOR(eigh, ops::EighOp, ops::EighOpMaker,
ops::EighGradOpMaker<paddle::framework::OpDesc>,
ops::EighGradOpMaker<paddle::imperative::OpBase>);
ops::EighGradOpMaker<paddle::imperative::OpBase>,
EighInferShapeFunctor);
REGISTER_OPERATOR(eigh_grad, ops::EighGradOp);

REGISTER_OP_CPU_KERNEL(
eigh, ops::EighKernel<paddle::platform::CPUDeviceContext, float>,
ops::EighKernel<paddle::platform::CPUDeviceContext, double>,
ops::EighKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::EighKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);

REGISTER_OP_CPU_KERNEL(
eigh_grad, ops::EighGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::EighGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::EighGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::EighGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
32 changes: 0 additions & 32 deletions paddle/fluid/operators/eigh_op.cu

This file was deleted.

74 changes: 0 additions & 74 deletions paddle/fluid/operators/eigh_op.h

This file was deleted.

32 changes: 32 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,38 @@ void TransposeInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype());
}

void EighInferMeta(const MetaTensor& x,
const std::string& uplo,
MetaTensor* out_v,
MetaTensor* out_w) {
auto input_dim = x.dims();
auto rank = input_dim.size();

PADDLE_ENFORCE_GE(rank,
2,
phi::errors::InvalidArgument(
"The Input(X) should have at least 2 dimensions."
"But received a %d dimension tensor.",
rank));
PADDLE_ENFORCE_EQ(
input_dim[rank - 2],
input_dim[rank - 1],
phi::errors::InvalidArgument(
"Eigh op is designed for square matrix, consequently"
"inner-most 2 dimensions of Input(X) should be symmetric."
"But received X's shape[-2] = %d and shape[-1] = %d.",
input_dim[rank - 2],
input_dim[rank - 1]));

std::vector<int64_t> values_dim;

for (auto i = 0; i < rank - 1; i++) {
values_dim.emplace_back(input_dim[i]);
}
out_v->set_dims(phi::make_ddim(values_dim));
out_w->set_dims(input_dim);
}

} // namespace phi

PD_REGISTER_INFER_META_FN(copy_to, phi::CopyToInferMeta);
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,4 +152,9 @@ void TransposeInferMeta(const MetaTensor& x,
const std::vector<int>& axis,
MetaTensor* out);

void EighInferMeta(const MetaTensor& x,
const std::string& uplo,
MetaTensor* out_w,
MetaTensor* out_v);

} // namespace phi
3 changes: 2 additions & 1 deletion paddle/phi/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} infermeta)
# NOTE: Some kernels depend on some targets that are not commonly used.
# These targets are not suitable for common dependencies.
# In this case, you need to manually generate them here.
set(MANUAL_BUILD_KERNELS math_kernel softmax_kernel softmax_grad_kernel triangular_solve_grad_kernel)
set(MANUAL_BUILD_KERNELS math_kernel softmax_kernel softmax_grad_kernel triangular_solve_grad_kernel eigh_kernel)
kernel_library(math_kernel DEPS ${COMMON_KERNEL_DEPS} cast_kernel copy_kernel)
kernel_library(softmax_kernel DEPS ${COMMON_KERNEL_DEPS} softmax)
kernel_library(softmax_grad_kernel DEPS ${COMMON_KERNEL_DEPS} softmax)
kernel_library(triangular_solve_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_reduce)
kernel_library(eigh_kernel DEPS ${COMMON_KERNEL_DEPS} lapack_function)

# auto parse and build kernel targets by cmake
register_kernels(EXCLUDES ${MANUAL_BUILD_KERNELS} DEPS ${COMMON_KERNEL_DEPS})
Expand Down
28 changes: 28 additions & 0 deletions paddle/phi/kernels/cpu/eigh_grad_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/eigh_grad_kernel.h"
#include "paddle/phi/kernels/impl/eigh_grad_kernel_impl.h"

#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"

PD_REGISTER_KERNEL(eigh_grad,
CPU,
ALL_LAYOUT,
phi::EighGradKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
43 changes: 43 additions & 0 deletions paddle/phi/kernels/cpu/eigh_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/eigh_kernel.h"
#include "paddle/phi/kernels/funcs/values_vectors_functor.h"

#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"

namespace phi {

template <typename T, typename Context>
void EighKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::string& uplo,
DenseTensor* out_v,
DenseTensor* out_w) {
bool is_lower = (uplo == "L");
phi::funcs::MatrixEighFunctor<Context, T> functor;
functor(dev_ctx, x, out_v, out_w, is_lower, true);
}

} // namespace phi

PD_REGISTER_KERNEL(eigh,
CPU,
ALL_LAYOUT,
phi::EighKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
30 changes: 30 additions & 0 deletions paddle/phi/kernels/eigh_grad_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"

namespace phi {

template <typename T, typename Context>
void EighGardKernel(const Context& dev_ctx,
const DenseTensor& out_v,
const DenseTensor& out_w,
const DenseTensor& dout_v,
const DenseTensor& dout_w,
DenseTensor* dx);

} // namespace phi
29 changes: 29 additions & 0 deletions paddle/phi/kernels/eigh_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"

namespace phi {

template <typename T, typename Context>
void EighKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::string& uplo,
DenseTensor* out_v,
DenseTensor* out_w);

} // namespace phi
Loading

0 comments on commit 5a9bd8d

Please sign in to comment.