From a7af9b8620c2b8bd8925d234a6fd78fde74549a6 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 27 Oct 2023 14:03:40 -0500 Subject: [PATCH] [Unity][Transform] Extract partial-tuple-usage from FuseOps Prior to this commit, the `FuseOps` pass explicitly tracked usage of tuple arguments, to minimize the set of arguments provided to each kernel. The additional tgracking and handling of partially-used tuples makes it difficult to follow the primary changes being made by `FuseOps`. This commit implements the same functionality in terms of the `ExpandTupleArguments` and `RemoveUnusedParameters` transforms, introduced in https://github.com/apache/tvm/pull/16115 and https://github.com/apache/tvm/pull/16116 respectively. By using these passes before the main `FuseOps` changes, partial tuple usage is already handled at that point. This commit is intended to minimize any changes to user-facing behavior, and so these pre-process passes are currently used internally by `FuseOps`. This may be avoided in the future by pulling this internal delegation out into a lowering pipeline. --- src/relax/transform/fuse_ops.cc | 96 ++++------------ tests/python/relax/test_transform_fuse_ops.py | 108 +++++++++--------- .../test_transform_fuse_ops_by_pattern.py | 1 - 3 files changed, 74 insertions(+), 131 deletions(-) diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index b0eeba399e90..4a03f2dc686a 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -411,21 +411,11 @@ class FunctionCreator : public ExprMutator { for (const Expr& arg : call->args) { CheckDefAndUpdateParam(arg); - if (GetStructInfoAs(arg) != nullptr) { - // The argument is fully referenced. Thus we remove it from the mapping. - partially_used_tuple_params_.erase(arg.get()); - } } } } else if (var_binding->value.as()) { const auto* tuple_item = var_binding->value.as(); CheckDefAndUpdateParam(tuple_item->tuple); - - if (partially_used_tuple_params_.find(tuple_item->tuple.get()) != - partially_used_tuple_params_.end()) { - // Appending get-item index to the mapping. - partially_used_tuple_params_[tuple_item->tuple.get()].push_back(tuple_item->index); - } } // Mark the binding variable as defined. @@ -461,51 +451,9 @@ class FunctionCreator : public ExprMutator { // Step 1. Start constructing a new dataflow block. builder_->BeginDataflowBlock(); - // Step 2. Handing partially used tuple parameters: replacing entire tuple - // parameters with the parameters of its fields that are accessed in the - // function. - std::unordered_map> tuple_get_item_remap; - for (auto& [tuple_arg, item_indices] : partially_used_tuple_params_) { - ICHECK(!item_indices.empty()); - int param_idx = tuple_param_idx_[tuple_arg]; - Var param = params_[param_idx]; - String param_name = params_[param_idx]->name_hint(); - TupleStructInfo param_sinfo = Downcast(tuple_arg->struct_info_); - - Array item_args; - Array item_params; - item_args.reserve(item_indices.size()); - item_params.reserve(item_indices.size()); - for (int item_idx : item_indices) { - Var item_param(param_name + "_" + std::to_string(item_idx), param_sinfo->fields[item_idx]); - item_args.push_back(TupleGetItem(GetRef(tuple_arg), item_idx)); - item_params.push_back(item_param); - tuple_get_item_remap[tuple_arg][item_idx] = item_param; - } - arguments_.erase(arguments_.begin() + param_idx); - arguments_.insert(arguments_.begin() + param_idx, item_args.begin(), item_args.end()); - params_.erase(params_.begin() + param_idx); - params_.insert(params_.begin() + param_idx, item_params.begin(), item_params.end()); - } - // Step 3. Visit each binding and collect outputs one by one. Array outputs(output_vars_.size(), Expr()); for (const Binding& binding : bindings_) { - // Special handing for TupleGetItem. - if (const auto* var_binding = binding.as()) { - if (const auto* tuple_get_item = var_binding->value.as()) { - auto it = tuple_get_item_remap.find(tuple_get_item->tuple.get()); - if (it != tuple_get_item_remap.end()) { - ICHECK(it->second.find(tuple_get_item->index) != it->second.end()); - var_remap_[var_binding->var->vid] = it->second[tuple_get_item->index]; - if (auto output_idx = GetOutputIndex(binding->var)) { - outputs.Set(*output_idx, it->second[tuple_get_item->index]); - } - continue; - } - } - } - if (auto output_idx = GetOutputIndex(binding->var)) { // Case 1. It is an output binding // We only allow VarBinding as output. @@ -602,13 +550,6 @@ class FunctionCreator : public ExprMutator { arguments_.push_back(expr); params_.push_back(param); } - - // Mark the tuple parameter is partially referenced in the beginning. - // We will remove it from the mapping once we find it is fully referenced. - if (param_sinfo->IsInstance()) { - partially_used_tuple_params_[expr.get()] = {}; - tuple_param_idx_[expr.get()] = static_cast(arguments_.size()) - 1; - } } } @@ -631,13 +572,6 @@ class FunctionCreator : public ExprMutator { std::vector output_vars_; /*! \brief Whether or not to lift bound constants to parameters */ bool lift_constant_; - /*! \brief Mapping from tuple parameter of the function to its position index */ - std::unordered_map tuple_param_idx_; - /*! - * \brief Mapping from partially referenced tuple parameter to the list of - * indices that the parameter is referred by TupleGetItem - */ - std::unordered_map> partially_used_tuple_params_; }; /*! @@ -1323,10 +1257,17 @@ Pass FuseOps(int fuse_opt_level) { auto max_fuse_depth = pc->GetConfig("relax.FuseOps.max_depth", Integer(kMaxFusedOps)); return relax::FuseOps(m, opt_level, max_fuse_depth.value().IntValue()); }; - return CreateModulePass(/*pass_function=*/pass_func, // - /*opt_level=*/0, // - /*pass_name=*/"FuseOps", // - /*required=*/{}); + auto inner_pass = CreateModulePass(/*pass_function=*/pass_func, // + /*opt_level=*/0, // + /*pass_name=*/"FuseOpsInner", // + /*required=*/{}); + return tvm::transform::Sequential( + { + inner_pass, + ExpandTupleArguments(), + RemoveUnusedParameters(), + }, + "FuseOpsInner"); } TVM_REGISTER_GLOBAL("relax.transform.FuseOps").set_body_typed(FuseOps); @@ -1337,10 +1278,17 @@ Pass FuseOpsByPattern(const tvm::Array& patterns, bool bind_const [=](IRModule m, PassContext pc) { return relax::FuseOpsByPattern(patterns, m, bind_constants, annotate_codegen); }; - return CreateModulePass(/*pass_function=*/pass_func, // - /*opt_level=*/0, // - /*pass_name=*/"FuseOpsByPattern", // - /*required=*/{}); + auto inner_pass = CreateModulePass(/*pass_function=*/pass_func, // + /*opt_level=*/0, // + /*pass_name=*/"FuseOpsByPatternInner", // + /*required=*/{}); + return tvm::transform::Sequential( + { + inner_pass, + ExpandTupleArguments(), + RemoveUnusedParameters(), + }, + "FuseOpsByPattern"); } TVM_REGISTER_GLOBAL("relax.transform.FuseOpsByPattern").set_body_typed(FuseOpsByPattern); diff --git a/tests/python/relax/test_transform_fuse_ops.py b/tests/python/relax/test_transform_fuse_ops.py index 1a4a630e3e5a..1d60dfedad31 100644 --- a/tests/python/relax/test_transform_fuse_ops.py +++ b/tests/python/relax/test_transform_fuse_ops.py @@ -1122,8 +1122,7 @@ def main(inp_0: R.Tensor((2, 320, 64, 64), dtype="float32"), inp_1: R.Tensor((2, lv31 = R.call_tir(cls.transpose, (w2,), out_sinfo=R.Tensor((1280, 320), dtype="float32")) lv: R.Tensor((2, 320), dtype="float32") = cls.fused_matmul_add1(inp_1, lv31, b2) lv35 = R.call_tir(cls.reshape1, (lv,), out_sinfo=R.Tensor((2, 320, 1, 1), dtype="float32")) - lv1: R.Tensor((2, 320, 64, 64), dtype="float32") = cls.fused_conv2d_add_add2(inp_0, w1, lv28, lv35) - gv: R.Tensor((2, 320, 64, 64), dtype="float32") = lv1 + gv: R.Tensor((2, 320, 64, 64), dtype="float32") = cls.fused_conv2d_add_add2(inp_0, w1, lv28, lv35) R.output(gv) return gv # fmt: on @@ -1156,16 +1155,16 @@ def main(inp_0: R.Tensor((1, 784), dtype="float32"), inp_1: R.Tensor((1, 128), d @I.ir_module class Expected: - @T.prim_func(private=True) - def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), rxplaceholder_1: T.Buffer((T.int64(128),), "float32"), T_add: T.Buffer((T.int64(1), T.int64(128)), "float32")): - T.func_attr({"op_pattern": 0, "tir.noalias": True}) - # with T.block("root"): - for ax0, ax1 in T.grid(T.int64(1), T.int64(128)): - with T.block("T_add"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(rxplaceholder[v_ax0, v_ax1], rxplaceholder_1[v_ax1]) - T.writes(T_add[v_ax0, v_ax1]) - T_add[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] + rxplaceholder_1[v_ax1] + # @T.prim_func(private=True) + # def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), rxplaceholder_1: T.Buffer((T.int64(128),), "float32"), T_add: T.Buffer((T.int64(1), T.int64(128)), "float32")): + # T.func_attr({"op_pattern": 0, "tir.noalias": True}) + # # with T.block("root"): + # for ax0, ax1 in T.grid(T.int64(1), T.int64(128)): + # with T.block("T_add"): + # v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + # T.reads(rxplaceholder[v_ax0, v_ax1], rxplaceholder_1[v_ax1]) + # T.writes(T_add[v_ax0, v_ax1]) + # T_add[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] + rxplaceholder_1[v_ax1] @T.prim_func(private=True) def add1(rxplaceholder: T.Buffer((T.int64(1), T.int64(10)), "float32"), rxplaceholder_1: T.Buffer((T.int64(10),), "float32"), T_add: T.Buffer((T.int64(1), T.int64(10)), "float32")): @@ -1178,18 +1177,18 @@ def add1(rxplaceholder: T.Buffer((T.int64(1), T.int64(10)), "float32"), rxplaceh T.writes(T_add[v_ax0, v_ax1]) T_add[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] + rxplaceholder_1[v_ax1] - @T.prim_func(private=True) - def matmul(rxplaceholder: T.Buffer((T.int64(1), T.int64(784)), "float32"), rxplaceholder_1: T.Buffer((T.int64(784), T.int64(128)), "float32"), matmul_1: T.Buffer((T.int64(1), T.int64(128)), "float32")): - T.func_attr({"op_pattern": 4, "tir.noalias": True}) - # with T.block("root"): - for i0, i1, k in T.grid(T.int64(1), T.int64(128), T.int64(784)): - with T.block("matmul"): - v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) - T.reads(rxplaceholder[v_i0, v_k], rxplaceholder_1[v_k, v_i1]) - T.writes(matmul_1[v_i0, v_i1]) - with T.init(): - matmul_1[v_i0, v_i1] = T.float32(0) - matmul_1[v_i0, v_i1] = matmul_1[v_i0, v_i1] + rxplaceholder[v_i0, v_k] * rxplaceholder_1[v_k, v_i1] + # @T.prim_func(private=True) + # def matmul(rxplaceholder: T.Buffer((T.int64(1), T.int64(784)), "float32"), rxplaceholder_1: T.Buffer((T.int64(784), T.int64(128)), "float32"), matmul_1: T.Buffer((T.int64(1), T.int64(128)), "float32")): + # T.func_attr({"op_pattern": 4, "tir.noalias": True}) + # # with T.block("root"): + # for i0, i1, k in T.grid(T.int64(1), T.int64(128), T.int64(784)): + # with T.block("matmul"): + # v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) + # T.reads(rxplaceholder[v_i0, v_k], rxplaceholder_1[v_k, v_i1]) + # T.writes(matmul_1[v_i0, v_i1]) + # with T.init(): + # matmul_1[v_i0, v_i1] = T.float32(0) + # matmul_1[v_i0, v_i1] = matmul_1[v_i0, v_i1] + rxplaceholder[v_i0, v_k] * rxplaceholder_1[v_k, v_i1] @T.prim_func(private=True) def matmul1(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), rxplaceholder_1: T.Buffer((T.int64(128), T.int64(10)), "float32"), matmul: T.Buffer((T.int64(1), T.int64(10)), "float32")): @@ -1204,27 +1203,27 @@ def matmul1(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), rxpl matmul[v_i0, v_i1] = T.float32(0) matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + rxplaceholder[v_i0, v_k] * rxplaceholder_1[v_k, v_i1] - @T.prim_func(private=True) - def relu(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), compute: T.Buffer((T.int64(1), T.int64(128)), "float32")): - T.func_attr({"op_pattern": 0, "tir.noalias": True}) - # with T.block("root"): - for i0, i1 in T.grid(T.int64(1), T.int64(128)): - with T.block("compute"): - v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[v_i0, v_i1]) - T.writes(compute[v_i0, v_i1]) - compute[v_i0, v_i1] = T.max(rxplaceholder[v_i0, v_i1], T.float32(0)) - - @T.prim_func(private=True) - def transpose(rxplaceholder: T.Buffer((T.int64(128), T.int64(784)), "float32"), T_transpose: T.Buffer((T.int64(784), T.int64(128)), "float32")): - T.func_attr({"op_pattern": 2, "tir.noalias": True}) - # with T.block("root"): - for ax0, ax1 in T.grid(T.int64(784), T.int64(128)): - with T.block("T_transpose"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(rxplaceholder[v_ax1, v_ax0]) - T.writes(T_transpose[v_ax0, v_ax1]) - T_transpose[v_ax0, v_ax1] = rxplaceholder[v_ax1, v_ax0] + # @T.prim_func(private=True) + # def relu(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), compute: T.Buffer((T.int64(1), T.int64(128)), "float32")): + # T.func_attr({"op_pattern": 0, "tir.noalias": True}) + # # with T.block("root"): + # for i0, i1 in T.grid(T.int64(1), T.int64(128)): + # with T.block("compute"): + # v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + # T.reads(rxplaceholder[v_i0, v_i1]) + # T.writes(compute[v_i0, v_i1]) + # compute[v_i0, v_i1] = T.max(rxplaceholder[v_i0, v_i1], T.float32(0)) + + # @T.prim_func(private=True) + # def transpose(rxplaceholder: T.Buffer((T.int64(128), T.int64(784)), "float32"), T_transpose: T.Buffer((T.int64(784), T.int64(128)), "float32")): + # T.func_attr({"op_pattern": 2, "tir.noalias": True}) + # # with T.block("root"): + # for ax0, ax1 in T.grid(T.int64(784), T.int64(128)): + # with T.block("T_transpose"): + # v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + # T.reads(rxplaceholder[v_ax1, v_ax0]) + # T.writes(T_transpose[v_ax0, v_ax1]) + # T_transpose[v_ax0, v_ax1] = rxplaceholder[v_ax1, v_ax0] @T.prim_func(private=True) def transpose1(rxplaceholder: T.Buffer((T.int64(10), T.int64(128)), "float32"), T_transpose: T.Buffer((T.int64(128), T.int64(10)), "float32")): @@ -1252,10 +1251,9 @@ def main(inp_0: R.Tensor((1, 784), dtype="float32"), inp_1: R.Tensor((1, 128), d R.func_attr({"num_input": 1}) cls = Expected with R.dataflow(): - lv = R.call_tir(cls.transpose, (linear1_weight,), out_sinfo=R.Tensor((784, 128), dtype="float32")) + # lv = R.call_tir(cls.transpose, (linear1_weight,), out_sinfo=R.Tensor((784, 128), dtype="float32")) lv4 = R.call_tir(cls.transpose1, (linear2_weight,), out_sinfo=R.Tensor((128, 10), dtype="float32")) - lv_1: R.Tensor((1, 10), dtype="float32") = cls.fused_matmul1_add1(inp_1, lv4, linear2_bias) - gv: R.Tensor((1, 10), dtype="float32") = lv_1 + gv: R.Tensor((1, 10), dtype="float32") = cls.fused_matmul1_add1(inp_1, lv4, linear2_bias) R.output(gv) return gv @@ -1319,7 +1317,7 @@ def main(s: R.Shape(["n"])): class Expected: @R.function(private=True) def fused_full_trilu_broadcast_to( - s: R.Shape(["n"]), + s: R.Prim(value="n"), ) -> R.Tensor([1, 1, "n", "n"], "float32"): R.func_attr({"Primitive": 1}) n = T.int64() @@ -1336,7 +1334,7 @@ def main(s: R.Shape(["n"])) -> R.Tensor((1, 1, "n", "n"), dtype="float32"): n = T.int64() with R.dataflow(): gv: R.Tensor([1, 1, n, n], "float32") = cls.fused_full_trilu_broadcast_to( - R.shape([n]) + R.prim_value(n) ) R.output(gv) return gv @@ -1367,7 +1365,7 @@ def main(s: R.Shape(["n"]), kv_cache: R.Object): class Expected: @R.function(private=True) def fused_full_trilu_broadcast_to( - s: R.Shape(["n"]), + s: R.Prim(value="n"), ) -> R.Tensor([1, 1, "n", "n"], "float32"): R.func_attr({"Primitive": 1}) n = T.int64() @@ -1384,7 +1382,7 @@ def main(s: R.Shape(["n"]), kv_cache: R.Object): n = T.int64() with R.dataflow(): lv: R.Tensor([1, 1, n, n], "float32") = cls.fused_full_trilu_broadcast_to( - R.shape([n]) + R.prim_value(n) ) gv = R.call_pure_packed( "vm.builtin.attention_kv_cache_view", @@ -1406,10 +1404,9 @@ def main(A: R.Tensor((10, 20), dtype="float32")) -> R.Tensor(dtype="float32", nd m = T.int64() n = T.int64() with R.dataflow(): - lv: R.Tensor((m, n), dtype="float32") = R.match_cast( + gv: R.Tensor((m, n), dtype="float32") = R.match_cast( A, R.Tensor((m, n), dtype="float32") ) - gv: R.Tensor((m, n), dtype="float32") = lv R.output(gv) return gv @@ -1491,10 +1488,9 @@ def main( cls = Expected with R.dataflow(): lv: R.Tensor((2,), dtype="float32") = x[0] - lv1: R.Tensor((2,), dtype="float32") = cls.fused_add_divide( + gv: R.Tensor((2,), dtype="float32") = cls.fused_add_divide( lv, R.const(1, "float32"), R.const(1, "float32") ) - gv: R.Tensor((2,), dtype="float32") = lv1 R.output(gv) return gv diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py b/tests/python/relax/test_transform_fuse_ops_by_pattern.py index bd434864a081..3125a1d87c93 100644 --- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py +++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py @@ -758,7 +758,6 @@ def main( gv: R.Tensor((1, 64, 56, 56), dtype="float32") = cls.fused_relax_nn_conv2d( data, weight1 ) - relu: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(data) R.output(gv) return gv