diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 6a001b5a047c..d52b3d598f89 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -111,6 +111,80 @@ def convert(node: fx.Node) -> relax.Var: return convert + def _clamp(self, node: fx.Node) -> relax.Expr: + args = self.retrieve_args(node) + a_min = args[1] if len(args) > 1 else node.kwargs["min"] + a_max = args[2] if len(args) > 2 else node.kwargs["max"] + if not isinstance(a_min, (int, float)): + raise ValueError( + f"TVM only supports constant min value for torch.clamp/clip, " + f"but got {a_min} with type {type(a_min)}" + ) + if not isinstance(a_max, (int, float)): + raise ValueError( + f"TVM only supports constant max value for torch.clamp/clip, " + f"but got {a_max} with type {type(a_max)}" + ) + return self.block_builder.emit(relax.op.clip(args[0], a_min, a_max)) + + def _gelu(self, node: fx.Node) -> relax.Expr: + approximate = node.kwargs.get("approximate", "none") + if approximate == "none": + return self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]])) + elif approximate == "tanh": + return self.block_builder.emit(relax.op.nn.gelu_tanh(self.env[node.args[0]])) + else: + raise KeyError("Unregonized approximate algorithm for gelu: {}.".format(approximate)) + + def _hardsigmoid(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + dtype = x.struct_info.dtype + x0 = relax.op.add(x, relax.const(3, dtype)) + x1 = relax.op.clip(x0, 0, 6) + return self.block_builder.emit(relax.op.divide(x1, relax.const(6, dtype))) + + def _hardswish(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + dtype = x.struct_info.dtype + x0 = relax.op.add(x, relax.const(3, dtype)) + x1 = relax.op.clip(x0, 0, 6) + x2 = relax.op.divide(x1, relax.const(6, dtype)) + return self.block_builder.emit(relax.op.multiply(x, x2)) + + def _leakyrelu(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + alpha = node.args[1] if len(node.args) > 1 else node.kwargs.get("negative_slope", 0.01) + return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha)) + + def _log_softmax(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1) + return self.block_builder.emit(relax.op.nn.log_softmax(x, dim)) + + def _round(self, node: fx.Node) -> relax.Expr: + if node.kwargs.get("decimals", 0) != 0: + raise ValueError("specifying decimals for round is not supported yet") + arg = self.env[node.args[0]] + return self.block_builder.emit(relax.op.round(arg)) + + def _softmax(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1) + return self.block_builder.emit(relax.op.nn.softmax(x, dim)) + + def _tril_triu(self, op: Callable) -> Callable: + from torch import fx + + def convert(node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + k = node.args[1] if len(node.args) > 1 else node.kwargs.get("diagonal", 0) + assert isinstance(k, int) + return self.block_builder.emit(op(x, k)) + + return convert + ########## Neural Network ########## def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 9af422d1c3ca..1ceddad7d79f 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -64,13 +64,51 @@ def create_input_vars( return parameters_buffers_constants, user_inputs + ########## Unary Ops ########## + + def _hardtanh(self, node: fx.Node) -> relax.Expr: + args = self.retrieve_args(node) + x = args[0] + min_val = node.args[1] if len(args) > 1 else node.kwargs("min_val", -1.0) + max_val = node.args[2] if len(args) > 2 else node.kwargs("max_val", 1.0) + return self.block_builder.emit(relax.op.clip(x, min_val, max_val)) + def create_convert_map( self, ) -> Dict[str, Callable[[fx.Node], relax.Var]]: return { # unary + "acos.default": self._unary_op(relax.op.acos), + "acosh.default": self._unary_op(relax.op.acosh), + "asin.default": self._unary_op(relax.op.asin), + "asinh.default": self._unary_op(relax.op.asinh), + "atan.default": self._unary_op(relax.op.atan), + "atanh.default": self._unary_op(relax.op.atanh), + "clamp.default": self._clamp, + "cos.default": self._unary_op(relax.op.cos), + "cosh.default": self._unary_op(relax.op.cosh), "dropout.default": lambda node: self.env[node.args[0]], + "exp.default": self._unary_op(relax.op.exp), + "gelu.default": self._gelu, + "hardsigmoid.default": self._hardsigmoid, + "hardswish.default": self._hardswish, + "hardtanh.default": self._hardtanh, + "leaky_relu.default": self._leakyrelu, + "log_softmax.int": self._log_softmax, + "neg.default": self._unary_op(relax.op.negative), "relu.default": self._unary_op(relax.op.nn.relu), + "round.default": self._round, + "rsqrt.default": self._unary_op(relax.op.rsqrt), + "sigmoid.default": self._unary_op(relax.op.sigmoid), + "silu.default": self._unary_op(relax.op.nn.silu), + "sin.default": self._unary_op(relax.op.sin), + "sinh.default": self._unary_op(relax.op.sinh), + "softmax.int": self._softmax, + "sqrt.default": self._unary_op(relax.op.sqrt), + "tan.default": self._unary_op(relax.op.tan), + "tanh.default": self._unary_op(relax.op.tanh), + "tril.default": self._tril_triu(relax.op.tril), + "triu.default": self._tril_triu(relax.op.triu), # neural network "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d, "conv2d.default": self._conv2d, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index ec53cf23edc5..6f7c6fa2c575 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -62,64 +62,12 @@ def _fetch_attr(self, model, target: str): ########## Unary Ops ########## - def _clamp(self, node: fx.Node) -> relax.Expr: - args = self.retrieve_args(node) - a_min = args[1] if len(args) > 1 else node.kwargs["min"] - a_max = args[2] if len(args) > 2 else node.kwargs["max"] - if not isinstance(a_min, (int, float)): - raise ValueError( - f"TVM only supports constant min value for torch.clamp/clip, " - f"but got {a_min} with type {type(a_min)}" - ) - if not isinstance(a_max, (int, float)): - raise ValueError( - f"TVM only supports constant max value for torch.clamp/clip, " - f"but got {a_max} with type {type(a_max)}" - ) - return self.block_builder.emit(relax.op.clip(args[0], a_min, a_max)) - - def _gelu(self, node: fx.Node) -> relax.Expr: - approximate = node.kwargs.get("approximate", "none") - if approximate == "none": - return self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]])) - elif approximate == "tanh": - return self.block_builder.emit(relax.op.nn.gelu_tanh(self.env[node.args[0]])) - else: - raise KeyError("Unregonized approximate algorithm for gelu: {}.".format(approximate)) - - def _hardsigmoid(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - dtype = x.struct_info.dtype - x0 = relax.op.add(x, relax.const(3, dtype)) - x1 = relax.op.clip(x0, 0, 6) - return self.block_builder.emit(relax.op.divide(x1, relax.const(6, dtype))) - - def _hardswish(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - dtype = x.struct_info.dtype - x0 = relax.op.add(x, relax.const(3, dtype)) - x1 = relax.op.clip(x0, 0, 6) - x2 = relax.op.divide(x1, relax.const(6, dtype)) - return self.block_builder.emit(relax.op.multiply(x, x2)) - - def _leakyrelu(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - alpha = node.args[1] if len(node.args) > 1 else node.kwargs.get("negative_slope", 0.01) - return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha)) - def _leakyrelu_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] alpha = module.negative_slope return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha)) - def _log_softmax(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1) - return self.block_builder.emit(relax.op.nn.log_softmax(x, dim)) - def _log_softmax_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] @@ -127,17 +75,6 @@ def _log_softmax_module(self, node: fx.Node) -> relax.Var: assert dim is not None return self.block_builder.emit(relax.op.nn.log_softmax(x, dim)) - def _round(self, node: fx.Node) -> relax.Expr: - if node.kwargs.get("decimals", 0) != 0: - raise ValueError("specifying decimals for round is not supported yet") - arg = self.env[node.args[0]] - return self.block_builder.emit(relax.op.round(arg)) - - def _softmax(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1) - return self.block_builder.emit(relax.op.nn.softmax(x, dim)) - def _softmax_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] @@ -159,17 +96,6 @@ def convert(node: fx.Node) -> relax.Var: return convert - def _tril_triu(self, op: Callable) -> Callable: - from torch import fx - - def convert(node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - k = node.args[1] if len(node.args) > 1 else node.kwargs.get("diagonal", 0) - assert isinstance(k, int) - return self.block_builder.emit(op(x, k)) - - return convert - ########## Binary Ops ########## def _binary_op(self, relax_op: Callable, intrinsic_op: Callable) -> Callable: diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 112390fe6094..6c17d96004b6 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -39,6 +39,166 @@ def verify_model(torch_model, example_args, binding, expected): def test_unary(): example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + # acos + class Acos(Module): + def forward(self, input): + return torch.acos(input) + + @tvm.script.ir_module + class expected_acos: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.acos(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Acos(), example_args, {}, expected_acos) + + # acosh + class Acosh(Module): + def forward(self, input): + return torch.acosh(input) + + @tvm.script.ir_module + class expected_acosh: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.acosh(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Acosh(), example_args, {}, expected_acosh) + + # asin + class Asin(Module): + def forward(self, input): + return torch.asin(input) + + @tvm.script.ir_module + class expected_asin: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.asin(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Asin(), example_args, {}, expected_asin) + + # asinh + class Asinh(Module): + def forward(self, input): + return torch.asinh(input) + + @tvm.script.ir_module + class expected_asinh: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.asinh(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Asinh(), example_args, {}, expected_asinh) + + # atan + class Atan(Module): + def forward(self, input): + return torch.atan(input) + + @tvm.script.ir_module + class expected_atan: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.atan(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Atan(), example_args, {}, expected_atan) + + # atanh + class Atanh(Module): + def forward(self, input): + return torch.atanh(input) + + @tvm.script.ir_module + class expected_atanh: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.atanh(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Atanh(), example_args, {}, expected_atanh) + + # cos + class Cos(Module): + def forward(self, input): + return torch.cos(input) + + @tvm.script.ir_module + class expected_cos: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.cos(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Cos(), example_args, {}, expected_cos) + + # cosh + class Cosh(Module): + def forward(self, input): + return torch.cosh(input) + + @tvm.script.ir_module + class expected_cosh: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.cosh(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Cosh(), example_args, {}, expected_cosh) + # dropout class Dropout1(Module): def __init__(self): @@ -53,7 +213,7 @@ def forward(self, input): return torch.dropout(input, 0.5, train=True) @tvm.script.ir_module - class expected1: + class expected_dropout: @R.function def main( input_1: R.Tensor((1, 3, 10, 10), dtype="float32") @@ -64,8 +224,47 @@ def main( R.output(gv) return gv - verify_model(Dropout1(), example_args, {}, expected1) - verify_model(Dropout2(), example_args, {}, expected1) + verify_model(Dropout1(), example_args, {}, expected_dropout) + verify_model(Dropout2(), example_args, {}, expected_dropout) + + # exp + class Exp(Module): + def forward(self, input): + return torch.exp(input) + + @tvm.script.ir_module + class expected_exp: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Exp(), example_args, {}, expected_exp) + + # neg + class Neg(Module): + def forward(self, input): + return -input + + @I.ir_module + class expected_neg: + @R.function + def main( + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.negative(inp_0) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Neg(), example_args, {}, expected_neg) # relu class ReLU0(Module): @@ -81,7 +280,7 @@ def forward(self, input): return torch.nn.functional.relu(input) @tvm.script.ir_module - class expected: + class expected_relu: @R.function def main( input_1: R.Tensor((1, 3, 10, 10), dtype="float32") @@ -93,8 +292,502 @@ def main( R.output(gv) return gv - verify_model(ReLU0(), example_args, {}, expected) - verify_model(ReLU1(), example_args, {}, expected) + verify_model(ReLU0(), example_args, {}, expected_relu) + verify_model(ReLU1(), example_args, {}, expected_relu) + + # rsqrt + class Rsqrt(Module): + def forward(self, input): + return torch.rsqrt(input) + + @I.ir_module + class expected_rsqrt: + @R.function + def main( + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.rsqrt(inp_0) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(Rsqrt(), example_args, {}, expected_rsqrt) + + # sigmoid + class Sigmoid(Module): + def __init__(self): + super().__init__() + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, input): + return self.sigmoid(input) + + class Sigmoid2(Module): + def forward(self, input): + return torch.sigmoid(input) + + @tvm.script.ir_module + class expected_sigmoid: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sigmoid(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Sigmoid(), example_args, {}, expected_sigmoid) + verify_model(Sigmoid2(), example_args, {}, expected_sigmoid) + + # silu + class SiLU(Module): + def __init__(self): + super().__init__() + self.silu = torch.nn.SiLU() + + def forward(self, input): + return self.silu(input) + + class SiLU2(Module): + def forward(self, input): + return torch.nn.functional.silu(input) + + @tvm.script.ir_module + class expected_silu: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.silu(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(SiLU(), example_args, {}, expected_silu) + verify_model(SiLU2(), example_args, {}, expected_silu) + + # sin + class Sin(Module): + def forward(self, input: torch.Tensor): + return torch.sin(input) + + @tvm.script.ir_module + class expected_sin: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sin(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Sin(), example_args, {}, expected_sin) + + # sinh + class Sinh(Module): + def forward(self, input): + return torch.sinh(input) + + @tvm.script.ir_module + class expected_sinh: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sinh(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Sinh(), example_args, {}, expected_sinh) + + # sqrt + class Sqrt(Module): + def forward(self, input): + return torch.sqrt(input) + + @tvm.script.ir_module + class expected_sqrt: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sqrt(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Sqrt(), example_args, {}, expected_sqrt) + + # tan + class Tan(Module): + def forward(self, input): + return torch.tan(input) + + @tvm.script.ir_module + class expected_tan: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.tan(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Tan(), example_args, {}, expected_tan) + + # tanh + class Tanh(Module): + def forward(self, input): + return torch.tanh(input) + + @tvm.script.ir_module + class expected_tanh: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.tanh(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Tanh(), example_args, {}, expected_tanh) + + +def test_clamp(): + class Clamp(Module): + def forward(self, input): + return torch.clamp(input, min=0.1, max=0.5) + + @tvm.script.ir_module + class expected_clamp: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(input_1, 0.1, 0.5) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(Clamp(), example_args, {}, expected_clamp) + + +def test_gelu(): + class Gelu(Module): + def __init__(self): + super().__init__() + self.gelu = torch.nn.GELU() + + def forward(self, input): + return self.gelu(input) + + class Gelu2(Module): + def forward(self, input): + return torch.nn.functional.gelu(input) + + @tvm.script.ir_module + class expected_gelu: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.gelu(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(Gelu(), example_args, {}, expected_gelu) + verify_model(Gelu2(), example_args, {}, expected_gelu) + + +def test_hardsigmoid(): + class Hardsigmoid(torch.nn.Module): + def __init__(self): + super().__init__() + self.hs = torch.nn.Hardsigmoid() + + def forward(self, input): + return self.hs(input) + + class Hardsigmoid2(torch.nn.Module): + def forward(self, input): + return torch.nn.functional.hardsigmoid(input) + + @tvm.script.ir_module + class expected_hardsigmoid: + @R.function + def main( + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0, R.const(3, "float32")) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(lv, 0, 6) + lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide( + lv1, R.const(6, "float32") + ) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv2,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(Hardsigmoid(), example_args, {}, expected_hardsigmoid) + verify_model(Hardsigmoid2(), example_args, {}, expected_hardsigmoid) + + +def test_hardswish(): + class Hardswish(torch.nn.Module): + def __init__(self): + super().__init__() + self.hs = torch.nn.Hardswish() + + def forward(self, input): + return self.hs(input) + + class Hardswish2(torch.nn.Module): + def forward(self, input): + return torch.nn.functional.hardswish(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0, R.const(3, "float32")) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(lv, 0, 6) + lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide( + lv1, R.const(6, "float32") + ) + lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(inp_0, lv2) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv3,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(Hardswish(), example_args, {}, expected1) + verify_model(Hardswish2(), example_args, {}, expected1) + + +def test_hardtanh(): + class Hardtanh(torch.nn.Module): + def __init__(self): + super().__init__() + self.ht = torch.nn.Hardtanh() + + def forward(self, input): + return self.ht(input) + + class Hardtanh2(torch.nn.Module): + def forward(self, input): + return torch.nn.functional.hardtanh(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip( + inp_0, R.prim_value(T.float64(-1.0)), R.prim_value(T.float64(1.0)) + ) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(Hardtanh(), example_args, {}, expected1) + verify_model(Hardtanh2(), example_args, {}, expected1) + + +def test_leakyrelu(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + + class LeakyReLU0(Module): + def __init__(self): + super().__init__() + self.leakyrelu = torch.nn.LeakyReLU(0.02) + + def forward(self, input): + return self.leakyrelu(input) + + class LeakyReLU1(Module): + def forward(self, input): + return torch.nn.functional.leaky_relu(input, 0.02) + + @tvm.script.ir_module + class expected: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.leakyrelu(input_1, 0.02) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(LeakyReLU0(), example_args, {}, expected) + verify_model(LeakyReLU1(), example_args, {}, expected) + + +def test_logsoftmax(): + class LogSoftmax(Module): + def __init__(self): + super().__init__() + self.lsm = torch.nn.LogSoftmax(dim=1) + + def forward(self, input): + return self.lsm(input) + + class LogSoftmax2(Module): + def forward(self, input): + return torch.nn.functional.log_softmax(input, dim=1) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.log_softmax(input_1, axis=1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(LogSoftmax(), example_args, {}, expected1) + verify_model(LogSoftmax2(), example_args, {}, expected1) + + +def test_round(): + class Round(Module): + def forward(self, input): + return torch.round(input) + + @tvm.script.ir_module + class expected: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.round(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(Round(), example_args, {}, expected) + + +def test_softmax(): + class Softmax(Module): + def __init__(self): + super().__init__() + self.sm = torch.nn.Softmax(dim=1) + + def forward(self, input): + return self.sm(input) + + class Softmax2(Module): + def forward(self, input): + return torch.nn.functional.softmax(input, dim=1) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.softmax(input_1, axis=1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(Softmax(), example_args, {}, expected1) + verify_model(Softmax2(), example_args, {}, expected1) + + +def test_tril_triu(): + example_args = (torch.randn(10, 10, dtype=torch.float32),) + + class Tril(Module): + def forward(self, input): + return torch.tril(input, 1) + + @tvm.script.ir_module + class expected_tril: + @R.function + def main( + input_1: R.Tensor((10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.tril(input_1, 1) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Tril(), example_args, {}, expected_tril) + + class Triu(Module): + def forward(self, input): + return torch.triu(input, 1) + + @tvm.script.ir_module + class expected_triu: + @R.function + def main( + input_1: R.Tensor((10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.triu(input_1, 1) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Triu(), example_args, {}, expected_triu) def test_adaptive_avgpool2d(): diff --git a/tests/python/relay/test_to_mixed_precision.py b/tests/python/relay/test_to_mixed_precision.py index ae5172f6caf0..a8032ce0d26d 100644 --- a/tests/python/relay/test_to_mixed_precision.py +++ b/tests/python/relay/test_to_mixed_precision.py @@ -98,6 +98,7 @@ def test_lstm(target_precision): ) +@pytest.mark.skip(reason="Flaky test") def test_lstm_float64(): """Tests if can handle other mixed precision types.