diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 12b460e859ac..c1cbd3416c57 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -102,6 +102,8 @@ def _retrieve_args(self, node): return [self._retrieve_args(x) for x in node] elif isinstance(node, dict): return {self._retrieve_args(k): self._retrieve_args(v) for k, v in node.items()} + elif node is None: + return relax.op.null_value() else: return node diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 4b0672ccc144..b35af088b530 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -6028,5 +6028,27 @@ def forward(self, x): np.testing.assert_allclose(pytorch_output2.numpy(), tvm_output2_np, rtol=1e-4, atol=1e-5) +def test_tensor_none_tuple(): + example_args = (torch.tensor([1.0, 2.0, 3.0]),) + + class TensorNoneModel(Module): + def forward(self, x): + return x + 1, None + + @tvm.script.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((3,), dtype="float32") + ) -> R.Tuple(R.Tensor((3,), dtype="float32"), R.Object): + with R.dataflow(): + lv: R.Tensor((3,), dtype="float32") = R.add(x, R.const(1.0, "float32")) + gv: R.Tuple(R.Tensor((3,), dtype="float32"), R.Object) = (lv, R.null_value()) + R.output(gv) + return gv + + verify_model(TensorNoneModel(), example_args, {}, Expected) + + if __name__ == "__main__": tvm.testing.main()