diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 003ceebec6ff..f54c045f83ec 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -306,6 +306,42 @@ def call_binary_op(op, lhs, rhs): return convert + ########## Linear Algebra ########## + + def _linalg_vector_norm(self, node: fx.Node) -> relax.Var: + + args = self.retrieve_args(node) + + data = args[0] + # Default ord=2 if not supplied + ord_val = args[1] if len(args) > 1 else 2.0 + dim = args[2] if len(args) > 2 else None + keepdim = args[3] if len(args) > 3 else False + + # If ord_val is a Python float/int, wrap it in a Relax const + # so that it matches data's dtype. + dtype = data.struct_info.dtype + ord_expr = ( + ord_val if isinstance(ord_val, relax.Expr) else relax.const(float(ord_val), dtype) + ) + # Reciprocal + reci_expr = ( + relax.op.divide(relax.const(1.0, dtype), ord_expr) + if isinstance(ord_val, relax.Expr) + else relax.const(1.0 / float(ord_val), dtype) + ) + + # abs(data) + abs_data = self.block_builder.emit(relax.op.abs(data)) + # abs_data^ord + abs_data_pow = self.block_builder.emit(relax.op.power(abs_data, ord_expr)) + # sum over dim + reduced = self.block_builder.emit(relax.op.sum(abs_data_pow, dim, keepdims=keepdim)) + # (sum(...))^(1/ord) + norm_val = self.block_builder.emit(relax.op.power(reduced, reci_expr)) + + return norm_val + ########## Neural Network ########## def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index c8d9d12505c6..1e5f7d6006de 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -231,6 +231,8 @@ def create_convert_map( "__or__.Scalar": self._binary_op(relax.op.bitwise_or, operator.or_), "__xor__.Tensor": self._binary_op(relax.op.bitwise_xor, operator.xor), "__xor__.Scalar": self._binary_op(relax.op.bitwise_xor, operator.xor), + # linear algebra + "linalg_vector_norm.default": self._linalg_vector_norm, # neural network "_native_batch_norm_legit_no_training.default": self._batch_norm_legit_no_training, "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d, diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py new file mode 100644 index 000000000000..d39bb8e9fea3 --- /dev/null +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -0,0 +1,92 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np +import torch +from torch.export import export + +import tvm +import tvm.testing +from tvm import relax +from tvm.relax.frontend.torch import from_exported_program + + +def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev): + """ + This util ensures that a torch module can successfully be exported to TVM + using torch.export and that the resuling IR program gives the same result + as PyTorch when ran on CUDA. + """ + raw_data_for_tvm = raw_data.copy() # In case the data is modified + torch_data = torch.from_numpy(raw_data) + example_args = (torch_data,) + + with torch.no_grad(): + exported_program = export(torch_module, example_args) + mod_from_torch = from_exported_program(exported_program, keep_params_as_input=True) + + tvm_mod, tvm_params = relax.frontend.detach_params(mod_from_torch) + + relax_pipeline = relax.get_default_pipeline(tvm.target.Target.from_device(tvm.cuda())) + # TODO try pipeline below? + # releax_pipeline = relax.backend.cuda.pipeline.get_default_pipeline(target) + ex = relax.build(tvm_mod, target=target, relax_pipeline=relax_pipeline) + vm = relax.VirtualMachine(ex, dev) + + gpu_data = tvm.nd.array(raw_data_for_tvm, dev) + gpu_params = [tvm.nd.array(p, dev) for p in tvm_params["main"]] + gpu_out = vm["main"](gpu_data, *gpu_params) + + pytorch_out = torch_module(torch_data).detach().numpy() + actual = gpu_out[0].numpy() + desired = pytorch_out + np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets("cuda") +def test_linalg_vector_norm(target, dev): + class VectorNorm0(torch.nn.Module): + def forward(self, x): + return torch.linalg.vector_norm(x, ord=1, dim=-1) + + class VectorNorm1(torch.nn.Module): + def forward(self, x): + return torch.linalg.vector_norm(x, ord=2, dim=2) + + class VectorNorm2(torch.nn.Module): + def forward(self, x): + return torch.linalg.vector_norm(x, ord=1, dim=-1) + + class VectorNorm3(torch.nn.Module): + def forward(self, x): + return torch.linalg.vector_norm(x, ord=2, dim=2) + + raw_data = np.random.randn(2, 3, 4, 10).astype(np.float32) + + torch_module0 = VectorNorm0().eval() + torch_module1 = VectorNorm1().eval() + torch_module2 = VectorNorm2().eval() + torch_module3 = VectorNorm3().eval() + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module1, target, dev) + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module2, target, dev) + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module3, target, dev) + + +if __name__ == "__main__": + tvm.testing.main()