diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index 2be7ad41f3e1..b8149ae47dfd 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -146,6 +146,12 @@ class GraphCreator : public ExprVisitor { SetNodePattern(param_node, OpPatternKind::kOpaque); AddToPostDFSOrder(param_node, param.get()); } + if (auto opt_num_input = func->GetAttr(attr::kNumInput)) { + for (int i = static_cast(opt_num_input.value()->value); + i < static_cast(func->params.size()); ++i) { + input_params_.insert(func->params[i].get()); + } + } ExprVisitor::VisitExpr_(func); } @@ -223,8 +229,15 @@ class GraphCreator : public ExprVisitor { IndexedForwardGraph::Node* binding_var_node) { ICHECK_NOTNULL(binding_var_node); - SetNodePattern(binding_var_node, OpPatternKind::kInjective); - VisitLeaf(tuple_item->tuple, binding_var_node, OpPatternKind::kInjective); + auto pattern = OpPatternKind::kInjective; + if (input_params_.count(tuple_item->tuple.as())) { + // TupleGetItem for fetching the parameter from the packed param tuple is treated as opaque + // and won't be fused. This prevents the usage of packed param tuple changes the order of the + // fusion result as the function usually begins with fetching the parameters. + pattern = OpPatternKind::kOpaque; + } + SetNodePattern(binding_var_node, pattern); + VisitLeaf(tuple_item->tuple, binding_var_node, pattern); } void VisitUnsupportedNode(const Expr& expr, IndexedForwardGraph::Node* binding_var_node) { @@ -353,6 +366,8 @@ class GraphCreator : public ExprVisitor { IndexedForwardGraph graph_; /*! \brief The graph nodes whose patterns are set */ std::unordered_set initialized_nodes_; + /*! \brief The model params in the function input */ + std::unordered_set input_params_; }; /*! diff --git a/tests/python/relax/test_transform_fuse_ops.py b/tests/python/relax/test_transform_fuse_ops.py index 3cd608d8ee8f..17bf58613294 100644 --- a/tests/python/relax/test_transform_fuse_ops.py +++ b/tests/python/relax/test_transform_fuse_ops.py @@ -1642,5 +1642,53 @@ def main( _check(Module, Expected) +def test_packed_params(): + # fmt: off + @I.ir_module + class Before: + @T.prim_func(private=True) + def cast(lv: T.Buffer((T.int64(16), T.int64(16)), "float16"), compute: T.Buffer((T.int64(16), T.int64(16)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + for i0, i1 in T.grid(T.int64(16), T.int64(16)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(lv[v_i0, v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = T.Cast("float32", lv[v_i0, v_i1]) + + @T.prim_func(private=True) + def matmul(x: T.Buffer((T.int64(16), T.int64(16)), "float32"), lv2: T.Buffer((T.int64(16), T.int64(16)), "float32"), T_matmul: T.Buffer((T.int64(16), T.int64(16)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + for ax0, ax1, k in T.grid(T.int64(16), T.int64(16), T.int64(16)): + with T.block("T_matmul"): + v_ax0, v_ax1, v_k = T.axis.remap("SSR", [ax0, ax1, k]) + T.reads(x[v_ax0, v_k], lv2[v_k, v_ax1]) + T.writes(T_matmul[v_ax0, v_ax1]) + with T.init(): + T_matmul[v_ax0, v_ax1] = T.float32(0) + T_matmul[v_ax0, v_ax1] = T_matmul[v_ax0, v_ax1] + x[v_ax0, v_k] * lv2[v_k, v_ax1] + + @R.function + def main(x: R.Tensor((16, 16), dtype="float32"), packed_params: R.Tuple(R.Tensor((16, 16), dtype="float16"), R.Tensor((16, 16), dtype="float16"))) -> R.Tensor((16, 16), dtype="float32"): + R.func_attr({"num_input": 1}) + cls = Before + with R.dataflow(): + lv: R.Tensor((16, 16), dtype="float16") = packed_params[0] + lv1: R.Tensor((16, 16), dtype="float16") = packed_params[1] + lv2 = R.call_tir(cls.cast, (lv,), out_sinfo=R.Tensor((16, 16), dtype="float32")) + lv3 = R.call_tir(cls.matmul, (x, lv2), out_sinfo=R.Tensor((16, 16), dtype="float32")) + lv4 = R.call_tir(cls.cast, (lv1,), out_sinfo=R.Tensor((16, 16), dtype="float32")) + lv5 = R.call_tir(cls.matmul, (lv3, lv4), out_sinfo=R.Tensor((16, 16), dtype="float32")) + gv: R.Tensor((16, 16), dtype="float32") = lv5 + R.output(gv) + return gv + # fmt: on + + Expected = Before + _check(Before, Expected) + + if __name__ == "__main__": tvm.testing.main()