-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Closed
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug
Description
I have this Relay Program:
import numpy as np
import tvm
from tvm import relay, transform, IRModule
x_shape = (1, relay.Any(), 100, 17, 1)
x = relay.var('x', shape=x_shape)
y = relay.squeeze(x, axis=[-1])
y = relay.sqrt(y)
mod = IRModule.from_expr(y)
with transform.PassContext(opt_level=1):
vm_exec = relay.vm.compile(mod, target='llvm')
vm = tvm.runtime.vm.VirtualMachine(vm_exec, tvm.cpu())
x_input = np.zeros(shape=(1, 100, 100, 17, 1), dtype='float32')
data = vm.run(x=x_input)
Expected behavior
run correctly.
Actual behavior
raise error:
Check failed: ret == 0 (-1 vs. 0) : Assert fail: 4 == T.tvm_struct_get(arg_T_sqrt, 0, 4, "int32"), arg.T_sqrt.ndim is expected to equal 4
Environment
None
Steps to reproduce
run this case
I find in squeeze shape function:
tvm/python/tvm/relay/op/_transform.py
Lines 914 to 927 in f21a17b
| @_reg.register_shape_func("squeeze", False) | |
| def squeeze_shape_func(attrs, inputs, _): | |
| """ | |
| Shape function for squeeze op. | |
| """ | |
| axis = attrs.axis if attrs.axis is None else get_const_tuple(attrs.axis) | |
| keep_axes = [] | |
| remove_axes = [] | |
| if axis is not None: | |
| for i in range(inputs[0].shape[0].value): | |
| if i not in axis: | |
| keep_axes.append(i) | |
| else: | |
| remove_axes.append(i) |
Here if the axis is -1, then will get the wrong result.
cc @shingjan
Metadata
Metadata
Assignees
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug