From 65f872641755daa4ef4f9aade985e0b5a3eef78b Mon Sep 17 00:00:00 2001 From: Leyuan Wang Date: Sun, 24 Sep 2017 12:01:09 -0700 Subject: [PATCH] log_softmax added to topi (#483) --- topi/python/topi/nn/softmax.py | 27 ++++++++++++++++++- topi/python/topi/testing/__init__.py | 2 +- topi/python/topi/testing/softmax_python.py | 22 ++++++++++++++- topi/tests/python/test_topi_softmax.py | 31 ++++++++++++++++++++++ 4 files changed, 79 insertions(+), 3 deletions(-) diff --git a/topi/python/topi/nn/softmax.py b/topi/python/topi/nn/softmax.py index e3b19cff9248..9060a31f532b 100644 --- a/topi/python/topi/nn/softmax.py +++ b/topi/python/topi/nn/softmax.py @@ -1,5 +1,5 @@ # pylint: disable=invalid-name -"""TVM operator softmax compute.""" +"""TVM operator for softmax and log_softmax compute.""" from __future__ import absolute_import import tvm @@ -26,3 +26,28 @@ def softmax(x): (m, ), lambda i: tvm.sum(tvm.exp(x[i, k] - max_elem[i]), axis=k)) return tvm.compute( x.shape, lambda i, j: tvm.exp(x[i, j] - max_elem[i]) / expsum[i]) + +@tvm.tag_scope(tag='log_softmax_output') +def log_softmax(x): + """Perform log softmax activation on the data + + Parameters + ---------- + data : tvm.Tensor + 2-D input data + + Returns + ------- + output : tvm.Tensor + 2-D output with same shape + """ + + assert len(x.shape) == 2, "only support 2-dim log softmax" + m, n = x.shape + k = tvm.reduce_axis((0, n), name='k') + max_elem = tvm.compute((m, ), lambda i: tvm.max(x[i, k], axis=k)) + k = tvm.reduce_axis((0, n), name='k') + expsum = tvm.compute( + (m, ), lambda i: tvm.sum(tvm.exp(x[i, k] - max_elem[i]), axis=k)) + return tvm.compute( + x.shape, lambda i, j: x[i, j] - max_elem[i] - tvm.log(expsum[i])) diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py index 1a715eb4fdc8..2a1866d2f1ef 100644 --- a/topi/python/topi/testing/__init__.py +++ b/topi/python/topi/testing/__init__.py @@ -8,4 +8,4 @@ from .conv2d_nchw_python import conv2d_nchw_python from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc from .dilate_python import dilate_python -from .softmax_python import softmax_python +from .softmax_python import softmax_python, log_softmax_python diff --git a/topi/python/topi/testing/softmax_python.py b/topi/python/topi/testing/softmax_python.py index 5dc561a56e2a..0cdbee1bdc07 100644 --- a/topi/python/topi/testing/softmax_python.py +++ b/topi/python/topi/testing/softmax_python.py @@ -1,5 +1,5 @@ # pylint: disable=invalid-name, trailing-whitespace -"""Softmax operation in python""" +"""Softmax and log_softmax operation in python""" import numpy as np def softmax_python(a_np): @@ -21,3 +21,23 @@ def softmax_python(a_np): expsum = np.sum(e, axis=1) out_np = e / expsum[:, None] return out_np + +def log_softmax_python(a_np): + """Log_softmax operator. + Parameters + ---------- + a_np : numpy.ndarray + 2-D input data + + Returns + ------- + output_np : numpy.ndarray + 2-D output with same shape + """ + assert len(a_np.shape) == 2, "only support 2-dim log_softmax" + max_elem = np.amax(a_np, axis=1) + max_elem = max_elem.reshape(max_elem.shape[0], 1) + e = np.exp(a_np-max_elem) + expsum = np.sum(e, axis=1) + out_np = a_np - max_elem - np.log(expsum[:, None]) + return out_np diff --git a/topi/tests/python/test_topi_softmax.py b/topi/tests/python/test_topi_softmax.py index b5eb9363dc66..cef5762295e5 100644 --- a/topi/tests/python/test_topi_softmax.py +++ b/topi/tests/python/test_topi_softmax.py @@ -36,5 +36,36 @@ def test_softmax(): verify_softmax(3, 4) +def verify_log_softmax(m, n): + A = tvm.placeholder((m, n), name='A') + B = topi.nn.log_softmax(A) + # confirm lower works + s = tvm.create_schedule([B.op]) + tvm.lower(s, [A, B], simple_mode=True) + + s = topi.cuda.schedule_softmax(B) + + a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) + b_np = topi.testing.log_softmax_python(a_np) + + def check_device(device): + if not tvm.module.enabled(device): + print("Skip because %s is not enabled" % device) + return + ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) + foo = tvm.build(s, [A, B], device, name="log_softmax") + foo(a, b) + np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) + + for device in ['cuda', 'opencl', 'metal']: + check_device(device) + +def test_log_softmax(): + verify_log_softmax(32, 10) + verify_log_softmax(3, 4) + if __name__ == "__main__": test_softmax() + test_log_softmax()