Skip to content

[Bug] [Relay]PyTorch] Can't properly convert torch.nn.ParameterList #16150

@mshr-h

Description

@mshr-h

It seems there's a bug in PyTorch frontend.

Environment

Ubuntu 22.04.
No CUDA.
PyTorch 2.0.0.

Steps to reproduce

The below code worked fine with TVM 1043136, but failed with TVM e9cf04e.
I did some debugging and found that there seems to be a bug in the change of the convert_params function.

from tvm import relay
import torch

torch.set_grad_enabled(False)

class ExampleModel(torch.nn.Module):
  def __init__(self, num_layer=2):
    super().__init__()
    self.biases = torch.nn.ParameterList([torch.randn(10)] * num_layer)
    self.weights = torch.nn.ParameterList([torch.randn(10, 10)] * num_layer)
    # self.biases = [torch.randn(10)] * num_layer # this works fine
    # self.weights = [torch.randn(10, 10)] * num_layer # this works fine

  def forward(self, x):
    for i in range(len(self.weights) - 1):
      x = torch.addmm(self.biases[i], x, self.weights[i])
    return torch.addmm(self.biases[-1], x, self.weights[-1])

x = torch.randn(20, 10)
model = ExampleModel()
script_model = torch.jit.trace(model, x)

input_infos = [("x", x.shape)]
mod, params = relay.frontend.from_pytorch(script_model, input_infos)

Changing this

            var_name = attr_name_sep.join(
                [source_map[_get_users(getattrs[-1])[0]], full_attr.split(attr_name_sep)[-1]]
            )

into this in the convert_params worked fine but not sure if it's the right way to fix it.

            var_name = attr_name_sep.join(
                [source_map[_get_users(getattrs[-1])[0]]] + 
                full_attr.split(attr_name_sep)[-2:]
            )

Also, changing ParameterList to Python list works fine.

The output looks like the below.

$ python hb_convert_error.py 
Traceback (most recent call last):
  File "/home/ubuntu/workspace/sandbox/tvm_/frontend/hb_convert_error.py", line 25, in <module>
    mod, params = relay.frontend.from_pytorch(script_model, input_infos)
  File "/home/ubuntu/workspace/sandbox/.dep/tvm/python/tvm/relay/frontend/pytorch.py", line 5258, in from_pytorch
    outputs = converter.convert_operators(operator_nodes, outputs, ret_name)
  File "/home/ubuntu/workspace/sandbox/.dep/tvm/python/tvm/relay/frontend/pytorch.py", line 4484, in convert_operators
    self.record_output_type(relay_out)
  File "/home/ubuntu/workspace/sandbox/.dep/tvm/python/tvm/relay/frontend/pytorch.py", line 238, in record_output_type
    self.infer_type_with_prelude(output)
  File "/home/ubuntu/workspace/sandbox/.dep/tvm/python/tvm/relay/frontend/pytorch.py", line 174, in infer_type_with_prelude
    body = self.infer_type(val, self.prelude.mod)
  File "/home/ubuntu/workspace/sandbox/.dep/tvm/python/tvm/relay/frontend/pytorch.py", line 167, in infer_type
    new_mod = transform.InferType()(new_mod)
  File "/home/ubuntu/workspace/sandbox/.dep/tvm/python/tvm/ir/transform.py", line 160, in __call__
    return _ffi_transform_api.RunPass(self, mod)
  File "/home/ubuntu/workspace/sandbox/.dep/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/home/ubuntu/workspace/sandbox/.dep/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
tvm._ffi.base.TVMError: Traceback (most recent call last):
  9: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}>(tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  8: tvm::transform::Pass::operator()(tvm::IRModule) const
  7: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  6: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  5: _ZN3tvm7runtime13PackedFuncObj
  4: tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1}>(tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
  3: tvm::DiagnosticContext::Render()
  2: tvm::DiagnosticRenderer::Render(tvm::DiagnosticContext const&)
  1: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (tvm::DiagnosticContext)>::AssignTypedLambda<tvm::TerminalRenderer(std::ostream&)::{lambda(tvm::DiagnosticContext const&)#1}>(tvm::TerminalRenderer(std::ostream&)::{lambda(tvm::DiagnosticContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  0: tvm::ReportAt(tvm::DiagnosticContext const&, std::ostream&, tvm::Span const&, tvm::Diagnostic const&)
  File "/home/ubuntu/workspace/sandbox/.dep/tvm/src/ir/diagnostic.cc", line 264
TVMError: The source maps are not populated for this module. Please use `tvm.relay.transform.AnnotateSpans` to attach source maps for error reporting.
Error: Incompatible broadcast type TensorType([20, 10], float32) and TensorType([10, 10], float32)

Triage

Please refer to the list of label tags here to find the relevant tags and add them below in a bullet format (example below).

  • frontend:pytorch

cc @shingjan @chunit-quic @masahi @vvchernov @jikechao

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions