-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Closed
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug
Description
It seems there's a bug in PyTorch frontend.
Environment
Ubuntu 22.04.
No CUDA.
PyTorch 2.0.0.
Steps to reproduce
The below code worked fine with TVM 1043136, but failed with TVM e9cf04e.
I did some debugging and found that there seems to be a bug in the change of the convert_params function.
from tvm import relay
import torch
torch.set_grad_enabled(False)
class ExampleModel(torch.nn.Module):
def __init__(self, num_layer=2):
super().__init__()
self.biases = torch.nn.ParameterList([torch.randn(10)] * num_layer)
self.weights = torch.nn.ParameterList([torch.randn(10, 10)] * num_layer)
# self.biases = [torch.randn(10)] * num_layer # this works fine
# self.weights = [torch.randn(10, 10)] * num_layer # this works fine
def forward(self, x):
for i in range(len(self.weights) - 1):
x = torch.addmm(self.biases[i], x, self.weights[i])
return torch.addmm(self.biases[-1], x, self.weights[-1])
x = torch.randn(20, 10)
model = ExampleModel()
script_model = torch.jit.trace(model, x)
input_infos = [("x", x.shape)]
mod, params = relay.frontend.from_pytorch(script_model, input_infos)Changing this
var_name = attr_name_sep.join(
[source_map[_get_users(getattrs[-1])[0]], full_attr.split(attr_name_sep)[-1]]
)into this in the convert_params worked fine but not sure if it's the right way to fix it.
var_name = attr_name_sep.join(
[source_map[_get_users(getattrs[-1])[0]]] +
full_attr.split(attr_name_sep)[-2:]
)Also, changing ParameterList to Python list works fine.
The output looks like the below.
$ python hb_convert_error.py
Traceback (most recent call last):
File "/home/ubuntu/workspace/sandbox/tvm_/frontend/hb_convert_error.py", line 25, in <module>
mod, params = relay.frontend.from_pytorch(script_model, input_infos)
File "/home/ubuntu/workspace/sandbox/.dep/tvm/python/tvm/relay/frontend/pytorch.py", line 5258, in from_pytorch
outputs = converter.convert_operators(operator_nodes, outputs, ret_name)
File "/home/ubuntu/workspace/sandbox/.dep/tvm/python/tvm/relay/frontend/pytorch.py", line 4484, in convert_operators
self.record_output_type(relay_out)
File "/home/ubuntu/workspace/sandbox/.dep/tvm/python/tvm/relay/frontend/pytorch.py", line 238, in record_output_type
self.infer_type_with_prelude(output)
File "/home/ubuntu/workspace/sandbox/.dep/tvm/python/tvm/relay/frontend/pytorch.py", line 174, in infer_type_with_prelude
body = self.infer_type(val, self.prelude.mod)
File "/home/ubuntu/workspace/sandbox/.dep/tvm/python/tvm/relay/frontend/pytorch.py", line 167, in infer_type
new_mod = transform.InferType()(new_mod)
File "/home/ubuntu/workspace/sandbox/.dep/tvm/python/tvm/ir/transform.py", line 160, in __call__
return _ffi_transform_api.RunPass(self, mod)
File "/home/ubuntu/workspace/sandbox/.dep/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
raise_last_ffi_error()
File "/home/ubuntu/workspace/sandbox/.dep/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
raise py_err
tvm._ffi.base.TVMError: Traceback (most recent call last):
9: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}>(tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
8: tvm::transform::Pass::operator()(tvm::IRModule) const
7: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
6: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
5: _ZN3tvm7runtime13PackedFuncObj
4: tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1}>(tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
3: tvm::DiagnosticContext::Render()
2: tvm::DiagnosticRenderer::Render(tvm::DiagnosticContext const&)
1: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (tvm::DiagnosticContext)>::AssignTypedLambda<tvm::TerminalRenderer(std::ostream&)::{lambda(tvm::DiagnosticContext const&)#1}>(tvm::TerminalRenderer(std::ostream&)::{lambda(tvm::DiagnosticContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
0: tvm::ReportAt(tvm::DiagnosticContext const&, std::ostream&, tvm::Span const&, tvm::Diagnostic const&)
File "/home/ubuntu/workspace/sandbox/.dep/tvm/src/ir/diagnostic.cc", line 264
TVMError: The source maps are not populated for this module. Please use `tvm.relay.transform.AnnotateSpans` to attach source maps for error reporting.
Error: Incompatible broadcast type TensorType([20, 10], float32) and TensorType([10, 10], float32)
Triage
Please refer to the list of label tags here to find the relevant tags and add them below in a bullet format (example below).
- frontend:pytorch
Metadata
Metadata
Assignees
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug