Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion topi/python/topi/nn/softmax.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# pylint: disable=invalid-name
"""TVM operator softmax compute."""
"""TVM operator for softmax and log_softmax compute."""
from __future__ import absolute_import
import tvm

Expand All @@ -26,3 +26,28 @@ def softmax(x):
(m, ), lambda i: tvm.sum(tvm.exp(x[i, k] - max_elem[i]), axis=k))
return tvm.compute(
x.shape, lambda i, j: tvm.exp(x[i, j] - max_elem[i]) / expsum[i])

@tvm.tag_scope(tag='log_softmax_output')
def log_softmax(x):
"""Perform log softmax activation on the data

Parameters
----------
data : tvm.Tensor
2-D input data

Returns
-------
output : tvm.Tensor
2-D output with same shape
"""

assert len(x.shape) == 2, "only support 2-dim log softmax"
m, n = x.shape
k = tvm.reduce_axis((0, n), name='k')
max_elem = tvm.compute((m, ), lambda i: tvm.max(x[i, k], axis=k))
k = tvm.reduce_axis((0, n), name='k')
expsum = tvm.compute(
(m, ), lambda i: tvm.sum(tvm.exp(x[i, k] - max_elem[i]), axis=k))
return tvm.compute(
x.shape, lambda i, j: x[i, j] - max_elem[i] - tvm.log(expsum[i]))
2 changes: 1 addition & 1 deletion topi/python/topi/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
from .conv2d_nchw_python import conv2d_nchw_python
from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc
from .dilate_python import dilate_python
from .softmax_python import softmax_python
from .softmax_python import softmax_python, log_softmax_python
22 changes: 21 additions & 1 deletion topi/python/topi/testing/softmax_python.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# pylint: disable=invalid-name, trailing-whitespace
"""Softmax operation in python"""
"""Softmax and log_softmax operation in python"""
import numpy as np

def softmax_python(a_np):
Expand All @@ -21,3 +21,23 @@ def softmax_python(a_np):
expsum = np.sum(e, axis=1)
out_np = e / expsum[:, None]
return out_np

def log_softmax_python(a_np):
"""Log_softmax operator.
Parameters
----------
a_np : numpy.ndarray
2-D input data

Returns
-------
output_np : numpy.ndarray
2-D output with same shape
"""
assert len(a_np.shape) == 2, "only support 2-dim log_softmax"
max_elem = np.amax(a_np, axis=1)
max_elem = max_elem.reshape(max_elem.shape[0], 1)
e = np.exp(a_np-max_elem)
expsum = np.sum(e, axis=1)
out_np = a_np - max_elem - np.log(expsum[:, None])
return out_np
31 changes: 31 additions & 0 deletions topi/tests/python/test_topi_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,36 @@ def test_softmax():
verify_softmax(3, 4)


def verify_log_softmax(m, n):
A = tvm.placeholder((m, n), name='A')
B = topi.nn.log_softmax(A)
# confirm lower works
s = tvm.create_schedule([B.op])
tvm.lower(s, [A, B], simple_mode=True)

s = topi.cuda.schedule_softmax(B)

a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
b_np = topi.testing.log_softmax_python(a_np)

def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
foo = tvm.build(s, [A, B], device, name="log_softmax")
foo(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)

for device in ['cuda', 'opencl', 'metal']:
check_device(device)

def test_log_softmax():
verify_log_softmax(32, 10)
verify_log_softmax(3, 4)

if __name__ == "__main__":
test_softmax()
test_log_softmax()