From 5f94690a1a244ad1b44155d3c3004eac8fa78916 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sat, 8 Mar 2025 17:21:03 -0500 Subject: [PATCH 1/5] Implement vector norm + unit test --- .../torch/base_fx_graph_translator.py | 43 +++++++++ .../torch/exported_program_translator.py | 2 + .../relax/test_from_exported_to_cuda.py | 93 +++++++++++++++++++ 3 files changed, 138 insertions(+) create mode 100644 tests/python/relax/test_from_exported_to_cuda.py diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 003ceebec6ff..0dda9b562b47 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -305,6 +305,49 @@ def call_binary_op(op, lhs, rhs): return intrinsic_op(lhs, rhs) return convert + + + ########## Linear Algebra ########## + + def _linalg_vector_norm(self, node: fx.Node) -> relax.Var: + + args = self.retrieve_args(node) + + data = args[0] + # Default ord=2 if not supplied + ord_val = args[1] if len(args) > 1 else 2.0 + dim = args[2] if len(args) > 2 else None + keepdim = args[3] if len(args) > 3 else False + + # If ord_val is a Python float/int, wrap it in a Relax const + # so that it matches data's dtype. + dtype = data.struct_info.dtype + ord_expr = ( + ord_val + if isinstance(ord_val, relax.Expr) + else relax.const(float(ord_val), dtype) + ) + # Reciprocal + reci_expr = ( + relax.op.divide( + relax.const(1.0, dtype), + ord_expr + ) + if isinstance(ord_val, relax.Expr) + else relax.const(1.0 / float(ord_val), dtype) + ) + + # abs(data) + abs_data = self.block_builder.emit(relax.op.abs(data)) + # abs_data^ord + abs_data_pow = self.block_builder.emit(relax.op.power(abs_data, ord_expr)) + # sum over dim + reduced = self.block_builder.emit(relax.op.sum(abs_data_pow, dim, keepdims=keepdim)) + # (sum(...))^(1/ord) + norm_val = self.block_builder.emit(relax.op.power(reduced, reci_expr)) + + return norm_val + ########## Neural Network ########## diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index c8d9d12505c6..1e5f7d6006de 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -231,6 +231,8 @@ def create_convert_map( "__or__.Scalar": self._binary_op(relax.op.bitwise_or, operator.or_), "__xor__.Tensor": self._binary_op(relax.op.bitwise_xor, operator.xor), "__xor__.Scalar": self._binary_op(relax.op.bitwise_xor, operator.xor), + # linear algebra + "linalg_vector_norm.default": self._linalg_vector_norm, # neural network "_native_batch_norm_legit_no_training.default": self._batch_norm_legit_no_training, "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d, diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py new file mode 100644 index 000000000000..269b2c5b9a77 --- /dev/null +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import relax +import tvm.testing +import numpy as np +import torch +from torch.export import export +from tvm.relax.frontend.torch import from_exported_program +from torch.nn import Softmax, Upsample + + +def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module): + """ + This util ensures that a torch module can successfully be exported to TVM + using torch.export and that the resuling IR program gives the same result + as PyTorch when ran on CUDA. + """ + torch_data = torch.from_numpy(raw_data) + example_args = (torch_data,) + + with torch.no_grad(): + exported_program = export(torch_module, example_args) + mod_from_torch = from_exported_program( + exported_program, keep_params_as_input=True + ) + + tvm_mod, tvm_params = relax.frontend.detach_params(mod_from_torch) + target = tvm.target.Target.from_device(tvm.cuda()) + + ex = relax.build(tvm_mod, target=target, + relax_pipeline=relax.get_default_pipeline(target)) + dev = tvm.device("cuda", 0) + vm = relax.VirtualMachine(ex, dev) + + gpu_data = tvm.nd.array(raw_data, dev) + gpu_params = [tvm.nd.array(p, dev) for p in tvm_params["main"]] + gpu_out = vm["main"](gpu_data, *gpu_params) + + pytorch_out = torch_module(torch_data).detach().numpy() + actual = gpu_out[0].numpy() + desired = pytorch_out + np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, + atol=1e-5) + +def test_linalg_vector_norm(): + + class VectorNorm0(torch.nn.Module): + def forward(self, x): + return torch.linalg.vector_norm(x, ord=1, dim=-1) + + class VectorNorm1(torch.nn.Module): + def forward(self, x): + return torch.linalg.vector_norm(x, ord=2, dim=2) + + class VectorNorm2(torch.nn.Module): + def forward(self, x): + return torch.linalg.vector_norm(x, ord=1, dim=-1) + + class VectorNorm3(torch.nn.Module): + def forward(self, x): + return torch.linalg.vector_norm(x, ord=2, dim=2) + + raw_data = np.random.randn(2,3,4,10).astype(np.float32) + + torch_module0 = VectorNorm0().eval() + torch_module1 = VectorNorm1().eval() + torch_module2 = VectorNorm2().eval() + torch_module3 = VectorNorm3().eval() + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0) + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module1) + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module2) + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module3) + + +if __name__ == "__main__": + tvm.testing.main() From bc66dccd68bd0395ebae856b1a3db11c5a4bbcb2 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 9 Mar 2025 17:41:49 -0400 Subject: [PATCH 2/5] Black formatter --- .../torch/base_fx_graph_translator.py | 21 +-- .../relax/test_from_exported_to_cuda.py | 27 ++-- .../test_frontend_from_exported_program.py | 136 ++++++++---------- 3 files changed, 82 insertions(+), 102 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 0dda9b562b47..f62a96326958 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -37,9 +37,9 @@ def __init__(self) -> None: self.env: Dict[fx.Node, relax.Expr] = {} self.params: Dict[torch.Tensor, relax.Expr] = {} self.block_builder: relax.BlockBuilder = None - self.convert_map: Dict[ - Union[torch.nn.Module, str], Callable[[fx.Node], relax.Var] - ] = self.create_convert_map() + self.convert_map: Dict[Union[torch.nn.Module, str], Callable[[fx.Node], relax.Var]] = ( + self.create_convert_map() + ) ########## Utilities ########## @@ -305,8 +305,7 @@ def call_binary_op(op, lhs, rhs): return intrinsic_op(lhs, rhs) return convert - - + ########## Linear Algebra ########## def _linalg_vector_norm(self, node: fx.Node) -> relax.Var: @@ -320,19 +319,14 @@ def _linalg_vector_norm(self, node: fx.Node) -> relax.Var: keepdim = args[3] if len(args) > 3 else False # If ord_val is a Python float/int, wrap it in a Relax const - # so that it matches data's dtype. + # so that it matches data's dtype. dtype = data.struct_info.dtype ord_expr = ( - ord_val - if isinstance(ord_val, relax.Expr) - else relax.const(float(ord_val), dtype) + ord_val if isinstance(ord_val, relax.Expr) else relax.const(float(ord_val), dtype) ) # Reciprocal reci_expr = ( - relax.op.divide( - relax.const(1.0, dtype), - ord_expr - ) + relax.op.divide(relax.const(1.0, dtype), ord_expr) if isinstance(ord_val, relax.Expr) else relax.const(1.0 / float(ord_val), dtype) ) @@ -348,7 +342,6 @@ def _linalg_vector_norm(self, node: fx.Node) -> relax.Var: return norm_val - ########## Neural Network ########## def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var: diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 269b2c5b9a77..fee6d755254f 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -27,8 +27,8 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module): """ - This util ensures that a torch module can successfully be exported to TVM - using torch.export and that the resuling IR program gives the same result + This util ensures that a torch module can successfully be exported to TVM + using torch.export and that the resuling IR program gives the same result as PyTorch when ran on CUDA. """ torch_data = torch.from_numpy(raw_data) @@ -36,15 +36,12 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module): with torch.no_grad(): exported_program = export(torch_module, example_args) - mod_from_torch = from_exported_program( - exported_program, keep_params_as_input=True - ) + mod_from_torch = from_exported_program(exported_program, keep_params_as_input=True) tvm_mod, tvm_params = relax.frontend.detach_params(mod_from_torch) target = tvm.target.Target.from_device(tvm.cuda()) - ex = relax.build(tvm_mod, target=target, - relax_pipeline=relax.get_default_pipeline(target)) + ex = relax.build(tvm_mod, target=target, relax_pipeline=relax.get_default_pipeline(target)) dev = tvm.device("cuda", 0) vm = relax.VirtualMachine(ex, dev) @@ -55,11 +52,11 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module): pytorch_out = torch_module(torch_data).detach().numpy() actual = gpu_out[0].numpy() desired = pytorch_out - np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, - atol=1e-5) + np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) + def test_linalg_vector_norm(): - + class VectorNorm0(torch.nn.Module): def forward(self, x): return torch.linalg.vector_norm(x, ord=1, dim=-1) @@ -67,19 +64,19 @@ def forward(self, x): class VectorNorm1(torch.nn.Module): def forward(self, x): return torch.linalg.vector_norm(x, ord=2, dim=2) - + class VectorNorm2(torch.nn.Module): def forward(self, x): return torch.linalg.vector_norm(x, ord=1, dim=-1) - + class VectorNorm3(torch.nn.Module): def forward(self, x): return torch.linalg.vector_norm(x, ord=2, dim=2) - - raw_data = np.random.randn(2,3,4,10).astype(np.float32) + + raw_data = np.random.randn(2, 3, 4, 10).astype(np.float32) torch_module0 = VectorNorm0().eval() - torch_module1 = VectorNorm1().eval() + torch_module1 = VectorNorm1().eval() torch_module2 = VectorNorm2().eval() torch_module3 = VectorNorm3().eval() diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 8ca335c2fe7a..77aac527bc06 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -82,7 +82,7 @@ def forward(self, input): class expected: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): with R.dataflow(): lv: R.Tensor((1, 3, 10, 10), dtype="float32") = relax_op(input_1) @@ -112,7 +112,7 @@ def forward(self, input): class expected: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="bool")): with R.dataflow(): lv: R.Tensor((1, 3, 10, 10), dtype="bool") = relax_op(input_1) @@ -135,7 +135,7 @@ def forward(self, input): class expected_clamp: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -163,7 +163,7 @@ def forward(self, input): class expected_dropout: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -191,7 +191,7 @@ def forward(self, input): class expected_gelu: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -220,7 +220,7 @@ def forward(self, input): class expected_hardsigmoid: @R.function def main( - inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): with R.dataflow(): lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0, R.const(3, "float32")) @@ -252,7 +252,7 @@ def forward(self, input): class expected1: @R.function def main( - inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): with R.dataflow(): lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0, R.const(3, "float32")) @@ -294,7 +294,7 @@ def forward(self, input): class expected_relu: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -323,7 +323,7 @@ def forward(self, input): class expected_sigmoid: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -352,7 +352,7 @@ def forward(self, input): class expected_silu: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -388,7 +388,7 @@ def forward(self, input): class expected1: @R.function def main( - inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): with R.dataflow(): lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip( @@ -425,7 +425,7 @@ def forward(self, input): class expected: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -456,7 +456,7 @@ def forward(self, input): class expected1: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -487,7 +487,7 @@ def forward(self, input): class expected1: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -512,7 +512,7 @@ def forward(self, input): class expected_tril: @R.function def main( - input_1: R.Tensor((10, 10), dtype="float32") + input_1: R.Tensor((10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -531,7 +531,7 @@ def forward(self, input): class expected_triu: @R.function def main( - input_1: R.Tensor((10, 10), dtype="float32") + input_1: R.Tensor((10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -795,7 +795,7 @@ def forward(self, input): class expected1: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -883,7 +883,7 @@ def forward(self, input): class expected1: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -1580,7 +1580,7 @@ def forward(self, x, y): class Expected1: @R.function def main( - inp_0: R.Tensor((4, 4), dtype="float32") + inp_0: R.Tensor((4, 4), dtype="float32"), ) -> R.Tuple(R.Tensor((), dtype="float32")): with R.dataflow(): lv: R.Tensor((), dtype="float32") = R.einsum((inp_0,), subscripts="ii") @@ -1827,7 +1827,7 @@ def forward(self, input): class expected1: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -1856,7 +1856,7 @@ def forward(self, input): class expected2: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 4, 4), dtype="float32")): # block 0 with R.dataflow(): @@ -1885,7 +1885,7 @@ def forward(self, input): class expected3: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 6, 6), dtype="float32")): # block 0 with R.dataflow(): @@ -2007,9 +2007,7 @@ def forward(self, data): @tvm.script.ir_module class expected1: @R.function - def main( - input_1: R.Tensor((3, 3, 10, 10), dtype="float32") - ) -> R.Tuple( + def main(input_1: R.Tensor((3, 3, 10, 10), dtype="float32")) -> R.Tuple( R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), @@ -2051,9 +2049,7 @@ def forward(self, data): @tvm.script.ir_module class expected2: @R.function - def main( - input_1: R.Tensor((3, 3, 10, 10), dtype="float32") - ) -> R.Tuple( + def main(input_1: R.Tensor((3, 3, 10, 10), dtype="float32")) -> R.Tuple( R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), @@ -2102,7 +2098,7 @@ def forward(self, input): class expected_bilinear: @R.function def main( - input: R.Tensor((1, 3, 112, 112), dtype="float32") + input: R.Tensor((1, 3, 112, 112), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")): # block 0 with R.dataflow(): @@ -2131,7 +2127,7 @@ def forward(self, input): class expected_nearest: @R.function def main( - input: R.Tensor((1, 3, 112, 112), dtype="float32") + input: R.Tensor((1, 3, 112, 112), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")): # block 0 with R.dataflow(): @@ -2170,7 +2166,7 @@ def forward(self, input: torch.Tensor): class Expected1: @R.function def main( - inp_0: R.Tensor((256, 256), dtype="float32") + inp_0: R.Tensor((256, 256), dtype="float32"), ) -> R.Tuple(R.Tensor((256,), dtype="float32")): with R.dataflow(): lv: R.Tensor((256,), dtype="float32") = R.mean(inp_0, axis=[-1], keepdims=False) @@ -2182,7 +2178,7 @@ def main( class Expected2: @R.function def main( - inp_0: R.Tensor((256, 256), dtype="float32") + inp_0: R.Tensor((256, 256), dtype="float32"), ) -> R.Tuple(R.Tensor((256, 1), dtype="float32")): with R.dataflow(): lv: R.Tensor((256, 1), dtype="float32") = R.mean(inp_0, axis=[-1], keepdims=True) @@ -2204,7 +2200,7 @@ def forward(self, x): class expected1: @R.function def main( - inp_0: R.Tensor((1, 2, 3, 4), dtype="float32") + inp_0: R.Tensor((1, 2, 3, 4), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 4), dtype="float32")): # block 0 with R.dataflow(): @@ -2238,7 +2234,7 @@ def forward(self, input): class expected_argmax1: @R.function def main( - inp_0: R.Tensor((256, 256), dtype="float32") + inp_0: R.Tensor((256, 256), dtype="float32"), ) -> R.Tuple(R.Tensor((256,), dtype="int64")): with R.dataflow(): lv: R.Tensor((256,), dtype="int64") = R.argmax(inp_0, axis=-1, keepdims=False) @@ -2250,7 +2246,7 @@ def main( class expected_argmax2: @R.function def main( - inp_0: R.Tensor((256, 256), dtype="float32") + inp_0: R.Tensor((256, 256), dtype="float32"), ) -> R.Tuple(R.Tensor((256, 1), dtype="int64")): with R.dataflow(): lv: R.Tensor((256, 1), dtype="int64") = R.argmax(inp_0, axis=-1, keepdims=True) @@ -2279,7 +2275,7 @@ def forward(self, input): class expected_argmin1: @R.function def main( - inp_0: R.Tensor((256, 256), dtype="float32") + inp_0: R.Tensor((256, 256), dtype="float32"), ) -> R.Tuple(R.Tensor((), dtype="int64")): with R.dataflow(): lv: R.Tensor((), dtype="int64") = R.argmin(inp_0, axis=None, keepdims=False) @@ -2291,7 +2287,7 @@ def main( class expected_argmin2: @R.function def main( - inp_0: R.Tensor((256, 256), dtype="float32") + inp_0: R.Tensor((256, 256), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 1), dtype="int64")): with R.dataflow(): lv: R.Tensor((1, 1), dtype="int64") = R.argmin(inp_0, axis=None, keepdims=True) @@ -2362,7 +2358,7 @@ def forward(self, input): class expected1: @R.function def main( - input_1: R.Tensor((1, 2, 3, 4), dtype="float32") + input_1: R.Tensor((1, 2, 3, 4), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="int32")): # block 0 with R.dataflow(): @@ -2388,7 +2384,7 @@ def forward(self, x): class expected1: @R.function def main( - x: R.Tensor((1, 2, 3, 4), dtype="float32") + x: R.Tensor((1, 2, 3, 4), dtype="float32"), ) -> R.Tuple(R.Tensor((4, 2, 3, 4), dtype="float32")): # block 0 with R.dataflow(): @@ -2419,7 +2415,7 @@ def forward(self, input): class expected1: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 100), dtype="float32")): # block 0 with R.dataflow(): @@ -2445,7 +2441,7 @@ def forward(self, x): class expected1: @R.function def main( - x: R.Tensor((1, 2, 3, 4), dtype="float32") + x: R.Tensor((1, 2, 3, 4), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 4, 3, 2), dtype="float32")): # block 0 with R.dataflow(): @@ -2483,7 +2479,7 @@ def main(x: R.Tensor((3,), dtype="float32")) -> R.Tuple(R.Tensor((6,), dtype="fl class expected2: @R.function def main( - x: R.Tensor((1, 3), dtype="float32") + x: R.Tensor((1, 3), dtype="float32"), ) -> R.Tuple(R.Tensor((4, 6), dtype="float32")): # block 0 with R.dataflow(): @@ -2511,7 +2507,7 @@ def forward(self, x): class expected1: @R.function def main( - x: R.Tensor((1, 2, 3, 4), dtype="float32") + x: R.Tensor((1, 2, 3, 4), dtype="float32"), ) -> R.Tuple(R.Tensor((2, 12), dtype="float32")): # block 0 with R.dataflow(): @@ -2533,7 +2529,7 @@ def forward(self, x): class expected1: @R.function def main( - x: R.Tensor((1, 3, 10, 10), dtype="float32") + x: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 10, 3), dtype="float32")): # block 0 with R.dataflow(): @@ -2574,7 +2570,7 @@ def forward(self, x): class expected2: @R.function def main( - x: R.Tensor((8, 16), dtype="float32") + x: R.Tensor((8, 16), dtype="float32"), ) -> R.Tuple(R.Tensor((8, 1, 1, 16, 1), dtype="float32")): with R.dataflow(): lv: R.Tensor((8, 16), dtype="float32") = R.strided_slice( @@ -2619,9 +2615,7 @@ def forward(self, input): @tvm.script.ir_module class Expected: @R.function - def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") - ) -> R.Tuple( + def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Tuple( R.Tensor((1, 1, 10, 10), dtype="float32"), R.Tensor((1, 1, 10, 10), dtype="float32"), R.Tensor((1, 1, 10, 10), dtype="float32"), @@ -2651,9 +2645,7 @@ def forward(self, data): @tvm.script.ir_module class expected1: @R.function - def main( - input_1: R.Tensor((3, 3, 10, 10), dtype="float32") - ) -> R.Tuple( + def main(input_1: R.Tensor((3, 3, 10, 10), dtype="float32")) -> R.Tuple( R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), @@ -2695,9 +2687,7 @@ def forward(self, data): @tvm.script.ir_module class expected2: @R.function - def main( - input_1: R.Tensor((3, 3, 10, 10), dtype="float32") - ) -> R.Tuple( + def main(input_1: R.Tensor((3, 3, 10, 10), dtype="float32")) -> R.Tuple( R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), @@ -2749,7 +2739,7 @@ def forward(self, input): class Expected1: @R.function def main( - inp_0: R.Tensor((3, 1, 4, 1), dtype="float32") + inp_0: R.Tensor((3, 1, 4, 1), dtype="float32"), ) -> R.Tuple(R.Tensor((3, 4, 1), dtype="float32")): with R.dataflow(): lv: R.Tensor((3, 4, 1), dtype="float32") = R.squeeze(inp_0, axis=[1]) @@ -2765,7 +2755,7 @@ def forward(self, input): class Expected2: @R.function def main( - inp_0: R.Tensor((3, 1, 4, 1), dtype="float32") + inp_0: R.Tensor((3, 1, 4, 1), dtype="float32"), ) -> R.Tuple(R.Tensor((3, 4), dtype="float32")): with R.dataflow(): lv: R.Tensor((3, 4), dtype="float32") = R.squeeze(inp_0, axis=None) @@ -2796,7 +2786,7 @@ def forward(self, x): class expected1: @R.function def main( - x: R.Tensor((1, 3), dtype="float32") + x: R.Tensor((1, 3), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 6), dtype="float32")): # block 0 with R.dataflow(): @@ -2809,7 +2799,7 @@ def main( class expected2: @R.function def main( - x: R.Tensor((1, 3), dtype="float32") + x: R.Tensor((1, 3), dtype="float32"), ) -> R.Tuple(R.Tensor((4, 6), dtype="float32")): # block 0 with R.dataflow(): @@ -2833,7 +2823,7 @@ def forward(self, x): class expected1: @R.function def main( - x: R.Tensor((1, 2, 3, 4), dtype="float32") + x: R.Tensor((1, 2, 3, 4), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 4, 3, 2), dtype="float32")): # block 0 with R.dataflow(): @@ -2855,7 +2845,7 @@ def forward(self, input): class expected1: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -2872,7 +2862,7 @@ def forward(self, input): class expected2: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10, 1), dtype="float32")): # block 0 with R.dataflow(): @@ -2896,7 +2886,7 @@ def forward(self, x): class expected1: @R.function def main( - x: R.Tensor((1, 2, 3, 4), dtype="float32") + x: R.Tensor((1, 2, 3, 4), dtype="float32"), ) -> R.Tuple(R.Tensor((2, 12), dtype="float32")): # block 0 with R.dataflow(): @@ -2918,7 +2908,7 @@ def forward(self, input): class Expected: @R.function def main( - input: R.Tensor((10, 10), dtype="float32") + input: R.Tensor((10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((20,), dtype="int32")): with R.dataflow(): lv: R.Tensor((20,), dtype="int32") = R.arange(0, 20, 1, dtype="int32") @@ -2939,7 +2929,7 @@ def forward(self, input): class Expected: @R.function def main( - input: R.Tensor((10, 10), dtype="float32") + input: R.Tensor((10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): with R.dataflow(): gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (input,) @@ -2959,7 +2949,7 @@ def forward(self, input): class Expected: @R.function def main( - inp_0: R.Tensor((10, 10), dtype="float32") + inp_0: R.Tensor((10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): with R.dataflow(): lv: R.Tensor((10, 10), dtype="float32") = R.zeros( @@ -2982,7 +2972,7 @@ def forward(self, input: torch.Tensor): class Expected: @R.function def main( - inp_0: R.Tensor((10, 10), dtype="float32") + inp_0: R.Tensor((10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): with R.dataflow(): lv: R.Tensor((10, 10), dtype="float32") = R.full( @@ -3005,7 +2995,7 @@ def forward(self, x): class expected1: @R.function def main( - x: R.Tensor((1, 2, 3), dtype="float32") + x: R.Tensor((1, 2, 3), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 2, 3), dtype="float32")): # block 0 with R.dataflow(): @@ -3034,7 +3024,7 @@ def forward(self, x): class expected_float: @R.function def main( - x: R.Tensor((1, 2, 3, 4), dtype="float32") + x: R.Tensor((1, 2, 3, 4), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")): # block 0 with R.dataflow(): @@ -3052,7 +3042,7 @@ def forward(self, x): class expected_half: @R.function def main( - x: R.Tensor((1, 2, 3, 4), dtype="float32") + x: R.Tensor((1, 2, 3, 4), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float16")): # block 0 with R.dataflow(): @@ -3070,7 +3060,7 @@ def forward(self, x): class expected_type: @R.function def main( - x: R.Tensor((1, 2, 3, 4), dtype="float32") + x: R.Tensor((1, 2, 3, 4), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")): # block 0 with R.dataflow(): @@ -3086,7 +3076,7 @@ def forward(self, input): class expected_to1: @R.function def main( - inp_0: R.Tensor((1, 2, 3, 4), dtype="float32") + inp_0: R.Tensor((1, 2, 3, 4), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float16")): with R.dataflow(): lv: R.Tensor((1, 2, 3, 4), dtype="float16") = R.astype(inp_0, dtype="float16") @@ -3102,7 +3092,7 @@ def forward(self, input): class expected_to2: @R.function def main( - inp_0: R.Tensor((1, 2, 3, 4), dtype="float32") + inp_0: R.Tensor((1, 2, 3, 4), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")): with R.dataflow(): lv: R.Tensor((1, 2, 3, 4), dtype="float32") = R.astype(inp_0, dtype="float32") @@ -3187,7 +3177,7 @@ def forward(self, x): class Expected: @R.function def main( - inp_0: R.Tensor((256, 256), dtype="float32") + inp_0: R.Tensor((256, 256), dtype="float32"), ) -> R.Tensor((256, 256), dtype="float32"): with R.dataflow(): gv: R.Tensor((256, 256), dtype="float32") = inp_0 From ed372bc09836cbb477dbbaee5398521cc0b00451 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 10 Mar 2025 03:11:58 -0400 Subject: [PATCH 3/5] cuda target in new test --- .../relax/test_from_exported_to_cuda.py | 31 ++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index fee6d755254f..e92e3c2e6af3 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -15,22 +15,23 @@ # specific language governing permissions and limitations # under the License. -import tvm -from tvm import relax -import tvm.testing import numpy as np import torch from torch.export import export + +import tvm +import tvm.testing +from tvm import relax from tvm.relax.frontend.torch import from_exported_program -from torch.nn import Softmax, Upsample -def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module): +def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev): """ This util ensures that a torch module can successfully be exported to TVM using torch.export and that the resuling IR program gives the same result as PyTorch when ran on CUDA. """ + raw_data_for_tvm = raw_data.copy() # In case the data is modified torch_data = torch.from_numpy(raw_data) example_args = (torch_data,) @@ -39,13 +40,14 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module): mod_from_torch = from_exported_program(exported_program, keep_params_as_input=True) tvm_mod, tvm_params = relax.frontend.detach_params(mod_from_torch) - target = tvm.target.Target.from_device(tvm.cuda()) - ex = relax.build(tvm_mod, target=target, relax_pipeline=relax.get_default_pipeline(target)) - dev = tvm.device("cuda", 0) + relax_pipeline = relax.get_default_pipeline(tvm.target.Target.from_device(tvm.cuda())) + # TODO try pipeline below? + # releax_pipeline = relax.backend.cuda.pipeline.get_default_pipeline(target) + ex = relax.build(tvm_mod, target=target, relax_pipeline=relax_pipeline) vm = relax.VirtualMachine(ex, dev) - gpu_data = tvm.nd.array(raw_data, dev) + gpu_data = tvm.nd.array(raw_data_for_tvm, dev) gpu_params = [tvm.nd.array(p, dev) for p in tvm_params["main"]] gpu_out = vm["main"](gpu_data, *gpu_params) @@ -55,7 +57,8 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module): np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) -def test_linalg_vector_norm(): +@tvm.testing.parametrize_targets("cuda") +def test_linalg_vector_norm(target, dev): class VectorNorm0(torch.nn.Module): def forward(self, x): @@ -80,10 +83,10 @@ def forward(self, x): torch_module2 = VectorNorm2().eval() torch_module3 = VectorNorm3().eval() - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0) - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module1) - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module2) - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module3) + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module1, target, dev) + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module2, target, dev) + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module3, target, dev) if __name__ == "__main__": From 748c60a1565fca7213f984206b67894c2e59c9ef Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 10 Mar 2025 03:41:22 -0400 Subject: [PATCH 4/5] ran Black Python formatter with version 22 --- .../torch/base_fx_graph_translator.py | 6 +++--- .../relax/test_from_exported_to_cuda.py | 1 - .../test_frontend_from_exported_program.py | 20 ++++++++++++++----- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index f62a96326958..f54c045f83ec 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -37,9 +37,9 @@ def __init__(self) -> None: self.env: Dict[fx.Node, relax.Expr] = {} self.params: Dict[torch.Tensor, relax.Expr] = {} self.block_builder: relax.BlockBuilder = None - self.convert_map: Dict[Union[torch.nn.Module, str], Callable[[fx.Node], relax.Var]] = ( - self.create_convert_map() - ) + self.convert_map: Dict[ + Union[torch.nn.Module, str], Callable[[fx.Node], relax.Var] + ] = self.create_convert_map() ########## Utilities ########## diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index e92e3c2e6af3..d39bb8e9fea3 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -59,7 +59,6 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar @tvm.testing.parametrize_targets("cuda") def test_linalg_vector_norm(target, dev): - class VectorNorm0(torch.nn.Module): def forward(self, x): return torch.linalg.vector_norm(x, ord=1, dim=-1) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 77aac527bc06..399739146359 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -2007,7 +2007,9 @@ def forward(self, data): @tvm.script.ir_module class expected1: @R.function - def main(input_1: R.Tensor((3, 3, 10, 10), dtype="float32")) -> R.Tuple( + def main( + input_1: R.Tensor((3, 3, 10, 10), dtype="float32") + ) -> R.Tuple( R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), @@ -2049,7 +2051,9 @@ def forward(self, data): @tvm.script.ir_module class expected2: @R.function - def main(input_1: R.Tensor((3, 3, 10, 10), dtype="float32")) -> R.Tuple( + def main( + input_1: R.Tensor((3, 3, 10, 10), dtype="float32") + ) -> R.Tuple( R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), @@ -2615,7 +2619,9 @@ def forward(self, input): @tvm.script.ir_module class Expected: @R.function - def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Tuple( + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple( R.Tensor((1, 1, 10, 10), dtype="float32"), R.Tensor((1, 1, 10, 10), dtype="float32"), R.Tensor((1, 1, 10, 10), dtype="float32"), @@ -2645,7 +2651,9 @@ def forward(self, data): @tvm.script.ir_module class expected1: @R.function - def main(input_1: R.Tensor((3, 3, 10, 10), dtype="float32")) -> R.Tuple( + def main( + input_1: R.Tensor((3, 3, 10, 10), dtype="float32") + ) -> R.Tuple( R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), @@ -2687,7 +2695,9 @@ def forward(self, data): @tvm.script.ir_module class expected2: @R.function - def main(input_1: R.Tensor((3, 3, 10, 10), dtype="float32")) -> R.Tuple( + def main( + input_1: R.Tensor((3, 3, 10, 10), dtype="float32") + ) -> R.Tuple( R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), From 482eed1eac81640ea3902a2dca2d227baa28d2d2 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 10 Mar 2025 04:30:03 -0400 Subject: [PATCH 5/5] restore unmodified frontend test --- .../test_frontend_from_exported_program.py | 116 +++++++++--------- 1 file changed, 58 insertions(+), 58 deletions(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 399739146359..8ca335c2fe7a 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -82,7 +82,7 @@ def forward(self, input): class expected: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): with R.dataflow(): lv: R.Tensor((1, 3, 10, 10), dtype="float32") = relax_op(input_1) @@ -112,7 +112,7 @@ def forward(self, input): class expected: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="bool")): with R.dataflow(): lv: R.Tensor((1, 3, 10, 10), dtype="bool") = relax_op(input_1) @@ -135,7 +135,7 @@ def forward(self, input): class expected_clamp: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -163,7 +163,7 @@ def forward(self, input): class expected_dropout: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -191,7 +191,7 @@ def forward(self, input): class expected_gelu: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -220,7 +220,7 @@ def forward(self, input): class expected_hardsigmoid: @R.function def main( - inp_0: R.Tensor((1, 3, 10, 10), dtype="float32"), + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): with R.dataflow(): lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0, R.const(3, "float32")) @@ -252,7 +252,7 @@ def forward(self, input): class expected1: @R.function def main( - inp_0: R.Tensor((1, 3, 10, 10), dtype="float32"), + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): with R.dataflow(): lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0, R.const(3, "float32")) @@ -294,7 +294,7 @@ def forward(self, input): class expected_relu: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -323,7 +323,7 @@ def forward(self, input): class expected_sigmoid: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -352,7 +352,7 @@ def forward(self, input): class expected_silu: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -388,7 +388,7 @@ def forward(self, input): class expected1: @R.function def main( - inp_0: R.Tensor((1, 3, 10, 10), dtype="float32"), + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): with R.dataflow(): lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip( @@ -425,7 +425,7 @@ def forward(self, input): class expected: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -456,7 +456,7 @@ def forward(self, input): class expected1: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -487,7 +487,7 @@ def forward(self, input): class expected1: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -512,7 +512,7 @@ def forward(self, input): class expected_tril: @R.function def main( - input_1: R.Tensor((10, 10), dtype="float32"), + input_1: R.Tensor((10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -531,7 +531,7 @@ def forward(self, input): class expected_triu: @R.function def main( - input_1: R.Tensor((10, 10), dtype="float32"), + input_1: R.Tensor((10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -795,7 +795,7 @@ def forward(self, input): class expected1: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -883,7 +883,7 @@ def forward(self, input): class expected1: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -1580,7 +1580,7 @@ def forward(self, x, y): class Expected1: @R.function def main( - inp_0: R.Tensor((4, 4), dtype="float32"), + inp_0: R.Tensor((4, 4), dtype="float32") ) -> R.Tuple(R.Tensor((), dtype="float32")): with R.dataflow(): lv: R.Tensor((), dtype="float32") = R.einsum((inp_0,), subscripts="ii") @@ -1827,7 +1827,7 @@ def forward(self, input): class expected1: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -1856,7 +1856,7 @@ def forward(self, input): class expected2: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 4, 4), dtype="float32")): # block 0 with R.dataflow(): @@ -1885,7 +1885,7 @@ def forward(self, input): class expected3: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 6, 6), dtype="float32")): # block 0 with R.dataflow(): @@ -2102,7 +2102,7 @@ def forward(self, input): class expected_bilinear: @R.function def main( - input: R.Tensor((1, 3, 112, 112), dtype="float32"), + input: R.Tensor((1, 3, 112, 112), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")): # block 0 with R.dataflow(): @@ -2131,7 +2131,7 @@ def forward(self, input): class expected_nearest: @R.function def main( - input: R.Tensor((1, 3, 112, 112), dtype="float32"), + input: R.Tensor((1, 3, 112, 112), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")): # block 0 with R.dataflow(): @@ -2170,7 +2170,7 @@ def forward(self, input: torch.Tensor): class Expected1: @R.function def main( - inp_0: R.Tensor((256, 256), dtype="float32"), + inp_0: R.Tensor((256, 256), dtype="float32") ) -> R.Tuple(R.Tensor((256,), dtype="float32")): with R.dataflow(): lv: R.Tensor((256,), dtype="float32") = R.mean(inp_0, axis=[-1], keepdims=False) @@ -2182,7 +2182,7 @@ def main( class Expected2: @R.function def main( - inp_0: R.Tensor((256, 256), dtype="float32"), + inp_0: R.Tensor((256, 256), dtype="float32") ) -> R.Tuple(R.Tensor((256, 1), dtype="float32")): with R.dataflow(): lv: R.Tensor((256, 1), dtype="float32") = R.mean(inp_0, axis=[-1], keepdims=True) @@ -2204,7 +2204,7 @@ def forward(self, x): class expected1: @R.function def main( - inp_0: R.Tensor((1, 2, 3, 4), dtype="float32"), + inp_0: R.Tensor((1, 2, 3, 4), dtype="float32") ) -> R.Tuple(R.Tensor((1, 4), dtype="float32")): # block 0 with R.dataflow(): @@ -2238,7 +2238,7 @@ def forward(self, input): class expected_argmax1: @R.function def main( - inp_0: R.Tensor((256, 256), dtype="float32"), + inp_0: R.Tensor((256, 256), dtype="float32") ) -> R.Tuple(R.Tensor((256,), dtype="int64")): with R.dataflow(): lv: R.Tensor((256,), dtype="int64") = R.argmax(inp_0, axis=-1, keepdims=False) @@ -2250,7 +2250,7 @@ def main( class expected_argmax2: @R.function def main( - inp_0: R.Tensor((256, 256), dtype="float32"), + inp_0: R.Tensor((256, 256), dtype="float32") ) -> R.Tuple(R.Tensor((256, 1), dtype="int64")): with R.dataflow(): lv: R.Tensor((256, 1), dtype="int64") = R.argmax(inp_0, axis=-1, keepdims=True) @@ -2279,7 +2279,7 @@ def forward(self, input): class expected_argmin1: @R.function def main( - inp_0: R.Tensor((256, 256), dtype="float32"), + inp_0: R.Tensor((256, 256), dtype="float32") ) -> R.Tuple(R.Tensor((), dtype="int64")): with R.dataflow(): lv: R.Tensor((), dtype="int64") = R.argmin(inp_0, axis=None, keepdims=False) @@ -2291,7 +2291,7 @@ def main( class expected_argmin2: @R.function def main( - inp_0: R.Tensor((256, 256), dtype="float32"), + inp_0: R.Tensor((256, 256), dtype="float32") ) -> R.Tuple(R.Tensor((1, 1), dtype="int64")): with R.dataflow(): lv: R.Tensor((1, 1), dtype="int64") = R.argmin(inp_0, axis=None, keepdims=True) @@ -2362,7 +2362,7 @@ def forward(self, input): class expected1: @R.function def main( - input_1: R.Tensor((1, 2, 3, 4), dtype="float32"), + input_1: R.Tensor((1, 2, 3, 4), dtype="float32") ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="int32")): # block 0 with R.dataflow(): @@ -2388,7 +2388,7 @@ def forward(self, x): class expected1: @R.function def main( - x: R.Tensor((1, 2, 3, 4), dtype="float32"), + x: R.Tensor((1, 2, 3, 4), dtype="float32") ) -> R.Tuple(R.Tensor((4, 2, 3, 4), dtype="float32")): # block 0 with R.dataflow(): @@ -2419,7 +2419,7 @@ def forward(self, input): class expected1: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 100), dtype="float32")): # block 0 with R.dataflow(): @@ -2445,7 +2445,7 @@ def forward(self, x): class expected1: @R.function def main( - x: R.Tensor((1, 2, 3, 4), dtype="float32"), + x: R.Tensor((1, 2, 3, 4), dtype="float32") ) -> R.Tuple(R.Tensor((1, 4, 3, 2), dtype="float32")): # block 0 with R.dataflow(): @@ -2483,7 +2483,7 @@ def main(x: R.Tensor((3,), dtype="float32")) -> R.Tuple(R.Tensor((6,), dtype="fl class expected2: @R.function def main( - x: R.Tensor((1, 3), dtype="float32"), + x: R.Tensor((1, 3), dtype="float32") ) -> R.Tuple(R.Tensor((4, 6), dtype="float32")): # block 0 with R.dataflow(): @@ -2511,7 +2511,7 @@ def forward(self, x): class expected1: @R.function def main( - x: R.Tensor((1, 2, 3, 4), dtype="float32"), + x: R.Tensor((1, 2, 3, 4), dtype="float32") ) -> R.Tuple(R.Tensor((2, 12), dtype="float32")): # block 0 with R.dataflow(): @@ -2533,7 +2533,7 @@ def forward(self, x): class expected1: @R.function def main( - x: R.Tensor((1, 3, 10, 10), dtype="float32"), + x: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 10, 3), dtype="float32")): # block 0 with R.dataflow(): @@ -2574,7 +2574,7 @@ def forward(self, x): class expected2: @R.function def main( - x: R.Tensor((8, 16), dtype="float32"), + x: R.Tensor((8, 16), dtype="float32") ) -> R.Tuple(R.Tensor((8, 1, 1, 16, 1), dtype="float32")): with R.dataflow(): lv: R.Tensor((8, 16), dtype="float32") = R.strided_slice( @@ -2749,7 +2749,7 @@ def forward(self, input): class Expected1: @R.function def main( - inp_0: R.Tensor((3, 1, 4, 1), dtype="float32"), + inp_0: R.Tensor((3, 1, 4, 1), dtype="float32") ) -> R.Tuple(R.Tensor((3, 4, 1), dtype="float32")): with R.dataflow(): lv: R.Tensor((3, 4, 1), dtype="float32") = R.squeeze(inp_0, axis=[1]) @@ -2765,7 +2765,7 @@ def forward(self, input): class Expected2: @R.function def main( - inp_0: R.Tensor((3, 1, 4, 1), dtype="float32"), + inp_0: R.Tensor((3, 1, 4, 1), dtype="float32") ) -> R.Tuple(R.Tensor((3, 4), dtype="float32")): with R.dataflow(): lv: R.Tensor((3, 4), dtype="float32") = R.squeeze(inp_0, axis=None) @@ -2796,7 +2796,7 @@ def forward(self, x): class expected1: @R.function def main( - x: R.Tensor((1, 3), dtype="float32"), + x: R.Tensor((1, 3), dtype="float32") ) -> R.Tuple(R.Tensor((1, 6), dtype="float32")): # block 0 with R.dataflow(): @@ -2809,7 +2809,7 @@ def main( class expected2: @R.function def main( - x: R.Tensor((1, 3), dtype="float32"), + x: R.Tensor((1, 3), dtype="float32") ) -> R.Tuple(R.Tensor((4, 6), dtype="float32")): # block 0 with R.dataflow(): @@ -2833,7 +2833,7 @@ def forward(self, x): class expected1: @R.function def main( - x: R.Tensor((1, 2, 3, 4), dtype="float32"), + x: R.Tensor((1, 2, 3, 4), dtype="float32") ) -> R.Tuple(R.Tensor((1, 4, 3, 2), dtype="float32")): # block 0 with R.dataflow(): @@ -2855,7 +2855,7 @@ def forward(self, input): class expected1: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -2872,7 +2872,7 @@ def forward(self, input): class expected2: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10, 1), dtype="float32")): # block 0 with R.dataflow(): @@ -2896,7 +2896,7 @@ def forward(self, x): class expected1: @R.function def main( - x: R.Tensor((1, 2, 3, 4), dtype="float32"), + x: R.Tensor((1, 2, 3, 4), dtype="float32") ) -> R.Tuple(R.Tensor((2, 12), dtype="float32")): # block 0 with R.dataflow(): @@ -2918,7 +2918,7 @@ def forward(self, input): class Expected: @R.function def main( - input: R.Tensor((10, 10), dtype="float32"), + input: R.Tensor((10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((20,), dtype="int32")): with R.dataflow(): lv: R.Tensor((20,), dtype="int32") = R.arange(0, 20, 1, dtype="int32") @@ -2939,7 +2939,7 @@ def forward(self, input): class Expected: @R.function def main( - input: R.Tensor((10, 10), dtype="float32"), + input: R.Tensor((10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): with R.dataflow(): gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (input,) @@ -2959,7 +2959,7 @@ def forward(self, input): class Expected: @R.function def main( - inp_0: R.Tensor((10, 10), dtype="float32"), + inp_0: R.Tensor((10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): with R.dataflow(): lv: R.Tensor((10, 10), dtype="float32") = R.zeros( @@ -2982,7 +2982,7 @@ def forward(self, input: torch.Tensor): class Expected: @R.function def main( - inp_0: R.Tensor((10, 10), dtype="float32"), + inp_0: R.Tensor((10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): with R.dataflow(): lv: R.Tensor((10, 10), dtype="float32") = R.full( @@ -3005,7 +3005,7 @@ def forward(self, x): class expected1: @R.function def main( - x: R.Tensor((1, 2, 3), dtype="float32"), + x: R.Tensor((1, 2, 3), dtype="float32") ) -> R.Tuple(R.Tensor((1, 2, 3), dtype="float32")): # block 0 with R.dataflow(): @@ -3034,7 +3034,7 @@ def forward(self, x): class expected_float: @R.function def main( - x: R.Tensor((1, 2, 3, 4), dtype="float32"), + x: R.Tensor((1, 2, 3, 4), dtype="float32") ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")): # block 0 with R.dataflow(): @@ -3052,7 +3052,7 @@ def forward(self, x): class expected_half: @R.function def main( - x: R.Tensor((1, 2, 3, 4), dtype="float32"), + x: R.Tensor((1, 2, 3, 4), dtype="float32") ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float16")): # block 0 with R.dataflow(): @@ -3070,7 +3070,7 @@ def forward(self, x): class expected_type: @R.function def main( - x: R.Tensor((1, 2, 3, 4), dtype="float32"), + x: R.Tensor((1, 2, 3, 4), dtype="float32") ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")): # block 0 with R.dataflow(): @@ -3086,7 +3086,7 @@ def forward(self, input): class expected_to1: @R.function def main( - inp_0: R.Tensor((1, 2, 3, 4), dtype="float32"), + inp_0: R.Tensor((1, 2, 3, 4), dtype="float32") ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float16")): with R.dataflow(): lv: R.Tensor((1, 2, 3, 4), dtype="float16") = R.astype(inp_0, dtype="float16") @@ -3102,7 +3102,7 @@ def forward(self, input): class expected_to2: @R.function def main( - inp_0: R.Tensor((1, 2, 3, 4), dtype="float32"), + inp_0: R.Tensor((1, 2, 3, 4), dtype="float32") ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")): with R.dataflow(): lv: R.Tensor((1, 2, 3, 4), dtype="float32") = R.astype(inp_0, dtype="float32") @@ -3187,7 +3187,7 @@ def forward(self, x): class Expected: @R.function def main( - inp_0: R.Tensor((256, 256), dtype="float32"), + inp_0: R.Tensor((256, 256), dtype="float32") ) -> R.Tensor((256, 256), dtype="float32"): with R.dataflow(): gv: R.Tensor((256, 256), dtype="float32") = inp_0