Skip to content

Commit

Permalink
[phi]move softsign from fluid to phi (#44616)
Browse files Browse the repository at this point in the history
* test_activation_op unitest error, yaml & activation.py in_dygraph_mode incomplete

* fix test_activation_op unitest error, add yaml and dygraph test

* fix code style with pre-commit

* try to fix namespace error of abs in activation_functor.h

* fix namespace error of abs
  • Loading branch information
jiahy0825 committed Jul 28, 2022
1 parent 798a4ea commit 20759c3
Show file tree
Hide file tree
Showing 15 changed files with 97 additions and 53 deletions.
4 changes: 4 additions & 0 deletions paddle/fluid/operators/activation_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1513,6 +1513,10 @@ REGISTER_ACTIVATION_OP(tanh_shrink,
TanhShrinkFunctor,
TanhShrinkGradFunctor);
REGISTER_ACTIVATION_OP(silu, Silu, SiluFunctor, SiluGradFunctor);
REGISTER_ACTIVATION_OP(softsign,
Softsign,
SoftsignFunctor,
SoftsignGradFunctor);
REGISTER_ACTIVATION_OP(hard_sigmoid,
HardSigmoid,
HardSigmoidFunctor,
Expand Down
32 changes: 3 additions & 29 deletions paddle/fluid/operators/activation_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ USE_PHI_FUNCTOR(TanhShrink)
USE_PHI_FUNCTOR(Silu)
USE_PHI_FUNCTOR(ELU)
USE_PHI_DOUBLE_GRAD_FUNCTOR(ELU)
USE_PHI_FUNCTOR(Softsign)
USE_PHI_FUNCTOR(Sigmoid)
USE_PHI_DOUBLE_GRAD_FUNCTOR(Sigmoid)
USE_PHI_TRIPLE_GRAD_FUNCTOR(Sigmoid)
Expand Down Expand Up @@ -493,35 +494,8 @@ inline void ExtractDoubleGradTensorWithInputDOut(
}
}

template <typename T>
struct SoftsignFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x / (static_cast<T>(1) + x.abs());
}
};

// d(softsign(x))/dx = 1 / (1 + |x|)^2
// Taken from https://en.wikipedia.org/wiki/Activation_function

template <typename T>
struct SoftsignGradFunctor : public BaseActivationFunctor<T> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) =
dout * (static_cast<T>(1) / (static_cast<T>(1) + x.abs()).square());
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

} // namespace operators
} // namespace paddle

#define FOR_EACH_ACTIVATION_OP(__macro) \
__macro(soft_relu, SoftRelu, SoftReluFunctor, SoftReluGradFunctor); \
__macro(softsign, Softsign, SoftsignFunctor, SoftsignGradFunctor);
#define FOR_EACH_ACTIVATION_OP(__macro) \
__macro(soft_relu, SoftRelu, SoftReluFunctor, SoftReluGradFunctor);
24 changes: 1 addition & 23 deletions paddle/fluid/operators/activation_op.kps
Original file line number Diff line number Diff line change
Expand Up @@ -66,29 +66,6 @@ struct CudaSoftReluGradFunctor : public BaseActivationFunctor<T> {
}
};

template <typename T>
struct CudaSoftsignFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f);

// softsign(x) = x / (1 + abs(x))
__device__ __forceinline__ T operator()(const T x) const {
return x / (one + abs(x));
}
};

template <typename T>
struct CudaSoftsignGradFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f);

// dx = dout / (1 + abs(x))^2
__device__ __forceinline__ T operator()(const T dout, const T x) const {
T temp = one + abs(x);
return dout / (temp * temp);
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename DeviceContext, typename Functor>
class ActivationCudaKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
Expand Down Expand Up @@ -174,6 +151,7 @@ USE_PHI_FUNCTOR(CudaSoftShrink)
USE_PHI_FUNCTOR(CudaTanhShrink)
USE_PHI_FUNCTOR(CudaSilu)
USE_PHI_FUNCTOR(CudaELU)
USE_PHI_FUNCTOR(CudaSoftsign)
USE_PHI_FUNCTOR(CudaSigmoid)
USE_PHI_FUNCTOR(CudaLogSigmoid)
USE_PHI_FUNCTOR(CudaHardSigmoid)
Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/api/yaml/legacy_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2152,6 +2152,17 @@
use_gpudnn : true
backward : softmax_grad

# softsign
- api : softsign
args : (Tensor x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : softsign
backward : softsign_grad

- api : spectral_norm
args : (Tensor weight, Tensor u, Tensor v, int dim, int power_iters, float eps)
output : Tensor
Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2037,6 +2037,17 @@
func : softmax_grad
use_gpudnn : true

- backward_api : softsign_grad
forward : softsign (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : softsign_grad
inplace : (out_grad -> x_grad)

- backward_api : spectral_norm_grad
forward : spectral_norm (Tensor weight, Tensor u, Tensor v, int dim, int power_iters, float eps) -> Tensor(out)
args : (Tensor weight, Tensor u, Tensor v, Tensor out_grad, int dim, int power_iters, float eps)
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/activation_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Atanh);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(TanhShrink);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Silu);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Square);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Softsign);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(LogSigmoid);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Log);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Log2);
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/activation_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ DECLARE_ACTIVATION_KERNEL(Reciprocal)
DECLARE_ACTIVATION_KERNEL(Square)
DECLARE_ACTIVATION_KERNEL(Sqrt)
DECLARE_ACTIVATION_KERNEL(Rsqrt)
DECLARE_ACTIVATION_KERNEL(Softsign)
DECLARE_ACTIVATION_KERNEL(Sigmoid)
DECLARE_ACTIVATION_KERNEL(LogSigmoid)
DECLARE_ACTIVATION_KERNEL(Log)
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/cpu/activation_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Expm1, Expm1GradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Reciprocal, ReciprocalGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Sqrt, SqrtGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Rsqrt, RsqrtGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Softsign, SoftsignGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(LogSigmoid, LogSigmoidGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Log, LogGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Log2, Log2GradFunctor);
Expand Down Expand Up @@ -335,6 +336,7 @@ PD_REGISTER_KERNEL(square_double_grad,
phi::dtype::float16,
int,
int64_t) {}
PD_REGISTER_ACTIVATION_GRAD_KERNEL(softsign_grad, SoftsignGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_double_grad, SigmoidDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_triple_grad, SigmoidTripleGradKernel)
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/cpu/activation_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ DEFINE_CPU_ACTIVATION_KERNEL(Reciprocal, ReciprocalFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Square, SquareFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Sqrt, SqrtFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Rsqrt, RsqrtFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Softsign, SoftsignFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Sigmoid, SigmoidFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(LogSigmoid, LogSigmoidFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Log, LogFunctor)
Expand Down Expand Up @@ -173,6 +174,7 @@ PD_REGISTER_KERNEL(expm1,
PD_REGISTER_KERNEL(logit, CPU, ALL_LAYOUT, phi::LogitKernel, float, double) {}
PD_REGISTER_KERNEL(
square, CPU, ALL_LAYOUT, phi::SquareKernel, float, double, int, int64_t) {}
PD_REGISTER_ACTIVATION_KERNEL(softsign, SoftsignKernel)
PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(logsigmoid, LogSigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_sigmoid, HardSigmoidKernel)
Expand Down
51 changes: 51 additions & 0 deletions paddle/phi/kernels/funcs/activation_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1364,6 +1364,32 @@ struct SiluGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct SoftsignFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x / (static_cast<T>(1) + x.abs());
}
};

// d(softsign(x))/dx = 1 / (1 + |x|)^2
// Taken from https://en.wikipedia.org/wiki/Activation_function

template <typename T>
struct SoftsignGradFunctor : public BaseActivationFunctor<T> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) =
dout * (static_cast<T>(1) / (static_cast<T>(1) + x.abs()).square());
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

// sigmoid(x) = 1 / (1 + exp(-x))
template <typename T>
struct SigmoidFunctor : public BaseActivationFunctor<T> {
Expand Down Expand Up @@ -3019,6 +3045,31 @@ struct CudaSiluGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaSoftsignFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f);

// softsign(x) = x / (1 + abs(x))
__device__ __forceinline__ T operator()(const T x) const {
// Using abs directly will cause namespace conflict
return x / (one + (x > -x ? x : -x));
}
};

template <typename T>
struct CudaSoftsignGradFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f);

// dx = dout / (1 + abs(x))^2
__device__ __forceinline__ T operator()(const T dout, const T x) const {
// Using abs directly will cause namespace conflict
T temp = one + (x > -x ? x : -x);
return dout / (temp * temp);
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaSigmoidFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/gpu/activation_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Expm1, CudaExpm1GradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Reciprocal, CudaReciprocalGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Sqrt, CudaSqrtGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Rsqrt, CudaRsqrtGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Softsign, CudaSoftsignGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(LogSigmoid, CudaLogSigmoidGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Log, CudaLogGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Log2, CudaLog2GradFunctor);
Expand Down Expand Up @@ -415,6 +416,7 @@ PD_REGISTER_KERNEL(square_double_grad,
phi::dtype::float16,
phi::dtype::bfloat16) {}

PD_REGISTER_ACTIVATION_GRAD_KERNEL(softsign_grad, SoftsignGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_double_grad, SigmoidDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_triple_grad, SigmoidTripleGradKernel)
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/gpu/activation_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ DEFINE_GPU_ACTIVATION_KERNEL(Reciprocal, CudaReciprocalFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Square, CudaSquareFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Sqrt, CudaSqrtFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Rsqrt, CudaRsqrtFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Softsign, CudaSoftsignFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Sigmoid, CudaSigmoidFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(LogSigmoid, CudaLogSigmoidFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Log, CudaLogFunctor)
Expand Down Expand Up @@ -241,6 +242,7 @@ PD_REGISTER_ACTIVATION_KERNEL(soft_shrink, SoftShrinkKernel)
PD_REGISTER_ACTIVATION_KERNEL(tanh_shrink, TanhShrinkKernel)
PD_REGISTER_ACTIVATION_KERNEL(elu, EluKernel)
PD_REGISTER_ACTIVATION_KERNEL(silu, SiluKernel)
PD_REGISTER_ACTIVATION_KERNEL(softsign, SoftsignKernel)
PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(logsigmoid, LogSigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_sigmoid, HardSigmoidKernel)
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/ops/compat/activation_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(HardShrink, "hard_shrink", "threshold");
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Mish, "mish", "threshold");
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(TanhShrink, "tanh_shrink", ); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Silu, "silu", ); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Softsign, "softsign", ); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(LogSigmoid, "logsigmoid", ); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log, "log", ); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log2, "log2", ); // NOLINT
Expand Down Expand Up @@ -294,6 +295,7 @@ PD_REGISTER_ARG_MAPPING_FN(elu, phi::EluOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elu_grad, phi::EluGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elu_grad_grad, phi::EluDoubleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(silu_grad, phi::SiluGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(softsign_grad, phi::SoftsignGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(sigmoid_grad, phi::SigmoidGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(sigmoid_grad_grad,
phi::SigmoidDoubleGradOpArgumentMapping);
Expand Down
3 changes: 2 additions & 1 deletion python/paddle/fluid/tests/unittests/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2795,6 +2795,7 @@ class TestSoftsign(TestActivation):
def setUp(self):
self.op_type = "softsign"
self.init_dtype()
self.python_api = paddle.nn.functional.softsign

np.random.seed(1024)
x = np.random.uniform(-1, 1, [10, 12]).astype(self.dtype)
Expand All @@ -2805,7 +2806,7 @@ def setUp(self):
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_eager=True)


class TestSoftsignAPI(unittest.TestCase):
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/nn/functional/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1275,6 +1275,8 @@ def softsign(x, name=None):
x = paddle.to_tensor(np.array([-0.4, -0.2, 0.1, 0.3]))
out = F.softsign(x) # [-0.285714, -0.166667, 0.0909091, 0.230769]
"""
if in_dygraph_mode():
return _C_ops.final_state_softsign(x)
if in_dynamic_mode():
return _C_ops.softsign(x)

Expand Down

0 comments on commit 20759c3

Please sign in to comment.