Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 0 additions & 25 deletions include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -525,31 +525,6 @@ TVM_DLL Pass CreateModulePass(
const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func, int opt_level,
String name, Array<runtime::String> required, bool traceable = false);

/*
* \brief Utility to apply a pass to specific functions in an IRModule
*
* TVM uses IRModule to IRModule transformations at all stages of
* lowering. These transformations may be useful when hand-writing an
* optimized model, or to perform optimizations on specific kernels
* within an IRModule. This utility allows a pass to be applied to a
* specified function, without altering other functions in the module.
*
* \param pass The IRModule to IRModule pass to be applied.
*
* \param func_name_regex A regex used to select the functions to be
* updated. The pass will be applied to all functions whose name
* matches the regex.
*
* \param error_if_no_function_matches_regex Specifies the behavior if
* an IRModule does not contain any function matching the provided
* regex. If true, an error will be raised. If false (default),
* the IRModule will be returned unmodified.
*
* \return The modified IRModule to IRModule pass.
*/
TVM_DLL Pass ApplyPassToFunction(Pass pass, String func_name_regex,
bool error_if_no_function_matches_regex = false);

/*!
* \brief A special trace pass that prints the header and IR to LOG(INFO).
* \param header The header to be attached to the output.
Expand Down
31 changes: 0 additions & 31 deletions src/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@

#include <chrono>
#include <iomanip>
#include <regex>
#include <stack>
#include <unordered_set>

Expand Down Expand Up @@ -532,36 +531,6 @@ Pass CreateModulePass(const runtime::TypedPackedFunc<IRModule(IRModule, PassCont
return ModulePass(pass_func, pass_info);
}

Pass ApplyPassToFunction(Pass pass, String func_name_regex,
bool error_if_no_function_matches_regex) {
auto pass_name =
static_cast<const std::stringstream&>(std::stringstream() << "ApplyPassTo" << func_name_regex)
.str();
std::regex regex(func_name_regex.operator std::string());

auto pass_func = [pass, regex](IRModule mod, PassContext) -> IRModule {
IRModule subset;

for (const auto& [gvar, func] : mod->functions) {
std::string name = gvar->name_hint;
if (std::regex_match(name, regex)) {
subset->Add(gvar, func);
}
}

if (subset->functions.size()) {
IRModule new_subset = pass(subset);
if (!new_subset.same_as(subset)) {
mod.CopyOnWrite()->Update(new_subset);
}
}

return mod;
};

return CreateModulePass(pass_func, 0, pass_name, {});
}

TVM_REGISTER_NODE_TYPE(PassInfoNode);

TVM_REGISTER_GLOBAL("transform.PassInfo")
Expand Down
156 changes: 83 additions & 73 deletions src/relax/transform/decompose_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Expr ExpandToMatchInput(Expr data, int ndim, Array<Integer> axes) {
return expand_dims(data, expand_axes);
}

Tuple DecomposeBatchNorm(const Call& call) {
Tuple SimplifyBatchNormInference(const Call& call) {
auto attrs = call->attrs.as<BatchNormAttrs>();
ICHECK_NOTNULL(attrs);

Expand All @@ -75,18 +75,14 @@ Tuple DecomposeBatchNorm(const Call& call) {
return Tuple({out, call->args[3], call->args[4]});
}

Expr MutateBatchNormForTraining(Call call) {
Tuple SimplifyBatchNormTraining(const Call& call) {
auto attrs = call->attrs.as<BatchNormAttrs>();
ICHECK_NOTNULL(attrs);

ICHECK_EQ(call->args.size(), 5);
Expr data = call->args[0];
TensorStructInfo sinfo = MatchTensorStructInfo(data);
Expr gamma = call->args[1];
Expr beta = call->args[2];
Expr moving_mean = call->args[3];
Expr moving_var = call->args[4];

TensorStructInfo sinfo = MatchTensorStructInfo(data);

Array<Integer> reduce_axes;
for (int i = 0; i < sinfo->ndim; ++i) {
Expand All @@ -96,21 +92,35 @@ Expr MutateBatchNormForTraining(Call call) {
}

Expr data_mean = mean(data, reduce_axes, false);
Expr data_mean_rs = ExpandToMatchInput(data_mean, sinfo->ndim, {attrs->axis});
Expr data_var = variance(data, reduce_axes, false);
Expr data_var_rs = ExpandToMatchInput(data_var, sinfo->ndim, {attrs->axis});

Expr momentum = MakeConstantScalar(attrs->momentum, sinfo->dtype);
Expr one_minus_mom = MakeConstantScalar(1 - attrs->momentum, sinfo->dtype);
// output = (x - mean) / sqrt(var + epsilon) * gamma + beta
Expr epsilon = MakeConstantScalar(attrs->epsilon, sinfo->dtype);
Expr sqrt_var = sqrt(add(data_var_rs, epsilon));
Expr out = divide(subtract(data, data_mean_rs), sqrt_var);

Expr new_moving_mean = add(multiply(one_minus_mom, moving_mean), multiply(momentum, data_mean));
Expr new_moving_var = add(multiply(one_minus_mom, moving_var), multiply(momentum, data_var));
if (attrs->scale) {
out = multiply(out, ExpandToMatchInput(gamma, sinfo->ndim, {attrs->axis}));
}
if (attrs->center) {
out = add(out, ExpandToMatchInput(beta, sinfo->ndim, {attrs->axis}));
}

call.CopyOnWrite()->args = {data, gamma, beta, data_mean, data_var};
// return call;
Expr moving_mean = call->args[3];
Expr moving_var = call->args[4];
Expr momentum = MakeConstantScalar(attrs->momentum, sinfo->dtype);
Expr one_minus_mom = MakeConstantScalar(1 - attrs->momentum, sinfo->dtype);

return relax::Tuple({TupleGetItem(call, 0), new_moving_mean, new_moving_var});
return Tuple({
out,
add(multiply(one_minus_mom, moving_mean), multiply(momentum, data_mean)),
add(multiply(one_minus_mom, moving_var), multiply(momentum, data_var)),
});
}

Expr DecomposeLayerNorm(const Call& call) {
Expr SimplifyLayerNorm(const Call& call) {
auto attrs = call->attrs.as<LayerNormAttrs>();
ICHECK_NOTNULL(attrs);

Expand Down Expand Up @@ -162,92 +172,92 @@ Expr TensorToShape(const Call& call_node, const BlockBuilder& builder) {
return ShapeExpr(shape_var);
}

/*! \brief Update operators that have a training-specific form
*
* Some operators, such as relax.op.batch_norm, need additional
* processing when being run for training. This mutator applies any mutations required
*/
class TrainingOperatorMutator : public ExprMutator {
private:
using ExprMutator::VisitExpr_;
class OpDecomposer : public ExprMutator {
public:
constexpr static const char* kModeInference = "inference";
constexpr static const char* kModeTraining = "training";

Expr VisitExpr_(const CallNode* call_node) final {
Call call = Downcast<Call>(VisitExprPostOrder_(call_node));
if (call->op == batch_norm_op_) {
return MutateBatchNormForTraining(call);
} else if (call->op == layer_norm_op_) {
// Here we only decompose LayerNorm in training because it is more efficient as a single op.
// In the future maybe we can also remove this decomposition during training.
return DecomposeLayerNorm(call);
} else {
return call;
}
explicit OpDecomposer(String mode) : ExprMutator(), mode_(mode) {
CHECK(mode == kModeInference || mode == kModeTraining)
<< "The argument mode must be one of the following values: \"inference\", \"training\".";
}

/* composite opeartor list */
const Op& batch_norm_op_ = Op::Get("relax.nn.batch_norm");
const Op& layer_norm_op_ = Op::Get("relax.nn.layer_norm");
};

class OpDecomposer : public ExprMutator {
private:
using ExprMutator::VisitExpr_;

Expr VisitExpr_(const CallNode* call_node) final {
Call call = Downcast<Call>(VisitExprPostOrder_(call_node));
if (call->op == batch_norm_op_) {
return DecomposeBatchNorm(call);
if (mode_ == kModeInference) {
return SimplifyBatchNormInference(call);
} else {
ICHECK_EQ(mode_, kModeTraining);
return SimplifyBatchNormTraining(call);
}
} else if (call->op == layer_norm_op_ && mode_ == kModeTraining) {
// Here we only decompose LayerNorm in training because it is more efficient as a single op.
// In the future maybe we can also remove this decomposition during training.
return SimplifyLayerNorm(call);
} else if (call->op == tensor_to_shape_op_) {
return TensorToShape(call, builder_);
}
return call;
}

const String mode_;

/* composite opeartor list */
const Op& batch_norm_op_ = Op::Get("relax.nn.batch_norm");
const Op& layer_norm_op_ = Op::Get("relax.nn.layer_norm");
const Op& tensor_to_shape_op_ = Op::Get("relax.tensor_to_shape");
};

namespace transform {
IRModule Decompose(IRModule mod, Optional<String> func_name, String mode) {
auto op_decomposer = OpDecomposer(mode);

Pass MutateOpsForTraining() {
auto pass_func = [](Function func, IRModule, PassContext) -> Function {
TrainingOperatorMutator mutator;
return Downcast<Function>(mutator(func));
};
return CreateFunctionPass(/*pass_function=*/pass_func,
/*opt_level=*/0,
/*pass_name=*/"MutateOpsForTraining",
/*required=*/{});
}
IRModuleNode* new_module = mod.CopyOnWrite();

Pass DecomposeOps() {
auto pass_func = [](Function func, IRModule, PassContext) -> Function {
OpDecomposer mutator;
return Downcast<Function>(mutator(func));
};
return CreateFunctionPass(/*pass_function=*/pass_func,
/*opt_level=*/0,
/*pass_name=*/"DecomposeOps",
/*required=*/{});
if (!func_name.defined()) { // simplify all functions
Map<GlobalVar, BaseFunc> functions = mod->functions;
for (const auto& func_pr : functions) {
if (const auto* relax_f = func_pr.second.as<FunctionNode>()) {
Function f = Downcast<Function>(op_decomposer(GetRef<Function>(relax_f)));
new_module->Update(func_pr.first, f);
}
}
} else { // simplify specified function
auto* func_ptr = mod->Lookup(func_name.value()).as<FunctionNode>();
CHECK(func_ptr) << func_name.value() << "is not a Relax Function";
auto gvar = mod->GetGlobalVar(func_name.value());
auto func = GetRef<Function>(func_ptr);
func = Downcast<Function>(op_decomposer(func));
new_module->Update(gvar, func);
}

return GetRef<IRModule>(new_module);
}

namespace transform {
Pass DecomposeOpsForInference(Optional<String> func_name) {
if (func_name) {
return ApplyPassToFunction(DecomposeOps(), func_name.value());
} else {
return DecomposeOps();
}
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule mod,
PassContext pc) {
return Decompose(mod, func_name, OpDecomposer::kModeInference);
};
return CreateModulePass(/*pass_function=*/pass_func,
/*opt_level=*/0,
/*pass_name=*/"DecomposeOpsForInference",
/*required=*/{});
}

Pass DecomposeOpsForTraining(Optional<String> func_name) {
auto module_pass = tvm::transform::Sequential({MutateOpsForTraining(), DecomposeOps()},
"DecomposeOpsForTraining");
if (func_name) {
return ApplyPassToFunction(module_pass, func_name.value());
} else {
return module_pass;
}
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule mod,
PassContext pc) {
return Decompose(mod, func_name, OpDecomposer::kModeTraining);
};
return CreateModulePass(/*pass_function=*/pass_func,
/*opt_level=*/0,
/*pass_name=*/"DecomposeOpsForTraining",
/*required=*/{});
}

TVM_REGISTER_GLOBAL("relax.transform.DecomposeOpsForInference")
Expand Down
71 changes: 38 additions & 33 deletions tests/python/relax/test_transform_decompose_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,39 +137,44 @@ def main(
R.Tensor((64,), dtype="float32"),
):
with R.dataflow():
# This portion is training-specific, computing the
# mean/variance of the dataset.
lv = R.mean(x, axis=[0, 2, 3], keepdims=False)
lv3 = R.variance(x, axis=[0, 2, 3], keepdims=False)

# This portion is identical to the batch_norm run during inference
lv1 = R.expand_dims(lv, axis=[0, 2, 3])
lv2 = R.subtract(x, lv1)
lv4 = R.expand_dims(lv3, axis=[0, 2, 3])
lv5 = R.add(lv4, R.const(9.9999997473787516e-06, "float32"))
lv6 = R.sqrt(lv5)
lv7 = R.divide(lv2, lv6)
lv8 = R.expand_dims(gamma, axis=[0, 2, 3])
lv9 = R.multiply(lv7, lv8)
lv10 = R.expand_dims(beta, axis=[0, 2, 3])
lv11 = R.add(lv9, lv10)
inner_tuple = (lv11, lv, lv3)
# This is the result that would be returned from a
# batch_norm at inference.

# However, at training we need to update the moving
# mean/variance, and to return those updated values.
inner_res = inner_tuple[0]
lv12 = R.multiply(R.const(0.89999997615814209, "float32"), moving_mean)
lv13 = R.multiply(R.const(0.10000000149011612, "float32"), lv)
lv14 = R.add(lv12, lv13)
lv15 = R.multiply(R.const(0.89999997615814209, "float32"), moving_var)
lv16 = R.multiply(R.const(0.10000000149011612, "float32"), lv3)
lv17 = R.add(lv15, lv16)
bn = (inner_res, lv14, lv17)
gv0 = bn[0]
gv1 = bn[1]
gv2 = bn[2]
lv: R.Tensor((64,), dtype="float32") = R.mean(x, axis=[0, 2, 3], keepdims=False)
lv1: R.Tensor((1, 64, 1, 1), dtype="float32") = R.expand_dims(lv, axis=[0, 2, 3])
lv2: R.Tensor((1, 64, 112, 112), dtype="float32") = R.subtract(x, lv1)
lv3: R.Tensor((64,), dtype="float32") = R.variance(
x, axis=[0, 2, 3], keepdims=False
)
lv4: R.Tensor((1, 64, 1, 1), dtype="float32") = R.expand_dims(lv3, axis=[0, 2, 3])
lv5: R.Tensor((1, 64, 1, 1), dtype="float32") = R.add(
lv4, R.const(9.9999997473787516e-06, "float32")
)
lv6: R.Tensor((1, 64, 1, 1), dtype="float32") = R.sqrt(lv5)
lv7: R.Tensor((1, 64, 112, 112), dtype="float32") = R.divide(lv2, lv6)
lv8: R.Tensor((1, 64, 1, 1), dtype="float32") = R.expand_dims(gamma, axis=[0, 2, 3])
lv9: R.Tensor((1, 64, 112, 112), dtype="float32") = R.multiply(lv7, lv8)
lv10: R.Tensor((1, 64, 1, 1), dtype="float32") = R.expand_dims(beta, axis=[0, 2, 3])
lv11: R.Tensor((1, 64, 112, 112), dtype="float32") = R.add(lv9, lv10)
lv12: R.Tensor((64,), dtype="float32") = R.multiply(
R.const(0.89999997615814209, "float32"), moving_mean
)
lv13: R.Tensor((64,), dtype="float32") = R.multiply(
R.const(0.10000000149011612, "float32"), lv
)
lv14: R.Tensor((64,), dtype="float32") = R.add(lv12, lv13)
lv15: R.Tensor((64,), dtype="float32") = R.multiply(
R.const(0.89999997615814209, "float32"), moving_var
)
lv16: R.Tensor((64,), dtype="float32") = R.multiply(
R.const(0.10000000149011612, "float32"), lv3
)
lv17: R.Tensor((64,), dtype="float32") = R.add(lv15, lv16)
bn: R.Tuple(
R.Tensor((1, 64, 112, 112), dtype="float32"),
R.Tensor((64,), dtype="float32"),
R.Tensor((64,), dtype="float32"),
) = (lv11, lv14, lv17)
gv0: R.Tensor((1, 64, 112, 112), dtype="float32") = bn[0]
gv1: R.Tensor((64,), dtype="float32") = bn[1]
gv2: R.Tensor((64,), dtype="float32") = bn[2]
R.output(gv0, gv1, gv2)
return (gv0, gv1, gv2)

Expand Down