Skip to content

Fix batch norm training op on CPU#6946

Merged
SherlockNoMad merged 6 commits intomicrosoft:masterfrom
pranav-prakash:master
May 1, 2021
Merged

Fix batch norm training op on CPU#6946
SherlockNoMad merged 6 commits intomicrosoft:masterfrom
pranav-prakash:master

Conversation

@pranav-prakash
Copy link
Contributor

@pranav-prakash pranav-prakash commented Mar 9, 2021

Description: Add support for training-forward-mode of BatchNorm on CPU EP, and implement BN Gradient on CPU EP

Fixes the issue described in #6087

Motivation and Context

As mentioned in the above issue, currently batch norm is not implemented for training. This PR makes the following changes:

  • Blacklist the mean/variance tensors from having their gradients calculated. These are calculated directly from the batch during training and are not updated via backprop
  • The onnx spec for BatchNorm currently doesn't provide an attribute to distinguish between training/inference mode (since batch norm has different behavior in both cases). BatchNorm in OP_SET version 7 has no mode attribute onnx/onnx#1042 While we can assume the presence of the optional outputs indicates we're doing training, it's technically valid for a serialized model for inference to contain those as well. The CUDA kernel nonetheless uses this as an indicator of training, and so do we.
  • The CPU implementation of the above was cribbed from caffe2 (as it appeared that the existing inference-only implementation was also taken from there). They do something weird where instead of outputting saved_variance they output inv_std_dev. Apparently this is for ease of interoperability with cuDNN, but it completely breaks the ONNX spec. Nonetheless, I've chosen to do the same because the existing cuda kernel for batchnormgrad also relies on it actually being inv_std_dev.
  • Added a CPU implementation for batchnormgrad. Again to match behavior with the cuda version we too completely break the spec and assume that the saved_var is actually inv_std_dev.

As a sidenote, I wonder if the flaky CUDA tests for the batch norm can also be resolved by d01006f

@pranav-prakash pranav-prakash requested a review from a team as a code owner March 9, 2021 00:17
@neginraoof
Copy link
Contributor

Hey @pranav-prakash
Can you also review this PR to fix onnx spec? onnx/onnx#3333

@pranav-prakash
Copy link
Contributor Author

pranav-prakash commented Mar 23, 2021

@neginraoof I'm not a msft employee, so I don't have any more power than you do with regard to code review. Or did you want me to just look over it and join the discussion? If so I've left my comments,

@snnn snnn added the training issues related to ONNX Runtime training; typically submitted using template label Mar 25, 2021
@pranav-prakash
Copy link
Contributor Author

pranav-prakash commented Apr 2, 2021

Given that the training mode spec for BN is only fully fleshed out for opset 14, do we still need to support the case of training for non-spatial BN (which could happen with opset < 14)? Removing this codepath would simplify things greatly (we could still maintain correctness by guarding training on opset 14 with ORT_ENFORCE).

@SherlockNoMad
Copy link
Contributor

Given that the training mode spec for BN is only fully fleshed out for opset 14, do we still need to support the case of training for non-spatial BN (which could happen with opset < 14)? Removing this codepath would simplify things greatly (we could still maintain correctness by guarding training on opset 14 with ORT_ENFORCE).

yes. Since BN training mode is only officially added since opset 14, let's ease the burden of supporting the non-spatial mode for training. :)

@SherlockNoMad
Copy link
Contributor

Hi @pranav-prakash, thanks a lot for your contribution.
Are you planning to update the PR to deprecate the support for non-spatial mode in training? I can help you with review and merging the PR.

@SherlockNoMad
Copy link
Contributor

Hi @pranav-prakash, I am also wondering what's your use scenario of ORT training? Which company/product are you working on?

@pranav-prakash
Copy link
Contributor Author

pranav-prakash commented Apr 5, 2021

@SherlockNoMad

update the PR to deprecate the support for non-spatial mode in training

Yes I started the PR for this, but then saw that the spec was also updated to remove the saved_mean and saved_var outputs. With these two removed, it seems that you would have to recompute the mean/var in the gradient op. I had asked about the motivation for this change on the associated PR, and it seems that the intent was to allow "backends [to] transform the graph/node into a custom-op of their choice (for backward-propagation)."

In terms of ORT though, does this mean that we would define our own variant of the BatchNorm schema that includes outputs for saved_mean and saved_inv_std and use a graph transform pass to change nodes to this? If not, is there another way to avoid the recomputation of batch_mean/batch_var in the backward pass?

I am also wondering what's your use scenario of ORT training? Which company/product are you working on?

I'm associated with UC Berkeley's architecture research group, and we're working on an ORT EP for our risc-v ml acceleator (Gemmini). In terms of use-cases, at the moment we're primarily interested in training convolutional neural networks (in both bfloat16 and fp32), e.g. resnet or mobilenet.

In terms of the exposed ORT training APIs we make use of, because running python would be too much overhead (we're targeting edge-devices), we call directly into the underlying C++ functions (TrainingRunner, training_session) rather than the python bindings. (Although this doesn't seem to be as fully documented or supported as the inference APIs).

@SherlockNoMad
Copy link
Contributor

Thanks a lot for the detailed introduction! We are really happy to see external contribution to ORT Training!!!

As the for the BatchNorm problem, the plan is to write a custom op (say, BatchNormInternal) that outputs the saved_mean and save_inv_std. We will substitute BatchNorm with BatchNormInternal before building the training graph. The BatchNormGrad op can still assume that O(3) and O(4) are present for speed up the computation.

The rational behind the onnx spec update is that, the "save_inv_std" is an internal implementation detail, (other framework may use save_inv_var instead), so it's better to leave it out of the spec.

@SherlockNoMad
Copy link
Contributor

Just curious, how far have your reached for training convolutional neural networks e.g. resnet or mobilenet? Are they working yet? We are also exploring federated learning on mobile/edge device. Would be nice to have a colab if our plans aligns well.

@pranav-prakash
Copy link
Contributor Author

pranav-prakash commented Apr 6, 2021

@SherlockNoMad

the plan is to write a custom op (say, BatchNormInternal) that outputs the saved_mean and save_inv_std

I see – so for this PR would you like me to create the schema for such a BatchNormTrainingInternal and move the training-mode calculations there? Or did you still want to have an implementation for training_mode = true in the BatchNorm kernel just for completeness (albeit one that will never be used if the op gets replaced with BatchNormTrainingInternal before training).

how far have your reached for training convolutional neural networks

We just recently got a trainer for resnet50 working, although we haven't yet verified end-to-end correctness since fpga simulation of Gemmini is much too slow for training resnet from scratch (we could likely compare results from a few dozen iterations against the CPU EP though).

We are also exploring federated learning on mobile/edge device. Would be nice to have a colab if our plans aligns well.

Yeah federated learning and training at the edge are exactly the scenarios we envisioned Gemmini would be a good fit for. We'd love to discuss our roadmap and see if there's any potential for collaboration here; feel free to email us at {pranavprakash,hngenc}@<university>.edu.

@SherlockNoMad
Copy link
Contributor

BatchNormTrainingInternal and BatchNorm can share the same kernel. Most of the code should be same, except for the handling of output 3 and 4. In the kernel, we can check if the ctx->Output(3, shape) and ctx->Output(4, shape) returns the nullptr.

If you have bandwidth, you can add the BatchNormalizationTraining Schema to training_op_defs.cc.

Also, a sample of the replacement code can be found in concat_replacement.cc. This replaces Concat with ConcatTraining.

@SherlockNoMad
Copy link
Contributor

@pranav-prakash. Great to know that you got the resnet50 working...
Actually I am a bit surprised that you didn't find too many missing gradients....

AFAIK, gradient is missing for LRN, Sum, GlobalMaxPool op... GlobalAveragePool's gradient builder can also need a fix...

@pranav-prakash
Copy link
Contributor Author

pranav-prakash commented Apr 6, 2021

@SherlockNoMad

BatchNormTrainingInternal and BatchNorm can share the same kernel

SG, I'll update the PR accordingly. Not sure I'll have the bandwidth to add the BatchNormalizationTraining op as well, but that should be an easy subsequent PR.

bit surprised that you didn't find too many missing gradients

We converted Sum into Add before training; iirc both the resnet50 structure from the model-zoo and exported pytorch use MaxPool instead of GlobalMaxPool. The pytorch model does use GlobalAveragePool but we didn't seem to run into any issues with the gradient builder on that.

@pranav-prakash pranav-prakash force-pushed the master branch 3 times, most recently from 2f78c94 to a139854 Compare April 6, 2021 23:46
@pranav-prakash
Copy link
Contributor Author

@SherlockNoMad
I've updated the BN kernel to support the opset-14 case. I don't think the onnx submodule has been bumped to the latest commit though, so that will need to be done before this can be merged.

Since the batch norm grad cannot be implemented until the schema & graph transform for internal BatchNormalizationTraining is added, I've reverted that for now. For reference, those can be found at
https://github.com/microsoft/onnxruntime/blob/8892ee4b6d343109699ab292e66c2c7a5e41925a/orttraining/orttraining/training_ops/cpu/nn/batch_norm_grad.h
https://github.com/microsoft/onnxruntime/blob/8892ee4b6d343109699ab292e66c2c7a5e41925a/orttraining/orttraining/training_ops/cpu/nn/batch_norm_grad.cc

@SherlockNoMad
Copy link
Contributor

#7177
is in progress.

@SherlockNoMad
Copy link
Contributor

/azp run Linux CPU CI Pipeline,Linux CPU x64 NoContribops CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline,MacOS CI Pipeline,MacOS NoContribops CI Pipeline,Windows CPU CI Pipeline,Windows GPU CI Pipeline,Windows GPU TensorRT CI Pipeline

@SherlockNoMad
Copy link
Contributor

/azp run orttraining-linux-ci-pipeline,orttraining-mac-ci-pipeline,orttraining-linux-gpu-ci-pipeline,centos7_cpu,Linux CPU Minimal Build E2E CI Pipeline,Linux Nuphar CI Pipeline,MacOS NoContribops CI Pipeline,Linux OpenVINO CI Pipeline,orttraining-distributed

@azure-pipelines
Copy link

Azure Pipelines successfully started running 9 pipeline(s).

@azure-pipelines
Copy link

Azure Pipelines successfully started running 8 pipeline(s).

@pranav-prakash
Copy link
Contributor Author

@SherlockNoMad
Now that the onnx submodule has been bumped to opset 14, I updated this PR to enable the opset 14 BN. Also fixed the previous CI failure.

@SherlockNoMad
Copy link
Contributor

see line 44 in batch_norm_op_test.cc

  std::unordered_set<std::string> excluded_eps = {kTensorrtExecutionProvider};
  if (spatial_mode == 0) {
    excluded_eps.insert(kOpenVINOExecutionProvider);
  }

I think it's correct to set them in excluded_eps .

SherlockNoMad
SherlockNoMad previously approved these changes Apr 30, 2021
Copy link
Contributor

@SherlockNoMad SherlockNoMad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for this PR.
I think it's ready to be merge, after the fix for the TRT and openVino fix.

@pranav-prakash
Copy link
Contributor Author

@SherlockNoMad fixed.

One other question I had was about the old cuda-test for ForwardTrainingTestWithSavedOutputsOpset9. The old test had

  test.AddOutput<float>("running_mean", channel_dims, {-0.1754f, 0.303106f});
  test.AddOutput<float>("saved_mean", channel_dims, {-0.306f, 0.115f});

That is, it had running_mean equal to the saved_mean, when according to the formula should have been equal to momentum*mean + saved_mean*(1-momentum). Was the old test incorrect, or did I miss something obvious?

@SherlockNoMad
Copy link
Contributor

/azp run Linux CPU CI Pipeline,Linux CPU x64 NoContribops CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline,MacOS CI Pipeline,MacOS NoContribops CI Pipeline,Windows CPU CI Pipeline,Windows GPU CI Pipeline,Windows GPU TensorRT CI Pipeline

@SherlockNoMad
Copy link
Contributor

/azp run orttraining-linux-ci-pipeline,orttraining-mac-ci-pipeline,orttraining-linux-gpu-ci-pipeline,centos7_cpu,Linux CPU Minimal Build E2E CI Pipeline,Linux Nuphar CI Pipeline,MacOS NoContribops CI Pipeline,Linux OpenVINO CI Pipeline,orttraining-distributed

@azure-pipelines
Copy link

Azure Pipelines successfully started running 9 pipeline(s).

@SherlockNoMad
Copy link
Contributor

I think the old test case was incorrect.

@azure-pipelines
Copy link

Azure Pipelines successfully started running 8 pipeline(s).

SherlockNoMad
SherlockNoMad previously approved these changes Apr 30, 2021
@SherlockNoMad
Copy link
Contributor

/azp run Linux CPU CI Pipeline,Linux CPU x64 NoContribops CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline,MacOS CI Pipeline,MacOS NoContribops CI Pipeline,Windows CPU CI Pipeline,Windows GPU CI Pipeline,Windows GPU TensorRT CI Pipeline

@SherlockNoMad
Copy link
Contributor

/azp run orttraining-linux-ci-pipeline,orttraining-mac-ci-pipeline,orttraining-linux-gpu-ci-pipeline,centos7_cpu,Linux CPU Minimal Build E2E CI Pipeline,Linux Nuphar CI Pipeline,MacOS NoContribops CI Pipeline,Linux OpenVINO CI Pipeline,orttraining-distributed

@azure-pipelines
Copy link

Azure Pipelines successfully started running 9 pipeline(s).

@azure-pipelines
Copy link

Azure Pipelines successfully started running 8 pipeline(s).

@SherlockNoMad
Copy link
Contributor

/azp run orttraining-amd-gpu-ci-pipeline, orttraining-ortmodule, orttraining-ortmodule-distributed

@azure-pipelines
Copy link

Azure Pipelines successfully started running 3 pipeline(s).

@SherlockNoMad
Copy link
Contributor

/azp run orttraining-amd-gpu-ci-pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 1 pipeline(s).

@SherlockNoMad SherlockNoMad merged commit 8ba6ed9 into microsoft:master May 1, 2021
mindest added a commit that referenced this pull request Apr 8, 2023
**Description**: Register an implementation for BatchNormInternal and
add a CPU kernel for BatchNormGradient. This is the third in a series of
PRs to implement BN training on CPU (first was #6946, second was #7539).

**Motivation and Context**
Support training networks with BatchNorm (e.g. convnets). Also note that
there exists a CUDA kernel for BN (forward training & backwards) but
it's currently disabled due to flaky failures; someone more familiar
with those parts can register the implementation for BNInternal on CUDA
(gradient kernel doesn't have to change).

---------

Co-authored-by: Simon Zirui Guo <simonguozirui@berkeley.edu>
Co-authored-by: mindest <linminuser@gmail.com>
Co-authored-by: mindest <30493312+mindest@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

training issues related to ONNX Runtime training; typically submitted using template

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants