From 6656fb6b4541ca849f1efe4fcc947d4da8cbd6ac Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 2 Nov 2023 15:52:01 -0500 Subject: [PATCH 1/3] [Unity][Transform] Extract partial-tuple-usage from FuseTIR Prior to this commit, the `FuseTIR` pass explicitly tracked usage of tuple arguments, to minimize the set of arguments provided to each kernel. The additional tgracking and handling of partially-used tuples makes it difficult to follow the primary changes being made by `FuseTIR`. This commit implements the same functionality in terms of the `ExpandTupleArguments` and `RemoveUnusedParameters` transforms, introduced in https://github.com/apache/tvm/pull/16115 and https://github.com/apache/tvm/pull/16116 respectively. By using these passes before the main `FuseOps` changes, partial tuple usage is already handled at that point. This commit is intended to minimize any changes to user-facing behavior, and so these pre-process passes are currently used internally by `FuseOps`. This may be avoided in the future by pulling this internal delegation out into a lowering pipeline. --- src/relax/transform/fuse_tir.cc | 252 +++++++----------- tests/python/relax/test_transform_fuse_tir.py | 14 +- 2 files changed, 107 insertions(+), 159 deletions(-) diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index 1059791c8a51..999886202abf 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -385,58 +385,45 @@ class FusedTIRConstructor : public ExprVisitor { : mod_(mod), func_name_(func_name) {} void VisitExpr_(const FunctionNode* func) final { - // Step 1. Create buffers for function params - - // Record which fields in a tuple passed as a parameter are actually accessed by the function. - std::unordered_set tuple_param; - for (auto param : func->params) { - if (GetStructInfo(param)->IsInstance()) { - tuple_param.insert(param.get()); - } - } - - PostOrderVisit(func->body, [=, &tuple_param](Expr e) { - if (auto tup_get = e.as(); - tup_get && tuple_param.count(tup_get->tuple.get())) { - func_info_.used_tuple_field_indices[tup_get->tuple.get()].insert(tup_get->index); - } - }); - + std::vector> prim_func_params; for (const Var& relax_param : func->params) { - auto sinfo = GetStructInfo(relax_param); - if (sinfo->IsInstance()) { - // It's a symbolic shape var, no need to alloc Buffers. - continue; - } - - auto [params, buffers] = [=]() { - if (const auto* tuple = sinfo.as()) { - // Add only those tuple fields which are actually used by the function body into the - // function parameters. - int index = 0; - Array params; - Array buffers; - for (auto i : func_info_.used_tuple_field_indices[relax_param.get()]) { - auto [ret_params, ret_buffers] = - CreateParamsAndBuffers(tuple->fields[i], relax_param->name_hint(), index); - ICHECK_EQ(ret_params.size(), ret_buffers.size()); - // Adding tuple field results to the end of params and buffers. - params.insert(params.end(), ret_params.begin(), ret_params.end()); - buffers.insert(buffers.end(), ret_buffers.begin(), ret_buffers.end()); - index += ret_params.size(); + size_t size_before = prim_func_params.size(); + CollectPrimFuncParams(relax_param, &prim_func_params); + + auto param_buffers = [&]() -> Array { + Array out; + for (size_t i = size_before; i < prim_func_params.size(); i++) { + if (auto buf = prim_func_params[i].as()) { + out.push_back(buf.value()); } - return std::make_pair(params, buffers); - } else { - return CreateParamsAndBuffers(sinfo, relax_param->name_hint()); } + return out; }(); - ICHECK_EQ(params.size(), buffers.size()); - for (size_t i = 0; i < params.size(); ++i) { - func_info_.buffer_map.Set(params[i], buffers[i]); - func_info_.params.push_back(params[i]); + func_info_.expr2buffers.Set(relax_param, param_buffers); + } + + // Move all scalar params after buffer params. + std::stable_sort(prim_func_params.begin(), prim_func_params.end(), + [](const auto& a, const auto& b) { + bool a_is_var = a.template as(); + bool b_is_var = b.template as(); + return a_is_var < b_is_var; + }); + + for (const auto& param : prim_func_params) { + if (auto opt = param.as()) { + auto buffer = opt.value(); + // Differentiate buffer name and param name by adding prefix + // `p_` to the buffer name. Every symbol should be unique in + // TVMScript, and while they can be de-deplicated when + // printed, it's more readable when done explicitly. Since + // Buffer is used more than param it gets the name with better + // readability. + tir::Var param = tir::Var("p_" + buffer->name, PrimType(DataType::Handle())); + func_info_.params.push_back(param); + func_info_.buffer_map.Set(param, buffer); } - func_info_.expr2buffers.Set(relax_param, buffers); } // Step 2. Visit Function body and create intermediate buffers @@ -458,13 +445,9 @@ class FusedTIRConstructor : public ExprVisitor { } // Step 4. Append symbolic vars - const relax::Var& last_relax_param = func->params.back(); - if (GetStructInfo(last_relax_param)->IsInstance()) { - auto [params, buffers] = - CreateParamsAndBuffers(GetStructInfo(last_relax_param), last_relax_param->name_hint()); - ICHECK(buffers.empty()); - for (size_t i = 0; i < params.size(); ++i) { - func_info_.params.push_back(params[i]); + for (const auto& param : prim_func_params) { + if (auto var = param.as()) { + func_info_.params.push_back(var.value()); } } @@ -548,12 +531,7 @@ class FusedTIRConstructor : public ExprVisitor { int end_buf_idx = 0; const TupleType& tuple_type = Downcast(tuple_get_item->tuple->checked_type()); for (int i = 0; i < tuple_get_item->index; ++i) { - auto it = func_info_.used_tuple_field_indices.find(tuple_get_item->tuple.get()); - // If this tuple is not passed as a parameter, or if the field at the index i is actually - // used, the corresponding buffer needs to be taken into account by this function. - if (it == func_info_.used_tuple_field_indices.end() || it->second.count(i)) { - begin_buf_idx += GetTotalTensorSize(tuple_type->fields[i]); - } + begin_buf_idx += GetTotalTensorSize(tuple_type->fields[i]); } end_buf_idx = begin_buf_idx + GetTotalTensorSize(tuple_type->fields[tuple_get_item->index]); func_info_.expr2buffers.Set( @@ -719,64 +697,46 @@ class FusedTIRConstructor : public ExprVisitor { } /*! - * \brief Create an TIR func params and buffers with specified relax type and shape + * \brief Collect TIR func params and buffers with specified relax type and shape * \param struct_info The struct info * \param name_hint The name hint for params and buffers - * \param index The index used for unique name_hint if type is Tuple. - * -1 means no need to add postfix since the relax param is not a Tuple. - * \return The created TIR func params and buffers + * \param out The vector into which to collect the params/buffers */ - static std::pair, Array> CreateParamsAndBuffers( - StructInfo struct_info, const String& name_hint, int index = -1) { - Array params; - Array buffers; - // The symbolic shape params must be defined at the end of the param list. - bool symbolic_shape_param_started = false; + static void CollectPrimFuncParams(const Var& relax_param, + std::vector>* out) { + auto struct_info = GetStructInfo(relax_param); + + CHECK(!struct_info.as()) + << "InternalError: " + << "All tuple parameters should be expanded before this point in FuseTIR. " + << "However, parameter " << relax_param << " has struct info " << struct_info; + + auto name_hint = relax_param->name_hint(); + if (const auto* tensor = struct_info.as()) { - // Case 1. the relax param is a Tensor, we directly create a tir var and buffer + // Case 1. The relax param is a Tensor, we directly create a tir var and buffer const auto* shape_expr = tensor->shape.as(); - ICHECK(shape_expr) << "FuseTIR expects all parameters are Tensors with symbolic shape."; - CHECK(!symbolic_shape_param_started) - << "The symbolic shape params must be defined at the end of the param " - "list."; - String name = index == -1 ? name_hint : name_hint + "_" + std::to_string(index); + ICHECK(shape_expr) << "FuseTIR expects all Tensor parameters have a known shape."; DataType dtype = tensor->dtype; - tir::Buffer buffer = tir::decl_buffer(shape_expr->values, dtype, name); - // Differentiate buffer name and param name by adding prefix `v_` to param - // Every symbol should be unique in TVMScript, and Buffer is used more than param - // So we decide to make sure buffer names have better readability. - tir::Var param = tir::Var("p_" + name, PrimType(DataType::Handle())); - params.push_back(std::move(param)); - buffers.push_back(std::move(buffer)); - } else if (const auto* tuple = struct_info.as()) { - // Case 2. the relax param is a Tuple, we recursively visit each field until it's a Tensor - // Enable postfix - CHECK(!symbolic_shape_param_started) - << "The symbolic shape params must be defined at the end of the param " - "list."; - if (index == -1) index = 0; - for (size_t i = 0; i < tuple->fields.size(); ++i) { - auto [ret_params, ret_buffers] = CreateParamsAndBuffers(tuple->fields[i], name_hint, index); - ICHECK_EQ(ret_params.size(), ret_buffers.size()); - // Adding tuple field results to the end of params and buffers. - params.insert(params.end(), ret_params.begin(), ret_params.end()); - buffers.insert(buffers.end(), ret_buffers.begin(), ret_buffers.end()); - index += ret_params.size(); - } + tir::Buffer buffer = tir::decl_buffer(shape_expr->values, dtype, name_hint); + out->push_back(std::move(buffer)); + + } else if (const auto* prim_value = struct_info.as()) { + // Case 2. The relax param is a scalar, we directly create a tir var + ICHECK(prim_value->value->IsInstance()); + out->push_back(Downcast(prim_value->value)); + } else if (const auto* shape_expr = struct_info.as()) { - // Case 3. the relax param is a scalar, we directly create a tir var - symbolic_shape_param_started = true; - ICHECK(index == -1) << "TypeError: The ShapeExprNode should not be in a Tuple field."; + // Case 3. The relax param is a tuple of scalars, each represented as a tir var for (const auto& var : shape_expr->values.value()) { ICHECK(var->IsInstance()); - params.push_back(Downcast(var)); + out->push_back(Downcast(var)); } } else { ICHECK(false) << "TypeError: The param type of PrimFunc is expected to be Tensor, Tuple or " "ShapeExpr, but got " << struct_info->GetTypeKey(); } - return std::make_pair(params, buffers); } /*! @@ -870,9 +830,6 @@ class FusedTIRConstructor : public ExprVisitor { /*! \brief The map from symbolic var to its corresponding var in the fused function */ tir::SymbolicMatcher symbolic_var_matcher = tir::SymbolicMatcher(&analyzer, &symbolic_var_remap); - - /*! \brief Record indices of tuple fields that are actually accessed. */ - std::unordered_map> used_tuple_field_indices; }; /*! \brief The IRModule */ @@ -987,34 +944,35 @@ class TIRFuseMutator : public ExprMutator { Array tir_vars; for (size_t i = 0; i < call->args.size(); ++i) { auto arg = call->args[i]; - Array flattened; - if (GetStructInfo(relax_func->params[i])->IsInstance()) { - // Add only those tuple fields which are actually used by the function body - auto tup_get_indices = GetTupleAccessedIndices(relax_func.get(), relax_func->params[i]); - for (size_t tup_get_ind : tup_get_indices) { - auto flattened_inner = FlattenArg(builder_->Emit(TupleGetItem(arg, tup_get_ind))); - flattened.insert(flattened.end(), flattened_inner.begin(), flattened_inner.end()); + auto sinfo = GetStructInfo(arg); + + ICHECK(!relax_func->params[i]->struct_info_->IsInstance() && + !sinfo.as()) + << "InternalError: " + << "All tuple parameters should be expanded before this point in FuseTIR. " + << "However, argument " << arg << " with struct info " << arg->struct_info_ + << " is passed as argument " << i << " to Primitive Relax function " << old_gv + << ", which expects parameter " << relax_func->params[i] << " to have struct info " + << relax_func->params[i]->struct_info_; + + if (const auto* shape = sinfo.as()) { + CHECK(shape->values.defined()) + << "FuseTIR requires all shape input has struct_info value."; + for (const PrimExpr& prim_value : shape->values.value()) { + CHECK(prim_value->IsInstance()) + << "All shape inputs are expected to be single tir var."; + tir_vars.push_back(prim_value); } - } else { - flattened.push_back(arg); - } + } else if (const auto* prim_value = sinfo.as()) { + CHECK(prim_value->value.defined()) + << "FuseTIR requires all R.Prim arguments to have a known value."; + PrimExpr expr = prim_value->value.value(); + CHECK(expr->IsInstance()) + << "FuseTIR currently requires all R.Prim arguments to provide a single tir::Var."; + tir_vars.push_back(expr); - for (const Expr& e : flattened) { - StructInfo sinfo = GetStructInfo(e); - if (sinfo->IsInstance()) { - arg_list.push_back(e); - } else if (const auto* shape = sinfo.as()) { - CHECK(shape->values.defined()) - << "FuseTIR requires all shape input has struct_info value."; - for (const PrimExpr& prim_value : shape->values.value()) { - CHECK(prim_value->IsInstance()) - << "All shape inputs are expected to be single tir var."; - tir_vars.push_back(prim_value); - } - } else { - LOG(FATAL) << "The flattened arg is expected to be either tensor or shape, but got " - << sinfo->GetTypeKey(); - } + } else { + arg_list.push_back(arg); } } // Step b. Create call_tir @@ -1042,23 +1000,6 @@ class TIRFuseMutator : public ExprMutator { return call; } - /********** Helper Functions **********/ - - /*! \brief Flatten the call args if it's Tuple by emitting `TupleGetItem`. */ - Array FlattenArg(const Expr& arg) { - if (const auto* tuple_sinfo = GetStructInfoAs(arg)) { - Array arg_list; - for (size_t i = 0; i < tuple_sinfo->fields.size(); ++i) { - Expr new_arg = builder_->Emit(TupleGetItem(arg, i)); - Array flattened = FlattenArg(new_arg); - arg_list.insert(arg_list.end(), flattened.begin(), flattened.end()); - } - return arg_list; - } else { - return {arg}; - } - } - private: /*! \brief The IRModule */ const IRModule& mod_; @@ -1076,10 +1017,17 @@ namespace transform { Pass FuseTIR() { runtime::TypedPackedFunc pass_func = // [=](IRModule m, PassContext pc) { return relax::FuseTIR(m); }; - return CreateModulePass(/*pass_function=*/pass_func, // - /*opt_level=*/0, // - /*pass_name=*/"FuseTIR", // - /*required=*/{}); + auto inner_pass = CreateModulePass(/*pass_function=*/pass_func, // + /*opt_level=*/0, // + /*pass_name=*/"FuseTIRInner", // + /*required=*/{}); + return tvm::transform::Sequential( + { + ExpandTupleArguments(), + RemoveUnusedParameters(), + inner_pass, + }, + "FuseTIR"); } TVM_REGISTER_GLOBAL("relax.transform.FuseTIR").set_body_typed(FuseTIR); diff --git a/tests/python/relax/test_transform_fuse_tir.py b/tests/python/relax/test_transform_fuse_tir.py index 991ed5254b18..c8f98e872413 100644 --- a/tests/python/relax/test_transform_fuse_tir.py +++ b/tests/python/relax/test_transform_fuse_tir.py @@ -205,7 +205,7 @@ def fused_exp_squeeze(x): with bb.function("main", [x]): with bb.dataflow(): lv = bb.emit_te(fused_exp_squeeze, x) - lv2 = bb.emit_te(fused_exp_squeeze, lv) + lv2 = bb.call_te(fused_exp_squeeze, lv) gv = bb.emit_output(lv2) bb.emit_func_output(gv) return bb.get() @@ -245,7 +245,7 @@ def fused_exp_exp_squeeze(x): x = relax.Var("x", R.Tensor([10, 20], "float32")) with bb.function("main", [x]): with bb.dataflow(): - lv = bb.emit_te(fused_exp_exp_squeeze, x) + lv = bb.call_te(fused_exp_exp_squeeze, x) gv = bb.emit_output(lv) bb.emit_func_output(gv) return bb.get() @@ -257,7 +257,7 @@ def test_fuse_with_tuple_as_param(): def before(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tuple([R.Tensor([10], "float32"), R.Tensor([10], "float32")])) - with bb.function("fused_exp_add", [x], attrs={"Primitive": True}): + with bb.function("fused_exp_add", [x], attrs={"Primitive": True}, private=True): with bb.dataflow(): lv0 = bb.emit(relax.TupleGetItem(x, 0)) lv1 = bb.emit(relax.TupleGetItem(x, 1)) @@ -300,7 +300,7 @@ def test_fuse_with_nested_tuple_as_param(): def before(): bb = relax.BlockBuilder() x = relax.Var("x", tuple_struct_info) - with bb.function("fused_exp_add_add", [x], attrs={"Primitive": True}): + with bb.function("fused_exp_add_add", [x], attrs={"Primitive": True}, private=True): with bb.dataflow(): lv0 = bb.emit(relax.TupleGetItem(x, 0)) lv0_exp = bb.emit_te(topi.exp, lv0) @@ -373,7 +373,7 @@ def fused_exp_squeeze(x): with bb.function("main", [x]): with bb.dataflow(): lv = bb.emit_te(fused_exp_squeeze, x) - lv2 = bb.emit_te(topi.add, lv, relax.const(1, "float32")) + lv2 = bb.call_te(topi.add, lv, relax.const(1, "float32")) gv = bb.emit_output(lv2) bb.emit_func_output(gv) return bb.get() @@ -414,7 +414,7 @@ def fused_add_exp_squeeze(x, y): x = relax.Var("x", R.Tensor([10, 20], "float32")) with bb.function("main", [x]): with bb.dataflow(): - lv = bb.emit_te(fused_add_exp_squeeze, x, relax.const(1, "float32")) + lv = bb.call_te(fused_add_exp_squeeze, x, relax.const(1, "float32")) gv = bb.emit_output(lv) bb.emit_func_output(gv) return bb.get() @@ -1268,7 +1268,7 @@ def reshape( (v_ax2 * T.int64(64) + v_ax3) % T.int64(2048), ] - @R.function + @R.function(private=True) def fused_reshape( lv: R.Tuple( R.Tensor((4, 8, 2048), dtype="float32"), R.Tensor((4, 8, 2048), dtype="float32") From 6ce4442508cec7a524a5e18492672f01d19c2fec Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 19 Dec 2023 21:09:54 -0600 Subject: [PATCH 2/3] Updated based on review comments --- src/relax/transform/fuse_tir.cc | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index 999886202abf..e5e2883a295e 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -403,7 +403,9 @@ class FusedTIRConstructor : public ExprVisitor { func_info_.expr2buffers.Set(relax_param, param_buffers); } - // Move all scalar params after buffer params. + // Move all scalar params after buffer params. To ensure that the + // order is deterministic and predictable for testing purposes, + // std::stable_sort is used instead of std::sort. std::stable_sort(prim_func_params.begin(), prim_func_params.end(), [](const auto& a, const auto& b) { bool a_is_var = a.template as(); @@ -733,9 +735,10 @@ class FusedTIRConstructor : public ExprVisitor { out->push_back(Downcast(var)); } } else { - ICHECK(false) << "TypeError: The param type of PrimFunc is expected to be Tensor, Tuple or " - "ShapeExpr, but got " - << struct_info->GetTypeKey(); + LOG(FATAL) << "TypeError: " + << "The param type of PrimFunc is expected to be " + << "Tensor, PrimValue, or ShapeExpr, " + << "but got " << struct_info->GetTypeKey(); } } From 12f2386468bbd9eadf52bf286a24f78342adcef6 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 29 Dec 2023 10:13:40 -0600 Subject: [PATCH 3/3] ci bump