Skip to content

Fix DAct input ordering of gradient input and activation input#1460

Closed
jberchtold-nvidia wants to merge 2 commits intoNVIDIA:release_v2.0from
jberchtold-nvidia:dev/jberchtold/fix-dact-with-zero-grad-input
Closed

Fix DAct input ordering of gradient input and activation input#1460
jberchtold-nvidia wants to merge 2 commits intoNVIDIA:release_v2.0from
jberchtold-nvidia:dev/jberchtold/fix-dact-with-zero-grad-input

Conversation

@jberchtold-nvidia
Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia commented Feb 6, 2025

Description

Fixes an issue where nvte_quantize_dbias_ and cast_fp8_2D where the gradient input and activation inputs were incorrectly swapped in their usages.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Fix internal functions used by nvte_quantize_dbias_ so it can now be used correctly with inputs of gradient input followed by activation input
  • Update dact_fn in activation_template.h to undo the swapping of parameters
  • Update cast_fp8_2D_kernel to fix usages of input and act_input tensors to correctly usage input as the gradient and act_input as the original activation input from the forward pass
  • Add unit tests for DAct where the output is FP8 to test these kernels. Included a test where the input gradient is all zeros to highlight a case that was failing without this change where the output gradient was not all zeroes

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@jberchtold-nvidia jberchtold-nvidia changed the base branch from main to release_v2.0 February 6, 2025 03:12
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia force-pushed the dev/jberchtold/fix-dact-with-zero-grad-input branch from 652e986 to 42f3560 Compare February 6, 2025 17:37
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci

@ptrendx
Copy link
Member

ptrendx commented Feb 7, 2025

Fixed in #1462, which was just merged. Closing this PR.

@ptrendx ptrendx closed this Feb 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants