diff --git a/examples/77_blackwell_fmha/CMakeLists.txt b/examples/77_blackwell_fmha/CMakeLists.txt index ac69081d30..65034d3d8e 100644 --- a/examples/77_blackwell_fmha/CMakeLists.txt +++ b/examples/77_blackwell_fmha/CMakeLists.txt @@ -39,7 +39,8 @@ set_property( ) set(TEST_BASIC --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=no) -set(TEST_CAUSAL --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=causal) +set(TEST_CAUSAL_00 --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=causal) +set(TEST_CAUSAL_01 --verify --iterations=0 --b=1 --h=1 --h_k=1 --q=1013 --k=1024 --d=128 --mask=causal --causal-type=qend) set(TEST_VARLEN --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=residual --varlen) set(TEST_HDIM64 --b=2 --h=4 --q=512 --k=512 --d=64 --verify) set(TEST_GQA --b=2 --h=4 --h_k=2 --q=512 --k=512 --d=64 --verify) @@ -119,7 +120,8 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC 77_blackwell_fmha.cu TEST_COMMAND_OPTIONS TEST_BASIC - TEST_CAUSAL + TEST_CAUSAL_00 + TEST_CAUSAL_01 TEST_VARLEN TEST_HDIM64 TEST_GQA @@ -222,7 +224,7 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC 77_blackwell_mla_fwd.cu TEST_COMMAND_OPTIONS TEST_BASIC - TEST_CAUSAL + TEST_CAUSAL_00 TEST_VARLEN TEST_HDIM64 TEST_GQA diff --git a/examples/77_blackwell_fmha/collective/fmha_fusion.hpp b/examples/77_blackwell_fmha/collective/fmha_fusion.hpp index dbdff428b7..a33ce2d2ce 100644 --- a/examples/77_blackwell_fmha/collective/fmha_fusion.hpp +++ b/examples/77_blackwell_fmha/collective/fmha_fusion.hpp @@ -225,8 +225,8 @@ struct CausalMask : NoMask { if constexpr (IsQBegin) { return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape)))); } else { - const int corner_count = int((get<1>(problem_size) % get<1>(tile_shape) || get<0>(problem_size) % get<0>(tile_shape))) ; - return std::min(trip_count, int(ceil_div(get<0>(tile_shape), get<1>(tile_shape))) + corner_count); + const int offset_tile_q = (get<1>(problem_size) - get<0>(problem_size)) % get<1>(tile_shape); + return std::min(trip_count, int(ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape)))); } }