-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Closed
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug
Description
Expected behavior
Run End-to-End Optimize tutorial code tvm 0.22 guideline
Output
Traceback (most recent call last):
^^^^^^
File ".../base_fx_graph_translator.py", line 938, in _conv2d
return self._conv2d_impl(
^^^^^^^^^^^^^^^^^^
File ".../base_fx_graph_translator.py", line 925, in _conv2d_impl
assert len(self.shape_of(bias)) == 1
^^^^^^^^^^^^^^^^^^^
File ".../base_fx_graph_translator.py", line 84, in shape_of
if not isinstance(tensor.struct_info, relax.TensorStructInfo):
^^^^^^^^^^^^^^^^^^
tvm.error.InternalError: Check failed: (ptr) is false: The struct_info is not populated, check if you have normalized the expr
[16:55:59] /block_builder.cc:64: Warning: BlockBuilder destroyed with remaining blocks!
Environment
- OS : Mac M2 ( Sequoia )
- TVM : 0.22.dev0
- Python : 3.11.14
- Pytorch : 2.8.0
Steps to reproduce
import os
import numpy as np
import torch
from torch.export import export
from torchvision.models.resnet import ResNet18_Weights, resnet18
torch_model = resnet18(weights=ResNet18_Weights.DEFAULT).eval()
from tvm import relax
from tvm.relax.frontend.torch import from_exported_program
# Give an example argument to torch.export
example_args = (torch.randn(1, 3, 224, 224, dtype=torch.float32),)
# Skip running in CI environment
IS_IN_CI = os.getenv("CI", "") == "true"
if not IS_IN_CI:
# Convert the model to IRModule
with torch.no_grad():
exported_program = export(torch_model, example_args)
mod = from_exported_program(exported_program, keep_params_as_input=True)
mod, params = relax.frontend.detach_params(mod)
mod.show()
Note
Updated _conv2d_impl to handle both Python None and relax.null_value() before calling shape_of and reshaping. This prevents the assert from triggering when a Conv2d operation has no bias. While this makes the code work, I'm not sure if this is the correct approach.
def _conv2d_impl(
self,
x: relax.Expr,
weight: relax.Expr,
bias: Optional[relax.Expr],
strides: Optional[Tuple],
padding: Optional[Tuple],
dilation: Optional[Tuple],
groups: Optional[Tuple],
):
conv2d = self.block_builder.emit(
relax.op.nn.conv2d(
x,
weight,
strides=strides,
padding=padding,
dilation=dilation,
groups=groups,
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="float32",
)
)
if bias is None:
return conv2d
# add
if isinstance(bias, relax.Call) and bias.op == relax.op.null_value().op:
return conv2d
assert len(self.shape_of(bias)) == 1
bias = relax.op.reshape(bias, (1, -1, 1, 1))
return self.block_builder.emit(relax.op.add(conv2d, bias))
Triage
- needs-triage
- bug
Metadata
Metadata
Assignees
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug