diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 055909ba6f486..4b539d5e122a9 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -182,6 +182,13 @@ Exp Operator. Computes exp of x element-wise with a natural number :math:`e` as )DOC"; +UNUSED constexpr char Expm1Doc[] = R"DOC( +Expm1 Operator. Computes expm1 of x element-wise with a natural number :math:`e` as the base. + +$$out = e^x - 1$$ + +)DOC"; + UNUSED constexpr char ReluDoc[] = R"DOC( Relu Activation Operator. @@ -706,6 +713,7 @@ REGISTER_ACTIVATION_OP_MAKER(Sigmoid, SigmoidDoc); REGISTER_ACTIVATION_OP_MAKER(Silu, SiluDoc); REGISTER_ACTIVATION_OP_MAKER(LogSigmoid, LogSigmoidDoc); REGISTER_ACTIVATION_OP_MAKER(Exp, ExpDoc); +REGISTER_ACTIVATION_OP_MAKER(Expm1, Expm1Doc); REGISTER_ACTIVATION_OP_MAKER(Relu, ReluDoc); REGISTER_ACTIVATION_OP_MAKER(Tanh, TanhDoc); REGISTER_ACTIVATION_OP_MAKER(TanhShrink, TanhShrinkDoc); @@ -1346,6 +1354,34 @@ REGISTER_OP_CPU_KERNEL( ops::ExpGradFunctor>); /* ========================================================================== */ +/* ========================== expm1 register ============================ */ +REGISTER_OPERATOR( + expm1, ops::ActivationOp, ops::Expm1OpMaker, ops::ActivationOpInferVarType, + ops::ActivationGradOpMaker::FwdDeps(), + paddle::framework::OpDesc>, + ops::ActivationGradOpMaker::FwdDeps(), + paddle::imperative::OpBase>, + std::conditional>(), + ops::ActFwdInplaceInferer, void>::type); +REGISTER_OPERATOR(expm1_grad, ops::ActivationOpGrad, + ops::ActivationGradOpInplaceInferer); + +REGISTER_OP_CPU_KERNEL(expm1, + ops::ActivationKernel>, + ops::ActivationKernel>, + ops::ActivationKernel>); +REGISTER_OP_CPU_KERNEL( + expm1_grad, ops::ActivationGradKernel>, + ops::ActivationGradKernel>, + ops::ActivationGradKernel>); +/* ========================================================================== */ + /* ========================== Log register ==================================*/ REGISTER_OPERATOR( log, ops::ActivationOp, ops::LogOpMaker, ops::ActivationOpInferVarType, diff --git a/paddle/fluid/operators/activation_op.cu b/paddle/fluid/operators/activation_op.cu index 87e65e8817798..77e6d6294a1ac 100644 --- a/paddle/fluid/operators/activation_op.cu +++ b/paddle/fluid/operators/activation_op.cu @@ -564,6 +564,30 @@ struct CudaExpGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; +template +struct CudaExpm1Functor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + + // expm1(x) = expm1(x) + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T* args) const { + MPType x = static_cast(args[0]); + return static_cast(expm1(x)); + } +}; + +template +struct CudaExpm1GradFunctor : public BaseActivationFunctor { + // dx = dout * out + // Inputs: args[0], the input dout + // args[1], the input out + __device__ __forceinline__ T operator()(const T* args) const { + return args[0] * args[1] + args[0]; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } +}; + template struct CudaLogFunctor : public BaseActivationFunctor { using MPType = typename details::MPTypeTrait::Type; @@ -1582,6 +1606,24 @@ REGISTER_OP_CUDA_KERNEL( ops::CudaExpGradFunctor>); /* ========================================================================== */ +/* ========================== expm1 register ============================ */ + +REGISTER_OP_CUDA_KERNEL( + expm1, ops::ActivationCudaKernel>, + ops::ActivationCudaKernel>, + ops::ActivationCudaKernel>); +REGISTER_OP_CUDA_KERNEL( + expm1_grad, ops::ActivationGradCudaKernel>, + ops::ActivationGradCudaKernel>, + ops::ActivationGradCudaKernel>); +/* ========================================================================== */ + /* ========================== Log register ==================================*/ REGISTER_ACTIVATION_CUDA_KERNEL(log, Log, CudaLogFunctor, CudaLogGradFunctor); diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index ccd5bf528ba58..b0731eea84d98 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -341,6 +341,26 @@ struct ExpGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; +// expm1(x) = e^x - 1 +template +struct Expm1Functor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.expm1(); + } +}; + +template +struct Expm1GradFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * out + dout; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } +}; + // relu(x) = max(x, 0) template struct ReluCPUFunctor : public BaseActivationFunctor { diff --git a/paddle/fluid/platform/eigen_ext.h b/paddle/fluid/platform/eigen_ext.h index 4eea87e909d1b..eb2e6691af73e 100644 --- a/paddle/fluid/platform/eigen_ext.h +++ b/paddle/fluid/platform/eigen_ext.h @@ -204,6 +204,12 @@ HOSTDEVICE inline paddle::platform::bfloat16 exp( return paddle::platform::bfloat16(::expf(static_cast(a))); } +template <> +HOSTDEVICE inline paddle::platform::bfloat16 expm1( + const paddle::platform::bfloat16& a) { + return paddle::platform::bfloat16(::expm1f(static_cast(a))); +} + template <> HOSTDEVICE inline paddle::platform::bfloat16 erf( const paddle::platform::bfloat16& a) { @@ -555,6 +561,11 @@ HOSTDEVICE inline float16 exp(const float16& a) { return float16(::expf(static_cast(a))); } +template <> +HOSTDEVICE inline float16 expm1(const float16& a) { + return float16(::expm1f(static_cast(a))); +} + template <> HOSTDEVICE inline float16 erf(const float16& a) { return float16(::erff(static_cast(a))); diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index ee4dcaa897940..dd303472dd98c 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -155,6 +155,7 @@ from .tensor.math import cosh # noqa: F401 from .tensor.math import cumsum # noqa: F401 from .tensor.math import exp # noqa: F401 +from .tensor.math import expm1 # noqa: F401 from .tensor.math import floor # noqa: F401 from .tensor.math import increment # noqa: F401 from .tensor.math import log # noqa: F401 @@ -409,6 +410,7 @@ 'acos', 'logical_xor', 'exp', + 'expm1', 'bernoulli', 'summary', 'sinh', diff --git a/python/paddle/fluid/layers/ops.py b/python/paddle/fluid/layers/ops.py index 813f671e02070..dcce510f1e8b4 100755 --- a/python/paddle/fluid/layers/ops.py +++ b/python/paddle/fluid/layers/ops.py @@ -37,6 +37,7 @@ __unary_func__ = [ 'exp', + 'expm1', 'atan', 'sqrt', 'rsqrt', @@ -161,6 +162,19 @@ """) +add_sample_code(globals()["expm1"], r""" +Examples: + .. code-block:: python + + import paddle + + x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3]) + out = paddle.expm1(x) + print(out) + # [-0.32967997, -0.18126924, 0.10517092, 0.34985882] + +""") + add_sample_code(globals()["tanh"], r""" Examples: .. code-block:: python diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index ef5ac46cede42..98d2493257d61 100755 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -73,6 +73,70 @@ def init_kernel_type(self): pass +class TestExpm1(TestActivation): + def setUp(self): + self.op_type = "expm1" + self.init_dtype() + + np.random.seed(2049) + x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype) + out = np.expm1(x) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TestExpm1API(unittest.TestCase): + def init_dtype(self): + self.dtype = 'float64' + self.shape = [11, 17] + + def setUp(self): + self.init_dtype() + self.x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype) + self.out_ref = np.expm1(self.x) + + self.place = [paddle.CPUPlace()] + if core.is_compiled_with_cuda(): + self.place.append(paddle.CUDAPlace(0)) + + def test_static_api(self): + paddle.enable_static() + + def run(place): + with paddle.static.program_guard(paddle.static.Program()): + X = paddle.fluid.data('X', self.shape, dtype=self.dtype) + out = paddle.expm1(X) + exe = paddle.static.Executor(place) + res = exe.run(feed={'X': self.x}) + for r in res: + self.assertEqual(np.allclose(self.out_ref, r), True) + + for place in self.place: + run(place) + + def test_dygraph_api(self): + def run(place): + paddle.disable_static(place) + X = paddle.to_tensor(self.x) + out = paddle.expm1(X) + self.assertEqual(np.allclose(self.out_ref, out.numpy()), True) + paddle.enable_static() + + for place in self.place: + run(place) + + def test_errors(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + X = paddle.fluid.data('X', self.shape, dtype='int32') + self.assertRaises(TypeError, paddle.expm1, X) + # The input dtype must be float16, float32, float64. + + class TestParameter(object): def test_out_name(self): with fluid.program_guard(fluid.Program()): @@ -2701,6 +2765,7 @@ def test_check_grad(self): create_test_act_fp16_class(TestActivation) +create_test_act_fp16_class(TestExpm1) create_test_act_fp16_class(TestSigmoid) create_test_act_fp16_class(TestSilu) create_test_act_fp16_class(TestLogSigmoid) diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index c8d80fc9bc68c..34c4fb60c2075 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -103,6 +103,7 @@ from .math import cumsum # noqa: F401 from .math import exp # noqa: F401 from .math import exp_ # noqa: F401 +from .math import expm1 # noqa: F401 from .math import floor # noqa: F401 from .math import floor_ # noqa: F401 from .math import increment # noqa: F401 diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 23addcb7e3f4e..562dd85f4bbc5 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -45,6 +45,7 @@ from ..fluid.layers import cosh # noqa: F401 from ..fluid.layers import exp # noqa: F401 from ..fluid.layers import exp_ # noqa: F401 +from ..fluid.layers import expm1 # noqa: F401 from ..fluid.layers import floor # noqa: F401 from ..fluid.layers import floor_ # noqa: F401 from ..fluid.layers import log # noqa: F401