diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 8f7b62a2c9d27..632ea92746210 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -1499,6 +1499,12 @@ REGISTER_ACTIVATION_OP(logsigmoid, LogSigmoid, LogSigmoidFunctor, REGISTER_ACTIVATION_OP(log2, Log2, Log2Functor, Log2GradFunctor); REGISTER_ACTIVATION_OP(log10, Log10, Log10Functor, Log10GradFunctor); REGISTER_ACTIVATION_OP(log1p, Log1p, Log1pFunctor, Log1pGradFunctor); +REGISTER_ACTIVATION_OP(hard_swish, HardSwish, HardSwishFunctor, + HardSwishGradFunctor); +REGISTER_ACTIVATION_OP(swish, Swish, SwishFunctor, SwishGradFunctor); +REGISTER_ACTIVATION_OP(round, Round, RoundFunctor, ZeroGradFunctor); +REGISTER_ACTIVATION_OP(floor, Floor, FloorFunctor, ZeroGradFunctor); +REGISTER_ACTIVATION_OP(ceil, Ceil, CeilFunctor, ZeroGradFunctor); /* ========================== sigmoid register ============================= */ @@ -1778,18 +1784,6 @@ REGISTER_OPERATOR( ops::ActFwdInplaceInferer, void>::type); REGISTER_OPERATOR(pow_grad, ops::PowOpGrad, ops::ActivationGradOpInplaceInferer); - -REGISTER_OP_CPU_KERNEL( - pow, ops::PowKernel>, - ops::PowKernel>, - ops::PowKernel>, - ops::PowKernel>); -REGISTER_OP_CPU_KERNEL( - pow_grad, - ops::PowGradKernel>, - ops::PowGradKernel>, - ops::PowGradKernel>, - ops::PowGradKernel>); /* ========================================================================== */ /* ========================== exp register ============================ */ diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index 7db5675c16b2d..5a72f2086c79a 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -286,10 +286,25 @@ USE_PHI_DOUBLE_GRAD_FUNCTOR(Log) USE_PHI_FUNCTOR(Log2) USE_PHI_FUNCTOR(Log10) USE_PHI_FUNCTOR(Log1p) +USE_PHI_FUNCTOR(Swish) +USE_PHI_FUNCTOR(HardSwish) +USE_PHI_FUNCTOR(Pow) template using ELUGradNegativeAlphaFunctor = phi::funcs::ELUGradNegativeAlphaFunctor; +template +using RoundFunctor = phi::funcs::RoundFunctor; + +template +using FloorFunctor = phi::funcs::FloorFunctor; + +template +using CeilFunctor = phi::funcs::CeilFunctor; + +template +using ZeroGradFunctor = phi::funcs::ZeroGradFunctor; + // exp(x) = e^x template struct ExpFunctor : public BaseActivationFunctor { @@ -391,46 +406,6 @@ struct RsqrtGradFunctor : public BaseActivationFunctor { } }; -// ceil(x) = ceiling(x) -template -struct CeilFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out) const { - out.device(d) = x.ceil(); - } -}; - -template -struct ZeroGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - dx.device(d) = static_cast(0) * out; - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { - return ActBwdOpFwdDeps::kNoDeps; - } -}; - -// floor(x) = flooring(x) -template -struct FloorFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out) const { - out.device(d) = x.floor(); - } -}; - -// round(x) = [x] -template -struct RoundFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out) const { - out.device(d) = x.round(); - } -}; - // reciprocal(x) = 1 / x template struct ReciprocalFunctor : public BaseActivationFunctor { @@ -509,51 +484,6 @@ struct Relu6GradFunctor : public BaseActivationFunctor { } }; -// HardSwish = min(max(0, x+3), 6) * x / 6 -template -struct HardSwishFunctor : public BaseActivationFunctor { - float threshold; - float scale; - float offset; - - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}}; - } - - template - void operator()(Device d, X x, Out out) const { - out.device(d) = (x + static_cast(offset)) - .cwiseMax(static_cast(0)) - .cwiseMin(static_cast(threshold)) * - x / static_cast(scale); - } -}; - -template -struct HardSwishGradFunctor : public BaseActivationFunctor { - float threshold; - float scale; - float offset; - - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}}; - } - template - void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - auto tmp = ((x + static_cast(offset)) < static_cast(threshold)) - .template cast(); - dx.device(d) = - dout * - (((x + static_cast(offset)) > static_cast(0)).template cast() * - (static_cast(2) * x + static_cast(offset)) / - static_cast(scale) * tmp + - static_cast(1) * (static_cast(1) - tmp)); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } -}; - // For numerical stability, using the following formula instead of softplus(x) = // log(1 + exp(x)) // softplus(x) = log(1 + exp(beta * x)) / beta when beta * x <= threshold(beta = @@ -776,35 +706,6 @@ struct CELUGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; -// FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5198 -template -struct PowFunctor : public BaseActivationFunctor { - float factor; - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"factor", &factor}}; - } - template - void operator()(Device d, X x, Out out) const { - out.device(d) = x.pow(static_cast(factor)); - } -}; - -template -struct PowGradFunctor : public BaseActivationFunctor { - float factor; - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"factor", &factor}}; - } - template - void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - dx.device(d) = dout * static_cast(factor) * - x.pow(static_cast(factor) - static_cast(1)); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } -}; - template struct LogitFunctor { template @@ -870,39 +771,6 @@ struct STanhGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; -template -struct SwishFunctor : public BaseActivationFunctor { - float beta; - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"beta", &beta}}; - } - - template - void operator()(Device d, X x, Out out) const { - out.device(d) = x / (static_cast(1) + (static_cast(-beta) * x).exp()); - } -}; - -template -struct SwishGradFunctor : public BaseActivationFunctor { - float beta; - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"beta", &beta}}; - } - - template - void operator()(Device d, X x, Out fake_out, dOut dout, dX dx) const { - auto temp1 = static_cast(1) / - (static_cast(1) + (static_cast(-beta) * x).exp()); - auto out = x * temp1; - auto temp2 = temp1 * (static_cast(1) - (static_cast(beta) * out)); - dx.device(d) = dout * ((static_cast(beta) * out) + temp2); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } -}; - template struct AbsGradGradFunctor : public BaseActivationFunctor { template @@ -1267,110 +1135,6 @@ class RsqrtDoubleGradKernel } }; -template -class PowKernel : public framework::OpKernel { - public: - using T = typename Functor::ELEMENT_TYPE; - - void Compute(const framework::ExecutionContext& context) const override { - const framework::Tensor* X = nullptr; - framework::Tensor* Out = nullptr; - ExtractActivationTensor(context, &X, &Out); - Out->mutable_data(context.GetPlace()); - - auto x = framework::EigenVector::Flatten( - GET_DATA_SAFELY(X, "Input", "X", "Pow")); - auto out = framework::EigenVector::Flatten( - GET_DATA_SAFELY(Out, "Output", "Out", "Pow")); - auto* place = - context.template device_context().eigen_device(); - Functor functor; - auto attrs = functor.GetAttrs(); - for (auto& attr : attrs) { - *attr.second = context.Attr(attr.first); - } - // get FactorTensor - auto* factor_tensor = context.HasInput("FactorTensor") - ? context.Input("FactorTensor") - : nullptr; - if (factor_tensor) { - auto* factor_data = factor_tensor->data(); - framework::Tensor cpu_factor_tensor; - if (platform::is_gpu_place(factor_tensor->place())) { - framework::TensorCopySync(*factor_tensor, platform::CPUPlace(), - &cpu_factor_tensor); - factor_data = cpu_factor_tensor.data(); - } - auto factor = - std::vector(factor_data, factor_data + factor_tensor->numel()); - PADDLE_ENFORCE_EQ( - factor.size(), 1, - platform::errors::InvalidArgument( - "The shape of factor(tensor) must be [1] rather than %d", - factor.size())); - for (auto& attr : attrs) { - *attr.second = factor[0]; - } - } - functor(*place, x, out); - } -}; - -template -class PowGradKernel - : public framework::OpKernel { - public: - using T = typename Functor::ELEMENT_TYPE; - void Compute(const framework::ExecutionContext& context) const override { - const framework::Tensor *X, *Out, *dOut; - framework::Tensor* dX = nullptr; - X = Out = dOut = nullptr; - ExtractActivationGradTensor(context, &X, &Out, &dOut, - &dX); - dX->mutable_data(context.GetPlace()); - auto dout = framework::EigenVector::Flatten( - GET_DATA_SAFELY(dOut, "Input", "Out@GRAD", "PowGrad")); - auto out = framework::EigenVector::Flatten( - GET_DATA_SAFELY(Out, "Input", "Out", "PowGrad")); - auto dx = framework::EigenVector::Flatten( - GET_DATA_SAFELY(dX, "Output", "X@GRAD", "PowGrad")); - auto x = framework::EigenVector::Flatten( - GET_DATA_SAFELY(X, "Input", "X", "PowGrad")); - auto* place = - context.template device_context().eigen_device(); - Functor functor; - auto attrs = functor.GetAttrs(); - for (auto& attr : attrs) { - *attr.second = context.Attr(attr.first); - } - // get FactorTensor - auto* factor_tensor = - context.HasInput("FactorTensor") - ? context.Input("FactorTensor") - : nullptr; - if (factor_tensor) { - auto* factor_data = factor_tensor->data(); - framework::Tensor cpu_factor_tensor; - if (platform::is_gpu_place(factor_tensor->place())) { - framework::TensorCopySync(*factor_tensor, platform::CPUPlace(), - &cpu_factor_tensor); - factor_data = cpu_factor_tensor.data(); - } - auto factor = - std::vector(factor_data, factor_data + factor_tensor->numel()); - PADDLE_ENFORCE_EQ( - factor.size(), 1, - platform::errors::InvalidArgument( - "The shape of factor(tensor) must be [1] rather than %d", - factor.size())); - for (auto& attr : attrs) { - *attr.second = factor[0]; - } - } - functor(*place, x, out, dout, dx); - } -}; - template class LogitKernel : public framework::OpKernel { public: @@ -1418,15 +1182,10 @@ class LogitGradKernel : public framework::OpKernel { } // namespace paddle #define FOR_EACH_ACTIVATION_OP(__macro) \ - __macro(ceil, Ceil, CeilFunctor, ZeroGradFunctor); \ - __macro(floor, Floor, FloorFunctor, ZeroGradFunctor); \ - __macro(round, Round, RoundFunctor, ZeroGradFunctor); \ __macro(reciprocal, Reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \ __macro(soft_relu, SoftRelu, SoftReluFunctor, SoftReluGradFunctor); \ __macro(stanh, STanh, STanhFunctor, STanhGradFunctor); \ __macro(softplus, Softplus, SoftplusFunctor, SoftplusGradFunctor); \ __macro(softsign, Softsign, SoftsignFunctor, SoftsignGradFunctor); \ __macro(relu6, Relu6, Relu6Functor, Relu6GradFunctor); \ - __macro(swish, Swish, SwishFunctor, SwishGradFunctor); \ - __macro(mish, Mish, MishFunctor, MishGradFunctor); \ - __macro(hard_swish, HardSwish, HardSwishFunctor, HardSwishGradFunctor); + __macro(mish, Mish, MishFunctor, MishGradFunctor); diff --git a/paddle/fluid/operators/activation_op.kps b/paddle/fluid/operators/activation_op.kps index bb08cee5bcde9..5118302f778d7 100644 --- a/paddle/fluid/operators/activation_op.kps +++ b/paddle/fluid/operators/activation_op.kps @@ -20,51 +20,6 @@ limitations under the License. */ namespace paddle { namespace operators { -template -struct CudaCeilFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - - // ceil(x) = ceil(x) - __device__ __forceinline__ T operator()(const T arg_x) const { - MPType x = static_cast(arg_x); - return static_cast(ceil(x)); - } -}; - -template -struct CudaFloorFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - - // floor(x) = floor(x) - __device__ __forceinline__ T operator()(const T arg_x) const { - MPType x = static_cast(arg_x); - return static_cast(floor(x)); - } -}; - -template -struct CudaRoundFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - - // round(x) = round(x) - __device__ __forceinline__ T operator()(const T arg_x) const { - MPType x = static_cast(arg_x); - return static_cast(round(x)); - } -}; - -// GradFunctor for ceil, floor and round -template -struct CudaZeroGradFunctor : public BaseActivationFunctor { - __device__ __forceinline__ T operator()(const T x) const { - return static_cast(0.0f); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { - return ActBwdOpFwdDeps::kNoDeps; - } -}; - template struct CudaReciprocalFunctor : public BaseActivationFunctor { T one = static_cast(1.0f); @@ -395,50 +350,6 @@ struct CudaRelu6GradFunctor : public BaseActivationFunctor { } }; -template -struct CudaSwishFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - MPType one = static_cast(1.0f); - float beta; - - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"beta", &beta}}; - } - - // swish(x) = x / (1 + exp(-beta * x)) - __device__ __forceinline__ T operator()(const T arg_x) const { - MPType x = static_cast(arg_x); - MPType b = static_cast(beta); - return static_cast(x / (one + exp(-b * x))); - } -}; - -template -struct CudaSwishGradFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - MPType one = static_cast(1.0f); - float beta; - - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"beta", &beta}}; - } - - // dx = dout * (1 + exp(-b * x) + b * x * exp(-b * x) / (1 + exp(-b * x))^2) - __device__ __forceinline__ T operator()(const T arg_dout, - const T arg_x) const { - MPType dout = static_cast(arg_dout); - MPType x = static_cast(arg_x); - MPType b = static_cast(beta); - MPType temp1 = one / (one + exp(-b * x)); - MPType out = x * temp1; - MPType temp2 = b * out; - MPType temp3 = temp1 * (one - temp2); - return static_cast(dout * (temp2 + temp3)); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } -}; - template struct CudaMishFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; @@ -488,58 +399,6 @@ struct CudaMishGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; -template -struct CudaHardSwishFunctor : public BaseActivationFunctor { - T zero = static_cast(0.0f); - float threshold; - float scale; - float offset; - - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}}; - } - - // hard_swish(x) = 0, when x <= -offset - // x , when x >= threshold - offset - // x * (x + offset) / scale, otherwise - // threshold = scale = 6, offset = 3 by default - __device__ __forceinline__ T operator()(const T x) const { - T t = static_cast(threshold); - T temp = x + static_cast(offset); - T temp_max = temp > zero ? temp : zero; - T temp_min = temp_max < t ? temp_max : t; - return temp_min * x / static_cast(scale); - } -}; - -template -struct CudaHardSwishGradFunctor : public BaseActivationFunctor { - T zero = static_cast(0.0f); - T one = static_cast(1.0f); - T two = static_cast(2.0f); - float threshold; - float scale; - float offset; - - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}}; - } - - // dx = 0, when x <= -offset - // dout , when x >= threshold - offset - // dout * (2 * x / scale + offset / scale), otherwise - // threshold = scale = 6, offset = 3 by default - __device__ __forceinline__ T operator()(const T dout, const T x) const { - T o = static_cast(offset); - T s = static_cast(scale); - T temp1 = static_cast(x + o > zero); - T temp2 = static_cast(x + o < static_cast(threshold)); - return dout * (temp1 * temp2 * (two * x + o) / s + one - temp2); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } -}; - template struct CudaCELUFunctor : public BaseActivationFunctor { using CT = typename details::MPTypeTrait::Type; @@ -684,6 +543,20 @@ USE_PHI_FUNCTOR(CudaLog) USE_PHI_FUNCTOR(CudaLog2) USE_PHI_FUNCTOR(CudaLog10) USE_PHI_FUNCTOR(CudaLog1p) +USE_PHI_FUNCTOR(CudaSwish) +USE_PHI_FUNCTOR(CudaHardSwish) + +template +using CudaRoundFunctor = phi::funcs::CudaRoundFunctor; + +template +using CudaFloorFunctor = phi::funcs::CudaFloorFunctor; + +template +using CudaCeilFunctor = phi::funcs::CudaCeilFunctor; + +template +using CudaZeroGradFunctor = phi::funcs::CudaZeroGradFunctor; template using CudaELUGradNegativeAlphaFunctor = @@ -813,23 +686,6 @@ REGISTER_OP_CUDA_KERNEL( ops::SquareGradGradFunctor>); /* ========================================================================== */ -/* ========================== pow register ============================ */ -REGISTER_OP_CUDA_KERNEL( - pow, ops::PowKernel>, - ops::PowKernel>, - ops::PowKernel>, - ops::PowKernel>, - ops::PowKernel>); -REGISTER_OP_CUDA_KERNEL( - pow_grad, - ops::PowGradKernel>, - ops::PowGradKernel>, - ops::PowGradKernel>, - ops::PowGradKernel>, - ops::PowGradKernel>); -/* ========================================================================== */ - /* ========================== logit register ============================ */ namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( @@ -889,9 +745,6 @@ REGISTER_OP_CUDA_KERNEL( #define FOR_EACH_ACTIVATION_CUDA_OP(__macro) \ __macro(softshrink, SoftShrink, CudaSoftShrinkFunctor, \ CudaSoftShrinkGradFunctor); \ - __macro(ceil, Ceil, CudaCeilFunctor, CudaZeroGradFunctor); \ - __macro(floor, Floor, CudaFloorFunctor, CudaZeroGradFunctor); \ - __macro(round, Round, CudaRoundFunctor, CudaZeroGradFunctor); \ __macro(reciprocal, Reciprocal, CudaReciprocalFunctor, \ CudaReciprocalGradFunctor); \ __macro(soft_relu, SoftRelu, CudaSoftReluFunctor, CudaSoftReluGradFunctor); \ @@ -903,10 +756,7 @@ REGISTER_OP_CUDA_KERNEL( CudaTanhShrinkGradFunctor); \ __macro(hard_shrink, HardShrink, CudaHardShrinkFunctor, \ CudaHardShrinkGradFunctor); \ - __macro(swish, Swish, CudaSwishFunctor, CudaSwishGradFunctor); \ - __macro(mish, Mish, CudaMishFunctor, CudaMishGradFunctor); \ - __macro(hard_swish, HardSwish, CudaHardSwishFunctor, \ - CudaHardSwishGradFunctor); + __macro(mish, Mish, CudaMishFunctor, CudaMishGradFunctor); FOR_EACH_ACTIVATION_CUDA_OP(REGISTER_ACTIVATION_CUDA_KERNEL) #ifdef PADDLE_WITH_XPU_KP diff --git a/paddle/phi/kernels/activation_grad_kernel.h b/paddle/phi/kernels/activation_grad_kernel.h index 6ad28f348f22f..be6f97ad7c96e 100644 --- a/paddle/phi/kernels/activation_grad_kernel.h +++ b/paddle/phi/kernels/activation_grad_kernel.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include "paddle/phi/common/scalar.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/infermeta/unary.h" @@ -50,6 +51,11 @@ namespace phi { const DenseTensor& dout, \ DenseTensor* dx); +#define DECLARE_ACTIVATION_GRAD_KERNEL_NODEP(name) \ + template \ + void name##GradKernel( \ + const Context& dev_ctx, const DenseTensor& dout, DenseTensor* dx); + #define DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPOUT(name, attr) \ template \ void name##GradKernel(const Context& dev_ctx, \ @@ -143,6 +149,22 @@ void LogDoubleGradKernel(const Context& dev_ctx, DenseTensor* dx, DenseTensor* ddout); +template +void HardSwishGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& dout, + float threshold, + float scale, + float offset, + DenseTensor* dx); + +template +void PowGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& dout, + const Scalar& factor, + DenseTensor* dx); + DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Cos); DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Tan); DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Acos); @@ -166,10 +188,15 @@ DECLARE_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu); DECLARE_ACTIVATION_GRAD_KERNEL_DEPOUT(Tanh); DECLARE_ACTIVATION_GRAD_KERNEL_DEPOUT(Sigmoid); +DECLARE_ACTIVATION_GRAD_KERNEL_NODEP(Round); +DECLARE_ACTIVATION_GRAD_KERNEL_NODEP(Floor); +DECLARE_ACTIVATION_GRAD_KERNEL_NODEP(Ceil); + DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(LeakyRelu, alpha); DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(ThresholdedRelu, threshold); DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(SoftShrink, lambda); DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(HardShrink, threshold); +DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Swish, beta); DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(BRelu, t_min, t_max); diff --git a/paddle/phi/kernels/activation_kernel.h b/paddle/phi/kernels/activation_kernel.h index 785d1089f06e8..84c46870e0a17 100644 --- a/paddle/phi/kernels/activation_kernel.h +++ b/paddle/phi/kernels/activation_kernel.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include "paddle/phi/common/scalar.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/infermeta/unary.h" @@ -60,13 +61,32 @@ DECLARE_ACTIVATION_KERNEL(Log) DECLARE_ACTIVATION_KERNEL(Log2) DECLARE_ACTIVATION_KERNEL(Log10) DECLARE_ACTIVATION_KERNEL(Log1p) +DECLARE_ACTIVATION_KERNEL(Round) +DECLARE_ACTIVATION_KERNEL(Floor) +DECLARE_ACTIVATION_KERNEL(Ceil) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(LeakyRelu, alpha) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, threshold) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(SoftShrink, lambda) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(HardShrink, threshold) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Elu, alpha) +DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Swish, beta) DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(BRelu, t_min, t_max) DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(HardSigmoid, slope, offset) + +template +void HardSwishKernel(const Context& dev_ctx, + const DenseTensor& x, + float threshold, + float scale, + float offset, + DenseTensor* out); + +template +void PowKernel(const Context& dev_ctx, + const DenseTensor& x, + const Scalar& factor, + DenseTensor* out); + } // namespace phi diff --git a/paddle/phi/kernels/cpu/activation_grad_kernel.cc b/paddle/phi/kernels/cpu/activation_grad_kernel.cc index 0776e570e9cd3..be0d02e2a14cf 100644 --- a/paddle/phi/kernels/cpu/activation_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_grad_kernel.cc @@ -107,6 +107,15 @@ namespace phi { dev_ctx, nullptr, &out, &dout, dx, functor); \ } +#define DEFINE_CPU_ACTIVATION_GRAD_KERNEL_NODEP(name, functor_class) \ + template \ + void name##GradKernel( \ + const Context& dev_ctx, const DenseTensor& dout, DenseTensor* dx) { \ + funcs::functor_class functor; \ + ActivationGradImpl>( \ + dev_ctx, nullptr, nullptr, &dout, dx, functor); \ + } + DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Cos, CosGradFunctor); DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Tan, TanGradFunctor); DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Acos, AcosGradFunctor); @@ -130,6 +139,10 @@ DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu, ReluGradFunctor); DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Tanh, TanhGradFunctor); DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Sigmoid, SigmoidGradFunctor); +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_NODEP(Round, ZeroGradFunctor); +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_NODEP(Floor, ZeroGradFunctor); +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_NODEP(Ceil, ZeroGradFunctor); + DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(LeakyRelu, LeakyReluGradFunctor, alpha); @@ -142,6 +155,7 @@ DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(SoftShrink, DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(HardShrink, HardShrinkGradFunctor, threshold); +DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Swish, SwishGradFunctor, beta); DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(BRelu, BReluGradFunctor, @@ -183,6 +197,23 @@ void EluGradKernel(const Context& dev_ctx, } } +template +void HardSwishGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& dout, + float threshold, + float scale, + float offset, + DenseTensor* dx) { + funcs::HardSwishGradFunctor functor; + auto attrs = functor.GetAttrs(); + *(attrs[0].second) = threshold; + *(attrs[1].second) = scale; + *(attrs[2].second) = offset; + ActivationGradImpl>( + dev_ctx, &x, nullptr, &dout, dx, functor); +} + } // namespace phi PD_REGISTER_KERNEL( @@ -242,3 +273,17 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(log2_grad, Log2GradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(log10_grad, Log10GradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(log1p_grad, Log1pGradKernel) PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(log_double_grad, LogDoubleGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_swish_grad, HardSwishGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(round_grad, RoundGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(floor_grad, FloorGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(ceil_grad, CeilGradKernel) + +PD_REGISTER_KERNEL(pow_grad, + CPU, + ALL_LAYOUT, + phi::PowGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/cpu/activation_kernel.cc b/paddle/phi/kernels/cpu/activation_kernel.cc index c8709261d2cb0..d55d4cfd0f6bc 100644 --- a/paddle/phi/kernels/cpu/activation_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_kernel.cc @@ -78,6 +78,9 @@ DEFINE_CPU_ACTIVATION_KERNEL(Log, LogFunctor) DEFINE_CPU_ACTIVATION_KERNEL(Log2, Log2Functor) DEFINE_CPU_ACTIVATION_KERNEL(Log10, Log10Functor) DEFINE_CPU_ACTIVATION_KERNEL(Log1p, Log1pFunctor) +DEFINE_CPU_ACTIVATION_KERNEL(Round, RoundFunctor) +DEFINE_CPU_ACTIVATION_KERNEL(Floor, FloorFunctor) +DEFINE_CPU_ACTIVATION_KERNEL(Ceil, CeilFunctor) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, LeakyReluFunctor, alpha) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, @@ -86,6 +89,7 @@ DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink, HardShrinkFunctor, threshold) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(SoftShrink, SoftShrinkFunctor, lambda) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Elu, ELUFunctor, alpha) +DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Swish, SwishFunctor, beta) DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(BRelu, BReluFunctor, t_min, t_max) DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(HardSigmoid, @@ -93,6 +97,22 @@ DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(HardSigmoid, slope, offset) +template +void HardSwishKernel(const Context& dev_ctx, + const DenseTensor& x, + float threshold, + float scale, + float offset, + DenseTensor* out) { + funcs::HardSwishFunctor functor; + auto attrs = functor.GetAttrs(); + *(attrs[0].second) = threshold; + *(attrs[1].second) = scale; + *(attrs[2].second) = offset; + ActivationImpl>( + dev_ctx, x, out, functor); +} + } // namespace phi PD_REGISTER_KERNEL(relu, CPU, ALL_LAYOUT, phi::ReluKernel, float, double) {} @@ -126,3 +146,10 @@ PD_REGISTER_ACTIVATION_KERNEL(log, LogKernel) PD_REGISTER_ACTIVATION_KERNEL(log2, Log2Kernel) PD_REGISTER_ACTIVATION_KERNEL(log10, Log10Kernel) PD_REGISTER_ACTIVATION_KERNEL(log1p, Log1pKernel) +PD_REGISTER_ACTIVATION_KERNEL(hard_swish, HardSwishKernel) +PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel) +PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel) +PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel) +PD_REGISTER_ACTIVATION_KERNEL(ceil, CeilKernel) +PD_REGISTER_KERNEL( + pow, CPU, ALL_LAYOUT, phi::PowKernel, float, double, int, int64_t) {} diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index 6e536bd00a4a1..bcadc59126198 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -1350,6 +1350,165 @@ struct LogGradGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +// HardSwish = min(max(0, x+3), 6) * x / 6 +template +struct HardSwishFunctor : public BaseActivationFunctor { + float threshold; + float scale; + float offset; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}}; + } + + template + void operator()(Device d, X x, Out out) const { + out.device(d) = (x + static_cast(offset)) + .cwiseMax(static_cast(0)) + .cwiseMin(static_cast(threshold)) * + x / static_cast(scale); + } +}; + +template +struct HardSwishGradFunctor : public BaseActivationFunctor { + float threshold; + float scale; + float offset; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}}; + } + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + auto tmp = ((x + static_cast(offset)) < static_cast(threshold)) + .template cast(); + dx.device(d) = + dout * + (((x + static_cast(offset)) > static_cast(0)).template cast() * + (static_cast(2) * x + static_cast(offset)) / + static_cast(scale) * tmp + + static_cast(1) * (static_cast(1) - tmp)); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + +template +struct SwishFunctor : public BaseActivationFunctor { + float beta; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"beta", &beta}}; + } + + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x / (static_cast(1) + (static_cast(-beta) * x).exp()); + } +}; + +template +struct SwishGradFunctor : public BaseActivationFunctor { + float beta; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"beta", &beta}}; + } + + template + void operator()(Device d, X x, Out fake_out, dOut dout, dX dx) const { + auto temp1 = static_cast(1) / + (static_cast(1) + (static_cast(-beta) * x).exp()); + auto out = x * temp1; + auto temp2 = temp1 * (static_cast(1) - (static_cast(beta) * out)); + dx.device(d) = dout * ((static_cast(beta) * out) + temp2); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + +// FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5198 +template +struct PowFunctor : public BaseActivationFunctor { + float factor; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"factor", &factor}}; + } + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.pow(static_cast(factor)); + } +}; + +template +struct PowGradFunctor : public BaseActivationFunctor { + float factor; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"factor", &factor}}; + } + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * static_cast(factor) * + x.pow(static_cast(factor) - static_cast(1)); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + +// floor(x) = flooring(x) +template +struct FloorFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.floor(); + } +}; + +// round(x) = [x] +template +struct RoundFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.round(); + } +}; + +// ceil(x) = ceiling(x) +template +struct CeilFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.ceil(); + } +}; + +template +struct ZeroGradFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = static_cast(0) * out; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kNoDeps; + } +}; + #if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__) template struct CudaReluFunctor : public BaseActivationFunctor { @@ -2190,6 +2349,147 @@ struct CudaLog10GradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct CudaSwishFunctor : public BaseActivationFunctor { + using MPType = typename phi::dtype::MPTypeTrait::Type; + MPType one = static_cast(1.0f); + float beta; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"beta", &beta}}; + } + + // swish(x) = x / (1 + exp(-beta * x)) + __device__ __forceinline__ T operator()(const T arg_x) const { + MPType x = static_cast(arg_x); + MPType b = static_cast(beta); + return static_cast(x / (one + exp(-b * x))); + } +}; + +template +struct CudaSwishGradFunctor : public BaseActivationFunctor { + using MPType = typename phi::dtype::MPTypeTrait::Type; + MPType one = static_cast(1.0f); + float beta; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"beta", &beta}}; + } + + // dx = dout * (1 + exp(-b * x) + b * x * exp(-b * x) / (1 + exp(-b * x))^2) + __device__ __forceinline__ T operator()(const T arg_dout, + const T arg_x) const { + MPType dout = static_cast(arg_dout); + MPType x = static_cast(arg_x); + MPType b = static_cast(beta); + MPType temp1 = one / (one + exp(-b * x)); + MPType out = x * temp1; + MPType temp2 = b * out; + MPType temp3 = temp1 * (one - temp2); + return static_cast(dout * (temp2 + temp3)); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + +template +struct CudaHardSwishFunctor : public BaseActivationFunctor { + T zero = static_cast(0.0f); + float threshold; + float scale; + float offset; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}}; + } + + // hard_swish(x) = 0, when x <= -offset + // x , when x >= threshold - offset + // x * (x + offset) / scale, otherwise + // threshold = scale = 6, offset = 3 by default + __device__ __forceinline__ T operator()(const T x) const { + T t = static_cast(threshold); + T temp = x + static_cast(offset); + T temp_max = temp > zero ? temp : zero; + T temp_min = temp_max < t ? temp_max : t; + return temp_min * x / static_cast(scale); + } +}; + +template +struct CudaHardSwishGradFunctor : public BaseActivationFunctor { + T zero = static_cast(0.0f); + T one = static_cast(1.0f); + T two = static_cast(2.0f); + float threshold; + float scale; + float offset; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}}; + } + + // dx = 0, when x <= -offset + // dout , when x >= threshold - offset + // dout * (2 * x / scale + offset / scale), otherwise + // threshold = scale = 6, offset = 3 by default + __device__ __forceinline__ T operator()(const T dout, const T x) const { + T o = static_cast(offset); + T s = static_cast(scale); + T temp1 = static_cast(x + o > zero); + T temp2 = static_cast(x + o < static_cast(threshold)); + return dout * (temp1 * temp2 * (two * x + o) / s + one - temp2); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + +template +struct CudaCeilFunctor : public BaseActivationFunctor { + using MPType = typename phi::dtype::MPTypeTrait::Type; + + // ceil(x) = ceil(x) + __device__ __forceinline__ T operator()(const T arg_x) const { + MPType x = static_cast(arg_x); + return static_cast(ceil(x)); + } +}; + +template +struct CudaFloorFunctor : public BaseActivationFunctor { + using MPType = typename phi::dtype::MPTypeTrait::Type; + + // floor(x) = floor(x) + __device__ __forceinline__ T operator()(const T arg_x) const { + MPType x = static_cast(arg_x); + return static_cast(floor(x)); + } +}; + +template +struct CudaRoundFunctor : public BaseActivationFunctor { + using MPType = typename phi::dtype::MPTypeTrait::Type; + + // round(x) = round(x) + __device__ __forceinline__ T operator()(const T arg_x) const { + MPType x = static_cast(arg_x); + return static_cast(round(x)); + } +}; + +// GradFunctor for ceil, floor and round +template +struct CudaZeroGradFunctor : public BaseActivationFunctor { + __device__ __forceinline__ T operator()(const T x) const { + return static_cast(0.0f); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kNoDeps; + } +}; + #endif } // namespace funcs diff --git a/paddle/phi/kernels/gpu/activation_grad_kernel.cu b/paddle/phi/kernels/gpu/activation_grad_kernel.cu index 3cc41555a898b..3c8b338d86b8c 100644 --- a/paddle/phi/kernels/gpu/activation_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_grad_kernel.cu @@ -159,10 +159,23 @@ void ActivationGradGPUImpl(const Context& dev_ctx, dev_ctx, nullptr, &out, &dout, dx, functor); \ } +#define DEFINE_GPU_ACTIVATION_GRAD_KERNEL_NODEP(name, functor_class) \ + template \ + void name##GradKernel( \ + const Context& dev_ctx, const DenseTensor& dout, DenseTensor* dx) { \ + funcs::functor_class functor; \ + ActivationGradGPUImpl>( \ + dev_ctx, nullptr, nullptr, &dout, dx, functor); \ + } + DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu, CudaReluGradFunctor); DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Tanh, CudaTanhGradFunctor); DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Sigmoid, CudaSigmoidGradFunctor); +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_NODEP(Round, CudaZeroGradFunctor); +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_NODEP(Floor, CudaZeroGradFunctor); +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_NODEP(Ceil, CudaZeroGradFunctor); + DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Cos, CudaCosGradFunctor); DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Tan, CudaTanGradFunctor); DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Acos, CudaAcosGradFunctor); @@ -194,6 +207,9 @@ DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(SoftShrink, DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(HardShrink, CudaHardShrinkGradFunctor, threshold); +DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Swish, + CudaSwishGradFunctor, + beta); DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(BRelu, CudaBReluGradFunctor, @@ -227,6 +243,23 @@ void EluGradKernel(const Context& dev_ctx, } } +template +void HardSwishGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& dout, + float threshold, + float scale, + float offset, + DenseTensor* dx) { + funcs::CudaHardSwishGradFunctor functor; + auto attrs = functor.GetAttrs(); + *(attrs[0].second) = threshold; + *(attrs[1].second) = scale; + *(attrs[2].second) = offset; + ActivationGradGPUImpl>( + dev_ctx, &x, nullptr, &dout, dx, functor); +} + } // namespace phi #ifdef PADDLE_WITH_HIP @@ -315,3 +348,18 @@ PD_REGISTER_KERNEL(log_double_grad, float, double, phi::dtype::float16) {} +PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_swish_grad, HardSwishGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(round_grad, RoundGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(floor_grad, FloorGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(ceil_grad, CeilGradKernel) + +PD_REGISTER_KERNEL(pow_grad, + GPU, + ALL_LAYOUT, + phi::PowGradKernel, + float, + double, + int, + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/activation_kernel.cu b/paddle/phi/kernels/gpu/activation_kernel.cu index fb4e2e07b21cb..75003cf342abd 100644 --- a/paddle/phi/kernels/gpu/activation_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_kernel.cu @@ -97,6 +97,9 @@ DEFINE_GPU_ACTIVATION_KERNEL(Log, CudaLogFunctor) DEFINE_GPU_ACTIVATION_KERNEL(Log2, CudaLog2Functor) DEFINE_GPU_ACTIVATION_KERNEL(Log10, CudaLog10Functor) DEFINE_GPU_ACTIVATION_KERNEL(Log1p, CudaLog1pFunctor) +DEFINE_GPU_ACTIVATION_KERNEL(Round, CudaRoundFunctor) +DEFINE_GPU_ACTIVATION_KERNEL(Floor, CudaFloorFunctor) +DEFINE_GPU_ACTIVATION_KERNEL(Ceil, CudaCeilFunctor) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, CudaLeakyReluFunctor, alpha) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, @@ -107,6 +110,7 @@ DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink, threshold) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(SoftShrink, CudaSoftShrinkFunctor, lambda) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Elu, CudaELUFunctor, alpha) +DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Swish, CudaSwishFunctor, beta) DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(BRelu, CudaBReluFunctor, t_min, t_max) DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(HardSigmoid, @@ -114,6 +118,22 @@ DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(HardSigmoid, slope, offset) +template +void HardSwishKernel(const Context& dev_ctx, + const DenseTensor& x, + float threshold, + float scale, + float offset, + DenseTensor* out) { + funcs::CudaHardSwishFunctor functor; + auto attrs = functor.GetAttrs(); + *(attrs[0].second) = threshold; + *(attrs[1].second) = scale; + *(attrs[2].second) = offset; + ActivationGPUImpl>( + dev_ctx, x, out, functor); +} + } // namespace phi #ifdef PADDLE_WITH_HIP @@ -172,3 +192,17 @@ PD_REGISTER_ACTIVATION_KERNEL(log, LogKernel) PD_REGISTER_ACTIVATION_KERNEL(log2, Log2Kernel) PD_REGISTER_ACTIVATION_KERNEL(log10, Log10Kernel) PD_REGISTER_ACTIVATION_KERNEL(log1p, Log1pKernel) +PD_REGISTER_ACTIVATION_KERNEL(hard_swish, HardSwishKernel) +PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel) +PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel) +PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel) +PD_REGISTER_ACTIVATION_KERNEL(ceil, CeilKernel) +PD_REGISTER_KERNEL(pow, + GPU, + ALL_LAYOUT, + phi::PowKernel, + float, + double, + int, + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/impl/activation_grad_impl.h b/paddle/phi/kernels/impl/activation_grad_impl.h index 7ef8a0887c75c..7924276414e29 100644 --- a/paddle/phi/kernels/impl/activation_grad_impl.h +++ b/paddle/phi/kernels/impl/activation_grad_impl.h @@ -293,4 +293,28 @@ void LogDoubleGradKernel(const Context& dev_ctx, functor(dev_ctx, &x, &ddx, ddout, &dout, dx); } +template +void PowGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& dout, + const Scalar& factor, + DenseTensor* dx) { + PADDLE_ENFORCE_NOT_NULL( + dx, errors::NotFound("The output DenseTensor dX can not be nullptr")); + if (dx) { + dev_ctx.template Alloc(dx); + } + auto dout_flatten = EigenVector::Flatten( + GET_DATA_SAFELY(&dout, "Input", "Out@GRAD", "PowGrad")); + auto dx_flatten = EigenVector::Flatten( + GET_DATA_SAFELY(dx, "Output", "X@GRAD", "PowGrad")); + auto x_flatten = + EigenVector::Flatten(GET_DATA_SAFELY(&x, "Input", "X", "PowGrad")); + auto* place = dev_ctx.eigen_device(); + phi::funcs::PowGradFunctor functor; + auto attrs = functor.GetAttrs(); + *(attrs[0].second) = factor.to(); + functor(*place, x_flatten, nullptr, dout_flatten, dx_flatten); +} + } // namespace phi diff --git a/paddle/phi/kernels/impl/activation_impl.h b/paddle/phi/kernels/impl/activation_impl.h index ca3debd394a1e..c2d160caf7b1d 100644 --- a/paddle/phi/kernels/impl/activation_impl.h +++ b/paddle/phi/kernels/impl/activation_impl.h @@ -47,4 +47,23 @@ void ActivationImpl(const Context& dev_ctx, } } +template +void PowKernel(const Context& dev_ctx, + const DenseTensor& x, + const Scalar& factor, + DenseTensor* out) { + PADDLE_ENFORCE_NOT_NULL(out, + errors::NotFound("Output Out should not be nullptr")); + dev_ctx.template Alloc(out); + auto x_flatten = phi::EigenVector::Flatten( + GET_DATA_SAFELY(&x, "Input", "X", "Activation")); + auto out_flatten = phi::EigenVector::Flatten( + GET_DATA_SAFELY(out, "Output", "Out", "Activation")); + auto* place = dev_ctx.eigen_device(); + phi::funcs::PowFunctor functor; + auto attrs = functor.GetAttrs(); + *(attrs[0].second) = factor.to(); + functor(*place, x_flatten, out_flatten); +} + } // namespace phi diff --git a/paddle/phi/ops/compat/activation_sig.cc b/paddle/phi/ops/compat/activation_sig.cc index 8b4884e35b608..7919769ec85a9 100644 --- a/paddle/phi/ops/compat/activation_sig.cc +++ b/paddle/phi/ops/compat/activation_sig.cc @@ -34,6 +34,13 @@ namespace phi { {GradVarName("X")}); \ } +#define DEFINE_ACT_GRAD_NODEP_OP_ARGMAP(func_name, op_name, attrs) \ + KernelSignature func_name##GradOpArgumentMapping( \ + const ArgumentMappingContext& ctx) { \ + return KernelSignature( \ + op_name "_grad", {GradVarName("Out")}, {attrs}, {GradVarName("X")}); \ + } + #define comma , DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Cos, "cos", ); // NOLINT @@ -61,6 +68,11 @@ DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log, "log", ); // NOLINT DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log2, "log2", ); // NOLINT DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log10, "log10", ); // NOLINT DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log1p, "log1p", ); // NOLINT +DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(HardSwish, + "hard_swish", + "threshold" comma "scale" comma + "offset"); // NOLINT +DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Swish, "swish", "beta"); // NOLINT DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Relu, "relu", ); // NOLINT DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Tanh, "tanh", ); // NOLINT @@ -69,6 +81,10 @@ DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(HardSigmoid, "hard_sigmoid", "slope" comma "offset"); // NOLINT +DEFINE_ACT_GRAD_NODEP_OP_ARGMAP(Round, "round", ); // NOLINT +DEFINE_ACT_GRAD_NODEP_OP_ARGMAP(Floor, "floor", ); // NOLINT +DEFINE_ACT_GRAD_NODEP_OP_ARGMAP(Ceil, "ceil", ); // NOLINT + KernelSignature ReluDoubleGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("relu_double_grad", {"Out", "DDX"}, {}, {"DDOut"}); @@ -135,6 +151,26 @@ KernelSignature LogDoubleGradOpArgumentMapping( "log_double_grad", {"X", "DOut", "DDX"}, {}, {"DX", "DDOut"}); } +KernelSignature PowOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.HasInput("FactorTensor")) { + return KernelSignature("pow", {"X"}, {"FactorTensor"}, {"Out"}); + } else { + return KernelSignature("pow", {"X"}, {"factor"}, {"Out"}); + } +} + +KernelSignature PowGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.HasInput("FactorTensor")) { + return KernelSignature("pow_grad", + {"X", GradVarName("Out")}, + {"FactorTensor"}, + {GradVarName("X")}); + } else { + return KernelSignature( + "pow_grad", {"X", GradVarName("Out")}, {"factor"}, {GradVarName("X")}); + } +} + } // namespace phi PD_REGISTER_BASE_KERNEL_NAME(relu_grad_grad, relu_double_grad); @@ -197,3 +233,11 @@ PD_REGISTER_ARG_MAPPING_FN(log_grad_grad, phi::LogDoubleGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(log2_grad, phi::Log2GradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(log10_grad, phi::Log10GradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(log1p_grad, phi::Log1pGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(hard_swish_grad, + phi::HardSwishGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(swish_grad, phi::SwishGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(round_grad, phi::RoundGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(floor_grad, phi::FloorGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(ceil_grad, phi::CeilGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(pow_grad, phi::PowGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(pow, phi::PowOpArgumentMapping);