Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,51 @@
#pylint: disable=invalid-name, unused-argument
"""Backend compiler related feature registration"""
from __future__ import absolute_import
import topi
import topi.cuda
from tvm import container
from . import op as _reg
from .op import schedule_injective, OpPattern
from .op import (schedule_injective, register_compute, register_schedule,
register_pattern, OpPattern)

schedule_broadcast = schedule_injective

# squeeze
@register_compute("squeeze")
def squeeze_compiler(attrs, inputs, output_type, target):
"""Compiler for squeeze dims."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

squeeze already support negative index, and None I think so you just have to redirect to topi call

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should rename this to "compute" to match the style in the rest of the system.

assert len(inputs) == 1

if attrs.axis is None:
axis = None
elif isinstance(attrs.axis, container.Array):
axis = tuple(attrs.axis)
else:
axis = int(attrs.axis)

return [topi.squeeze(inputs[0], axis)]

register_pattern("squeeze", OpPattern.INJECTIVE)
register_schedule("squeeze", schedule_injective)

# expand_dims
@register_compute("expand_dims")
def expand_dims_compiler(attrs, inputs, output_type, target):
"""Compiler for expand_dims."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sure expand_dims in topi support negative index(I think they may already)

assert len(inputs) == 1

new_axis = int(attrs.num_newaxis)
assert new_axis >= 0

# axis should be in range [-data.ndim - 1, data.ndim]
axis = int(attrs.axis)
assert axis >= -len(inputs[0].shape) - 1
assert axis <= len(inputs[0].shape)

return [topi.expand_dims(inputs[0], axis, new_axis)]

_reg.register_schedule("expand_dims", schedule_broadcast)
_reg.register_pattern("expand_dims", OpPattern.BROADCAST)

# strided_slice
_reg.register_schedule("strided_slice", schedule_injective)
Expand Down
17 changes: 17 additions & 0 deletions tests/python/relay/test_op_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,22 @@ def check_binary_op(opfunc, ref):
check_binary_op(opfunc, ref)


def test_expand_dims():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Invoke this function from __main__

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do

# based on topi test
def verify_expand_dims(dshape, dtype, oshape, axis, num_newaxis):
x = relay.Var("x", relay.TensorType(dshape, dtype))
func = relay.Function([x], relay.expand_dims(x, axis, num_newaxis))
for target, ctx in ctx_list():
data = np.random.uniform(size=dshape).astype(dtype)
ref_res = data.reshape(oshape)
intrp = relay.create_executor("graph", ctx=ctx, target=target)
op_res = intrp.evaluate(func)(data)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)

verify_expand_dims((3, 10), 'float32', (3, 10, 1, 1), 2, 2)
verify_expand_dims((3, 10), 'float32', (1, 3, 10), -3, 1)


def test_bias_add():
xshape=(10, 2, 3, 4)
bshape=(2,)
Expand Down Expand Up @@ -295,6 +311,7 @@ def test_dense():
test_binary_op()
test_expand_dims_infer_type()
test_concatenate()
test_expand_dims()
test_softmax()
test_log_softmax()
test_dropout()
Expand Down
17 changes: 17 additions & 0 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,22 @@ def test_clip():
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)


def test_squeeze():
def verify_squeeze(shape, dtype, axis):
x = relay.var("x", relay.TensorType(shape, dtype))
squeeze = relay.squeeze(x, axis=axis)

np_axis = tuple(axis) if axis is not None else None

data = np.random.random_sample(shape).astype(dtype)
intrp = create_executor()
op_res = intrp.evaluate(squeeze, { x : relay.const(data) })
ref_res = np.squeeze(data, axis=np_axis)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)

verify_squeeze((1, 3, 2, 5), "float32", None)
verify_squeeze((1, 3, 1), "float32", [0])
verify_squeeze((1, 2, 1, 2, 1), "float32", [0, 2])


def test_transpose_infer_type():
Expand Down Expand Up @@ -295,6 +311,7 @@ def test_infer_type_prelu():
test_full_like()
test_infer_type_leaky_relu()
test_infer_type_prelu()
test_squeeze()
test_squeeze_infer_type()
test_squeeze_bad_axes_infer_type()
test_split_infer_type()