-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Description
If strided slice is used in a model, the stride argument is ignored and the result is wrong.
I encountered the problem when trying to compile an ONNX model created by pytorch conversion. Similar problem was present in the pytorch frontend (#6414), and was fixed by #6418.
Possibly related issue #6316.
Code to reproduce the problem:
import torch
import tvm
from tvm import relay
import onnx
class TriggerBug(torch.nn.Module):
def __init__(self):
super(TriggerBug, self).__init__()
def forward(self, x):
return x[..., 0::2] + x[..., 1::2]
x_in = torch.randn(1, 4)
torch_model = TriggerBug()
onnx_name = 'strided_slice.onnx'
example_output = torch_model(x_in)
# convert to ONNX
torch.onnx.export(torch_model, (x_in,), onnx_name,
verbose=True,
example_outputs=example_output,
input_names=['x'],
output_names=['y'],
opset_version=10,
enable_onnx_checker=True)
onnx_model = onnx.load(onnx_name)
target = 'llvm'
shape_dict = {'x': x_in.shape}
mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)
with tvm.transform.PassContext(opt_level=1):
intrp = relay.build_module.create_executor('graph', mod, tvm.cpu(0), target)
The traceback:
mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)
File "/Users/name/opt/anaconda3/envs/tvm/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.9-x86_64.egg/tvm/relay/frontend/onnx.py", line 2456, in from_onnx
mod, params = g.from_onnx(graph, opset)
File "/Users/name/opt/anaconda3/envs/tvm/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.9-x86_64.egg/tvm/relay/frontend/onnx.py", line 2302, in from_onnx
return IRModule.from_expr(func), self._params
File "/Users/name/opt/anaconda3/envs/tvm/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.9-x86_64.egg/tvm/ir/module.py", line 236, in from_expr
return _ffi_api.Module_FromExpr(expr, funcs, defs)
File "/Users/name/opt/anaconda3/envs/tvm/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.9-x86_64.egg/tvm/_ffi/_ctypes/packed_func.py", line 225, in call
raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
[bt] (8) 9 libtvm.dylib 0x0000000122684df8 TVMFuncCall + 72
[bt] (7) 8 libtvm.dylib 0x0000000121b8e452 std::__1::__function::__func<void tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::RelayExpr, tvm::Map<tvm::GlobalVar, tvm::BaseFunc, void, void>, tvm::Map<tvm::GlobalTypeVar, tvm::TypeData, void, void>)>::AssignTypedLambdatvm::$_9(tvm::$_9)::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*), std::__1::allocator<void tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::RelayExpr, tvm::Map<tvm::GlobalVar, tvm::BaseFunc, void, void>, tvm::Map<tvm::GlobalTypeVar, tvm::TypeData, void, void>)>::AssignTypedLambdatvm::$_9(tvm::$_9)::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 610
[bt] (6) 7 libtvm.dylib 0x0000000121b7f810 tvm::IRModule::FromExpr(tvm::RelayExpr const&, tvm::Map<tvm::GlobalVar, tvm::BaseFunc, void, void> const&, tvm::Map<tvm::GlobalTypeVar, tvm::TypeData, void, void> const&) + 1040
[bt] (5) 6 libtvm.dylib 0x0000000121b7ca47 tvm::IRModuleNode::Add(tvm::GlobalVar const&, tvm::BaseFunc const&, bool) + 183
[bt] (4) 5 libtvm.dylib 0x0000000121b7c4ef tvm::RunTypeCheck(tvm::IRModule const&, tvm::GlobalVar const&, tvm::relay::Function) + 1103
[bt] (3) 4 libtvm.dylib 0x00000001224dca20 tvm::relay::InferType(tvm::relay::Function const&, tvm::IRModule const&, tvm::GlobalVar const&) + 544
[bt] (2) 3 libtvm.dylib 0x00000001224dbbc7 tvm::relay::TypeInferencer::Infer(tvm::RelayExpr) + 119
[bt] (1) 2 libtvm.dylib 0x0000000121b6d87c tvm::ErrorReporter::RenderErrors(tvm::IRModule const&, bool) + 5308
[bt] (0) 1 libtvm.dylib 0x00000001219917bf dmlc::LogMessageFatal::~LogMessageFatal() + 111
File "/Users/name/code/python/tvm/src/ir/error.cc", line 132
TVMError:
Error(s) have occurred. The program has been annotated with them:In
main:
#[version = "0.0.5"]
fn (%x: Tensor[(1, 4), float32]) {
%0 = strided_slice(%x, begin=[0, 0], end=[2147483647, 9223372036854775807], strides=[1]);
%1 = strided_slice(%x, begin=[0, 1], end=[2147483647, 9223372036854775807], strides=[1]);
add(%0, %1) Incompatible broadcast type TensorType([1, 4], float32) and TensorType([1, 3], float32);
}
The intermediate ONNX graph is:
graph(%x : Float(1:4, 4:1, requires_grad=0, device=cpu)):
%1 : Tensor = onnx::Constantvalue={1}
%2 : Tensor = onnx::Constantvalue={0}
%3 : Tensor = onnx::Constantvalue={9223372036854775807}
%4 : Tensor = onnx::Constantvalue={2}
%5 : Float(1:4, 2:2, requires_grad=0, device=cpu) = onnx::Slice(%x, %2, %3, %1, %4)
%6 : Tensor = onnx::Constantvalue={1}
%7 : Tensor = onnx::Constantvalue={1}
%8 : Tensor = onnx::Constantvalue={9223372036854775807}
%9 : Tensor = onnx::Constantvalue={2}
%10 : Float(1:4, 2:2, requires_grad=0, device=cpu) = onnx::Slice(%x, %7, %8, %6, %9)
%y : Float(1:2, 2:1, requires_grad=0, device=cpu) = onnx::Add(%5, %10)
return (%y)
Here the stride length is correctly present.
Versions:
- pytorch: 1.7.0.dev20200908
- TVM: 0.7.dev1 git revision 84fa626
- onnx: 1.7.0
If you are asking why am I going this route via ONNX and not use directly pytorch frontend: The compilation of my real model from pytorch does not currently work, but I have verified that the converted ONNX version works. I was hoping that the ONNX frontend could then compile the full model.