From 054fc9978dff3c3629de92e61fc5ab960ecb8c37 Mon Sep 17 00:00:00 2001 From: tianshuo78520a <707759223@qq.com> Date: Tue, 29 Mar 2022 14:36:59 +0800 Subject: [PATCH] Revert "[Phi] trans logsumexp op (#40790)" (#41068) This reverts commit 9c0eaadaeadf8d8e74b6a8db9f9b3fb6c700fa5c. --- .../operators/reduce_ops/logsumexp_op.cc | 93 +++++++++- .../operators/reduce_ops/logsumexp_op.cu} | 14 +- .../fluid/operators/reduce_ops/logsumexp_op.h | 170 ++++++++++++++++++ .../reduce_ops/logsumexp_op.part.cu} | 15 +- .../operators/reduce_ops/logsumexp_op_xpu.cc | 2 +- paddle/phi/infermeta/unary.cc | 85 --------- paddle/phi/infermeta/unary.h | 6 - .../phi/kernels/gpu/logsumexp_grad_kernel.cu | 22 --- paddle/phi/kernels/gpu/logsumexp_kernel.cu | 23 --- .../kernels/impl/logsumexp_grad_kernel_impl.h | 91 ---------- .../phi/kernels/impl/logsumexp_kernel_impl.h | 100 ----------- paddle/phi/kernels/logsumexp_grad_kernel.h | 31 ---- paddle/phi/kernels/logsumexp_kernel.h | 29 --- paddle/phi/ops/compat/logsumexp_sig.cc | 29 --- 14 files changed, 270 insertions(+), 440 deletions(-) rename paddle/{phi/kernels/cpu/logsumexp_kernel.cc => fluid/operators/reduce_ops/logsumexp_op.cu} (61%) create mode 100644 paddle/fluid/operators/reduce_ops/logsumexp_op.h rename paddle/{phi/kernels/cpu/logsumexp_grad_kernel.cc => fluid/operators/reduce_ops/logsumexp_op.part.cu} (58%) delete mode 100644 paddle/phi/kernels/gpu/logsumexp_grad_kernel.cu delete mode 100644 paddle/phi/kernels/gpu/logsumexp_kernel.cu delete mode 100644 paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h delete mode 100644 paddle/phi/kernels/impl/logsumexp_kernel_impl.h delete mode 100644 paddle/phi/kernels/logsumexp_grad_kernel.h delete mode 100644 paddle/phi/kernels/logsumexp_kernel.h delete mode 100644 paddle/phi/ops/compat/logsumexp_sig.cc diff --git a/paddle/fluid/operators/reduce_ops/logsumexp_op.cc b/paddle/fluid/operators/reduce_ops/logsumexp_op.cc index 0602c73db6bbc..9f0ef19bd6299 100644 --- a/paddle/fluid/operators/reduce_ops/logsumexp_op.cc +++ b/paddle/fluid/operators/reduce_ops/logsumexp_op.cc @@ -12,13 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/operators/reduce_ops/logsumexp_op.h" #include #include #include -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/operators/reduce_ops/reduce_op_function.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -26,6 +23,80 @@ namespace operators { class LogsumexpOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "logsumexp"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "logsumexp"); + auto x_dims = ctx->GetInputDim("X"); + auto x_rank = x_dims.size(); + PADDLE_ENFORCE_LE(x_rank, 4, + platform::errors::InvalidArgument( + "The input tensor X's dimensions of logsumexp " + "should be less or equal than 4. But received X's " + "dimensions = %d, X's shape = [%s].", + x_rank, x_dims)); + auto axis = ctx->Attrs().Get>("axis"); + PADDLE_ENFORCE_GT( + axis.size(), 0, + platform::errors::InvalidArgument( + "The size of axis of logsumexp " + "should be greater than 0. But received the size of axis " + "of logsumexp is %d.", + axis.size())); + + for (size_t i = 0; i < axis.size(); i++) { + PADDLE_ENFORCE_LT(axis[i], x_rank, + platform::errors::InvalidArgument( + "axis[%d] should be in the " + "range [-D, D), where D is the dimensions of X and " + "D is %d. But received axis[%d] = %d.", + i, x_rank, i, axis[i])); + PADDLE_ENFORCE_GE(axis[i], -x_rank, + platform::errors::InvalidArgument( + "axis[%d] should be in the " + "range [-D, D), where D is the dimensions of X and " + "D is %d. But received axis[%d] = %d.", + i, x_rank, i, axis[i])); + if (axis[i] < 0) { + axis[i] += x_rank; + } + } + + bool keepdim = ctx->Attrs().Get("keepdim"); + bool reduce_all = ctx->Attrs().Get("reduce_all"); + auto dims_vector = vectorize(x_dims); + if (reduce_all) { + if (keepdim) + ctx->SetOutputDim("Out", + phi::make_ddim(std::vector(x_rank, 1))); + else + ctx->SetOutputDim("Out", {1}); + } else { + auto dims_vector = vectorize(x_dims); + if (keepdim) { + for (size_t i = 0; i < axis.size(); ++i) { + dims_vector[axis[i]] = 1; + } + } else { + const int kDelFlag = -1; + for (size_t i = 0; i < axis.size(); ++i) { + dims_vector[axis[i]] = kDelFlag; + } + dims_vector.erase( + std::remove(dims_vector.begin(), dims_vector.end(), kDelFlag), + dims_vector.end()); + } + if (!keepdim && dims_vector.size() == 0) { + dims_vector.push_back(1); + } + auto out_dims = phi::make_ddim(dims_vector); + ctx->SetOutputDim("Out", out_dims); + if (axis.size() > 0 && axis[0] != 0) { + // Only pass LoD when not reducing on the first dim. + ctx->ShareLoD("X", /*->*/ "Out"); + } + } + } }; class LogsumexpOpMaker : public framework::OpProtoAndCheckerMaker { @@ -93,10 +164,16 @@ class LogsumexpGradOpMaker : public framework::SingleGradOpMaker { } // namespace paddle namespace ops = paddle::operators; -DECLARE_INFER_SHAPE_FUNCTOR(logsumexp, LogsumexpInferShapeFunctor, - PD_INFER_META(phi::LogsumexpInferMeta)); + REGISTER_OPERATOR(logsumexp, ops::LogsumexpOp, ops::LogsumexpOpMaker, ops::LogsumexpGradOpMaker, - ops::LogsumexpGradOpMaker, - LogsumexpInferShapeFunctor); + ops::LogsumexpGradOpMaker); REGISTER_OPERATOR(logsumexp_grad, ops::LogsumexpGrapOp); + +REGISTER_OP_CPU_KERNEL( + logsumexp, ops::LogsumexpKernel, + ops::LogsumexpKernel); +REGISTER_OP_CPU_KERNEL( + logsumexp_grad, + ops::LogsumexpGradKernel, + ops::LogsumexpGradKernel); diff --git a/paddle/phi/kernels/cpu/logsumexp_kernel.cc b/paddle/fluid/operators/reduce_ops/logsumexp_op.cu similarity index 61% rename from paddle/phi/kernels/cpu/logsumexp_kernel.cc rename to paddle/fluid/operators/reduce_ops/logsumexp_op.cu index 06e0b30a9ca65..86a31595ebaab 100644 --- a/paddle/phi/kernels/cpu/logsumexp_kernel.cc +++ b/paddle/fluid/operators/reduce_ops/logsumexp_op.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,12 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/kernels/logsumexp_kernel.h" +#include "paddle/fluid/operators/reduce_ops/logsumexp_op.h" -#include "paddle/phi/backends/cpu/cpu_context.h" -#include "paddle/phi/core/kernel_registry.h" +namespace ops = paddle::operators; -#include "paddle/phi/kernels/impl/logsumexp_kernel_impl.h" - -PD_REGISTER_KERNEL( - logsumexp, CPU, ALL_LAYOUT, phi::LogsumexpKernel, float, double) {} +REGISTER_OP_CUDA_KERNEL( + logsumexp, ops::LogsumexpKernel, + ops::LogsumexpKernel); diff --git a/paddle/fluid/operators/reduce_ops/logsumexp_op.h b/paddle/fluid/operators/reduce_ops/logsumexp_op.h new file mode 100644 index 0000000000000..4490f08b2129a --- /dev/null +++ b/paddle/fluid/operators/reduce_ops/logsumexp_op.h @@ -0,0 +1,170 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/operators/reduce_ops/reduce_op_function.h" + +namespace paddle { +namespace operators { + +#define HANDLE_DIM(NDIM, RDIM) \ + if (ndim == NDIM && rdim == RDIM) { \ + paddle::operators::ReduceFunctor( \ + context.template device_context(), *input, output, \ + axis, keepdim); \ + } + +struct LogsumexpFunctor { + template + void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) { + auto x_dim = x->dimensions(); + auto t_dim = x_dim; + for (int i = 0; i < static_cast(dim.size()); i++) { + t_dim[dim[i]] = 1; + } + + auto r_dim = x_dim; + for (int i = 0; i < static_cast(r_dim.size()); i++) { + r_dim[i] = 1; + } + for (int i = 0; i < static_cast(dim.size()); i++) { + r_dim[dim[i]] = x_dim[dim[i]]; + } + + auto y_dim = y->dimensions(); + auto x_max = x->maximum(dim); + y->device(place) = + (x_max + + (*x - x_max.reshape(t_dim).broadcast(r_dim)).exp().sum(dim).log()) + .reshape(y_dim); + } +}; + +struct LogsumexpGradFunctor { + template + void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy, + const Dim& dim, int size) { + dx->device(place) = dy->broadcast(dim) * (*x - y->broadcast(dim)).exp(); + } +}; + +template +class LogsumexpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* input = context.Input("X"); + auto* output = context.Output("Out"); + output->mutable_data(context.GetPlace()); + + auto axis = context.Attr>("axis"); + auto keepdim = context.Attr("keepdim"); + auto reduce_all = context.Attr("reduce_all"); + + const auto& input_dim_size = input->dims().size(); + // The dims has full dim, set the reduce_all is True + reduce_all |= (static_cast(axis.size()) == input_dim_size); + + if (reduce_all) { + // Flatten and reduce 1-D tensor + auto x = EigenVector::Flatten(*input); + auto out = EigenScalar::From(*output); + auto& place = + *context.template device_context().eigen_device(); + auto reduce_dim = Eigen::array({{0}}); + LogsumexpFunctor()(place, &x, &out, reduce_dim); + } else { + int ndim = input_dim_size; + int rdim = axis.size(); + // comments for accelerating compiling temporarily. + // HANDLE_DIM(6, 5); + // HANDLE_DIM(6, 4); + // HANDLE_DIM(6, 3); + // HANDLE_DIM(6, 2); + // HANDLE_DIM(6, 1); + // HANDLE_DIM(5, 4); + // HANDLE_DIM(5, 3); + // HANDLE_DIM(5, 2); + // HANDLE_DIM(5, 1); + HANDLE_DIM(4, 3); + HANDLE_DIM(4, 2); + HANDLE_DIM(4, 1); + HANDLE_DIM(3, 2); + HANDLE_DIM(3, 1); + HANDLE_DIM(2, 1); + } + } +}; + +template +class LogsumexpGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* input = context.Input("X"); + auto* output = context.Input("Out"); + auto* output_grad = context.Input(framework::GradVarName("Out")); + auto* input_grad = context.Output(framework::GradVarName("X")); + input_grad->mutable_data(context.GetPlace()); + + auto axis = context.Attr>("axis"); + auto reduce_all = context.Attr("reduce_all"); + const auto input_dim_size = context.Input("X")->dims().size(); + reduce_all |= (static_cast(axis.size()) == input_dim_size); + + if (reduce_all) { + auto x = EigenVector::Flatten(*input); + auto y = EigenVector::Flatten(*output); + auto dy = EigenVector::Flatten(*output_grad); + auto dx = EigenVector::Flatten(*input_grad); + auto& place = + *context.template device_context().eigen_device(); + auto broadcast_dim = + Eigen::array({{static_cast(input->numel())}}); + LogsumexpGradFunctor()(place, &x, &y, &dx, &dy, broadcast_dim, + broadcast_dim[0]); + } else { + int rank = input->dims().size(); + LogsumexpGradFunctor functor; + switch (rank) { + case 1: + ReduceGradFunctor( + context.template device_context(), *input, *output, + *output_grad, input_grad, functor, axis); + break; + case 2: + ReduceGradFunctor( + context.template device_context(), *input, *output, + *output_grad, input_grad, functor, axis); + break; + case 3: + ReduceGradFunctor( + context.template device_context(), *input, *output, + *output_grad, input_grad, functor, axis); + break; + case 4: + ReduceGradFunctor( + context.template device_context(), *input, *output, + *output_grad, input_grad, functor, axis); + break; + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/phi/kernels/cpu/logsumexp_grad_kernel.cc b/paddle/fluid/operators/reduce_ops/logsumexp_op.part.cu similarity index 58% rename from paddle/phi/kernels/cpu/logsumexp_grad_kernel.cc rename to paddle/fluid/operators/reduce_ops/logsumexp_op.part.cu index e0ef67084b445..81124e4f070a5 100644 --- a/paddle/phi/kernels/cpu/logsumexp_grad_kernel.cc +++ b/paddle/fluid/operators/reduce_ops/logsumexp_op.part.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,11 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/kernels/logsumexp_grad_kernel.h" +// .part used to speed up nvcc compile +#include "paddle/fluid/operators/reduce_ops/logsumexp_op.h" -#include "paddle/phi/backends/cpu/cpu_context.h" -#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h" +namespace ops = paddle::operators; -PD_REGISTER_KERNEL( - logsumexp_grad, CPU, ALL_LAYOUT, phi::LogsumexpGradKernel, float, double) {} +REGISTER_OP_CUDA_KERNEL( + logsumexp_grad, + ops::LogsumexpGradKernel, + ops::LogsumexpGradKernel); diff --git a/paddle/fluid/operators/reduce_ops/logsumexp_op_xpu.cc b/paddle/fluid/operators/reduce_ops/logsumexp_op_xpu.cc index 6fb60fa179157..dcb849de0991b 100644 --- a/paddle/fluid/operators/reduce_ops/logsumexp_op_xpu.cc +++ b/paddle/fluid/operators/reduce_ops/logsumexp_op_xpu.cc @@ -14,7 +14,7 @@ #ifdef PADDLE_WITH_XPU -#include "paddle/fluid/operators/reduce_ops/reduce_op_function.h" +#include "paddle/fluid/operators/reduce_ops/logsumexp_op.h" #include "paddle/fluid/platform/device/xpu/xpu_header.h" #include "paddle/fluid/platform/device_context.h" diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index dc1b8685844af..199029a3a094a 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -804,91 +804,6 @@ void KthvalueInferMeta(const MetaTensor& x, indices->set_dtype(x.dtype()); } -void LogsumexpInferMeta(const MetaTensor& input, - const std::vector& axis, - bool keepdim, - bool reduce_all, - MetaTensor* out) { - auto x_dims = input.dims(); - auto x_rank = x_dims.size(); - std::vector formated_axis = axis; - PADDLE_ENFORCE_LE(x_rank, - 4, - errors::InvalidArgument( - "The input tensor X's dimensions of logsumexp " - "should be less or equal than 4. But received X's " - "dimensions = %d, X's shape = [%s].", - x_rank, - x_dims)); - PADDLE_ENFORCE_GT( - axis.size(), - 0, - errors::InvalidArgument( - "The size of axis of logsumexp " - "should be greater than 0. But received the size of axis " - "of logsumexp is %d.", - axis.size())); - - for (size_t i = 0; i < axis.size(); i++) { - PADDLE_ENFORCE_LT(axis[i], - x_rank, - errors::InvalidArgument( - "axis[%d] should be in the " - "range [-D, D), where D is the dimensions of X and " - "D is %d. But received axis[%d] = %d.", - i, - x_rank, - i, - axis[i])); - PADDLE_ENFORCE_GE(axis[i], - -x_rank, - errors::InvalidArgument( - "axis[%d] should be in the " - "range [-D, D), where D is the dimensions of X and " - "D is %d. But received axis[%d] = %d.", - i, - x_rank, - i, - axis[i])); - if (axis[i] < 0) { - formated_axis[i] += x_rank; - } - } - - auto dims_vector = vectorize(x_dims); - if (reduce_all) { - if (keepdim) - out->set_dims(phi::make_ddim(std::vector(x_rank, 1))); - else - out->set_dims({1}); - } else { - auto dims_vector = vectorize(x_dims); - if (keepdim) { - for (size_t i = 0; i < formated_axis.size(); ++i) { - dims_vector[formated_axis[i]] = 1; - } - } else { - const int kDelFlag = -1; - for (size_t i = 0; i < formated_axis.size(); ++i) { - dims_vector[formated_axis[i]] = kDelFlag; - } - dims_vector.erase( - std::remove(dims_vector.begin(), dims_vector.end(), kDelFlag), - dims_vector.end()); - } - if (!keepdim && dims_vector.size() == 0) { - dims_vector.push_back(1); - } - auto out_dims = phi::make_ddim(dims_vector); - out->set_dims(out_dims); - if (formated_axis.size() > 0 && formated_axis[0] != 0) { - // Only pass LoD when not reducing on the first dim. - out->share_lod(input); - } - } - out->set_dtype(input.dtype()); -} - void MatrixPowerInferMeta(const MetaTensor& x, int n, MetaTensor* out) { auto dims = x.dims(); auto n_dim = dims.size(); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index c88e93807ffe7..bae8083ef7191 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -136,12 +136,6 @@ void KthvalueInferMeta(const MetaTensor& x, MetaTensor* indices, MetaConfig = MetaConfig()); -void LogsumexpInferMeta(const MetaTensor& input, - const std::vector& axis, - bool keepdim, - bool reduce_all, - MetaTensor* out); - void MatrixPowerInferMeta(const MetaTensor& x, int n, MetaTensor* out); void MaxOutInferMeta(const MetaTensor& x, diff --git a/paddle/phi/kernels/gpu/logsumexp_grad_kernel.cu b/paddle/phi/kernels/gpu/logsumexp_grad_kernel.cu deleted file mode 100644 index 490b3e9404561..0000000000000 --- a/paddle/phi/kernels/gpu/logsumexp_grad_kernel.cu +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/kernels/logsumexp_grad_kernel.h" - -#include "paddle/phi/backends/cpu/cpu_context.h" -#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h" - -PD_REGISTER_KERNEL( - logsumexp_grad, GPU, ALL_LAYOUT, phi::LogsumexpGradKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/logsumexp_kernel.cu b/paddle/phi/kernels/gpu/logsumexp_kernel.cu deleted file mode 100644 index 0f07a39ab113a..0000000000000 --- a/paddle/phi/kernels/gpu/logsumexp_kernel.cu +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/kernels/logsumexp_kernel.h" - -#include "paddle/phi/backends/cpu/cpu_context.h" -#include "paddle/phi/core/kernel_registry.h" - -#include "paddle/phi/kernels/impl/logsumexp_kernel_impl.h" - -PD_REGISTER_KERNEL( - logsumexp, GPU, ALL_LAYOUT, phi::LogsumexpKernel, float, double) {} diff --git a/paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h b/paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h deleted file mode 100644 index c2583ce8d32df..0000000000000 --- a/paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include -#include - -#include "paddle/phi/kernels/funcs/eigen/common.h" -#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" -#include "paddle/phi/kernels/funcs/reduce_grad_functions.h" -#include "paddle/phi/kernels/logsumexp_grad_kernel.h" - -namespace phi { - -struct LogsumexpGradFunctor { - template - void operator()(const Context& place, - X* x, - Y* y, - DX* dx, - DY* dy, - const Dim& dim, - int size) { - dx->device(place) = dy->broadcast(dim) * (*x - y->broadcast(dim)).exp(); - } -}; - -template -void LogsumexpGradKernel(const Context& dev_ctx, - const DenseTensor& in, - const DenseTensor& out, - const DenseTensor& out_grad, - const std::vector& axis, - bool keepdim, - bool reduce_all, - DenseTensor* in_grad) { - dev_ctx.template Alloc(in_grad); - - const auto input_dim_size = in.dims().size(); - reduce_all |= (static_cast(axis.size()) == input_dim_size); - - if (reduce_all) { - auto x = phi::EigenVector::Flatten(in); - auto y = phi::EigenVector::Flatten(out); - auto dy = phi::EigenVector::Flatten(out_grad); - auto dx = phi::EigenVector::Flatten(*in_grad); - auto& place = *dev_ctx.eigen_device(); - auto broadcast_dim = Eigen::array({{static_cast(in.numel())}}); - LogsumexpGradFunctor()( - place, &x, &y, &dx, &dy, broadcast_dim, broadcast_dim[0]); - } else { - int rank = in.dims().size(); - LogsumexpGradFunctor functor; - switch (rank) { - case 1: - phi::funcs::ReduceGradFunctor( - dev_ctx, in, out, out_grad, in_grad, functor, axis); - break; - case 2: - phi::funcs::ReduceGradFunctor( - dev_ctx, in, out, out_grad, in_grad, functor, axis); - break; - case 3: - phi::funcs::ReduceGradFunctor( - dev_ctx, in, out, out_grad, in_grad, functor, axis); - break; - case 4: - phi::funcs::ReduceGradFunctor( - dev_ctx, in, out, out_grad, in_grad, functor, axis); - break; - } - } -} - -} // namespace phi diff --git a/paddle/phi/kernels/impl/logsumexp_kernel_impl.h b/paddle/phi/kernels/impl/logsumexp_kernel_impl.h deleted file mode 100644 index 7a9573ff522b0..0000000000000 --- a/paddle/phi/kernels/impl/logsumexp_kernel_impl.h +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include -#include - -#include "paddle/phi/kernels/cpu/reduce.h" -#include "paddle/phi/kernels/funcs/eigen/common.h" -#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" -#include "paddle/phi/kernels/logsumexp_kernel.h" - -namespace phi { - -#define HANDLE_DIM(NDIM, RDIM) \ - if (ndim == NDIM && rdim == RDIM) { \ - ReduceFunctor( \ - dev_ctx, x, out, axis, keepdim); \ - } - -struct LogsumexpFunctor { - template - void operator()(const Context& place, X* x, Y* y, const Dim& dim) { - auto x_dim = x->dimensions(); - auto t_dim = x_dim; - for (int i = 0; i < static_cast(dim.size()); i++) { - t_dim[dim[i]] = 1; - } - - auto r_dim = x_dim; - for (int i = 0; i < static_cast(r_dim.size()); i++) { - r_dim[i] = 1; - } - for (int i = 0; i < static_cast(dim.size()); i++) { - r_dim[dim[i]] = x_dim[dim[i]]; - } - - auto y_dim = y->dimensions(); - auto x_max = x->maximum(dim); - y->device(place) = - (x_max + - (*x - x_max.reshape(t_dim).broadcast(r_dim)).exp().sum(dim).log()) - .reshape(y_dim); - } -}; - -template -void LogsumexpKernel(const Context& dev_ctx, - const DenseTensor& x, - const std::vector& axis, - bool keepdim, - bool reduce_all, - DenseTensor* out) { - dev_ctx.template Alloc(out); - - const auto& input_dim_size = x.dims().size(); - // The dims has full dim, set the reduce_all is True - reduce_all |= (static_cast(axis.size()) == input_dim_size); - - if (reduce_all) { - // Flatten and reduce 1-D tensor - auto input = phi::EigenVector::Flatten(x); - auto output = phi::EigenScalar::From(*out); - auto& place = *dev_ctx.eigen_device(); - auto reduce_dim = Eigen::array({{0}}); - LogsumexpFunctor()(place, &input, &output, reduce_dim); - } else { - int ndim = input_dim_size; - int rdim = axis.size(); - // comments for accelerating compiling temporarily. - // HANDLE_DIM(6, 5); - // HANDLE_DIM(6, 4); - // HANDLE_DIM(6, 3); - // HANDLE_DIM(6, 2); - // HANDLE_DIM(6, 1); - // HANDLE_DIM(5, 4); - // HANDLE_DIM(5, 3); - // HANDLE_DIM(5, 2); - // HANDLE_DIM(5, 1); - HANDLE_DIM(4, 3); - HANDLE_DIM(4, 2); - HANDLE_DIM(4, 1); - HANDLE_DIM(3, 2); - HANDLE_DIM(3, 1); - HANDLE_DIM(2, 1); - } -} - -} // namespace phi diff --git a/paddle/phi/kernels/logsumexp_grad_kernel.h b/paddle/phi/kernels/logsumexp_grad_kernel.h deleted file mode 100644 index d68c447aa65cb..0000000000000 --- a/paddle/phi/kernels/logsumexp_grad_kernel.h +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/phi/core/dense_tensor.h" - -namespace phi { - -template -void LogsumexpGradKernel(const Context& ctx, - const DenseTensor& in, - const DenseTensor& out, - const DenseTensor& out_grad, - const std::vector& axis, - bool keepdim, - bool reduce_all, - DenseTensor* in_grad); - -} // namespace phi diff --git a/paddle/phi/kernels/logsumexp_kernel.h b/paddle/phi/kernels/logsumexp_kernel.h deleted file mode 100644 index ba1b18230fa52..0000000000000 --- a/paddle/phi/kernels/logsumexp_kernel.h +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/phi/core/dense_tensor.h" - -namespace phi { - -template -void LogsumexpKernel(const Context& ctx, - const DenseTensor& x, - const std::vector& axis, - bool keepdim, - bool reduce_all, - DenseTensor* out); - -} // namespace phi diff --git a/paddle/phi/ops/compat/logsumexp_sig.cc b/paddle/phi/ops/compat/logsumexp_sig.cc deleted file mode 100644 index ca7345dbe7049..0000000000000 --- a/paddle/phi/ops/compat/logsumexp_sig.cc +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -KernelSignature LogsumexpGradOpArgumentMapping( - const ArgumentMappingContext& ctx) { - return KernelSignature("logsumexp_grad", - {"X", "Out", GradVarName("Out")}, - {"axis", "keepdim", "reduce_all"}, - {GradVarName("X")}); -} - -} // namespace phi - -PD_REGISTER_ARG_MAPPING_FN(logsumexp_grad, phi::LogsumexpGradOpArgumentMapping);