diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 52784dc8c3cd..322ee04e0c20 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -730,6 +730,51 @@ def convert(node: fx.Node): ########## Manipulation ########## + def _cat(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) + return self.block_builder.emit(relax.op.concat(args[0], axis=axis)) + + def _cumsum(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) + if "dtype" in node.kwargs: + dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env) + else: + dtype = None + if "out" in node.kwargs: + raise ValueError("specifying out for cumsum is not supported yet") + + return self.block_builder.emit(relax.op.cumsum(x, dim, dtype)) + + def _expand(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + sizes = args[1:] if len(args) > 2 else args[1] + broadcast_shape, in_shape = [], self.shape_of(args[0]) + for idx, i in enumerate(sizes): + if isinstance(i, int) and i == -1: + broadcast_shape.append(in_shape[idx]) + else: + broadcast_shape.append(i) + return self.block_builder.emit(relax.op.broadcast_to(args[0], broadcast_shape)) + + def _permute(self, node: fx.Node) -> relax.Var: + import torch # type: ignore + + args = self.retrieve_args(node) + x = args[0] + dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] + return self.block_builder.emit(relax.op.permute_dims(x, dims)) + + def _repeat(self, node: fx.Node) -> relax.Var: + import torch # type: ignore + + args = self.retrieve_args(node) + x = args[0] + dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] + return self.block_builder.emit(relax.op.tile(x, dims)) + def _reshape(self, node: fx.Node) -> relax.Var: import torch # type: ignore @@ -738,6 +783,122 @@ def _reshape(self, node: fx.Node) -> relax.Var: dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] return self.block_builder.emit(relax.op.reshape(x, dims)) + def _split(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + split_size = node.args[1] + dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0) + if isinstance(split_size, (list, tuple)): + n_section = [] + for s in split_size[:-1]: + cum_sum = 0 if not n_section else n_section[-1] + n_section.append(s + cum_sum) + else: + n_section = (self.shape_of(x)[dim].value + split_size - 1) // split_size + return self.block_builder.emit(relax.op.split(x, n_section, dim)) + + def _squeeze(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) + return self.block_builder.emit(relax.op.squeeze(x, dim)) + + def _tile(self, node: fx.Node) -> relax.Var: + import torch # type: ignore + + args = self.retrieve_args(node) + x = args[0] + dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] + return self.block_builder.emit(relax.op.tile(x, dims)) + + def _transpose(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + full_idx = list(range(len(self.shape_of(args[0])))) + full_idx[args[1]], full_idx[args[2]] = full_idx[args[2]], full_idx[args[1]] + return self.block_builder.emit(relax.op.permute_dims(args[0], full_idx)) + + ########## Creation ########## + + def _to_copy(self, node: fx.Node) -> relax.Var: + import torch # type: ignore + + x = self.env[node.args[0]] + if len(node.args) == 2: + if isinstance(node.args[1], torch.dtype): + dtype = self._convert_data_type(node.args[1], self.env) + return self.block_builder.emit(relax.op.astype(x, dtype)) + elif "dtype" in node.kwargs: + dtype = self._convert_data_type(node.kwargs["dtype"], self.env) + return self.block_builder.emit(relax.op.astype(x, dtype)) + return x + + def _arange(self, node: fx.Node) -> relax.Var: + import torch # type: ignore + + start_end_step = [None, None, None] + if "start" in node.kwargs: + start_end_step[0] = node.kwargs["start"] + if "end" in node.kwargs: + start_end_step[1] = node.kwargs["end"] + if "step" in node.kwargs: + start_end_step[2] = node.kwargs["step"] + + if len(node.args) == 1: + assert start_end_step[1] is None + start_end_step[1] = node.args[0] + elif len(node.args) == 2: + assert start_end_step[0] is None + assert start_end_step[1] is None + start_end_step[0] = node.args[0] + start_end_step[1] = node.args[1] + elif len(node.args) == 3: + assert start_end_step[0] is None + assert start_end_step[1] is None + assert start_end_step[2] is None + start_end_step[0] = node.args[0] + start_end_step[1] = node.args[1] + start_end_step[2] = node.args[2] + + if start_end_step[0] is None: + start_end_step[0] = 0 + if start_end_step[2] is None: + start_end_step[2] = 1 + + if "dtype" in node.kwargs: + dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env) + elif any([isinstance(x, float) for x in start_end_step]): + dtype = self._convert_data_type(torch.get_default_dtype()) + else: + dtype = "int64" + start_end_step = [ + self.env[x] if isinstance(x, torch.fx.Node) else x for x in start_end_step + ] + return self.block_builder.emit(relax.op.arange(*start_end_step, dtype=dtype)) + + def _empty(self, node: fx.Node) -> relax.Var: + dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env) + return self.block_builder.emit(relax.op.zeros(node.args[0], dtype)) + + def _fill(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + dtype = x.struct_info.dtype + value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype) + return self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype)) + + def _new_ones(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + self_var = args[0] + size = args[1] if isinstance(args[1], (list, tuple)) else args[1:] + if not isinstance(size, (list, tuple)): + size = (size,) + size = relax.ShapeExpr(size) + return self.block_builder.emit( + relax.op.full( + size, + relax.const(1, self_var.struct_info.dtype), + self_var.struct_info.dtype, + ) + ) + ########## Others ########## def _getitem(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 64583d750974..1401a0bcef3a 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -162,6 +162,22 @@ def _upsample_nearest2d(self, node: fx.node) -> relax.Var: scale_factor = node.args[3] if len(node.args) > 3 else node.kwargs.get("scale_factor", None) return self._upsample_impl(x, size, align_corners, scale_factor, "nearest_neighbor") + ########## Manipulation ########## + + def _select(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] + index = relax.const(node.args[2], "int64") + return self.block_builder.emit(relax.op.take(x, index, dim)) + + def _slice(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + axes = [node.args[1]] + begin = [node.args[2]] + end = [node.args[3]] + stride = [node.args[4] if len(node.args) > 4 else 1] + return self.block_builder.emit(relax.op.strided_slice(x, axes, begin, end, stride)) + def create_convert_map( self, ) -> Dict[str, Callable[[fx.Node], relax.Var]]: @@ -249,7 +265,30 @@ def create_convert_map( "argmax.default": self._argmax_argmin(relax.op.argmax), "argmin.default": self._argmax_argmin(relax.op.argmin), # tensor manipulation + "cat.default": self._cat, + "concat.default": self._cat, + "cumsum.default": self._cumsum, + "expand.default": self._expand, + "permute.default": self._permute, + "repeat.default": self._repeat, + "select.int": self._select, + "slice.Tensor": self._slice, + "split.Tensor": self._split, + "squeeze.default": self._squeeze, + "squeeze.dim": self._squeeze, + "tile.default": self._tile, + "transpose.int": self._transpose, + "unsqueeze.default": lambda node: self.block_builder.emit( + relax.op.expand_dims(self.env[node.args[0]], node.args[1]) + ), "view.default": self._reshape, + # tensor creation + "_to_copy.default": self._to_copy, + "arange.start": self._arange, + "clone.default": lambda node: self.env[node.args[0]], + "empty.memory_format": self._empty, + "fill.Scalar": self._fill, + "new_ones.default": self._new_ones, # other "getitem": self._getitem, } diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index c60c7c3953b4..9fbc95fa7c00 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -380,41 +380,12 @@ def _max_pool2d_module(self, node: fx.Node) -> relax.Var: ########## Manipulation ########## - def _cat(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) - return self.block_builder.emit(relax.op.concat(args[0], axis=axis)) - def _chunk(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] chunks = node.args[1] dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0) return self.block_builder.emit(relax.op.split(x, chunks, dim)) - def _cumsum(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - - dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) - if "dtype" in node.kwargs: - dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env) - else: - dtype = None - if "out" in node.kwargs: - raise ValueError("specifying out for cumsum is not supported yet") - - return self.block_builder.emit(relax.op.cumsum(x, dim, dtype)) - - def _expand(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - sizes = args[1:] if len(args) > 2 else args[1] - broadcast_shape, in_shape = [], self.shape_of(args[0]) - for idx, i in enumerate(sizes): - if isinstance(i, int) and i == -1: - broadcast_shape.append(in_shape[idx]) - else: - broadcast_shape.append(i) - return self.block_builder.emit(relax.op.broadcast_to(args[0], broadcast_shape)) - def _flatten_impl(self, x, start_dim, end_dim) -> relax.Var: shape = self.shape_of(x) start_dim = start_dim if start_dim >= 0 else len(shape) + start_dim @@ -440,22 +411,6 @@ def _flatten_module(self, node: fx.Node) -> relax.Var: end_dim = module.end_dim return self._flatten_impl(x, start_dim, end_dim) - def _permute(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - - args = self.retrieve_args(node) - x = args[0] - dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] - return self.block_builder.emit(relax.op.permute_dims(x, dims)) - - def _repeat(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - - args = self.retrieve_args(node) - x = args[0] - dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] - return self.block_builder.emit(relax.op.tile(x, dims)) - def _size(self, node: fx.Node) -> relax.Expr: x = self.env[node.args[0]] shape = self.shape_of(x) @@ -466,87 +421,8 @@ def _size(self, node: fx.Node) -> relax.Expr: idx = node.args[1] return self.shape_of(x)[idx].value - def _split(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - split_size = node.args[1] - dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0) - if isinstance(split_size, (list, tuple)): - n_section = [] - for s in split_size[:-1]: - cum_sum = 0 if not n_section else n_section[-1] - n_section.append(s + cum_sum) - else: - n_section = (self.shape_of(x)[dim].value + split_size - 1) // split_size - return self.block_builder.emit(relax.op.split(x, n_section, dim)) - - def _squeeze(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) - return self.block_builder.emit(relax.op.squeeze(x, dim)) - - def _tile(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - - args = self.retrieve_args(node) - x = args[0] - dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] - return self.block_builder.emit(relax.op.tile(x, dims)) - - def _transpose(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - full_idx = list(range(len(self.shape_of(args[0])))) - full_idx[args[1]], full_idx[args[2]] = full_idx[args[2]], full_idx[args[1]] - return self.block_builder.emit(relax.op.permute_dims(args[0], full_idx)) - ########## Creation ########## - def _arange(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - - start_end_step = [None, None, None] - if "start" in node.kwargs: - start_end_step[0] = node.kwargs["start"] - if "end" in node.kwargs: - start_end_step[1] = node.kwargs["end"] - if "step" in node.kwargs: - start_end_step[2] = node.kwargs["step"] - - if len(node.args) == 1: - assert start_end_step[1] is None - start_end_step[1] = node.args[0] - elif len(node.args) == 2: - assert start_end_step[0] is None - assert start_end_step[1] is None - start_end_step[0] = node.args[0] - start_end_step[1] = node.args[1] - elif len(node.args) == 3: - assert start_end_step[0] is None - assert start_end_step[1] is None - assert start_end_step[2] is None - start_end_step[0] = node.args[0] - start_end_step[1] = node.args[1] - start_end_step[2] = node.args[2] - - if start_end_step[0] is None: - start_end_step[0] = 0 - if start_end_step[2] is None: - start_end_step[2] = 1 - - if "dtype" in node.kwargs: - dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env) - elif any([isinstance(x, float) for x in start_end_step]): - dtype = self._convert_data_type(torch.get_default_dtype()) - else: - dtype = "int64" - start_end_step = [ - self.env[x] if isinstance(x, torch.fx.Node) else x for x in start_end_step - ] - return self.block_builder.emit(relax.op.arange(*start_end_step, dtype=dtype)) - - def _empty(self, node: fx.Node) -> relax.Var: - dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env) - return self.block_builder.emit(relax.op.zeros(node.args[0], dtype)) - def _inplace_fill(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] @@ -596,21 +472,6 @@ def _masked_fill(self, node: fx.Node) -> relax.Var: values = self.block_builder.emit(relax.op.full_like(x, rx_value)) return self.block_builder.emit(relax.op.where(mask, values, x)) - def _new_ones(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - self_var = args[0] - size = args[1] if isinstance(args[1], (list, tuple)) else args[1:] - if not isinstance(size, (list, tuple)): - size = (size,) - size = relax.ShapeExpr(size) - return self.block_builder.emit( - relax.op.full( - size, - relax.const(1, self_var.struct_info.dtype), - self_var.struct_info.dtype, - ) - ) - def _ones(self, node: fx.Node) -> relax.Var: import torch diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 7c887d9b9610..65890ff6971b 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -2734,6 +2734,582 @@ def main( verify_model(Argmin2(), example_args, {}, expected_argmin2) +def test_cat_concat(): + class Cat0(Module): + def forward(self, x, y): + return torch.cat((x, y)) + + class Cat1(Module): + def forward(self, x, y): + return torch.cat((x, y), dim=1) + + class Cat2(Module): + def forward(self, x, y): + return torch.cat((x, y), 1) + + class Cat3(Module): + def forward(self, x, y): + return torch.concat((x, y), dim=0) + + @I.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((2, 3), dtype="float32"), + inp_1: R.Tensor((2, 3), dtype="float32"), + ) -> R.Tuple(R.Tensor((4, 3), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((4, 3), dtype="float32") = R.concat((inp_0, inp_1), axis=0) + gv: R.Tuple(R.Tensor((4, 3), dtype="float32")) = (lv,) + R.output(gv) + return gv + + @I.ir_module + class Expected2: + @R.function + def main( + inp_0: R.Tensor((2, 3), dtype="float32"), + inp_1: R.Tensor((2, 3), dtype="float32"), + ) -> R.Tuple(R.Tensor((2, 6), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((2, 6), dtype="float32") = R.concat((inp_0, inp_1), axis=1) + gv: R.Tuple(R.Tensor((2, 6), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(2, 3, dtype=torch.float32), torch.randn(2, 3, dtype=torch.float32)) + verify_model(Cat0(), example_args, {}, Expected1) + verify_model(Cat1(), example_args, {}, Expected2) + verify_model(Cat2(), example_args, {}, Expected2) + verify_model(Cat3(), example_args, {}, Expected1) + + +def test_cumsum(): + class Cumsum(Module): + def forward(self, input): + return torch.cumsum(input, dim=1, dtype=torch.int32) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="int32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 2, 3, 4), dtype="int32") = R.cumsum(input_1, axis=1, dtype="int32") + gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="int32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),) + verify_model(Cumsum(), example_args, {}, expected1) + + +def test_expand(): + class Expand1(Module): + def forward(self, x): + return x.expand(4, 2, 3, 4) + + class Expand2(Module): + def forward(self, x): + return x.expand(4, -1, -1, 4) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((4, 2, 3, 4), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((4, 2, 3, 4), dtype="float32") = R.broadcast_to(x, (4, 2, 3, 4)) + gv: R.Tuple(R.Tensor((4, 2, 3, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),) + verify_model(Expand1(), example_args, {}, expected1) + verify_model(Expand2(), example_args, {}, expected1) + + +def test_flatten(): + class Flatten(Module): + def __init__(self): + super().__init__() + self.f = torch.nn.Flatten(2, -1) + + def forward(self, input): + return self.f(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 100), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 100), dtype="float32") = R.reshape(input_1, (1, 3, 100)) + gv: R.Tuple(R.Tensor((1, 3, 100), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(Flatten(), example_args, {}, expected1) + + +def test_permute(): + class Permute1(Module): + def forward(self, x): + return x.permute(0, 3, 2, 1) + + class Permute2(Module): + def forward(self, x): + return torch.permute(x, (0, 3, 2, 1)) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 4, 3, 2), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 4, 3, 2), dtype="float32") = R.permute_dims(x, axes=[0, 3, 2, 1]) + gv: R.Tuple(R.Tensor((1, 4, 3, 2), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),) + verify_model(Permute1(), example_args, {}, expected1) + verify_model(Permute2(), example_args, {}, expected1) + + +def test_repeat(): + class Tile1(Module): + def forward(self, x: torch.Tensor): + return x.repeat(2) + + class Tile2(Module): + def forward(self, x: torch.Tensor): + return x.repeat(4, 2) + + @tvm.script.ir_module + class expected1: + @R.function + def main(x: R.Tensor((3,), dtype="float32")) -> R.Tuple(R.Tensor((6,), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((6,), dtype="float32") = R.tile(x, 2) + gv: R.Tuple(R.Tensor((6,), dtype="float32")) = (lv,) + R.output(gv) + return gv + + @tvm.script.ir_module + class expected2: + @R.function + def main( + x: R.Tensor((1, 3), dtype="float32") + ) -> R.Tuple(R.Tensor((4, 6), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((4, 6), dtype="float32") = R.tile(x, [4, 2]) + gv: R.Tuple(R.Tensor((4, 6), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(3, dtype=torch.float32),) + verify_model(Tile1(), example_args, {}, expected1) + + example_args = (torch.randn(1, 3, dtype=torch.float32),) + verify_model(Tile2(), example_args, {}, expected2) + + example_args = (torch.randn(1, 3, dtype=torch.float32),) + verify_model(Tile2(), example_args, {}, expected2) + + +def test_reshape(): + class Reshape(Module): + def forward(self, x): + return x.reshape(2, 12) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((2, 12), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((2, 12), dtype="float32") = R.reshape(x, (2, 12)) + gv: R.Tuple(R.Tensor((2, 12), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),) + verify_model(Reshape(), example_args, {}, expected1) + + +def test_select_slice(): + class Slice1(Module): + def forward(self, x): + return x[0, 1::2, :, :3] + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 10, 3), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((3, 10, 10), dtype="float32") = R.take(x, R.const(0, "int64"), axis=0) + lv1: R.Tensor((1, 10, 10), dtype="float32") = R.strided_slice( + lv, + (R.prim_value(0),), + (R.prim_value(1),), + (R.prim_value(9223372036854775807),), + (R.prim_value(2),), + assume_inbound=False, + ) + lv2: R.Tensor((1, 10, 10), dtype="float32") = R.strided_slice( + lv1, + (R.prim_value(1),), + (R.prim_value(0),), + (R.prim_value(9223372036854775807),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv3: R.Tensor((1, 10, 3), dtype="float32") = R.strided_slice( + lv2, + (R.prim_value(2),), + (R.prim_value(0),), + (R.prim_value(3),), + (R.prim_value(1),), + assume_inbound=False, + ) + gv: R.Tuple(R.Tensor((1, 10, 3), dtype="float32")) = (lv3,) + R.output(gv) + return gv + + class Slice2(Module): + def forward(self, x): + return x[:, None, None, :, None] + + @I.ir_module + class expected2: + @R.function + def main( + x: R.Tensor((8, 16), dtype="float32") + ) -> R.Tuple(R.Tensor((8, 1, 1, 16, 1), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((8, 16), dtype="float32") = R.strided_slice( + x, + (R.prim_value(0),), + (R.prim_value(0),), + (R.prim_value(9223372036854775807),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv1: R.Tensor((8, 1, 16), dtype="float32") = R.expand_dims(lv, axis=[1]) + lv2: R.Tensor((8, 1, 1, 16), dtype="float32") = R.expand_dims(lv1, axis=[2]) + lv3: R.Tensor((8, 1, 1, 16), dtype="float32") = R.strided_slice( + lv2, + (R.prim_value(3),), + (R.prim_value(0),), + (R.prim_value(9223372036854775807),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv4: R.Tensor((8, 1, 1, 16, 1), dtype="float32") = R.expand_dims(lv3, axis=[4]) + gv: R.Tuple(R.Tensor((8, 1, 1, 16, 1), dtype="float32")) = (lv4,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(Slice1(), example_args, {}, expected1) + + example_args = (torch.randn(8, 16, dtype=torch.float32),) + verify_model(Slice2(), example_args, {}, expected2) + + +def test_split(): + class Chunk(Module): + def forward(self, input): + return torch.chunk(input, 3, dim=1) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple( + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + ): + # block 0 + with R.dataflow(): + lv: R.Tuple( + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + ) = R.split(input_1, indices_or_sections=3, axis=1) + lv1: R.Tensor((1, 1, 10, 10), dtype="float32") = lv[0] + lv2: R.Tensor((1, 1, 10, 10), dtype="float32") = lv[1] + lv3: R.Tensor((1, 1, 10, 10), dtype="float32") = lv[2] + gv: R.Tuple( + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + ) = (lv1, lv2, lv3) + R.output(gv) + return gv + + class Unbind1(Module): + def forward(self, data): + return torch.unbind(data) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((3, 3, 10, 10), dtype="float32") + ) -> R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ): + # block 0 + with R.dataflow(): + lv: R.Tuple( + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((0, 3, 10, 10), dtype="float32"), + ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=0) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0] + lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[0]) + lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[1] + lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3, axis=[0]) + lv5: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[2] + lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5, axis=[0]) + lv7: R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ) = (lv2, lv4, lv6) + lv8: R.Tensor((3, 10, 10), dtype="float32") = lv7[0] + lv9: R.Tensor((3, 10, 10), dtype="float32") = lv7[1] + lv10: R.Tensor((3, 10, 10), dtype="float32") = lv7[2] + gv: R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ) = (lv8, lv9, lv10) + R.output(gv) + return gv + + class Unbind2(Module): + def forward(self, data): + return torch.unbind(data, dim=1) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((3, 3, 10, 10), dtype="float32") + ) -> R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ): + # block 0 + with R.dataflow(): + lv: R.Tuple( + R.Tensor((3, 1, 10, 10), dtype="float32"), + R.Tensor((3, 1, 10, 10), dtype="float32"), + R.Tensor((3, 1, 10, 10), dtype="float32"), + R.Tensor((3, 0, 10, 10), dtype="float32"), + ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=1) + lv1: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[0] + lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[1]) + lv3: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[1] + lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3, axis=[1]) + lv5: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[2] + lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5, axis=[1]) + lv7: R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ) = (lv2, lv4, lv6) + lv8: R.Tensor((3, 10, 10), dtype="float32") = lv7[0] + lv9: R.Tensor((3, 10, 10), dtype="float32") = lv7[1] + lv10: R.Tensor((3, 10, 10), dtype="float32") = lv7[2] + gv: R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ) = (lv8, lv9, lv10) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(Chunk(), example_args, {}, Expected) + + example_args = (torch.randn(3, 3, 10, 10, dtype=torch.float32),) + verify_model(Unbind1(), example_args, {}, expected1) + verify_model(Unbind2(), example_args, {}, expected2) + + +def test_squeeze(): + class Squeeze1(Module): + def forward(self, input): + return input.squeeze(1) + + @tvm.script.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((3, 1, 4, 1), dtype="float32") + ) -> R.Tuple(R.Tensor((3, 4, 1), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((3, 4, 1), dtype="float32") = R.squeeze(inp_0, axis=[1]) + gv: R.Tuple(R.Tensor((3, 4, 1), dtype="float32")) = (lv,) + R.output(gv) + return gv + + class Squeeze2(Module): + def forward(self, input): + return input.squeeze() + + @tvm.script.ir_module + class Expected2: + @R.function + def main( + inp_0: R.Tensor((3, 1, 4, 1), dtype="float32") + ) -> R.Tuple(R.Tensor((3, 4), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((3, 4), dtype="float32") = R.squeeze(inp_0, axis=None) + gv: R.Tuple(R.Tensor((3, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(3, 1, 4, 1, dtype=torch.float32),) + + verify_model(Squeeze1(), example_args, {}, Expected1) + verify_model(Squeeze2(), example_args, {}, Expected2) + + +def test_tile(): + class Tile1(Module): + def forward(self, x): + return x.tile((2,)) + + class Tile2(Module): + def forward(self, x): + return x.tile(4, 2) + + class Tile3(Module): + def forward(self, x): + return torch.tile(x, (4, 2)) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 3), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 6), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 6), dtype="float32") = R.tile(x, [2]) + gv: R.Tuple(R.Tensor((1, 6), dtype="float32")) = (lv,) + R.output(gv) + return gv + + @tvm.script.ir_module + class expected2: + @R.function + def main( + x: R.Tensor((1, 3), dtype="float32") + ) -> R.Tuple(R.Tensor((4, 6), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((4, 6), dtype="float32") = R.tile(x, [4, 2]) + gv: R.Tuple(R.Tensor((4, 6), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, dtype=torch.float32),) + verify_model(Tile1(), example_args, {}, expected1) + verify_model(Tile2(), example_args, {}, expected2) + verify_model(Tile3(), example_args, {}, expected2) + + +def test_transpose(): + class Transpose(Module): + def forward(self, x): + return x.transpose(1, 3) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 4, 3, 2), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 4, 3, 2), dtype="float32") = R.permute_dims(x, axes=[0, 3, 2, 1]) + gv: R.Tuple(R.Tensor((1, 4, 3, 2), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),) + verify_model(Transpose(), example_args, {}, expected1) + + +def test_unsqueeze(): + class Unsqueeze1(Module): + def forward(self, input): + return input.unsqueeze(1) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 1, 3, 10, 10), dtype="float32") = R.expand_dims(input_1, 1) + gv: R.Tuple(R.Tensor((1, 1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + class Unsqueeze2(Module): + def forward(self, input): + return input.unsqueeze(-1) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10, 1), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10, 1), dtype="float32") = R.expand_dims(input_1, -1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10, 1), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + + verify_model(Unsqueeze1(), example_args, {}, expected1) + verify_model(Unsqueeze2(), example_args, {}, expected2) + + def test_view(): class View(Module): def forward(self, x): @@ -2756,6 +3332,211 @@ def main( verify_model(View(), example_args, {}, expected1) +def test_arange(): + class Arange(Module): + def forward(self, input): + return torch.arange(0, 20, dtype=torch.int32) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + input: R.Tensor((10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((20,), dtype="int32")): + with R.dataflow(): + lv: R.Tensor((20,), dtype="int32") = R.arange(0, 20, 1, dtype="int32") + gv: R.Tuple(R.Tensor((20,), dtype="int32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(10, 10, dtype=torch.float32),) + verify_model(Arange(), example_args, {}, Expected) + + +def test_clone(): + class Clone(Module): + def forward(self, input): + return torch.clone(input) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + input: R.Tensor((10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + with R.dataflow(): + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (input,) + R.output(gv) + return gv + + example_args = (torch.randn(10, 10, dtype=torch.float32),) + verify_model(Clone(), example_args, {}, Expected) + + +def test_empty(): + class Empty(Module): + def forward(self, input): + return torch.empty((10, 10), dtype=torch.float32) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.zeros( + R.shape([10, 10]), dtype="float32" + ) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(10, 10, dtype=torch.float32),) + verify_model(Empty(), example_args, {}, Expected) + + +def test_fill(): + class Fill(Module): + def forward(self, input: torch.Tensor): + return torch.fill(input, 1.5) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.full( + R.shape([10, 10]), R.const(1.5, "float32"), dtype="float32" + ) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(10, 10, dtype=torch.float32),) + verify_model(Fill(), example_args, {}, Expected) + + +def test_new_ones(): + class NewOnes(Module): + def forward(self, x): + return x.new_ones(1, 2, 3) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 2, 3), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 2, 3), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 2, 3), dtype="float32") = R.full( + (1, 2, 3), R.const(1, "float32"), dtype="float32" + ) + gv: R.Tuple(R.Tensor((1, 2, 3), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 2, 3, dtype=torch.float32),) + verify_model(NewOnes(), example_args, {}, expected1) + + +def test_to_copy(): + # float + class ToFloat(Module): + def forward(self, x): + return x.float() + + @tvm.script.ir_module + class expected_float: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 2, 3, 4), dtype="float32") = R.astype(x, dtype="float32") + gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + + # half + class ToHalf(Module): + def forward(self, x): + return x.half() + + @tvm.script.ir_module + class expected_half: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float16")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 2, 3, 4), dtype="float16") = R.astype(x, dtype="float16") + gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float16")) = (lv,) + R.output(gv) + return gv + + # type + class Type(Module): + def forward(self, x): + return x.type(torch.float32) + + @tvm.script.ir_module + class expected_type: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")): + # block 0 + with R.dataflow(): + gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")) = (x,) + R.output(gv) + return gv + + class To1(Module): + def forward(self, input): + return input.to(torch.float16) + + @I.ir_module + class expected_to1: + @R.function + def main( + inp_0: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float16")): + with R.dataflow(): + lv: R.Tensor((1, 2, 3, 4), dtype="float16") = R.astype(inp_0, dtype="float16") + gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float16")) = (lv,) + R.output(gv) + return gv + + class To2(Module): + def forward(self, input): + return input.to("cpu") + + @I.ir_module + class expected_to2: + @R.function + def main( + inp_0: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 2, 3, 4), dtype="float32") = R.astype(inp_0, dtype="float32") + gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),) + verify_model(ToFloat(), example_args, {}, expected_float) + verify_model(ToHalf(), example_args, {}, expected_half) + verify_model(Type(), example_args, {}, expected_type) + verify_model(To1(), example_args, {}, expected_to1) + verify_model(To2(), example_args, {}, expected_to2) + + def test_keep_params(): class Conv2D1(Module): def __init__(self):