FusedRMSNorm/"T5LayerNorm" based on FusedLayerNorm#1274
FusedRMSNorm/"T5LayerNorm" based on FusedLayerNorm#1274crcrpar merged 11 commits intoNVIDIA:masterfrom
Conversation
|
Some benchmark data on A100: |
There was a problem hiding this comment.
bunch of suggestions to remove comment outed lines. you can batch into suggestions into one if you like, see https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/reviewing-changes-in-pull-requests/incorporating-feedback-in-your-pull-request#applying-suggested-changes.
What do you think about dissecting apex/normalization/fused_layer_norm.py into fused_layer_norm.py and fused_rms_norm.py?
| class FusedRMSNormAffineFunction(torch.autograd.Function): | ||
| @staticmethod | ||
| def forward(ctx, input, weight, normalized_shape, eps): | ||
| #def forward(ctx, input, weight, bias, normalized_shape, eps): |
There was a problem hiding this comment.
| #def forward(ctx, input, weight, bias, normalized_shape, eps): |
| #) | ||
| output, invvar = fused_layer_norm_cuda.rms_forward_affine( | ||
| input_, ctx.normalized_shape, weight_, ctx.eps) | ||
| #ctx.save_for_backward(input_, weight_, bias_, mean, invvar) |
There was a problem hiding this comment.
| #ctx.save_for_backward(input_, weight_, bias_, mean, invvar) |
| at::IntList normalized_shape, | ||
| #endif | ||
| at::Tensor* gamma, | ||
| // at::Tensor* beta, |
There was a problem hiding this comment.
| // at::Tensor* beta, |
| double epsilon, | ||
| at::Tensor* grad_input, | ||
| at::Tensor* grad_gamma) | ||
| // at::Tensor* grad_beta |
There was a problem hiding this comment.
| // at::Tensor* grad_beta |
| using accscalar_t = at::acc_type<scalar_t_in, true>; | ||
| HostRMSNormGradient( | ||
| dout->DATA_PTR<scalar_t_out>(), | ||
| // mean->DATA_PTR<accscalar_t>(), |
There was a problem hiding this comment.
| // mean->DATA_PTR<accscalar_t>(), |
| // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta | ||
| // if gamma Tensor is NULL on input. | ||
| gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL, | ||
| // gamma != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL, |
There was a problem hiding this comment.
| // gamma != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL, |
| epsilon, | ||
| grad_input->DATA_PTR<scalar_t_in>(), | ||
| gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_out>() : NULL); | ||
| // gamma != NULL ? grad_beta->DATA_PTR<scalar_t_out>() : NULL); |
There was a problem hiding this comment.
| // gamma != NULL ? grad_beta->DATA_PTR<scalar_t_out>() : NULL); |
| native = apex.normalization.FusedRMSNorm( | ||
| normalized_shape=normalized_shape, elementwise_affine=elementwise_affine | ||
| ) | ||
| fused = apex.normalization.FusedRMSNorm( | ||
| normalized_shape=normalized_shape, elementwise_affine=elementwise_affine | ||
| ).cuda() | ||
| return native, fused |
There was a problem hiding this comment.
this testing won't do much good as it's comparing to itself :)
Since there isn't torch.nn.RMSNorm, perhaps writing one out in plain python?
There was a problem hiding this comment.
It's a bit opaque here, but the "native" version is computed on CPU which dispatches to a manual plain python version sourced from T5LayerNorm:
apex/apex/normalization/fused_layer_norm.py
Line 410 in 028ef04
There was a problem hiding this comment.
Thank you for explaining this nuance, @eqy. I can see it now.
|
@crcrpar this is now refactored to use the existing FusedLayerNorm implementation via an added |
crcrpar
left a comment
There was a problem hiding this comment.
I think this is the last iteration
Oof, nice catch, thanks Co-authored-by: Masaki Kozuki <masaki.kozuki.2014@gmail.com>
| y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta | ||
|
|
||
| The mean and standard-deviation are calculated separately over the last | ||
| certain number dimensions which have to be of the shape specified by | ||
| :attr:`normalized_shape`. | ||
| :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of | ||
| :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``. |
There was a problem hiding this comment.
this looks like a copy-n-paste error - as this version has no bias and no mean subtraction in the math formula.
I think the note below needs updating as well wrt bias.
There was a problem hiding this comment.
oh! I missed that one - thank you!
looks much better now - with only one small issue - commented in the other PR.
* FusedRMSNorm based on FusedLayerNorm * refactor duplicated kernels * delete comments * delete comments * cleanup * cleanup * cleanup, fixed clobbering forward_affine_mixed_dtypes * fix pybind naming and add MixedFused test * undo skipping * check elementwise_affine * Update tests/L0/run_fused_layer_norm/test_fused_layer_norm.py Oof, nice catch, thanks Co-authored-by: Masaki Kozuki <masaki.kozuki.2014@gmail.com> Co-authored-by: Masaki Kozuki <masaki.kozuki.2014@gmail.com>
* FusedRMSNorm/"T5LayerNorm" based on FusedLayerNorm (NVIDIA#1274) * FusedRMSNorm based on FusedLayerNorm * refactor duplicated kernels * delete comments * delete comments * cleanup * cleanup * cleanup, fixed clobbering forward_affine_mixed_dtypes * fix pybind naming and add MixedFused test * undo skipping * check elementwise_affine * Update tests/L0/run_fused_layer_norm/test_fused_layer_norm.py Oof, nice catch, thanks Co-authored-by: Masaki Kozuki <masaki.kozuki.2014@gmail.com> Co-authored-by: Masaki Kozuki <masaki.kozuki.2014@gmail.com> * fix and generate docs for FusedRMSNorm (NVIDIA#1285) * [FusedRMSNorm doc] document where epsilon is added (NVIDIA#1295) * [FusedRMSNorm doc] add epsilon to formula * correct * better wording * Fix some bugs * Optimize HostRMSNormGradient and HostApplyRMSNorm for AMD GPUs * Fix NaN issues in FusedRMSNorm * Update test_fused_layer_norm.py * Skip test_fused_layer_norm.TestAutocastFusedRMSNorm on ROCm * Use at::cuda::warp_size() instead of at::cuda::getCurrentDeviceProperties()->warpSize Co-authored-by: eqy <eddiey@nvidia.com> Co-authored-by: Masaki Kozuki <masaki.kozuki.2014@gmail.com> Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
#1271
Pattern-matched implementation of FusedRMSNorm based on FusedLayerNorm. Tests are passing (needed threshold adjustment for
float16), awaiting benchmark results and cleanup.