-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Description
When a torch model contains strided slicing of an N-D array, the stride argument is ignored and non-strided version is returned. Because the striding is ignored, the size and the contents of the result are wrong.
Pseudocode example:
x = Tensor[1,4]
y = x[:, 0::2] # => shape:(1, 4), should be (1, 2)
A small script triggering the bug:
import torch
from tvm import relay
class TriggerBug(torch.nn.Module):
def __init__(self):
super(TriggerBug, self).__init__()
def forward(self, x):
return x[..., 0::2] + x[..., 1::2]
#x_in = torch.randn(4) # this would work
x_in = torch.randn(1, 4) # this doesn't work
torch_model = TriggerBug()
traced_model = torch.jit.trace(torch_model, (x_in,))
mod, params = relay.frontend.from_pytorch(traced_model, [('x_in', x_in.shape)])
The output is:
mod, params = relay.frontend.from_pytorch(traced_model, [('x_in', x_in.shape)])
File "/Users/name/opt/anaconda3/envs/env/lib/python3.8/site-packages/tvm-0.7.dev1-py3.8-macosx-10.9-x86_64.egg/tvm/relay/frontend/pytorch.py", line 2788, in from_pytorch
mod["main"] = tvm.relay.Function(_analysis.free_vars(ret[0]), ret[0])
File "/Users/name/opt/anaconda3/envs/env/lib/python3.8/site-packages/tvm-0.7.dev1-py3.8-macosx-10.9-x86_64.egg/tvm/ir/module.py", line 74, in setitem
return self._add(var, val)
File "/Users/name/opt/anaconda3/envs/env/lib/python3.8/site-packages/tvm-0.7.dev1-py3.8-macosx-10.9-x86_64.egg/tvm/ir/module.py", line 83, in _add
_ffi_api.Module_Add(self, var, val, update)
File "/Users/name/opt/anaconda3/envs/env/lib/python3.8/site-packages/tvm-0.7.dev1-py3.8-macosx-10.9-x86_64.egg/tvm/_ffi/_ctypes/packed_func.py", line 225, in call
raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
[bt] (8) 9 libffi.7.dylib 0x000000010ebd7ead ffi_call_unix64 + 85
[bt] (7) 8 libtvm.dylib 0x000000012801c3c8 TVMFuncCall + 72
[bt] (6) 7 libtvm.dylib 0x000000012754260c std::__1::__function::__func<tvm::$_3, std::__1::allocatortvm::$_3, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 492
[bt] (5) 6 libtvm.dylib 0x0000000127536437 tvm::IRModuleNode::Add(tvm::GlobalVar const&, tvm::BaseFunc const&, bool) + 183
[bt] (4) 5 libtvm.dylib 0x0000000127535edf tvm::RunTypeCheck(tvm::IRModule const&, tvm::GlobalVar const&, tvm::relay::Function) + 1103
[bt] (3) 4 libtvm.dylib 0x0000000127e773e0 tvm::relay::InferType(tvm::relay::Function const&, tvm::IRModule const&, tvm::GlobalVar const&) + 544
[bt] (2) 3 libtvm.dylib 0x0000000127e76587 tvm::relay::TypeInferencer::Infer(tvm::RelayExpr) + 119
[bt] (1) 2 libtvm.dylib 0x000000012752714c tvm::ErrorReporter::RenderErrors(tvm::IRModule const&, bool) + 5308
[bt] (0) 1 libtvm.dylib 0x000000012735447f dmlc::LogMessageFatal::~LogMessageFatal() + 111
File "/Users/puu/code/python/tvm_fix/tvm/src/ir/error.cc", line 132
TVMError:
Error(s) have occurred. The program has been annotated with them:In
main:
#[version = "0.0.5"]
fn (%x_in: Tensor[(1, 4), float32]) {
%0 = strided_slice(%x_in, meta[relay.Constant][0], meta[relay.Constant][1], meta[relay.Constant][2], begin=[0, 0], end=[1, 4], strides=[2]);
%1 = strided_slice(%x_in, meta[relay.Constant][3], meta[relay.Constant][4], meta[relay.Constant][5], begin=[0, 1], end=[1, 4], strides=[2]);
add(%0, %1) Incompatible broadcast type TensorType([1, 4], float32) and TensorType([1, 3], float32);
}
This suggests that the slicing ignores the stride argument.
If the tensor to be sliced is 1D, the result is correct, but even a 2D fails.
#6316 seems somehow related, but unfortunately it doesn't fix this issue.