-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Closed
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug
Description
Expected behavior
Translating a tiny torch.exported model that contains nn.GRU fails in TVM Relax Torch frontend with:
AssertionError: Unsupported function types ['gru.input']
This looks like missing coverage for the fused RNN op (aten::gru overload gru.input) in from_exported_program. Even if unsupported today, it would be helpful to either (a) add lowering for GRU, or (b) fail with a more actionable message / doc pointer describing the current RNN support scope and workarounds.
Actual behavior
=== Versions ===
Python : 3.10.16 | packaged by conda-forge | (main, Apr 8, 2025, 20:53:32) [GCC 13.3.0]
Torch : 2.8.0+cu128
TVM : 0.21.0
================
torch.export: OK
TVM from_exported_program: FAILED as expected
Error: Unsupported function types ['gru.input']
Traceback (most recent call last):
...
File ".../tvm/relax/frontend/torch/base_fx_graph_translator.py", line 116, in _check_unsupported_func_type
assert not missing_func_types, f"Unsupported function types {missing_func_types}"
AssertionError: Unsupported function types ['gru.input']
Environment
- OS: (Ubuntu 22.04.4 LTS (x86_64))
- TVM version: (release v0.21.0)
- Python: (3.10.16)
- LLVM: (17.0.6)
Steps to reproduce
# minimal_gru_from_exported_program_repro.py
import torch
import torch.nn as nn
import sys
def print_env():
print("=== Versions ===")
print("Python :", sys.version.replace("\n", " "))
print("Torch :", torch.__version__)
try:
import torchaudio, torchvision
print("torchaudio:", getattr(torchaudio, '__version__', 'unknown'))
print("torchvision:", getattr(torchvision, '__version__', 'unknown'))
except Exception:
pass
try:
import tvm
print("TVM :", tvm.__version__)
except Exception as e:
print("TVM : import error ->", e)
print("================")
class M(nn.Module):
"""
Tiny model that triggers a fused GRU op (aten::gru/gru.input)
in the ExportedProgram graph.
"""
def __init__(self):
super().__init__()
self.proj = nn.Linear(12, 10) # map to GRU input_size=10
self.gru = nn.GRU(10, 20, num_layers=2)
def forward(self, x):
# x: (B=1, 12)
x = self.proj(x) # (1, 10)
x = x.unsqueeze(0) # (T=1, B=1, 10) → seq len 1
y, h = self.gru(x) # y: (1, 1, 20)
return y
def main():
torch.manual_seed(0)
print_env()
m = M().eval()
inp = torch.randn(1, 12)
# Eager sanity
with torch.inference_mode():
_ = m(inp)
# torch.export succeeds
from torch.export import export as torch_export
ep = torch_export(m, (inp,))
print("torch.export: OK")
# TVM: translate ExportedProgram → Relax
from tvm.relax.frontend.torch import from_exported_program
try:
mod = from_exported_program(ep)
print("TVM from_exported_program: OK (unexpected)")
except AssertionError as e:
print("TVM from_exported_program: FAILED as expected")
print("Error:", e) # shows "Unsupported function types ['gru.input']"
raise
except Exception as e:
print("TVM from_exported_program: FAILED (different exception)")
raise
if __name__ == "__main__":
main()Triage
- needs-triage
- bug
Metadata
Metadata
Assignees
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug