diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 4205f2253a652..c835cf8ea1480 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -1485,6 +1485,13 @@ REGISTER_ACTIVATION_OP(atanh, Atanh, AtanhFunctor, AtanhGradFunctor); REGISTER_ACTIVATION_OP(brelu, BRelu, BReluFunctor, BReluGradFunctor); REGISTER_ACTIVATION_OP(thresholded_relu, ThresholdedRelu, ThresholdedReluFunctor, ThresholdedReluGradFunctor); +REGISTER_ACTIVATION_OP(hard_shrink, HardShrink, HardShrinkFunctor, + HardShrinkGradFunctor); +REGISTER_ACTIVATION_OP(softshrink, SoftShrink, SoftShrinkFunctor, + SoftShrinkGradFunctor); +REGISTER_ACTIVATION_OP(tanh_shrink, TanhShrink, TanhShrinkFunctor, + TanhShrinkGradFunctor); +REGISTER_ACTIVATION_OP(silu, Silu, SiluFunctor, SiluGradFunctor); /* ========================== sigmoid register ============================= */ @@ -1626,22 +1633,6 @@ REGISTER_OPERATOR( ops::ActivationOpDoubleGrad::FwdDeps()>, ops::ActivationDoubleGradOpInplaceInferer); -REGISTER_OP_CPU_KERNEL(elu, - ops::ActivationKernel>, - ops::ActivationKernel>); -REGISTER_OP_CPU_KERNEL( - elu_grad, ops::ELUGradKernel, - ops::ELUGradKernel); -REGISTER_OP_CPU_KERNEL( - elu_grad_grad, ops::ELUDoubleGradKernel>, - ops::ELUDoubleGradKernel>, - ops::ELUDoubleGradKernel>); - /* ========================================================================== */ /* ======================== logit register ============================ diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index b076db01c22c6..4f197b95b2174 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -279,6 +279,15 @@ USE_PHI_FUNCTOR(BRelu) USE_PHI_FUNCTOR(ThresholdedRelu) USE_PHI_FUNCTOR(LeakyRelu) USE_PHI_DOUBLE_GRAD_FUNCTOR(LeakyRelu) +USE_PHI_FUNCTOR(HardShrink) +USE_PHI_FUNCTOR(SoftShrink) +USE_PHI_FUNCTOR(TanhShrink) +USE_PHI_FUNCTOR(Silu) +USE_PHI_FUNCTOR(ELU) +USE_PHI_DOUBLE_GRAD_FUNCTOR(ELU) + +template +using ELUGradNegativeAlphaFunctor = phi::funcs::ELUGradNegativeAlphaFunctor; template struct SigmoidGradFunctor : public BaseActivationFunctor { @@ -392,31 +401,6 @@ struct SigmoidTripleGradFunctor : public BaseActivationFunctor { } }; -// silu(x) = x / (1 + exp(-x)) -template -struct SiluFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out) const { - auto temp = static_cast(1) / (static_cast(1) + (-x).exp()); - out.device(d) = x * temp; - } -}; - -// silu'(x) = (1 / (1 + e^{-x})) * (1 + out * e^{-x})) -template -struct SiluGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - auto temp1 = static_cast(1) + (-x).exp(); // 1+e^(-x) - auto temp2 = x * (-x).exp(); // x*e^(-x) - dx.device(d) = dout * ((static_cast(1) / temp1) * - (static_cast(1) + (temp2 / temp1))); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } -}; - // Originally: logsigmoid(x) = -log (1 + exp(-x)) // For numerical stability, we can use the log-sum-exp trick: // https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/ @@ -512,99 +496,6 @@ using ReluGradGradFunctor = phi::funcs::ReluGradGradFunctor; template using ReluCUDAFunctor = phi::funcs::ReluCUDAFunctor; -// tanhshrink(x) = x - tanh(x) -// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) -template -struct TanhShrinkFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out) const { - out.device(d) = x - x.tanh(); - } -}; - -template -struct TanhShrinkGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - dx.device(d) = dout * (x.tanh() * x.tanh()); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } -}; - -// tanhshrink(x) = x - tanh(x) -// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) -template -struct HardShrinkFunctor : public BaseActivationFunctor { - float threshold; - - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"threshold", &threshold}}; - } - template - void operator()(Device d, X x, Out out) const { - auto temp1 = x < static_cast(threshold * -1.f); - auto temp2 = x > static_cast(threshold); - out.device(d) = x * (temp1 || temp2).template cast(); - } -}; - -template -struct HardShrinkGradFunctor : public BaseActivationFunctor { - float threshold; - - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"threshold", &threshold}}; - } - - template - void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - auto temp1 = x < static_cast(threshold * -1.f); - auto temp2 = x > static_cast(threshold); - dx.device(d) = dout * (temp1 || temp2).template cast(); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } -}; - -// softshrink(x) = x - lambda, if x > lambda; x + lambda, if x < -lambda; 0 -// otherwise -template -struct SoftShrinkFunctor : public BaseActivationFunctor { - float lambda; - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"lambda", &lambda}}; - } - - template - void operator()(Device d, X x, Out out) const { - auto lambdaT = static_cast(lambda); - auto temp1 = (x > lambdaT).template cast(); - auto temp2 = (x < -lambdaT).template cast(); - out.device(d) = temp1 * (x - lambdaT) + temp2 * (x + lambdaT); - } -}; - -template -struct SoftShrinkGradFunctor : public BaseActivationFunctor { - float lambda; - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"lambda", &lambda}}; - } - template - void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - auto lambdaT = static_cast(lambda); - auto temp1 = (x > lambdaT).template cast(); - auto temp2 = (x < -lambdaT).template cast(); - dx.device(d) = dout * (temp1 + temp2).template cast(); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } -}; - // sqrt(x) = x^(1/2) template struct SqrtFunctor : public BaseActivationFunctor { @@ -1036,59 +927,6 @@ struct SoftReluGradFunctor : public BaseActivationFunctor { } }; -template -struct ELUFunctor : public BaseActivationFunctor { - float alpha; - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"alpha", &alpha}}; - } - - template - void operator()(Device d, X x, Out out) const { - out.device(d) = - (x < static_cast(0)) - .select(static_cast(alpha) * (x.exp() - static_cast(1)), x); - } -}; - -template -struct ELUGradFunctor : public BaseActivationFunctor { - float alpha; - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"alpha", &alpha}}; - } - template - void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - // case 1: alpha >= 0 - // dx = dout, if out > 0 - // dx = dout * (out + alpha), if out <= 0 - dx.device(d) = (out > static_cast(0)) - .select(dout, dout * (out + static_cast(alpha))); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } -}; - -template -struct ELUGradNegativeAlphaFunctor : public BaseActivationFunctor { - float alpha; - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"alpha", &alpha}}; - } - template - void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - // case 2: alpha < 0 - // dx = dout, if x > 0 - // dx = dout * (out + alpha), if x <=0 - dx.device(d) = (x > static_cast(0)) - .select(dout, dout * static_cast(alpha) * x.exp()); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } -}; - template class ELUGradKernel : public framework::OpKernel { public: @@ -1354,44 +1192,6 @@ struct AbsGradGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; -template -struct ELUGradGradFunctor : public BaseActivationFunctor { - float alpha; - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"alpha", &alpha}}; - } - template - void operator()(const Device& dev, const framework::Tensor* X, - const framework::Tensor* ddX, framework::Tensor* ddOut, - const framework::Tensor* dOut, framework::Tensor* dX) const { - auto* d = dev.eigen_device(); - auto ddx = framework::EigenVector::Flatten( - GET_DATA_SAFELY(ddX, "Input", "DDX", "ELUGradGrad")); - auto x = framework::EigenVector::Flatten( - GET_DATA_SAFELY(X, "Input", "X", "ELUGradGrad")); - - if (dX) { - auto dx = framework::EigenVector::Flatten( - GET_DATA_SAFELY(dX, "Output", "DX", "ELUGradGrad")); - auto dout = framework::EigenVector::Flatten( - GET_DATA_SAFELY(dOut, "Output", "DOut", "ELUGradGrad")); - dx.device(*d) = ddx * dout * static_cast(alpha) * x.exp() * - (x <= static_cast(0)).template cast(); - } - - if (ddOut) { - auto ddout = framework::EigenVector::Flatten( - GET_DATA_SAFELY(ddOut, "Output", "DDOut", "ELUGradGrad")); - ddout.device(*d) = ddx * - ((x > static_cast(0)).template cast() + - static_cast(alpha) * x.exp() * - (x <= static_cast(0)).template cast()) - .template cast(); - } - } - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } -}; - template struct CELUGradGradFunctor : public BaseActivationFunctor { float alpha; @@ -2151,26 +1951,22 @@ struct LogGradGradFunctor : public BaseActivationFunctor { } // namespace operators } // namespace paddle -#define FOR_EACH_ACTIVATION_OP(__macro) \ - __macro(silu, Silu, SiluFunctor, SiluGradFunctor); \ - __macro(logsigmoid, LogSigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \ - __macro(softshrink, SoftShrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \ - __macro(ceil, Ceil, CeilFunctor, ZeroGradFunctor); \ - __macro(floor, Floor, FloorFunctor, ZeroGradFunctor); \ - __macro(round, Round, RoundFunctor, ZeroGradFunctor); \ - __macro(reciprocal, Reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \ - __macro(log1p, Log1p, Log1pFunctor, Log1pGradFunctor); \ - __macro(log2, Log2, Log2Functor, Log2GradFunctor); \ - __macro(log10, Log10, Log10Functor, Log10GradFunctor); \ - __macro(soft_relu, SoftRelu, SoftReluFunctor, SoftReluGradFunctor); \ - __macro(stanh, STanh, STanhFunctor, STanhGradFunctor); \ - __macro(softplus, Softplus, SoftplusFunctor, SoftplusGradFunctor); \ - __macro(softsign, Softsign, SoftsignFunctor, SoftsignGradFunctor); \ - __macro(relu6, Relu6, Relu6Functor, Relu6GradFunctor); \ - __macro(tanh_shrink, TanhShrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \ - __macro(hard_shrink, HardShrink, HardShrinkFunctor, HardShrinkGradFunctor); \ - __macro(hard_sigmoid, HardSigmoid, HardSigmoidFunctor, \ - HardSigmoidGradFunctor); \ - __macro(swish, Swish, SwishFunctor, SwishGradFunctor); \ - __macro(mish, Mish, MishFunctor, MishGradFunctor); \ +#define FOR_EACH_ACTIVATION_OP(__macro) \ + __macro(logsigmoid, LogSigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \ + __macro(ceil, Ceil, CeilFunctor, ZeroGradFunctor); \ + __macro(floor, Floor, FloorFunctor, ZeroGradFunctor); \ + __macro(round, Round, RoundFunctor, ZeroGradFunctor); \ + __macro(reciprocal, Reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \ + __macro(log1p, Log1p, Log1pFunctor, Log1pGradFunctor); \ + __macro(log2, Log2, Log2Functor, Log2GradFunctor); \ + __macro(log10, Log10, Log10Functor, Log10GradFunctor); \ + __macro(soft_relu, SoftRelu, SoftReluFunctor, SoftReluGradFunctor); \ + __macro(stanh, STanh, STanhFunctor, STanhGradFunctor); \ + __macro(softplus, Softplus, SoftplusFunctor, SoftplusGradFunctor); \ + __macro(softsign, Softsign, SoftsignFunctor, SoftsignGradFunctor); \ + __macro(relu6, Relu6, Relu6Functor, Relu6GradFunctor); \ + __macro(hard_sigmoid, HardSigmoid, HardSigmoidFunctor, \ + HardSigmoidGradFunctor); \ + __macro(swish, Swish, SwishFunctor, SwishGradFunctor); \ + __macro(mish, Mish, MishFunctor, MishGradFunctor); \ __macro(hard_swish, HardSwish, HardSwishFunctor, HardSwishGradFunctor); diff --git a/paddle/fluid/operators/activation_op.kps b/paddle/fluid/operators/activation_op.kps index 256f20db08445..22613cbe2a2b2 100644 --- a/paddle/fluid/operators/activation_op.kps +++ b/paddle/fluid/operators/activation_op.kps @@ -44,35 +44,6 @@ struct CudaSigmoidGradFunctor : public BaseActivationFunctor { } }; -template -struct CudaSiluFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - MPType one = static_cast(1.0f); - - // silu(x) = x / (1 + exp(-x)) - __device__ __forceinline__ T operator()(const T arg_x) const { - MPType x = static_cast(arg_x); - return static_cast(x / (one + exp(-x))); - } -}; - -template -struct CudaSiluGradFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - MPType one = static_cast(1.0f); - - // dx = dout * (1 + exp(-x) + x * exp(-x) / (1 + exp(-x))^2) - __device__ __forceinline__ T operator()(const T arg_dout, - const T arg_x) const { - MPType dout = static_cast(arg_dout); - MPType x = static_cast(arg_x); - MPType temp = one / (one + exp(-x)); - return static_cast(dout * (temp * (one + x * (one - temp)))); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } -}; - template struct CudaLogSigmoidFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; @@ -110,43 +81,6 @@ struct CudaLogSigmoidGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; -template -struct CudaSoftShrinkFunctor : public BaseActivationFunctor { - float lambda; - - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"lambda", &lambda}}; - } - - // softshrink(x) = x - lambda, if x > lambda; - // x + lambda, if x < -lambda; - // 0, otherwise. - __device__ __forceinline__ T operator()(const T x) const { - T l = static_cast(lambda); - T temp1 = static_cast(x > l); - T temp2 = static_cast(x < -l); - return temp1 * (x - l) + temp2 * (x + l); - } -}; - -template -struct CudaSoftShrinkGradFunctor : public BaseActivationFunctor { - T zero = static_cast(0.0f); - float lambda; - - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"lambda", &lambda}}; - } - - // dx = dout, if x > lambda or x < -lambda else 0 - __device__ __forceinline__ T operator()(const T dout, const T x) const { - T l = static_cast(lambda); - return (x >= -l && x <= l) ? zero : dout; - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } -}; - template struct CudaCeilFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; @@ -615,66 +549,6 @@ struct CudaRelu6GradFunctor : public BaseActivationFunctor { } }; -template -struct CudaTanhShrinkFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - - // tanhshrink(x) = x - tanh(x) - __device__ __forceinline__ T operator()(const T arg_x) const { - MPType x = static_cast(arg_x); - return static_cast(x - tanh(x)); - } -}; - -template -struct CudaTanhShrinkGradFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - - // dx = dout * tanh(x)^2 - __device__ __forceinline__ T operator()(const T arg_dout, - const T arg_x) const { - MPType dout = static_cast(arg_dout); - MPType x = static_cast(arg_x); - return static_cast(dout * tanh(x) * tanh(x)); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } -}; - -template -struct CudaHardShrinkFunctor : public BaseActivationFunctor { - T zero = static_cast(0.0f); - float threshold; - - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"threshold", &threshold}}; - } - - // hadrshrink(x) = (x > -threshold && x < threshold) ? 0 : x - __device__ __forceinline__ T operator()(const T x) const { - T t = static_cast(threshold); - return (x > -t && x < t) ? zero : x; - } -}; - -template -struct CudaHardShrinkGradFunctor : public BaseActivationFunctor { - T zero = static_cast(0.0f); - float threshold; - - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"threshold", &threshold}}; - } - - // dx = (x > -threshold && x < threshold) ? 0 : dout - __device__ __forceinline__ T operator()(const T dout, const T x) const { - T t = static_cast(threshold); - return (x > -t && x < t) ? zero : dout; - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } -}; - template struct CudaHardSigmoidFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); @@ -863,110 +737,6 @@ struct CudaHardSwishGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; -template -struct CudaELUFunctor : public BaseActivationFunctor { - using CT = typename details::MPTypeTrait::Type; - CT zero = static_cast(0.0f); - CT one = static_cast(1.0f); - float alpha; - - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"alpha", &alpha}}; - } - - // elu(x) = x, if x > 0 - // elu(x) = alpha * (e^x - 1), if x <= 0 - __device__ __forceinline__ T operator()(const T arg_x) const { - CT x = static_cast(arg_x); - CT temp = static_cast(alpha) * (exp(x) - one); - CT res = x > zero ? x : temp; - return static_cast(res); - } -}; - -template -struct CudaELUGradFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - MPType zero = static_cast(0.0f); - float alpha; - - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"alpha", &alpha}}; - } - - // case 1: alpha >= 0 - // dx = dout, if out > 0 - // dx = dout * (out + alpha), if out <= 0 - __device__ __forceinline__ T operator()(T arg_dout, T arg_out) const { - MPType dout = static_cast(arg_dout); - MPType out = static_cast(arg_out); - MPType a = static_cast(alpha); - MPType out_pos = static_cast(out > zero); - MPType out_neg = static_cast(out <= zero); - return static_cast(dout * (out_pos + out_neg * (out + a))); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { - return ActBwdOpFwdDeps::kDepOut; - } -}; - -template -struct CudaELUGradNegativeAlphaFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - MPType zero = static_cast(0.0f); - float alpha; - - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"alpha", &alpha}}; - } - - // case 2: alpha < 0 - // dx = dout, if x > 0 - // dx = dout * (out + alpha), if x <=0 - __device__ __forceinline__ T operator()(const T arg_dout, const T arg_out, - const T arg_x) const { - MPType dout = static_cast(arg_dout); - MPType out = static_cast(arg_out); - MPType x = static_cast(arg_x); - MPType a = static_cast(alpha); - MPType x_pos = static_cast(x > zero); - MPType x_neg = static_cast(x <= zero); - return static_cast(dout * (x_pos + x_neg * (out + a))); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } -}; - -template -class ELUGradCudaKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const { - auto* d_out = ctx.Input(framework::GradVarName("Out")); - auto* out = ctx.Input("Out"); - auto* x = ctx.Input("X"); - auto* d_x = ctx.Output(framework::GradVarName("X")); - d_x->mutable_data(ctx.GetPlace()); - const float alpha = ctx.Attr("alpha"); - - auto& dev_ctx = ctx.device_context(); - std::vector ins = {d_out, out}; - std::vector outs = {d_x}; - if (alpha > 0) { - CudaELUGradFunctor functor; - functor.alpha = alpha; - paddle::operators::LaunchSameDimsElementwiseCudaKernel(dev_ctx, ins, - &outs, functor); - } else { - CudaELUGradNegativeAlphaFunctor functor; - functor.alpha = alpha; - ins.push_back(x); - paddle::operators::LaunchSameDimsElementwiseCudaKernel(dev_ctx, ins, - &outs, functor); - } - } -}; - template struct CudaCELUFunctor : public BaseActivationFunctor { using CT = typename details::MPTypeTrait::Type; @@ -1099,6 +869,15 @@ USE_PHI_FUNCTOR(CudaTanh) USE_PHI_FUNCTOR(CudaBRelu) USE_PHI_FUNCTOR(CudaLeakyRelu) USE_PHI_FUNCTOR(CudaThresholdedRelu) +USE_PHI_FUNCTOR(CudaHardShrink) +USE_PHI_FUNCTOR(CudaSoftShrink) +USE_PHI_FUNCTOR(CudaTanhShrink) +USE_PHI_FUNCTOR(CudaSilu) +USE_PHI_FUNCTOR(CudaELU) + +template +using CudaELUGradNegativeAlphaFunctor = + phi::funcs::CudaELUGradNegativeAlphaFunctor; } // namespace operators } // namespace paddle @@ -1158,26 +937,6 @@ namespace plat = paddle::platform; ops::ActivationGradCudaKernel>); -/* ======================== elu register ============================ */ -REGISTER_OP_CUDA_KERNEL( - elu, ops::ActivationCudaKernel>, - ops::ActivationCudaKernel>, - ops::ActivationCudaKernel>); -REGISTER_OP_CUDA_KERNEL( - elu_grad, ops::ELUGradCudaKernel, - ops::ELUGradCudaKernel, - ops::ELUGradCudaKernel); - -REGISTER_OP_CUDA_KERNEL( - elu_grad_grad, ops::ELUDoubleGradKernel>, - ops::ELUDoubleGradKernel>, - ops::ELUDoubleGradKernel>); /* ========================================================================== */ /* ======================== celu register ============================ */ @@ -1359,7 +1118,6 @@ REGISTER_OP_CUDA_KERNEL( /* ========================================================================== */ #define FOR_EACH_ACTIVATION_CUDA_OP(__macro) \ - __macro(silu, Silu, CudaSiluFunctor, CudaSiluGradFunctor); \ __macro(logsigmoid, LogSigmoid, CudaLogSigmoidFunctor, \ CudaLogSigmoidGradFunctor); \ __macro(softshrink, SoftShrink, CudaSoftShrinkFunctor, \ diff --git a/paddle/phi/kernels/activation_grad_kernel.h b/paddle/phi/kernels/activation_grad_kernel.h index a5b737b28c23b..e0dfca756e147 100644 --- a/paddle/phi/kernels/activation_grad_kernel.h +++ b/paddle/phi/kernels/activation_grad_kernel.h @@ -26,6 +26,23 @@ namespace phi { const DenseTensor& dout, \ DenseTensor* dx); +#define DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepX(name, attr) \ + template \ + void name##GradKernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + const DenseTensor& dout, \ + float attr, \ + DenseTensor* dx); + +#define DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DepX(name, attr1, attr2) \ + template \ + void name##GradKernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + const DenseTensor& dout, \ + float attr1, \ + float attr2, \ + DenseTensor* dx); + #define DECLARE_ACTIVATION_GRAD_KERNEL_DepOut(name) \ template \ void name##GradKernel(const Context& dev_ctx, \ @@ -33,6 +50,14 @@ namespace phi { const DenseTensor& dout, \ DenseTensor* dx); +#define DECLARE_ACTIVATION_GRAD_KERNEL_WITH_ONE_ATTRS_DepOut(name, attr) \ + template \ + void name##GradKernel(const Context& dev_ctx, \ + const DenseTensor& out, \ + const DenseTensor& dout, \ + float attr, \ + DenseTensor* dx); + template void ReluDoubleGradKernel(const Context& dev_ctx, const DenseTensor& out, @@ -58,21 +83,6 @@ void TanhTripleGradKernel(const Context& dev_ctx, DenseTensor* d_dout, DenseTensor* d_ddx); -template -void BReluGradKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& dout, - float t_min, - float t_max, - DenseTensor* dx); - -template -void LeakyReluGradKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& dout, - float alpha, - DenseTensor* dx); - template void LeakyReluDoubleGradKernel(const Context& dev_ctx, const DenseTensor& x, @@ -81,11 +91,21 @@ void LeakyReluDoubleGradKernel(const Context& dev_ctx, DenseTensor* ddout); template -void ThresholdedReluGradKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& dout, - float threshold, - DenseTensor* dx); +void EluGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& dout, + float alpha, + DenseTensor* dx); + +template +void EluDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& dout, + const DenseTensor& ddx, + float alpha, + DenseTensor* dx, + DenseTensor* ddout); DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Cos); DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Tan); @@ -98,7 +118,17 @@ DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Cosh); DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Asinh); DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Acosh); DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Atanh); +DECLARE_ACTIVATION_GRAD_KERNEL_DepX(TanhShrink); +DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Silu); + DECLARE_ACTIVATION_GRAD_KERNEL_DepOut(Relu); DECLARE_ACTIVATION_GRAD_KERNEL_DepOut(Tanh); +DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepX(LeakyRelu, alpha) + DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepX(ThresholdedRelu, threshold) + DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepX(SoftShrink, lambda) + DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepX(HardShrink, threshold) + + DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DepX(BRelu, t_min, t_max) + } // namespace phi diff --git a/paddle/phi/kernels/activation_kernel.h b/paddle/phi/kernels/activation_kernel.h index 885dccad8e377..0762ce43ff8f0 100644 --- a/paddle/phi/kernels/activation_kernel.h +++ b/paddle/phi/kernels/activation_kernel.h @@ -24,6 +24,21 @@ namespace phi { void name##Kernel( \ const Context& dev_ctx, const DenseTensor& x, DenseTensor* out); +#define DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(name, attr) \ + template \ + void name##Kernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + float attr, \ + DenseTensor* out); + +#define DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(name, attr1, attr2) \ + template \ + void name##Kernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + float attr1, \ + float attr2, \ + DenseTensor* out); + DECLARE_ACTIVATION_KERNEL(Cos) DECLARE_ACTIVATION_KERNEL(Tan) DECLARE_ACTIVATION_KERNEL(Acos) @@ -37,24 +52,15 @@ DECLARE_ACTIVATION_KERNEL(Acosh) DECLARE_ACTIVATION_KERNEL(Atanh) DECLARE_ACTIVATION_KERNEL(Relu) DECLARE_ACTIVATION_KERNEL(Tanh) +DECLARE_ACTIVATION_KERNEL(TanhShrink) +DECLARE_ACTIVATION_KERNEL(Silu) + +DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(LeakyRelu, alpha) +DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, threshold) +DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(SoftShrink, lambda) +DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(HardShrink, threshold) +DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Elu, alpha) -template -void BReluKernel(const Context& dev_ctx, - const DenseTensor& x, - float t_min, - float t_max, - DenseTensor* out); - -template -void LeakyReluKernel(const Context& dev_ctx, - const DenseTensor& x, - float alpha, - DenseTensor* out); - -template -void ThresholdedReluKernel(const Context& dev_ctx, - const DenseTensor& x, - float threshold, - DenseTensor* out); +DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(BRelu, t_min, t_max) } // namespace phi diff --git a/paddle/phi/kernels/cpu/activation_grad_kernel.cc b/paddle/phi/kernels/cpu/activation_grad_kernel.cc index f9af50f6832a1..11b396a84d0de 100644 --- a/paddle/phi/kernels/cpu/activation_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_grad_kernel.cc @@ -21,101 +21,140 @@ limitations under the License. */ namespace phi { -#define DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(name, functor_class) \ +#define DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(name, functor_class) \ template \ void name##GradKernel(const Context& dev_ctx, \ const DenseTensor& x, \ const DenseTensor& dout, \ DenseTensor* dx) { \ - functor_class functor; \ - ActivationGradImpl>( \ + funcs::functor_class functor; \ + ActivationGradImpl>( \ dev_ctx, &x, nullptr, &dout, dx, functor); \ } -#define DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepX( \ - name, functor_class, attr) \ - template \ - void name##GradKernel(const Context& dev_ctx, \ - const DenseTensor& x, \ - const DenseTensor& dout, \ - float attr, \ - DenseTensor* dx) { \ - functor_class functor; \ - auto attrs = functor.GetAttrs(); \ - *(attrs[0].second) = attr; \ - ActivationGradImpl>( \ - dev_ctx, &x, nullptr, &dout, dx, functor); \ +#define DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX( \ + name, functor_class, attr) \ + template \ + void name##GradKernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + const DenseTensor& dout, \ + float attr, \ + DenseTensor* dx) { \ + funcs::functor_class functor; \ + auto attrs = functor.GetAttrs(); \ + *(attrs[0].second) = attr; \ + ActivationGradImpl>( \ + dev_ctx, &x, nullptr, &dout, dx, functor); \ } -#define DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DepX( \ - name, functor_class, attr1, attr2) \ - template \ - void name##GradKernel(const Context& dev_ctx, \ - const DenseTensor& x, \ - const DenseTensor& dout, \ - float attr1, \ - float attr2, \ - DenseTensor* dx) { \ - functor_class functor; \ - auto attrs = functor.GetAttrs(); \ - *(attrs[0].second) = attr1; \ - *(attrs[1].second) = attr2; \ - ActivationGradImpl>( \ - dev_ctx, &x, nullptr, &dout, dx, functor); \ +#define DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX( \ + name, functor_class, attr1, attr2) \ + template \ + void name##GradKernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + const DenseTensor& dout, \ + float attr1, \ + float attr2, \ + DenseTensor* dx) { \ + funcs::functor_class functor; \ + auto attrs = functor.GetAttrs(); \ + *(attrs[0].second) = attr1; \ + *(attrs[1].second) = attr2; \ + ActivationGradImpl>( \ + dev_ctx, &x, nullptr, &dout, dx, functor); \ } -#define DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepOut(name, functor_class) \ +#define DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(name, functor_class) \ template \ void name##GradKernel(const Context& dev_ctx, \ const DenseTensor& out, \ const DenseTensor& dout, \ DenseTensor* dx) { \ - functor_class functor; \ - ActivationGradImpl>( \ + funcs::functor_class functor; \ + ActivationGradImpl>( \ dev_ctx, nullptr, &out, &dout, dx, functor); \ } -#define DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepOut( \ - name, functor_class, attr) \ - template \ - void name##GradKernel(const Context& dev_ctx, \ - const DenseTensor& out, \ - const DenseTensor& dout, \ - float attr, \ - DenseTensor* dx) { \ - functor_class functor; \ - auto attrs = functor.GetAttrs(); \ - *(attrs[0].second) = attr; \ - ActivationGradImpl>( \ - dev_ctx, nullptr, &out, &dout, dx, functor); \ +#define DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPOUT( \ + name, functor_class, attr) \ + template \ + void name##GradKernel(const Context& dev_ctx, \ + const DenseTensor& out, \ + const DenseTensor& dout, \ + float attr, \ + DenseTensor* dx) { \ + funcs::functor_class functor; \ + auto attrs = functor.GetAttrs(); \ + *(attrs[0].second) = attr; \ + ActivationGradImpl>( \ + dev_ctx, nullptr, &out, &dout, dx, functor); \ } -DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Cos, funcs::CosGradFunctor); -DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Tan, funcs::TanGradFunctor); -DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Acos, funcs::AcosGradFunctor); -DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Sin, funcs::SinGradFunctor); -DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Asin, funcs::AsinGradFunctor); -DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Atan, funcs::AtanGradFunctor); -DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Sinh, funcs::SinhGradFunctor); -DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Cosh, funcs::CoshGradFunctor); -DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Asinh, funcs::AsinhGradFunctor); -DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Acosh, funcs::AcoshGradFunctor); -DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Atanh, funcs::AtanhGradFunctor); - -DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepOut(Relu, funcs::ReluGradFunctor); -DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepOut(Tanh, funcs::TanhGradFunctor); - -DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepX(LeakyRelu, - funcs::LeakyReluGradFunctor, +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Cos, CosGradFunctor); +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Tan, TanGradFunctor); +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Acos, AcosGradFunctor); +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Sin, SinGradFunctor); +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Asin, AsinGradFunctor); +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Atan, AtanGradFunctor); +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Sinh, SinhGradFunctor); +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Cosh, CoshGradFunctor); +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Asinh, AsinhGradFunctor); +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Acosh, AcoshGradFunctor); +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Atanh, AtanhGradFunctor); +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(TanhShrink, TanhShrinkGradFunctor); +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Silu, SiluGradFunctor); + +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu, ReluGradFunctor); +DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Tanh, TanhGradFunctor); + +DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(LeakyRelu, + LeakyReluGradFunctor, alpha); -DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepX( - ThresholdedRelu, funcs::ThresholdedReluGradFunctor, threshold); - -DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DepX(BRelu, - funcs::BReluGradFunctor, +DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(ThresholdedRelu, + ThresholdedReluGradFunctor, + threshold); +DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(SoftShrink, + SoftShrinkGradFunctor, + lambda); +DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(HardShrink, + HardShrinkGradFunctor, + threshold); + +DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(BRelu, + BReluGradFunctor, t_min, t_max); +template +void EluGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& dout, + float alpha, + DenseTensor* dx) { + dev_ctx.template Alloc(dx); + + auto x_flatten = + EigenVector::Flatten(GET_DATA_SAFELY(&x, "Input", "X", "elu_grad")); + auto out_flatten = EigenVector::Flatten( + GET_DATA_SAFELY(&out, "Input", "Out", "elu_grad")); + auto dout_flatten = EigenVector::Flatten( + GET_DATA_SAFELY(&dout, "Input", "dOut", "elu_grad")); + auto dx_flatten = + EigenVector::Flatten(GET_DATA_SAFELY(dx, "Output", "dX", "elu_grad")); + auto* place = dev_ctx.eigen_device(); + + if (alpha > 0) { + funcs::ELUGradFunctor functor; + functor.alpha = alpha; + functor(*place, x_flatten, out_flatten, dout_flatten, dx_flatten); + } else { + funcs::ELUGradNegativeAlphaFunctor functor; + functor.alpha = alpha; + functor(*place, x_flatten, out_flatten, dout_flatten, dx_flatten); + } +} + } // namespace phi PD_REGISTER_KERNEL( @@ -144,6 +183,11 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(brelu_grad, BReluGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(thresholded_relu_grad, ThresholdedReluGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(soft_shrink_grad, SoftShrinkGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_shrink_grad, HardShrinkGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_shrink_grad, TanhShrinkGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_grad, EluGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(silu_grad, SiluGradKernel) PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(relu_double_grad, ReluDoubleGradKernel) @@ -151,6 +195,7 @@ PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(tanh_double_grad, TanhDoubleGradKernel) PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(leaky_relu_double_grad, LeakyReluDoubleGradKernel) +PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(elu_double_grad, EluDoubleGradKernel) PD_REGISTER_KERNEL(tanh_triple_grad, CPU, diff --git a/paddle/phi/kernels/cpu/activation_kernel.cc b/paddle/phi/kernels/cpu/activation_kernel.cc index 0d13429c8f651..59ce18a11cc5e 100644 --- a/paddle/phi/kernels/cpu/activation_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_kernel.cc @@ -19,78 +19,93 @@ limitations under the License. */ namespace phi { -#define DEFINE_CPU_ACTIVATION_KERNEL(name, functor_class) \ - template \ - void name##Kernel( \ - const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) { \ - functor_class functor; \ - ActivationImpl(dev_ctx, x, out, functor); \ +#define DEFINE_CPU_ACTIVATION_KERNEL(name, functor_class) \ + template \ + void name##Kernel( \ + const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) { \ + funcs::functor_class functor; \ + ActivationImpl>( \ + dev_ctx, x, out, functor); \ } -#define DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(name, functor_class, attr) \ - template \ - void name##Kernel(const Context& dev_ctx, \ - const DenseTensor& x, \ - float attr, \ - DenseTensor* out) { \ - functor_class functor; \ - auto attrs = functor.GetAttrs(); \ - *(attrs[0].second) = attr; \ - ActivationImpl>(dev_ctx, x, out, functor); \ +#define DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(name, functor_class, attr) \ + template \ + void name##Kernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + float attr, \ + DenseTensor* out) { \ + funcs::functor_class functor; \ + auto attrs = functor.GetAttrs(); \ + *(attrs[0].second) = attr; \ + ActivationImpl>( \ + dev_ctx, x, out, functor); \ } -#define DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS( \ - name, functor_class, attr1, attr2) \ - template \ - void name##Kernel(const Context& dev_ctx, \ - const DenseTensor& x, \ - float attr1, \ - float attr2, \ - DenseTensor* out) { \ - functor_class functor; \ - auto attrs = functor.GetAttrs(); \ - *(attrs[0].second) = attr1; \ - *(attrs[1].second) = attr2; \ - ActivationImpl>(dev_ctx, x, out, functor); \ +#define DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS( \ + name, functor_class, attr1, attr2) \ + template \ + void name##Kernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + float attr1, \ + float attr2, \ + DenseTensor* out) { \ + funcs::functor_class functor; \ + auto attrs = functor.GetAttrs(); \ + *(attrs[0].second) = attr1; \ + *(attrs[1].second) = attr2; \ + ActivationImpl>( \ + dev_ctx, x, out, functor); \ } -DEFINE_CPU_ACTIVATION_KERNEL(Sin, funcs::SinFunctor) -DEFINE_CPU_ACTIVATION_KERNEL(Cos, funcs::CosFunctor) -DEFINE_CPU_ACTIVATION_KERNEL(Tan, funcs::TanFunctor) -DEFINE_CPU_ACTIVATION_KERNEL(Asin, funcs::AsinFunctor) -DEFINE_CPU_ACTIVATION_KERNEL(Atan, funcs::AtanFunctor) -DEFINE_CPU_ACTIVATION_KERNEL(Acos, funcs::AcosFunctor) -DEFINE_CPU_ACTIVATION_KERNEL(Sinh, funcs::SinhFunctor) -DEFINE_CPU_ACTIVATION_KERNEL(Cosh, funcs::CoshFunctor) -DEFINE_CPU_ACTIVATION_KERNEL(Asinh, funcs::AsinhFunctor) -DEFINE_CPU_ACTIVATION_KERNEL(Acosh, funcs::AcoshFunctor) -DEFINE_CPU_ACTIVATION_KERNEL(Atanh, funcs::AtanhFunctor) -DEFINE_CPU_ACTIVATION_KERNEL(Relu, funcs::ReluCPUFunctor) -DEFINE_CPU_ACTIVATION_KERNEL(Tanh, funcs::TanhFunctor) -DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, funcs::LeakyReluFunctor, alpha) +DEFINE_CPU_ACTIVATION_KERNEL(Sin, SinFunctor) +DEFINE_CPU_ACTIVATION_KERNEL(Cos, CosFunctor) +DEFINE_CPU_ACTIVATION_KERNEL(Tan, TanFunctor) +DEFINE_CPU_ACTIVATION_KERNEL(Asin, AsinFunctor) +DEFINE_CPU_ACTIVATION_KERNEL(Atan, AtanFunctor) +DEFINE_CPU_ACTIVATION_KERNEL(Acos, AcosFunctor) +DEFINE_CPU_ACTIVATION_KERNEL(Sinh, SinhFunctor) +DEFINE_CPU_ACTIVATION_KERNEL(Cosh, CoshFunctor) +DEFINE_CPU_ACTIVATION_KERNEL(Asinh, AsinhFunctor) +DEFINE_CPU_ACTIVATION_KERNEL(Acosh, AcoshFunctor) +DEFINE_CPU_ACTIVATION_KERNEL(Atanh, AtanhFunctor) +DEFINE_CPU_ACTIVATION_KERNEL(Relu, ReluCPUFunctor) +DEFINE_CPU_ACTIVATION_KERNEL(Tanh, TanhFunctor) +DEFINE_CPU_ACTIVATION_KERNEL(TanhShrink, TanhShrinkFunctor) +DEFINE_CPU_ACTIVATION_KERNEL(Silu, SiluFunctor) + +DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, LeakyReluFunctor, alpha) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, - funcs::ThresholdedReluFunctor, + ThresholdedReluFunctor, threshold) -DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(BRelu, funcs::BReluFunctor, t_min, t_max) +DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink, HardShrinkFunctor, threshold) +DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(SoftShrink, SoftShrinkFunctor, lambda) +DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Elu, ELUFunctor, alpha) + +DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(BRelu, BReluFunctor, t_min, t_max) } // namespace phi PD_REGISTER_KERNEL(relu, CPU, ALL_LAYOUT, phi::ReluKernel, float, double) {} #define PD_REGISTER_ACTIVATION_KERNEL(name, func) \ - PD_REGISTER_KERNEL(name, CPU, ALL_LAYOUT, phi::func##Kernel, float, double) {} + PD_REGISTER_KERNEL(name, CPU, ALL_LAYOUT, phi::func, float, double) {} -PD_REGISTER_ACTIVATION_KERNEL(sin, Sin) -PD_REGISTER_ACTIVATION_KERNEL(cos, Cos) -PD_REGISTER_ACTIVATION_KERNEL(tan, Tan) -PD_REGISTER_ACTIVATION_KERNEL(acos, Acos) -PD_REGISTER_ACTIVATION_KERNEL(asin, Asin) -PD_REGISTER_ACTIVATION_KERNEL(atan, Atan) -PD_REGISTER_ACTIVATION_KERNEL(sinh, Sinh) -PD_REGISTER_ACTIVATION_KERNEL(cosh, Cosh) -PD_REGISTER_ACTIVATION_KERNEL(asinh, Asinh) -PD_REGISTER_ACTIVATION_KERNEL(acosh, Acosh) -PD_REGISTER_ACTIVATION_KERNEL(atanh, Atanh) -PD_REGISTER_ACTIVATION_KERNEL(tanh, Tanh) -PD_REGISTER_ACTIVATION_KERNEL(brelu, BRelu) -PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyRelu) -PD_REGISTER_ACTIVATION_KERNEL(thresholded_relu, ThresholdedRelu) +PD_REGISTER_ACTIVATION_KERNEL(sin, SinKernel) +PD_REGISTER_ACTIVATION_KERNEL(cos, CosKernel) +PD_REGISTER_ACTIVATION_KERNEL(tan, TanKernel) +PD_REGISTER_ACTIVATION_KERNEL(acos, AcosKernel) +PD_REGISTER_ACTIVATION_KERNEL(asin, AsinKernel) +PD_REGISTER_ACTIVATION_KERNEL(atan, AtanKernel) +PD_REGISTER_ACTIVATION_KERNEL(sinh, SinhKernel) +PD_REGISTER_ACTIVATION_KERNEL(cosh, CoshKernel) +PD_REGISTER_ACTIVATION_KERNEL(asinh, AsinhKernel) +PD_REGISTER_ACTIVATION_KERNEL(acosh, AcoshKernel) +PD_REGISTER_ACTIVATION_KERNEL(atanh, AtanhKernel) +PD_REGISTER_ACTIVATION_KERNEL(tanh, TanhKernel) +PD_REGISTER_ACTIVATION_KERNEL(brelu, BReluKernel) +PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel) +PD_REGISTER_ACTIVATION_KERNEL(thresholded_relu, ThresholdedReluKernel) +PD_REGISTER_ACTIVATION_KERNEL(hard_shrink, HardShrinkKernel) +PD_REGISTER_ACTIVATION_KERNEL(soft_shrink, SoftShrinkKernel) +PD_REGISTER_ACTIVATION_KERNEL(tanh_shrink, TanhShrinkKernel) +PD_REGISTER_ACTIVATION_KERNEL(elu, EluKernel) +PD_REGISTER_ACTIVATION_KERNEL(silu, SiluKernel) diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index c8fb54bb102d3..9c37427d87751 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -29,11 +29,13 @@ #include #include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/extensions.h" namespace phi { namespace funcs { @@ -776,6 +778,236 @@ struct ThresholdedReluGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +// tanhshrink(x) = x - tanh(x) +// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) +template +struct TanhShrinkFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x - x.tanh(); + } +}; + +template +struct TanhShrinkGradFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * (x.tanh() * x.tanh()); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + +// tanhshrink(x) = x - tanh(x) +// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) +template +struct HardShrinkFunctor : public BaseActivationFunctor { + float threshold; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + template + void operator()(Device d, X x, Out out) const { + auto temp1 = x < static_cast(threshold * -1.f); + auto temp2 = x > static_cast(threshold); + out.device(d) = x * (temp1 || temp2).template cast(); + } +}; + +template +struct HardShrinkGradFunctor : public BaseActivationFunctor { + float threshold; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + auto temp1 = x < static_cast(threshold * -1.f); + auto temp2 = x > static_cast(threshold); + dx.device(d) = dout * (temp1 || temp2).template cast(); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + +// softshrink(x) = x - lambda, if x > lambda; x + lambda, if x < -lambda; 0 +// otherwise +template +struct SoftShrinkFunctor : public BaseActivationFunctor { + float lambda; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"lambda", &lambda}}; + } + + template + void operator()(Device d, X x, Out out) const { + auto lambdaT = static_cast(lambda); + auto temp1 = (x > lambdaT).template cast(); + auto temp2 = (x < -lambdaT).template cast(); + out.device(d) = temp1 * (x - lambdaT) + temp2 * (x + lambdaT); + } +}; + +template +struct SoftShrinkGradFunctor : public BaseActivationFunctor { + float lambda; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"lambda", &lambda}}; + } + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + auto lambdaT = static_cast(lambda); + auto temp1 = (x > lambdaT).template cast(); + auto temp2 = (x < -lambdaT).template cast(); + dx.device(d) = dout * (temp1 + temp2).template cast(); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + +template +struct ELUFunctor : public BaseActivationFunctor { + float alpha; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"alpha", &alpha}}; + } + + template + void operator()(Device d, X x, Out out) const { + out.device(d) = + (x < static_cast(0)) + .select(static_cast(alpha) * (x.exp() - static_cast(1)), x); + } +}; + +template +struct ELUGradFunctor : public BaseActivationFunctor { + float alpha; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"alpha", &alpha}}; + } + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + // case 1: alpha >= 0 + // dx = dout, if out > 0 + // dx = dout * (out + alpha), if out <= 0 + dx.device(d) = (out > static_cast(0)) + .select(dout, dout * (out + static_cast(alpha))); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + +template +struct ELUGradNegativeAlphaFunctor : public BaseActivationFunctor { + float alpha; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"alpha", &alpha}}; + } + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + // case 2: alpha < 0 + // dx = dout, if x > 0 + // dx = dout * (out + alpha), if x <=0 + dx.device(d) = (x > static_cast(0)) + .select(dout, dout * static_cast(alpha) * x.exp()); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + +template +struct ELUGradGradFunctor : public BaseActivationFunctor { + float alpha; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"alpha", &alpha}}; + } + template + void operator()(const Device& dev, + const DenseTensor* X, + const DenseTensor* ddX, + DenseTensor* ddOut, + const DenseTensor* dOut, + DenseTensor* dX) const { + auto* d = dev.eigen_device(); + auto ddx = EigenVector::Flatten( + GET_DATA_SAFELY(ddX, "Input", "DDX", "ELUGradGrad")); + auto x = EigenVector::Flatten( + GET_DATA_SAFELY(X, "Input", "X", "ELUGradGrad")); + + if (dX) { + auto dx = EigenVector::Flatten( + GET_DATA_SAFELY(dX, "Output", "DX", "ELUGradGrad")); + auto dout = EigenVector::Flatten( + GET_DATA_SAFELY(dOut, "Output", "DOut", "ELUGradGrad")); + dx.device(*d) = ddx * dout * static_cast(alpha) * x.exp() * + (x <= static_cast(0)).template cast(); + } + + if (ddOut) { + auto ddout = EigenVector::Flatten( + GET_DATA_SAFELY(ddOut, "Output", "DDOut", "ELUGradGrad")); + ddout.device(*d) = ddx * + ((x > static_cast(0)).template cast() + + static_cast(alpha) * x.exp() * + (x <= static_cast(0)).template cast()) + .template cast(); + } + } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + +// silu(x) = x / (1 + exp(-x)) +template +struct SiluFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + auto temp = static_cast(1) / (static_cast(1) + (-x).exp()); + out.device(d) = x * temp; + } +}; + +// silu'(x) = (1 / (1 + e^{-x})) * (1 + out * e^{-x})) +template +struct SiluGradFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + auto temp1 = static_cast(1) + (-x).exp(); // 1+e^(-x) + auto temp2 = x * (-x).exp(); // x*e^(-x) + dx.device(d) = dout * ((static_cast(1) / temp1) * + (static_cast(1) + (temp2 / temp1))); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + #if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__) template struct CudaReluFunctor : public BaseActivationFunctor { @@ -1214,6 +1446,209 @@ struct CudaLeakyReluGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; + +template +struct CudaSoftShrinkFunctor : public BaseActivationFunctor { + float lambda; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"lambda", &lambda}}; + } + + // softshrink(x) = x - lambda, if x > lambda; + // x + lambda, if x < -lambda; + // 0, otherwise. + __device__ __forceinline__ T operator()(const T x) const { + T l = static_cast(lambda); + T temp1 = static_cast(x > l); + T temp2 = static_cast(x < -l); + return temp1 * (x - l) + temp2 * (x + l); + } +}; + +template +struct CudaSoftShrinkGradFunctor : public BaseActivationFunctor { + T zero = static_cast(0.0f); + float lambda; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"lambda", &lambda}}; + } + + // dx = dout, if x > lambda or x < -lambda else 0 + __device__ __forceinline__ T operator()(const T dout, const T x) const { + T l = static_cast(lambda); + return (x >= -l && x <= l) ? zero : dout; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + +template +struct CudaTanhShrinkFunctor : public BaseActivationFunctor { + using MPType = typename phi::dtype::MPTypeTrait::Type; + + // tanhshrink(x) = x - tanh(x) + __device__ __forceinline__ T operator()(const T arg_x) const { + MPType x = static_cast(arg_x); + return static_cast(x - tanh(x)); + } +}; + +template +struct CudaTanhShrinkGradFunctor : public BaseActivationFunctor { + using MPType = typename phi::dtype::MPTypeTrait::Type; + + // dx = dout * tanh(x)^2 + __device__ __forceinline__ T operator()(const T arg_dout, + const T arg_x) const { + MPType dout = static_cast(arg_dout); + MPType x = static_cast(arg_x); + return static_cast(dout * tanh(x) * tanh(x)); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + +template +struct CudaHardShrinkFunctor : public BaseActivationFunctor { + T zero = static_cast(0.0f); + float threshold; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + // hadrshrink(x) = (x > -threshold && x < threshold) ? 0 : x + __device__ __forceinline__ T operator()(const T x) const { + T t = static_cast(threshold); + return (x > -t && x < t) ? zero : x; + } +}; + +template +struct CudaHardShrinkGradFunctor : public BaseActivationFunctor { + T zero = static_cast(0.0f); + float threshold; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + // dx = (x > -threshold && x < threshold) ? 0 : dout + __device__ __forceinline__ T operator()(const T dout, const T x) const { + T t = static_cast(threshold); + return (x > -t && x < t) ? zero : dout; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + +template +struct CudaELUFunctor : public BaseActivationFunctor { + using CT = typename phi::dtype::MPTypeTrait::Type; + CT zero = static_cast(0.0f); + CT one = static_cast(1.0f); + float alpha; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"alpha", &alpha}}; + } + + // elu(x) = x, if x > 0 + // elu(x) = alpha * (e^x - 1), if x <= 0 + __device__ __forceinline__ T operator()(const T arg_x) const { + CT x = static_cast(arg_x); + CT temp = static_cast(alpha) * (exp(x) - one); + CT res = x > zero ? x : temp; + return static_cast(res); + } +}; + +template +struct CudaELUGradFunctor : public BaseActivationFunctor { + using MPType = typename phi::dtype::MPTypeTrait::Type; + MPType zero = static_cast(0.0f); + float alpha; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"alpha", &alpha}}; + } + + // case 1: alpha >= 0 + // dx = dout, if out > 0 + // dx = dout * (out + alpha), if out <= 0 + __device__ __forceinline__ T operator()(T arg_dout, T arg_out) const { + MPType dout = static_cast(arg_dout); + MPType out = static_cast(arg_out); + MPType a = static_cast(alpha); + MPType out_pos = static_cast(out > zero); + MPType out_neg = static_cast(out <= zero); + return static_cast(dout * (out_pos + out_neg * (out + a))); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } +}; + +template +struct CudaELUGradNegativeAlphaFunctor : public BaseActivationFunctor { + using MPType = typename phi::dtype::MPTypeTrait::Type; + MPType zero = static_cast(0.0f); + float alpha; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"alpha", &alpha}}; + } + + // case 2: alpha < 0 + // dx = dout, if x > 0 + // dx = dout * (out + alpha), if x <=0 + __device__ __forceinline__ T operator()(const T arg_dout, + const T arg_out, + const T arg_x) const { + MPType dout = static_cast(arg_dout); + MPType out = static_cast(arg_out); + MPType x = static_cast(arg_x); + MPType a = static_cast(alpha); + MPType x_pos = static_cast(x > zero); + MPType x_neg = static_cast(x <= zero); + return static_cast(dout * (x_pos + x_neg * (out + a))); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + +template +struct CudaSiluFunctor : public BaseActivationFunctor { + using MPType = typename phi::dtype::MPTypeTrait::Type; + MPType one = static_cast(1.0f); + + // silu(x) = x / (1 + exp(-x)) + __device__ __forceinline__ T operator()(const T arg_x) const { + MPType x = static_cast(arg_x); + return static_cast(x / (one + exp(-x))); + } +}; + +template +struct CudaSiluGradFunctor : public BaseActivationFunctor { + using MPType = typename phi::dtype::MPTypeTrait::Type; + MPType one = static_cast(1.0f); + + // dx = dout * (1 + exp(-x) + x * exp(-x) / (1 + exp(-x))^2) + __device__ __forceinline__ T operator()(const T arg_dout, + const T arg_x) const { + MPType dout = static_cast(arg_dout); + MPType x = static_cast(arg_x); + MPType temp = one / (one + exp(-x)); + return static_cast(dout * (temp * (one + x * (one - temp)))); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + #endif } // namespace funcs diff --git a/paddle/phi/kernels/gpu/activation_grad_kernel.cu b/paddle/phi/kernels/gpu/activation_grad_kernel.cu index 00792b8ab6070..b12fc6975b37d 100644 --- a/paddle/phi/kernels/gpu/activation_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_grad_kernel.cu @@ -73,7 +73,7 @@ void ActivationGradGPUImpl(const Context& dev_ctx, } } -#define DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(name, functor_class) \ +#define DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(name, functor_class) \ template \ void name##GradKernel(const Context& dev_ctx, \ const DenseTensor& x, \ @@ -84,7 +84,7 @@ void ActivationGradGPUImpl(const Context& dev_ctx, dev_ctx, &x, nullptr, &dout, dx, functor); \ } -#define DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepX( \ +#define DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX( \ name, functor_class, attr) \ template \ void name##GradKernel(const Context& dev_ctx, \ @@ -99,7 +99,7 @@ void ActivationGradGPUImpl(const Context& dev_ctx, dev_ctx, &x, nullptr, &dout, dx, functor); \ } -#define DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DepX( \ +#define DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX( \ name, functor_class, attr1, attr2) \ template \ void name##GradKernel(const Context& dev_ctx, \ @@ -116,7 +116,7 @@ void ActivationGradGPUImpl(const Context& dev_ctx, dev_ctx, &x, nullptr, &dout, dx, functor); \ } -#define DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepOut(name, functor_class) \ +#define DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(name, functor_class) \ template \ void name##GradKernel(const Context& dev_ctx, \ const DenseTensor& out, \ @@ -127,7 +127,7 @@ void ActivationGradGPUImpl(const Context& dev_ctx, dev_ctx, nullptr, &out, &dout, dx, functor); \ } -#define DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepOut( \ +#define DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPOUT( \ name, functor_class, attr) \ template \ void name##GradKernel(const Context& dev_ctx, \ @@ -142,32 +142,62 @@ void ActivationGradGPUImpl(const Context& dev_ctx, dev_ctx, nullptr, &out, &dout, dx, functor); \ } -DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepOut(Relu, CudaReluGradFunctor); -DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepOut(Tanh, CudaTanhGradFunctor); -DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Cos, CudaCosGradFunctor); -DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Tan, CudaTanGradFunctor); -DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Acos, CudaAcosGradFunctor); -DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Sin, CudaSinGradFunctor); -DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Asin, CudaAsinGradFunctor); -DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Atan, CudaAtanGradFunctor); -DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Sinh, CudaSinhGradFunctor); -DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Cosh, CudaCoshGradFunctor); -DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Asinh, CudaAsinhGradFunctor); -DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Acosh, CudaAcoshGradFunctor); -DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Atanh, CudaAtanhGradFunctor); - -DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepX(LeakyRelu, +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu, CudaReluGradFunctor); +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Tanh, CudaTanhGradFunctor); +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Cos, CudaCosGradFunctor); +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Tan, CudaTanGradFunctor); +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Acos, CudaAcosGradFunctor); +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Sin, CudaSinGradFunctor); +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Asin, CudaAsinGradFunctor); +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Atan, CudaAtanGradFunctor); +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Sinh, CudaSinhGradFunctor); +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Cosh, CudaCoshGradFunctor); +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Asinh, CudaAsinhGradFunctor); +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Acosh, CudaAcoshGradFunctor); +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Atanh, CudaAtanhGradFunctor); +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(TanhShrink, CudaTanhShrinkGradFunctor); +DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Silu, CudaSiluGradFunctor); + +DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(LeakyRelu, CudaLeakyReluGradFunctor, alpha); -DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepX(ThresholdedRelu, +DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(ThresholdedRelu, CudaThresholdedReluGradFunctor, threshold); +DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(SoftShrink, + CudaSoftShrinkGradFunctor, + lambda); +DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(HardShrink, + CudaHardShrinkGradFunctor, + threshold); -DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DepX(BRelu, +DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(BRelu, CudaBReluGradFunctor, t_min, t_max); +template +void EluGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& dout, + float alpha, + DenseTensor* dx) { + dev_ctx.template Alloc(dx); + std::vector ins = {&dout, &out}; + std::vector outs = {dx}; + if (alpha > 0) { + funcs::CudaELUGradFunctor functor; + functor.alpha = alpha; + funcs::ElementwiseKernel(dev_ctx, ins, &outs, functor); + } else { + funcs::CudaELUGradNegativeAlphaFunctor functor; + functor.alpha = alpha; + ins.push_back(&x); + funcs::ElementwiseKernel(dev_ctx, ins, &outs, functor); + } +} + } // namespace phi #ifdef PADDLE_WITH_HIP @@ -234,3 +264,9 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_double_grad, LeakyReluDoubleGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(thresholded_relu_grad, ThresholdedReluGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(soft_shrink_grad, SoftShrinkGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_shrink_grad, HardShrinkGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_shrink_grad, TanhShrinkGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(silu_grad, SiluGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_grad, EluGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_double_grad, EluDoubleGradKernel) diff --git a/paddle/phi/kernels/gpu/activation_kernel.cu b/paddle/phi/kernels/gpu/activation_kernel.cu index 3c340a89f5746..cd9330ead8429 100644 --- a/paddle/phi/kernels/gpu/activation_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_kernel.cu @@ -38,12 +38,13 @@ void ActivationGPUImpl(const Context& dev_ctx, funcs::ElementwiseKernel(dev_ctx, ins, &outs, functor); } -#define DEFINE_GPU_ACTIVATION_KERNEL(name, functor_class) \ - template \ - void name##Kernel( \ - const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) { \ - functor_class functor; \ - ActivationGPUImpl(dev_ctx, x, out, functor); \ +#define DEFINE_GPU_ACTIVATION_KERNEL(name, functor_class) \ + template \ + void name##Kernel( \ + const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) { \ + funcs::functor_class functor; \ + ActivationGPUImpl>( \ + dev_ctx, x, out, functor); \ } #define DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(name, functor_class, attr) \ @@ -75,24 +76,31 @@ void ActivationGPUImpl(const Context& dev_ctx, dev_ctx, x, out, functor); \ } -DEFINE_GPU_ACTIVATION_KERNEL(Cos, funcs::CudaCosFunctor) -DEFINE_GPU_ACTIVATION_KERNEL(Tan, funcs::CudaTanFunctor) -DEFINE_GPU_ACTIVATION_KERNEL(Acos, funcs::CudaAcosFunctor) -DEFINE_GPU_ACTIVATION_KERNEL(Sin, funcs::CudaSinFunctor) -DEFINE_GPU_ACTIVATION_KERNEL(Asin, funcs::CudaAsinFunctor) -DEFINE_GPU_ACTIVATION_KERNEL(Atan, funcs::CudaAtanFunctor) -DEFINE_GPU_ACTIVATION_KERNEL(Sinh, funcs::CudaSinhFunctor) -DEFINE_GPU_ACTIVATION_KERNEL(Cosh, funcs::CudaCoshFunctor) -DEFINE_GPU_ACTIVATION_KERNEL(Asinh, funcs::CudaAsinhFunctor) -DEFINE_GPU_ACTIVATION_KERNEL(Acosh, funcs::CudaAcoshFunctor) -DEFINE_GPU_ACTIVATION_KERNEL(Atanh, funcs::CudaAtanhFunctor) -DEFINE_GPU_ACTIVATION_KERNEL(Relu, funcs::CudaReluFunctor) -DEFINE_GPU_ACTIVATION_KERNEL(Tanh, funcs::CudaTanhFunctor) +DEFINE_GPU_ACTIVATION_KERNEL(Cos, CudaCosFunctor) +DEFINE_GPU_ACTIVATION_KERNEL(Tan, CudaTanFunctor) +DEFINE_GPU_ACTIVATION_KERNEL(Acos, CudaAcosFunctor) +DEFINE_GPU_ACTIVATION_KERNEL(Sin, CudaSinFunctor) +DEFINE_GPU_ACTIVATION_KERNEL(Asin, CudaAsinFunctor) +DEFINE_GPU_ACTIVATION_KERNEL(Atan, CudaAtanFunctor) +DEFINE_GPU_ACTIVATION_KERNEL(Sinh, CudaSinhFunctor) +DEFINE_GPU_ACTIVATION_KERNEL(Cosh, CudaCoshFunctor) +DEFINE_GPU_ACTIVATION_KERNEL(Asinh, CudaAsinhFunctor) +DEFINE_GPU_ACTIVATION_KERNEL(Acosh, CudaAcoshFunctor) +DEFINE_GPU_ACTIVATION_KERNEL(Atanh, CudaAtanhFunctor) +DEFINE_GPU_ACTIVATION_KERNEL(Relu, CudaReluFunctor) +DEFINE_GPU_ACTIVATION_KERNEL(Tanh, CudaTanhFunctor) +DEFINE_GPU_ACTIVATION_KERNEL(TanhShrink, CudaTanhShrinkFunctor) +DEFINE_GPU_ACTIVATION_KERNEL(Silu, CudaSiluFunctor) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, CudaLeakyReluFunctor, alpha) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, CudaThresholdedReluFunctor, threshold) +DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink, + CudaHardShrinkFunctor, + threshold) +DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(SoftShrink, CudaSoftShrinkFunctor, lambda) +DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Elu, CudaELUFunctor, alpha) DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(BRelu, CudaBReluFunctor, t_min, t_max) @@ -142,3 +150,8 @@ PD_REGISTER_ACTIVATION_KERNEL(tanh, TanhKernel) PD_REGISTER_ACTIVATION_KERNEL(brelu, BReluKernel) PD_REGISTER_ACTIVATION_KERNEL(thresholded_relu, ThresholdedReluKernel) PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel) +PD_REGISTER_ACTIVATION_KERNEL(hard_shrink, HardShrinkKernel) +PD_REGISTER_ACTIVATION_KERNEL(soft_shrink, SoftShrinkKernel) +PD_REGISTER_ACTIVATION_KERNEL(tanh_shrink, TanhShrinkKernel) +PD_REGISTER_ACTIVATION_KERNEL(elu, EluKernel) +PD_REGISTER_ACTIVATION_KERNEL(silu, SiluKernel) diff --git a/paddle/phi/kernels/impl/activation_grad_impl.h b/paddle/phi/kernels/impl/activation_grad_impl.h index a48a6226f23f8..a95f49c0e7cfd 100644 --- a/paddle/phi/kernels/impl/activation_grad_impl.h +++ b/paddle/phi/kernels/impl/activation_grad_impl.h @@ -202,4 +202,24 @@ void TanhTripleGradKernel(const Context& dev_ctx, d_ddx); // output } +template +void EluDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& dout, + const DenseTensor& ddx, + float alpha, + DenseTensor* dx, + DenseTensor* ddout) { + if (dx) { + dx->Resize(x.dims()); + dev_ctx.template Alloc(dx); + } + if (ddout) { + dev_ctx.template Alloc(ddout); + } + funcs::ELUGradGradFunctor functor; + functor.alpha = alpha; + functor(dev_ctx, &x, &ddx, ddout, &dout, dx); +} + } // namespace phi diff --git a/paddle/phi/ops/compat/activation_sig.cc b/paddle/phi/ops/compat/activation_sig.cc index cbfca5b17ae99..890dbadf17c81 100644 --- a/paddle/phi/ops/compat/activation_sig.cc +++ b/paddle/phi/ops/compat/activation_sig.cc @@ -16,45 +16,49 @@ limitations under the License. */ namespace phi { -#define DefineActGradDepXOpArgMap(func_name, op_name, attrs) \ - KernelSignature func_name##GradOpArgumentMapping( \ - const ArgumentMappingContext& ctx) { \ - return KernelSignature(op_name "_grad", \ - {"X", GradVarName("Out")}, \ - {attrs}, \ - {GradVarName("X")}); \ +#define DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(func_name, op_name, attrs) \ + KernelSignature func_name##GradOpArgumentMapping( \ + const ArgumentMappingContext& ctx) { \ + return KernelSignature(op_name "_grad", \ + {"X", GradVarName("Out")}, \ + {attrs}, \ + {GradVarName("X")}); \ } -#define DefineActGradDepOutOpArgMap(func_name, op_name, attrs) \ - KernelSignature func_name##GradOpArgumentMapping( \ - const ArgumentMappingContext& ctx) { \ - return KernelSignature(op_name "_grad", \ - {"Out", GradVarName("Out")}, \ - {attrs}, \ - {GradVarName("X")}); \ +#define DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(func_name, op_name, attrs) \ + KernelSignature func_name##GradOpArgumentMapping( \ + const ArgumentMappingContext& ctx) { \ + return KernelSignature(op_name "_grad", \ + {"Out", GradVarName("Out")}, \ + {attrs}, \ + {GradVarName("X")}); \ } #define comma , -DefineActGradDepXOpArgMap(Cos, "cos", ); // NOLINT -DefineActGradDepXOpArgMap(Tan, "tan", ); // NOLINT -DefineActGradDepXOpArgMap(Acos, "acos", ); // NOLINT -DefineActGradDepXOpArgMap(Sin, "sin", ); // NOLINT -DefineActGradDepXOpArgMap(Asin, "asin", ); // NOLINT -DefineActGradDepXOpArgMap(Atan, "atan", ); // NOLINT -DefineActGradDepXOpArgMap(Sinh, "sinh", ); // NOLINT -DefineActGradDepXOpArgMap(Cosh, "cosh", ); // NOLINT -DefineActGradDepXOpArgMap(Asinh, "asinh", ); // NOLINT -DefineActGradDepXOpArgMap(Acosh, "acosh", ); // NOLINT -DefineActGradDepXOpArgMap(Atanh, "atanh", ); // NOLINT -DefineActGradDepXOpArgMap(BRelu, "brelu", "t_min" comma "t_max"); // NOLINT -DefineActGradDepXOpArgMap(LeakyRelu, "leaky_relu", "alpha"); // NOLINT -DefineActGradDepXOpArgMap(ThresholdedRelu, - "thresholded_relu", - "threshold"); // NOLINT - -DefineActGradDepOutOpArgMap(Relu, "relu", ); // NOLINT -DefineActGradDepOutOpArgMap(Tanh, "tanh", ); // NOLINT +DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Cos, "cos", ); // NOLINT +DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Tan, "tan", ); // NOLINT +DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Acos, "acos", ); // NOLINT +DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Sin, "sin", ); // NOLINT +DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Asin, "asin", ); // NOLINT +DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Atan, "atan", ); // NOLINT +DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Sinh, "sinh", ); // NOLINT +DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Cosh, "cosh", ); // NOLINT +DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Asinh, "asinh", ); // NOLINT +DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Acosh, "acosh", ); // NOLINT +DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Atanh, "atanh", ); // NOLINT +DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(BRelu, "brelu", "t_min" comma "t_max"); +DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(LeakyRelu, "leaky_relu", "alpha"); +DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(ThresholdedRelu, + "thresholded_relu", + "threshold"); +DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(SoftShrink, "soft_shrink", "lambda"); +DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(HardShrink, "hard_shrink", "threshold"); +DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(TanhShrink, "tanh_shrink", ); // NOLINT +DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Silu, "silu", ); // NOLINT + +DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Relu, "relu", ); // NOLINT +DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Tanh, "tanh", ); // NOLINT KernelSignature ReluDoubleGradOpArgumentMapping( const ArgumentMappingContext& ctx) { @@ -85,11 +89,31 @@ KernelSignature LeakyReluOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature("leaky_relu", {"X"}, {"alpha"}, {"Out"}); } +KernelSignature EluOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("elu", {"X"}, {"alpha"}, {"Out"}); +} + +KernelSignature EluGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("elu_grad", + {"X", "Out", GradVarName("Out")}, + {"alpha"}, + {GradVarName("X")}); +} + +KernelSignature EluDoubleGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "elu_double_grad", {"X", "DOut", "DDX"}, {"alpha"}, {"DX", "DDOut"}); +} + } // namespace phi PD_REGISTER_BASE_KERNEL_NAME(relu_grad_grad, relu_double_grad); PD_REGISTER_BASE_KERNEL_NAME(tanh_grad_grad, tanh_double_grad); PD_REGISTER_BASE_KERNEL_NAME(leaky_relu_grad_grad, leaky_relu_double_grad); +PD_REGISTER_BASE_KERNEL_NAME(softshrink, soft_shrink); +PD_REGISTER_BASE_KERNEL_NAME(softshrink_grad, soft_shrink_grad); +PD_REGISTER_BASE_KERNEL_NAME(elu_grad_grad, elu_double_grad); PD_REGISTER_ARG_MAPPING_FN(cos_grad, phi::CosGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(tan_grad, phi::TanGradOpArgumentMapping); @@ -118,3 +142,13 @@ PD_REGISTER_ARG_MAPPING_FN(leaky_relu_grad_grad, phi::LeakyReluDoubleGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(thresholded_relu_grad, phi::ThresholdedReluGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(softshrink_grad, + phi::SoftShrinkGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(hard_shrink_grad, + phi::HardShrinkGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(tanh_shrink_grad, + phi::TanhShrinkGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(elu, phi::EluOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(elu_grad, phi::EluGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(elu_grad_grad, phi::EluDoubleGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(silu_grad, phi::SiluGradOpArgumentMapping);