-
Notifications
You must be signed in to change notification settings - Fork 381
[Feat] Adapt gemm v2 for cutedsl backend #1544
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughThe changes refactor CuTeDSL GEMM support by removing environment variable-based GEMM v1 requirement handling and introducing direct CuTeDSL backend integration. This includes adding CuTeDSL as a GEMM implementation option, introducing target normalization logic for CuTeDSL, and simplifying execution backend resolution. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tilelang/tileop/gemm/__init__.py (1)
192-197: Unreachable code: duplicateis_tcgen5mma()check.Lines 196-197 are unreachable because
is_tcgen5mma()is already handled at lines 192-193. This appears to be pre-existing dead code, but worth cleaning up.🔎 Proposed fix
elif gemm_inst.is_tcgen5mma(): return GemmTCGEN5 elif gemm_inst.is_mfma(): return GemmMFMA - elif gemm_inst.is_tcgen5mma(): - raise NotImplementedError("TCGEN5MMA is not implemented") else: raise ValueError(f"Unsupported GEMM instruction: {gemm_inst}")
🧹 Nitpick comments (1)
tilelang/utils/target.py (1)
130-130: Consider using%sor f-string withoutstr()for exception message.The static analyzer flags
str(e)as unnecessary since{e}in an f-string already converts to string. However, usingstr(e)explicitly is a common defensive pattern and not incorrect.🔎 Optional cleanup
- raise AssertionError(f"CuTeDSL backend is not available. Please install tilelang-cutedsl package. {str(e)}") from e + raise AssertionError(f"CuTeDSL backend is not available. Please install tilelang-cutedsl package. {e}") from e
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
testing/python/cache/test_tilelang_kernel_cache.pytesting/python/jit/test_tilelang_jit_cutedsl.pytilelang/jit/execution_backend.pytilelang/tileop/gemm/__init__.pytilelang/tileop/gemm/gemm_cutedsl.pytilelang/utils/target.py
💤 Files with no reviewable changes (3)
- testing/python/cache/test_tilelang_kernel_cache.py
- testing/python/jit/test_tilelang_jit_cutedsl.py
- tilelang/jit/execution_backend.py
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1483
File: tilelang/jit/adapter/cutedsl/adapter.py:93-95
Timestamp: 2025-12-26T06:45:47.669Z
Learning: For the CuTeDSL backend in tilelang/jit/adapter/cutedsl/adapter.py, the host_kernel_source and device_kernel_source have the same value.
📚 Learning: 2025-12-26T06:45:47.669Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1483
File: tilelang/jit/adapter/cutedsl/adapter.py:93-95
Timestamp: 2025-12-26T06:45:47.669Z
Learning: For the CuTeDSL backend in tilelang/jit/adapter/cutedsl/adapter.py, the host_kernel_source and device_kernel_source have the same value.
Applied to files:
tilelang/tileop/gemm/gemm_cutedsl.pytilelang/tileop/gemm/__init__.py
🧬 Code graph analysis (2)
tilelang/tileop/gemm/gemm_cutedsl.py (7)
tilelang/tileop/gemm/gemm_base.py (2)
GemmBase(13-169)mbar(128-129)tilelang/language/ast/ir.py (1)
target(1677-1707)tilelang/tileop/gemm/__init__.py (12)
infer_layout(138-142)GemmInst(35-54)is_wgmma(44-45)lower(144-148)A(67-68)B(71-72)C(75-76)trans_A(103-104)trans_B(107-108)clear_accum(127-128)k_pack(131-132)wg_wait(135-136)tilelang/tileop/gemm/gemm_wgmma.py (1)
GemmWGMMA(14-136)tilelang/tileop/gemm/gemm_mma.py (1)
GemmMMA(14-222)tilelang/language/gemm_op.py (1)
gemm_v1(130-155)tilelang/transform/simplify.py (1)
_Simplify(31-49)
tilelang/utils/target.py (3)
tilelang/language/ast/ir.py (1)
target(1677-1707)tilelang/jit/__init__.py (3)
jit(437-437)jit(441-451)jit(454-533)tilelang/jit/adapter/cutedsl/checks.py (1)
check_cutedsl_available(37-88)
🪛 Ruff (0.14.10)
tilelang/tileop/gemm/gemm_cutedsl.py
38-38: Unused method argument: layout_map
(ARG002)
38-38: Unused method argument: target
(ARG002)
38-38: Unused method argument: thread_nums
(ARG002)
38-38: Unused method argument: thread_var
(ARG002)
tilelang/utils/target.py
79-79: Do not catch blind exception: Exception
(BLE001)
130-130: Avoid specifying long messages outside the exception class
(TRY003)
130-130: Use explicit conversion flag
Replace with conversion flag
(RUF010)
140-140: Avoid specifying long messages outside the exception class
(TRY003)
145-148: Avoid specifying long messages outside the exception class
(TRY003)
151-151: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
🔇 Additional comments (8)
tilelang/tileop/gemm/__init__.py (2)
14-17: LGTM on new imports.The imports for
GemmCuTeDSLandis_cutedsl_targetare correctly added to support the new CuTeDSL backend integration.
182-185: LGTM on CuTeDSL dispatch logic.The early return for CuTeDSL targets is correctly placed before the instruction-based dispatch, ensuring the CuTeDSL backend bypasses the standard lowering path.
tilelang/utils/target.py (3)
63-82: LGTM onnormalize_cutedsl_targetimplementation.The function correctly handles both
Targetobjects and string inputs. The string replacement approach (cutedsl→cuda) withcutedsladded to keys is a clean way to create a CUDA-compatible target that's identifiable as CuTeDSL.The broad
except Exceptionat line 79 is acceptable here sinceTarget()construction can fail for various reasons (invalid syntax, unsupported options, etc.), and returningNoneis the appropriate fallback for normalization failure.
123-132: LGTM on CuTeDSL availability check with lazy import.The lazy import pattern correctly defers the CuTeDSL dependency check until a CuTeDSL target is actually requested, avoiding unnecessary import overhead for other backends.
133-151: LGTM on fallback target validation.The validation logic correctly handles
Targetobjects, non-empty strings, and provides helpful error messages with supported target examples.tilelang/tileop/gemm/gemm_cutedsl.py (3)
1-16: LGTM on module structure and class docstring.The module docstring and class docstring clearly explain the purpose: bypassing complex MMA/WGMMA lowering in favor of direct
tl::gemmintrinsic calls for the CuTeDSL backend.
17-36: LGTM on layout inference delegation.The
infer_layoutmethod correctly determines the underlying GEMM instruction type and delegates to the appropriate implementation (GemmWGMMA or GemmMMA) for layout inference. This ensures CuTeDSL gets compatible layouts despite using a different lowering path.
38-63: Verify intentional omission oflayout_map,target,thread_nums, andthread_varparameters.The
lowermethod doesn't uselayout_map,target,thread_nums, orthread_var, which are used by other implementations (e.g.,GemmMMA,GemmWGMMA) for fragment loading, thread indexing, and instruction selection. Please confirm this is intentional for the CuTeDSL path.If these are intentionally unused because
gemm_v1handles them internally, consider adding a brief comment or using_prefixes to suppress the linter warnings:🔎 Optional: Prefix unused parameters
- def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var): + def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var): # noqa: ARG002Or add underscore prefixes if the base class signature allows it.
Summary by CodeRabbit
New Features
Refactor
✏️ Tip: You can customize this high-level summary in your review settings.