diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 7c20d1b1a469..3cf07effecaa 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -434,6 +434,9 @@ def create_convert_map( "matmul.default": self._binary_op( partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul ), + "mm.default": self._binary_op( + partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul + ), "max.other": self._binary_op(relax.op.maximum, max), "min.other": self._binary_op(relax.op.minimum, min), "max.default": self._unary_op(relax.op.max), diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 2871e3f4cde3..ead341de287a 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -5914,6 +5914,32 @@ def main( verify_model(Model(), example_args, {}, Expected) +def test_mm(): + class MatrixMultiply(Module): + def forward(self, a, b): + return torch.mm(a, b) + + example_args = ( + torch.randn(2, 3, dtype=torch.float32), + torch.randn(3, 4, dtype=torch.float32), + ) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + a: R.Tensor((2, 3), dtype="float32"), + b: R.Tensor((3, 4), dtype="float32"), + ) -> R.Tuple(R.Tensor((2, 4), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((2, 4), dtype="float32") = R.matmul(a, b, out_dtype="float32") + gv: R.Tuple(R.Tensor((2, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(MatrixMultiply(), example_args, {}, Expected) + + if __name__ == "__main__": tvm.testing.main() 1