From 29286466f5c25bd6ad6daf710020ff81ac8d1575 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 25 Jul 2024 21:54:23 -0700 Subject: [PATCH 1/2] [Relax] Disable fusion for fetching from the packed params in FuseOps The order of bindings in the fusion result is determined by the first binding in each partition group. When the packed param tuple is used, the function usually begins with a numbers of `TupleGetItem` to unpack the param tuple. Previously `TupleGetItem` is treated as `kInjective`, this causes any operation that relies purely on these params to be moved to the beginning of the function and increases the memory usage of the intermediate results. --- src/relax/transform/fuse_ops.cc | 19 ++++++- tests/python/relax/test_transform_fuse_ops.py | 54 +++++++++++++++++-- 2 files changed, 68 insertions(+), 5 deletions(-) diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index 2be7ad41f3e1..b8149ae47dfd 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -146,6 +146,12 @@ class GraphCreator : public ExprVisitor { SetNodePattern(param_node, OpPatternKind::kOpaque); AddToPostDFSOrder(param_node, param.get()); } + if (auto opt_num_input = func->GetAttr(attr::kNumInput)) { + for (int i = static_cast(opt_num_input.value()->value); + i < static_cast(func->params.size()); ++i) { + input_params_.insert(func->params[i].get()); + } + } ExprVisitor::VisitExpr_(func); } @@ -223,8 +229,15 @@ class GraphCreator : public ExprVisitor { IndexedForwardGraph::Node* binding_var_node) { ICHECK_NOTNULL(binding_var_node); - SetNodePattern(binding_var_node, OpPatternKind::kInjective); - VisitLeaf(tuple_item->tuple, binding_var_node, OpPatternKind::kInjective); + auto pattern = OpPatternKind::kInjective; + if (input_params_.count(tuple_item->tuple.as())) { + // TupleGetItem for fetching the parameter from the packed param tuple is treated as opaque + // and won't be fused. This prevents the usage of packed param tuple changes the order of the + // fusion result as the function usually begins with fetching the parameters. + pattern = OpPatternKind::kOpaque; + } + SetNodePattern(binding_var_node, pattern); + VisitLeaf(tuple_item->tuple, binding_var_node, pattern); } void VisitUnsupportedNode(const Expr& expr, IndexedForwardGraph::Node* binding_var_node) { @@ -353,6 +366,8 @@ class GraphCreator : public ExprVisitor { IndexedForwardGraph graph_; /*! \brief The graph nodes whose patterns are set */ std::unordered_set initialized_nodes_; + /*! \brief The model params in the function input */ + std::unordered_set input_params_; }; /*! diff --git a/tests/python/relax/test_transform_fuse_ops.py b/tests/python/relax/test_transform_fuse_ops.py index 3cd608d8ee8f..32ad4d4e440f 100644 --- a/tests/python/relax/test_transform_fuse_ops.py +++ b/tests/python/relax/test_transform_fuse_ops.py @@ -1633,14 +1633,62 @@ def main( ) -> R.Tensor((10, 20), dtype="float32"): cls = Expected with R.dataflow(): - gv1: R.Tensor( - (10, 20), dtype="float32" - ) = cls.fused_add_exp_inplace_squeeze_inplace(x, p0) + gv1: R.Tensor((10, 20), dtype="float32") = ( + cls.fused_add_exp_inplace_squeeze_inplace(x, p0) + ) R.output(gv1) return gv1 _check(Module, Expected) +def test_packed_params(): + # fmt: off + @I.ir_module + class Before: + @T.prim_func(private=True) + def cast(lv: T.Buffer((T.int64(16), T.int64(16)), "float16"), compute: T.Buffer((T.int64(16), T.int64(16)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + for i0, i1 in T.grid(T.int64(16), T.int64(16)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(lv[v_i0, v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = T.Cast("float32", lv[v_i0, v_i1]) + + @T.prim_func(private=True) + def matmul(x: T.Buffer((T.int64(16), T.int64(16)), "float32"), lv2: T.Buffer((T.int64(16), T.int64(16)), "float32"), T_matmul: T.Buffer((T.int64(16), T.int64(16)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + for ax0, ax1, k in T.grid(T.int64(16), T.int64(16), T.int64(16)): + with T.block("T_matmul"): + v_ax0, v_ax1, v_k = T.axis.remap("SSR", [ax0, ax1, k]) + T.reads(x[v_ax0, v_k], lv2[v_k, v_ax1]) + T.writes(T_matmul[v_ax0, v_ax1]) + with T.init(): + T_matmul[v_ax0, v_ax1] = T.float32(0) + T_matmul[v_ax0, v_ax1] = T_matmul[v_ax0, v_ax1] + x[v_ax0, v_k] * lv2[v_k, v_ax1] + + @R.function + def main(x: R.Tensor((16, 16), dtype="float32"), packed_params: R.Tuple(R.Tensor((16, 16), dtype="float16"), R.Tensor((16, 16), dtype="float16"))) -> R.Tensor((16, 16), dtype="float32"): + R.func_attr({"num_input": 1}) + cls = Before + with R.dataflow(): + lv: R.Tensor((16, 16), dtype="float16") = packed_params[0] + lv1: R.Tensor((16, 16), dtype="float16") = packed_params[1] + lv2 = R.call_tir(cls.cast, (lv,), out_sinfo=R.Tensor((16, 16), dtype="float32")) + lv3 = R.call_tir(cls.matmul, (x, lv2), out_sinfo=R.Tensor((16, 16), dtype="float32")) + lv4 = R.call_tir(cls.cast, (lv1,), out_sinfo=R.Tensor((16, 16), dtype="float32")) + lv5 = R.call_tir(cls.matmul, (lv3, lv4), out_sinfo=R.Tensor((16, 16), dtype="float32")) + gv: R.Tensor((16, 16), dtype="float32") = lv5 + R.output(gv) + return gv + # fmt: on + + Expected = Before + _check(Before, Expected) + + if __name__ == "__main__": tvm.testing.main() From 295bdc0ddf7538d5a0a517e83c6a31a50d125e7d Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 26 Jul 2024 09:01:46 -0700 Subject: [PATCH 2/2] lint --- tests/python/relax/test_transform_fuse_ops.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/relax/test_transform_fuse_ops.py b/tests/python/relax/test_transform_fuse_ops.py index 32ad4d4e440f..17bf58613294 100644 --- a/tests/python/relax/test_transform_fuse_ops.py +++ b/tests/python/relax/test_transform_fuse_ops.py @@ -1633,9 +1633,9 @@ def main( ) -> R.Tensor((10, 20), dtype="float32"): cls = Expected with R.dataflow(): - gv1: R.Tensor((10, 20), dtype="float32") = ( - cls.fused_add_exp_inplace_squeeze_inplace(x, p0) - ) + gv1: R.Tensor( + (10, 20), dtype="float32" + ) = cls.fused_add_exp_inplace_squeeze_inplace(x, p0) R.output(gv1) return gv1