diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 746010a4dc8a..52122ce33369 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -472,6 +472,31 @@ def _masked_fill(self, node: fx.Node) -> relax.Var: values = self.block_builder.emit(relax.op.full_like(x, rx_value)) return self.block_builder.emit(relax.op.where(mask, values, x)) + def _masked_scatter(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + mask = self.env[node.args[1]] + source = self.env[node.args[2]] + ndim = len(mask.struct_info.shape) + if ndim == 1: + index = self.block_builder.emit(relax.op.cumsum(mask, 0, dtype="int32")) + index = self.block_builder.emit(relax.op.subtract(index, relax.const(1, "int32"))) + gathered_source = self.block_builder.emit(relax.op.take(source, index, axis=0)) + else: + f_mask = self.block_builder.emit(relax.op.reshape(mask, [-1])) + index = self.block_builder.emit(relax.op.cumsum(f_mask, 0, dtype="int32")) + index = self.block_builder.emit(relax.op.subtract(index, relax.const(1, "int32"))) + source_shape = [-1] + [ + s for idx, s in enumerate(source.struct_info.shape) if idx >= ndim + ] + f_source = self.block_builder.emit(relax.op.reshape(source, source_shape)) + gathered_source = self.block_builder.emit(relax.op.take(f_source, index, axis=0)) + gathered_source = self.block_builder.emit( + relax.op.reshape(gathered_source, x.struct_info.shape) + ) + if ndim != len(x.struct_info.shape): + mask = self.block_builder.emit(relax.op.broadcast_to(mask, x.struct_info.shape)) + return self.block_builder.emit(relax.op.where(mask, gathered_source, x)) + def _ones(self, node: fx.Node) -> relax.Var: import torch @@ -695,6 +720,7 @@ def create_convert_map( "index_select": self._index_select, "masked_fill_": self._inplace_masked_fill, "masked_fill": self._masked_fill, + "masked_scatter": self._masked_scatter, "new_ones": self._new_ones, "ones": self._ones, "tensor": self._tensor, diff --git a/src/contrib/msc/core/ir/graph_builder.cc b/src/contrib/msc/core/ir/graph_builder.cc index abb7dfbd5e02..27115cb13065 100644 --- a/src/contrib/msc/core/ir/graph_builder.cc +++ b/src/contrib/msc/core/ir/graph_builder.cc @@ -704,7 +704,9 @@ const MSCPrim RelaxGraphBuilder::MatchOrCreatePrim(const PrimExpr& prim, const S } void RelaxGraphBuilder::VisitExpr_(const relax::ConstantNode* op) { - AddNode(GetRef(op)); + if (!expr_tensor_map_.count(GetRef(op))) { + AddNode(GetRef(op)); + } } void RelaxGraphBuilder::VisitBinding_(const relax::VarBindingNode* binding, diff --git a/src/contrib/msc/core/transform/set_expr_layout.cc b/src/contrib/msc/core/transform/set_expr_layout.cc index a3902a44bfaa..f3504d772338 100644 --- a/src/contrib/msc/core/transform/set_expr_layout.cc +++ b/src/contrib/msc/core/transform/set_expr_layout.cc @@ -492,9 +492,16 @@ InferLayoutOutput ForwardInferLayoutTake(const Call& call, return InferLayoutOutput({input_layout, indices_layout}, {output_layout}, Attrs()); } if (indices_layout->layout.defined()) { - size_t indices_size = indices_layout->layout.ndim(); - LayoutDecision output_layout = - LayoutUtils::ExpandLayout(indices_layout, std::vector{indices_size}); + std::vector expand_axes; + for (size_t i = indices_layout->layout.ndim(); i < output_shape.size(); i++) { + expand_axes.push_back(i); + } + LayoutDecision output_layout; + if (expand_axes.size() == 0) { + output_layout = indices_layout; + } else { + output_layout = LayoutUtils::ExpandLayout(indices_layout, expand_axes); + } return InferLayoutOutput({input_layout, indices_layout}, {output_layout}, Attrs()); } return InferLayoutOutput(); diff --git a/src/contrib/msc/framework/torch/torch_opcode.cc b/src/contrib/msc/framework/torch/torch_opcode.cc index abac3682fbb1..f5784efe3d26 100644 --- a/src/contrib/msc/framework/torch/torch_opcode.cc +++ b/src/contrib/msc/framework/torch/torch_opcode.cc @@ -224,6 +224,12 @@ class TorchConstantCodeGen : public TorchOpCode { } else if (dtype == "float32") { stack_.assign(module_ref(), node()->GetTypeAttr("scalar")); } + } else if (dtype == "bool") { + stack_.func_call("register_buffer", "", "self") + .call_arg(DocUtils::ToStr(ref_name)) + .inplace_start("torch.BoolTensor") + .call_arg(DocUtils::ToDocList(node()->OutputAt(0)->shape)) + .inplace_end(); } else if (dtype == "int32") { stack_.func_call("register_buffer", "", "self") .call_arg(DocUtils::ToStr(ref_name)) @@ -658,6 +664,18 @@ class TorchStridedSliceCodeGen : public TorchOpCode { } }; +class TorchTakeCodeGen : public TorchOpCode { + TORCH_OP_CODEGEN_METHODS(TorchTakeCodeGen) + + protected: + void CodeGenForward() final { + if (node()->InputAt(1)->DTypeName() == "int32") { + stack_.func_call("to", IdxInput(1), IdxInput(1)).call_arg("torch.int64"); + } + stack_.assign(IdxNode(), DocUtils::ToIndex(IdxInput(0), IdxInput(1))); + } +}; + class TorchTriCodeGen : public TorchOpCode { TORCH_OP_CODEGEN_METHODS(TorchTriCodeGen) @@ -738,6 +756,7 @@ const std::shared_ptr>> map->emplace("subtract", std::make_shared("", "torch.subtract")); map->emplace("tan", std::make_shared("", "torch.tan")); map->emplace("tanh", std::make_shared("", "torch.tanh")); + map->emplace("where", std::make_shared("", "torch.where")); // reduce ops map->emplace("max", std::make_shared("", "torch.max")); @@ -771,6 +790,7 @@ const std::shared_ptr>> map->emplace("scatter_nd", std::make_shared("", "")); map->emplace("split", std::make_shared("", "torch.split")); map->emplace("strided_slice", std::make_shared("", "")); + map->emplace("take", std::make_shared("", "")); // create ops map->emplace("constant", std::make_shared("nn.Parameter", "")); diff --git a/tests/python/contrib/test_msc/test_graph_build.py b/tests/python/contrib/test_msc/test_graph_build.py index 647879378e0c..3b514ad6d890 100644 --- a/tests/python/contrib/test_msc/test_graph_build.py +++ b/tests/python/contrib/test_msc/test_graph_build.py @@ -2472,6 +2472,91 @@ def forward(self, data, index, src): ) +@pytest.mark.parametrize("dynamic", [True, False]) +def test_masked_scatter(dynamic): + """test graph builder for masked_scatter""" + + dim = "dim" if dynamic else 5 + + class MaskedScatter1(Module): + def forward(self, data, mask, src): + return data.masked_scatter(mask, src) + + class MaskedScatter2(Module): + def forward(self, data, mask, src): + return data.masked_scatter(mask, src) + + expected1 = { + "inputs": [ + {"name": "inp_0", "shape": [dim], "dtype": "float32", "layout": "A"}, + {"name": "inp_1", "shape": [dim], "dtype": "bool", "layout": "A"}, + {"name": "inp_2", "shape": [10], "dtype": "float32", "layout": "A"}, + ], + "outputs": [{"name": "where", "shape": [dim], "dtype": "float32", "layout": "A"}], + "nodes": { + "total": 8, + "input": 3, + "cumsum": 1, + "constant": 1, + "subtract": 1, + "take": 1, + "where": 1, + }, + } + expected2 = { + "inputs": [ + { + "name": "inp_0", + "shape": [2, dim], + "dtype": "float32", + "layout": "" if dynamic else "BA", + }, + { + "name": "inp_1", + "shape": [2, dim], + "dtype": "bool", + "layout": "" if dynamic else "BA", + }, + { + "name": "inp_2", + "shape": [3, dim], + "dtype": "float32", + "layout": "" if dynamic else "BA", + }, + ], + "outputs": [ + { + "name": "where", + "shape": [2, dim], + "dtype": "float32", + "layout": "" if dynamic else "BA", + } + ], + "nodes": { + "total": 11, + "input": 3, + "reshape": 3, + "cumsum": 1, + "constant": 1, + "subtract": 1, + "take": 1, + "where": 1, + }, + } + if dynamic: + expected1["prims"] = {"total": 1, "shape": 1} + expected2["prims"] = {"total": 5, "shape": 1, "Int": 2, "Mul": 2} + + verify_model( + MaskedScatter1(), [([dim], "float32"), ([dim], "bool"), ([10], "float32")], expected1 + ) + verify_model( + MaskedScatter2(), + [([2, dim], "float32"), ([2, dim], "bool"), ([3, dim], "float32")], + expected2, + ) + + def test_put(): """test graph builder for index_put""" diff --git a/tests/python/contrib/test_msc/test_translate_relax.py b/tests/python/contrib/test_msc/test_translate_relax.py index 27a02844e19d..d8f746d68822 100644 --- a/tests/python/contrib/test_msc/test_translate_relax.py +++ b/tests/python/contrib/test_msc/test_translate_relax.py @@ -1193,6 +1193,29 @@ def forward(self, data, index, src): verify_model(Scatter2(), [([20, 20], "float32"), ([2, 5], "int64"), ([2, 5], "float32")]) +def test_masked_scatter(): + """test relax translator for masked_scatter""" + + class MaskedScatter1(Module): + def __init__(self): + super().__init__() + self.mask = msc_utils.random_data([(5,), "bool"], MSCFramework.TORCH) + + def forward(self, data, src): + return data.masked_scatter(self.mask, src) + + class MaskedScatter2(Module): + def __init__(self): + super().__init__() + self.mask = msc_utils.random_data([(2, 5), "bool"], MSCFramework.TORCH) + + def forward(self, data, src): + return data.masked_scatter(self.mask, src) + + verify_model(MaskedScatter1(), [([5], "float32"), ([10], "float32")]) + verify_model(MaskedScatter2(), [([2, 5], "float32"), ([3, 5], "float32")]) + + def test_put(): """test relax translator for index_put""" diff --git a/tests/python/contrib/test_msc/test_translate_torch.py b/tests/python/contrib/test_msc/test_translate_torch.py index 6ed28c0ac0b7..6535ef66c8b3 100644 --- a/tests/python/contrib/test_msc/test_translate_torch.py +++ b/tests/python/contrib/test_msc/test_translate_torch.py @@ -1173,6 +1173,29 @@ def forward(self, data, index, src): ) +def test_masked_scatter(): + """test torch translator for masked_scatter""" + + class MaskedScatter1(Module): + def __init__(self): + super().__init__() + self.mask = msc_utils.random_data([(5,), "bool"], MSCFramework.TORCH) + + def forward(self, data, src): + return data.masked_scatter(self.mask, src) + + class MaskedScatter2(Module): + def __init__(self): + super().__init__() + self.mask = msc_utils.random_data([(2, 5), "bool"], MSCFramework.TORCH) + + def forward(self, data, src): + return data.masked_scatter(self.mask, src) + + verify_model(MaskedScatter1(), [([5], "float32"), ([10], "float32")], True) + verify_model(MaskedScatter2(), [([2, 5], "float32"), ([3, 5], "float32")], True) + + def test_put(): """test torch translator for index_put""" diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 08331f08612b..d9857723b1f5 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -4023,5 +4023,66 @@ def main( verify_model(Scatter(), input_info, {}, expected) +def test_masked_scatter(): + class MaskedScatter1(Module): + def forward(self, data, mask, src): + return data.masked_scatter(mask, src) + + class MaskedScatter2(Module): + def forward(self, data, mask, src): + return data.masked_scatter(mask, src) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + inp_0: R.Tensor((5,), dtype="float32"), + inp_1: R.Tensor((5,), dtype="bool"), + inp_2: R.Tensor((10,), dtype="float32"), + ) -> R.Tensor((5,), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((5,), dtype="int32") = R.cumsum( + inp_1, axis=0, dtype="int32", exclusive=False + ) + lv1: R.Tensor((5,), dtype="int32") = R.subtract(lv, R.const(1, "int32")) + lv2: R.Tensor((5,), dtype="float32") = R.take(inp_2, lv1, axis=0) + lv3: R.Tensor((5,), dtype="float32") = R.where(inp_1, lv2, inp_0) + gv: R.Tensor((5,), dtype="float32") = lv3 + R.output(gv) + return gv + + @tvm.script.ir_module + class expected2: + @R.function + def main( + inp_0: R.Tensor((2, 5), dtype="float32"), + inp_1: R.Tensor((2, 5), dtype="bool"), + inp_2: R.Tensor((3, 5), dtype="float32"), + ) -> R.Tensor((2, 5), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((10,), dtype="bool") = R.reshape(inp_1, R.shape([10])) + lv1: R.Tensor((10,), dtype="int32") = R.cumsum( + lv, axis=0, dtype="int32", exclusive=False + ) + lv2: R.Tensor((10,), dtype="int32") = R.subtract(lv1, R.const(1, "int32")) + lv3: R.Tensor((15,), dtype="float32") = R.reshape(inp_2, R.shape([15])) + lv4: R.Tensor((10,), dtype="float32") = R.take(lv3, lv2, axis=0) + lv5: R.Tensor((2, 5), dtype="float32") = R.reshape(lv4, R.shape([2, 5])) + lv6: R.Tensor((2, 5), dtype="float32") = R.where(inp_1, lv5, inp_0) + gv: R.Tensor((2, 5), dtype="float32") = lv6 + R.output(gv) + return gv + + verify_model( + MaskedScatter1(), [([5], "float32"), ([5], "bool"), ([10], "float32")], {}, expected1 + ) + verify_model( + MaskedScatter2(), + [([2, 5], "float32"), ([2, 5], "bool"), ([3, 5], "float32")], + {}, + expected2, + ) + + if __name__ == "__main__": tvm.testing.main()