Skip to content

[Perf] Add async GMEM->LDS infra; switch KV tile load to inverse-swizzle addressing#212

Open
diptorupd wants to merge 3 commits intoROCm:amd-integrationfrom
diptorupd:perf/async-pipeline
Open

[Perf] Add async GMEM->LDS infra; switch KV tile load to inverse-swizzle addressing#212
diptorupd wants to merge 3 commits intoROCm:amd-integrationfrom
diptorupd:perf/async-pipeline

Conversation

@diptorupd
Copy link
Copy Markdown
Collaborator

  • Adds buffer_load_dword_lds primitives to memory_ops_hip.h: srsrc_t (V# buffer descriptor), make_srsrc(), to_sgpr_u32(), and async_load_dword/tile64_to_lds(). The design follows the HipKittens global_to_shared pattern. wait_group<N>() is present but the s_waitcnt asm is commented out — async pipelining is not active yet and will be a follow-up.

  • produce_kv and load_q_global_smem are rewritten to use inverse-swizzle addressing: global memory is accessed via the swizzled index so LDS writes use a simple linear offset. Each call is split into a Phase A (all global loads issued) and Phase B (all LDS stores), giving the hardware room to pipeline the two without a double-buffer. This restructuring is the prerequisite for wiring in the async path.

  • permuted_smem.cuh extracts all swizzle arithmetic into a standalone SwizzleLayout<mode> policy type. smem_t is refactored to take it as a template parameter (composition over inheritance), with the smem_t::Layout alias serving as an explicit bijection handle between the global and LDS index spaces.

  • Also fixes __umulhi to use v_mul_hi_u32 on gfx942 in fastdiv.cuh.

Note: The PR by itself is essentially performance neutral as it only has the plumbing for a follow up to enable async pipeline. However, it does improve performance by around 5 to 10%.

Sequence Length Type amd-integration (TFLOPS) perf/async-pipeline (TFLOPS)
seq= 512 causal 32.95 34.12
seq= 1024 causal 43.55 45.99
seq= 2048 causal 50.12 52.17
seq= 4096 causal 73.69 76.37
seq= 8192 causal 84.48 86.86
seq= 512 non-causal 65.65 68.20
seq= 1024 non-causal 83.69 88.53
seq= 2048 non-causal 88.99 93.97
seq= 4096 non-causal 91.97 102.99

Testing results:
======================= 29424 passed, 2876 skipped in 445.63s (0:07:25) =========================

@diptorupd diptorupd changed the title [ROCm] Add async GMEM->LDS infra; switch KV tile load to inverse-swizzle addressing [Perf] Add async GMEM->LDS infra; switch KV tile load to inverse-swizzle addressing Apr 1, 2026
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends the ROCm/HIP backend infrastructure for upcoming async GMEM→LDS pipelining and refactors the generic prefill path to use inverse-swizzle addressing, while also improving the shared-memory swizzle arithmetic organization.

Changes:

  • Add HIP-only async GMEM→LDS descriptor + load primitives (V#/SRD, SGPR pinning helper, tile load helpers) and expose selected wrappers in the public memory interface.
  • Rewrite generic prefill K/V production and Q global→smem loading to use inverse-swizzle addressing and a two-phase (global load, then LDS store) structure.
  • Refactor swizzle arithmetic into a SwizzleLayout<mode> policy and update smem_t to use composition; also adjust __umulhi emulation in fastdiv.cuh.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 9 comments.

Show a summary per file
File Description
include/gpu_iface/memory_ops.hpp Exposes HIP-only async GMEM→LDS wrapper APIs (make_srsrc, async_load_64b_to_lds).
include/gpu_iface/fastdiv.cuh Updates high-multiply emulation to use unsigned uint64_t math on non-CUDA paths.
include/gpu_iface/backend/hip/memory_ops_hip.h Adds HIP async GMEM→LDS intrinsics, SRD builder, SGPR pinning, and documents wait/commit semantics.
include/flashinfer/attention/generic/prefill.cuh Switches Q/K/V tile loads to inverse-swizzle addressing and updates smem_t usage to the new layout-policy form.
include/flashinfer/attention/generic/permuted_smem.cuh Introduces SwizzleLayout policy, adds inverse mapping helper, and refactors smem_t to take a layout policy type.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +349 to +356
for (uint32_t i = 0; i < NITERS; ++i) {
const uint32_t row = row_base + i * ROW_STEP; // compile-time offset per unrolled step
const uint32_t col = col_sw ^ SmemTy::Layout::template col_swizzle_xor<UPCAST_STRIDE>(row);
if ((kv_idx_base + row) < kv_len) {
loaded[i] = gptr_u2[(kv_idx_base + row) * stride_n_u2 + col];
} else if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
loaded[i] = SmemCell{0u, 0u};
}
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

produce_kv can write uninitialized data to LDS when fill_mode == kNoFill and (kv_idx_base + row) >= kv_len: loaded[i] is left uninitialized in Phase A but always stored in Phase B. This is undefined behavior and can lead to incorrect results or miscompilations even if later masked. Consider zero-initializing loaded[] (e.g., default-initialize it) or predicating the Phase-B store on the same bounds check so no uninitialized value is written.

Copilot uses AI. Check for mistakes.
Comment thread include/flashinfer/attention/generic/prefill.cuh Outdated
Comment thread include/flashinfer/attention/generic/prefill.cuh Outdated
const uint32_t qo_packed_idx_base =
(qo_tile_idx * NUM_WARPS_Q + get_warp_idx_q<KTraits>(tid.y)) * NUM_MMA_Q * 16;
smem_t<SWIZZLE_MODE_KV, typename KTraits::SmemBasePtrTy> qo_smem(smem_storage.q_smem);
smem_t<SwizzleLayout<SWIZZLE_MODE_KV>, typename KTraits::SmemBasePtrTy> qo_smem(
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qo_smem is a Q shared-memory tile, but it’s parameterized with SWIZZLE_MODE_KV. Today SWIZZLE_MODE_Q and SWIZZLE_MODE_KV happen to be equal in KernelTraits, but using the Q constant here would make the intent explicit and prevent accidental divergence if the two swizzle modes are later decoupled.

Suggested change
smem_t<SwizzleLayout<SWIZZLE_MODE_KV>, typename KTraits::SmemBasePtrTy> qo_smem(
smem_t<SwizzleLayout<SWIZZLE_MODE_Q>, typename KTraits::SmemBasePtrTy> qo_smem(

Copilot uses AI. Check for mistakes.
Comment thread include/flashinfer/attention/generic/permuted_smem.cuh Outdated
Comment thread include/flashinfer/attention/generic/permuted_smem.cuh Outdated
Comment on lines +144 to +156
* Issues two buffer_load_dword lds instructions (vmcnt += 2).
* The wavefront does not stall; the caller must later call
* wait_group<N>() + __syncthreads() before reading lds_dst.
*
* @param lds_dst LDS destination (2 consecutive uint32 slots)
* @param rsrc Buffer resource from make_srsrc()
* @param global_byte_offset Per-thread byte offset from rsrc base
*/
__device__ __forceinline__ void async_load_64b_to_lds(mem_detail::lds_ptr_t lds_dst,
mem_detail::srsrc_t rsrc,
uint32_t global_byte_offset) {
mem_detail::async_load_64b_to_lds(lds_dst, rsrc, global_byte_offset);
}
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

memory::async_load_64b_to_lds wraps mem_detail::async_load_64b_to_lds, but the HIP backend marks that overload as deprecated because a per-lane LDS pointer is semantically incorrect for buffer_load_dword_lds (lds_ptr must be uniform). Exposing this wrapper risks propagating a known-bad API and will also trigger deprecation warnings; prefer exposing the uniform-base API (e.g., async_load_dword_to_lds/async_load_tile64_to_lds) or enforce/pin uniformity in the wrapper signature/documentation.

Suggested change
* Issues two buffer_load_dword lds instructions (vmcnt += 2).
* The wavefront does not stall; the caller must later call
* wait_group<N>() + __syncthreads() before reading lds_dst.
*
* @param lds_dst LDS destination (2 consecutive uint32 slots)
* @param rsrc Buffer resource from make_srsrc()
* @param global_byte_offset Per-thread byte offset from rsrc base
*/
__device__ __forceinline__ void async_load_64b_to_lds(mem_detail::lds_ptr_t lds_dst,
mem_detail::srsrc_t rsrc,
uint32_t global_byte_offset) {
mem_detail::async_load_64b_to_lds(lds_dst, rsrc, global_byte_offset);
}
* This per-lane LDS-pointer variant is intentionally not exposed here because
* the underlying HIP backend overload is deprecated: buffer_load_dword_lds
* requires a uniform LDS base pointer. Callers should instead use the
* uniform-base async load primitives provided in mem_detail (e.g.,
* async_load_dword_to_lds / async_load_tile64_to_lds) via safer wrappers.
*/

Copilot uses AI. Check for mistakes.
Comment thread include/gpu_iface/backend/hip/memory_ops_hip.h Outdated
Comment on lines +58 to +66
/// Wavefront width for GFX9 / CDNA3.
static constexpr uint32_t kWarpSize = 64u;

/**
* @brief LLVM intrinsic that emits a buffer_load_dword_lds instruction.
*
* @param rsrc Buffer resource descriptor (V#).
* @param lds_ptr Uniform LDS destination pointer.
* @param size Element size in bytes (1, 2, or 4 on GFX9).
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kWarpSize is defined (64u) but not used by any code path (only referenced in a comment), while async_load_tile64_to_lds hard-codes 0x3f/64. Consider using kWarpSize in the implementation (or removing it) to avoid unused-constant warnings and keep the lane math self-documenting.

Suggested change
/// Wavefront width for GFX9 / CDNA3.
static constexpr uint32_t kWarpSize = 64u;
/**
* @brief LLVM intrinsic that emits a buffer_load_dword_lds instruction.
*
* @param rsrc Buffer resource descriptor (V#).
* @param lds_ptr Uniform LDS destination pointer.
* @param size Element size in bytes (1, 2, or 4 on GFX9).
/**
* @brief LLVM intrinsic that emits a buffer_load_dword_lds instruction.
*
* @param rsrc Buffer resource descriptor (V#).
* @param lds_ptr Uniform LDS destination pointer.
* @param size Element size in bytes (1, 2, or 4 on GFX9).
* @param lds_ptr Uniform LDS destination pointer.
* @param size Element size in bytes (1, 2, or 4 on GFX9).

Copilot uses AI. Check for mistakes.
Copilot AI review requested due to automatic review settings April 1, 2026 18:22
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 5 out of 5 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +359 to 363
// Phase B: drain loads and write to LDS.
#pragma unroll
for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) {
#pragma unroll
for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DTypeKV)); ++j) {
smem.template load_vector_async<fill_mode>(*smem_offset, *gptr, kv_idx < kv_len);
*smem_offset = smem.template advance_offset_by_column<16>(*smem_offset, j);
*gptr += 16 * upcast_size<DTypeKV, VECTOR_BIT_WIDTH>();
}
kv_idx += NUM_WARPS * 4;
*smem_offset = smem.template advance_offset_by_row<NUM_WARPS * 4, UPCAST_STRIDE>(*smem_offset) -
(sizeof(DTypeKV) * NUM_MMA_D * 2);
*gptr += NUM_WARPS * 4 * stride_n -
sizeof(DTypeKV) * NUM_MMA_D * 2 * upcast_size<DTypeKV, VECTOR_BIT_WIDTH>();
for (uint32_t i = 0; i < NITERS; ++i) {
smem.base[tid_linear + i * NUM_THREADS] = loaded[i];
}
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Phase B stores to smem.base unconditionally. When (kv_idx_base + row) >= kv_len and fill_mode == kNoFill, this will still overwrite the LDS tile (with whatever happens to be in registers). If the intended semantics are “no fill”, the store itself needs to be predicated (or the loaded value must be defined) for the out-of-bounds lanes.

Copilot uses AI. Check for mistakes.
*/
__device__ __forceinline__ void async_load_dword_to_lds(uint32_t lds_base_uniform, srsrc_t rsrc,
uint32_t voffset) {
_fi_async_load_to_lds(rsrc, (lds_ptr_t)(uintptr_t)lds_base_uniform, 4, static_cast<int>(voffset),
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

static_cast<int>(voffset) passes an unsigned byte offset through a signed conversion. If voffset exceeds INT_MAX, the unsigned→signed conversion is implementation-defined and can change the bit pattern, which would break addressing. Prefer passing an int32_t created via bit-preserving cast (e.g., __builtin_bit_cast(int32_t, voffset)) or change the intrinsic wrapper signature to take int32_t/uint32_t and only bit-cast at the call boundary so the i32 bit pattern is preserved.

Suggested change
_fi_async_load_to_lds(rsrc, (lds_ptr_t)(uintptr_t)lds_base_uniform, 4, static_cast<int>(voffset),
int32_t voffset_i32 = __builtin_bit_cast(int32_t, voffset);
_fi_async_load_to_lds(rsrc, (lds_ptr_t)(uintptr_t)lds_base_uniform, 4, voffset_i32,

Copilot uses AI. Check for mistakes.
Comment on lines +97 to +102
* \brief Inverse of get_permuted_offset: recover (row, col) from a physical LDS cell index.
*
* XOR is self-inverse ((x ^ mask) ^ mask == x), so the same col_swizzle_xor
* expression serves both the forward and inverse direction in all supported
* XOR-based swizzle modes.
*
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Doxygen comment here has a formatting error: the “XOR-based swizzle modes.” line is missing the leading *, so it won’t be included correctly in generated docs.

Copilot uses AI. Check for mistakes.
@diptorupd diptorupd marked this pull request as draft April 2, 2026 13:33
@diptorupd diptorupd force-pushed the perf/async-pipeline branch 2 times, most recently from 9a8a408 to accd0b2 Compare April 16, 2026 13:51
…zle addressing

Adds buffer_load_dword_lds primitives to memory_ops_hip.h: srsrc_t (V# buffer
descriptor), make_srsrc(), to_sgpr_u32(), and async_load_dword/tile64_to_lds().
The design follows the HipKittens global_to_shared pattern.  wait_group<N>() is
present but the s_waitcnt asm is commented out — async pipelining is not active
yet and will be a follow-up.
produce_kv and load_q_global_smem are rewritten to use inverse-swizzle
addressing: global memory is accessed via the swizzled index so LDS writes
use a simple linear offset.  Each call is split into a Phase A (all global
loads issued) and Phase B (all LDS stores), giving the hardware room to
pipeline the two without a double-buffer.  This restructuring is the
prerequisite for wiring in the async path.
permuted_smem.cuh extracts all swizzle arithmetic into a standalone
SwizzleLayout<mode> policy type.  smem_t is refactored to take it as a
template parameter (composition over inheritance), with the smem_t::Layout
alias serving as an explicit bijection handle between the global and LDS
index spaces.
Also fixes __umulhi to use v_mul_hi_u32 on gfx942 in fastdiv.cuh.
@diptorupd diptorupd force-pushed the perf/async-pipeline branch from accd0b2 to d48d5d4 Compare April 16, 2026 13:51
@diptorupd diptorupd marked this pull request as ready for review April 16, 2026 14:07
Copilot AI review requested due to automatic review settings April 16, 2026 14:07
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

*gptr += NUM_WARPS * 4 * stride_n -
sizeof(DTypeKV) * NUM_MMA_D * 2 * upcast_size<DTypeKV, VECTOR_BIT_WIDTH>();
for (uint32_t i = 0; i < NITERS; ++i) {
smem.base[tid_linear + i * NUM_THREADS] = loaded[i];
Comment thread include/gpu_iface/memory_ops.hpp Outdated
diptorupd and others added 2 commits April 18, 2026 18:10
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Diptorup Deb <diptorup@cs.unc.edu>
Copilot AI review requested due to automatic review settings April 18, 2026 19:45
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 5 out of 5 changed files in this pull request and generated 1 comment.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

*gptr += NUM_WARPS * 4 * stride_n -
sizeof(DTypeKV) * NUM_MMA_D * 2 * upcast_size<DTypeKV, VECTOR_BIT_WIDTH>();
for (uint32_t i = 0; i < NITERS; ++i) {
smem.base[tid_linear + i * NUM_THREADS] = loaded[i];
@yiakwy-xpu-ml-framework-team
Copy link
Copy Markdown

Hi @diptorupd buffer load does not support JIT kernels. As we are moving forward with more and more jit kernels for quick development, is there any workaround for async load in JIT kernels ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants