-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Closed
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug
Description
In the latest TVM (v.0.13.dev0). TVM load the torch model failed and threw "Expected Array[PrimExpr], but got Array[index 1: relay.Constant]" This bug is similar to this PR
What caused this crash? Wish for your comments. Thanks!
Expected behavior
Load successfully!
Actual behavior
File "test.py", line 11, in <module>
mod, params = relay.frontend.from_pytorch(trace, input_shapes)
File "/workplace/software/tvm/tvm_/python/tvm/relay/frontend/pytorch.py", line 5002, in from_pytorch
outputs = converter.convert_operators(operator_nodes, outputs, ret_name)
File "/workplace/software/tvm/tvm_/python/tvm/relay/frontend/pytorch.py", line 4260, in convert_operators
_get_input_types(op_node, outputs, default_dtype=self.default_dtype),
File "/workplace/software/tvm/tvm_/python/tvm/relay/frontend/pytorch.py", line 1099, in adaptive_avg_pool
return func(data)
File "/workplace/software/tvm/tvm_/python/tvm/relay/frontend/pytorch.py", line 1094, in func
return op(x, output_size=output_size)
File "/workplace/software/tvm/tvm_/python/tvm/relay/op/nn/nn.py", line 3388, in adaptive_avg_pool2d
return _make.adaptive_avg_pool2d(data, output_size, layout, out_layout)
File "/workplace/software/tvm/tvm_/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
4: TVMFuncCall
3: _ZN3tvm7runtime13Pac
2: tvm::runtime::TypedPackedFunc<tvm::RelayExpr (tvm::RelayExpr, tvm::runtime::Array<tvm::PrimExpr, void>, tvm::runtime::String, tvm::runtime::String)>::AssignTypedLambda<tvm::RelayExpr (*)(tvm::RelayExpr, tvm::runtime::Array<tvm::PrimExpr, void>, tvm::runtime::String, tvm::runtime::String)>(tvm::RelayExpr (*)(tvm::RelayExpr, tvm::runtime::Array<tvm::PrimExpr, void>, tvm::runtime::String, tvm::runtime::String), std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
1: tvm::runtime::TVMMovableArgValueWithContext_::operator tvm::runtime::Array<tvm::PrimExpr, void><tvm::runtime::Array<tvm::PrimExpr, void> >() const
0: _ZN3tvm7runtime6detail
6: TVMFuncCall
5: _ZN3tvm7runtime13Pac
4: tvm::runtime::TypedPackedFunc<tvm::RelayExpr (tvm::RelayExpr, tvm::runtime::Array<tvm::PrimExpr, void>, tvm::runtime::String, tvm::runtime::String)>::AssignTypedLambda<tvm::RelayExpr (*)(tvm::RelayExpr, tvm::runtime::Array<tvm::PrimExpr, void>, tvm::runtime::String, tvm::runtime::String)>(tvm::RelayExpr (*)(tvm::RelayExpr, tvm::runtime::Array<tvm::PrimExpr, void>, tvm::runtime::String, tvm::runtime::String), std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
3: tvm::runtime::TVMMovableArgValueWithContext_::operator tvm::runtime::Array<tvm::PrimExpr, void><tvm::runtime::Array<tvm::PrimExpr, void> >() const
2: tvm::runtime::TVMMovableArgValue_::operator tvm::runtime::Array<tvm::PrimExpr, void><tvm::runtime::Array<tvm::PrimExpr, void>, void>() const
1: tvm::runtime::Array<tvm::PrimExpr, void> tvm::runtime::TVMPODValue_::AsObjectRef<tvm::runtime::Array<tvm::PrimExpr, void> >() const
0: _ZN3tvm7runtime6detail
File "/workplace/software/tvm/tvm_/include/tvm/runtime/packed_func.h", line 777
TVMError: In function relay.op.nn._make.adaptive_avg_pool2d(0: RelayExpr, 1: Array<PrimExpr>, 2: runtime.String, 3: runtime.String) -> RelayExpr: error while converting argument 1: [03:05:14] /workplace/software/tvm/tvm_/include/tvm/runtime/packed_func.h:1866:
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
Check failed: (!checked_type.defined()) is false: Expected Array[PrimExpr], but got Array[index 1: relay.Constant]
Environment
Any environment details, such as: Operating System, TVM version, etc
Steps to reproduce
import torch
from tvm import relay
m = torch.nn.AdaptiveAvgPool2d((3,None),)
input_data=[torch.randn([1, 3, 5, 6], dtype=torch.float32)]
trace = torch.jit.trace(m, input_data)
input_shapes = [('input0', torch.Size([1, 3, 5, 6]))]
mod, params = relay.frontend.from_pytorch(trace, input_shapes)
Triage
- frontend:pytorch
Metadata
Metadata
Assignees
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug