Skip to content

[Bug] tutorial code fails on MacOS #18373

@garfield0xff

Description

@garfield0xff

Expected behavior

Run End-to-End Optimize tutorial code tvm 0.22 guideline

Output

Traceback (most recent call last):
^^^^^^
  File ".../base_fx_graph_translator.py", line 938, in _conv2d
    return self._conv2d_impl(
           ^^^^^^^^^^^^^^^^^^
  File ".../base_fx_graph_translator.py", line 925, in _conv2d_impl
    assert len(self.shape_of(bias)) == 1
               ^^^^^^^^^^^^^^^^^^^
  File ".../base_fx_graph_translator.py", line 84, in shape_of
    if not isinstance(tensor.struct_info, relax.TensorStructInfo):
                      ^^^^^^^^^^^^^^^^^^

tvm.error.InternalError: Check failed: (ptr) is false: The struct_info is not populated, check if you have normalized the expr
[16:55:59] /block_builder.cc:64: Warning: BlockBuilder destroyed with remaining blocks!

Environment

  • OS : Mac M2 ( Sequoia )
  • TVM : 0.22.dev0
  • Python : 3.11.14
  • Pytorch : 2.8.0

Steps to reproduce

import os
import numpy as np
import torch
from torch.export import export
from torchvision.models.resnet import ResNet18_Weights, resnet18

torch_model = resnet18(weights=ResNet18_Weights.DEFAULT).eval()

from tvm import relax
from tvm.relax.frontend.torch import from_exported_program

# Give an example argument to torch.export
example_args = (torch.randn(1, 3, 224, 224, dtype=torch.float32),)

# Skip running in CI environment
IS_IN_CI = os.getenv("CI", "") == "true"

if not IS_IN_CI:
    # Convert the model to IRModule
    with torch.no_grad():
        exported_program = export(torch_model, example_args)
        mod = from_exported_program(exported_program, keep_params_as_input=True)

    mod, params = relax.frontend.detach_params(mod)
    mod.show()

Note

Updated _conv2d_impl to handle both Python None and relax.null_value() before calling shape_of and reshaping. This prevents the assert from triggering when a Conv2d operation has no bias. While this makes the code work, I'm not sure if this is the correct approach.

    def _conv2d_impl(
        self,
        x: relax.Expr,
        weight: relax.Expr,
        bias: Optional[relax.Expr],
        strides: Optional[Tuple],
        padding: Optional[Tuple],
        dilation: Optional[Tuple],
        groups: Optional[Tuple],
    ):
        conv2d = self.block_builder.emit(
            relax.op.nn.conv2d(
                x,
                weight,
                strides=strides,
                padding=padding,
                dilation=dilation,
                groups=groups,
                data_layout="NCHW",
                kernel_layout="OIHW",
                out_dtype="float32",
            )
        )

        
        if bias is None:
            return conv2d
        # add 
        if isinstance(bias, relax.Call) and bias.op == relax.op.null_value().op:
            return conv2d
        assert len(self.shape_of(bias)) == 1
        bias = relax.op.reshape(bias, (1, -1, 1, 1))
        return self.block_builder.emit(relax.op.add(conv2d, bias))

Triage

  • needs-triage
  • bug

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions