Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add expm1 op and test #33066

Merged
merged 1 commit into from
Jun 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions paddle/fluid/operators/activation_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,13 @@ Exp Operator. Computes exp of x element-wise with a natural number :math:`e` as

)DOC";

UNUSED constexpr char Expm1Doc[] = R"DOC(
Expm1 Operator. Computes expm1 of x element-wise with a natural number :math:`e` as the base.

$$out = e^x - 1$$

)DOC";

UNUSED constexpr char ReluDoc[] = R"DOC(
Relu Activation Operator.

Expand Down Expand Up @@ -706,6 +713,7 @@ REGISTER_ACTIVATION_OP_MAKER(Sigmoid, SigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(Silu, SiluDoc);
REGISTER_ACTIVATION_OP_MAKER(LogSigmoid, LogSigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(Exp, ExpDoc);
REGISTER_ACTIVATION_OP_MAKER(Expm1, Expm1Doc);
REGISTER_ACTIVATION_OP_MAKER(Relu, ReluDoc);
REGISTER_ACTIVATION_OP_MAKER(Tanh, TanhDoc);
REGISTER_ACTIVATION_OP_MAKER(TanhShrink, TanhShrinkDoc);
Expand Down Expand Up @@ -1346,6 +1354,34 @@ REGISTER_OP_CPU_KERNEL(
ops::ExpGradFunctor<int64_t>>);
/* ========================================================================== */

/* ========================== expm1 register ============================ */
REGISTER_OPERATOR(
expm1, ops::ActivationOp, ops::Expm1OpMaker, ops::ActivationOpInferVarType,
ops::ActivationGradOpMaker<ops::Expm1GradFunctor<float>::FwdDeps(),
paddle::framework::OpDesc>,
ops::ActivationGradOpMaker<ops::Expm1GradFunctor<float>::FwdDeps(),
paddle::imperative::OpBase>,
std::conditional<ops::CanInplaceAct<ops::Expm1GradFunctor<float>>(),
ops::ActFwdInplaceInferer, void>::type);
REGISTER_OPERATOR(expm1_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInferer);

REGISTER_OP_CPU_KERNEL(expm1,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::Expm1Functor<float>>,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::Expm1Functor<double>>,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::Expm1Functor<plat::float16>>);
REGISTER_OP_CPU_KERNEL(
expm1_grad, ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::Expm1GradFunctor<float>>,
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::Expm1GradFunctor<double>>,
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::Expm1GradFunctor<plat::float16>>);
/* ========================================================================== */

/* ========================== Log register ==================================*/
REGISTER_OPERATOR(
log, ops::ActivationOp, ops::LogOpMaker, ops::ActivationOpInferVarType,
Expand Down
42 changes: 42 additions & 0 deletions paddle/fluid/operators/activation_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,30 @@ struct CudaExpGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};

template <typename T>
struct CudaExpm1Functor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;

// expm1(x) = expm1(x)
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(expm1(x));
}
};

template <typename T>
struct CudaExpm1GradFunctor : public BaseActivationFunctor<T> {
// dx = dout * out
// Inputs: args[0], the input dout
// args[1], the input out
__device__ __forceinline__ T operator()(const T* args) const {
return args[0] * args[1] + args[0];
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};

template <typename T>
struct CudaLogFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
Expand Down Expand Up @@ -1582,6 +1606,24 @@ REGISTER_OP_CUDA_KERNEL(
ops::CudaExpGradFunctor<plat::float16>>);
/* ========================================================================== */

/* ========================== expm1 register ============================ */

REGISTER_OP_CUDA_KERNEL(
expm1, ops::ActivationCudaKernel<plat::CUDADeviceContext,
ops::CudaExpm1Functor<float>>,
ops::ActivationCudaKernel<plat::CUDADeviceContext,
ops::CudaExpm1Functor<double>>,
ops::ActivationCudaKernel<plat::CUDADeviceContext,
ops::CudaExpm1Functor<plat::float16>>);
REGISTER_OP_CUDA_KERNEL(
expm1_grad, ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaExpm1GradFunctor<float>>,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaExpm1GradFunctor<double>>,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaExpm1GradFunctor<plat::float16>>);
/* ========================================================================== */

/* ========================== Log register ==================================*/
REGISTER_ACTIVATION_CUDA_KERNEL(log, Log, CudaLogFunctor, CudaLogGradFunctor);

Expand Down
20 changes: 20 additions & 0 deletions paddle/fluid/operators/activation_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,26 @@ struct ExpGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};

// expm1(x) = e^x - 1
template <typename T>
struct Expm1Functor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.expm1();
}
};

template <typename T>
struct Expm1GradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * out + dout;
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};

// relu(x) = max(x, 0)
template <typename T>
struct ReluCPUFunctor : public BaseActivationFunctor<T> {
Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/platform/eigen_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,12 @@ HOSTDEVICE inline paddle::platform::bfloat16 exp(
return paddle::platform::bfloat16(::expf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline paddle::platform::bfloat16 expm1(
const paddle::platform::bfloat16& a) {
return paddle::platform::bfloat16(::expm1f(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline paddle::platform::bfloat16 erf(
const paddle::platform::bfloat16& a) {
Expand Down Expand Up @@ -555,6 +561,11 @@ HOSTDEVICE inline float16 exp(const float16& a) {
return float16(::expf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 expm1(const float16& a) {
return float16(::expm1f(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 erf(const float16& a) {
return float16(::erff(static_cast<float>(a)));
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@
from .tensor.math import cosh # noqa: F401
from .tensor.math import cumsum # noqa: F401
from .tensor.math import exp # noqa: F401
from .tensor.math import expm1 # noqa: F401
from .tensor.math import floor # noqa: F401
from .tensor.math import increment # noqa: F401
from .tensor.math import log # noqa: F401
Expand Down Expand Up @@ -409,6 +410,7 @@
'acos',
'logical_xor',
'exp',
'expm1',
'bernoulli',
'summary',
'sinh',
Expand Down
14 changes: 14 additions & 0 deletions python/paddle/fluid/layers/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

__unary_func__ = [
'exp',
'expm1',
'atan',
'sqrt',
'rsqrt',
Expand Down Expand Up @@ -161,6 +162,19 @@

""")

add_sample_code(globals()["expm1"], r"""
Examples:
.. code-block:: python

import paddle

x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.expm1(x)
print(out)
# [-0.32967997, -0.18126924, 0.10517092, 0.34985882]

""")

add_sample_code(globals()["tanh"], r"""
Examples:
.. code-block:: python
Expand Down
65 changes: 65 additions & 0 deletions python/paddle/fluid/tests/unittests/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,70 @@ def init_kernel_type(self):
pass


class TestExpm1(TestActivation):
def setUp(self):
self.op_type = "expm1"
self.init_dtype()

np.random.seed(2049)
x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype)
out = np.expm1(x)

self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}

def test_check_grad(self):
self.check_grad(['X'], 'Out')


class TestExpm1API(unittest.TestCase):
def init_dtype(self):
self.dtype = 'float64'
self.shape = [11, 17]

def setUp(self):
self.init_dtype()
self.x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype)
self.out_ref = np.expm1(self.x)

self.place = [paddle.CPUPlace()]
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))

def test_static_api(self):
paddle.enable_static()

def run(place):
with paddle.static.program_guard(paddle.static.Program()):
X = paddle.fluid.data('X', self.shape, dtype=self.dtype)
out = paddle.expm1(X)
exe = paddle.static.Executor(place)
res = exe.run(feed={'X': self.x})
for r in res:
self.assertEqual(np.allclose(self.out_ref, r), True)

for place in self.place:
run(place)

def test_dygraph_api(self):
def run(place):
paddle.disable_static(place)
X = paddle.to_tensor(self.x)
out = paddle.expm1(X)
self.assertEqual(np.allclose(self.out_ref, out.numpy()), True)
paddle.enable_static()

for place in self.place:
run(place)

def test_errors(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
X = paddle.fluid.data('X', self.shape, dtype='int32')
self.assertRaises(TypeError, paddle.expm1, X)
# The input dtype must be float16, float32, float64.


class TestParameter(object):
def test_out_name(self):
with fluid.program_guard(fluid.Program()):
Expand Down Expand Up @@ -2701,6 +2765,7 @@ def test_check_grad(self):


create_test_act_fp16_class(TestActivation)
create_test_act_fp16_class(TestExpm1)
create_test_act_fp16_class(TestSigmoid)
create_test_act_fp16_class(TestSilu)
create_test_act_fp16_class(TestLogSigmoid)
Expand Down
1 change: 1 addition & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
from .math import cumsum # noqa: F401
from .math import exp # noqa: F401
from .math import exp_ # noqa: F401
from .math import expm1 # noqa: F401
from .math import floor # noqa: F401
from .math import floor_ # noqa: F401
from .math import increment # noqa: F401
Expand Down
1 change: 1 addition & 0 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from ..fluid.layers import cosh # noqa: F401
from ..fluid.layers import exp # noqa: F401
from ..fluid.layers import exp_ # noqa: F401
from ..fluid.layers import expm1 # noqa: F401
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensor/init.py 文件也需要修改

from ..fluid.layers import floor # noqa: F401
from ..fluid.layers import floor_ # noqa: F401
from ..fluid.layers import log # noqa: F401
Expand Down