Feature/add bottom causal mask#2480
Conversation
|
Thanks @Aya-ZIbra for contributing into CUTLASS repo. |
|
lgtm; no permission to formally approve |
|
This appears to be functionally the same as what's on master now. Does this still need to be merged @Aya-ZIbra ? |
|
@richardmcai I see now that the masked trip_count calculation is updated on master. We still need to merge this PR because:
|
can you add this to the CMake testlist? |
fixed flipped logic for isQBegin
| int q_tile = min(get<0>(tile_shape), get<0>(problem_size) - get<0>(blk_coord) * get<0>(tile_shape)); | ||
| int offset_tile_q = IsQBegin ? 0: int((get<1>(problem_size)) - int(get<0>(problem_size))% get<1>(tile_shape)); |
There was a problem hiding this comment.
The current calculation of masked_trip_count fails for some of my local tests
Hi @Aya-ZIbra , thanks for contributing to the CUTLASS repo and for pointing out this bug.
Based on your tip, I found what’s causing the bug. However, the current PR causes a performance drop when kIsQBegin is false. I’d suggest using the logic below instead—it only changes a few lines of code and won’t affect performance. What do you think?
if constexpr (IsQBegin) {
return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape))));
} else {
const int corner_count = int((get<1>(problem_size) % get<1>(tile_shape) || get<0>(problem_size) % get<0>(tile_shape))) ;
return std::min(trip_count, int(ceil_div(get<0>(tile_shape), get<1>(tile_shape))) + corner_count);
}There was a problem hiding this comment.
Thank you! I have applied the changes. local tests are now passing.
There was a problem hiding this comment.
@dianzhangchen @richardmcai : Let me know if I can get this stamped, please.
|
Merged in deepseek-ai/FlashMLA#85 |
| } else { | ||
| const int offset_tile_q = get<1>(problem_size) % get<1>(tile_shape); | ||
| return std::min(trip_count, int(ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape)))); | ||
| const int corner_count = int((get<1>(problem_size) % get<1>(tile_shape) || get<0>(problem_size) % get<0>(tile_shape))) ; |
There was a problem hiding this comment.
What is the || doing here? These are int values not bools?
There was a problem hiding this comment.
I was proposed by dianzhangchen for performance, I guess.
There was a problem hiding this comment.
We use naive logic here: if problem_size0 (or problem_size1) isn’t a multiple of tile_size0 (or tile_size1), we add one extra mask tile. In some corner cases, this may add a redundant mask tile (when problem_size0 % tile_size0 = problem_size1 % tile_size1 != 0).
| const int offset_tile_q = get<1>(problem_size) % get<1>(tile_shape); | ||
| return std::min(trip_count, int(ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape)))); | ||
| const int corner_count = int((get<1>(problem_size) % get<1>(tile_shape) || get<0>(problem_size) % get<0>(tile_shape))) ; | ||
| return std::min(trip_count, int(ceil_div(get<0>(tile_shape), get<1>(tile_shape))) + corner_count); |
There was a problem hiding this comment.
This doesn't look correct, suppose tile_shape0 = 2, tile_shape1 = 4, and get<1>(problem_size) % get<1>(tile_shape) = 2 and get<0>(problem_size) % get<0>(tile_shape) == 0 (so previous calculation should be correct, because in the 0th dimension problem is evenly divisible. Previously we would have ceil_div(2+2, 4) = 1, now we are getting corner_count=1, ceil_div(2, 4) + corner_count = 1 + 1 = 2, which differs from previous behavior, and it shouldn't given that in dim 0 proclem_size is divisible by tile_shape
There was a problem hiding this comment.
If you looked carefully, the older calculation ( for problem_size divisible by tile_shape) does over-estimate the masked trips a bit ( +1) . But that is not an issue as long as we do min(trip_count, masked_trips).
There was a problem hiding this comment.
I'm talking about a case where new calculation produces a value that is larger than the old calculation did. For problem size divisible by tile shape both calculations produce the same value.
The current expression is confusing
|
@ngimel @richardmcai // my original implementation (1)
int offset_tile_q = (get<1>(problem_size) % get<1>(tile_shape)) -(get<0>(problem_size) % get<1>(tile_shape));
return std::min(trip_count, int(ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape))));The number of k tiles that the q segment spans can be expressed as:
The point is that current implementation doesn't take in account that misalignment can come from qlen%k_tile != 0 . // current
int offset_tile_q = (get<1>(problem_size) % get<1>(tile_shape))The offset is calculated as the difference between the total lengths of the K and q dimensions:
Simplifying: Alternatively, we can account for this as follows, if it makes a difference for performance: const int corner_count = int((get<1>(problem_size) % get<1>(tile_shape) || get<0>(problem_size) % get<1>(tile_shape))) ;
return std::min(trip_count, int(ceil_div(get<0>(tile_shape), get<1>(tile_shape))) + corner_count);Finally, here is a picture for the edge case I would like to fix. Currently, the masked_trips would wrongly evaluate to 2. Let me know your feedback please. |
| } else { | ||
| const int corner_count = int((get<1>(problem_size) % get<1>(tile_shape) || get<0>(problem_size) % get<0>(tile_shape))) ; | ||
| return std::min(trip_count, int(ceil_div(get<0>(tile_shape), get<1>(tile_shape))) + corner_count); | ||
| int offset_tile_q = (get<1>(problem_size) % get<1>(tile_shape)) - (get<0>(problem_size) % get<1>(tile_shape)); |
There was a problem hiding this comment.
why is this dividing 0 problem size by 1 tile shape? Are there any implicitinvariants, like get<1>(tile_shape) is a multiple of get<0>(tile_shape)? If so, they should be explicitly mentioned in the comments otherwise this is very confusing.
There was a problem hiding this comment.
@ngimel I don't fully get your question. For Qbotom = True, the offset is calculated as (<1>problem_size <0>problem_size) so that is why <0>problem_size divisiblity by <1> tile shape needs to be considered. Our tests are failing without this fix.

Reproduce error/fix with: ./77_blackwell_fmha_fp16 --verify --b=1 --q=1013 --k=1024 --h=1 --h_k=1 --mask=causal --causal-type=qend
* Rebase to latest * update * upd Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Update fmha_fusion.hpp * Update fmha_fusion.hpp fixed flipped logic for isQBegin * Update fmha_fusion.hpp * Avoid use of booleans The current expression is confusing * fmt * Update fmha_fusion.hpp Reproduce error/fix with: ./77_blackwell_fmha_fp16 --verify --b=1 --q=1013 --k=1024 --h=1 --h_k=1 --mask=causal --causal-type=qend * add test, format --------- Co-authored-by: Richard Cai <ricai@nvidia.com> Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
* Rebase to latest * update * upd Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Update fmha_fusion.hpp * Update fmha_fusion.hpp fixed flipped logic for isQBegin * Update fmha_fusion.hpp * Avoid use of booleans The current expression is confusing * fmt * Update fmha_fusion.hpp Reproduce error/fix with: ./77_blackwell_fmha_fp16 --verify --b=1 --q=1013 --k=1024 --h=1 --h_k=1 --mask=causal --causal-type=qend * add test, format --------- Co-authored-by: Richard Cai <ricai@nvidia.com> Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
* Rebase to latest * update * upd Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Update fmha_fusion.hpp * Update fmha_fusion.hpp fixed flipped logic for isQBegin * Update fmha_fusion.hpp * Avoid use of booleans The current expression is confusing * fmt * Update fmha_fusion.hpp Reproduce error/fix with: ./77_blackwell_fmha_fp16 --verify --b=1 --q=1013 --k=1024 --h=1 --h_k=1 --mask=causal --causal-type=qend * add test, format --------- Co-authored-by: Richard Cai <ricai@nvidia.com> Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>

This commit adds support for bottom-right causal masking in the FMHA implementation by parameterizing the
CausalMaskstruct with a boolean template parameterkIsQBegin.Changes
CausalMaskto add template for BottomCausalMask, whenkIsQBeginis false.kIsQBeginparameter.get_masked_trip_count.Use Case
This enhancement is particularly needed for partial prefill inference scenarios. By setting
kIsQBegin=false, users can now efficiently implement inference pipelines that use bottom-right causal masking.