Skip to content

[Bug] shape int32-int64 check error in trilu's te.compute #13029

@ganler

Description

@ganler

Expected behavior

TVM should successfully compile a model whose operators are supported.

Actual behavior

The compilation could fail when the model contains the recently supported trilu operator.

In the Steps to reproduce section, the minimal reproducible is derived from an ONNX model exported by PyTorch which uses int64 as shape arguments, mixing with int32 constants in TVM's frontend translator, causing the compilation to fail due to int32-int64 mismatch in check_op:

check_position = check_op(row_index, col_index - k)

A quick fix could just be aligning integer types of row_index and col_index - k before doing check_op.

Environment

fa17da22c73fb9e95c27e4c28130835b628caf6b on Ubuntu 20.04.

Steps to reproduce

Minimized reproducible.

import tvm
from tvm import relay

x1 = relay.var("x1", shape=[2, 1], dtype="float32")
x2 = relay.var("x2", shape=(1, 1, 1, 1), dtype="float32")
x3 = relay.var("x3", shape=(), dtype="int64")
v0 = relay.broadcast_to(x1, shape=relay.const([2, 1], dtype="int64"))
v2 = relay.divide(x2, v0)
v3 = relay.trilu(v0, x3)

f = relay.Function([x1, x2, x3], relay.Tuple([v2, v3]))
relay.create_executor("graph", device=tvm.cpu(), target="llvm").evaluate(f)
Log. Click to expand!
"""
Traceback (most recent call last):
  File "test.py", line 12, in <module>
    relay.create_executor("graph", device=tvm.cpu(), target="llvm").evaluate(f)
 ...
  25: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::FunctionNode const*)
  24: tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::FunctionNode const*)
  23: _ZN3tvm5relay9
  22: tvm::relay::ExprMutator::VisitExpr_(tvm::relay::FunctionNode const*)
  21: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  20: tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  19: tvm::NodeFunctor<tvm::RelayExpr (tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>*) const
  18: _ZZN3tvm5relay11ExprFunc
  17: tvm::relay::ExprMutator::VisitExpr_(tvm::relay::TupleNode const*)
  16: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  15: tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  14: tvm::NodeFunctor<tvm::RelayExpr (tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>*) const
  13: _ZZN3tvm5relay11ExprFunc
  12: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::CallNode const*)
  11: tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::CallNode const*)
  10: tvm::relay::tec::TECompilerImpl::Lower(tvm::relay::tec::CCacheKey const&)
  9: tvm::relay::tec::TECompilerImpl::LowerInternal(tvm::relay::tec::CCacheKey const&, tvm::GlobalVarSupply)
  8: tvm::relay::tec::PrimFuncFor(tvm::relay::Function const&, tvm::Target const&, tvm::GlobalVarSupply)
  7: tvm::relay::tec::ScheduleBuilder::Create(tvm::relay::Function const&, tvm::GlobalVarSupply)
  6: tvm::relay::tec::LowerToTECompute::Lower(tvm::relay::Function const&)
  5: tvm::relay::backend::MemoizedExprTranslator<tvm::runtime::Array<tvm::te::Tensor, void> >::VisitExpr(tvm::RelayExpr const&)
  4: tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  3: tvm::NodeFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>*) const
  2: _ZZN3tvm5relay11ExprFunc
  1: tvm::relay::tec::LowerToTECompute::VisitExpr_(tvm::relay::CallNode const*)
  0: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<TVMFuncCreateFromCFunc::$_2> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  File "/home/jiawei/dev/tvm-official-release/python/tvm/_ffi/_ctypes/packed_func.py", line 81, in cfun
    rv = local_pyfunc(*pyargs)
  File "/home/jiawei/dev/tvm-official-release/python/tvm/relay/backend/te_compiler.py", line 317, in lower_call
    best_impl, outputs = select_implementation(op, call.attrs, inputs, ret_type, target)
  File "/home/jiawei/dev/tvm-official-release/python/tvm/relay/backend/te_compiler.py", line 207, in select_implementation
    outs = impl.compute(attrs, inputs, out_type)
  File "/home/jiawei/dev/tvm-official-release/python/tvm/relay/op/op.py", line 126, in compute
    return _OpImplementationCompute(self, attrs, inputs, out_type)
  File "/home/jiawei/dev/tvm-official-release/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
  3: TVMFuncCall
  2: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::relay::$_3> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  1: tvm::relay::OpImplementation::Compute(tvm::Attrs const&, tvm::runtime::Array<tvm::te::Tensor, void> const&, tvm::Type const&)
  0: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<TVMFuncCreateFromCFunc::$_2> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  File "/home/jiawei/dev/tvm-official-release/python/tvm/_ffi/_ctypes/packed_func.py", line 81, in cfun
    rv = local_pyfunc(*pyargs)
  File "/home/jiawei/dev/tvm-official-release/python/tvm/relay/op/strategy/generic.py", line 1489, in _compute_trilu
    topi_compute(
  File "/home/jiawei/dev/tvm-official-release/python/tvm/topi/transform.py", line 1061, in trilu
    return te.compute(data.shape, _apply_trilu, name="trilu")
  File "/home/jiawei/dev/tvm-official-release/python/tvm/te/operation.py", line 132, in compute
    body = fcompute(*[v.var for v in dim_var])
  File "/home/jiawei/dev/tvm-official-release/python/tvm/topi/transform.py", line 1057, in _apply_trilu
    check_position = check_op(row_index, col_index - k)
  File "/home/jiawei/dev/tvm-official-release/python/tvm/tir/expr.py", line 881, in __init__
    self.__init_handle_by_constructor__(_ffi_api.LE, a, b, span)  # type: ignore
  File "/home/jiawei/dev/tvm-official-release/python/tvm/_ffi/_ctypes/object.py", line 145, in __init_handle_by_constructor__
    handle = __init_by_constructor__(fconstructor, args)
  File "/home/jiawei/dev/tvm-official-release/python/tvm/_ffi/_ctypes/packed_func.py", line 260, in __init_handle_by_constructor__
    raise get_last_ffi_error()
  2: TVMFuncCall
  1: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::tir::LE (tvm::PrimExpr, tvm::PrimExpr, tvm::Span)>::AssignTypedLambda<tvm::tir::$_51>(tvm::tir::$_51, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  0: tvm::tir::LE::LE(tvm::PrimExpr, tvm::PrimExpr, tvm::Span)
  File "/home/jiawei/dev/tvm-official-release/src/tir/ir/expr.cc", line 459
TypeError: Check failed: (a.dtype() == b.dtype()) is false: mismatched types. int32 vs. int64
"""

Triage

Please refer to the list of label tags linked above to find the relevant tags and add them here in a bullet format (example below).

  • needs-triage

cc: @jwfromm

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions