Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions examples/77_blackwell_fmha/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ set_property(
)

set(TEST_BASIC --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=no)
set(TEST_CAUSAL --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=causal)
set(TEST_CAUSAL_00 --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=causal)
set(TEST_CAUSAL_01 --verify --iterations=0 --b=1 --h=1 --h_k=1 --q=1013 --k=1024 --d=128 --mask=causal --causal-type=qend)
set(TEST_VARLEN --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=residual --varlen)
set(TEST_HDIM64 --b=2 --h=4 --q=512 --k=512 --d=64 --verify)
set(TEST_GQA --b=2 --h=4 --h_k=2 --q=512 --k=512 --d=64 --verify)
Expand Down Expand Up @@ -119,7 +120,8 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
77_blackwell_fmha.cu
TEST_COMMAND_OPTIONS
TEST_BASIC
TEST_CAUSAL
TEST_CAUSAL_00
TEST_CAUSAL_01
TEST_VARLEN
TEST_HDIM64
TEST_GQA
Expand Down Expand Up @@ -222,7 +224,7 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
77_blackwell_mla_fwd.cu
TEST_COMMAND_OPTIONS
TEST_BASIC
TEST_CAUSAL
TEST_CAUSAL_00
TEST_VARLEN
TEST_HDIM64
TEST_GQA
Expand Down
4 changes: 2 additions & 2 deletions examples/77_blackwell_fmha/collective/fmha_fusion.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,8 @@ struct CausalMask : NoMask {
if constexpr (IsQBegin) {
return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape))));
} else {
const int corner_count = int((get<1>(problem_size) % get<1>(tile_shape) || get<0>(problem_size) % get<0>(tile_shape))) ;
return std::min(trip_count, int(ceil_div(get<0>(tile_shape), get<1>(tile_shape))) + corner_count);
const int offset_tile_q = (get<1>(problem_size) - get<0>(problem_size)) % get<1>(tile_shape);
return std::min(trip_count, int(ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape))));
}
}

Expand Down