Skip to content

Improve and fix bugs about fused softmax layer#133

Closed
hyunwoongko wants to merge 5 commits intoNVIDIA:mainfrom
hyunwoongko:main
Closed

Improve and fix bugs about fused softmax layer#133
hyunwoongko wants to merge 5 commits intoNVIDIA:mainfrom
hyunwoongko:main

Conversation

@hyunwoongko
Copy link
Copy Markdown
Contributor

@hyunwoongko hyunwoongko commented Aug 12, 2021

  1. Fix bugs about ELEMENTS_PER_LDG_STG (reported in Error in fused softmax kernel result #132)
  2. Add test codes for all fused cuda kernel using huggingface transformers
  3. Add constraint about 0 <= length_key <= 2048 (originally it was in the header file as TORCH_INTERNAL_ASSERT)
  4. Add constraint about batch_per_block (originally it was in the header file as TORCH_INTERNAL_ASSERT)
  5. Refactor python fused sacle mask softmax layer codes

@hyunwoongko
Copy link
Copy Markdown
Contributor Author

스크린샷 2021-08-13 오전 6 45 18

Everything works well.

int softmax_elements_stride,
int attn_batches)
{
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 );
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should retain these asserts if someone wants to use the cuda code directly.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I see. I agree with you.

Comment thread megatron/model/fused_softmax.py Outdated
self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_float16 # input must be fp16
and mask is not None # mask tensor must not be None
and 16 < sq <= 2048 # sq must be 16 ~ 2048
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should be 16 < sk <= 2048 and sq % 4 == 0

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

custom_kernel_constraint = key_seq_len > 16 and key_seq_len <= 2048 and \
            query_seq_len % 4 == 0 and attn_batch_size % 4 == 0

Yes. You are right,. It was my mistake. I will change these things.

@hyunwoongko
Copy link
Copy Markdown
Contributor Author

@kvareddy I fixed codes.:)

jaredcasper added a commit that referenced this pull request Sep 1, 2021
Fused softmax checks and additions from Github (#133)

See merge request ADLR/megatron-lm!312
@jaredcasper
Copy link
Copy Markdown
Contributor

These changes should all be merged in now. Thanks again for the PR!

@jaredcasper jaredcasper closed this Sep 1, 2021
itlamp pushed a commit to itlamp/Megatron-LM-comms that referenced this pull request Apr 7, 2025
* [SW-212054] W/A for dtype mismatch with CAG

* Narrow wa for CAG enabled only

* add TODO remove comment

* Reduced impact of workround to mixtral failing case only

* Revert not needed changes

* Remove empty spaces

* Remove one more empty space

* More generic local usage of get_args inside LinearWithGradAccumulationAndAsyncCommunication's backward

* Fix local import of get_args

* Remove usage of get_args in megatron/core

* Remove not needed empty line

* Style fixes

* Reorder import in layers.py

* Add/modify headers regarding 2025 year

* Reorder copyright headers, link todo with jira ticket
andresnowak pushed a commit to andresnowak/SwissAi-Megatron-LM that referenced this pull request Oct 28, 2025
* [feat] add yi & llama sparse upcycling model

* add run script

* tokenizer may be empty for megatron
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants