Conversation
remove bn replacement in grad graph builder
orttraining/orttraining/training_ops/cuda/nn/batch_norm_grad.cc
Outdated
Show resolved
Hide resolved
orttraining/orttraining/training_ops/cuda/nn/batch_norm_grad.cc
Outdated
Show resolved
Hide resolved
orttraining/orttraining/training_ops/cuda/nn/batch_norm_internal.cc
Outdated
Show resolved
Hide resolved
orttraining/orttraining/training_ops/cuda/nn/batch_norm_internal.cc
Outdated
Show resolved
Hide resolved
| std::vector<float> running_mean = {-0.1754f, 0.303106f}; | ||
| std::vector<float> running_var = {0.7812f, 1.5865f}; | ||
| std::vector<float> saved_mean = {-0.306f, 0.114562f}; | ||
| std::vector<float> saved_inv_std = {1.2288f, 0.861317f}; |
There was a problem hiding this comment.
let's rename this to saved_inv_var to reflect the reality.
and comment that this test data will only work for CUDA and not CPU.
There was a problem hiding this comment.
If the cudnn is actually returning saved_inv_std, would this UT also work for CPU impl?
There was a problem hiding this comment.
It should, but I also infer from the result given in the calculation that
- when calculating
saved_inv_stdandy, it uses biased std/var - when calculating
running_var, it uses unbiased std/var
As for the CPU implementation, it always uses the biased one for calculation. Not sure why cudnn has such inconsistency itself and which is more reasonable.
And the above difference makes the running_var output differ for CPU/CUDA given the same input data.
There was a problem hiding this comment.
ic... thank you for the detail investigation.
Could you please also help document this subtle difference in the kernel and UT comment?
I think the CPU impl is the correct one, as in the ONNX spec, we explicitly mentioned that the variance should be population variance, aka biased variance.
When the number of sample is large, the difference between biased-var and unbiased-var would be small. Let's note this done and move on.
|
Hi @mindest, thanks a lot for the PR. Please ask Vincent for sign-off if I am not online. |
Description: Implement BatchNormInternal for cuda
Motivation and Context