From 722781aef34a94f701641d6c86d18d356bbd29fe Mon Sep 17 00:00:00 2001 From: Aya Ibrahim Date: Fri, 18 Jul 2025 01:34:00 -0700 Subject: [PATCH 01/10] Rebase to latest --- .../collective/fmha_fusion.hpp | 52 ++++++------------- 1 file changed, 17 insertions(+), 35 deletions(-) diff --git a/examples/77_blackwell_fmha/collective/fmha_fusion.hpp b/examples/77_blackwell_fmha/collective/fmha_fusion.hpp index 000c0a0a5f..e7e57beaac 100644 --- a/examples/77_blackwell_fmha/collective/fmha_fusion.hpp +++ b/examples/77_blackwell_fmha/collective/fmha_fusion.hpp @@ -203,16 +203,10 @@ struct CausalMask : NoMask { // See note below on different ways to think about causal attention // Again, we'd add the offset_q into the max_blocks_q calculation - if constexpr (IsQBegin) { - int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size); - int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape), get<1>(tile_shape)); - return std::min(max_blocks_k, max_blocks_q); - } else { - const int offset_q = get<1>(problem_size) - get<0>(problem_size); - int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size); - int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape) + offset_q, get<1>(tile_shape)); - return std::min(max_blocks_k, max_blocks_q); - } + int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size); + int offset_q = IsQBegin ? int(get<1>(problem_size)) - int(get<0>(problem_size)) : 0; + int max_blocks_q = ceil_div(offset_q + (get<0>(blk_coord) + 1) * get<0>(tile_shape), get<1>(tile_shape)); + return std::min(max_blocks_k, max_blocks_q); } template @@ -221,14 +215,13 @@ struct CausalMask : NoMask { BlkCoord const& blk_coord, TileShape const& tile_shape, ProblemSize const& problem_size) { - - if constexpr (IsQBegin) { - int trip_count = get_trip_count(blk_coord, tile_shape, problem_size); - return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape)))); - } else { - const int offset_tile_q = get<1>(problem_size) % get<1>(tile_shape); - return ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape)); - } + int trip_count = get_trip_count(blk_coord, tile_shape, problem_size); + int q_tile = min(get<0>(tile_shape), get<0>(problem_size) - get<0>(blk_coord) * get<0>(tile_shape)); + int offset_q = IsQBegin ? int(get<1>(problem_size)) - int(get<0>(problem_size)) : 0; + int first_masked_tile_k = int(offset_q / get<1>(tile_shape)); + int last_masked_tile_k = int((offset_q + q_tile - 1) / get<1>(tile_shape)); + int masked_blocks = last_masked_tile_k - first_masked_tile_k + 1; + return std::min(masked_blocks, trip_count); } template @@ -255,23 +248,12 @@ struct CausalMask : NoMask { // - this is usually what we want for inference settings // where we only compute the next row and use cache for the rest // - if you'd like this, you only need to set kIsQBegin=false - - if constexpr (IsQBegin) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(acc_qk); i++) { - auto pos = index_qk(i); - if ((get<0>(pos) < get<1>(pos)) || (get<1>(pos) >= get<1>(problem_size))) { - acc_qk(i) = -INFINITY; - } - } - } else { - const auto offset_q = get<1>(problem_size) - get<0>(problem_size); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(acc_qk); i++) { - auto pos = index_qk(i); - if ((get<0>(pos) + offset_q < get<1>(pos)) || (get<1>(pos) >= get<1>(problem_size))) { - acc_qk(i) = -INFINITY; - } + int offset_q = IsQBegin ? int(get<1>(problem_size)) - int(get<0>(problem_size)) : 0; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + if ((get<0>(pos) + offset_q < get<1>(pos)) ||(get<1>(pos) >= get<1>(problem_size))) { + acc_qk(i) = -INFINITY; } } } From 08556c9bc6ac84426e087e487b4a8906db9d77e0 Mon Sep 17 00:00:00 2001 From: ayaoibrahim1123 Date: Fri, 18 Jul 2025 12:10:45 -0700 Subject: [PATCH 02/10] update --- .../77_blackwell_fmha/collective/fmha_fusion.hpp | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/examples/77_blackwell_fmha/collective/fmha_fusion.hpp b/examples/77_blackwell_fmha/collective/fmha_fusion.hpp index e7e57beaac..cc479ef2f3 100644 --- a/examples/77_blackwell_fmha/collective/fmha_fusion.hpp +++ b/examples/77_blackwell_fmha/collective/fmha_fusion.hpp @@ -215,13 +215,14 @@ struct CausalMask : NoMask { BlkCoord const& blk_coord, TileShape const& tile_shape, ProblemSize const& problem_size) { - int trip_count = get_trip_count(blk_coord, tile_shape, problem_size); - int q_tile = min(get<0>(tile_shape), get<0>(problem_size) - get<0>(blk_coord) * get<0>(tile_shape)); - int offset_q = IsQBegin ? int(get<1>(problem_size)) - int(get<0>(problem_size)) : 0; - int first_masked_tile_k = int(offset_q / get<1>(tile_shape)); - int last_masked_tile_k = int((offset_q + q_tile - 1) / get<1>(tile_shape)); - int masked_blocks = last_masked_tile_k - first_masked_tile_k + 1; - return std::min(masked_blocks, trip_count); + + if constexpr (IsQBegin) { + int trip_count = get_trip_count(blk_coord, tile_shape, problem_size); + return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape)))); + } else { + const int offset_tile_q = get<1>(problem_size) % get<1>(tile_shape); + return ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape)); + } } template From 81c475f5bac56de949f4abc09b23e6157394c5c5 Mon Sep 17 00:00:00 2001 From: Aya Ibrahim Date: Fri, 18 Jul 2025 14:33:16 -0700 Subject: [PATCH 03/10] upd Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- .../77_blackwell_fmha/collective/fmha_fusion.hpp | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/examples/77_blackwell_fmha/collective/fmha_fusion.hpp b/examples/77_blackwell_fmha/collective/fmha_fusion.hpp index cc479ef2f3..e3c907b082 100644 --- a/examples/77_blackwell_fmha/collective/fmha_fusion.hpp +++ b/examples/77_blackwell_fmha/collective/fmha_fusion.hpp @@ -215,14 +215,12 @@ struct CausalMask : NoMask { BlkCoord const& blk_coord, TileShape const& tile_shape, ProblemSize const& problem_size) { - - if constexpr (IsQBegin) { - int trip_count = get_trip_count(blk_coord, tile_shape, problem_size); - return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape)))); - } else { - const int offset_tile_q = get<1>(problem_size) % get<1>(tile_shape); - return ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape)); - } + int trip_count = get_trip_count(blk_coord, tile_shape, problem_size); + int q_tile = min(get<0>(tile_shape), get<0>(problem_size) - get<0>(blk_coord) * get<0>(tile_shape)); + int offset_tile_q = IsQBegin ? int((get<1>(problem_size)) - int(get<0>(problem_size))% get<1>(tile_shape)) : 0; + + int masked_blocks = ceil_div(q_tile + offset_tile_q, get<1>tile_shape); + return std::min(masked_blocks, trip_count); } template From 3119f3c436865cba3d12514a62646a0e2cdf54ee Mon Sep 17 00:00:00 2001 From: "Aya Z. Ibrahim" <56703597+Aya-ZIbra@users.noreply.github.com> Date: Fri, 18 Jul 2025 15:25:50 -0700 Subject: [PATCH 04/10] Update fmha_fusion.hpp --- examples/77_blackwell_fmha/collective/fmha_fusion.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/77_blackwell_fmha/collective/fmha_fusion.hpp b/examples/77_blackwell_fmha/collective/fmha_fusion.hpp index e3c907b082..f19b57e395 100644 --- a/examples/77_blackwell_fmha/collective/fmha_fusion.hpp +++ b/examples/77_blackwell_fmha/collective/fmha_fusion.hpp @@ -219,7 +219,7 @@ struct CausalMask : NoMask { int q_tile = min(get<0>(tile_shape), get<0>(problem_size) - get<0>(blk_coord) * get<0>(tile_shape)); int offset_tile_q = IsQBegin ? int((get<1>(problem_size)) - int(get<0>(problem_size))% get<1>(tile_shape)) : 0; - int masked_blocks = ceil_div(q_tile + offset_tile_q, get<1>tile_shape); + int masked_blocks = ceil_div(q_tile + offset_tile_q, get<1>(tile_shape)); return std::min(masked_blocks, trip_count); } From 61e6a2299d45cfc776e3c7b801744b8e8d987d41 Mon Sep 17 00:00:00 2001 From: "Aya Z. Ibrahim" <56703597+Aya-ZIbra@users.noreply.github.com> Date: Mon, 28 Jul 2025 12:36:22 -0700 Subject: [PATCH 05/10] Update fmha_fusion.hpp fixed flipped logic for isQBegin --- examples/77_blackwell_fmha/collective/fmha_fusion.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/77_blackwell_fmha/collective/fmha_fusion.hpp b/examples/77_blackwell_fmha/collective/fmha_fusion.hpp index 39ae750381..bdf3c1f2ed 100644 --- a/examples/77_blackwell_fmha/collective/fmha_fusion.hpp +++ b/examples/77_blackwell_fmha/collective/fmha_fusion.hpp @@ -204,7 +204,7 @@ struct CausalMask : NoMask { // See note below on different ways to think about causal attention // Again, we'd add the offset_q into the max_blocks_q calculation int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size); - int offset_q = IsQBegin ? int(get<1>(problem_size)) - int(get<0>(problem_size)) : 0; + int offset_q = IsQBegin ? 0: int(get<1>(problem_size)) - int(get<0>(problem_size)); int max_blocks_q = ceil_div(offset_q + (get<0>(blk_coord) + 1) * get<0>(tile_shape), get<1>(tile_shape)); return std::min(max_blocks_k, max_blocks_q); @@ -219,7 +219,7 @@ struct CausalMask : NoMask { int trip_count = get_trip_count(blk_coord, tile_shape, problem_size); int q_tile = min(get<0>(tile_shape), get<0>(problem_size) - get<0>(blk_coord) * get<0>(tile_shape)); - int offset_tile_q = IsQBegin ? int((get<1>(problem_size)) - int(get<0>(problem_size))% get<1>(tile_shape)) : 0; + int offset_tile_q = IsQBegin ? 0: int((get<1>(problem_size)) - int(get<0>(problem_size))% get<1>(tile_shape)); int masked_blocks = ceil_div(q_tile + offset_tile_q, get<1>(tile_shape)); return std::min(masked_blocks, trip_count); @@ -250,7 +250,7 @@ struct CausalMask : NoMask { // - this is usually what we want for inference settings // where we only compute the next row and use cache for the rest // - if you'd like this, you only need to set kIsQBegin=false - int offset_q = IsQBegin ? int(get<1>(problem_size)) - int(get<0>(problem_size)) : 0; + int offset_q = IsQBegin ? 0: int(get<1>(problem_size)) - int(get<0>(problem_size)); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(acc_qk); i++) { auto pos = index_qk(i); From 97bf993f6d11d697ffafd5e6ef695b1141f2474d Mon Sep 17 00:00:00 2001 From: "Aya Z. Ibrahim" <56703597+Aya-ZIbra@users.noreply.github.com> Date: Thu, 21 Aug 2025 10:46:03 -0700 Subject: [PATCH 06/10] Update fmha_fusion.hpp --- .../collective/fmha_fusion.hpp | 50 ++++++++++++------- 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/examples/77_blackwell_fmha/collective/fmha_fusion.hpp b/examples/77_blackwell_fmha/collective/fmha_fusion.hpp index bdf3c1f2ed..7bc7935c21 100644 --- a/examples/77_blackwell_fmha/collective/fmha_fusion.hpp +++ b/examples/77_blackwell_fmha/collective/fmha_fusion.hpp @@ -204,10 +204,14 @@ struct CausalMask : NoMask { // See note below on different ways to think about causal attention // Again, we'd add the offset_q into the max_blocks_q calculation int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size); - int offset_q = IsQBegin ? 0: int(get<1>(problem_size)) - int(get<0>(problem_size)); - int max_blocks_q = ceil_div(offset_q + (get<0>(blk_coord) + 1) * get<0>(tile_shape), get<1>(tile_shape)); - return std::min(max_blocks_k, max_blocks_q); - + if constexpr (IsQBegin) { + int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape), get<1>(tile_shape)); + return std::min(max_blocks_k, max_blocks_q); + } else { + const int offset_q = get<1>(problem_size) - get<0>(problem_size); + int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape) + offset_q, get<1>(tile_shape)); + return std::min(max_blocks_k, max_blocks_q); + } } template @@ -216,14 +220,14 @@ struct CausalMask : NoMask { BlkCoord const& blk_coord, TileShape const& tile_shape, ProblemSize const& problem_size) { - + int trip_count = get_trip_count(blk_coord, tile_shape, problem_size); - int q_tile = min(get<0>(tile_shape), get<0>(problem_size) - get<0>(blk_coord) * get<0>(tile_shape)); - int offset_tile_q = IsQBegin ? 0: int((get<1>(problem_size)) - int(get<0>(problem_size))% get<1>(tile_shape)); - - int masked_blocks = ceil_div(q_tile + offset_tile_q, get<1>(tile_shape)); - return std::min(masked_blocks, trip_count); - + if constexpr (IsQBegin) { + return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape)))); + } else { + const int corner_count = int((get<1>(problem_size) % get<1>(tile_shape) || get<0>(problem_size) % get<0>(tile_shape))) ; + return std::min(trip_count, int(ceil_div(get<0>(tile_shape), get<1>(tile_shape))) + corner_count); + } } template @@ -250,17 +254,27 @@ struct CausalMask : NoMask { // - this is usually what we want for inference settings // where we only compute the next row and use cache for the rest // - if you'd like this, you only need to set kIsQBegin=false - int offset_q = IsQBegin ? 0: int(get<1>(problem_size)) - int(get<0>(problem_size)); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(acc_qk); i++) { - auto pos = index_qk(i); - if ((get<0>(pos) + offset_q < get<1>(pos)) ||(get<1>(pos) >= get<1>(problem_size))) { - acc_qk(i) = -INFINITY; + + if constexpr (IsQBegin) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + if ((get<0>(pos) < get<1>(pos)) || (get<1>(pos) >= get<1>(problem_size))) { + acc_qk(i) = -INFINITY; + } + } + } else { + const auto offset_q = get<1>(problem_size) - get<0>(problem_size); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + if ((get<0>(pos) + offset_q < get<1>(pos)) || (get<1>(pos) >= get<1>(problem_size))) { + acc_qk(i) = -INFINITY; + } } } } }; - template struct CausalForBackwardMask : CausalMask, ResidualMaskForBackward { From 5576a6c04b3abef5c6c255b0b2034b98b2b28049 Mon Sep 17 00:00:00 2001 From: "Aya Z. Ibrahim" <56703597+Aya-ZIbra@users.noreply.github.com> Date: Fri, 5 Sep 2025 18:10:03 -0700 Subject: [PATCH 07/10] Avoid use of booleans The current expression is confusing --- examples/77_blackwell_fmha/collective/fmha_fusion.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/77_blackwell_fmha/collective/fmha_fusion.hpp b/examples/77_blackwell_fmha/collective/fmha_fusion.hpp index 7bc7935c21..6b0d78c1fe 100644 --- a/examples/77_blackwell_fmha/collective/fmha_fusion.hpp +++ b/examples/77_blackwell_fmha/collective/fmha_fusion.hpp @@ -225,8 +225,8 @@ struct CausalMask : NoMask { if constexpr (IsQBegin) { return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape)))); } else { - const int corner_count = int((get<1>(problem_size) % get<1>(tile_shape) || get<0>(problem_size) % get<0>(tile_shape))) ; - return std::min(trip_count, int(ceil_div(get<0>(tile_shape), get<1>(tile_shape))) + corner_count); + int offset_tile_q = (get<1>(problem_size) % get<1>(tile_shape)) - (get<0>(problem_size) % get<1>(tile_shape)); + return std::min(trip_count, int(ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape)))); } } From bf593cba659ea46c27ddf7e969187735d6ceac3e Mon Sep 17 00:00:00 2001 From: "Aya Z. Ibrahim" <56703597+Aya-ZIbra@users.noreply.github.com> Date: Fri, 5 Sep 2025 18:12:27 -0700 Subject: [PATCH 08/10] fmt --- examples/77_blackwell_fmha/collective/fmha_fusion.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/77_blackwell_fmha/collective/fmha_fusion.hpp b/examples/77_blackwell_fmha/collective/fmha_fusion.hpp index 6b0d78c1fe..a6e7f3f104 100644 --- a/examples/77_blackwell_fmha/collective/fmha_fusion.hpp +++ b/examples/77_blackwell_fmha/collective/fmha_fusion.hpp @@ -225,7 +225,7 @@ struct CausalMask : NoMask { if constexpr (IsQBegin) { return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape)))); } else { - int offset_tile_q = (get<1>(problem_size) % get<1>(tile_shape)) - (get<0>(problem_size) % get<1>(tile_shape)); + const int offset_tile_q = (get<1>(problem_size) % get<1>(tile_shape)) - (get<0>(problem_size) % get<1>(tile_shape)); return std::min(trip_count, int(ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape)))); } } From ff829fd8a39ba33a5cdc2d1f5787576985c45667 Mon Sep 17 00:00:00 2001 From: "Aya Z. Ibrahim" <56703597+Aya-ZIbra@users.noreply.github.com> Date: Wed, 10 Sep 2025 15:19:28 -0700 Subject: [PATCH 09/10] Update fmha_fusion.hpp Reproduce error/fix with: ./77_blackwell_fmha_fp16 --verify --b=1 --q=1013 --k=1024 --h=1 --h_k=1 --mask=causal --causal-type=qend --- examples/77_blackwell_fmha/collective/fmha_fusion.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/77_blackwell_fmha/collective/fmha_fusion.hpp b/examples/77_blackwell_fmha/collective/fmha_fusion.hpp index a6e7f3f104..409dda5b93 100644 --- a/examples/77_blackwell_fmha/collective/fmha_fusion.hpp +++ b/examples/77_blackwell_fmha/collective/fmha_fusion.hpp @@ -225,7 +225,7 @@ struct CausalMask : NoMask { if constexpr (IsQBegin) { return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape)))); } else { - const int offset_tile_q = (get<1>(problem_size) % get<1>(tile_shape)) - (get<0>(problem_size) % get<1>(tile_shape)); + const int offset_tile_q = (get<1>(problem_size) - get<0>(problem_size) ) % get<1>(tile_shape); return std::min(trip_count, int(ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape)))); } } From bba78db6871c2dc5d70e71fe0bee6084b0f9eebd Mon Sep 17 00:00:00 2001 From: Richard Cai Date: Wed, 10 Sep 2025 15:41:08 -0700 Subject: [PATCH 10/10] add test, format --- examples/77_blackwell_fmha/CMakeLists.txt | 8 +++++--- examples/77_blackwell_fmha/collective/fmha_fusion.hpp | 3 ++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/examples/77_blackwell_fmha/CMakeLists.txt b/examples/77_blackwell_fmha/CMakeLists.txt index edaf76f9e0..b7b0b21d8d 100644 --- a/examples/77_blackwell_fmha/CMakeLists.txt +++ b/examples/77_blackwell_fmha/CMakeLists.txt @@ -39,7 +39,8 @@ set_property( ) set(TEST_BASIC --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=no) -set(TEST_CAUSAL --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=causal) +set(TEST_CAUSAL_00 --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=causal) +set(TEST_CAUSAL_01 --verify --iterations=0 --b=1 --h=1 --h_k=1 --q=1013 --k=1024 --d=128 --mask=causal --causal-type=qend) set(TEST_VARLEN --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=residual --varlen) set(TEST_HDIM64 --b=2 --h=4 --q=512 --k=512 --d=64 --verify) set(TEST_GQA --b=2 --h=4 --h_k=2 --q=512 --k=512 --d=64 --verify) @@ -115,7 +116,8 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC 77_blackwell_fmha.cu TEST_COMMAND_OPTIONS TEST_BASIC - TEST_CAUSAL + TEST_CAUSAL_00 + TEST_CAUSAL_01 TEST_VARLEN TEST_HDIM64 TEST_GQA @@ -216,7 +218,7 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC 77_blackwell_mla_fwd.cu TEST_COMMAND_OPTIONS TEST_BASIC - TEST_CAUSAL + TEST_CAUSAL_00 TEST_VARLEN TEST_HDIM64 TEST_GQA diff --git a/examples/77_blackwell_fmha/collective/fmha_fusion.hpp b/examples/77_blackwell_fmha/collective/fmha_fusion.hpp index 409dda5b93..a33ce2d2ce 100644 --- a/examples/77_blackwell_fmha/collective/fmha_fusion.hpp +++ b/examples/77_blackwell_fmha/collective/fmha_fusion.hpp @@ -225,7 +225,7 @@ struct CausalMask : NoMask { if constexpr (IsQBegin) { return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape)))); } else { - const int offset_tile_q = (get<1>(problem_size) - get<0>(problem_size) ) % get<1>(tile_shape); + const int offset_tile_q = (get<1>(problem_size) - get<0>(problem_size)) % get<1>(tile_shape); return std::min(trip_count, int(ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape)))); } } @@ -275,6 +275,7 @@ struct CausalMask : NoMask { } } }; + template struct CausalForBackwardMask : CausalMask, ResidualMaskForBackward {