Skip to content

Commit

Permalink
CPU forward calculation replaces Eigen with Lapack;Modify linalg expo…
Browse files Browse the repository at this point in the history
…sure rules (#35916)
  • Loading branch information
Zjq9409 committed Sep 26, 2021
1 parent 1b90f96 commit 7ff226f
Show file tree
Hide file tree
Showing 9 changed files with 217 additions and 179 deletions.
17 changes: 8 additions & 9 deletions paddle/fluid/operators/eigh_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,18 +147,17 @@ REGISTER_OPERATOR(eigh, ops::EighOp, ops::EignOpMaker,
REGISTER_OPERATOR(eigh_grad, ops::EighGradOp);

REGISTER_OP_CPU_KERNEL(
eigh, ops::EighKernel<paddle::platform::CPUDeviceContext, float, float>,
ops::EighKernel<paddle::platform::CPUDeviceContext, double, double>,
ops::EighKernel<paddle::platform::CPUDeviceContext, float,
eigh, ops::EighKernel<paddle::platform::CPUDeviceContext, float>,
ops::EighKernel<paddle::platform::CPUDeviceContext, double>,
ops::EighKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::EighKernel<paddle::platform::CPUDeviceContext, double,
ops::EighKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);

REGISTER_OP_CPU_KERNEL(
eigh_grad,
ops::EighGradKernel<paddle::platform::CPUDeviceContext, float, float>,
ops::EighGradKernel<paddle::platform::CPUDeviceContext, double, double>,
ops::EighGradKernel<paddle::platform::CPUDeviceContext, float,
eigh_grad, ops::EighGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::EighGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::EighGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::EighGradKernel<paddle::platform::CPUDeviceContext, double,
ops::EighGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
17 changes: 8 additions & 9 deletions paddle/fluid/operators/eigh_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,17 @@ limitations under the License. */

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
eigh, ops::EighKernel<paddle::platform::CUDADeviceContext, float, float>,
ops::EighKernel<paddle::platform::CUDADeviceContext, double, double>,
ops::EighKernel<paddle::platform::CUDADeviceContext, float,
eigh, ops::EighKernel<paddle::platform::CUDADeviceContext, float>,
ops::EighKernel<paddle::platform::CUDADeviceContext, double>,
ops::EighKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::EighKernel<paddle::platform::CUDADeviceContext, double,
ops::EighKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);

REGISTER_OP_CUDA_KERNEL(
eigh_grad,
ops::EighGradKernel<paddle::platform::CUDADeviceContext, float, float>,
ops::EighGradKernel<paddle::platform::CUDADeviceContext, double, double>,
ops::EighGradKernel<paddle::platform::CUDADeviceContext, float,
eigh_grad, ops::EighGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::EighGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::EighGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::EighGradKernel<paddle::platform::CUDADeviceContext, double,
ops::EighGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
7 changes: 4 additions & 3 deletions paddle/fluid/operators/eigh_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace operators {

using Tensor = framework::Tensor;

template <typename DeviceContext, typename ValueType, typename T>
template <typename DeviceContext, typename T>
class EighKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
Expand All @@ -31,15 +31,16 @@ class EighKernel : public framework::OpKernel<T> {
auto output_v = ctx.Output<Tensor>("Eigenvectors");
std::string lower = ctx.Attr<std::string>("UPLO");
bool is_lower = (lower == "L");
math::MatrixEighFunctor<DeviceContext, ValueType, T> functor;
math::MatrixEighFunctor<DeviceContext, T> functor;
functor(ctx, *input, output_w, output_v, is_lower, true);
}
};

template <typename DeviceContext, typename ValueType, typename T>
template <typename DeviceContext, typename T>
class EighGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using ValueType = math::Real<T>;
auto& x_grad = *ctx.Output<framework::Tensor>(framework::GradVarName("X"));
x_grad.mutable_data<T>(ctx.GetPlace());
auto& output_w = *ctx.Input<Tensor>("Eigenvalues");
Expand Down
Loading

0 comments on commit 7ff226f

Please sign in to comment.