Skip to content

[AMD] Add MI350 (gfx950 / HIP) support for TileKernels operators #13

Open
zhangnju wants to merge 2 commits intodeepseek-ai:mainfrom
zhangnju:rocm_gfx950_support
Open

[AMD] Add MI350 (gfx950 / HIP) support for TileKernels operators #13
zhangnju wants to merge 2 commits intodeepseek-ai:mainfrom
zhangnju:rocm_gfx950_support

Conversation

@zhangnju
Copy link
Copy Markdown

@zhangnju zhangnju commented Apr 25, 2026

This PR adds AMD MI350 (gfx950) support for TileKernels library by fixing some HIP incompatibilities while maintaining the compatibility with CUDA .
After this PR and TileLang PR #2099 , the supported operators on MI350 are:

engram:    engram_fused_weight, engram_gate_fwd, engram_grad_w_reduce, engram_hash
mhc:       mhc_expand, mhc_head_compute_mix, mhc_pre_norm_fn, mhc_pre_big_fuse,
           mhc_pre_split_mixes, mhc_sinkhorn
moe:       moe_aux_fi, group_count, inplace_unique_group_indices, mask_indices_by_tp,
           normalize_weight, top2_sum_gate, topk_gate, topk_sum_and_topk_group_idx
quant:     cast_back (e4m3 only), cast_back_e5m6 (partial Scale Factor),
           per_block_cast (e4m3 only), per_channel_cast, per_channel_cast_and_transpose,
           per_token_cast (e4m3 only), per_token_cast_to_e5m6 (partial Scale Factor),
           swiglu_fwd_per_channel_cast_transpose
transpose: batched_transpose

Operators still not supported on MI350 are in the below table, which are caused by fundamental CUDA-only features

Operator(s) Reason
get_fused_mapping, expand_to_fused, reduce_fused, per_channel_cast_fused, swiglu_bwd_per_token_cast, swiglu_fwd_per_token_cast __match_any_sync — CUDA warp match intrinsic, no AMD equivalent
mhc_post PDL (Programmatic Dependent Launch) — NV SM90+ only
mhc_pre_apply_mix, mhc_multilayer_recompute NV SM90+ architecture dependency
engram_gate_bwd T.get_warp_idx() — no AMD equivalent
per_block_cast_lossless FP4 / e2m1 not supported on AMD hardware

Core Changes

  1. tile_kernels/config.py — Add get_warp_size()
  • Runtime detection of current target (HIP/CUDA)
  • Returns 64 for CDNA family (gfx9xx / wave64), 32 otherwise
  • lru_cache-decorated, zero overhead after first call
  1. MoE kernels — Parameterize warp shuffle width
  • top2_sum_gate_kernel.py: All T.shfl_xor / T.shfl_sync calls gain width=warp_size; warp_reduce_sum macro changed from hardcoded 5 steps to log2(warp_size) steps
  • topk_sum_and_topk_group_idx_kernel.py: num_threads now uses get_warp_size(); num_tokens_per_block fixed to 1 (one wavefront per token)
  • moe/common.py: get_topk_group_idx accepts a warp_size parameter; fixes token_idx/lane_idx computation and shfl_sync width
  • get_fused_mapping_kernel.py: warp_size replaced with get_warp_size()
  1. normalize_weight_kernel.py — Two HIP backend bug fixes
  • T.alloc_var(init=float_literal) silently skips initialization on HIP → replaced with T.alloc_local((1,), ...) + explicit assignment
  • T.vectorized produces NaN outputs on HIP → switched to T.unroll when running on HIP
  1. mhc/pre_big_fuse_kernel.py — Shared memory sync fixes
  • Added T.sync_threads() before the thread_idx < 32 branch to prevent shared memory races on wave64
  • Replaced T.Pipelined(num_stages=2) with T.serial + manual T.sync_threads() to avoid AMD LDS double-buffer overflow
  • Inlined pre_mix computation (previously read from pre_mix_shared) to eliminate alloc_shared sync issues
  1. engram/engram_gate_kernel.py
  • Import get_warp_size for downstream use

Test Changes (tests/)

  • All test files: unified IS_HIP detection
  • Skip kernels with no AMD equivalent (__match_any_sync dependents: expand_to_fused, get_fused_mapping, reduce_fused)
  • test_pre_big_fuse.py: On HIP, different thread layout (64 vs 128 threads) causes minor bfloat16 accumulation differences → use assert_close(atol=2e-2)
  • test_engram_gate_fwd.py: HIP FMA contraction differs from CUDA → slightly relaxed output threshold

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant