From e2ba74fb974f10e70a2ef5c06188f5a93bf5444b Mon Sep 17 00:00:00 2001 From: Pranav Prakash Date: Fri, 30 Apr 2021 21:57:36 -0700 Subject: [PATCH 1/5] Add transformer for BatchNorm -> BN Internal --- .../core/framework/gradient_graph_builder.cc | 1 - .../core/graph/training_op_defs.cc | 64 +++++++++++++++++++ .../core/optimizer/batchnorm_replacement.cc | 55 ++++++++++++++++ .../core/optimizer/batchnorm_replacement.h | 30 +++++++++ .../core/optimizer/graph_transformer_utils.cc | 3 +- .../core/optimizer/insert_output_rewriter.cc | 30 --------- .../core/optimizer/insert_output_rewriter.h | 17 ----- 7 files changed, 151 insertions(+), 49 deletions(-) create mode 100644 orttraining/orttraining/core/optimizer/batchnorm_replacement.cc create mode 100644 orttraining/orttraining/core/optimizer/batchnorm_replacement.h diff --git a/orttraining/orttraining/core/framework/gradient_graph_builder.cc b/orttraining/orttraining/core/framework/gradient_graph_builder.cc index 5f4fd90aeacf4..5f82da6966d21 100644 --- a/orttraining/orttraining/core/framework/gradient_graph_builder.cc +++ b/orttraining/orttraining/core/framework/gradient_graph_builder.cc @@ -32,7 +32,6 @@ GradientGraphBuilder::GradientGraphBuilder(Graph* graph, auto rule_based_graph_transformer = onnxruntime::make_unique("pre_training_rule_based_graph_transformer"); rule_based_graph_transformer->Register(make_unique()); - rule_based_graph_transformer->Register(make_unique()); graph_transformation_mgr_.Register(std::move(rule_based_graph_transformer), TransformerLevel::Level2); diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index 9bbf962ee4320..def943a9585b5 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -1947,6 +1947,70 @@ Return true if all elements are true and false otherwise. {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"}, "Constrain input and output types to float tensors."); + ONNX_CONTRIB_OPERATOR_SCHEMA(BatchNormInternal) + .SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL) + .SetDoc("Variant of BatchNormalization with additional output for saved_mean/inv_std_dev.") + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("epsilon", "epsilon value", AttributeProto::FLOAT, 1e-5f) + .Attr("momentum", "momentum value", AttributeProto::FLOAT, 0.9f) + .Attr("training_mode", "true if training", AttributeProto::INT, static_cast(1)) + .Input(0, "X", "Input tensor.", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input(1, "scale", "Scale tensor of shape (C).", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input(2, "B", "Bias tensor of shape (C).", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input(3, "input_mean", "running mean tensor of shape (C).", "U", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input(4, "input_var", "running variance tensor of shape (C).", "U", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Output(0, "Y", "The output tensor of the same shape as X", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Output(1,"running_mean", "The running mean after BN.", "U", OpSchema::Optional, true, 1, OpSchema::NonDifferentiable) + .Output(2, "running_var", "Running var after BN", "U", OpSchema::Optional, true, 1, OpSchema::NonDifferentiable) + .Output(3, "saved_mean", "Mean of the batch", "U", OpSchema::Optional, true, 1, OpSchema::NonDifferentiable) + .Output(4, "saved_inv_std", "Inverse standard deviation for the batch", "U", OpSchema::Optional, true, 1, OpSchema::NonDifferentiable) + .TypeConstraint( + "T", + {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"}, + "Constrain input and output types to float tensors.") + .TypeConstraint( + "U", + {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"}, + "Constrain mean and variance types to float tensors. It allows all float type for U.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateShapeAndTypeFromFirstInput(ctx); + propagateShapeFromInputToOutput(ctx, 0, 0); + + Dim num_channels; + + unifyInputDim(ctx, 0, 1, num_channels); + unifyInputDim(ctx, 1, 0, num_channels); + unifyInputDim(ctx, 2, 0, num_channels); + unifyInputDim(ctx, 3, 0, num_channels); + unifyInputDim(ctx, 4, 0, num_channels); + + if (ctx.getAttribute("training_mode") && + static_cast(ctx.getAttribute("training_mode")->i()) != 0) { + if (ctx.getNumOutputs() != 3) + fail_shape_inference( + "This number of op outputs should be 3 when Training_mode = True, but it is not."); + } else { + if (ctx.getNumOutputs() != 1) + fail_shape_inference( + "This number of op outputs should be 1 when Training_mode = False, but it is not."); + } + + if (ctx.getNumOutputs() > 1) { + ONNX_NAMESPACE::TensorShapeProto outputs_shape; + *outputs_shape.add_dim() = num_channels; // channel + + propagateElemTypeFromInputToOutput(ctx, 3, 1); + updateOutputShape(ctx, 1, outputs_shape); + + if (ctx.getNumOutputs() > 2) { + propagateElemTypeFromInputToOutput(ctx, 4, 2); + updateOutputShape(ctx, 2, outputs_shape); + } + } + }); + + ONNX_CONTRIB_OPERATOR_SCHEMA(ReduceAllL2) .SetDomain(kMSDomain) .SinceVersion(1) diff --git a/orttraining/orttraining/core/optimizer/batchnorm_replacement.cc b/orttraining/orttraining/core/optimizer/batchnorm_replacement.cc new file mode 100644 index 0000000000000..24de7c171d5ae --- /dev/null +++ b/orttraining/orttraining/core/optimizer/batchnorm_replacement.cc @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/core/optimizer/batchnorm_replacement.h" + +#include "core/common/logging/logging.h" +#include "core/optimizer/rewrite_rule.h" +#include "core/optimizer/utils.h" +#include "core/graph/graph.h" +#include "core/graph/graph_utils.h" + +namespace onnxruntime { + +Status BatchNormReplacement::Apply(Graph& graph, Node& bn_node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { + const auto& bn_inputs = bn_node.MutableInputDefs(); + auto& bn_outputs = bn_node.MutableOutputDefs(); + const NodeArg* scale_input_def = bn_inputs[1]; + auto scale_input_def_type_proto = scale_input_def->TypeAsProto(); + + // Guard against a BatchNorm that already has optional outputs present for some reason + if (bn_outputs.size() == 1) { + NodeArg& running_mean_def = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("running_mean_def"), scale_input_def_type_proto); + NodeArg& running_var_def = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("running_var_def"), scale_input_def_type_proto); + bn_outputs.push_back(&running_mean_def); + bn_outputs.push_back(&running_var_def); + + NodeArg& saved_mean_def = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("saved_mean_def"), scale_input_def_type_proto); + NodeArg& saved_inv_std = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("saved_inv_std"), scale_input_def_type_proto); + bn_outputs.push_back(&saved_inv_std); + bn_outputs.push_back(&saved_mean_def); + } + + // check Batch Normalization node has 5 output node args for training mode + ORT_ENFORCE(bn_node.OutputDefs().size() == 5); + + Node& batchnorm_internal_node = graph.AddNode(graph.GenerateNodeName("BatchNormInternal"), + "BatchNormInternal", + "BatchNormalization with saved mean/inv_std_dev", + bn_inputs, + bn_outputs, + &bn_node.GetAttributes(), + kMSDomain); + batchnorm_internal_node.AddAttribute("is_training", static_cast(1)); + // Assign provider to this new node. Provider should be same as the provider for old node. + batchnorm_internal_node.SetExecutionProviderType(bn_node.GetExecutionProviderType()); + graph_utils::FinalizeNodeFusion(graph, batchnorm_internal_node, bn_node); + rule_effect = RewriteRuleEffect::kRemovedCurrentNode; + return Status::OK(); +} + +bool BatchNormReplacement::SatisfyCondition(const Graph&, const Node&, const logging::Logger&) const { + return true; +} + +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/batchnorm_replacement.h b/orttraining/orttraining/core/optimizer/batchnorm_replacement.h new file mode 100644 index 0000000000000..2d0705bc7a6f1 --- /dev/null +++ b/orttraining/orttraining/core/optimizer/batchnorm_replacement.h @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/rewrite_rule.h" + +namespace onnxruntime { + +/** +@Class BatchNorm Replacement + +Rewrite rule that replaces BatchNorm with BatchNormInternal, that has additional outputs +for saved_mean and saved_std_dev +*/ +class BatchNormReplacement : public RewriteRule { + public: + BatchNormReplacement() noexcept : RewriteRule("BatchNormReplacement") {} + + std::vector TargetOpTypes() const noexcept override { + return {"BatchNormalization"}; + } + + private: + bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; + + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; +}; + +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 13c5961e109b4..9c3de0b734bfe 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -46,6 +46,7 @@ #include "orttraining/core/framework/distributed_run_context.h" #include "core/optimizer/bias_dropout_fusion.h" #include "orttraining/core/optimizer/concat_replacement.h" +#include "orttraining/core/optimizer/batchnorm_replacement.h" #include "orttraining/core/optimizer/insert_output_rewriter.h" #include "orttraining/core/optimizer/localized_recompute.h" #include "orttraining/core/optimizer/transformer_layer_recompute.h" @@ -73,7 +74,7 @@ std::vector> GeneratePreTrainingTransformers( onnxruntime::make_unique(optimizer_utils::GenerateRuleBasedTransformerName(level), compatible_eps); rule_transformer->Register(make_unique()); - rule_transformer->Register(make_unique()); + rule_transformer->Register(make_unique()); rule_transformer->Register(make_unique()); rule_transformer->Register(make_unique()); rule_transformer->Register(make_unique()); diff --git a/orttraining/orttraining/core/optimizer/insert_output_rewriter.cc b/orttraining/orttraining/core/optimizer/insert_output_rewriter.cc index b240aa9ef35d5..2aade8c9bc1f9 100644 --- a/orttraining/orttraining/core/optimizer/insert_output_rewriter.cc +++ b/orttraining/orttraining/core/optimizer/insert_output_rewriter.cc @@ -60,34 +60,4 @@ bool InsertSoftmaxCrossEntropyLossOutput::SatisfyCondition(const Graph& /*graph* return false; } -Status AdjustBatchNormOutputs::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& /*logger*/) const { - auto& outputs = node.MutableOutputDefs(); - const auto& inputs = node.InputDefs(); - const NodeArg* scale_input_def = inputs[1]; - auto scale_input_def_type_proto = scale_input_def->TypeAsProto(); - - NodeArg& running_mean_def = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("running_mean_def"), scale_input_def_type_proto); - NodeArg& running_var_def = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("running_var_def"), scale_input_def_type_proto); - NodeArg& saved_mean_def = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("saved_mean_def"), scale_input_def_type_proto); - NodeArg& saved_var_def = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("saved_var_def"), scale_input_def_type_proto); - - outputs.push_back(&running_mean_def); - outputs.push_back(&running_var_def); - outputs.push_back(&saved_mean_def); - outputs.push_back(&saved_var_def); - - // check Batch Normalization node has 5 output node args for training mode - ORT_ENFORCE(node.OutputDefs().size() == 5); - - rule_effect = RewriteRuleEffect::kUpdatedCurrentNode; - return Status::OK(); -} - -bool AdjustBatchNormOutputs::SatisfyCondition(const Graph& /*graph*/, const Node& node, const logging::Logger& /*logger*/) const { - if (node.OutputDefs().size() == 1) { - return true; - } - return false; -} - } // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/insert_output_rewriter.h b/orttraining/orttraining/core/optimizer/insert_output_rewriter.h index 3c1858228c15a..20eb79452739f 100644 --- a/orttraining/orttraining/core/optimizer/insert_output_rewriter.h +++ b/orttraining/orttraining/core/optimizer/insert_output_rewriter.h @@ -42,21 +42,4 @@ class InsertSoftmaxCrossEntropyLossOutput : public RewriteRule { Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; }; -// Rewrite rule that adjust Batch Normalization nodes to have 5 outputs for training mode -// instead of 1 for inference mode -class AdjustBatchNormOutputs : public RewriteRule { - public: - AdjustBatchNormOutputs() noexcept - : RewriteRule("AdjustBatchNormOutputs") { - } - - std::vector TargetOpTypes() const noexcept override { - return {"BatchNormalization"}; - } - - private: - bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; - - Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; -}; } // namespace onnxruntime From 8142bbbe351da338c4fcd107e75dae5681ad7add Mon Sep 17 00:00:00 2001 From: Pranav Prakash Date: Mon, 3 May 2021 14:23:46 -0700 Subject: [PATCH 2/5] Add test for BN replacement transformer --- .../core/graph/training_op_defs.cc | 15 ++++--- .../core/optimizer/batchnorm_replacement.cc | 2 +- .../test/optimizer/graph_transform_test.cc | 44 +++++++++++++++++++ 3 files changed, 53 insertions(+), 8 deletions(-) diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index def943a9585b5..3a5d3d572fa91 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -1987,9 +1987,9 @@ Return true if all elements are true and false otherwise. if (ctx.getAttribute("training_mode") && static_cast(ctx.getAttribute("training_mode")->i()) != 0) { - if (ctx.getNumOutputs() != 3) + if (ctx.getNumOutputs() != 5) fail_shape_inference( - "This number of op outputs should be 3 when Training_mode = True, but it is not."); + "This number of op outputs should be 5 when Training_mode = True, but it is not."); } else { if (ctx.getNumOutputs() != 1) fail_shape_inference( @@ -2002,11 +2002,12 @@ Return true if all elements are true and false otherwise. propagateElemTypeFromInputToOutput(ctx, 3, 1); updateOutputShape(ctx, 1, outputs_shape); - - if (ctx.getNumOutputs() > 2) { - propagateElemTypeFromInputToOutput(ctx, 4, 2); - updateOutputShape(ctx, 2, outputs_shape); - } + propagateElemTypeFromInputToOutput(ctx, 4, 2); + updateOutputShape(ctx, 2, outputs_shape); + propagateElemTypeFromInputToOutput(ctx, 3, 3); + updateOutputShape(ctx, 3, outputs_shape); + propagateElemTypeFromInputToOutput(ctx, 4, 4); + updateOutputShape(ctx, 4, outputs_shape); } }); diff --git a/orttraining/orttraining/core/optimizer/batchnorm_replacement.cc b/orttraining/orttraining/core/optimizer/batchnorm_replacement.cc index 24de7c171d5ae..0838c871262ec 100644 --- a/orttraining/orttraining/core/optimizer/batchnorm_replacement.cc +++ b/orttraining/orttraining/core/optimizer/batchnorm_replacement.cc @@ -40,7 +40,7 @@ Status BatchNormReplacement::Apply(Graph& graph, Node& bn_node, RewriteRuleEffec bn_outputs, &bn_node.GetAttributes(), kMSDomain); - batchnorm_internal_node.AddAttribute("is_training", static_cast(1)); + batchnorm_internal_node.AddAttribute("training_mode", static_cast(1)); // Assign provider to this new node. Provider should be same as the provider for old node. batchnorm_internal_node.SetExecutionProviderType(bn_node.GetExecutionProviderType()); graph_utils::FinalizeNodeFusion(graph, batchnorm_internal_node, bn_node); diff --git a/orttraining/orttraining/test/optimizer/graph_transform_test.cc b/orttraining/orttraining/test/optimizer/graph_transform_test.cc index 9d0d1119bac04..5f77caeedf4d7 100644 --- a/orttraining/orttraining/test/optimizer/graph_transform_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transform_test.cc @@ -16,6 +16,7 @@ #include "orttraining/core/optimizer/gist_encode_decode.h" #include "orttraining/core/optimizer/megatron_transformer.h" #include "orttraining/core/optimizer/concat_replacement.h" +#include "orttraining/core/optimizer/batchnorm_replacement.h" #include "orttraining/core/optimizer/localized_recompute.h" #include "test/optimizer/graph_transform_test_fixture.h" #include "test/util/include/default_providers.h" @@ -33,6 +34,49 @@ namespace test { #define MODEL_FOLDER ORT_TSTR("testdata/transform/") +TEST_F(GraphTransformationTests, BatchNormReplacement) { + Model model("BatchNormReplacement", true, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {{"", 14}, {"com.microsoft", 1}}, + {}, *logger_); + auto& graph = model.MainGraph(); + + std::vector inputs; + std::vector outputs; + + // 1x3x3x3 + TypeProto input_tensor_type; + input_tensor_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); + input_tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + input_tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(3); + input_tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(3); + input_tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(3); + + TypeProto scale_tensor_type; + scale_tensor_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); + scale_tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(3); + + auto& input_X = graph.GetOrCreateNodeArg("X", &input_tensor_type); + auto& input_scale = graph.GetOrCreateNodeArg("scale", &scale_tensor_type); + auto& input_B = graph.GetOrCreateNodeArg("B", &scale_tensor_type); + auto& input_mean = graph.GetOrCreateNodeArg("input_mean", &scale_tensor_type); + auto& input_var = graph.GetOrCreateNodeArg("input_var", &scale_tensor_type); + + auto& output_Y = graph.GetOrCreateNodeArg("Y", &input_tensor_type); + graph.AddNode("BN", "BatchNormalization", "", {&input_X, &input_scale, &input_B, &input_mean, &input_var}, {&output_Y}); + + auto status = graph.Resolve(); + EXPECT_EQ(status, Status::OK()); + + auto rule_transformer_L1 = onnxruntime::make_unique("BatchNormReplacement"); + rule_transformer_L1->Register(onnxruntime::make_unique()); + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + ASSERT_TRUE(graph.NumberOfNodes() == 1); + // Make sure that BN was updated to add outputs + ASSERT_TRUE(graph.Nodes().begin()->MutableOutputDefs().size() == 5); +} + TEST_F(GraphTransformationTests, DropoutWithZeroRatioElimination) { auto model_uri = MODEL_FOLDER "dropout_ratio.onnx"; std::shared_ptr model; From 8dddf56ee5cd0ed32ba041e476ef0e5437765db2 Mon Sep 17 00:00:00 2001 From: Pranav Prakash Date: Thu, 6 May 2021 16:45:23 -0700 Subject: [PATCH 3/5] Resolve comments --- .../core/framework/gradient_graph_builder.cc | 10 ---------- .../core/framework/gradient_graph_builder.h | 2 -- .../orttraining/test/optimizer/graph_transform_test.cc | 1 + 3 files changed, 1 insertion(+), 12 deletions(-) diff --git a/orttraining/orttraining/core/framework/gradient_graph_builder.cc b/orttraining/orttraining/core/framework/gradient_graph_builder.cc index 5f82da6966d21..45f20598a3bd5 100644 --- a/orttraining/orttraining/core/framework/gradient_graph_builder.cc +++ b/orttraining/orttraining/core/framework/gradient_graph_builder.cc @@ -29,13 +29,6 @@ GradientGraphBuilder::GradientGraphBuilder(Graph* graph, loss_node_arg_name_(loss_node_arg_name), gradient_graph_config_(gradient_graph_config), logger_(logger) { - auto rule_based_graph_transformer = - onnxruntime::make_unique("pre_training_rule_based_graph_transformer"); - rule_based_graph_transformer->Register(make_unique()); - - graph_transformation_mgr_.Register(std::move(rule_based_graph_transformer), - TransformerLevel::Level2); - auto forward_reachable_nodes = BFSWithStopGradient(x_node_arg_names); for (const auto& name : y_node_arg_names) { @@ -186,9 +179,6 @@ Status GradientGraphBuilder::CheckNodeArgsReachable() const { } Status GradientGraphBuilder::Build(const std::unordered_set* p_initializer_names_to_preserve) { - auto opt_ret = graph_transformation_mgr_.ApplyTransformers(*graph_, TransformerLevel::Level2, logger_); - ORT_RETURN_IF_ERROR(opt_ret); - GraphAugmenter::GraphDefs gradient_graph_defs; // add "gradient of the loss" node, always 1. if (loss_node_arg_name_ != "") { diff --git a/orttraining/orttraining/core/framework/gradient_graph_builder.h b/orttraining/orttraining/core/framework/gradient_graph_builder.h index 25967e0c77721..f157f5951f5f1 100644 --- a/orttraining/orttraining/core/framework/gradient_graph_builder.h +++ b/orttraining/orttraining/core/framework/gradient_graph_builder.h @@ -110,8 +110,6 @@ class GradientGraphBuilder { const logging::Logger& logger_; - onnxruntime::GraphTransformerManager graph_transformation_mgr_{5}; - // key: ArgDef for the gradient after accumulation // value: ArgDef for the gradients to be accumulated struct ArgDefHasher { diff --git a/orttraining/orttraining/test/optimizer/graph_transform_test.cc b/orttraining/orttraining/test/optimizer/graph_transform_test.cc index 5f77caeedf4d7..1900bb76f3e57 100644 --- a/orttraining/orttraining/test/optimizer/graph_transform_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transform_test.cc @@ -75,6 +75,7 @@ TEST_F(GraphTransformationTests, BatchNormReplacement) { ASSERT_TRUE(graph.NumberOfNodes() == 1); // Make sure that BN was updated to add outputs ASSERT_TRUE(graph.Nodes().begin()->MutableOutputDefs().size() == 5); + ASSERT_TRUE(graph.Nodes().begin()->OpType().compare("BatchNormInternal") == 0); } TEST_F(GraphTransformationTests, DropoutWithZeroRatioElimination) { From 772dc3fd5465ef8a825f450e34a58e6827d6023c Mon Sep 17 00:00:00 2001 From: Pranav Prakash Date: Thu, 6 May 2021 17:19:42 -0700 Subject: [PATCH 4/5] Resolve comments --- .../core/framework/gradient_graph_builder.h | 1 + .../core/graph/gradient_builder_registry.cc | 2 +- .../core/optimizer/batchnorm_replacement.cc | 5 +- .../core/optimizer/graph_transformer_utils.cc | 3 +- .../test/optimizer/graph_transform_test.cc | 103 +++++++++++++++++- 5 files changed, 107 insertions(+), 7 deletions(-) diff --git a/orttraining/orttraining/core/framework/gradient_graph_builder.h b/orttraining/orttraining/core/framework/gradient_graph_builder.h index 958e67d9ec044..130f42654789c 100644 --- a/orttraining/orttraining/core/framework/gradient_graph_builder.h +++ b/orttraining/orttraining/core/framework/gradient_graph_builder.h @@ -30,6 +30,7 @@ static std::unordered_map> {"Not", {0}}, {"And", {0, 1}}, {"BatchNormalization", {3, 4}}, + {"BatchNormInternal", {3, 4}}, {"Or", {0, 1}}, {"Xor", {0, 1}}, {"Equal", {0, 1}}, diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.cc b/orttraining/orttraining/core/graph/gradient_builder_registry.cc index 7399f96308cc8..ec84e20d4498c 100755 --- a/orttraining/orttraining/core/graph/gradient_builder_registry.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_registry.cc @@ -88,7 +88,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() { REGISTER_GRADIENT_BUILDER("FastGelu", GetFastGeluGradient); REGISTER_GRADIENT_BUILDER("LayerNormalization", GetLayerNormalizationGradient); REGISTER_GRADIENT_BUILDER("SimplifiedLayerNormalization", GetSimplifiedLayerNormalizationGradient); - REGISTER_GRADIENT_BUILDER("BatchNormalization", GetBatchNormalizationGradient); + REGISTER_GRADIENT_BUILDER("BatchNormInternal", GetBatchNormalizationGradient); REGISTER_GRADIENT_BUILDER("MegatronF", GetMegatronFGradient); REGISTER_GRADIENT_BUILDER("MegatronG", GetMegatronGGradient); REGISTER_GRADIENT_BUILDER("Slice", GetSliceGradient); diff --git a/orttraining/orttraining/core/optimizer/batchnorm_replacement.cc b/orttraining/orttraining/core/optimizer/batchnorm_replacement.cc index 0838c871262ec..28e6cb93cdb45 100644 --- a/orttraining/orttraining/core/optimizer/batchnorm_replacement.cc +++ b/orttraining/orttraining/core/optimizer/batchnorm_replacement.cc @@ -17,13 +17,14 @@ Status BatchNormReplacement::Apply(Graph& graph, Node& bn_node, RewriteRuleEffec const NodeArg* scale_input_def = bn_inputs[1]; auto scale_input_def_type_proto = scale_input_def->TypeAsProto(); - // Guard against a BatchNorm that already has optional outputs present for some reason if (bn_outputs.size() == 1) { NodeArg& running_mean_def = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("running_mean_def"), scale_input_def_type_proto); NodeArg& running_var_def = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("running_var_def"), scale_input_def_type_proto); bn_outputs.push_back(&running_mean_def); bn_outputs.push_back(&running_var_def); + } + if (bn_outputs.size() == 3) { NodeArg& saved_mean_def = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("saved_mean_def"), scale_input_def_type_proto); NodeArg& saved_inv_std = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("saved_inv_std"), scale_input_def_type_proto); bn_outputs.push_back(&saved_inv_std); @@ -33,7 +34,7 @@ Status BatchNormReplacement::Apply(Graph& graph, Node& bn_node, RewriteRuleEffec // check Batch Normalization node has 5 output node args for training mode ORT_ENFORCE(bn_node.OutputDefs().size() == 5); - Node& batchnorm_internal_node = graph.AddNode(graph.GenerateNodeName("BatchNormInternal"), + Node& batchnorm_internal_node = graph.AddNode(graph.GenerateNodeName(bn_node.Name() + "_BatchNormInternal"), "BatchNormInternal", "BatchNormalization with saved mean/inv_std_dev", bn_inputs, diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index e226fc7964669..cb83f48458f53 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -75,8 +75,7 @@ std::vector> GeneratePreTrainingTransformers( std::make_unique(optimizer_utils::GenerateRuleBasedTransformerName(level), compatible_eps); rule_transformer->Register(std::make_unique()); - rule_transformer->Register(make_unique()); - rule_transformer->Register(std::make_unique()); + rule_transformer->Register(std::make_unique()); rule_transformer->Register(std::make_unique()); rule_transformer->Register(std::make_unique()); rule_transformer->Register(std::make_unique()); diff --git a/orttraining/orttraining/test/optimizer/graph_transform_test.cc b/orttraining/orttraining/test/optimizer/graph_transform_test.cc index f00774d310bbf..764a945f212f7 100644 --- a/orttraining/orttraining/test/optimizer/graph_transform_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transform_test.cc @@ -66,8 +66,107 @@ TEST_F(GraphTransformationTests, BatchNormReplacement) { auto status = graph.Resolve(); EXPECT_EQ(status, Status::OK()); - auto rule_transformer_L1 = onnxruntime::make_unique("BatchNormReplacement"); - rule_transformer_L1->Register(onnxruntime::make_unique()); + auto rule_transformer_L1 = std::make_unique("BatchNormReplacement"); + rule_transformer_L1->Register(std::make_unique()); + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + ASSERT_TRUE(graph.NumberOfNodes() == 1); + // Make sure that BN was updated to add outputs + ASSERT_TRUE(graph.Nodes().begin()->MutableOutputDefs().size() == 5); + ASSERT_TRUE(graph.Nodes().begin()->OpType().compare("BatchNormInternal") == 0); +} + + +TEST_F(GraphTransformationTests, BatchNormReplacementWithOptionalOutputPresentOpset14) { + Model model("BatchNormReplacement", true, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {{"", 14}, {"com.microsoft", 1}}, + {}, *logger_); + auto& graph = model.MainGraph(); + + std::vector inputs; + std::vector outputs; + + // 1x3x3x3 + TypeProto input_tensor_type; + input_tensor_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); + input_tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + input_tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(3); + input_tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(3); + input_tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(3); + + TypeProto scale_tensor_type; + scale_tensor_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); + scale_tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(3); + + auto& input_X = graph.GetOrCreateNodeArg("X", &input_tensor_type); + auto& input_scale = graph.GetOrCreateNodeArg("scale", &scale_tensor_type); + auto& input_B = graph.GetOrCreateNodeArg("B", &scale_tensor_type); + auto& input_mean = graph.GetOrCreateNodeArg("input_mean", &scale_tensor_type); + auto& input_var = graph.GetOrCreateNodeArg("input_var", &scale_tensor_type); + + auto& output_Y = graph.GetOrCreateNodeArg("Y", &input_tensor_type); + auto& output_running_mean = graph.GetOrCreateNodeArg("running_mean", &scale_tensor_type); + auto& output_running_var = graph.GetOrCreateNodeArg("running_var", &scale_tensor_type); + auto& bn_node = graph.AddNode("BN", "BatchNormalization", "", {&input_X, &input_scale, &input_B, &input_mean, &input_var}, + {&output_Y, &output_running_mean, &output_running_var}); + bn_node.AddAttribute("training_mode", static_cast(1)); + + auto status = graph.Resolve(); + EXPECT_EQ(status, Status::OK()); + + auto rule_transformer_L1 = std::make_unique("BatchNormReplacement"); + rule_transformer_L1->Register(std::make_unique()); + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + ASSERT_TRUE(graph.NumberOfNodes() == 1); + // Make sure that BN was updated to add outputs + ASSERT_TRUE(graph.Nodes().begin()->MutableOutputDefs().size() == 5); + ASSERT_TRUE(graph.Nodes().begin()->OpType().compare("BatchNormInternal") == 0); +} + + +TEST_F(GraphTransformationTests, BatchNormReplacementWithOptionalOutputPresentOpset9) { + Model model("BatchNormReplacement", true, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {{"", 9}, {"com.microsoft", 1}}, + {}, *logger_); + auto& graph = model.MainGraph(); + + std::vector inputs; + std::vector outputs; + + // 1x3x3x3 + TypeProto input_tensor_type; + input_tensor_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); + input_tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + input_tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(3); + input_tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(3); + input_tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(3); + + TypeProto scale_tensor_type; + scale_tensor_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); + scale_tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(3); + + auto& input_X = graph.GetOrCreateNodeArg("X", &input_tensor_type); + auto& input_scale = graph.GetOrCreateNodeArg("scale", &scale_tensor_type); + auto& input_B = graph.GetOrCreateNodeArg("B", &scale_tensor_type); + auto& input_mean = graph.GetOrCreateNodeArg("input_mean", &scale_tensor_type); + auto& input_var = graph.GetOrCreateNodeArg("input_var", &scale_tensor_type); + + auto& output_Y = graph.GetOrCreateNodeArg("Y", &input_tensor_type); + auto& output_running_mean = graph.GetOrCreateNodeArg("running_mean", &scale_tensor_type); + auto& output_running_var = graph.GetOrCreateNodeArg("running_var", &scale_tensor_type); + auto& saved_mean = graph.GetOrCreateNodeArg("saved_mean", &scale_tensor_type); + auto& saved_var = graph.GetOrCreateNodeArg("saved_var", &scale_tensor_type); + graph.AddNode("BN", "BatchNormalization", "", {&input_X, &input_scale, &input_B, &input_mean, &input_var}, + {&output_Y, &output_running_mean, &output_running_var, &saved_mean, &saved_var}); + + auto status = graph.Resolve(); + EXPECT_EQ(status, Status::OK()); + + auto rule_transformer_L1 = std::make_unique("BatchNormReplacement"); + rule_transformer_L1->Register(std::make_unique()); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); From abc6b3530fd90d43b311fe4f57aa6dbc331337fd Mon Sep 17 00:00:00 2001 From: Pranav Prakash Date: Fri, 7 May 2021 16:09:49 -0700 Subject: [PATCH 5/5] Revert removal of InsertMaxpoolOutput in gradient_graph_builder --- .../core/framework/gradient_graph_builder.cc | 11 +++++++++++ .../core/framework/gradient_graph_builder.h | 2 ++ 2 files changed, 13 insertions(+) diff --git a/orttraining/orttraining/core/framework/gradient_graph_builder.cc b/orttraining/orttraining/core/framework/gradient_graph_builder.cc index 45f20598a3bd5..b254abc3700eb 100644 --- a/orttraining/orttraining/core/framework/gradient_graph_builder.cc +++ b/orttraining/orttraining/core/framework/gradient_graph_builder.cc @@ -9,6 +9,7 @@ #include "orttraining/core/graph/gradient_builder_registry.h" #include "orttraining/core/graph/gradient_config.h" #include "orttraining/core/optimizer/insert_output_rewriter.h" +#include "orttraining/core/optimizer/batchnorm_replacement.h" #include "core/optimizer/gelu_fusion.h" #include "core/optimizer/rule_based_graph_transformer.h" @@ -29,6 +30,13 @@ GradientGraphBuilder::GradientGraphBuilder(Graph* graph, loss_node_arg_name_(loss_node_arg_name), gradient_graph_config_(gradient_graph_config), logger_(logger) { + auto rule_based_graph_transformer = + std::make_unique("pre_training_rule_based_graph_transformer"); + rule_based_graph_transformer->Register(std::make_unique()); + rule_based_graph_transformer->Register(std::make_unique()); + + graph_transformation_mgr_.Register(std::move(rule_based_graph_transformer), + TransformerLevel::Level2); auto forward_reachable_nodes = BFSWithStopGradient(x_node_arg_names); for (const auto& name : y_node_arg_names) { @@ -179,6 +187,9 @@ Status GradientGraphBuilder::CheckNodeArgsReachable() const { } Status GradientGraphBuilder::Build(const std::unordered_set* p_initializer_names_to_preserve) { + auto opt_ret = graph_transformation_mgr_.ApplyTransformers(*graph_, TransformerLevel::Level2, logger_); + ORT_RETURN_IF_ERROR(opt_ret); + GraphAugmenter::GraphDefs gradient_graph_defs; // add "gradient of the loss" node, always 1. if (loss_node_arg_name_ != "") { diff --git a/orttraining/orttraining/core/framework/gradient_graph_builder.h b/orttraining/orttraining/core/framework/gradient_graph_builder.h index 130f42654789c..e5085a830e8b8 100644 --- a/orttraining/orttraining/core/framework/gradient_graph_builder.h +++ b/orttraining/orttraining/core/framework/gradient_graph_builder.h @@ -112,6 +112,8 @@ class GradientGraphBuilder { const logging::Logger& logger_; + onnxruntime::GraphTransformerManager graph_transformation_mgr_{5}; + // key: ArgDef for the gradient after accumulation // value: ArgDef for the gradients to be accumulated struct ArgDefHasher {