Skip to content

Feature/add bottom causal mask#2480

Merged
hwu36 merged 12 commits intoNVIDIA:mainfrom
Aya-ZIbra:feature/add-bottom-causal-mask
Sep 18, 2025
Merged

Feature/add bottom causal mask#2480
hwu36 merged 12 commits intoNVIDIA:mainfrom
Aya-ZIbra:feature/add-bottom-causal-mask

Conversation

@Aya-ZIbra
Copy link
Copy Markdown
Contributor

This commit adds support for bottom-right causal masking in the FMHA implementation by parameterizing the CausalMask struct with a boolean template parameter kIsQBegin.

Changes

  • Modified CausalMask to add template for BottomCausalMask, when kIsQBegin is false.
  • Added logic to calculate the appropriate offset based on the kIsQBegin parameter.
  • Updated formula for calculation of get_masked_trip_count.

Use Case

This enhancement is particularly needed for partial prefill inference scenarios. By setting kIsQBegin=false, users can now efficiently implement inference pipelines that use bottom-right causal masking.

Aya-ZIbra and others added 4 commits July 18, 2025 02:15
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@TejashShah
Copy link
Copy Markdown

Thanks @Aya-ZIbra for contributing into CUTLASS repo.

@v0i0 @hwu36 @thakkarV for code review.

@v0i0
Copy link
Copy Markdown
Contributor

v0i0 commented Jul 21, 2025

lgtm; no permission to formally approve

@richardmcai
Copy link
Copy Markdown
Contributor

This appears to be functionally the same as what's on master now. Does this still need to be merged @Aya-ZIbra ?

@Aya-ZIbra
Copy link
Copy Markdown
Contributor Author

Aya-ZIbra commented Jul 26, 2025

@richardmcai I see now that the masked trip_count calculation is updated on master. We still need to merge this PR because:

  1. The current calculation of masked_trip_count fails for some of my local tests
    For example: batch_size: 2, causal: True, is_mqa: False, cu_seqlens_q: [ 0, 1013, 2021], cu_seqlens_k: [ 0, 1024, 2035]
  2. The code is more concise and avoids repetition.
  3. Calculation in this PR avoids over-estimation of masked trip_count by calculating the q_tile shape as function of blk_coord.

Comment thread examples/77_blackwell_fmha/collective/fmha_fusion.hpp Outdated
@richardmcai
Copy link
Copy Markdown
Contributor

The current calculation of masked_trip_count fails for some of my local tests
For example: batch_size: 2, causal: True, is_mqa: False, cu_seqlens_q: [ 0, 1013, 2021], cu_seqlens_k: [ 0, 1024, 2035]

can you add this to the CMake testlist?

fixed flipped logic for isQBegin
Comment on lines +221 to +222
int q_tile = min(get<0>(tile_shape), get<0>(problem_size) - get<0>(blk_coord) * get<0>(tile_shape));
int offset_tile_q = IsQBegin ? 0: int((get<1>(problem_size)) - int(get<0>(problem_size))% get<1>(tile_shape));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dianzhangchen does this seem fine to you?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current calculation of masked_trip_count fails for some of my local tests

Hi @Aya-ZIbra , thanks for contributing to the CUTLASS repo and for pointing out this bug.

Based on your tip, I found what’s causing the bug. However, the current PR causes a performance drop when kIsQBegin is false. I’d suggest using the logic below instead—it only changes a few lines of code and won’t affect performance. What do you think?

if constexpr (IsQBegin) {
    return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape))));
} else {
    const int corner_count = int((get<1>(problem_size) % get<1>(tile_shape) || get<0>(problem_size) % get<0>(tile_shape))) ;
    return std::min(trip_count, int(ceil_div(get<0>(tile_shape), get<1>(tile_shape))) + corner_count);
}

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! I have applied the changes. local tests are now passing.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dianzhangchen @richardmcai : Let me know if I can get this stamped, please.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Aya-ZIbra, it LGTM.

@Aya-ZIbra
Copy link
Copy Markdown
Contributor Author

Merged in deepseek-ai/FlashMLA#85

@Aya-ZIbra Aya-ZIbra closed this Aug 25, 2025
@Aya-ZIbra Aya-ZIbra reopened this Sep 5, 2025
} else {
const int offset_tile_q = get<1>(problem_size) % get<1>(tile_shape);
return std::min(trip_count, int(ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape))));
const int corner_count = int((get<1>(problem_size) % get<1>(tile_shape) || get<0>(problem_size) % get<0>(tile_shape))) ;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the || doing here? These are int values not bools?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was proposed by dianzhangchen for performance, I guess.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nevertheless it's pretty confusing

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use naive logic here: if problem_size0 (or problem_size1) isn’t a multiple of tile_size0 (or tile_size1), we add one extra mask tile. In some corner cases, this may add a redundant mask tile (when problem_size0 % tile_size0 = problem_size1 % tile_size1 != 0).

const int offset_tile_q = get<1>(problem_size) % get<1>(tile_shape);
return std::min(trip_count, int(ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape))));
const int corner_count = int((get<1>(problem_size) % get<1>(tile_shape) || get<0>(problem_size) % get<0>(tile_shape))) ;
return std::min(trip_count, int(ceil_div(get<0>(tile_shape), get<1>(tile_shape))) + corner_count);
Copy link
Copy Markdown

@ngimel ngimel Sep 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't look correct, suppose tile_shape0 = 2, tile_shape1 = 4, and get<1>(problem_size) % get<1>(tile_shape) = 2 and get<0>(problem_size) % get<0>(tile_shape) == 0 (so previous calculation should be correct, because in the 0th dimension problem is evenly divisible. Previously we would have ceil_div(2+2, 4) = 1, now we are getting corner_count=1, ceil_div(2, 4) + corner_count = 1 + 1 = 2, which differs from previous behavior, and it shouldn't given that in dim 0 proclem_size is divisible by tile_shape

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you looked carefully, the older calculation ( for problem_size divisible by tile_shape) does over-estimate the masked trips a bit ( +1) . But that is not an issue as long as we do min(trip_count, masked_trips).

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm talking about a case where new calculation produces a value that is larger than the old calculation did. For problem size divisible by tile shape both calculations produce the same value.

The current expression is confusing
@Aya-ZIbra
Copy link
Copy Markdown
Contributor Author

Aya-ZIbra commented Sep 6, 2025

@ngimel @richardmcai
To clarify my proposal, here is my original implementation which is simpler to explain.

// my original implementation (1)
int offset_tile_q = (get<1>(problem_size) % get<1>(tile_shape)) -(get<0>(problem_size) % get<1>(tile_shape));
return std::min(trip_count, int(ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape))));

The number of k tiles that the q segment spans can be expressed as:

number_of_splits=ceil_div(q_tile+x,k_tile)
where x is a value that accounts for the misalignment caused by the offset.

The point is that current implementation doesn't take in account that misalignment can come from qlen%k_tile != 0 .

// current
int offset_tile_q = (get<1>(problem_size) % get<1>(tile_shape))

The offset is calculated as the difference between the total lengths of the K and q dimensions:

offset = klen - qlen = (n * k_tile + k_rem) - (m * k_tile + q_rem)

Simplifying:
offset = (n - m) * k_tile + (k_rem - q_rem)
The term qrem = (get<0>(problem_size) % get<1>(tile_shape)); is important for correctness.

Alternatively, we can account for this as follows, if it makes a difference for performance:

const int corner_count = int((get<1>(problem_size) % get<1>(tile_shape) || get<0>(problem_size) % get<1>(tile_shape))) ; 
      return std::min(trip_count, int(ceil_div(get<0>(tile_shape), get<1>(tile_shape))) + corner_count);

Finally, here is a picture for the edge case I would like to fix. Currently, the masked_trips would wrongly evaluate to 2. Let me know your feedback please.
Scratch drawing

} else {
const int corner_count = int((get<1>(problem_size) % get<1>(tile_shape) || get<0>(problem_size) % get<0>(tile_shape))) ;
return std::min(trip_count, int(ceil_div(get<0>(tile_shape), get<1>(tile_shape))) + corner_count);
int offset_tile_q = (get<1>(problem_size) % get<1>(tile_shape)) - (get<0>(problem_size) % get<1>(tile_shape));
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this dividing 0 problem size by 1 tile shape? Are there any implicitinvariants, like get<1>(tile_shape) is a multiple of get<0>(tile_shape)? If so, they should be explicitly mentioned in the comments otherwise this is very confusing.

Copy link
Copy Markdown
Contributor Author

@Aya-ZIbra Aya-ZIbra Sep 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel I don't fully get your question. For Qbotom = True, the offset is calculated as (<1>problem_size <0>problem_size) so that is why <0>problem_size divisiblity by <1> tile shape needs to be considered. Our tests are failing without this fix.
image

@Aya-ZIbra Aya-ZIbra requested a review from ngimel September 8, 2025 18:53
Reproduce error/fix with: 
./77_blackwell_fmha_fp16 --verify --b=1 --q=1013 --k=1024 --h=1 --h_k=1 --mask=causal --causal-type=qend
@hwu36 hwu36 merged commit 6457918 into NVIDIA:main Sep 18, 2025
hwu36 added a commit that referenced this pull request Sep 23, 2025
* Rebase to latest

* update

* upd

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Update fmha_fusion.hpp

* Update fmha_fusion.hpp

fixed flipped logic for isQBegin

* Update fmha_fusion.hpp

* Avoid use of booleans

The current expression is confusing

* fmt

* Update fmha_fusion.hpp

Reproduce error/fix with: 
./77_blackwell_fmha_fp16 --verify --b=1 --q=1013 --k=1024 --h=1 --h_k=1 --mask=causal --causal-type=qend

* add test, format

---------

Co-authored-by: Richard Cai <ricai@nvidia.com>
Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
Albresky pushed a commit to Albresky/cutlass that referenced this pull request Oct 11, 2025
* Rebase to latest

* update

* upd

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Update fmha_fusion.hpp

* Update fmha_fusion.hpp

fixed flipped logic for isQBegin

* Update fmha_fusion.hpp

* Avoid use of booleans

The current expression is confusing

* fmt

* Update fmha_fusion.hpp

Reproduce error/fix with: 
./77_blackwell_fmha_fp16 --verify --b=1 --q=1013 --k=1024 --h=1 --h_k=1 --mask=causal --causal-type=qend

* add test, format

---------

Co-authored-by: Richard Cai <ricai@nvidia.com>
Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
guocuimi pushed a commit to vectorch-ai/cutlass that referenced this pull request Nov 6, 2025
* Rebase to latest

* update

* upd

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Update fmha_fusion.hpp

* Update fmha_fusion.hpp

fixed flipped logic for isQBegin

* Update fmha_fusion.hpp

* Avoid use of booleans

The current expression is confusing

* fmt

* Update fmha_fusion.hpp

Reproduce error/fix with: 
./77_blackwell_fmha_fp16 --verify --b=1 --q=1013 --k=1024 --h=1 --h_k=1 --mask=causal --causal-type=qend

* add test, format

---------

Co-authored-by: Richard Cai <ricai@nvidia.com>
Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants