From aa037e5e95f244d4c151932d150657a45a51a1f1 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sun, 21 Jan 2024 09:07:23 -0500 Subject: [PATCH] Revert "[Unity] Split DecomposeOpsForTraining into two steps" --- include/tvm/ir/transform.h | 25 --- src/ir/transform.cc | 31 ---- src/relax/transform/decompose_ops.cc | 156 ++++++++++-------- .../relax/test_transform_decompose_ops.py | 71 ++++---- 4 files changed, 121 insertions(+), 162 deletions(-) diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index adf332525020..ec151d9d7589 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -525,31 +525,6 @@ TVM_DLL Pass CreateModulePass( const runtime::TypedPackedFunc& pass_func, int opt_level, String name, Array required, bool traceable = false); -/* - * \brief Utility to apply a pass to specific functions in an IRModule - * - * TVM uses IRModule to IRModule transformations at all stages of - * lowering. These transformations may be useful when hand-writing an - * optimized model, or to perform optimizations on specific kernels - * within an IRModule. This utility allows a pass to be applied to a - * specified function, without altering other functions in the module. - * - * \param pass The IRModule to IRModule pass to be applied. - * - * \param func_name_regex A regex used to select the functions to be - * updated. The pass will be applied to all functions whose name - * matches the regex. - * - * \param error_if_no_function_matches_regex Specifies the behavior if - * an IRModule does not contain any function matching the provided - * regex. If true, an error will be raised. If false (default), - * the IRModule will be returned unmodified. - * - * \return The modified IRModule to IRModule pass. - */ -TVM_DLL Pass ApplyPassToFunction(Pass pass, String func_name_regex, - bool error_if_no_function_matches_regex = false); - /*! * \brief A special trace pass that prints the header and IR to LOG(INFO). * \param header The header to be attached to the output. diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 3bae6be9ba34..f83812094312 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -31,7 +31,6 @@ #include #include -#include #include #include @@ -532,36 +531,6 @@ Pass CreateModulePass(const runtime::TypedPackedFunc(std::stringstream() << "ApplyPassTo" << func_name_regex) - .str(); - std::regex regex(func_name_regex.operator std::string()); - - auto pass_func = [pass, regex](IRModule mod, PassContext) -> IRModule { - IRModule subset; - - for (const auto& [gvar, func] : mod->functions) { - std::string name = gvar->name_hint; - if (std::regex_match(name, regex)) { - subset->Add(gvar, func); - } - } - - if (subset->functions.size()) { - IRModule new_subset = pass(subset); - if (!new_subset.same_as(subset)) { - mod.CopyOnWrite()->Update(new_subset); - } - } - - return mod; - }; - - return CreateModulePass(pass_func, 0, pass_name, {}); -} - TVM_REGISTER_NODE_TYPE(PassInfoNode); TVM_REGISTER_GLOBAL("transform.PassInfo") diff --git a/src/relax/transform/decompose_ops.cc b/src/relax/transform/decompose_ops.cc index 1a4cd216256b..899c80c1c454 100644 --- a/src/relax/transform/decompose_ops.cc +++ b/src/relax/transform/decompose_ops.cc @@ -48,7 +48,7 @@ Expr ExpandToMatchInput(Expr data, int ndim, Array axes) { return expand_dims(data, expand_axes); } -Tuple DecomposeBatchNorm(const Call& call) { +Tuple SimplifyBatchNormInference(const Call& call) { auto attrs = call->attrs.as(); ICHECK_NOTNULL(attrs); @@ -75,18 +75,14 @@ Tuple DecomposeBatchNorm(const Call& call) { return Tuple({out, call->args[3], call->args[4]}); } -Expr MutateBatchNormForTraining(Call call) { +Tuple SimplifyBatchNormTraining(const Call& call) { auto attrs = call->attrs.as(); ICHECK_NOTNULL(attrs); - ICHECK_EQ(call->args.size(), 5); Expr data = call->args[0]; + TensorStructInfo sinfo = MatchTensorStructInfo(data); Expr gamma = call->args[1]; Expr beta = call->args[2]; - Expr moving_mean = call->args[3]; - Expr moving_var = call->args[4]; - - TensorStructInfo sinfo = MatchTensorStructInfo(data); Array reduce_axes; for (int i = 0; i < sinfo->ndim; ++i) { @@ -96,21 +92,35 @@ Expr MutateBatchNormForTraining(Call call) { } Expr data_mean = mean(data, reduce_axes, false); + Expr data_mean_rs = ExpandToMatchInput(data_mean, sinfo->ndim, {attrs->axis}); Expr data_var = variance(data, reduce_axes, false); + Expr data_var_rs = ExpandToMatchInput(data_var, sinfo->ndim, {attrs->axis}); - Expr momentum = MakeConstantScalar(attrs->momentum, sinfo->dtype); - Expr one_minus_mom = MakeConstantScalar(1 - attrs->momentum, sinfo->dtype); + // output = (x - mean) / sqrt(var + epsilon) * gamma + beta + Expr epsilon = MakeConstantScalar(attrs->epsilon, sinfo->dtype); + Expr sqrt_var = sqrt(add(data_var_rs, epsilon)); + Expr out = divide(subtract(data, data_mean_rs), sqrt_var); - Expr new_moving_mean = add(multiply(one_minus_mom, moving_mean), multiply(momentum, data_mean)); - Expr new_moving_var = add(multiply(one_minus_mom, moving_var), multiply(momentum, data_var)); + if (attrs->scale) { + out = multiply(out, ExpandToMatchInput(gamma, sinfo->ndim, {attrs->axis})); + } + if (attrs->center) { + out = add(out, ExpandToMatchInput(beta, sinfo->ndim, {attrs->axis})); + } - call.CopyOnWrite()->args = {data, gamma, beta, data_mean, data_var}; - // return call; + Expr moving_mean = call->args[3]; + Expr moving_var = call->args[4]; + Expr momentum = MakeConstantScalar(attrs->momentum, sinfo->dtype); + Expr one_minus_mom = MakeConstantScalar(1 - attrs->momentum, sinfo->dtype); - return relax::Tuple({TupleGetItem(call, 0), new_moving_mean, new_moving_var}); + return Tuple({ + out, + add(multiply(one_minus_mom, moving_mean), multiply(momentum, data_mean)), + add(multiply(one_minus_mom, moving_var), multiply(momentum, data_var)), + }); } -Expr DecomposeLayerNorm(const Call& call) { +Expr SimplifyLayerNorm(const Call& call) { auto attrs = call->attrs.as(); ICHECK_NOTNULL(attrs); @@ -162,92 +172,92 @@ Expr TensorToShape(const Call& call_node, const BlockBuilder& builder) { return ShapeExpr(shape_var); } -/*! \brief Update operators that have a training-specific form - * - * Some operators, such as relax.op.batch_norm, need additional - * processing when being run for training. This mutator applies any mutations required - */ -class TrainingOperatorMutator : public ExprMutator { - private: - using ExprMutator::VisitExpr_; +class OpDecomposer : public ExprMutator { + public: + constexpr static const char* kModeInference = "inference"; + constexpr static const char* kModeTraining = "training"; - Expr VisitExpr_(const CallNode* call_node) final { - Call call = Downcast(VisitExprPostOrder_(call_node)); - if (call->op == batch_norm_op_) { - return MutateBatchNormForTraining(call); - } else if (call->op == layer_norm_op_) { - // Here we only decompose LayerNorm in training because it is more efficient as a single op. - // In the future maybe we can also remove this decomposition during training. - return DecomposeLayerNorm(call); - } else { - return call; - } + explicit OpDecomposer(String mode) : ExprMutator(), mode_(mode) { + CHECK(mode == kModeInference || mode == kModeTraining) + << "The argument mode must be one of the following values: \"inference\", \"training\"."; } - /* composite opeartor list */ - const Op& batch_norm_op_ = Op::Get("relax.nn.batch_norm"); - const Op& layer_norm_op_ = Op::Get("relax.nn.layer_norm"); -}; - -class OpDecomposer : public ExprMutator { private: using ExprMutator::VisitExpr_; Expr VisitExpr_(const CallNode* call_node) final { Call call = Downcast(VisitExprPostOrder_(call_node)); if (call->op == batch_norm_op_) { - return DecomposeBatchNorm(call); + if (mode_ == kModeInference) { + return SimplifyBatchNormInference(call); + } else { + ICHECK_EQ(mode_, kModeTraining); + return SimplifyBatchNormTraining(call); + } + } else if (call->op == layer_norm_op_ && mode_ == kModeTraining) { + // Here we only decompose LayerNorm in training because it is more efficient as a single op. + // In the future maybe we can also remove this decomposition during training. + return SimplifyLayerNorm(call); } else if (call->op == tensor_to_shape_op_) { return TensorToShape(call, builder_); } return call; } + const String mode_; + /* composite opeartor list */ const Op& batch_norm_op_ = Op::Get("relax.nn.batch_norm"); + const Op& layer_norm_op_ = Op::Get("relax.nn.layer_norm"); const Op& tensor_to_shape_op_ = Op::Get("relax.tensor_to_shape"); }; -namespace transform { +IRModule Decompose(IRModule mod, Optional func_name, String mode) { + auto op_decomposer = OpDecomposer(mode); -Pass MutateOpsForTraining() { - auto pass_func = [](Function func, IRModule, PassContext) -> Function { - TrainingOperatorMutator mutator; - return Downcast(mutator(func)); - }; - return CreateFunctionPass(/*pass_function=*/pass_func, - /*opt_level=*/0, - /*pass_name=*/"MutateOpsForTraining", - /*required=*/{}); -} + IRModuleNode* new_module = mod.CopyOnWrite(); -Pass DecomposeOps() { - auto pass_func = [](Function func, IRModule, PassContext) -> Function { - OpDecomposer mutator; - return Downcast(mutator(func)); - }; - return CreateFunctionPass(/*pass_function=*/pass_func, - /*opt_level=*/0, - /*pass_name=*/"DecomposeOps", - /*required=*/{}); + if (!func_name.defined()) { // simplify all functions + Map functions = mod->functions; + for (const auto& func_pr : functions) { + if (const auto* relax_f = func_pr.second.as()) { + Function f = Downcast(op_decomposer(GetRef(relax_f))); + new_module->Update(func_pr.first, f); + } + } + } else { // simplify specified function + auto* func_ptr = mod->Lookup(func_name.value()).as(); + CHECK(func_ptr) << func_name.value() << "is not a Relax Function"; + auto gvar = mod->GetGlobalVar(func_name.value()); + auto func = GetRef(func_ptr); + func = Downcast(op_decomposer(func)); + new_module->Update(gvar, func); + } + + return GetRef(new_module); } +namespace transform { Pass DecomposeOpsForInference(Optional func_name) { - if (func_name) { - return ApplyPassToFunction(DecomposeOps(), func_name.value()); - } else { - return DecomposeOps(); - } + runtime::TypedPackedFunc pass_func = [=](IRModule mod, + PassContext pc) { + return Decompose(mod, func_name, OpDecomposer::kModeInference); + }; + return CreateModulePass(/*pass_function=*/pass_func, + /*opt_level=*/0, + /*pass_name=*/"DecomposeOpsForInference", + /*required=*/{}); } Pass DecomposeOpsForTraining(Optional func_name) { - auto module_pass = tvm::transform::Sequential({MutateOpsForTraining(), DecomposeOps()}, - "DecomposeOpsForTraining"); - if (func_name) { - return ApplyPassToFunction(module_pass, func_name.value()); - } else { - return module_pass; - } + runtime::TypedPackedFunc pass_func = [=](IRModule mod, + PassContext pc) { + return Decompose(mod, func_name, OpDecomposer::kModeTraining); + }; + return CreateModulePass(/*pass_function=*/pass_func, + /*opt_level=*/0, + /*pass_name=*/"DecomposeOpsForTraining", + /*required=*/{}); } TVM_REGISTER_GLOBAL("relax.transform.DecomposeOpsForInference") diff --git a/tests/python/relax/test_transform_decompose_ops.py b/tests/python/relax/test_transform_decompose_ops.py index 4e5bcb82e979..85657ab245ea 100644 --- a/tests/python/relax/test_transform_decompose_ops.py +++ b/tests/python/relax/test_transform_decompose_ops.py @@ -137,39 +137,44 @@ def main( R.Tensor((64,), dtype="float32"), ): with R.dataflow(): - # This portion is training-specific, computing the - # mean/variance of the dataset. - lv = R.mean(x, axis=[0, 2, 3], keepdims=False) - lv3 = R.variance(x, axis=[0, 2, 3], keepdims=False) - - # This portion is identical to the batch_norm run during inference - lv1 = R.expand_dims(lv, axis=[0, 2, 3]) - lv2 = R.subtract(x, lv1) - lv4 = R.expand_dims(lv3, axis=[0, 2, 3]) - lv5 = R.add(lv4, R.const(9.9999997473787516e-06, "float32")) - lv6 = R.sqrt(lv5) - lv7 = R.divide(lv2, lv6) - lv8 = R.expand_dims(gamma, axis=[0, 2, 3]) - lv9 = R.multiply(lv7, lv8) - lv10 = R.expand_dims(beta, axis=[0, 2, 3]) - lv11 = R.add(lv9, lv10) - inner_tuple = (lv11, lv, lv3) - # This is the result that would be returned from a - # batch_norm at inference. - - # However, at training we need to update the moving - # mean/variance, and to return those updated values. - inner_res = inner_tuple[0] - lv12 = R.multiply(R.const(0.89999997615814209, "float32"), moving_mean) - lv13 = R.multiply(R.const(0.10000000149011612, "float32"), lv) - lv14 = R.add(lv12, lv13) - lv15 = R.multiply(R.const(0.89999997615814209, "float32"), moving_var) - lv16 = R.multiply(R.const(0.10000000149011612, "float32"), lv3) - lv17 = R.add(lv15, lv16) - bn = (inner_res, lv14, lv17) - gv0 = bn[0] - gv1 = bn[1] - gv2 = bn[2] + lv: R.Tensor((64,), dtype="float32") = R.mean(x, axis=[0, 2, 3], keepdims=False) + lv1: R.Tensor((1, 64, 1, 1), dtype="float32") = R.expand_dims(lv, axis=[0, 2, 3]) + lv2: R.Tensor((1, 64, 112, 112), dtype="float32") = R.subtract(x, lv1) + lv3: R.Tensor((64,), dtype="float32") = R.variance( + x, axis=[0, 2, 3], keepdims=False + ) + lv4: R.Tensor((1, 64, 1, 1), dtype="float32") = R.expand_dims(lv3, axis=[0, 2, 3]) + lv5: R.Tensor((1, 64, 1, 1), dtype="float32") = R.add( + lv4, R.const(9.9999997473787516e-06, "float32") + ) + lv6: R.Tensor((1, 64, 1, 1), dtype="float32") = R.sqrt(lv5) + lv7: R.Tensor((1, 64, 112, 112), dtype="float32") = R.divide(lv2, lv6) + lv8: R.Tensor((1, 64, 1, 1), dtype="float32") = R.expand_dims(gamma, axis=[0, 2, 3]) + lv9: R.Tensor((1, 64, 112, 112), dtype="float32") = R.multiply(lv7, lv8) + lv10: R.Tensor((1, 64, 1, 1), dtype="float32") = R.expand_dims(beta, axis=[0, 2, 3]) + lv11: R.Tensor((1, 64, 112, 112), dtype="float32") = R.add(lv9, lv10) + lv12: R.Tensor((64,), dtype="float32") = R.multiply( + R.const(0.89999997615814209, "float32"), moving_mean + ) + lv13: R.Tensor((64,), dtype="float32") = R.multiply( + R.const(0.10000000149011612, "float32"), lv + ) + lv14: R.Tensor((64,), dtype="float32") = R.add(lv12, lv13) + lv15: R.Tensor((64,), dtype="float32") = R.multiply( + R.const(0.89999997615814209, "float32"), moving_var + ) + lv16: R.Tensor((64,), dtype="float32") = R.multiply( + R.const(0.10000000149011612, "float32"), lv3 + ) + lv17: R.Tensor((64,), dtype="float32") = R.add(lv15, lv16) + bn: R.Tuple( + R.Tensor((1, 64, 112, 112), dtype="float32"), + R.Tensor((64,), dtype="float32"), + R.Tensor((64,), dtype="float32"), + ) = (lv11, lv14, lv17) + gv0: R.Tensor((1, 64, 112, 112), dtype="float32") = bn[0] + gv1: R.Tensor((64,), dtype="float32") = bn[1] + gv2: R.Tensor((64,), dtype="float32") = bn[2] R.output(gv0, gv1, gv2) return (gv0, gv1, gv2)