Skip to content

Conversation

@benenzhu
Copy link
Contributor

@benenzhu benenzhu commented Jan 7, 2026

Fix the index_map selection so the MFMA layout will be right for float32 MFMA with transpose A. Problem only found for k_dim == 4.
Also open the corresponding CI previous closed by #1443

Previous generated code for
gemm_sr(128, 128, 128, True, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128) is:

    A_local[0] = ((float *)buf_dyn_shmem)[(
        ki * 64 + 
        (threadIdx.x & 15) * 16 + 
        ((threadIdx.x & 3) >> 1) * 4 +
        ((((threadIdx.x & 15) >> 2) + ki) & 1) * 8 +
        (threadIdx.x >> 4) + 512)];

The right should be:

    A_local[0] = ((float *)buf_dyn_shmem)[(
        ki * 64 + 
        (threadIdx.x & 3) +
        ((((threadIdx.x & 7) >> 2) + ((threadIdx.x & 63) >> 5)) & 1) * 4 + 
        ((((threadIdx.x & 15) >> 3) + (ki & 1)) & 1) * 8 + 
        (threadIdx.x >> 4) * 16) + 512];

Summary by CodeRabbit

  • Tests

    • Expanded GEMM coverage by enabling additional single-root test cases for both int8 and float inputs, including transposed input scenarios.
  • Bug Fixes

    • Improved handling of transposed matrix layouts in layout-mapping logic to ensure correct mapping for affected matrix shapes.

✏️ Tip: You can customize this high-level summary in your review settings.

…A when k_dim=4

Previously, when k_dim=4, the index_map always used non-transposed layout
for matrix A regardless of the transposed flag. This caused precision
issues for transposed GEMM operations on ROCm.

Re-enable the previously skipped test cases for trans_A=True with float dtype.
@github-actions
Copy link

github-actions bot commented Jan 7, 2026

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 7, 2026

📝 Walkthrough

Walkthrough

Enables additional GEMM SR test cases (int8 and float variants) and makes MFMA ldmatrix index-map selection transposed-aware for k_dim == 4, adjusting layout choices based on the transposed state.

Changes

Cohort / File(s) Summary
Test Coverage Expansion
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py
Adds/enables four GEMM SR parameterized test cases (two int8 configurations and two float configurations) by expanding the test parameter set; no runtime logic changed.
MFMA Macro Generation Logic
tilelang/intrinsics/mfma_macro_generator.py
get_ldmatrix_index_map now selects index_map and reverse_index_map based on the transposed flag when k_dim == 4, using transposed-specific layout mappings (previously used A layouts regardless of transpose).

Sequence Diagram(s)

(omitted — changes are limited to tests and internal mapping selection without a multi-component sequential flow)

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

  • #1136: Modifies MFMA/layout selection and introduces transposed-aware layout/thread-binding logic in mfma_macro_generator.py.

Suggested reviewers

  • LeiWang1999

Poem

🐰 I hopped through tests, restored the play,

Int8 and float now join the fray.
Maps learned to flip when matrices turn,
MFMA hums with lessons to learn.
A tiny hop — code's bright new day. 🥕

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main bug fix (correcting index_map for transposed A matrix in MFMA with k_dim==4) and the secondary objective (reopening ROCm CI for gemmsr tests), both reflected in the changeset.

✏️ Tip: You can configure your own custom Pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

📜 Recent review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c89ef66 and 2c4bce4.

📒 Files selected for processing (1)
  • testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py
🧰 Additional context used
🧠 Learnings (3)
📓 Common learnings
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.
📚 Learning: 2026-01-06T05:20:45.325Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1606
File: testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py:30-30
Timestamp: 2026-01-06T05:20:45.325Z
Learning: In `testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py`, the test validates that the `hoist_broadcast_values` transformation pass correctly identifies and hoists broadcast operations by checking for patterns in the generated kernel source code. The specific literal values used (e.g., 430) are not important for the test's purpose, as it does not validate numerical precision or actual stored tensor values.

Applied to files:

  • testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.

Applied to files:

  • testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py
🧬 Code graph analysis (1)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py (1)
tilelang/language/v2/dtypes.py (3)
  • int8 (243-243)
  • int32 (245-245)
  • float32 (300-300)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (2)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py (2)

432-433: LGTM! Int8 test cases complete the test matrix.

These additions re-enable int8 GEMM SR test cases with trans_A=False, complementing the existing trans_A=True cases at lines 434-435. Together, they provide comprehensive coverage of all transpose combinations for int8.


438-439: LGTM! Critical test cases for the transposed A bug fix.

These additions directly test the corrected MFMA layout index_map selection for float32 with trans_A=True and k_dim==4. Line 438 exactly matches the configuration example mentioned in the PR description: gemm_sr(128,128,128, True, False, T.float, T.float, T.float32, 128,128,32,2,128).

Together with the existing trans_A=False cases at lines 436-437, these complete the float32 test matrix and validate that the fix restores correct indexing for A_local when A is transposed.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@benenzhu benenzhu changed the title Fix(rocm-ci): correct index_map selection for transposed A matrix in MFMA Layout with k_dim==4 [BugFix] Correct index_map selection for transposed A matrix in MFMA Layout with k_dim==4 and open rocm-ci for gemmsr Jan 7, 2026
@LeiWang1999
Copy link
Member

cc @Gongen-Ali

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes a bug in the MFMA layout index map selection for transposed A matrices when k_dim==4 (used for float32 types), and re-enables previously disabled ROCm CI tests that were affected by this bug.

Key Changes:

  • Adds proper transpose handling for A matrix when k_dim==4 in the get_ldmatrix_index_map method
  • Re-enables 4 test cases for int8 and float32 GEMM operations that were previously disabled due to precision issues on ROCm

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.

File Description
tilelang/intrinsics/mfma_macro_generator.py Fixes index_map selection to use appropriate layout functions (4x16 vs 16x4) based on transpose flag for A matrix when k_dim==4, mirroring the existing logic for k_dim==16/32/64
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py Re-enables 4 test cases (2 int8, 2 float32) with various transpose combinations that were previously commented out

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@LeiWang1999 LeiWang1999 merged commit b914318 into tile-ai:main Jan 8, 2026
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants