-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Description
Expected behavior
TVM should successfully compile a model whose operators are supported.
Actual behavior
The compilation could fail when the model contains the recently supported trilu operator.
In the Steps to reproduce section, the minimal reproducible is derived from an ONNX model exported by PyTorch which uses int64 as shape arguments, mixing with int32 constants in TVM's frontend translator, causing the compilation to fail due to int32-int64 mismatch in check_op:
tvm/python/tvm/topi/transform.py
Line 1057 in bdcfa01
| check_position = check_op(row_index, col_index - k) |
A quick fix could just be aligning integer types of row_index and col_index - k before doing check_op.
Environment
fa17da22c73fb9e95c27e4c28130835b628caf6b on Ubuntu 20.04.
Steps to reproduce
Minimized reproducible.
import tvm
from tvm import relay
x1 = relay.var("x1", shape=[2, 1], dtype="float32")
x2 = relay.var("x2", shape=(1, 1, 1, 1), dtype="float32")
x3 = relay.var("x3", shape=(), dtype="int64")
v0 = relay.broadcast_to(x1, shape=relay.const([2, 1], dtype="int64"))
v2 = relay.divide(x2, v0)
v3 = relay.trilu(v0, x3)
f = relay.Function([x1, x2, x3], relay.Tuple([v2, v3]))
relay.create_executor("graph", device=tvm.cpu(), target="llvm").evaluate(f)Log. Click to expand!
"""
Traceback (most recent call last):
File "test.py", line 12, in <module>
relay.create_executor("graph", device=tvm.cpu(), target="llvm").evaluate(f)
...
25: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::FunctionNode const*)
24: tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::FunctionNode const*)
23: _ZN3tvm5relay9
22: tvm::relay::ExprMutator::VisitExpr_(tvm::relay::FunctionNode const*)
21: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
20: tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
19: tvm::NodeFunctor<tvm::RelayExpr (tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>*) const
18: _ZZN3tvm5relay11ExprFunc
17: tvm::relay::ExprMutator::VisitExpr_(tvm::relay::TupleNode const*)
16: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
15: tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
14: tvm::NodeFunctor<tvm::RelayExpr (tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>*) const
13: _ZZN3tvm5relay11ExprFunc
12: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::CallNode const*)
11: tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::CallNode const*)
10: tvm::relay::tec::TECompilerImpl::Lower(tvm::relay::tec::CCacheKey const&)
9: tvm::relay::tec::TECompilerImpl::LowerInternal(tvm::relay::tec::CCacheKey const&, tvm::GlobalVarSupply)
8: tvm::relay::tec::PrimFuncFor(tvm::relay::Function const&, tvm::Target const&, tvm::GlobalVarSupply)
7: tvm::relay::tec::ScheduleBuilder::Create(tvm::relay::Function const&, tvm::GlobalVarSupply)
6: tvm::relay::tec::LowerToTECompute::Lower(tvm::relay::Function const&)
5: tvm::relay::backend::MemoizedExprTranslator<tvm::runtime::Array<tvm::te::Tensor, void> >::VisitExpr(tvm::RelayExpr const&)
4: tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
3: tvm::NodeFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>*) const
2: _ZZN3tvm5relay11ExprFunc
1: tvm::relay::tec::LowerToTECompute::VisitExpr_(tvm::relay::CallNode const*)
0: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<TVMFuncCreateFromCFunc::$_2> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
File "/home/jiawei/dev/tvm-official-release/python/tvm/_ffi/_ctypes/packed_func.py", line 81, in cfun
rv = local_pyfunc(*pyargs)
File "/home/jiawei/dev/tvm-official-release/python/tvm/relay/backend/te_compiler.py", line 317, in lower_call
best_impl, outputs = select_implementation(op, call.attrs, inputs, ret_type, target)
File "/home/jiawei/dev/tvm-official-release/python/tvm/relay/backend/te_compiler.py", line 207, in select_implementation
outs = impl.compute(attrs, inputs, out_type)
File "/home/jiawei/dev/tvm-official-release/python/tvm/relay/op/op.py", line 126, in compute
return _OpImplementationCompute(self, attrs, inputs, out_type)
File "/home/jiawei/dev/tvm-official-release/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
raise get_last_ffi_error()
3: TVMFuncCall
2: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::relay::$_3> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
1: tvm::relay::OpImplementation::Compute(tvm::Attrs const&, tvm::runtime::Array<tvm::te::Tensor, void> const&, tvm::Type const&)
0: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<TVMFuncCreateFromCFunc::$_2> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
File "/home/jiawei/dev/tvm-official-release/python/tvm/_ffi/_ctypes/packed_func.py", line 81, in cfun
rv = local_pyfunc(*pyargs)
File "/home/jiawei/dev/tvm-official-release/python/tvm/relay/op/strategy/generic.py", line 1489, in _compute_trilu
topi_compute(
File "/home/jiawei/dev/tvm-official-release/python/tvm/topi/transform.py", line 1061, in trilu
return te.compute(data.shape, _apply_trilu, name="trilu")
File "/home/jiawei/dev/tvm-official-release/python/tvm/te/operation.py", line 132, in compute
body = fcompute(*[v.var for v in dim_var])
File "/home/jiawei/dev/tvm-official-release/python/tvm/topi/transform.py", line 1057, in _apply_trilu
check_position = check_op(row_index, col_index - k)
File "/home/jiawei/dev/tvm-official-release/python/tvm/tir/expr.py", line 881, in __init__
self.__init_handle_by_constructor__(_ffi_api.LE, a, b, span) # type: ignore
File "/home/jiawei/dev/tvm-official-release/python/tvm/_ffi/_ctypes/object.py", line 145, in __init_handle_by_constructor__
handle = __init_by_constructor__(fconstructor, args)
File "/home/jiawei/dev/tvm-official-release/python/tvm/_ffi/_ctypes/packed_func.py", line 260, in __init_handle_by_constructor__
raise get_last_ffi_error()
2: TVMFuncCall
1: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::tir::LE (tvm::PrimExpr, tvm::PrimExpr, tvm::Span)>::AssignTypedLambda<tvm::tir::$_51>(tvm::tir::$_51, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
0: tvm::tir::LE::LE(tvm::PrimExpr, tvm::PrimExpr, tvm::Span)
File "/home/jiawei/dev/tvm-official-release/src/tir/ir/expr.cc", line 459
TypeError: Check failed: (a.dtype() == b.dtype()) is false: mismatched types. int32 vs. int64
"""Triage
Please refer to the list of label tags linked above to find the relevant tags and add them here in a bullet format (example below).
- needs-triage
cc: @jwfromm