Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "orttraining/core/graph/gradient_builder_registry.h"
#include "orttraining/core/graph/gradient_config.h"
#include "orttraining/core/optimizer/insert_output_rewriter.h"
#include "orttraining/core/optimizer/batchnorm_replacement.h"
#include "core/optimizer/gelu_fusion.h"
#include "core/optimizer/rule_based_graph_transformer.h"

Expand All @@ -32,11 +33,10 @@ GradientGraphBuilder::GradientGraphBuilder(Graph* graph,
auto rule_based_graph_transformer =
std::make_unique<RuleBasedGraphTransformer>("pre_training_rule_based_graph_transformer");
rule_based_graph_transformer->Register(std::make_unique<InsertMaxPoolOutput>());
rule_based_graph_transformer->Register(std::make_unique<AdjustBatchNormOutputs>());
rule_based_graph_transformer->Register(std::make_unique<BatchNormReplacement>());

graph_transformation_mgr_.Register(std::move(rule_based_graph_transformer),
TransformerLevel::Level2);

auto forward_reachable_nodes = BFSWithStopGradient(x_node_arg_names);

for (const auto& name : y_node_arg_names) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ static std::unordered_map<std::string, std::unordered_set<size_t>>
{"Not", {0}},
{"And", {0, 1}},
{"BatchNormalization", {3, 4}},
{"BatchNormInternal", {3, 4}},
{"Or", {0, 1}},
{"Xor", {0, 1}},
{"Equal", {0, 1}},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
REGISTER_GRADIENT_BUILDER("FastGelu", GetFastGeluGradient);
REGISTER_GRADIENT_BUILDER("LayerNormalization", GetLayerNormalizationGradient);
REGISTER_GRADIENT_BUILDER("SimplifiedLayerNormalization", GetSimplifiedLayerNormalizationGradient);
REGISTER_GRADIENT_BUILDER("BatchNormalization", GetBatchNormalizationGradient);
REGISTER_GRADIENT_BUILDER("BatchNormInternal", GetBatchNormalizationGradient);
REGISTER_GRADIENT_BUILDER("MegatronF", GetMegatronFGradient);
REGISTER_GRADIENT_BUILDER("MegatronG", GetMegatronGGradient);
REGISTER_GRADIENT_BUILDER("Slice", GetSliceGradient);
Expand Down
65 changes: 65 additions & 0 deletions orttraining/orttraining/core/graph/training_op_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2107,6 +2107,71 @@ Return true if all elements are true and false otherwise.
{"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
"Constrain input and output types to float tensors.");

ONNX_CONTRIB_OPERATOR_SCHEMA(BatchNormInternal)
.SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL)
.SetDoc("Variant of BatchNormalization with additional output for saved_mean/inv_std_dev.")
.SetDomain(kMSDomain)
.SinceVersion(1)
.Attr("epsilon", "epsilon value", AttributeProto::FLOAT, 1e-5f)
.Attr("momentum", "momentum value", AttributeProto::FLOAT, 0.9f)
.Attr("training_mode", "true if training", AttributeProto::INT, static_cast<int64_t>(1))
.Input(0, "X", "Input tensor.", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
.Input(1, "scale", "Scale tensor of shape (C).", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
.Input(2, "B", "Bias tensor of shape (C).", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
.Input(3, "input_mean", "running mean tensor of shape (C).", "U", OpSchema::Single, true, 1, OpSchema::Differentiable)
.Input(4, "input_var", "running variance tensor of shape (C).", "U", OpSchema::Single, true, 1, OpSchema::Differentiable)
.Output(0, "Y", "The output tensor of the same shape as X", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
.Output(1,"running_mean", "The running mean after BN.", "U", OpSchema::Optional, true, 1, OpSchema::NonDifferentiable)
.Output(2, "running_var", "Running var after BN", "U", OpSchema::Optional, true, 1, OpSchema::NonDifferentiable)
.Output(3, "saved_mean", "Mean of the batch", "U", OpSchema::Optional, true, 1, OpSchema::NonDifferentiable)
.Output(4, "saved_inv_std", "Inverse standard deviation for the batch", "U", OpSchema::Optional, true, 1, OpSchema::NonDifferentiable)
.TypeConstraint(
"T",
{"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
"Constrain input and output types to float tensors.")
.TypeConstraint(
"U",
{"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
"Constrain mean and variance types to float tensors. It allows all float type for U.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateShapeAndTypeFromFirstInput(ctx);
propagateShapeFromInputToOutput(ctx, 0, 0);

Dim num_channels;

unifyInputDim(ctx, 0, 1, num_channels);
unifyInputDim(ctx, 1, 0, num_channels);
unifyInputDim(ctx, 2, 0, num_channels);
unifyInputDim(ctx, 3, 0, num_channels);
unifyInputDim(ctx, 4, 0, num_channels);

if (ctx.getAttribute("training_mode") &&
static_cast<int>(ctx.getAttribute("training_mode")->i()) != 0) {
if (ctx.getNumOutputs() != 5)
fail_shape_inference(
"This number of op outputs should be 5 when Training_mode = True, but it is not.");
} else {
if (ctx.getNumOutputs() != 1)
fail_shape_inference(
"This number of op outputs should be 1 when Training_mode = False, but it is not.");
}

if (ctx.getNumOutputs() > 1) {
ONNX_NAMESPACE::TensorShapeProto outputs_shape;
*outputs_shape.add_dim() = num_channels; // channel

propagateElemTypeFromInputToOutput(ctx, 3, 1);
updateOutputShape(ctx, 1, outputs_shape);
propagateElemTypeFromInputToOutput(ctx, 4, 2);
updateOutputShape(ctx, 2, outputs_shape);
propagateElemTypeFromInputToOutput(ctx, 3, 3);
updateOutputShape(ctx, 3, outputs_shape);
propagateElemTypeFromInputToOutput(ctx, 4, 4);
updateOutputShape(ctx, 4, outputs_shape);
}
});


ONNX_CONTRIB_OPERATOR_SCHEMA(ReduceAllL2)
.SetDomain(kMSDomain)
.SinceVersion(1)
Expand Down
56 changes: 56 additions & 0 deletions orttraining/orttraining/core/optimizer/batchnorm_replacement.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "orttraining/core/optimizer/batchnorm_replacement.h"

#include "core/common/logging/logging.h"
#include "core/optimizer/rewrite_rule.h"
#include "core/optimizer/utils.h"
#include "core/graph/graph.h"
#include "core/graph/graph_utils.h"

namespace onnxruntime {

Status BatchNormReplacement::Apply(Graph& graph, Node& bn_node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
const auto& bn_inputs = bn_node.MutableInputDefs();
auto& bn_outputs = bn_node.MutableOutputDefs();
const NodeArg* scale_input_def = bn_inputs[1];
auto scale_input_def_type_proto = scale_input_def->TypeAsProto();

if (bn_outputs.size() == 1) {
NodeArg& running_mean_def = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("running_mean_def"), scale_input_def_type_proto);
NodeArg& running_var_def = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("running_var_def"), scale_input_def_type_proto);
bn_outputs.push_back(&running_mean_def);
bn_outputs.push_back(&running_var_def);
}

if (bn_outputs.size() == 3) {
NodeArg& saved_mean_def = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("saved_mean_def"), scale_input_def_type_proto);
NodeArg& saved_inv_std = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("saved_inv_std"), scale_input_def_type_proto);
bn_outputs.push_back(&saved_inv_std);
bn_outputs.push_back(&saved_mean_def);
}

// check Batch Normalization node has 5 output node args for training mode
ORT_ENFORCE(bn_node.OutputDefs().size() == 5);

Node& batchnorm_internal_node = graph.AddNode(graph.GenerateNodeName(bn_node.Name() + "_BatchNormInternal"),
"BatchNormInternal",
"BatchNormalization with saved mean/inv_std_dev",
bn_inputs,
bn_outputs,
&bn_node.GetAttributes(),
kMSDomain);
batchnorm_internal_node.AddAttribute("training_mode", static_cast<int64_t>(1));
// Assign provider to this new node. Provider should be same as the provider for old node.
batchnorm_internal_node.SetExecutionProviderType(bn_node.GetExecutionProviderType());
graph_utils::FinalizeNodeFusion(graph, batchnorm_internal_node, bn_node);
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
return Status::OK();
}

bool BatchNormReplacement::SatisfyCondition(const Graph&, const Node&, const logging::Logger&) const {
return true;
}

} // namespace onnxruntime
30 changes: 30 additions & 0 deletions orttraining/orttraining/core/optimizer/batchnorm_replacement.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/optimizer/rewrite_rule.h"

namespace onnxruntime {

/**
@Class BatchNorm Replacement

Rewrite rule that replaces BatchNorm with BatchNormInternal, that has additional outputs
for saved_mean and saved_std_dev
*/
class BatchNormReplacement : public RewriteRule {
public:
BatchNormReplacement() noexcept : RewriteRule("BatchNormReplacement") {}

std::vector<std::string> TargetOpTypes() const noexcept override {
return {"BatchNormalization"};
}

private:
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;

Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
};

} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
#include "orttraining/core/framework/distributed_run_context.h"
#include "core/optimizer/bias_dropout_fusion.h"
#include "orttraining/core/optimizer/concat_replacement.h"
#include "orttraining/core/optimizer/batchnorm_replacement.h"
#include "orttraining/core/optimizer/insert_output_rewriter.h"
#include "orttraining/core/optimizer/localized_recompute.h"
#include "orttraining/core/optimizer/transformer_layer_recompute.h"
Expand Down Expand Up @@ -74,7 +75,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
std::make_unique<RuleBasedGraphTransformer>(optimizer_utils::GenerateRuleBasedTransformerName(level),
compatible_eps);
rule_transformer->Register(std::make_unique<InsertMaxPoolOutput>());
rule_transformer->Register(std::make_unique<AdjustBatchNormOutputs>());
rule_transformer->Register(std::make_unique<BatchNormReplacement>());
rule_transformer->Register(std::make_unique<UnsqueezeElimination>());
rule_transformer->Register(std::make_unique<ExpandElimination>());
rule_transformer->Register(std::make_unique<CastElimination>());
Expand Down
30 changes: 0 additions & 30 deletions orttraining/orttraining/core/optimizer/insert_output_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,34 +60,4 @@ bool InsertSoftmaxCrossEntropyLossOutput::SatisfyCondition(const Graph& /*graph*
return false;
}

Status AdjustBatchNormOutputs::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& /*logger*/) const {
auto& outputs = node.MutableOutputDefs();
const auto& inputs = node.InputDefs();
const NodeArg* scale_input_def = inputs[1];
auto scale_input_def_type_proto = scale_input_def->TypeAsProto();

NodeArg& running_mean_def = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("running_mean_def"), scale_input_def_type_proto);
NodeArg& running_var_def = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("running_var_def"), scale_input_def_type_proto);
NodeArg& saved_mean_def = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("saved_mean_def"), scale_input_def_type_proto);
NodeArg& saved_var_def = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("saved_var_def"), scale_input_def_type_proto);

outputs.push_back(&running_mean_def);
outputs.push_back(&running_var_def);
outputs.push_back(&saved_mean_def);
outputs.push_back(&saved_var_def);

// check Batch Normalization node has 5 output node args for training mode
ORT_ENFORCE(node.OutputDefs().size() == 5);

rule_effect = RewriteRuleEffect::kUpdatedCurrentNode;
return Status::OK();
}

bool AdjustBatchNormOutputs::SatisfyCondition(const Graph& /*graph*/, const Node& node, const logging::Logger& /*logger*/) const {
if (node.OutputDefs().size() == 1) {
return true;
}
return false;
}

} // namespace onnxruntime
17 changes: 0 additions & 17 deletions orttraining/orttraining/core/optimizer/insert_output_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,4 @@ class InsertSoftmaxCrossEntropyLossOutput : public RewriteRule {
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
};

// Rewrite rule that adjust Batch Normalization nodes to have 5 outputs for training mode
// instead of 1 for inference mode
class AdjustBatchNormOutputs : public RewriteRule {
public:
AdjustBatchNormOutputs() noexcept
: RewriteRule("AdjustBatchNormOutputs") {
}

std::vector<std::string> TargetOpTypes() const noexcept override {
return {"BatchNormalization"};
}

private:
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;

Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
};
} // namespace onnxruntime
Loading