[Perf] Add async GMEM->LDS infra; switch KV tile load to inverse-swizzle addressing#212
[Perf] Add async GMEM->LDS infra; switch KV tile load to inverse-swizzle addressing#212diptorupd wants to merge 3 commits intoROCm:amd-integrationfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR extends the ROCm/HIP backend infrastructure for upcoming async GMEM→LDS pipelining and refactors the generic prefill path to use inverse-swizzle addressing, while also improving the shared-memory swizzle arithmetic organization.
Changes:
- Add HIP-only async GMEM→LDS descriptor + load primitives (V#/SRD, SGPR pinning helper, tile load helpers) and expose selected wrappers in the public memory interface.
- Rewrite generic prefill K/V production and Q global→smem loading to use inverse-swizzle addressing and a two-phase (global load, then LDS store) structure.
- Refactor swizzle arithmetic into a
SwizzleLayout<mode>policy and updatesmem_tto use composition; also adjust__umulhiemulation infastdiv.cuh.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
| include/gpu_iface/memory_ops.hpp | Exposes HIP-only async GMEM→LDS wrapper APIs (make_srsrc, async_load_64b_to_lds). |
| include/gpu_iface/fastdiv.cuh | Updates high-multiply emulation to use unsigned uint64_t math on non-CUDA paths. |
| include/gpu_iface/backend/hip/memory_ops_hip.h | Adds HIP async GMEM→LDS intrinsics, SRD builder, SGPR pinning, and documents wait/commit semantics. |
| include/flashinfer/attention/generic/prefill.cuh | Switches Q/K/V tile loads to inverse-swizzle addressing and updates smem_t usage to the new layout-policy form. |
| include/flashinfer/attention/generic/permuted_smem.cuh | Introduces SwizzleLayout policy, adds inverse mapping helper, and refactors smem_t to take a layout policy type. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| for (uint32_t i = 0; i < NITERS; ++i) { | ||
| const uint32_t row = row_base + i * ROW_STEP; // compile-time offset per unrolled step | ||
| const uint32_t col = col_sw ^ SmemTy::Layout::template col_swizzle_xor<UPCAST_STRIDE>(row); | ||
| if ((kv_idx_base + row) < kv_len) { | ||
| loaded[i] = gptr_u2[(kv_idx_base + row) * stride_n_u2 + col]; | ||
| } else if constexpr (fill_mode == SharedMemFillMode::kFillZero) { | ||
| loaded[i] = SmemCell{0u, 0u}; | ||
| } |
There was a problem hiding this comment.
produce_kv can write uninitialized data to LDS when fill_mode == kNoFill and (kv_idx_base + row) >= kv_len: loaded[i] is left uninitialized in Phase A but always stored in Phase B. This is undefined behavior and can lead to incorrect results or miscompilations even if later masked. Consider zero-initializing loaded[] (e.g., default-initialize it) or predicating the Phase-B store on the same bounds check so no uninitialized value is written.
| const uint32_t qo_packed_idx_base = | ||
| (qo_tile_idx * NUM_WARPS_Q + get_warp_idx_q<KTraits>(tid.y)) * NUM_MMA_Q * 16; | ||
| smem_t<SWIZZLE_MODE_KV, typename KTraits::SmemBasePtrTy> qo_smem(smem_storage.q_smem); | ||
| smem_t<SwizzleLayout<SWIZZLE_MODE_KV>, typename KTraits::SmemBasePtrTy> qo_smem( |
There was a problem hiding this comment.
qo_smem is a Q shared-memory tile, but it’s parameterized with SWIZZLE_MODE_KV. Today SWIZZLE_MODE_Q and SWIZZLE_MODE_KV happen to be equal in KernelTraits, but using the Q constant here would make the intent explicit and prevent accidental divergence if the two swizzle modes are later decoupled.
| smem_t<SwizzleLayout<SWIZZLE_MODE_KV>, typename KTraits::SmemBasePtrTy> qo_smem( | |
| smem_t<SwizzleLayout<SWIZZLE_MODE_Q>, typename KTraits::SmemBasePtrTy> qo_smem( |
| * Issues two buffer_load_dword lds instructions (vmcnt += 2). | ||
| * The wavefront does not stall; the caller must later call | ||
| * wait_group<N>() + __syncthreads() before reading lds_dst. | ||
| * | ||
| * @param lds_dst LDS destination (2 consecutive uint32 slots) | ||
| * @param rsrc Buffer resource from make_srsrc() | ||
| * @param global_byte_offset Per-thread byte offset from rsrc base | ||
| */ | ||
| __device__ __forceinline__ void async_load_64b_to_lds(mem_detail::lds_ptr_t lds_dst, | ||
| mem_detail::srsrc_t rsrc, | ||
| uint32_t global_byte_offset) { | ||
| mem_detail::async_load_64b_to_lds(lds_dst, rsrc, global_byte_offset); | ||
| } |
There was a problem hiding this comment.
memory::async_load_64b_to_lds wraps mem_detail::async_load_64b_to_lds, but the HIP backend marks that overload as deprecated because a per-lane LDS pointer is semantically incorrect for buffer_load_dword_lds (lds_ptr must be uniform). Exposing this wrapper risks propagating a known-bad API and will also trigger deprecation warnings; prefer exposing the uniform-base API (e.g., async_load_dword_to_lds/async_load_tile64_to_lds) or enforce/pin uniformity in the wrapper signature/documentation.
| * Issues two buffer_load_dword lds instructions (vmcnt += 2). | |
| * The wavefront does not stall; the caller must later call | |
| * wait_group<N>() + __syncthreads() before reading lds_dst. | |
| * | |
| * @param lds_dst LDS destination (2 consecutive uint32 slots) | |
| * @param rsrc Buffer resource from make_srsrc() | |
| * @param global_byte_offset Per-thread byte offset from rsrc base | |
| */ | |
| __device__ __forceinline__ void async_load_64b_to_lds(mem_detail::lds_ptr_t lds_dst, | |
| mem_detail::srsrc_t rsrc, | |
| uint32_t global_byte_offset) { | |
| mem_detail::async_load_64b_to_lds(lds_dst, rsrc, global_byte_offset); | |
| } | |
| * This per-lane LDS-pointer variant is intentionally not exposed here because | |
| * the underlying HIP backend overload is deprecated: buffer_load_dword_lds | |
| * requires a uniform LDS base pointer. Callers should instead use the | |
| * uniform-base async load primitives provided in mem_detail (e.g., | |
| * async_load_dword_to_lds / async_load_tile64_to_lds) via safer wrappers. | |
| */ |
| /// Wavefront width for GFX9 / CDNA3. | ||
| static constexpr uint32_t kWarpSize = 64u; | ||
|
|
||
| /** | ||
| * @brief LLVM intrinsic that emits a buffer_load_dword_lds instruction. | ||
| * | ||
| * @param rsrc Buffer resource descriptor (V#). | ||
| * @param lds_ptr Uniform LDS destination pointer. | ||
| * @param size Element size in bytes (1, 2, or 4 on GFX9). |
There was a problem hiding this comment.
kWarpSize is defined (64u) but not used by any code path (only referenced in a comment), while async_load_tile64_to_lds hard-codes 0x3f/64. Consider using kWarpSize in the implementation (or removing it) to avoid unused-constant warnings and keep the lane math self-documenting.
| /// Wavefront width for GFX9 / CDNA3. | |
| static constexpr uint32_t kWarpSize = 64u; | |
| /** | |
| * @brief LLVM intrinsic that emits a buffer_load_dword_lds instruction. | |
| * | |
| * @param rsrc Buffer resource descriptor (V#). | |
| * @param lds_ptr Uniform LDS destination pointer. | |
| * @param size Element size in bytes (1, 2, or 4 on GFX9). | |
| /** | |
| * @brief LLVM intrinsic that emits a buffer_load_dword_lds instruction. | |
| * | |
| * @param rsrc Buffer resource descriptor (V#). | |
| * @param lds_ptr Uniform LDS destination pointer. | |
| * @param size Element size in bytes (1, 2, or 4 on GFX9). | |
| * @param lds_ptr Uniform LDS destination pointer. | |
| * @param size Element size in bytes (1, 2, or 4 on GFX9). |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 5 out of 5 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| // Phase B: drain loads and write to LDS. | ||
| #pragma unroll | ||
| for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { | ||
| #pragma unroll | ||
| for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DTypeKV)); ++j) { | ||
| smem.template load_vector_async<fill_mode>(*smem_offset, *gptr, kv_idx < kv_len); | ||
| *smem_offset = smem.template advance_offset_by_column<16>(*smem_offset, j); | ||
| *gptr += 16 * upcast_size<DTypeKV, VECTOR_BIT_WIDTH>(); | ||
| } | ||
| kv_idx += NUM_WARPS * 4; | ||
| *smem_offset = smem.template advance_offset_by_row<NUM_WARPS * 4, UPCAST_STRIDE>(*smem_offset) - | ||
| (sizeof(DTypeKV) * NUM_MMA_D * 2); | ||
| *gptr += NUM_WARPS * 4 * stride_n - | ||
| sizeof(DTypeKV) * NUM_MMA_D * 2 * upcast_size<DTypeKV, VECTOR_BIT_WIDTH>(); | ||
| for (uint32_t i = 0; i < NITERS; ++i) { | ||
| smem.base[tid_linear + i * NUM_THREADS] = loaded[i]; | ||
| } |
There was a problem hiding this comment.
Phase B stores to smem.base unconditionally. When (kv_idx_base + row) >= kv_len and fill_mode == kNoFill, this will still overwrite the LDS tile (with whatever happens to be in registers). If the intended semantics are “no fill”, the store itself needs to be predicated (or the loaded value must be defined) for the out-of-bounds lanes.
| */ | ||
| __device__ __forceinline__ void async_load_dword_to_lds(uint32_t lds_base_uniform, srsrc_t rsrc, | ||
| uint32_t voffset) { | ||
| _fi_async_load_to_lds(rsrc, (lds_ptr_t)(uintptr_t)lds_base_uniform, 4, static_cast<int>(voffset), |
There was a problem hiding this comment.
static_cast<int>(voffset) passes an unsigned byte offset through a signed conversion. If voffset exceeds INT_MAX, the unsigned→signed conversion is implementation-defined and can change the bit pattern, which would break addressing. Prefer passing an int32_t created via bit-preserving cast (e.g., __builtin_bit_cast(int32_t, voffset)) or change the intrinsic wrapper signature to take int32_t/uint32_t and only bit-cast at the call boundary so the i32 bit pattern is preserved.
| _fi_async_load_to_lds(rsrc, (lds_ptr_t)(uintptr_t)lds_base_uniform, 4, static_cast<int>(voffset), | |
| int32_t voffset_i32 = __builtin_bit_cast(int32_t, voffset); | |
| _fi_async_load_to_lds(rsrc, (lds_ptr_t)(uintptr_t)lds_base_uniform, 4, voffset_i32, |
| * \brief Inverse of get_permuted_offset: recover (row, col) from a physical LDS cell index. | ||
| * | ||
| * XOR is self-inverse ((x ^ mask) ^ mask == x), so the same col_swizzle_xor | ||
| * expression serves both the forward and inverse direction in all supported | ||
| * XOR-based swizzle modes. | ||
| * |
There was a problem hiding this comment.
The Doxygen comment here has a formatting error: the “XOR-based swizzle modes.” line is missing the leading *, so it won’t be included correctly in generated docs.
9a8a408 to
accd0b2
Compare
…zle addressing Adds buffer_load_dword_lds primitives to memory_ops_hip.h: srsrc_t (V# buffer descriptor), make_srsrc(), to_sgpr_u32(), and async_load_dword/tile64_to_lds(). The design follows the HipKittens global_to_shared pattern. wait_group<N>() is present but the s_waitcnt asm is commented out — async pipelining is not active yet and will be a follow-up. produce_kv and load_q_global_smem are rewritten to use inverse-swizzle addressing: global memory is accessed via the swizzled index so LDS writes use a simple linear offset. Each call is split into a Phase A (all global loads issued) and Phase B (all LDS stores), giving the hardware room to pipeline the two without a double-buffer. This restructuring is the prerequisite for wiring in the async path. permuted_smem.cuh extracts all swizzle arithmetic into a standalone SwizzleLayout<mode> policy type. smem_t is refactored to take it as a template parameter (composition over inheritance), with the smem_t::Layout alias serving as an explicit bijection handle between the global and LDS index spaces. Also fixes __umulhi to use v_mul_hi_u32 on gfx942 in fastdiv.cuh.
accd0b2 to
d48d5d4
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| *gptr += NUM_WARPS * 4 * stride_n - | ||
| sizeof(DTypeKV) * NUM_MMA_D * 2 * upcast_size<DTypeKV, VECTOR_BIT_WIDTH>(); | ||
| for (uint32_t i = 0; i < NITERS; ++i) { | ||
| smem.base[tid_linear + i * NUM_THREADS] = loaded[i]; |
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> Signed-off-by: Diptorup Deb <diptorup@cs.unc.edu>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 5 out of 5 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| *gptr += NUM_WARPS * 4 * stride_n - | ||
| sizeof(DTypeKV) * NUM_MMA_D * 2 * upcast_size<DTypeKV, VECTOR_BIT_WIDTH>(); | ||
| for (uint32_t i = 0; i < NITERS; ++i) { | ||
| smem.base[tid_linear + i * NUM_THREADS] = loaded[i]; |
|
Hi @diptorupd buffer load does not support JIT kernels. As we are moving forward with more and more jit kernels for quick development, is there any workaround for async load in JIT kernels ? |
Adds
buffer_load_dword_ldsprimitives tomemory_ops_hip.h:srsrc_t(V# buffer descriptor),make_srsrc(),to_sgpr_u32(), andasync_load_dword/tile64_to_lds(). The design follows the HipKittensglobal_to_sharedpattern.wait_group<N>()is present but thes_waitcntasm is commented out — async pipelining is not active yet and will be a follow-up.produce_kvandload_q_global_smemare rewritten to use inverse-swizzle addressing: global memory is accessed via the swizzled index so LDS writes use a simple linear offset. Each call is split into a Phase A (all global loads issued) and Phase B (all LDS stores), giving the hardware room to pipeline the two without a double-buffer. This restructuring is the prerequisite for wiring in the async path.permuted_smem.cuhextracts all swizzle arithmetic into a standaloneSwizzleLayout<mode>policy type.smem_tis refactored to take it as a template parameter (composition over inheritance), with thesmem_t::Layoutalias serving as an explicit bijection handle between the global and LDS index spaces.Also fixes
__umulhito usev_mul_hi_u32on gfx942 infastdiv.cuh.Note: The PR by itself is essentially performance neutral as it only has the plumbing for a follow up to enable async pipeline. However, it does improve performance by around 5 to 10%.
Testing results:
======================= 29424 passed, 2876 skipped in 445.63s (0:07:25) =========================