Skip to content

[Bug] [feature][relax.frontend.torch] from_exported_program rejects fused GRU (aten::gru.input) with “Unsupported function types ['gru.input']” #18356

@tinywisdom

Description

@tinywisdom

Expected behavior

Translating a tiny torch.exported model that contains nn.GRU fails in TVM Relax Torch frontend with:

AssertionError: Unsupported function types ['gru.input']

This looks like missing coverage for the fused RNN op (aten::gru overload gru.input) in from_exported_program. Even if unsupported today, it would be helpful to either (a) add lowering for GRU, or (b) fail with a more actionable message / doc pointer describing the current RNN support scope and workarounds.

Actual behavior

=== Versions ===
Python : 3.10.16 | packaged by conda-forge | (main, Apr  8, 2025, 20:53:32) [GCC 13.3.0]
Torch  : 2.8.0+cu128
TVM   : 0.21.0
================
torch.export: OK
TVM from_exported_program: FAILED as expected
Error: Unsupported function types ['gru.input']
Traceback (most recent call last):
  ...
  File ".../tvm/relax/frontend/torch/base_fx_graph_translator.py", line 116, in _check_unsupported_func_type
    assert not missing_func_types, f"Unsupported function types {missing_func_types}"
AssertionError: Unsupported function types ['gru.input']

Environment

  • OS: (Ubuntu 22.04.4 LTS (x86_64))
  • TVM version: (release v0.21.0)
  • Python: (3.10.16)
  • LLVM: (17.0.6)

Steps to reproduce

# minimal_gru_from_exported_program_repro.py
import torch
import torch.nn as nn
import sys

def print_env():
    print("=== Versions ===")
    print("Python :", sys.version.replace("\n", " "))
    print("Torch  :", torch.__version__)
    try:
        import torchaudio, torchvision
        print("torchaudio:", getattr(torchaudio, '__version__', 'unknown'))
        print("torchvision:", getattr(torchvision, '__version__', 'unknown'))
    except Exception:
        pass
    try:
        import tvm
        print("TVM   :", tvm.__version__)
    except Exception as e:
        print("TVM   : import error ->", e)
    print("================")

class M(nn.Module):
    """
    Tiny model that triggers a fused GRU op (aten::gru/gru.input)
    in the ExportedProgram graph.
    """
    def __init__(self):
        super().__init__()
        self.proj = nn.Linear(12, 10)   # map to GRU input_size=10
        self.gru  = nn.GRU(10, 20, num_layers=2)

    def forward(self, x):
        # x: (B=1, 12)
        x = self.proj(x)                # (1, 10)
        x = x.unsqueeze(0)              # (T=1, B=1, 10) → seq len 1
        y, h = self.gru(x)              # y: (1, 1, 20)
        return y

def main():
    torch.manual_seed(0)
    print_env()

    m = M().eval()
    inp = torch.randn(1, 12)

    # Eager sanity
    with torch.inference_mode():
        _ = m(inp)

    # torch.export succeeds
    from torch.export import export as torch_export
    ep = torch_export(m, (inp,))
    print("torch.export: OK")

    # TVM: translate ExportedProgram → Relax
    from tvm.relax.frontend.torch import from_exported_program
    try:
        mod = from_exported_program(ep)
        print("TVM from_exported_program: OK (unexpected)")
    except AssertionError as e:
        print("TVM from_exported_program: FAILED as expected")
        print("Error:", e)  # shows "Unsupported function types ['gru.input']"
        raise
    except Exception as e:
        print("TVM from_exported_program: FAILED (different exception)")
        raise

if __name__ == "__main__":
    main()

Triage

  • needs-triage
  • bug

Metadata

Metadata

Assignees

Labels

needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions