From a2b82de2b5de9d3c52c98f018edee752f1360680 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sat, 8 Mar 2025 21:38:46 -0500 Subject: [PATCH 1/5] from exported program support for expand_as, with unit test --- .../torch/base_fx_graph_translator.py | 8 ++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 1 + .../relax/test_from_exported_to_cuda.py | 111 ++++++++++++++++++ 4 files changed, 121 insertions(+) create mode 100644 tests/python/relax/test_from_exported_to_cuda.py diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 003ceebec6ff..92d5875a8987 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -846,6 +846,14 @@ def _expand(self, node: fx.Node) -> relax.Var: else: broadcast_shape.append(i) return self.block_builder.emit(relax.op.broadcast_to(args[0], broadcast_shape)) + + def _expand_as(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + # args[0] is the 'self' tensor + # args[1] is the 'other' tensor + data = args[0] + other_shape = self.shape_of(args[1]) # the shape of 'other' + return self.block_builder.emit(relax.op.broadcast_to(data, other_shape)) def _flip(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index c8d9d12505c6..7c4bdf664312 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -268,6 +268,7 @@ def create_convert_map( "concat.default": self._cat, "cumsum.default": self._cumsum, "expand.default": self._expand, + "expand_as.default": self._expand_as, "permute.default": self._permute, "repeat.default": self._repeat, "select.int": self._select, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index ef98d3c02501..406b6b8c9dc2 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -732,6 +732,7 @@ def create_convert_map( "contiguous": lambda node: self.env[node.args[0]], "cumsum": self._cumsum, "expand": self._expand, + "expand_as.default": self._expand_as, "flatten": self._flatten, "flip": self._flip, "gather": self._gather, diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py new file mode 100644 index 000000000000..3c04965de534 --- /dev/null +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -0,0 +1,111 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import relax +import tvm.testing +import numpy as np +import torch +from torch import nn +from torch.export import export +from tvm.relax.frontend.torch import from_exported_program +from torch.nn import Softmax, Upsample + + +def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module): + """ + This util ensures that a torch module can successfully be exported to TVM + using torch.export and that the resuling IR program gives the same result + as PyTorch when ran on CUDA. + """ + raw_data_for_tvm = raw_data.copy() # In case the data is modified + torch_data = torch.from_numpy(raw_data) + example_args = (torch_data,) + + with torch.no_grad(): + exported_program = export(torch_module, example_args) + mod_from_torch = from_exported_program( + exported_program, keep_params_as_input=True + ) + + tvm_mod, tvm_params = relax.frontend.detach_params(mod_from_torch) + target = tvm.target.Target.from_device(tvm.cuda()) + + ex = relax.build(tvm_mod, target=target, + relax_pipeline=relax.get_default_pipeline(target)) + dev = tvm.device("cuda", 0) + vm = relax.VirtualMachine(ex, dev) + + gpu_data = tvm.nd.array(raw_data_for_tvm, dev) + gpu_params = [tvm.nd.array(p, dev) for p in tvm_params["main"]] + gpu_out = vm["main"](gpu_data, *gpu_params) + + pytorch_out = torch_module(torch_data).detach().numpy() + actual = gpu_out[0].numpy() + desired = pytorch_out + np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, + atol=1e-5) + + +def test_tensor_expand_as(): + + class ExpandAs0(torch.nn.Module): + def __init__(self): + super().__init__() + self.template = torch.ones((1, 1, 1, 1)) + + def forward(self, x): + return self.template.expand_as(x) + + class ExpandAs1(torch.nn.Module): + def __init__(self): + super().__init__() + self.template = torch.ones((2, 1, 4, 1)) + + def forward(self, x): + return self.template.expand_as(x) + + class ExpandAs2(torch.nn.Module): + def __init__(self): + super().__init__() + self.template = torch.ones((2, 1, 1, 10)) + + def forward(self, x): + return self.template.expand_as(x) + + class ExpandAs3(torch.nn.Module): + def __init__(self): + super().__init__() + self.template = torch.ones((2, 3, 1, 1)) + + def forward(self, x): + return self.template.expand_as(x) + + raw_data = np.random.randn(2, 3, 4, 10).astype(np.float32) + + torch_module0 = ExpandAs0().eval() + torch_module1 = ExpandAs1().eval() + torch_module2 = ExpandAs2().eval() + torch_module3 = ExpandAs3().eval() + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0) + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module1) + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module2) + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module3) + +if __name__ == "__main__": + tvm.testing.main() From 7ae94d47d6cec979dc7eca9ae6b3db3941788d81 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 9 Mar 2025 17:36:09 -0400 Subject: [PATCH 2/5] black formatter --- .../torch/base_fx_graph_translator.py | 8 ++-- .../relax/test_from_exported_to_cuda.py | 37 +++++++++---------- 2 files changed, 21 insertions(+), 24 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 92d5875a8987..044f16d6b762 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -37,9 +37,9 @@ def __init__(self) -> None: self.env: Dict[fx.Node, relax.Expr] = {} self.params: Dict[torch.Tensor, relax.Expr] = {} self.block_builder: relax.BlockBuilder = None - self.convert_map: Dict[ - Union[torch.nn.Module, str], Callable[[fx.Node], relax.Var] - ] = self.create_convert_map() + self.convert_map: Dict[Union[torch.nn.Module, str], Callable[[fx.Node], relax.Var]] = ( + self.create_convert_map() + ) ########## Utilities ########## @@ -846,7 +846,7 @@ def _expand(self, node: fx.Node) -> relax.Var: else: broadcast_shape.append(i) return self.block_builder.emit(relax.op.broadcast_to(args[0], broadcast_shape)) - + def _expand_as(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) # args[0] is the 'self' tensor diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 3c04965de534..fa643c45fea9 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -3,7 +3,7 @@ # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance +# "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 @@ -28,25 +28,22 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module): """ - This util ensures that a torch module can successfully be exported to TVM - using torch.export and that the resuling IR program gives the same result + This util ensures that a torch module can successfully be exported to TVM + using torch.export and that the resuling IR program gives the same result as PyTorch when ran on CUDA. """ - raw_data_for_tvm = raw_data.copy() # In case the data is modified + raw_data_for_tvm = raw_data.copy() # In case the data is modified torch_data = torch.from_numpy(raw_data) example_args = (torch_data,) with torch.no_grad(): exported_program = export(torch_module, example_args) - mod_from_torch = from_exported_program( - exported_program, keep_params_as_input=True - ) + mod_from_torch = from_exported_program(exported_program, keep_params_as_input=True) tvm_mod, tvm_params = relax.frontend.detach_params(mod_from_torch) target = tvm.target.Target.from_device(tvm.cuda()) - ex = relax.build(tvm_mod, target=target, - relax_pipeline=relax.get_default_pipeline(target)) + ex = relax.build(tvm_mod, target=target, relax_pipeline=relax.get_default_pipeline(target)) dev = tvm.device("cuda", 0) vm = relax.VirtualMachine(ex, dev) @@ -57,17 +54,16 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module): pytorch_out = torch_module(torch_data).detach().numpy() actual = gpu_out[0].numpy() desired = pytorch_out - np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, - atol=1e-5) + np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) def test_tensor_expand_as(): - + class ExpandAs0(torch.nn.Module): def __init__(self): super().__init__() self.template = torch.ones((1, 1, 1, 1)) - + def forward(self, x): return self.template.expand_as(x) @@ -75,30 +71,30 @@ class ExpandAs1(torch.nn.Module): def __init__(self): super().__init__() self.template = torch.ones((2, 1, 4, 1)) - + def forward(self, x): return self.template.expand_as(x) - + class ExpandAs2(torch.nn.Module): def __init__(self): super().__init__() self.template = torch.ones((2, 1, 1, 10)) - + def forward(self, x): return self.template.expand_as(x) - + class ExpandAs3(torch.nn.Module): def __init__(self): super().__init__() self.template = torch.ones((2, 3, 1, 1)) - + def forward(self, x): return self.template.expand_as(x) - + raw_data = np.random.randn(2, 3, 4, 10).astype(np.float32) torch_module0 = ExpandAs0().eval() - torch_module1 = ExpandAs1().eval() + torch_module1 = ExpandAs1().eval() torch_module2 = ExpandAs2().eval() torch_module3 = ExpandAs3().eval() @@ -107,5 +103,6 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module2) assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module3) + if __name__ == "__main__": tvm.testing.main() From 0b0233f14fb8387d3b81d6ede2cdf936cd6dc69c Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 10 Mar 2025 03:14:13 -0400 Subject: [PATCH 3/5] cuda target in new test --- .../relax/test_from_exported_to_cuda.py | 29 ++++++++++--------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index fa643c45fea9..4bd17b591e2e 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -15,18 +15,17 @@ # specific language governing permissions and limitations # under the License. -import tvm -from tvm import relax -import tvm.testing import numpy as np import torch -from torch import nn from torch.export import export + +import tvm +import tvm.testing +from tvm import relax from tvm.relax.frontend.torch import from_exported_program -from torch.nn import Softmax, Upsample -def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module): +def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev): """ This util ensures that a torch module can successfully be exported to TVM using torch.export and that the resuling IR program gives the same result @@ -41,10 +40,11 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module): mod_from_torch = from_exported_program(exported_program, keep_params_as_input=True) tvm_mod, tvm_params = relax.frontend.detach_params(mod_from_torch) - target = tvm.target.Target.from_device(tvm.cuda()) - ex = relax.build(tvm_mod, target=target, relax_pipeline=relax.get_default_pipeline(target)) - dev = tvm.device("cuda", 0) + relax_pipeline = relax.get_default_pipeline(tvm.target.Target.from_device(tvm.cuda())) + # TODO try pipeline below? + # releax_pipeline = relax.backend.cuda.pipeline.get_default_pipeline(target) + ex = relax.build(tvm_mod, target=target, relax_pipeline=relax_pipeline) vm = relax.VirtualMachine(ex, dev) gpu_data = tvm.nd.array(raw_data_for_tvm, dev) @@ -57,7 +57,8 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module): np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) -def test_tensor_expand_as(): +@tvm.testing.parametrize_targets("cuda") +def test_tensor_expand_as(target, dev): class ExpandAs0(torch.nn.Module): def __init__(self): @@ -98,10 +99,10 @@ def forward(self, x): torch_module2 = ExpandAs2().eval() torch_module3 = ExpandAs3().eval() - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0) - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module1) - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module2) - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module3) + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module1, target, dev) + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module2, target, dev) + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module3, target, dev) if __name__ == "__main__": From 85a6ed81dc948b4ee3c34f96fd8ab3faad4a3438 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 10 Mar 2025 03:40:31 -0400 Subject: [PATCH 4/5] ran Black Python formatter with version 22 --- python/tvm/relax/frontend/torch/base_fx_graph_translator.py | 6 +++--- tests/python/relax/test_from_exported_to_cuda.py | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 044f16d6b762..495157750ae1 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -37,9 +37,9 @@ def __init__(self) -> None: self.env: Dict[fx.Node, relax.Expr] = {} self.params: Dict[torch.Tensor, relax.Expr] = {} self.block_builder: relax.BlockBuilder = None - self.convert_map: Dict[Union[torch.nn.Module, str], Callable[[fx.Node], relax.Var]] = ( - self.create_convert_map() - ) + self.convert_map: Dict[ + Union[torch.nn.Module, str], Callable[[fx.Node], relax.Var] + ] = self.create_convert_map() ########## Utilities ########## diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 4bd17b591e2e..b7bd20a0ce5e 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -59,7 +59,6 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar @tvm.testing.parametrize_targets("cuda") def test_tensor_expand_as(target, dev): - class ExpandAs0(torch.nn.Module): def __init__(self): super().__init__() From 1fa0475341b25e8efef44f8022226dd2547a13be Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 10 Mar 2025 03:53:39 -0400 Subject: [PATCH 5/5] remove TODO --- tests/python/relax/test_from_exported_to_cuda.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index b7bd20a0ce5e..247d5725fc11 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -42,8 +42,6 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar tvm_mod, tvm_params = relax.frontend.detach_params(mod_from_torch) relax_pipeline = relax.get_default_pipeline(tvm.target.Target.from_device(tvm.cuda())) - # TODO try pipeline below? - # releax_pipeline = relax.backend.cuda.pipeline.get_default_pipeline(target) ex = relax.build(tvm_mod, target=target, relax_pipeline=relax_pipeline) vm = relax.VirtualMachine(ex, dev)