-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Relay][Op] Add compute, schedule, and tests for expand_dims and squeeze #2133
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
52a87d4
c500be5
9104d41
94672f0
b719e50
3655882
aace0d3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,8 +1,51 @@ | ||
| #pylint: disable=invalid-name, unused-argument | ||
| """Backend compiler related feature registration""" | ||
| from __future__ import absolute_import | ||
| import topi | ||
| import topi.cuda | ||
| from tvm import container | ||
| from . import op as _reg | ||
| from .op import schedule_injective, OpPattern | ||
| from .op import (schedule_injective, register_compute, register_schedule, | ||
| register_pattern, OpPattern) | ||
|
|
||
| schedule_broadcast = schedule_injective | ||
|
|
||
| # squeeze | ||
| @register_compute("squeeze") | ||
| def squeeze_compiler(attrs, inputs, output_type, target): | ||
| """Compiler for squeeze dims.""" | ||
| assert len(inputs) == 1 | ||
|
|
||
| if attrs.axis is None: | ||
| axis = None | ||
| elif isinstance(attrs.axis, container.Array): | ||
| axis = tuple(attrs.axis) | ||
| else: | ||
| axis = int(attrs.axis) | ||
|
|
||
| return [topi.squeeze(inputs[0], axis)] | ||
|
|
||
| register_pattern("squeeze", OpPattern.INJECTIVE) | ||
| register_schedule("squeeze", schedule_injective) | ||
|
|
||
| # expand_dims | ||
| @register_compute("expand_dims") | ||
| def expand_dims_compiler(attrs, inputs, output_type, target): | ||
| """Compiler for expand_dims.""" | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. make sure expand_dims in topi support negative index(I think they may already) |
||
| assert len(inputs) == 1 | ||
|
|
||
| new_axis = int(attrs.num_newaxis) | ||
| assert new_axis >= 0 | ||
|
|
||
| # axis should be in range [-data.ndim - 1, data.ndim] | ||
| axis = int(attrs.axis) | ||
| assert axis >= -len(inputs[0].shape) - 1 | ||
| assert axis <= len(inputs[0].shape) | ||
|
|
||
| return [topi.expand_dims(inputs[0], axis, new_axis)] | ||
|
|
||
| _reg.register_schedule("expand_dims", schedule_broadcast) | ||
| _reg.register_pattern("expand_dims", OpPattern.BROADCAST) | ||
|
|
||
| # strided_slice | ||
| _reg.register_schedule("strided_slice", schedule_injective) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -90,6 +90,22 @@ def check_binary_op(opfunc, ref): | |
| check_binary_op(opfunc, ref) | ||
|
|
||
|
|
||
| def test_expand_dims(): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Invoke this function from
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will do |
||
| # based on topi test | ||
| def verify_expand_dims(dshape, dtype, oshape, axis, num_newaxis): | ||
| x = relay.Var("x", relay.TensorType(dshape, dtype)) | ||
| func = relay.Function([x], relay.expand_dims(x, axis, num_newaxis)) | ||
| for target, ctx in ctx_list(): | ||
| data = np.random.uniform(size=dshape).astype(dtype) | ||
| ref_res = data.reshape(oshape) | ||
| intrp = relay.create_executor("graph", ctx=ctx, target=target) | ||
| op_res = intrp.evaluate(func)(data) | ||
| np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01) | ||
|
|
||
| verify_expand_dims((3, 10), 'float32', (3, 10, 1, 1), 2, 2) | ||
| verify_expand_dims((3, 10), 'float32', (1, 3, 10), -3, 1) | ||
|
|
||
|
|
||
| def test_bias_add(): | ||
| xshape=(10, 2, 3, 4) | ||
| bshape=(2,) | ||
|
|
@@ -295,6 +311,7 @@ def test_dense(): | |
| test_binary_op() | ||
| test_expand_dims_infer_type() | ||
| test_concatenate() | ||
| test_expand_dims() | ||
| test_softmax() | ||
| test_log_softmax() | ||
| test_dropout() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
squeeze already support negative index, and None I think so you just have to redirect to topi call
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should rename this to "compute" to match the style in the rest of the system.