Skip to content

Conversation

@lucifer1004
Copy link
Collaborator

@lucifer1004 lucifer1004 commented Dec 13, 2025

PR Description

Summary

This PR introduces CuTeDSL as a new TileLang target, providing an end-to-end path to lower → generate CuTeDSL Python → build/load runtime module → run via JIT, plus a dedicated test suite and CI wiring.

User-facing changes

  • Add a new target string: target="cutedsl" (implemented as a CUDA target with an extra cutedsl key).
  • execution_backend="auto" now resolves to cutedsl when the target is CuTeDSL; existing CUDA targets keep their previous defaults.

What’s included

  • Target / lowering glue
    • tilelang.utils.target.determine_target() recognizes cutedsl* targets and injects the cutedsl key.
    • tilelang.engine.lower.device_codegen*() dispatches to target.build.tilelang_cutedsl* when the target has the cutedsl key.
  • Codegen + runtime
    • New CuTeDSL codegen (src/target/codegen_cutedsl.*) built on top of the Python codegen infrastructure (src/target/codegen_py.*).
    • New runtime module builder for CuTeDSL without compile (src/target/rt_mod_cutedsl.cc).
  • JIT / execution backend
    • New CuTeDSL adapter under tilelang/jit/adapter/cutedsl/ and integration into tilelang.jit.
    • Execution backend resolver enforces CuTeDSL prerequisites and fails fast with actionable errors.
  • Kernel cache
    • Extend kernel cache to support CuTeDSL artifacts (Python module + launcher library/source). CuTeDSL can also populate additional artifacts after first execution via the adapter cache path hook.
  • Tests + CI
    • Add testing/python/jit/test_tilelang_jit_cutedsl.py.
    • CI and local CI script run CuTeDSL tests in a dedicated step with TILELANG_USE_GEMM_V1=1 (kept isolated to avoid changing default GEMM selection for other CUDA tests).

Requirements / constraints

  • CuTeDSL backend currently requires GEMM v1. Please set:
    • TILELANG_USE_GEMM_V1=1 before importing tilelang.

Backward compatibility

  • Intended to be fully backward compatible:
    • No behavior changes unless users explicitly select target="cutedsl" (or request execution_backend="cutedsl").
    • CI keeps non-CuTeDSL CUDA tests on the existing default GEMM path; CuTeDSL tests run separately with GEMM v1 env enabled.

Testing

  • Added unit/integration coverage:
    • testing/python/jit/test_tilelang_jit_cutedsl.py (correctness vs PyTorch matmul, profiler do_bench, multi-stream execution, dynamic shapes)
  • CI:
    • Runs CuTeDSL tests in a dedicated CUDA-only step with TILELANG_USE_GEMM_V1=1 and -n 1 to avoid xdist contention on a single GPU.

Notes for reviewers

  • Key integration points to review:
    • tilelang/utils/target.py (target parsing + key injection)
    • tilelang/engine/lower.py (codegen dispatch)
    • tilelang/jit/execution_backend.py (auto backend selection + GEMM v1 requirement)
    • tilelang/cache/kernel_cache.py (CuTeDSL artifact persistence)

Follow-ups (not in this PR)

  • Remove/relax the GEMM v1 requirement once CuTeDSL supports the default GEMM path.
  • Expand operator coverage and add more performance-focused benchmarks/examples.

Summary by CodeRabbit

  • New Features

    • Added CuTeDSL as a CUDA execution backend with runtime availability checks and GEMM‑v1 gating.
  • Kernel / Packaging

    • Cache/packaging now saves/loads CuTeDSL artifacts (launcher, cubin, Python host module) and supports optional launcher compilation.
  • JIT / Adapters

    • Full CuTeDSL JIT integration: adapter, library generator, source wrapper, kernel adapter, and execution paths.
  • Codegen

    • New TileLang Python and CuTeDSL device code generators to emit host and device sources.
  • Tests

    • Added comprehensive CuTeDSL GEMM tests (JIT, dynamic shapes, multi‑stream, benchmarks).
  • Chores

    • CI and local test runner updated to isolate/run CuTeDSL GEMM‑v1 tests.

✏️ Tip: You can customize this high-level summary in your review settings.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 13, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Adds a CuTeDSL backend and integration across build, C++/Python codegens, runtime builder, JIT adapter/wrapper/lib generator, Cutlass‑DSL contrib helpers, kernel caching/launcher artifact support, target mapping and JIT routing for a "cutedsl" execution backend, plus tests and CI adjustments to isolate GEMM v1 tests.

Changes

Cohort / File(s) Summary
Build / CMake
CMakeLists.txt, src/tl_templates/cuda/nvrtc_std.h
Add CuTeDSL CUDA sources to the CUDA build (codegen_py.cc, codegen_utils.cc, codegen_cutedsl.cc, rt_mod_cutedsl.cc) and annotate header endif.
C++ CuTeDSL codegen & utils
src/target/codegen_cutedsl.h, src/target/codegen_cutedsl.cc, src/target/codegen_utils.h, src/target/codegen_utils.cc, src/target/rt_mod_cutedsl.cc
New CodeGenTileLangCuTeDSL declaration/implementation, parentheses utilities, runtime builder BuildTileLangCuTeDSLWithoutCompile, ExtractFuncInfo, and FFI registration.
Python codegen base
src/target/codegen_py.h, src/target/codegen_py.cc
New CodeGenTileLangPY header/implementation providing Python-target codegen (AddFunction, Finish, many visitors and helpers).
JIT adapter / wrapper / libgen (CuTeDSL)
tilelang/jit/adapter/cutedsl/*
New CuTeDSLKernelAdapter, TLCuTeDSLSourceWrapper (C++ launcher & TMA templates), CuTeDSLLibraryGenerator (nvcc launcher path), runtime availability checks, exports and package init.
JIT integration & utils
tilelang/jit/__init__.py, tilelang/jit/adapter/__init__.py, tilelang/jit/adapter/cutedsl/__init__.py, tilelang/jit/adapter/nvrtc/adapter.py, tilelang/jit/adapter/nvrtc/wrapper.py, tilelang/jit/adapter/utils.py, tilelang/jit/execution_backend.py, tilelang/jit/kernel.py, tilelang/jit/adapter/wrapper.py
Wire "cutedsl" into public APIs and types, add is_cutedsl_target and extraction helpers, change wrapper.wrap contract to return a dict with tma/launcher fields, adapt NVRTC caller to dict result, and add GEMM‑v1 gating/validation for cutedsl.
Engine & target mapping
tilelang/engine/lower.py, tilelang/utils/target.py
Add cutedsl target alias (maps to cuda and appends "cutedsl" to target keys) and route lowering/build calls to cutedsl vs cuda.
Kernel cache & artifacts
tilelang/cache/kernel_cache.py
Add CuTeDSL artifact paths and constants (launcher_lib.so, launcher.cpp, kernel.py, kernel.cubin), extend execution_backend Literal to include "cutedsl", and add CuTeDSL-specific save/load/cache-path handling.
CuTeDSL contrib modules
tilelang/contrib/cutedsl/*
New Cutlass‑DSL helper modules: cp.async/TMA ops, GEMM V1 (SM80/SM90), ldmatrix/stmatrix, math, mbar, reduce, threadblock swizzle, pack/atomics, tensor helpers, and module exports.
Wrapper / NVRTC adjustments
tilelang/jit/adapter/nvrtc/adapter.py, tilelang/jit/adapter/wrapper.py, tilelang/jit/adapter/utils.py
NVRTC adapter now reads wrapper dict result; TLPyWrapper.wrap accepts py_source and returns dict; utils added extractors and floor_div_op support; wrapper selects TLCuTeDSLSourceWrapper for cutedsl targets.
Tests & CI / local test script
.github/workflows/ci.yml, maint/scripts/run_local_ci_test.sh, testing/python/jit/test_tilelang_jit_cutedsl.py, requirements-test-cuda.txt
Add comprehensive CuTeDSL JIT/GEMM tests (bench, multi-stream, dynamic shapes), isolate GEMM v1 tests in CI (TILELANG_USE_GEMM_V1=1) and local test script, and add nvidia-cutlass-dsl>=4.3.1 to CUDA test requirements.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant JIT as tilelang.jit
    participant Lower as tilelang.engine.lower
    participant CodeGen as CodeGenTileLangCuTeDSL
    participant LibGen as CuTeDSLLibraryGenerator
    participant Wrapper as TLCuTeDSLSourceWrapper
    participant Adapter as CuTeDSLKernelAdapter
    participant CUDA as CUDA Runtime

    User->>JIT: compile(func, target="cutedsl")
    JIT->>Lower: lower(IRModule, target)
    Lower->>CodeGen: BuildTileLangCuTeDSLWithoutCompile(mod, target)
    CodeGen-->>Lower: device source + metadata
    Lower->>LibGen: request compile_lib (kernel.py, kernel.cubin, launcher)
    LibGen->>LibGen: write kernel.py, optionally compile launcher.cpp → launcher_lib.so
    LibGen-->>Wrapper: provide artifacts (py host func, cubin, launcher)
    JIT->>Adapter: construct CuTeDSLKernelAdapter(..., artifacts)
    User->>Adapter: call converted torch function
    Adapter->>CUDA: launch via launcher_lib / cubin (stream, args)
    CUDA-->>User: kernel results
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

  • Areas needing extra attention:
    • src/target/codegen_cutedsl.{h,cc} — dense lowering, intrinsics, vectorization and dtype handling.
    • tilelang/jit/adapter/cutedsl/* — adapter, wrapper, libgen, launcher templates, nvcc invocation and FFI interplay.
    • tilelang/cache/kernel_cache.py — CuTeDSL artifact save/load symmetry and cache path propagation.
    • tilelang/jit/adapter/wrapper.py & tilelang/jit/adapter/nvrtc/adapter.py — changed wrap contract (tuple→dict) and all call sites.
    • tilelang/utils/target.py & tilelang/engine/lower.py — target aliasing and routing implications for other targets and tooling.
    • Tests/CI (testing/python/jit/test_tilelang_jit_cutedsl.py, .github/workflows/ci.yml, maint/scripts/run_local_ci_test.sh) — environment gating, dependency handling and GEMM v1 isolation.

Possibly related PRs

Suggested labels

enhancement

Suggested reviewers

  • oraluben
  • XuehaiPan

Poem

🐇 I hopped through code with nimble paws,

CuTeDSL sprung up with kernels and laws,
Launchers sing and GEMMs align,
Threads and tensors dance in line,
The rabbit cheers: "Build fast — now paws applause!"

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 33.68% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title 'feat(cutedsl): add CuTeDSL backend' accurately summarizes the main objective—introducing CuTeDSL as a new code generation target for TileLang.
✨ Finishing touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 32fe9b4 and 043ba95.

📒 Files selected for processing (1)
  • .github/workflows/ci.yml (1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.

Applied to files:

  • .github/workflows/ci.yml
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
  • GitHub Check: Build SDist
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.8
  • GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
🔇 Additional comments (2)
.github/workflows/ci.yml (2)

373-373: LGTM! Clean isolation of CuTeDSL tests.

The ignore flag properly excludes CuTeDSL tests from the main CUDA test run, preventing GEMM version conflicts. This coordinates well with the dedicated step below that runs these tests with GEMM v1.


376-392: Excellent isolation strategy for GEMM v1 tests!

The dedicated step cleanly separates CuTeDSL tests that require GEMM v1 from the main CUDA test suite. Key strengths:

  • Environment variable set at step level ensures proper isolation
  • Single worker (--numprocesses=1) prevents GPU contention
  • Clear comments explain the rationale and constraints
  • Consistent pytest command structure with other test steps

This approach effectively avoids GEMM version conflicts while maintaining comprehensive test coverage.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 9

Note

Due to the large number of review comments, Critical, Major severity comments were prioritized as inline comments.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tilelang/jit/kernel.py (1)

160-160: Missing "cutedsl" in from_database execution_backend type hint.

The from_database class method's execution_backend parameter type hint on line 160 doesn't include "cutedsl", but the __init__ method on line 66 does. This inconsistency could cause type checking issues.

-        execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"],
+        execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"],
🟡 Minor comments (14)
tilelang/contrib/cutedsl/mbar.py-17-19 (1)

17-19: Unused timeout_ns parameter.

The timeout_ns parameter is declared with a default value but never passed to arch.mbarrier_wait. If the underlying implementation supports timeouts, this should be forwarded; otherwise, remove the parameter to avoid confusion.

-def mbarrier_wait(mbar_ptr: Pointer, phase: Int, timeout_ns: Int = 10000000):
+def mbarrier_wait(mbar_ptr: Pointer, phase: Int):
     """Waits on a mbarrier with a specified phase."""
     arch.mbarrier_wait(mbar_ptr, phase)

Alternatively, if arch.mbarrier_wait accepts a timeout parameter, forward it:

 def mbarrier_wait(mbar_ptr: Pointer, phase: Int, timeout_ns: Int = 10000000):
     """Waits on a mbarrier with a specified phase."""
-    arch.mbarrier_wait(mbar_ptr, phase)
+    arch.mbarrier_wait(mbar_ptr, phase, timeout_ns)
src/target/rt_mod_cutedsl.cc-41-60 (1)

41-60: Confirm: CUDAModuleCreate parameters are semantically inconsistent with the actual content type.

In BuildTileLangCuTeDSLWithoutCompile, the function generates Python source code via CodeGenTileLangCuTeDSL, yet passes "ptx", "ptx" to CUDAModuleCreate. This mirrors the same pattern in BuildTileLangCUDAWithoutCompile (line 102 of rt_mod_cuda.cc), which also generates source code but uses "ptx" as a placeholder format.

However, this design is problematic: "ptx" is a real GPU compilation format (Parallel Thread Execution), not an appropriate identifier for uncompiled source code. In the compiled path (BuildTileLangCUDA), the first parameter is an actual binary and the second is the format identifier. In the uncompiled paths, using "ptx" as both a placeholder and format identifier for Python source is misleading and inconsistent with NVIDIA's standard format definitions (as used in tilelang/contrib/nvrtc.py and nvcc.py, which validate that formats must be "ptx", "cubin", or "fatbin").

Consider either: (1) using distinct format identifiers for source code (e.g., "python" or "cuda_src"), (2) clarifying the design with a code comment explaining why "ptx" is used as a placeholder, or (3) investigating if this is unintended and the correct format should reflect the actual content type.

src/target/codegen_cutedsl.cc-1105-1105 (1)

1105-1105: Typo: "ommited" should be "omitted".

Minor typo in comment.

tilelang/contrib/cutedsl/gemm_V1.py-179-218 (1)

179-218: gemm_rr also ignores use_wgmma and wg_wait parameters.

Same issue as gemm_sr — the parameters are accepted but unused.

tilelang/contrib/cutedsl/gemm_V1.py-232-247 (1)

232-247: Instance caching with __new__ may cause issues with mutable state.

The singleton pattern caches instances by args, but __init__ is still called on every __new__ return. The hasattr(self, "initialized") check prevents re-initialization, but initialized is never actually set to True.

         if not hasattr(self, "initialized"):
             self.cta_tiler = (M, N, K)
             ...
             self.clear_accum = clear_accum
+            self.initialized = True
tilelang/jit/adapter/cutedsl/libgen.py-78-78 (1)

78-78: Unused timeout parameter.

The timeout parameter is declared but never used in compile_lib. Either implement timeout support for subprocess calls or remove the parameter to match the actual behavior.

-    def compile_lib(self, timeout: float = None):
+    def compile_lib(self, timeout: float | None = None):
         target = self.target
         if is_cutedsl_target(target):
             ...
+            # Pass timeout to subprocess.run calls if provided
+            subprocess.run(..., timeout=timeout)

Committable suggestion skipped: line range outside the PR's diff.

tilelang/contrib/cutedsl/cpasync.py-170-197 (1)

170-197: Docstring parameter size vs actual parameter cp_size mismatch.

The docstring documents a size parameter (lines 181-182) but the actual parameter is cp_size. Also, the docstring for cp_size (lines 185-186) duplicates the description.

     """
     Asynchronously copy data from global memory to shared memory.

     :param dst: Destination pointer in shared memory
     :type dst: Pointer
     :param src: Source pointer in global memory
     :type src: Pointer
-    :param size: Size of the copy in bytes
-    :type size: Int
+    :param cp_size: Size of the copy in bytes
+    :type cp_size: Int
     :param modifier: Cache modifier
-    :type modifier: Int
-    :param cp_size: Optional copy size override
-    :type cp_size: Int
+    :type modifier: nvvm.LoadCacheModifierKind
+    :param src_size: Optional source size override (defaults to cp_size)
+    :type src_size: Int
     """
tilelang/contrib/cutedsl/gemm_V1.py-138-176 (1)

138-176: gemm_sr ignores use_wgmma and wg_wait parameters.

The function signature includes use_wgmma and wg_wait but the comment states "wgmma doesn't support gemm_sr" and always uses SM80. These parameters should either be removed or a warning/error raised when use_wgmma=True.

 def gemm_sr(
     ...
     use_wgmma=None,
     wg_wait=0,
     ...
 ):
     """GEMM with A from shared memory and B from register/fragment"""
-    # wgmma doesn't support gemm_sr, only use SM80
+    if use_wgmma:
+        raise ValueError("wgmma does not support gemm_sr (A from shared, B from register)")
     gemm = Gemm_SM80(...)
tilelang/contrib/cutedsl/reduce.py-140-159 (1)

140-159: Potential infinite recursion if scale is never reached.

The run method recurses until offset == self.scale. If scale is set incorrectly (e.g., larger than initial threads/2), this could recurse indefinitely. Consider adding a guard.

         return (
             x
             if offset == self.scale
-            else AllReduce(self.reducer, offset, self.scale, self.thread_offset, self.all_threads).run(x, red_buf)
+            else AllReduce(self.reducer, offset, self.scale, self.thread_offset, self.all_threads).run(x, red_buf)
+            if offset > self.scale
+            else x  # Reached minimum offset
         )

Alternatively, document the invariant that scale must be a power-of-2 divisor of threads/2.

tilelang/contrib/cutedsl/gemm_V1.py-380-411 (1)

380-411: Same issues in Gemm_SM90: _instances annotation and missing initialized = True.

Applying the same fixes as for Gemm_SM80.

+from typing import ClassVar
+
 class Gemm_SM90:
-    _instances = {}  # cache instances for the same arguments
+    _instances: ClassVar[dict] = {}

     ...
     def __init__(self, ...):
         if not hasattr(self, "initialized"):
             ...
             self.clear_accum = clear_accum
+            self.initialized = True
src/target/codegen_py.h-162-162 (1)

162-162: Typo in comment.

Minor typo: "statment" should be "statement".

-  // statment
+  // statement
tilelang/jit/adapter/cutedsl/adapter.py-93-93 (1)

93-93: File handle not properly closed.

Using open().read() without a context manager or explicit close can leak file handles.

-        self.device_kernel_source = open(self.libpath).read()
+        with open(self.libpath) as f:
+            self.device_kernel_source = f.read()
tilelang/jit/adapter/cutedsl/wrapper.py-1224-1224 (1)

1224-1224: Incorrect type annotation.

list[str] = None suggests the variable holds a list, but it's initialized to None. Use Optional[list[str]] or list[str] | None.

-            function_params: list[str] = None
+            function_params: list[str] | None = None
tilelang/jit/adapter/cutedsl/adapter.py-145-145 (1)

145-145: Same file handle issue in from_database.

-        adapter.kernel_global_source = open(kernel_lib_path).read()
+        with open(kernel_lib_path) as f:
+            adapter.kernel_global_source = f.read()
🧹 Nitpick comments (42)
src/tl_templates/cuda/nvrtc_std.h (1)

176-176: Good practice: documenting closing preprocessor directive.

Adding a comment to the closing #endif improves readability, especially when the corresponding #ifdef __CUDACC_RTC__ (line 20) is far above. This clearly documents intent and aids future maintenance.

maint/scripts/run_local_ci_test.sh (1)

21-21: Consider error handling for cd command.

While unlikely to fail in practice, adding error handling improves script robustness.

Apply this diff:

-cd testing/python
+cd testing/python || exit 1

Based on static analysis hints.

tilelang/jit/adapter/utils.py (1)

41-54: Unused annotation parameter.

The annotation parameter is declared but never used in the function body. Unlike match_declare_kernel and match_declare_kernel_cpu which use their annotation parameter, this function always matches @cute.kernel.

Either remove the unused parameter or incorporate it into the pattern:

-def match_declare_kernel_cutedsl(source: str, annotation: str = "@cute.kernel") -> int:
+def match_declare_kernel_cutedsl(source: str) -> int:
     # Match decorator followed by function definition across lines
     # \s+ allows any whitespace including newlines between decorator and def
     pattern = r"@cute\.kernel\s+def\s+(\w+)"
tilelang/utils/target.py (2)

99-106: LGTM - CuTeDSL target creation logic.

The approach of creating a CUDA target and augmenting its keys with "cutedsl" correctly establishes CuTeDSL as a CUDA-based backend. This enables downstream dispatch logic to identify and route to CuTeDSL-specific codegen paths.

Minor style improvement per Ruff suggestion - use unpacking syntax:

-        target_dict["keys"] = list(target_dict["keys"]) + ["cutedsl"]
+        target_dict["keys"] = [*list(target_dict["keys"]), "cutedsl"]

127-132: Redundant Target instance check.

Lines 127-128 early-return if return_var is a Target, but lines 130-131 perform the same check immediately after. The early return is only reachable when return_var is a Target (from line 106 or 110), making the second check at line 130 dead code in that path.

Consider consolidating:

     if isinstance(return_var, Target):
         return return_var
-    if return_object:
-        if isinstance(return_var, Target):
-            return return_var
+    if return_object:
         return Target(return_var)
     return return_var
tilelang/contrib/cutedsl/mbar.py (1)

6-12: Remove unnecessary noqa directives.

The # noqa: F401 comments are unnecessary since these imports are intentional re-exports for the module's public API, and the F401 rule isn't enabled in your configuration.

-from cutlass.cute.typing import Pointer, Int, Boolean  # noqa: F401
-from cutlass.cutlass_dsl import CuTeDSL, dsl_user_op  # noqa: F401
+from cutlass.cute.typing import Pointer, Int, Boolean
+from cutlass.cutlass_dsl import CuTeDSL, dsl_user_op

-from cutlass.cute.arch import mbarrier_init, mbarrier_expect_tx, mbarrier_arrive  # noqa: F401
-from cutlass.cute.arch import mbarrier_arrive_and_expect_tx as arrive_and_expect_tx  # noqa: F401
-from cutlass.cute.arch import cp_async_mbarrier_arrive_noinc as mbarrier_cp_async_arrive_noinc  # noqa: F401
+from cutlass.cute.arch import mbarrier_init, mbarrier_expect_tx, mbarrier_arrive
+from cutlass.cute.arch import mbarrier_arrive_and_expect_tx as arrive_and_expect_tx
+from cutlass.cute.arch import cp_async_mbarrier_arrive_noinc as mbarrier_cp_async_arrive_noinc
tilelang/contrib/cutedsl/__init__.py (4)

23-29: Consider defining __all__ in submodules for better traceability.

The star imports make it difficult to trace which symbols are exposed. While this pattern is common for re-export modules, consider adding explicit __all__ definitions in each submodule (.mbar, .cpasync, etc.) to document the intended public API and help static analysis tools.


53-60: Verify warp size assumption in shuffle_elect.

The function assumes a warp size of 32 (hardcoded divisor). While this is correct for current NVIDIA GPUs, consider adding a comment or using a named constant for clarity:

 def shuffle_elect(thread_extent):
     # thread_extent is the number of threads of a warpgroup
+    WARP_SIZE = 32
     warp_idx = cute.arch.warp_idx()
     warp_idx = cute.arch.make_warp_uniform(warp_idx)
     if thread_extent == 0:
         return warp_idx == 0
     else:
-        return (warp_idx % (thread_extent // 32)) == 0
+        return (warp_idx % (thread_extent // WARP_SIZE)) == 0

88-98: Consider removing has_side_effects=True for pure data packing.

The mov.b32 instruction is a pure data movement with no side effects—it only packs two 16-bit values into a 32-bit register. Setting has_side_effects=True may prevent beneficial compiler optimizations like dead code elimination or instruction reordering.

         packed_xy = llvm.inline_asm(
             Int32.mlir_type,
             [x_i16, y_i16],
             "mov.b32 $0, {$1, $2};",
             "=r,h,h",
-            has_side_effects=True,
+            has_side_effects=False,
             is_align_stack=False,
             asm_dialect=llvm.AsmDialect.AD_ATT,
             loc=loc,
             ip=ip,
         )

105-130: Consider expanding AtomicAdd to support additional dtypes.

The function currently supports only Float32 and Int32. Common use cases might also require Float64 (double precision) or Int64 support. Consider documenting this limitation or expanding support if needed.

tilelang/contrib/cutedsl/math.py (1)

11-44: Add return type annotations for consistency.

The divf function has a return type annotation, but the other math functions are missing them. For API consistency and better IDE support:

-def exp(x: Union[TensorSSA, Numeric], fastmath: bool = True):
+def exp(x: Union[TensorSSA, Numeric], fastmath: bool = True) -> Union[TensorSSA, Numeric]:
     return cute.math.exp(x, fastmath=fastmath)

Apply similar changes to exp2, log, log2, log10, tan, cos, sin, and sqrt.

src/target/codegen_cutedsl.cc (4)

22-39: Duplicate utility function - consider extracting to shared header.

CheckOutermostParenthesesMatch is also implemented in src/target/codegen_py.cc (per relevant snippets). Consider extracting to a shared utility header to avoid code duplication.


91-109: Consider explicit error messages for unsupported Float8 variants.

Several Float8 variants (e3m4, e4m3b11fnuz, e4m3fnuz, e5m2fnuz) silently fall through to the LOG(FATAL) at line 134, but the error message won't indicate which specific variant was attempted. Adding explicit error messages would improve debuggability.

     } else if (t.is_float8_e3m4()) {
-      // unsupported
+      LOG(FATAL) << "Float8 e3m4 is not supported in CuTeDSL";
     } else if (t.is_float8_e4m3()) {

265-629: Consider refactoring the large CallNode visitor.

This visitor function spans ~365 lines with a long if-else chain. While it follows the established TVM codegen pattern, consider extracting groups of related operations (e.g., TMA ops, mbarrier ops, math intrinsics) into separate helper methods for improved maintainability.


1072-1073: Variable shadowing: args shadows outer parameter.

The inner args variable shadows the function parameter args from line 1033. Consider renaming to template_args for clarity.

       auto pos_right = global_symbol_str.find('>', pos_left + 1);
       if (pos_right != std::string::npos) {
-        auto args =
+        auto template_str =
             global_symbol_str.substr(pos_left + 1, pos_right - pos_left - 1);
-        ReplaceAll(args, "true", "True");
-        ReplaceAll(args, "false", "False");
-        global_symbol_str.replace(pos_left, args.size() + 2, "(" + args + ")");
+        ReplaceAll(template_str, "true", "True");
+        ReplaceAll(template_str, "false", "False");
+        global_symbol_str.replace(pos_left, template_str.size() + 2, "(" + template_str + ")");
       }
tilelang/contrib/cutedsl/ldsm.py (2)

11-12: Remove unused noqa directives.

Per static analysis hints, the # noqa: F401 directives on lines 11-12 are unused. The imports ir from line 11 and Pointer, Int32 from line 12 are actually used in the code (Pointer in type hints, Int32 implicitly via cute.Int32).

-from cutlass._mlir import ir  # noqa: F401
-from cutlass.cute.typing import Pointer, Int32  # noqa: F401
+from cutlass._mlir import ir
+from cutlass.cute.typing import Pointer, Int32

25-33: Consider supporting num=1 in _ldmatrix helper.

The assertion assert num in [2, 4] excludes the single-matrix case, requiring separate handling in ptx_ldmatrix_x1. Consider extending the helper to support num=1 to reduce code duplication and ensure consistency between transposed and non-transposed variants.

 def _ldmatrix(smem_ptr, local_ptr, num, transpose, loc=None, ip=None):
     """Internal helper for ldmatrix operations"""
     layout = nvvm.MMALayout.col if transpose else nvvm.MMALayout.row
-    assert num in [2, 4]
-    ret_type = llvm.StructType.get_literal([T.i32()] * num)
-    out_i32 = nvvm.ldmatrix(ret_type, smem_ptr.llvm_ptr, num=num, layout=layout, loc=loc, ip=ip)
+    assert num in [1, 2, 4], f"ldmatrix supports num=1, 2, or 4, got {num}"
+    if num == 1:
+        out_i32 = nvvm.ldmatrix(T.i32(), smem_ptr.llvm_ptr, num=1, layout=layout, loc=loc, ip=ip)
+        out = cute.make_tensor(cute.recast_ptr(local_ptr, dtype=cute.Int32), 1)
+        out[0] = cute.Int32(out_i32)
+        return
+    ret_type = llvm.StructType.get_literal([T.i32()] * num)
+    out_i32 = nvvm.ldmatrix(ret_type, smem_ptr.llvm_ptr, num=num, layout=layout, loc=loc, ip=ip)
     out = cute.make_tensor(cute.recast_ptr(local_ptr, dtype=cute.Int32), num)
     for i in range(num):
         out[i] = cute.Int32(llvm.extractvalue(T.i32(), out_i32, [i], loc=loc, ip=ip))
tilelang/cache/kernel_cache.py (2)

319-339: Consider adding null check for lib_generator before accessing launcher_libpath.

The code checks lib_gen is truthy, but if lib_gen exists but launcher_libpath is an empty string (falsy but not None), the condition passes but _load_binary will fail. The current check is reasonable but could be more explicit.

-                lib_gen = getattr(kernel.adapter, "lib_generator", None)
-                if lib_gen and hasattr(lib_gen, "launcher_libpath") and lib_gen.launcher_libpath:
+                lib_gen = getattr(kernel.adapter, "lib_generator", None)
+                if lib_gen is not None and getattr(lib_gen, "launcher_libpath", None):

439-453: Static analysis: Consider using logging.exception for better stack traces.

Per static analysis hints, using self.logger.exception() instead of self.logger.error() in the except blocks would automatically include the stack trace, improving debuggability.

             except Exception as e:
-                self.logger.error(f"Error loading kernel source code from disk: {e}")
+                self.logger.exception(f"Error loading kernel source code from disk: {e}")
tilelang/contrib/cutedsl/reduce.py (1)

161-184: Hopper path uses hardcoded barrier IDs 1 and 2.

Using fixed barrier IDs (1, 2) in run_hopper could conflict with other barriers in the kernel. Consider making these configurable or documenting this constraint.

tilelang/contrib/cutedsl/cpasync.py (3)

2-15: Remove unnecessary noqa directives.

Per static analysis, the # noqa: F401 comments are unnecessary since F401 is not enabled. These can be safely removed for cleaner code.

-from cutlass.cutlass_dsl import CuTeDSL, T, if_generate, dsl_user_op  # noqa: F401
+from cutlass.cutlass_dsl import CuTeDSL, T, if_generate, dsl_user_op

-from cutlass._mlir.dialects import nvvm, cute_nvgpu  # noqa: F401
+from cutlass._mlir.dialects import nvvm, cute_nvgpu

21-36: Use TypeError for invalid type errors.

Per static analysis hints, TypeError is more appropriate than ValueError when the issue is an incorrect type.

     else:
-        raise ValueError(f"Invalid source type: {type(src)}")
+        raise TypeError(f"Expected cute.Tensor or cute.Pointer for src, got {type(src).__name__}")
     else:
-        raise ValueError(f"Invalid destination type: {type(dst)}")
+        raise TypeError(f"Expected cute.Tensor or cute.Pointer for dst, got {type(dst).__name__}")

72-106: The tma_load function overloads based on crd type — document this behavior.

When crd is not a tuple, the function reinterprets parameters differently (treating tma_desc as smem_ptr, mbar as gmem_ptr, etc.). This is confusing and error-prone. Consider splitting into two functions or adding clearer documentation.

The comment on line 76 mentions this is a BUG related to API differences. Consider either:

  1. Creating a separate function for the non-tuple case
  2. Adding explicit documentation in the docstring about this dual behavior
tilelang/jit/adapter/cutedsl/libgen.py (3)

17-36: Module-level subprocess calls block import and swallow errors silently.

Running tvm-ffi-config at module import time can slow down imports and the empty fallback silently hides configuration issues. Consider lazy initialization or logging a warning when falling back.

+import logging
+
+_logger = logging.getLogger(__name__)
+
 try:
     tvm_cxxflags = (
         subprocess.check_output(
             ["tvm-ffi-config", "--cxxflags"],
             text=True,
         )
         .strip()
         .split()
     )
     tvm_ldflags = (
         subprocess.check_output(
             ["tvm-ffi-config", "--ldflags"],
             text=True,
         )
         .strip()
         .split()
     )
 except (subprocess.CalledProcessError, FileNotFoundError):
+    _logger.debug("tvm-ffi-config not available, C++ launcher compilation may fail")
     tvm_cxxflags = []
     tvm_ldflags = []

91-112: Temporary files not cleaned up on failure.

tma_src is created but never explicitly cleaned up if compilation fails. The file will persist until system cleanup.

Consider using a context manager or explicit cleanup:

-            tma_src = tempfile.NamedTemporaryFile(mode="w", suffix=".cpp", delete=False)
             if self.tma_cpp_init_code is not None:
+                tma_src = tempfile.NamedTemporaryFile(mode="w", suffix=".cpp", delete=False)
+                try:
                     with open(tma_src.name, "w") as f:
                         f.write(self.tma_cpp_init_code)
                     ...
+                finally:
+                    os.unlink(tma_src.name)

136-138: Include stdout in error message for better diagnostics.

When compilation fails, result.stdout may also contain useful information (warnings, notes). Consider including both.

                 result = subprocess.run(compile_cmd, check=False, capture_output=True, text=True)
                 if result.returncode != 0:
-                    raise RuntimeError(f"Failed to compile C++ launcher: {result.stderr}")
+                    raise RuntimeError(
+                        f"Failed to compile C++ launcher:\nstdout: {result.stdout}\nstderr: {result.stderr}"
+                    )
tilelang/contrib/cutedsl/gemm_V1.py (2)

221-228: Mutable class attribute _instances should use ClassVar.

Per static analysis, mutable class attributes like _instances = {} should be annotated with typing.ClassVar to clarify they're shared across instances.

+from typing import ClassVar
+
 class Gemm_SM80:
-    _instances = {}  # cache instances for the same arguments
+    _instances: ClassVar[dict] = {}  # cache instances for the same arguments

424-424: Unused variable tma_tensor from tuple unpacking.

The static analysis correctly identifies that tma_tensor is never used. Prefix with underscore to indicate intentional discard.

-        tma_atom, tma_tensor = cute.nvgpu.cpasync.make_tiled_tma_atom(
+        tma_atom, _tma_tensor = cute.nvgpu.cpasync.make_tiled_tma_atom(
testing/python/jit/test_tilelang_jit_cutedsl.py (4)

9-52: Consider consolidating duplicated matmul and matmul_jit_kernel functions.

These two functions are nearly identical, with the only difference being their names. This duplication increases maintenance burden. Consider using a single function or a factory pattern.

-def matmul_jit_kernel(
-    M,
-    N,
-    K,
-    block_M,
-    block_N,
-    block_K,
-    trans_A,
-    trans_B,
-    in_dtype,
-    out_dtype,
-    accum_dtype,
-    num_stages,
-    threads,
-):
-    A_shape = (K, M) if trans_A else (M, K)
-    B_shape = (N, K) if trans_B else (K, N)
-    A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
-    B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
-
-    @T.prim_func
-    def main(
-        A: T.Tensor(A_shape, in_dtype),
-        B: T.Tensor(B_shape, in_dtype),
-        C: T.Tensor((M, N), out_dtype),
-    ):
-        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
-            A_shared = T.alloc_shared(A_shared_shape, in_dtype)
-            B_shared = T.alloc_shared(B_shared_shape, in_dtype)
-            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
-            T.clear(C_local)
-            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
-                if trans_A:
-                    T.copy(A[k * block_K, by * block_M], A_shared)
-                else:
-                    T.copy(A[by * block_M, k * block_K], A_shared)
-                if trans_B:
-                    T.copy(B[bx * block_N, k * block_K], B_shared)
-                else:
-                    T.copy(B[k * block_K, bx * block_N], B_shared)
-                T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
-            T.copy(C_local, C[by * block_M, bx * block_N])
-
-    return main
+# Reuse `matmul` for JIT kernel tests
+matmul_jit_kernel = matmul

Also applies to: 55-98


137-143: Inefficient tensor creation pattern.

Creating a tensor and then transposing it allocates memory twice. Consider creating the tensor with the correct shape directly when transposed.

-    A = torch.randn(M, K, dtype=in_dtype).cuda()
-    B = torch.randn(K, N, dtype=in_dtype).cuda()
-
-    if trans_A:
-        A = A.T
-    if trans_B:
-        B = B.T
+    A_shape = (K, M) if trans_A else (M, K)
+    B_shape = (N, K) if trans_B else (K, N)
+    A = torch.randn(*A_shape, dtype=in_dtype).cuda()
+    B = torch.randn(*B_shape, dtype=in_dtype).cuda()

244-248: Multi-stream test lacks correctness verification.

The test launches kernels on multiple streams but doesn't verify the results are correct. This only tests that the kernel doesn't crash, not that it produces correct results with concurrent execution.

     num_streams = 4
+    streams = []
+    outputs = []
     for _ in range(num_streams):
         stream = torch.cuda.Stream()
+        streams.append(stream)
+        tensor_c = torch.randn(M, N, dtype=out_dtype).cuda()
+        outputs.append(tensor_c)
         with torch.cuda.stream(stream):
-            matmul_kernel(tensor_a, tensor_b, tensor_c)
+            matmul_kernel(tensor_a, tensor_b, tensor_c)
+
+    # Wait for all streams and verify results
+    for stream in streams:
+        stream.synchronize()
+
+    tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype)
+    for tensor_c in outputs:
+        tilelang.testing.torch_assert_close(tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)

310-315: Unused check_hopper function.

This utility function is defined but never called in the test file. If it's intended to skip tests on non-Hopper hardware, it should be used as a pytest skip decorator or guard.

+import pytest
+
 def check_hopper():
     if not torch.cuda.is_available():
         return False
     props = torch.cuda.get_device_properties(0)
     compute_capability = props.major, props.minor
     return compute_capability == (9, 0)
+
+requires_hopper = pytest.mark.skipif(
+    not check_hopper(),
+    reason="Requires Hopper GPU (SM 9.0)"
+)

Then apply @requires_hopper decorator to tests that need Hopper hardware.

tilelang/jit/adapter/cutedsl/adapter.py (4)

50-63: Duplicated parameter shape processing logic.

The parameter shape conversion logic is duplicated between __init__ and from_database. Consider extracting to a helper method.

+    @staticmethod
+    def _convert_param_shapes(params: list[KernelParam]) -> list[list]:
+        """Convert parameter shapes to native Python lists."""
+        param_shapes = []
+        for param in params:
+            native_shape = []
+            for dim in param.shape:
+                if isinstance(dim, tir.IntImm):
+                    native_shape.append(int(dim))
+                elif isinstance(dim, tir.Var):
+                    native_shape.append(dim)
+                else:
+                    native_shape.append(dim)
+            param_shapes.append(native_shape)
+        return param_shapes

Then use self.param_shapes = self._convert_param_shapes(params) in both __init__ and from_database.

Also applies to: 124-136


250-254: Broad exception catch may mask bugs.

Catching generic Exception can hide unexpected errors. Consider catching more specific exceptions like OSError or IOError.

         try:
             shutil.copy2(src_cubin_path, dst_cubin_path)
             logger.debug(f"Saved CuTeDSL cubin to cache: {dst_cubin_path}")
-        except Exception as e:
+        except OSError as e:
             logger.warning(f"Failed to save cubin to cache: {e}")

231-232: Move imports to module level.

Importing os and shutil inside a method adds overhead on each call. These are standard library modules and should be imported at module level.

 from __future__ import annotations
 import logging
+import os
+import shutil
 from typing import Any, Callable

Then remove the local imports at lines 231-232.


21-22: Class attribute pymodule may be unintentionally shared.

Mutable or stateful class attributes can be shared across instances. While None itself isn't mutable, this pattern can cause confusion if the attribute is meant to be instance-specific.

Consider initializing pymodule in __init__ or from_database only:

 class CuTeDSLKernelAdapter(BaseKernelAdapter):
-    pymodule = None
+    # pymodule is initialized in __init__ and from_database
src/target/codegen_cutedsl.h (2)

25-25: Global constant in header may cause ODR violations.

const int64_t LOOP_UNROLL_THRESHOLD defined in a header file could cause one-definition-rule (ODR) issues if this header is included in multiple translation units.

-const int64_t LOOP_UNROLL_THRESHOLD = 64;
+inline constexpr int64_t LOOP_UNROLL_THRESHOLD = 64;

Or move the definition to the .cc file and declare extern const here.


78-82: Documentation comment format issue.

The Doxygen \param comment format is incorrect - it should specify the parameter name without the type.

   /*!
    * \brief Print expr representing the thread tag
-   * \param IterVar iv The thread index to be binded;
+   * \param iv The thread index to be bound
    */
   virtual void BindThreadIndex_(const IterVar &iv); // NOLINT(*)
tilelang/jit/adapter/cutedsl/wrapper.py (2)

483-529: Class-level mutable attributes should use ClassVar.

While these dictionaries are used as read-only lookups, annotating them with ClassVar documents the intent and satisfies static analysis.

+from typing import Any, ClassVar
+
 class TLCuTeDSLSourceWrapper(TLCUDASourceWrapper):
-    _TYPE_MAP = {
+    _TYPE_MAP: ClassVar[dict[str, str]] = {
         "float32": "cutlass.Float32",
         ...
     }
-    _CXX_TYPE_MAP = {
+    _CXX_TYPE_MAP: ClassVar[dict[str, str]] = {
         "float32": "float",
         ...
     }
-    _CTYPES_MAP = {
+    _CTYPES_MAP: ClassVar[dict[str, str]] = {
         "buffer": "ctypes.c_uint64",
         ...
     }

1067-1067: Prefix unused unpacked variables with underscore.

The variables dtype and globalAddress are unpacked but never used, which triggers linter warnings.

-            _, dtype, tensor_rank, globalAddress, *remaining_args = args[1:]
+            _, _dtype, tensor_rank, _globalAddress, *remaining_args = args[1:]
src/target/codegen_py.cc (2)

471-487: Remove unused variable buffer_var.

The variable buffer_var on line 479 is declared but never used.

   DataType value_dtype = op->value.dtype();
   DataType element_dtype = op->buffer->dtype;
   PrimExpr index_expr = op->indices[0];
-  Var buffer_var = op->buffer->data;

453-469: Remove unused variable buffer_var.

The variable buffer_var on line 462 is declared but never used.

   DataType value_dtype = op->dtype;
   PrimExpr index = op->indices[0];
-  Var buffer_var = op->buffer->data;
   DataType element_dtype = op->buffer->dtype;
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2905143 and 3f1def2.

📒 Files selected for processing (37)
  • .github/workflows/ci.yml (1 hunks)
  • 3rdparty/tvm (1 hunks)
  • CMakeLists.txt (1 hunks)
  • examples/gemm_fp8/example_tilelang_gemm_fp8.py (1 hunks)
  • maint/scripts/run_local_ci_test.sh (1 hunks)
  • src/target/codegen_cutedsl.cc (1 hunks)
  • src/target/codegen_cutedsl.h (1 hunks)
  • src/target/codegen_py.cc (1 hunks)
  • src/target/codegen_py.h (1 hunks)
  • src/target/rt_mod_cutedsl.cc (1 hunks)
  • src/tl_templates/cuda/nvrtc_std.h (1 hunks)
  • testing/python/jit/test_tilelang_jit_cutedsl.py (1 hunks)
  • tilelang/cache/kernel_cache.py (12 hunks)
  • tilelang/contrib/cutedsl/.gitignore (1 hunks)
  • tilelang/contrib/cutedsl/__init__.py (1 hunks)
  • tilelang/contrib/cutedsl/cpasync.py (1 hunks)
  • tilelang/contrib/cutedsl/gemm_V1.py (1 hunks)
  • tilelang/contrib/cutedsl/ldsm.py (1 hunks)
  • tilelang/contrib/cutedsl/math.py (1 hunks)
  • tilelang/contrib/cutedsl/mbar.py (1 hunks)
  • tilelang/contrib/cutedsl/reduce.py (1 hunks)
  • tilelang/contrib/cutedsl/threadblock_swizzle.py (1 hunks)
  • tilelang/engine/lower.py (2 hunks)
  • tilelang/jit/__init__.py (7 hunks)
  • tilelang/jit/adapter/__init__.py (1 hunks)
  • tilelang/jit/adapter/cutedsl/__init__.py (1 hunks)
  • tilelang/jit/adapter/cutedsl/adapter.py (1 hunks)
  • tilelang/jit/adapter/cutedsl/checks.py (1 hunks)
  • tilelang/jit/adapter/cutedsl/libgen.py (1 hunks)
  • tilelang/jit/adapter/cutedsl/wrapper.py (1 hunks)
  • tilelang/jit/adapter/nvrtc/adapter.py (1 hunks)
  • tilelang/jit/adapter/nvrtc/wrapper.py (1 hunks)
  • tilelang/jit/adapter/utils.py (4 hunks)
  • tilelang/jit/adapter/wrapper.py (7 hunks)
  • tilelang/jit/execution_backend.py (4 hunks)
  • tilelang/jit/kernel.py (7 hunks)
  • tilelang/utils/target.py (3 hunks)
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.

Applied to files:

  • examples/gemm_fp8/example_tilelang_gemm_fp8.py
  • src/target/rt_mod_cutedsl.cc
  • .github/workflows/ci.yml
  • testing/python/jit/test_tilelang_jit_cutedsl.py
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.

Applied to files:

  • tilelang/contrib/cutedsl/__init__.py
🧬 Code graph analysis (19)
examples/gemm_fp8/example_tilelang_gemm_fp8.py (2)
tilelang/jit/__init__.py (3)
  • jit (434-434)
  • jit (438-448)
  • jit (451-521)
tilelang/jit/kernel.py (1)
  • out_idx (629-630)
tilelang/jit/adapter/utils.py (2)
src/target/codegen_cutedsl.h (1)
  • tvm (18-97)
src/target/codegen_py.h (1)
  • tvm (25-247)
tilelang/contrib/cutedsl/threadblock_swizzle.py (4)
tilelang/carver/template/base.py (1)
  • arch (156-163)
tilelang/jit/__init__.py (3)
  • jit (434-434)
  • jit (438-448)
  • jit (451-521)
tilelang/carver/roller/rasterization.py (1)
  • panel_width (14-16)
tilelang/carver/roller/hint.py (1)
  • stride (45-46)
tilelang/jit/adapter/nvrtc/wrapper.py (1)
tilelang/jit/adapter/utils.py (1)
  • pythonic_expr (156-286)
tilelang/jit/adapter/__init__.py (1)
tilelang/jit/adapter/cutedsl/adapter.py (1)
  • CuTeDSLKernelAdapter (21-341)
tilelang/jit/execution_backend.py (3)
tilelang/jit/adapter/utils.py (1)
  • is_cutedsl_target (114-115)
tilelang/env.py (1)
  • use_gemm_v1 (284-290)
tilelang/jit/adapter/cutedsl/checks.py (1)
  • check_cutedsl_available (33-79)
src/target/codegen_cutedsl.cc (2)
src/target/codegen_py.cc (2)
  • CheckOutermostParenthesesMatch (22-38)
  • CheckOutermostParenthesesMatch (22-22)
src/target/codegen_py.h (1)
  • PrintExpr_ (119-121)
tilelang/contrib/cutedsl/mbar.py (2)
tilelang/carver/template/base.py (1)
  • arch (156-163)
tilelang/language/builtin.py (2)
  • mbarrier_expect_tx (280-289)
  • mbarrier_arrive (262-277)
tilelang/contrib/cutedsl/ldsm.py (1)
tilelang/language/proxy.py (1)
  • make_tensor (278-279)
tilelang/jit/kernel.py (2)
tilelang/jit/adapter/utils.py (1)
  • is_cutedsl_target (114-115)
tilelang/jit/adapter/cutedsl/adapter.py (3)
  • CuTeDSLKernelAdapter (21-341)
  • from_database (99-149)
  • get_kernel_source (197-205)
tilelang/utils/target.py (1)
tilelang/language/ast/ir.py (1)
  • target (1677-1707)
tilelang/contrib/cutedsl/reduce.py (3)
src/tl_templates/cuda/reduce.h (1)
  • T (178-250)
src/tl_templates/cuda/nvrtc_std.h (1)
  • min (124-124)
tilelang/language/proxy.py (1)
  • make_tensor (278-279)
tilelang/engine/lower.py (1)
tilelang/language/ast/ir.py (1)
  • target (1677-1707)
tilelang/contrib/cutedsl/cpasync.py (1)
src/tl_templates/cuda/copy.h (1)
  • cp_async_wait (20-26)
src/target/codegen_cutedsl.h (1)
src/target/codegen_py.h (1)
  • codegen (26-246)
tilelang/cache/kernel_cache.py (1)
tilelang/autotuner/param.py (4)
  • _save_kernel_to_disk (176-260)
  • _safe_write_file (158-166)
  • _load_binary (152-155)
  • _safe_write_executable (169-174)
tilelang/jit/adapter/cutedsl/libgen.py (3)
tilelang/jit/adapter/libgen.py (1)
  • LibraryGenerator (23-172)
tilelang/jit/adapter/utils.py (1)
  • is_cutedsl_target (114-115)
tilelang/jit/adapter/cutedsl/wrapper.py (2)
  • host_func (550-554)
  • host_func (557-559)
tilelang/jit/adapter/wrapper.py (1)
tilelang/jit/adapter/utils.py (4)
  • is_cutedsl_target (114-115)
  • pythonic_expr (156-286)
  • parse_tma_descriptor_args (389-494)
  • is_cuda_target (98-99)
src/target/codegen_py.cc (1)
src/target/codegen_py.h (2)
  • PrintStmt_ (113-113)
  • PrintExpr_ (119-121)
🪛 Cppcheck (2.18.0)
src/target/rt_mod_cutedsl.cc

[error] 62-62: syntax error

(syntaxError)

🪛 Ruff (0.14.8)
tilelang/jit/adapter/utils.py

41-41: Unused function argument: annotation

(ARG001)


54-54: Avoid specifying long messages outside the exception class

(TRY003)


85-85: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/contrib/cutedsl/__init__.py

9-9: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


10-10: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


11-11: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


14-14: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


17-17: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


18-18: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


23-23: from .mbar import * used; unable to detect undefined names

(F403)


24-24: from .cpasync import * used; unable to detect undefined names

(F403)


25-25: from .gemm_V1 import * used; unable to detect undefined names

(F403)


26-26: from .reduce import * used; unable to detect undefined names

(F403)


27-27: from .ldsm import * used; unable to detect undefined names

(F403)


28-28: from .math import * used; unable to detect undefined names

(F403)


29-29: from .threadblock_swizzle import * used; unable to detect undefined names

(F403)


64-64: bar_sync_ptx may be undefined, or defined from star imports

(F405)


129-129: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/jit/adapter/__init__.py

7-7: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/jit/execution_backend.py

79-81: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/contrib/cutedsl/mbar.py

6-6: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


7-7: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


10-10: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


11-11: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


12-12: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


17-17: Unused function argument: timeout_ns

(ARG001)

tilelang/jit/adapter/cutedsl/__init__.py

6-11: __all__ is not sorted

Apply an isort-style sorting to __all__

(RUF022)


13-13: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


14-14: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


15-15: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


16-16: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/contrib/cutedsl/ldsm.py

11-11: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


12-12: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/utils/target.py

104-104: Consider [*list(target_dict["keys"]), "cutedsl"] instead of concatenation

Replace with [*list(target_dict["keys"]), "cutedsl"]

(RUF005)

tilelang/contrib/cutedsl/cpasync.py

2-2: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


4-4: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


11-11: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


14-14: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


15-15: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


29-29: Prefer TypeError exception for invalid type

(TRY004)


29-29: Avoid specifying long messages outside the exception class

(TRY003)


35-35: Prefer TypeError exception for invalid type

(TRY004)


35-35: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/jit/adapter/cutedsl/adapter.py

109-109: Unused class method argument: pass_configs

(ARG003)


197-197: Unused method argument: kernel_only

(ARG002)


253-253: Do not catch blind exception: Exception

(BLE001)


273-275: Avoid specifying long messages outside the exception class

(TRY003)


294-294: Avoid specifying long messages outside the exception class

(TRY003)


312-312: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/jit/adapter/cutedsl/wrapper.py

483-499: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


503-514: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


516-529: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


594-594: Avoid specifying long messages outside the exception class

(TRY003)


893-893: Unused method argument: code

(ARG002)


978-978: Consider [*inner_args, "stream: CUstream"] instead of concatenation

Replace with [*inner_args, "stream: CUstream"]

(RUF005)


989-989: Unused method argument: function_name

(ARG002)


1041-1041: Unused method argument: device_index

(ARG002)


1067-1067: Unpacked variable dtype is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


1067-1067: Unpacked variable globalAddress is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


1235-1235: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/jit/adapter/cutedsl/checks.py

47-47: Do not catch blind exception: Exception

(BLE001)


55-57: Avoid specifying long messages outside the exception class

(TRY003)


65-65: Avoid specifying long messages outside the exception class

(TRY003)


79-79: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/contrib/cutedsl/gemm_V1.py

3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


151-151: Unused function argument: use_wgmma

(ARG001)


152-152: Unused function argument: wg_wait

(ARG001)


192-192: Unused function argument: use_wgmma

(ARG001)


193-193: Unused function argument: wg_wait

(ARG001)


222-222: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


233-233: Unused method argument: stride_A

(ARG002)


233-233: Unused method argument: stride_B

(ARG002)


233-233: Unused method argument: offset_A

(ARG002)


233-233: Unused method argument: offset_B

(ARG002)


381-381: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


392-392: Unused method argument: stride_A

(ARG002)


392-392: Unused method argument: stride_B

(ARG002)


392-392: Unused method argument: offset_A

(ARG002)


392-392: Unused method argument: offset_B

(ARG002)


424-424: Unpacked variable tma_tensor is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

tilelang/cache/kernel_cache.py

301-301: Do not catch blind exception: Exception

(BLE001)


302-302: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


445-445: Do not catch blind exception: Exception

(BLE001)


446-446: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


452-452: Do not catch blind exception: Exception

(BLE001)


453-453: Use logging.exception instead of logging.error

Replace with exception

(TRY400)

tilelang/jit/adapter/cutedsl/libgen.py

20-20: Starting a process with a partial executable path

(S607)


28-28: Starting a process with a partial executable path

(S607)


78-78: Unused method argument: timeout

(ARG002)


78-78: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


92-92: Unused noqa directive (non-enabled: SIM115)

Remove unused noqa directive

(RUF100)


99-99: subprocess call: check for execution of untrusted input

(S603)


100-108: Starting a process with a partial executable path

(S607)


115-115: Unused noqa directive (non-enabled: SIM115)

Remove unused noqa directive

(RUF100)


136-136: subprocess call: check for execution of untrusted input

(S603)


138-138: Avoid specifying long messages outside the exception class

(TRY003)


146-146: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/jit/adapter/wrapper.py

896-896: Avoid specifying long messages outside the exception class

(TRY003)

🪛 Shellcheck (0.11.0)
maint/scripts/run_local_ci_test.sh

[warning] 21-21: Use 'cd ... || exit' or 'cd ... || return' in case cd fails.

(SC2164)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.8
  • GitHub Check: Build SDist
  • GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 18

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tilelang/jit/kernel.py (1)

149-163: Missing "cutedsl" in from_database type hint.

The execution_backend parameter type hint is missing "cutedsl", but lines 405-416 handle the cutedsl case in _create_adapter_from_database.

     def from_database(
         cls,
         func: PrimFunc,
         host_kernel_source: str,
         device_kernel_source: str,
         kernel_lib_path: str,
         params: list[KernelParam],
         target: str | Target,
         target_host: str | Target,
         out_idx: list[int] | int,
-        execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"],
+        execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"],
         pass_configs: dict[str, Any] | None = None,
         compile_flags: list[str] | None = None,
     ):
♻️ Duplicate comments (9)
tilelang/contrib/cutedsl/threadblock_swizzle.py (2)

25-38: Zero stride risk already flagged.

The stride calculation on line 35 can result in zero when the final panel has fewer than gridDim.x blocks, causing division by zero on line 36. This was flagged in a previous review.


41-54: Zero stride risk already flagged.

Same issue applies to rasterization2DColumn on line 51.

tilelang/contrib/cutedsl/reduce.py (2)

16-27: Unsupported third parameter for nvvm.fmin already flagged.

The NVVM fmin intrinsic only accepts two operands. Passing c as a third parameter will fail. This was flagged in a previous review.


30-41: Same issue applies to nvvm.fmax.

The max function has the same unsupported third parameter issue.

tilelang/contrib/cutedsl/ldsm.py (1)

69-72: Bug: ptx_ldmatrix_x1_trans will fail due to assertion in _ldmatrix.

This issue was already identified in a previous review. The _ldmatrix helper asserts num in [2, 4] at line 28, but ptx_ldmatrix_x1_trans calls _ldmatrix with num=1. This will cause an AssertionError at runtime.

The fix should mirror the inline handling used in ptx_ldmatrix_x1 (lines 51-54):

 @dsl_user_op
 def ptx_ldmatrix_x1_trans(smem_ptr: Pointer, local_ptr: Pointer, *, loc=None, ip=None) -> None:
     """Load 1 matrix (8x8) with transpose from shared memory"""
-    _ldmatrix(smem_ptr, local_ptr, 1, True, loc, ip)
+    out_i32 = nvvm.ldmatrix(T.i32(), smem_ptr.llvm_ptr, num=1, layout=nvvm.MMALayout.col, loc=loc, ip=ip)
+    out = cute.make_tensor(cute.recast_ptr(local_ptr, dtype=cute.Int32), 1)
+    out[0] = cute.Int32(out_i32)
src/target/codegen_py.h (1)

28-28: Avoid using namespace in header files.

This was flagged in a previous review. using namespace tir; in a header file pollutes the namespace for all translation units that include this header, potentially causing name collisions.

Instead, use explicit tir:: qualification in declarations or move the using directive inside method implementations in the .cc file.

src/target/codegen_py.cc (3)

81-102: Add collision check for gvar->name_hint before reserving it (duplicate).
This matches the earlier review: the name_hint path can silently generate duplicate Python defs.

     } else {
-      func_name_supply_->ReserveName(gvar->name_hint);
-      return gvar->name_hint;
+      auto name = gvar->name_hint;
+      ICHECK(!func_name_supply_->ContainsName(name))
+          << "Function " << gvar << " must use name hint " << name
+          << ", but this name has already been used.";
+      func_name_supply_->ReserveName(name);
+      return name;
     }

188-217: PrintType switch fallthrough prints incorrect types (duplicate); also float16 path is inconsistent.
Missing break causes intbool-style output. Past review already flagged this.

 void CodeGenTileLangPY::PrintType(DataType type,
                                   std::ostream &os) { // NOLINT(*)
   if (type.is_float()) {
-    if (type.bits() == 32 || type.bits() == 64) {
+    if (type.bits() == 16 || type.bits() == 32 || type.bits() == 64) {
       os << "float";
+    } else {
+      LOG(FATAL) << "Unsupported float bit-width: " << type.bits();
     }
   } else if (type.is_uint()) {
     switch (type.bits()) {
     case 8:
     case 16:
     case 32:
     case 64: {
       os << "int";
+      break;
     }
     case 1:
       os << "bool";
+      break;
+    default:
+      LOG(FATAL) << "Unsupported uint bit-width: " << type.bits();
     }
   } else if (type.is_int()) {
     switch (type.bits()) {
     case 8:
     case 16:
     case 32:
     case 64: {
       os << "int";
+      break;
     }
+    default:
+      LOG(FATAL) << "Unsupported int bit-width: " << type.bits();
     }
   } else {
     LOG(FATAL) << "Cannot convert type " << type << " to Python type";
   }
 }

522-536: For-loop bounds stream TVM IR directly instead of Python expr printing (duplicate).
Past review already flagged this; should use PrintExpr_.

   if (is_zero(op->min)) {
     PrintExpr_(op->extent, stream);
   } else {
-    stream << op->min << ", "
-           << arith::Analyzer().Simplify(op->extent + op->min);
+    PrintExpr_(op->min, stream);
+    stream << ", ";
+    PrimExpr upper = arith::Analyzer().Simplify(op->extent + op->min);
+    PrintExpr_(upper, stream);
   }
🧹 Nitpick comments (28)
tilelang/jit/adapter/utils.py (2)

41-54: Remove unused annotation parameter or use it in the pattern.

The annotation parameter defaults to "@cute.kernel" but is never used in the function body. The regex pattern is hardcoded to @cute\.kernel. Either remove the parameter or incorporate it into the pattern for consistency with sibling functions like match_declare_kernel.

-def match_declare_kernel_cutedsl(source: str, annotation: str = "@cute.kernel") -> int:
+def match_declare_kernel_cutedsl(source: str) -> int:
     # Match decorator followed by function definition across lines
     # \s+ allows any whitespace including newlines between decorator and def
     pattern = r"@cute\.kernel\s+def\s+(\w+)"

Or, to use the parameter:

-    pattern = r"@cute\.kernel\s+def\s+(\w+)"
+    pattern = rf"{re.escape(annotation)}\s+def\s+(\w+)"

78-85: Regex may fail on type hints with nested parentheses.

The pattern \([^)]*\) matches until the first ), which breaks for signatures like def kernel(callback: Callable[[int], int]). If such signatures are possible, consider a more robust approach (e.g., iterative parenthesis matching or a simpler heuristic).

For current use cases with CuTeDSL kernels, this is likely acceptable.

tilelang/contrib/cutedsl/math.py (1)

11-44: Consider adding return type annotations for consistency.

The divf function has a return type annotation but the other functions (exp, exp2, log, etc.) do not. Adding consistent type hints improves maintainability.

-def exp(x: Union[TensorSSA, Numeric], fastmath: bool = True):
+def exp(x: Union[TensorSSA, Numeric], fastmath: bool = True) -> Union[TensorSSA, Numeric]:
     return cute.math.exp(x, fastmath=fastmath)
tilelang/contrib/cutedsl/reduce.py (1)

155-159: Recursive AllReduce may hit stack limits for large thread counts.

The recursive call pattern AllReduce(...).run(...) builds up stack frames. For very large threads values (e.g., 1024), this results in ~10 recursive calls which is fine, but the pattern could be refactored to iterative if needed for deeper recursion.

Also applies to: 180-184

tilelang/contrib/cutedsl/gemm_V1.py (3)

222-222: Annotate mutable class attribute with ClassVar.

The _instances dict is a mutable class attribute shared across all instances. Annotate it with typing.ClassVar for clarity and type checker compatibility.

+from typing import ClassVar

 class Gemm_SM80:
-    _instances = {}  # cache instances for the same arguments
+    _instances: ClassVar[dict] = {}  # cache instances for the same arguments

Apply the same change to Gemm_SM90._instances.

Also applies to: 381-381


424-424: Prefix unused variable with underscore.

The unpacked tma_tensor is never used. Prefix it with _ to indicate it's intentionally unused.

-        tma_atom, tma_tensor = cute.nvgpu.cpasync.make_tiled_tma_atom(
+        tma_atom, _tma_tensor = cute.nvgpu.cpasync.make_tiled_tma_atom(

3-3: Remove unused noqa directive.

The # noqa: F401 comment on line 3 is flagged as unused by Ruff.

-import cutlass.utils as utils  # noqa: F401
+import cutlass.utils as utils
src/target/codegen_cutedsl.h (1)

86-94: Consider documenting the eviction policy enum mapping.

The eviction_policy_names_ vector maps indices to policy strings. A brief comment explaining the index-to-policy correspondence would improve maintainability.

+  // Maps eviction policy indices to their string representations
+  // Index 0: EVICT_NORMAL, Index 1: EVICT_FIRST, Index 2: EVICT_LAST
   std::vector<std::string> eviction_policy_names_ = {
       "EVICT_NORMAL", "EVICT_FIRST", "EVICT_LAST"};
tilelang/jit/adapter/cutedsl/__init__.py (1)

6-16: Consider aligning __all__ order with import order for consistency.

The __all__ list places check_cutedsl_available last, while imports have it first. Consider either sorting __all__ alphabetically or matching the import order for easier maintenance.

 __all__ = [
+    "check_cutedsl_available",
     "CuTeDSLKernelAdapter",
-    "TLCuTeDSLSourceWrapper",
     "CuTeDSLLibraryGenerator",
-    "check_cutedsl_available",
+    "TLCuTeDSLSourceWrapper",
 ]
tilelang/utils/target.py (2)

99-106: Consider using unpacking for cleaner list construction.

The logic correctly transforms cutedsl* target strings to CUDA-based Targets with the cutedsl key appended. Per static analysis suggestion, consider using unpacking for the list concatenation.

-        target_dict["keys"] = list(target_dict["keys"]) + ["cutedsl"]
+        target_dict["keys"] = [*target_dict["keys"], "cutedsl"]

127-132: Dead code: lines 130-131 are now unreachable.

The early return at lines 127-128 ensures that if return_var is a Target, the function returns immediately. This makes the isinstance(return_var, Target) check at line 130 unreachable.

Either remove the dead code or consolidate the logic:

     if isinstance(return_var, Target):
         return return_var
     if return_object:
-        if isinstance(return_var, Target):
-            return return_var
         return Target(return_var)
     return return_var
tilelang/jit/adapter/cutedsl/checks.py (1)

43-49: Consider narrowing the exception type.

The bare Exception catch is documented as "best-effort" which is reasonable, but consider catching a more specific base like (ValueError, TypeError, AttributeError) to avoid accidentally swallowing unexpected errors (e.g., KeyboardInterrupt is not caught, but other programming errors might be).

     except _importlib_metadata.PackageNotFoundError:
         dist_version = None
-    except Exception:
+    except (ValueError, TypeError, AttributeError, OSError):
         # Metadata is best-effort; don't block internal/nonstandard installs here.
         dist_version = None
tilelang/contrib/cutedsl/__init__.py (3)

20-26: Consider using explicit imports instead of star imports.

Star imports make it harder to track what symbols are available and can lead to namespace pollution. Consider defining an __all__ in each submodule and/or using explicit imports in this init file.

That said, for a facade/convenience module like this, star imports are acceptable if the submodules are well-maintained.


85-95: Consider setting has_side_effects=False for the pack operation.

The mov.b32 instruction is a pure bitwise packing operation with no side effects. Setting has_side_effects=True may prevent useful compiler optimizations like dead code elimination or instruction reordering.

         packed_xy = llvm.inline_asm(
             Int32.mlir_type,
             [x_i16, y_i16],
             "mov.b32 $0, {$1, $2};",
             "=r,h,h",
-            has_side_effects=True,
+            has_side_effects=False,
             is_align_stack=False,
             asm_dialect=llvm.AsmDialect.AD_ATT,
             loc=loc,
             ip=ip,
         )

102-127: Consider extending AtomicAdd to support more data types.

The current implementation only supports Float32 and Int32. For broader utility, consider adding support for Float16 (atomicAdd on fp16 is available on SM_60+) and Int64.

Would you like me to help extend this to support additional data types?

src/target/codegen_cutedsl.cc (2)

22-47: Consider extracting duplicated utility functions.

CheckOutermostParenthesesMatch and RemoveOutermostParentheses are duplicated from src/target/codegen_py.cc (lines 21-37 in the external context). Consider extracting these into a shared utility header to reduce code duplication.

// Could create a shared header like codegen_utils.h
// and move these functions there for reuse across codegen backends

91-109: Silent fallthrough for unsupported float8 variants.

Several float8 variants (e3m4, e4m3b11fnuz, e4m3fnuz, e5m2fnuz) have empty branches that silently fall through to the error at line 134. While this works, adding explicit LOG(WARNING) or comments would clarify intent.

     } else if (t.is_float8_e3m4()) {
-      // unsupported
+      // Float8E3M4 is unsupported in CuTeDSL
     } else if (t.is_float8_e4m3()) {
tilelang/jit/adapter/cutedsl/libgen.py (2)

78-78: Unused timeout parameter.

The timeout parameter is declared but never used in the method body. Either implement timeout handling for the subprocess calls or remove the parameter.

-    def compile_lib(self, timeout: float = None):
+    def compile_lib(self, timeout: float | None = None):
+        # Note: timeout is reserved for future use with subprocess timeouts

Or if timeout support is needed:

-                result = subprocess.run(compile_cmd, check=False, capture_output=True, text=True)
+                result = subprocess.run(compile_cmd, check=False, capture_output=True, text=True, timeout=timeout)

92-112: Temporary files not cleaned up on success.

The NamedTemporaryFile objects created with delete=False are never explicitly cleaned up after successful compilation. While this may be acceptable for debugging, consider cleaning up temp files in production paths.

             tma_src = tempfile.NamedTemporaryFile(mode="w", suffix=".cpp", delete=False)
             if self.tma_cpp_init_code is not None:
                 with open(tma_src.name, "w") as f:
                     f.write(self.tma_cpp_init_code)

                 # ... compilation code ...
+                
+                # Clean up temp source file
+                os.unlink(tma_src.name)
tilelang/contrib/cutedsl/ldsm.py (1)

11-12: Remove unused noqa directives.

Static analysis indicates the # noqa: F401 comments on lines 11-12 are unnecessary since the rule is not enabled. These can be safely removed.

-from cutlass._mlir import ir  # noqa: F401
-from cutlass.cute.typing import Pointer, Int32  # noqa: F401
+from cutlass._mlir import ir
+from cutlass.cute.typing import Pointer, Int32
tilelang/cache/kernel_cache.py (2)

423-431: Required files validation includes launcher for CuTeDSL.

Good validation that ensures the launcher library exists before attempting to load a cached CuTeDSL kernel. This prevents confusing errors later in the loading process.

Minor: Consider using a generator expression instead of list comprehension in the all() call on line 431:

-        if not all([os.path.exists(file) for file in required_files]):
+        if not all(os.path.exists(file) for file in required_files):

301-302: Consider using logging.exception for better stack traces.

Per static analysis hints, logging.exception provides better debugging context than logging.error when catching exceptions, as it automatically includes the stack trace.

         except Exception as e:
-            self.logger.error(f"Error saving kernel source code to disk: {e}")
+            self.logger.exception("Error saving kernel source code to disk")

This pattern applies to similar locations at lines 314, 370, 380, 446, 453, and 462.

tilelang/contrib/cutedsl/cpasync.py (3)

1-19: Remove ineffective noqa: F401 directives (RUF100) and consider dropping unused constants.
Ruff reports these noqa comments are unused, so they add noise. Also BYTES_PER_TENSORMAP / BYTES_PER_POINTER aren’t referenced in this module.


21-37: Don’t use assert for user input validation; raise ValueError and prefer TypeError for bad operand types.
assert size in [16, 8, 4] can be optimized away and produces weaker errors than explicit exceptions; also the invalid-type branches should be TypeError (per Ruff TRY004).

 def cp_async_gs(size, dst, dst_offset, src, src_offset):
-    assert size in [16, 8, 4]
+    if size not in (4, 8, 16):
+        raise ValueError("cp_async_gs size must be one of {4, 8, 16} bytes")
     mode = nvvm.LoadCacheModifierKind.CG if size == 16 else nvvm.LoadCacheModifierKind.CA
     if isinstance(src, cute.Tensor):
         src_ptr = src.iterator
     elif isinstance(src, cute.Pointer):
         src_ptr = src
     else:
-        raise ValueError(f"Invalid source type: {type(src)}")
+        raise TypeError("cp_async_gs src must be cute.Tensor or cute.Pointer")
     if isinstance(dst, cute.Tensor):
         dst_ptr = dst.iterator
     elif isinstance(dst, cute.Pointer):
         dst_ptr = dst
     else:
-        raise ValueError(f"Invalid destination type: {type(dst)}")
+        raise TypeError("cp_async_gs dst must be cute.Tensor or cute.Pointer")
     cp_async_shared_global(dst_ptr + dst_offset, src_ptr + src_offset, size, mode)

170-198: Ensure size is a plain int before building IntegerAttr (defensive).
If cp_size/src_size come in as non-int “Int” wrappers, ir.IntegerAttr.get(..., size) can misbehave.

-    size = src_size if src_size else cp_size
+    size = int(src_size) if src_size is not None else int(cp_size)
     nvvm.cp_async_shared_global(
         dst=dst.llvm_ptr,
         src=src.llvm_ptr,
         size=ir.IntegerAttr.get(ir.IntegerType.get_signless(32), size),
tilelang/jit/adapter/cutedsl/adapter.py (2)

197-206: kernel_only is unused (ARG002); either honor it or drop it.


216-255: Create cache dir and avoid catching Exception broadly.
If _cache_path exists but directory doesn’t, copy will fail; also catching Exception can hide real programmer errors.

         # Destination cubin path (in cache directory)
         dst_cubin_path = os.path.join(cache_path, "kernel.cubin")
+        os.makedirs(cache_path, exist_ok=True)
@@
-        except Exception as e:
+        except (OSError, shutil.Error) as e:
             logger.warning(f"Failed to save cubin to cache: {e}")
tilelang/jit/adapter/cutedsl/wrapper.py (1)

483-529: Mark mapping dicts as ClassVar (RUF012) to avoid “mutable default on class” pitfalls.
Even if you never mutate them, annotating helps tooling and avoids future foot-guns.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3f1def2 and 5617be1.

📒 Files selected for processing (37)
  • .github/workflows/ci.yml (1 hunks)
  • 3rdparty/tvm (1 hunks)
  • CMakeLists.txt (1 hunks)
  • examples/gemm_fp8/example_tilelang_gemm_fp8.py (1 hunks)
  • maint/scripts/run_local_ci_test.sh (1 hunks)
  • src/target/codegen_cutedsl.cc (1 hunks)
  • src/target/codegen_cutedsl.h (1 hunks)
  • src/target/codegen_py.cc (1 hunks)
  • src/target/codegen_py.h (1 hunks)
  • src/target/rt_mod_cutedsl.cc (1 hunks)
  • src/tl_templates/cuda/nvrtc_std.h (1 hunks)
  • testing/python/jit/test_tilelang_jit_cutedsl.py (1 hunks)
  • tilelang/cache/kernel_cache.py (12 hunks)
  • tilelang/contrib/cutedsl/.gitignore (1 hunks)
  • tilelang/contrib/cutedsl/__init__.py (1 hunks)
  • tilelang/contrib/cutedsl/cpasync.py (1 hunks)
  • tilelang/contrib/cutedsl/gemm_V1.py (1 hunks)
  • tilelang/contrib/cutedsl/ldsm.py (1 hunks)
  • tilelang/contrib/cutedsl/math.py (1 hunks)
  • tilelang/contrib/cutedsl/mbar.py (1 hunks)
  • tilelang/contrib/cutedsl/reduce.py (1 hunks)
  • tilelang/contrib/cutedsl/threadblock_swizzle.py (1 hunks)
  • tilelang/engine/lower.py (2 hunks)
  • tilelang/jit/__init__.py (7 hunks)
  • tilelang/jit/adapter/__init__.py (1 hunks)
  • tilelang/jit/adapter/cutedsl/__init__.py (1 hunks)
  • tilelang/jit/adapter/cutedsl/adapter.py (1 hunks)
  • tilelang/jit/adapter/cutedsl/checks.py (1 hunks)
  • tilelang/jit/adapter/cutedsl/libgen.py (1 hunks)
  • tilelang/jit/adapter/cutedsl/wrapper.py (1 hunks)
  • tilelang/jit/adapter/nvrtc/adapter.py (1 hunks)
  • tilelang/jit/adapter/nvrtc/wrapper.py (1 hunks)
  • tilelang/jit/adapter/utils.py (4 hunks)
  • tilelang/jit/adapter/wrapper.py (7 hunks)
  • tilelang/jit/execution_backend.py (4 hunks)
  • tilelang/jit/kernel.py (7 hunks)
  • tilelang/utils/target.py (3 hunks)
✅ Files skipped from review due to trivial changes (2)
  • src/tl_templates/cuda/nvrtc_std.h
  • tilelang/contrib/cutedsl/.gitignore
🚧 Files skipped from review as they are similar to previous changes (6)
  • 3rdparty/tvm
  • .github/workflows/ci.yml
  • tilelang/jit/adapter/nvrtc/wrapper.py
  • tilelang/engine/lower.py
  • testing/python/jit/test_tilelang_jit_cutedsl.py
  • examples/gemm_fp8/example_tilelang_gemm_fp8.py
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.

Applied to files:

  • tilelang/contrib/cutedsl/__init__.py
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.

Applied to files:

  • src/target/rt_mod_cutedsl.cc
🧬 Code graph analysis (16)
tilelang/jit/adapter/cutedsl/checks.py (1)
tilelang/jit/__init__.py (2)
  • compile (49-118)
  • compile (350-376)
src/target/codegen_cutedsl.cc (4)
src/target/codegen_py.cc (2)
  • CheckOutermostParenthesesMatch (22-38)
  • CheckOutermostParenthesesMatch (22-22)
tilelang/language/builtin.py (6)
  • get_mbarrier (92-101)
  • no_set_max_nreg (202-204)
  • warpgroup_arrive (292-298)
  • warpgroup_commit_batch (301-307)
  • warpgroup_fence_operand (462-618)
  • set_max_nreg (164-179)
tilelang/contrib/cutedsl/cpasync.py (1)
  • tma_load (59-106)
tilelang/language/math_intrinsics.py (5)
  • ieee_add (149-172)
  • ieee_sub (175-196)
  • ieee_mul (199-220)
  • ieee_fmaf (223-247)
  • ieee_fdiv (309-330)
tilelang/jit/adapter/wrapper.py (3)
tilelang/jit/adapter/utils.py (3)
  • is_metal_target (110-111)
  • is_cutedsl_target (114-115)
  • pythonic_expr (156-286)
tilelang/jit/adapter/cutedsl/wrapper.py (1)
  • _pythonic_expr (565-567)
tilelang/jit/adapter/nvrtc/wrapper.py (1)
  • _pythonic_expr (271-276)
tilelang/contrib/cutedsl/mbar.py (2)
tilelang/carver/template/base.py (1)
  • arch (156-163)
tilelang/language/builtin.py (2)
  • mbarrier_expect_tx (280-289)
  • mbarrier_arrive (262-277)
tilelang/contrib/cutedsl/cpasync.py (1)
src/tl_templates/cuda/copy.h (1)
  • cp_async_wait (20-26)
src/target/codegen_py.h (1)
src/target/codegen_py.cc (34)
  • AddFunction (49-64)
  • AddFunction (49-49)
  • Finish (66-71)
  • Finish (66-66)
  • GetFunctionName_ (73-79)
  • GetFunctionName_ (73-73)
  • RegisterFunction_ (81-102)
  • RegisterFunction_ (81-82)
  • InitFuncState_ (104-109)
  • InitFuncState_ (104-104)
  • PrintFunctionSignature_ (111-132)
  • PrintFunctionSignature_ (111-113)
  • ReserveKeywordsAsUnique_ (134-181)
  • ReserveKeywordsAsUnique_ (134-134)
  • PrintSSAAssign (183-186)
  • PrintSSAAssign (183-184)
  • PrintType (188-217)
  • PrintType (188-189)
  • VisitStmt_ (471-487)
  • VisitStmt_ (471-471)
  • VisitStmt_ (489-491)
  • VisitStmt_ (489-489)
  • CastFromTo_ (592-600)
  • CastFromTo_ (592-593)
  • PrintBinaryExpr_ (602-620)
  • PrintBinaryExpr_ (602-605)
  • PrintCallExtern_ (634-647)
  • PrintCallExtern_ (634-638)
  • GetBufferRef_ (650-662)
  • GetBufferRef_ (650-652)
  • RegisterHandleType_ (664-672)
  • RegisterHandleType_ (664-665)
  • HandleTypeMatch_ (674-680)
  • HandleTypeMatch_ (674-675)
tilelang/utils/target.py (1)
tilelang/language/ast/ir.py (1)
  • target (1677-1707)
tilelang/contrib/cutedsl/ldsm.py (1)
tilelang/language/proxy.py (1)
  • make_tensor (278-279)
tilelang/jit/adapter/cutedsl/adapter.py (8)
tilelang/engine/param.py (1)
  • KernelParam (13-105)
tilelang/jit/adapter/wrapper.py (4)
  • wrap (144-145)
  • wrap (858-878)
  • wrap (885-912)
  • host_func (557-567)
tilelang/jit/adapter/cutedsl/checks.py (1)
  • check_cutedsl_available (33-79)
tilelang/jit/adapter/cutedsl/libgen.py (1)
  • CuTeDSLLibraryGenerator (39-146)
tilelang/utils/target.py (1)
  • determine_target (63-133)
tilelang/jit/adapter/base.py (2)
  • BaseKernelAdapter (11-96)
  • _post_init (95-96)
tilelang/jit/adapter/cutedsl/wrapper.py (2)
  • host_func (550-554)
  • host_func (557-559)
tilelang/jit/adapter/nvrtc/wrapper.py (2)
  • host_func (260-264)
  • host_func (267-269)
src/target/codegen_cutedsl.h (2)
src/target/codegen_py.h (4)
  • tvm (25-247)
  • codegen (26-246)
  • void (86-86)
  • void (92-92)
src/target/codegen_cutedsl.cc (42)
  • PrintFuncDecorator_ (60-63)
  • PrintFuncDecorator_ (60-61)
  • PreFunctionBody_ (65-70)
  • PreFunctionBody_ (65-65)
  • PrintType (141-146)
  • PrintType (141-142)
  • VisitExpr_ (148-152)
  • VisitExpr_ (148-149)
  • VisitExpr_ (154-189)
  • VisitExpr_ (154-155)
  • VisitExpr_ (191-215)
  • VisitExpr_ (191-192)
  • VisitExpr_ (217-224)
  • VisitExpr_ (217-218)
  • VisitExpr_ (225-228)
  • VisitExpr_ (225-226)
  • VisitExpr_ (229-232)
  • VisitExpr_ (229-230)
  • VisitExpr_ (265-629)
  • VisitExpr_ (265-266)
  • VisitExpr_ (631-691)
  • VisitExpr_ (631-632)
  • VisitStmt_ (693-746)
  • VisitStmt_ (693-693)
  • VisitStmt_ (748-801)
  • VisitStmt_ (748-748)
  • VisitStmt_ (803-850)
  • VisitStmt_ (803-803)
  • VisitStmt_ (852-889)
  • VisitStmt_ (852-852)
  • VisitStmt_ (891-915)
  • VisitStmt_ (891-891)
  • VisitStmt_ (917-938)
  • VisitStmt_ (917-917)
  • PrintBinaryExpr_ (1010-1019)
  • PrintBinaryExpr_ (1010-1013)
  • PrintBinaryIntrinsic_ (1021-1029)
  • PrintBinaryIntrinsic_ (1021-1023)
  • PrintCallExtern_ (1031-1157)
  • PrintCallExtern_ (1031-1035)
  • GetBufferRef_ (1188-1247)
  • GetBufferRef_ (1188-1190)
tilelang/jit/adapter/cutedsl/wrapper.py (4)
src/target/codegen_py.h (1)
  • tvm (25-247)
tilelang/jit/adapter/utils.py (3)
  • extract_python_func_declaration (57-85)
  • pythonic_expr (156-286)
  • parse_tma_descriptor_args (389-494)
tilelang/jit/adapter/nvrtc/wrapper.py (4)
  • host_func (260-264)
  • host_func (267-269)
  • _pythonic_expr (271-276)
  • update_lib_code (519-573)
src/tl_templates/cuda/nvrtc_std.h (1)
  • max (126-126)
tilelang/jit/adapter/utils.py (1)
src/target/codegen_py.h (1)
  • tvm (25-247)
src/target/codegen_py.cc (2)
src/target/codegen_cutedsl.cc (2)
  • CheckOutermostParenthesesMatch (23-39)
  • CheckOutermostParenthesesMatch (23-23)
src/target/codegen_py.h (2)
  • PrintStmt_ (113-113)
  • PrintExpr_ (119-121)
tilelang/jit/adapter/__init__.py (1)
tilelang/jit/adapter/cutedsl/adapter.py (1)
  • CuTeDSLKernelAdapter (21-341)
tilelang/jit/adapter/cutedsl/__init__.py (4)
tilelang/jit/adapter/cutedsl/checks.py (1)
  • check_cutedsl_available (33-79)
tilelang/jit/adapter/cutedsl/adapter.py (1)
  • CuTeDSLKernelAdapter (21-341)
tilelang/jit/adapter/cutedsl/wrapper.py (1)
  • TLCuTeDSLSourceWrapper (473-1252)
tilelang/jit/adapter/cutedsl/libgen.py (1)
  • CuTeDSLLibraryGenerator (39-146)
tilelang/jit/execution_backend.py (3)
tilelang/jit/adapter/utils.py (1)
  • is_cutedsl_target (114-115)
tilelang/env.py (1)
  • use_gemm_v1 (284-290)
tilelang/jit/adapter/cutedsl/checks.py (1)
  • check_cutedsl_available (33-79)
🪛 Cppcheck (2.18.0)
src/target/rt_mod_cutedsl.cc

[error] 62-62: syntax error

(syntaxError)

🪛 Ruff (0.14.8)
tilelang/jit/adapter/cutedsl/checks.py

47-47: Do not catch blind exception: Exception

(BLE001)


55-57: Avoid specifying long messages outside the exception class

(TRY003)


65-65: Avoid specifying long messages outside the exception class

(TRY003)


79-79: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/jit/adapter/wrapper.py

896-896: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/cache/kernel_cache.py

301-301: Do not catch blind exception: Exception

(BLE001)


302-302: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


445-445: Do not catch blind exception: Exception

(BLE001)


446-446: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


452-452: Do not catch blind exception: Exception

(BLE001)


453-453: Use logging.exception instead of logging.error

Replace with exception

(TRY400)

tilelang/contrib/cutedsl/__init__.py

7-7: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


8-8: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


9-9: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


11-11: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


14-14: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


15-15: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


20-20: from .mbar import * used; unable to detect undefined names

(F403)


21-21: from .cpasync import * used; unable to detect undefined names

(F403)


22-22: from .gemm_V1 import * used; unable to detect undefined names

(F403)


23-23: from .reduce import * used; unable to detect undefined names

(F403)


24-24: from .ldsm import * used; unable to detect undefined names

(F403)


25-25: from .math import * used; unable to detect undefined names

(F403)


26-26: from .threadblock_swizzle import * used; unable to detect undefined names

(F403)


61-61: bar_sync_ptx may be undefined, or defined from star imports

(F405)


126-126: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/contrib/cutedsl/mbar.py

6-6: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


7-7: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


10-10: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


11-11: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


12-12: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


17-17: Unused function argument: timeout_ns

(ARG001)

tilelang/contrib/cutedsl/cpasync.py

2-2: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


4-4: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


11-11: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


14-14: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


15-15: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


29-29: Prefer TypeError exception for invalid type

(TRY004)


29-29: Avoid specifying long messages outside the exception class

(TRY003)


35-35: Prefer TypeError exception for invalid type

(TRY004)


35-35: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/utils/target.py

104-104: Consider [*list(target_dict["keys"]), "cutedsl"] instead of concatenation

Replace with [*list(target_dict["keys"]), "cutedsl"]

(RUF005)

tilelang/contrib/cutedsl/ldsm.py

11-11: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


12-12: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/contrib/cutedsl/gemm_V1.py

3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


151-151: Unused function argument: use_wgmma

(ARG001)


152-152: Unused function argument: wg_wait

(ARG001)


192-192: Unused function argument: use_wgmma

(ARG001)


193-193: Unused function argument: wg_wait

(ARG001)


222-222: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


233-233: Unused method argument: stride_A

(ARG002)


233-233: Unused method argument: stride_B

(ARG002)


233-233: Unused method argument: offset_A

(ARG002)


233-233: Unused method argument: offset_B

(ARG002)


381-381: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


392-392: Unused method argument: stride_A

(ARG002)


392-392: Unused method argument: stride_B

(ARG002)


392-392: Unused method argument: offset_A

(ARG002)


392-392: Unused method argument: offset_B

(ARG002)


424-424: Unpacked variable tma_tensor is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

tilelang/jit/adapter/cutedsl/adapter.py

109-109: Unused class method argument: pass_configs

(ARG003)


197-197: Unused method argument: kernel_only

(ARG002)


253-253: Do not catch blind exception: Exception

(BLE001)


273-275: Avoid specifying long messages outside the exception class

(TRY003)


294-294: Avoid specifying long messages outside the exception class

(TRY003)


312-312: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/jit/adapter/cutedsl/wrapper.py

483-499: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


503-514: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


516-529: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


594-594: Avoid specifying long messages outside the exception class

(TRY003)


893-893: Unused method argument: code

(ARG002)


980-980: Consider [*inner_args, "stream: CUstream"] instead of concatenation

Replace with [*inner_args, "stream: CUstream"]

(RUF005)


991-991: Unused method argument: function_name

(ARG002)


1043-1043: Unused method argument: device_index

(ARG002)


1069-1069: Unpacked variable dtype is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


1069-1069: Unpacked variable globalAddress is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


1237-1237: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/jit/adapter/utils.py

41-41: Unused function argument: annotation

(ARG001)


54-54: Avoid specifying long messages outside the exception class

(TRY003)


85-85: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/jit/adapter/__init__.py

7-7: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/jit/adapter/cutedsl/__init__.py

6-11: __all__ is not sorted

Apply an isort-style sorting to __all__

(RUF022)


13-13: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


14-14: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


15-15: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


16-16: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/jit/adapter/cutedsl/libgen.py

20-20: Starting a process with a partial executable path

(S607)


28-28: Starting a process with a partial executable path

(S607)


78-78: Unused method argument: timeout

(ARG002)


78-78: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


92-92: Unused noqa directive (non-enabled: SIM115)

Remove unused noqa directive

(RUF100)


99-99: subprocess call: check for execution of untrusted input

(S603)


100-108: Starting a process with a partial executable path

(S607)


115-115: Unused noqa directive (non-enabled: SIM115)

Remove unused noqa directive

(RUF100)


136-136: subprocess call: check for execution of untrusted input

(S603)


138-138: Avoid specifying long messages outside the exception class

(TRY003)


146-146: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/jit/execution_backend.py

79-81: Avoid specifying long messages outside the exception class

(TRY003)

🪛 Shellcheck (0.11.0)
maint/scripts/run_local_ci_test.sh

[warning] 21-21: Use 'cd ... || exit' or 'cd ... || return' in case cd fails.

(SC2164)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (6)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.8
  • GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
  • GitHub Check: Build SDist
🔇 Additional comments (35)
maint/scripts/run_local_ci_test.sh (1)

17-20: Two-phase pytest approach is well designed.

The separation of parallel tests (excluding CuTeDSL) from isolated GEMM v1 tests is sound and well documented. The environment variable is correctly scoped to only the CuTeDSL invocation, preventing unintended side effects on other tests. The single-worker flag appropriately prevents xdist contention on the GPU.

Also applies to: 22-22, 24-25

CMakeLists.txt (1)

217-220: LGTM!

The new source files for CuTeDSL codegen and runtime are correctly added to the CUDA build configuration. The placement within the USE_CUDA block is appropriate.

tilelang/jit/adapter/utils.py (2)

114-116: LGTM!

Clean implementation for detecting CuTeDSL targets by checking both the target kind and the presence of the "cutedsl" key.


156-158: LGTM!

Good addition of the floor_div_op parameter. The default "/" preserves backward compatibility for C/C++ codegen, while allowing Python codegen to pass "//" for integer division semantics.

Also applies to: 166-169, 240-240

tilelang/contrib/cutedsl/gemm_V1.py (1)

138-177: LGTM on gemm_sr design decision.

The comment "wgmma doesn't support gemm_sr, only use SM80" is helpful. The unused use_wgmma and wg_wait parameters maintain API consistency with other gemm_* functions.

src/target/codegen_cutedsl.h (2)

21-25: LGTM on LOOP_UNROLL_THRESHOLD.

The threshold constant is well-documented with the rationale for its value. Being const gives it internal linkage, avoiding ODR violations across translation units.


27-94: LGTM on class structure.

The CodeGenTileLangCuTeDSL class is well-organized with clear separation of concerns:

  • Protected override methods for CuTeDSL-specific expression/statement handling
  • Virtual helper methods for specialized operations (vec load/store, buffer access)
  • Private state for mbarrier naming and unroll factors

The inheritance from CodeGenTileLangPY is appropriate for the Python codegen path.

tilelang/contrib/cutedsl/math.py (1)

7-8: No action required on cute.math._math_op usage.

There is no public alternative for divf in CUTLASS CuTe. The _math_op wrapper is the necessary mechanism for integrating MLIR operations (like arith.divf) into the cute.math interface, making this usage appropriate and intentional rather than a dependency on an unstable internal API.

tilelang/jit/adapter/__init__.py (1)

7-7: LGTM!

The import follows the established pattern for re-exporting kernel adapters from this package's __init__.py. The # noqa: F401 directive is consistent with other imports in this file.

tilelang/jit/adapter/nvrtc/adapter.py (1)

78-80: LGTM!

The change from tuple unpacking to dictionary-based access aligns with the updated TLPyWrapper.wrap() return signature. The NVRTC adapter correctly extracts only the fields it needs (host_func and function_names), maintaining backward compatibility while enabling the broader dictionary-based output required for CuTeDSL integration.

tilelang/jit/execution_backend.py (3)

77-90: LGTM! Validation logic is well-structured.

The helper function encapsulates the GEMM v1 and CuTeDSL availability checks cleanly. Converting ImportError to ValueError preserves the function's error contract while keeping the actionable message.


93-107: LGTM!

The auto-resolution logic correctly prioritizes CuTeDSL when the target has the cutedsl key, validating prerequisites before returning. The early return prevents falling through to the generic CUDA default.


123-126: LGTM!

Explicit cutedsl backend selection also validates GEMM v1 and CuTeDSL availability, ensuring consistent prerequisite checking regardless of how the backend is selected.

tilelang/contrib/cutedsl/mbar.py (2)

22-30: LGTM!

The mbarrier_cp_async_arrive implementation correctly uses the @dsl_user_op decorator and properly forwards the loc and ip parameters to the underlying NVVM intrinsic.


33-38: LGTM!

The fence functions correctly delegate to the cutlass arch implementations with appropriate parameters.

src/target/rt_mod_cutedsl.cc (3)

9-39: LGTM!

The ExtractFuncInfo function correctly extracts function metadata including parameter types, grid constant handling, and launch parameter tags.


62-66: LGTM!

The FFI registration follows TVM conventions. The Cppcheck syntax error is a false positive due to the TVM macro expansion.


41-60: Confirmed: This follows the established pattern for "without compile" code paths.

The ("ptx", "ptx") parameters are intentional. The same format is used identically in BuildTileLangCUDAWithoutCompile in rt_mod_cuda.cc (line 102), confirming this is the expected behavior for the "without compile" code generation path. While the generated code is CuTeDSL Python rather than actual PTX, the format strings appear to be placeholders or conventionally used for the intermediate representation handled by post-processing callbacks.

tilelang/jit/adapter/cutedsl/checks.py (1)

67-79: LGTM!

The filesystem-based check for the "cute" submodule is a smart optimization to avoid expensive imports while still verifying CuTe support is available. The check correctly handles both package (cute/) and module (cute.py) cases.

tilelang/jit/kernel.py (3)

320-333: LGTM!

The CuTeDSL adapter instantiation follows the established pattern of other backends, with proper target validation via is_cutedsl_target.


405-416: LGTM!

The database loading path for CuTeDSL correctly delegates to CuTeDSLKernelAdapter.from_database with the expected parameters.


467-476: LGTM!

The source access methods correctly include "cutedsl" in the set of backends that delegate to the adapter for source retrieval.

tilelang/jit/adapter/wrapper.py (3)

885-912: LGTM! Clean dispatch pattern for CuTeDSL target.

The conditional routing based on is_cutedsl_target and the dictionary-based return contract provide a clean interface for the new backend. The structure with host_func, function_names, and launcher-related fields aligns well with the CuTeDSL adapter requirements.

One minor note: the error message on line 896 could be moved to a custom exception class per TRY003, but this is acceptable for now given it's a configuration error that should be rare.


203-205: Good fix for C/CUDA integer division operator.

The comment correctly explains that C/C++ uses / for integer division, not //. This ensures the generated CUDA source compiles correctly.


333-376: Good typo fix: tma_descripter_inittma_descriptor_init.

Renaming the variable from tma_descripter_init to tma_descriptor_init improves code clarity.

src/target/codegen_cutedsl.cc (2)

265-629: Comprehensive intrinsic handling with clear unsupported-op fallbacks.

The extensive handling of TileLang intrinsics with explicit LOG(FATAL) for unsupported operations provides clear error messages. The pattern of delegating to CodeGenTileLangPY::VisitExpr_ for unhandled cases is appropriate for the inheritance hierarchy.


852-889: Well-structured unroll loop handling.

The ForNode visitor correctly handles:

  • Non-unrolled loops by delegating to parent
  • Unroll factor from attributes
  • range_constexpr for small constant loops
  • Full unroll flag when appropriate

This aligns with CuTeDSL's Python-based loop constructs.

tilelang/jit/__init__.py (1)

430-430: LGTM! Consistent ExecutionBackend type extension.

The cutedsl backend is consistently added across all relevant type annotations and function signatures. The type alias ExecutionBackend centralizes the definition, making future updates easier.

src/target/codegen_py.h (1)

32-48: Well-documented public API for CodeGenTileLangPY.

The public interface with AddFunction and Finish methods is clean and well-documented. The Doxygen comments clearly explain the purpose of each method.

tilelang/cache/kernel_cache.py (2)

31-34: LGTM! Clear constant definitions for CuTeDSL artifacts.

The new constants (LAUNCHER_LIB_PATH, LAUNCHER_CPP_PATH, CUTEDSL_CUBIN_PATH) are well-named and align with the CuTeDSL artifact structure.


221-225: Good: Cache path propagation for CuTeDSL cubin saving.

Setting kernel.adapter._cache_path enables the CuTeDSL adapter to save the cubin after first execution. The conditional check for execution_backend == "cutedsl" is appropriate.

tilelang/contrib/cutedsl/cpasync.py (3)

39-43: Conditional cp.async wrapper looks fine.
[cute.jit] wrapper is minimal and keeps the “conditional” behavior localized.


149-168: TMA store arrive/wait wrappers look consistent with the intended PTX mapping.


200-212: Prefetch helper looks OK; consider arch gating only if NVVM op is arch-specific.

tilelang/jit/adapter/cutedsl/adapter.py (1)

338-341: prim_func accessor is clean and keeps adapter logic decoupled from module structure.

@LeiWang1999
Copy link
Member

Awesome Contributions!

@LeiWang1999 LeiWang1999 self-requested a review December 13, 2025 13:36
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

♻️ Duplicate comments (3)
src/target/codegen_py.cc (3)

81-102: Collision check needed for name_hint fallback path.

As noted in previous reviews, the name_hint fallback at line 97 needs the same collision validation as the global_symbol path. This prevents duplicate function names when multiple functions share the same name_hint.


188-217: Missing break statements and type coverage.

As noted in previous reviews, the switch cases lack break statements causing fallthrough, and float16/default cases are missing. These must be fixed to generate correct Python type annotations.


522-536: Inconsistent expression printing in for-loop bounds.

As noted in previous reviews, lines 529-530 stream expressions directly instead of using PrintExpr_, which can produce TVM IR representation instead of valid Python syntax.

🧹 Nitpick comments (2)
src/target/codegen_cutedsl.cc (2)

1067-1080: Add validation for unsupported nested template syntax.

The comment states "Nested '<' is not supported," but there's no validation to detect or reject this case. If nested templates are encountered, the code will silently produce incorrect output.

Add a check to detect and fail fast on unsupported input:

   // Replace "<...>" with "(...)". Nested "<" is not supported
   {
     auto pos_left = global_symbol_str.find('<');
     while (pos_left != std::string::npos) {
       auto pos_right = global_symbol_str.find('>', pos_left + 1);
       if (pos_right != std::string::npos) {
         auto args =
             global_symbol_str.substr(pos_left + 1, pos_right - pos_left - 1);
+        // Check for nested '<' which is not supported
+        if (args.find('<') != std::string::npos) {
+          LOG(FATAL) << "Nested template arguments are not supported: " << global_symbol_str;
+        }
         ReplaceAll(args, "true", "True");
         ReplaceAll(args, "false", "False");

778-778: Use LOG(FATAL) for consistency with other error paths.

Line 778 uses ICHECK(false) to report an unsupported scope, while lines 758, 760, and 795 use LOG(FATAL) for the same purpose.

For consistency with the rest of this file, consider:

-      ICHECK(false) << "Unsupported scope: " << scope;
+      LOG(FATAL) << "Unsupported scope: " << scope;
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5617be1 and 98fd67a.

📒 Files selected for processing (2)
  • src/target/codegen_cutedsl.cc (1 hunks)
  • src/target/codegen_py.cc (1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.

Applied to files:

  • src/target/codegen_py.cc
🧬 Code graph analysis (2)
src/target/codegen_cutedsl.cc (1)
src/target/codegen_py.cc (4)
  • CheckOutermostParenthesesMatch (22-38)
  • CheckOutermostParenthesesMatch (22-22)
  • RemoveOutermostParentheses (40-46)
  • RemoveOutermostParentheses (40-40)
src/target/codegen_py.cc (1)
src/target/codegen_py.h (2)
  • PrintStmt_ (113-113)
  • PrintExpr_ (119-121)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.8
  • GitHub Check: Build SDist
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
🔇 Additional comments (10)
src/target/codegen_py.cc (10)

22-47: LGTM!

The parentheses matching and removal logic is correct. The algorithm properly tracks nesting depth and validates that the outermost parentheses enclose the entire expression.


49-79: LGTM!

The function registration and code generation flow is well-structured. State management and scope handling are appropriate for a code generator.


104-181: LGTM!

The function state initialization, signature generation, and keyword reservation are well-implemented. The keyword list is comprehensive and includes both Python keywords and type names that could cause conflicts.


183-186: LGTM!

The SSA assignment printing correctly removes unnecessary outermost parentheses for cleaner output.


219-255: LGTM!

The expression visitors for variables, integers, floats, and casts are correctly implemented. Note that the float16 handling at lines 247-250 depends on the fix for PrintType flagged in previous reviews.


262-366: LGTM!

The cast, arithmetic, logical, comparison, select, and ramp expression visitors are correctly implemented and generate appropriate Python syntax.


368-451: LGTM!

The call expression visitor comprehensively handles built-in operations, external calls, and function references. The fallthrough to LOG(FATAL) ensures unhandled operations are caught during development.


453-520: LGTM!

Buffer load/store, let statements, and allocation are correctly implemented. The Python list allocation [None] * size is appropriate for creating mutable storage.


538-579: LGTM!

The while, if/else, sequence, and evaluate statement visitors correctly generate Python control flow structures.


592-680: LGTM!

The utility methods for type casting, binary expressions, external calls, buffer references, and handle type tracking are correctly implemented and provide appropriate validation.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
requirements-test-cuda.txt (1)

10-11: Verify nvidia-cutclass-dsl version and consider clarifying GEMM v1 requirement.

The nvidia-cutlass-dsl version 4.3.1 is compatible with CUDA 12.x, so the dependency constraint is appropriate. However, per the PR objectives, CuTeDSL currently requires GEMM v1, and users must set TILELANG_USE_GEMM_V1=1 before importing tilelang. Consider updating the comment to document this prerequisite for clarity.

Consider revising the comment to include the GEMM v1 requirement:

-# CuTeDSL (CUTLASS Python DSL with CuTe support)
+# CuTeDSL (CUTLASS Python DSL with CuTe support; requires GEMM v1 - set TILELANG_USE_GEMM_V1=1)

This would help users understand the prerequisite when reviewing test requirements.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1a94479 and cdd8b61.

📒 Files selected for processing (1)
  • requirements-test-cuda.txt (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (6)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.8
  • GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
  • GitHub Check: Build SDist

Copy link
Collaborator

@kurisu6912 kurisu6912 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @lucifer1004 for your great contribution to tilelang!

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
src/target/codegen_cutedsl.cc (1)

843-843: Define the missing constant LOOP_UNROLL_THRESHOLD.

The constant LOOP_UNROLL_THRESHOLD is referenced but not defined in this file or in any included headers. This will cause a compilation error.

Add the constant definition at the top of the file or include the appropriate header:

 namespace tvm {
 namespace codegen {
 namespace {
+constexpr int LOOP_UNROLL_THRESHOLD = 16;  // Adjust value as appropriate

 void ReplaceAll(std::string &str, const std::string &from,
🧹 Nitpick comments (1)
src/target/codegen_cutedsl.cc (1)

25-32: Consider extracting ReplaceAll to shared utilities.

The ReplaceAll function is a general-purpose string utility that could be shared across codegen modules.

If this pattern is used in multiple codegen files, consider adding it to codegen_utils.h/cc alongside the parentheses utilities for consistency.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between cdd8b61 and 4154b79.

📒 Files selected for processing (5)
  • CMakeLists.txt (1 hunks)
  • src/target/codegen_cutedsl.cc (1 hunks)
  • src/target/codegen_py.cc (1 hunks)
  • src/target/codegen_utils.cc (1 hunks)
  • src/target/codegen_utils.h (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • CMakeLists.txt
  • src/target/codegen_py.cc
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-12T09:47:46.474Z
Learnt from: kurisu6912
Repo: tile-ai/tilelang PR: 794
File: tilelang/transform/add_bufstore_wrapper.py:30-33
Timestamp: 2025-09-12T09:47:46.474Z
Learning: In TVM's PyStmtExprMutator, visit_block_ methods typically call super().visit_block_(op) to process child nodes and update internal state, but return the original op when the block itself doesn't need transformation. The pattern `return op` is correct for blocks that serve as containers where mutations happen at deeper levels.

Applied to files:

  • src/target/codegen_cutedsl.cc
🧬 Code graph analysis (1)
src/target/codegen_utils.h (1)
src/target/codegen_utils.cc (4)
  • CheckOutermostParenthesesMatch (11-27)
  • CheckOutermostParenthesesMatch (11-11)
  • RemoveOutermostParentheses (29-35)
  • RemoveOutermostParentheses (29-29)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
  • GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.8
  • GitHub Check: Build SDist
🔇 Additional comments (7)
src/target/codegen_cutedsl.cc (4)

555-555: Verify fastmath configuration from pass config.

The fastmath=True parameter is hardcoded. Consider checking the pass configuration to allow users to control this behavior, as suggested in previous reviews.

Based on past review comments, this should respect the pass configuration for fastmath settings.


622-624: Clarify the logic for checking buffer reference suffix.

The code checks if ref.back() == ')' to determine whether to append .load(). Based on the comment block at lines 1155-1163, this is intentional: only form (4) ending with ')' represents a full tensor that needs .load(), while other forms represent scalar accesses that don't.

This logic correctly distinguishes between tensor forms (requiring .load()) and scalar accesses. The comment block at lines 1155-1163 documents this design clearly.


156-156: Scientific notation is appropriate for Python codegen.

The use of std::scientific instead of hexfloat is intentional for CuTeDSL, which generates Python code. Python doesn't support C-style hexfloat literals, so decimal scientific notation is the correct choice here.

The comment at line 154 explicitly states "For CuTeDSL, use Python decimal format instead of C hexfloat", which aligns with the implementation.


6-6: Good: Utility duplication resolved.

The inclusion of codegen_utils.h addresses the previous concern about duplicated utility functions (CheckOutermostParenthesesMatch and RemoveOutermostParentheses).

src/target/codegen_utils.h (1)

1-34: LGTM: Well-structured utility header.

The header file follows best practices with proper include guards, clear documentation, and a clean API. The functions address the code duplication concern raised in previous reviews.

src/target/codegen_utils.cc (2)

11-27: LGTM: Correct parentheses matching logic.

The implementation correctly validates that the first and last characters are matching parentheses that enclose the entire string. The balance counter approach ensures that intermediate closing parentheses don't prematurely match.


29-35: LGTM: Safe parentheses removal.

The function safely removes outermost parentheses only when they match, preventing incorrect transformations.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

♻️ Duplicate comments (4)
src/target/codegen_py.cc (3)

47-68: Prevent duplicate Python def names in the name_hint path (add collision check).

global_symbol path checks ContainsName, but the name_hint fallback reserves blindly (Line 63-65). If two GlobalVars share the same name_hint, you’ll emit duplicate function definitions.

   auto function_name = [&]() -> ffi::String {
     if (auto global_symbol =
             func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol)) {
       auto name = global_symbol.value();
       ICHECK(!func_name_supply_->ContainsName(name))
           << "Function " << gvar << " must use global symbol " << name
           << ", but this name has already been used.";
       func_name_supply_->ReserveName(name);
       return name;
     } else {
-      func_name_supply_->ReserveName(gvar->name_hint);
-      return gvar->name_hint;
+      auto name = gvar->name_hint;
+      ICHECK(!func_name_supply_->ContainsName(name))
+          << "Function " << gvar << " must use name_hint " << name
+          << ", but this name has already been used.";
+      func_name_supply_->ReserveName(name);
+      return name;
     }
   }();

235-238: Escape Python string literals in StringImm emission.

Raw op->value can contain ", \, newlines, etc., generating invalid Python.

 void CodeGenTileLangPY::VisitExpr_(const StringImmNode *op,
                                    std::ostream &os) { // NOLINT(*)
-  os << "\"" << op->value << "\"";
+  os << "\"";
+  for (char c : std::string(op->value)) {
+    switch (c) {
+      case '\"': os << "\\\""; break;
+      case '\\': os << "\\\\"; break;
+      case '\n': os << "\\n"; break;
+      case '\r': os << "\\r"; break;
+      case '\t': os << "\\t"; break;
+      default: os << c; break;
+    }
+  }
+  os << "\"";
 }

561-570: Escape assert message strings (reuse the same escaping as StringImm).

Same escaping bug as VisitExpr_(StringImmNode).

 void CodeGenTileLangPY::VisitStmt_(const AssertStmtNode *op) {
   std::string cond = PrintExpr_(op->condition);
   PrintIndent();
   if (const auto *str = op->message.as<StringImmNode>()) {
-    stream << "assert " << cond << ", \"" << str->value << "\"\n";
+    stream << "assert " << cond << ", \"";
+    for (char c : std::string(str->value)) {
+      switch (c) {
+        case '\"': stream << "\\\""; break;
+        case '\\': stream << "\\\\"; break;
+        case '\n': stream << "\\n"; break;
+        case '\r': stream << "\\r"; break;
+        case '\t': stream << "\\t"; break;
+        default: stream << c; break;
+      }
+    }
+    stream << "\"\n";
   } else {
     stream << "assert " << cond << "\n";
   }
   PrintStmt_(op->body);
 }
src/target/codegen_cutedsl.cc (1)

241-605: Keep an eye on VisitExpr_(CallNode*) size/complexity (already flagged previously).
This is still a large monolithic dispatcher; any further additions will get risky to maintain and test.

🧹 Nitpick comments (1)
src/target/codegen_cutedsl.cc (1)

47-113: DTypeToString: t.is_scalar() vs t.is_void() is inconsistent—verify TVM DataType semantics.
You ICHECK(t.is_scalar()) and then handle t.is_void(). If void is not considered scalar, this is unreachable; if it is, OK. Consider either removing the scalar check for void, or rewriting to make intent explicit.

Also applies to: 116-121

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c09af46 and f3362bc.

📒 Files selected for processing (4)
  • src/target/codegen_cutedsl.cc (1 hunks)
  • src/target/codegen_py.cc (1 hunks)
  • src/target/codegen_utils.cc (1 hunks)
  • src/target/codegen_utils.h (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/target/codegen_utils.h
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-12T09:47:46.474Z
Learnt from: kurisu6912
Repo: tile-ai/tilelang PR: 794
File: tilelang/transform/add_bufstore_wrapper.py:30-33
Timestamp: 2025-09-12T09:47:46.474Z
Learning: In TVM's PyStmtExprMutator, visit_block_ methods typically call super().visit_block_(op) to process child nodes and update internal state, but return the original op when the block itself doesn't need transformation. The pattern `return op` is correct for blocks that serve as containers where mutations happen at deeper levels.

Applied to files:

  • src/target/codegen_cutedsl.cc
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.

Applied to files:

  • src/target/codegen_cutedsl.cc
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Build SDist
  • GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.8
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (9)
tilelang/jit/adapter/cutedsl/adapter.py (3)

146-153: from_database() must set adapter.libpath (and should align the “source” field it reads) to avoid runtime AttributeError.
Right now _save_cubin_to_cache_if_needed() reads self.libpath (Line 241) but from_database() never sets it; also open(kernel_lib_path).read() is stored in kernel_global_source (Line 151) while get_kernel_source() returns device_kernel_source (Line 211), so this looks inconsistent.

@@
         adapter.lib_generator.assign_compile_flags(compile_flags)
         adapter.lib_generator.load_lib(lib_path=kernel_lib_path)
-        adapter.kernel_global_source = open(kernel_lib_path).read()
+        adapter.libpath = kernel_lib_path
+        adapter.device_kernel_source = open(kernel_lib_path).read()
         adapter.pymodule = adapter.lib_generator.pymodule

192-200: Guard buffer.strides is None in _process_dynamic_symbolic().
TVM can represent compact buffers with strides=None; iterating it will crash.

@@
         for i, param in enumerate(params):
             if param not in buffer_map:
                 continue
             buffer = buffer_map[param]
+            if buffer.strides is None:
+                continue
             for j, stride in enumerate(buffer.strides):
                 if isinstance(stride, tir.Var):
                     unique_push_back(stride, (1, i, j))

262-333: Argument marshalling is unsafe: scalar params + PrimFunc param-index mapping can index the wrong ins[] (or go OOB).
dynamic_symbolic_map stores PrimFunc param indices (Line 165-166), but _wrap_forward_from_prebuild_lib() indexes into ins[...] (inputs-only) at Lines 293-317; this breaks as soon as you have outputs, scalars, or non-tensor params.

A safer pattern is: materialize param_values in PrimFunc-param order, allocate outputs into it, and resolve dynamic vars against param_values[...] (ensuring the referenced value is a tensor).

@@
-    def _wrap_forward_from_prebuild_lib(self, *ins: list[torch.Tensor], stream: int | None = None):
+    def _wrap_forward_from_prebuild_lib(self, *ins: Any, stream: int | None = None):
@@
-        ins_idx = 0
-        args = []
+        ins_idx = 0
+        # Map provided inputs into PrimFunc-param order (excluding allocated outputs).
+        param_values: list[Any] = [None] * len(self.params)
+        for i in range(len(self.params)):
+            if i in self.result_idx:
+                continue
+            param_values[i] = ins[ins_idx]
+            ins_idx += 1
+
+        first_tensor = next((v for v in param_values if isinstance(v, torch.Tensor)), None)
+        if first_tensor is None:
+            raise ValueError("Expected at least one torch.Tensor argument to infer device")
+
+        args: list[Any] = []
@@
             if i in self.result_idx:
                 dtype = self.param_dtypes[i]
                 shape = []
                 # Now working with native Python list, no FFI calls needed
                 for s in self.param_shapes[i]:
                     if isinstance(s, tir.Var):
-                        ref_id, ref_tensor_idx, ref_dim_idx = self.dynamic_symbolic_map[s]
+                        ref_id, ref_param_idx, ref_dim_idx = self.dynamic_symbolic_map[s]
+                        ref_val = param_values[ref_param_idx]
+                        if not isinstance(ref_val, torch.Tensor):
+                            raise TypeError(
+                                f"Dynamic var {s} refers to non-tensor param at index {ref_param_idx}"
+                            )
                         if ref_id == 0:
-                            shape.append(ins[ref_tensor_idx].shape[ref_dim_idx])
+                            shape.append(ref_val.shape[ref_dim_idx])
                         elif ref_id == 1:
                             # Stride vars are not expected in output shapes, but handle defensively.
-                            shape.append(ins[ref_tensor_idx].stride()[ref_dim_idx])
+                            shape.append(ref_val.stride()[ref_dim_idx])
                         else:
                             raise ValueError(f"Unknown dynamic symbol ref id: {ref_id}")
                     else:  # Already converted to Python int during initialization
                         shape.append(s)
-                device = ins[0].device if len(ins) > 0 else torch.cuda.current_device()
-                tensor = torch.empty(*shape, dtype=dtype, device=device)
+                tensor = torch.empty(*shape, dtype=dtype, device=first_tensor.device)
+                param_values[i] = tensor
             else:
-                tensor = ins[ins_idx]
-                ins_idx += 1
+                tensor = param_values[i]
             args.append(tensor)
@@
         for sym in self.dynamic_symbolic_order:
-            ref_id, buffer_idx, dim_idx = self.dynamic_symbolic_map[sym]
+            ref_id, ref_param_idx, dim_idx = self.dynamic_symbolic_map[sym]
+            ref_val = param_values[ref_param_idx]
+            if not isinstance(ref_val, torch.Tensor):
+                raise TypeError(f"Dynamic symbolic var {sym} refers to non-tensor param at index {ref_param_idx}")
             if ref_id == 0:
-                args.append(ins[buffer_idx].shape[dim_idx])
+                args.append(ref_val.shape[dim_idx])
             elif ref_id == 1:
-                args.append(ins[buffer_idx].stride()[dim_idx])
+                args.append(ref_val.stride()[dim_idx])
             else:
                 raise ValueError(f"Unknown dynamic symbol ref id: {ref_id}")
src/target/codegen_cutedsl.cc (6)

24-31: Guard ReplaceAll() against empty from (infinite loop).

 void ReplaceAll(std::string &str, const std::string &from,
                 const std::string &to) {
+  ICHECK(!from.empty()) << "ReplaceAll(): `from` must be non-empty";
   auto pos = str.find(from);
   while (pos != std::string::npos) {
     str.replace(pos, from.size(), to);
     pos = str.find(from, pos + to.size());
   }
 }

193-200: Integer division emission uses // (Python floor), which can diverge from TIR semantics for negatives.

If negative operands are possible, either (a) emit a trunc-toward-zero form, or (b) enforce/ICHECK non-negativity for both operands before using //.


375-426: eviction_policy_names_ indexing and need_reduce assume IntImm + in-range; add validation to avoid OOB/null deref crashes.

At minimum, validate op->args.back() / op->args[...-2] are IntImmNode and bounds-check the eviction policy index before subscripting.


607-667: Vector load/store: guard lane math + don’t assume RampNode in non-contiguous path.

Please add:

  • ICHECK_GE(value_lanes, element_dtype.lanes())
  • ICHECK_EQ(value_lanes % element_dtype.lanes(), 0)
  • Only attempt arith::ramp(..., value_lanes / element_dtype.lanes()) when the quotient is > 0
  • In non-contiguous path, ICHECK(index.as<RampNode>()) << "..."; once before the loop (not per-iteration)

Also applies to: 669-722


737-742: shared.dyn hardcodes alignment/div_by=1024; make dtype-aware to avoid incorrect alignment assumptions.

Consider using op->dtype.bytes() * op->dtype.lanes() (like the shared path) for both alignment= and div_by=.


1213-1222: GetBufferRef_: buffer_size silently truncates if bitwidths aren’t divisible; fail fast.

   } else {
-    int buffer_size =
-        t.bits() * t.lanes() /
-        (buffer_element_dtype.bits() * buffer_element_dtype.lanes());
+    const int num = t.bits() * t.lanes();
+    const int den =
+        buffer_element_dtype.bits() * buffer_element_dtype.lanes();
+    ICHECK_GT(den, 0);
+    ICHECK_EQ(num % den, 0)
+        << "Cannot form view: requested dtype " << t
+        << " is not divisible by buffer dtype " << buffer_element_dtype;
+    const int buffer_size = num / den;
 
     std::ostringstream os;
     os << "tl.make_tensor_at_offset(" << ptr_str << ", " << index_str << ", ("
        << buffer_size << ",), div_by=" << buffer_size << ")";
     return os.str();
   }
🧹 Nitpick comments (1)
tilelang/jit/adapter/cutedsl/adapter.py (1)

203-212: Consider keeping get_kernel_source() behavior consistent with BaseKernelAdapter.get_kernel_source().
Base adapter supports kernel_only=False; here kernel_only is unused and you always return device_kernel_source. If intentional, consider dropping the param or mirroring the base behavior (e.g., returning host+device when available).

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f3362bc and 702c1b1.

📒 Files selected for processing (2)
  • src/target/codegen_cutedsl.cc (1 hunks)
  • tilelang/jit/adapter/cutedsl/adapter.py (1 hunks)
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2025-09-12T09:47:46.474Z
Learnt from: kurisu6912
Repo: tile-ai/tilelang PR: 794
File: tilelang/transform/add_bufstore_wrapper.py:30-33
Timestamp: 2025-09-12T09:47:46.474Z
Learning: In TVM's PyStmtExprMutator, visit_block_ methods typically call super().visit_block_(op) to process child nodes and update internal state, but return the original op when the block itself doesn't need transformation. The pattern `return op` is correct for blocks that serve as containers where mutations happen at deeper levels.

Applied to files:

  • src/target/codegen_cutedsl.cc
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.

Applied to files:

  • src/target/codegen_cutedsl.cc
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.

Applied to files:

  • src/target/codegen_cutedsl.cc
🧬 Code graph analysis (2)
src/target/codegen_cutedsl.cc (2)
src/target/codegen_py.h (2)
  • PrintExpr_ (119-121)
  • PrintStmt_ (113-113)
src/target/codegen_utils.cc (2)
  • RemoveOutermostParentheses (29-35)
  • RemoveOutermostParentheses (29-29)
tilelang/jit/adapter/cutedsl/adapter.py (3)
tilelang/utils/target.py (1)
  • determine_target (63-133)
tilelang/jit/adapter/base.py (1)
  • _post_init (95-96)
tilelang/jit/adapter/nvrtc/adapter.py (4)
  • _process_dynamic_symbolic (155-175)
  • prim_func (265-267)
  • get_kernel_source (177-188)
  • _convert_torch_func (254-262)
🪛 Ruff (0.14.8)
tilelang/jit/adapter/cutedsl/adapter.py

48-48: Avoid specifying long messages outside the exception class

(TRY003)


112-112: Unused class method argument: pass_configs

(ARG003)


124-124: Avoid specifying long messages outside the exception class

(TRY003)


203-203: Unused method argument: kernel_only

(ARG002)


259-259: Do not catch blind exception: Exception

(BLE001)


279-281: Avoid specifying long messages outside the exception class

(TRY003)


300-300: Avoid specifying long messages outside the exception class

(TRY003)


318-318: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.8
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
  • GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
  • GitHub Check: Build SDist

@lucifer1004 lucifer1004 marked this pull request as draft December 15, 2025 13:21
@lucifer1004

This comment was marked as outdated.

@lucifer1004 lucifer1004 marked this pull request as ready for review December 15, 2025 16:24
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
src/target/codegen_cutedsl.cc (1)

671-734: Vector BufferStore path still lacks the lane‑count safety checks added for loads

The vector load path defends against invalid lane relationships:

if (value_lanes == element_dtype.lanes()) {
  ...
} else {
  ICHECK_GE(value_lanes, element_dtype.lanes())
      << "Unsupported load/store: value lanes < buffer element lanes";
  ...
  if (arith::ramp(base, 1, value_lanes / element_dtype.lanes()).Match(index)) {
    ...
  } else {
    const RampNode* ramp = index.as<RampNode>();
    ICHECK(ramp) << "Expected Ramp index for vectorized non-contiguous access";
    ...
  }
}

The store path, however, still does:

int value_lanes = value_dtype.lanes();
...
} else {
  bool is_contiguous = false;
  arith::PVar<PrimExpr> base;
  if (arith::ramp(base, 1, value_lanes / element_dtype.lanes())
          .Match(index_expr)) {
    is_contiguous = true;
  }
  ...
  for (int i = 0; i < value_lanes; ++i) {
    const RampNode *ramp = index_expr.as<RampNode>();
    ICHECK(ramp);
    ...
  }
}

If malformed IR ever produces value_lanes < element_dtype.lanes(), value_lanes / element_dtype.lanes() becomes 0 and the pattern match degenerates; the non‑contiguous branch also assumes index_expr is a RampNode without checking, which can crash.

Mirroring the load‑side checks would make this more robust:

   } else {
-    bool is_contiguous = false;
-    arith::PVar<PrimExpr> base;
-    if (arith::ramp(base, 1, value_lanes / element_dtype.lanes())
-            .Match(index_expr)) {
-      is_contiguous = true;
-    }
+    ICHECK_GE(value_lanes, element_dtype.lanes())
+        << "Unsupported store: value lanes < buffer element lanes";
+    bool is_contiguous = false;
+    arith::PVar<PrimExpr> base;
+    const int stride_count = value_lanes / element_dtype.lanes();
+    if (stride_count > 0 &&
+        arith::ramp(base, 1, stride_count).Match(index_expr)) {
+      is_contiguous = true;
+    }
@@
-    } else {
+    } else {
       ICHECK(element_dtype.is_scalar())
           << "buffer element type for non-contiguous store must be scalar "
              "currently";
@@
-      for (int i = 0; i < value_lanes; ++i) {
-        const RampNode *ramp = index_expr.as<RampNode>();
-        ICHECK(ramp);
+      const RampNode *ramp = index_expr.as<RampNode>();
+      ICHECK(ramp)
+          << "Expected Ramp index for vectorized non-contiguous store";
+      for (int i = 0; i < value_lanes; ++i) {
         auto idx_expr =
             arith::Analyzer().Simplify(ramp->base + ramp->stride * i);

This keeps load/store behavior symmetric and avoids silent misbehavior on bad IR.

Also applies to: 736-789

🧹 Nitpick comments (4)
tilelang/contrib/cutedsl/__init__.py (4)

7-15: Star imports and noqa: F401 pragmas: fine but noisy for tooling

The from ... import * plus # noqa: F401 pattern works here as an aggregator, but Ruff is flagging unused noqa directives and can’t see that bar_sync_ptx comes from reduce.py. If you want cleaner static-analysis output, consider:

  • Dropping the # noqa: F401 comments now that F401 isn’t enabled, and/or
  • Replacing import * with explicit re‑exports and an __all__ in this package.

Given this is a convenience facade, it’s reasonable to leave as‑is if you’re okay with the Ruff noise.

Also applies to: 20-26, 61-63


45-49: Tighten make_tensor_at_offset around div_by and as_numeric usage

This helper is central to all tl.make_tensor_at_offset(...) codegen, so the div_by contract matters:

  • You import as_numeric from cutlass.base_dsl.typing but call cutlass.as_numeric(offset) instead. If these are aliases, using the imported as_numeric makes the dependency clearer and avoids relying on a top‑level re‑export.
  • It may be worth asserting div_by >= 1 (and maybe documenting that it’s in “elements” units) to catch accidental zero/negative values before they hit cute.assume(..., divby=div_by).

Both changes are small but make the semantics of the div_by path easier to reason about, especially since there are open questions around div_by usage elsewhere in this PR.


51-59: Consider guarding shuffle_elect against non‑warp‑multiple thread_extent

The implementation assumes thread_extent is 0 or a multiple of 32 (thread_extent // 32 is used as the number of warps per group). If someone accidentally passes a non‑multiple (e.g. 48), this will compute thread_extent // 32 == 1 but the semantics become unclear.

A light defensive check such as:

 def shuffle_elect(thread_extent):
-    # thread_extent is the number of threads of a warpgroup
+    # thread_extent is the number of threads of a warpgroup
+    assert thread_extent == 0 or thread_extent % 32 == 0, \
+        "shuffle_elect expects thread_extent to be 0 or a multiple of 32"

would make misuse much easier to diagnose, and it matches the intended warp‑group semantics from the C++ intrinsics.


66-100: pack_half2 / AtomicAdd implementation looks correct; only minor polish possible

  • pack_half2 correctly bitcasts fp16 values to i16 and uses a single mov.b32 to pack them. Since it’s a pure bit‑manipulation op, you could set has_side_effects=False on the inline asm to give NVVM more freedom to optimize, but functionally this is fine.
  • AtomicAdd correctly routes to nvvm.atomicrmw for Float32/Int32 with relaxed GPU scope and returns the updated value wrapped in the pointer’s dtype. The explicit ValueError for unsupported dtypes is good fail‑fast behavior; if you expect bf16 or other integer widths in the near future, this is the obvious extension point.

No blocking issues here; these helpers look ready for use as‑is.

Also applies to: 103-127

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9021d8c and f996239.

📒 Files selected for processing (3)
  • src/target/codegen_cutedsl.cc (1 hunks)
  • src/target/codegen_cutedsl.h (1 hunks)
  • tilelang/contrib/cutedsl/__init__.py (1 hunks)
🧰 Additional context used
🧠 Learnings (7)
📚 Learning: 2025-12-15T07:23:46.620Z
Learnt from: cherichy
Repo: tile-ai/tilelang PR: 1421
File: tilelang/contrib/cutedsl/cpasync.py:45-55
Timestamp: 2025-12-15T07:23:46.620Z
Learning: In tilelang/contrib/cutedsl/cpasync.py, using AddressSpace.generic for TMA descriptor pointers (tensormap_ptr) in the extract_tensormap_ptr function is correct. When creating ptr_type with _cute_ir.PtrType.get for TMA descriptors in CuTeDSL, AddressSpace.generic should be used, not a device-specific or constant address space.

Applied to files:

  • tilelang/contrib/cutedsl/__init__.py
  • src/target/codegen_cutedsl.cc
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.

Applied to files:

  • tilelang/contrib/cutedsl/__init__.py
  • src/target/codegen_cutedsl.cc
📚 Learning: 2025-12-15T07:48:36.835Z
Learnt from: cherichy
Repo: tile-ai/tilelang PR: 1421
File: src/target/codegen_cutedsl.cc:789-793
Timestamp: 2025-12-15T07:48:36.835Z
Learning: In tilelang/contrib/cutedsl, the `tl.make_rmem_tensor` function accepts both an Integer and a Tuple of Integer for its shape parameter. Therefore, both `tl.make_rmem_tensor(N, ...)` and `tl.make_rmem_tensor((N,), ...)` are valid syntaxes in CuteDSL-generated code.

Applied to files:

  • tilelang/contrib/cutedsl/__init__.py
  • src/target/codegen_cutedsl.cc
📚 Learning: 2025-09-12T09:47:46.474Z
Learnt from: kurisu6912
Repo: tile-ai/tilelang PR: 794
File: tilelang/transform/add_bufstore_wrapper.py:30-33
Timestamp: 2025-09-12T09:47:46.474Z
Learning: In TVM's PyStmtExprMutator, visit_block_ methods typically call super().visit_block_(op) to process child nodes and update internal state, but return the original op when the block itself doesn't need transformation. The pattern `return op` is correct for blocks that serve as containers where mutations happen at deeper levels.

Applied to files:

  • src/target/codegen_cutedsl.cc
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.

Applied to files:

  • src/target/codegen_cutedsl.cc
📚 Learning: 2025-11-03T06:24:11.411Z
Learnt from: Rachmanino
Repo: tile-ai/tilelang PR: 1175
File: src/op/math.cc:44-52
Timestamp: 2025-11-03T06:24:11.411Z
Learning: In tilelang's `src/op/math.cc`, the `infinity_op` function uses `std::numeric_limits<float>::infinity()` as a placeholder for all float types (including float64 and bfloat16). The codegen layer (PrintConst:Inf) handles the correct infinity value based on the dtype field of the FloatImm node, so the specific C++ template argument doesn't matter.

Applied to files:

  • src/target/codegen_cutedsl.cc
📚 Learning: 2025-12-15T08:56:18.649Z
Learnt from: cherichy
Repo: tile-ai/tilelang PR: 1421
File: tilelang/contrib/cutedsl/reduce.py:161-184
Timestamp: 2025-12-15T08:56:18.649Z
Learning: In Tilelang's CUDA backend and CuTeDSL backend, barrier IDs 1 and 2 are reserved for internal use (such as in AllReduce operations). User-defined barriers should use IDs starting from 3 to avoid synchronization conflicts.

Applied to files:

  • src/target/codegen_cutedsl.cc
🧬 Code graph analysis (2)
tilelang/contrib/cutedsl/__init__.py (5)
tilelang/language/builtin.py (1)
  • sync_threads (702-709)
src/tl_templates/cuda/intrin.h (2)
  • warpgroup_reg_alloc (124-126)
  • warpgroup_reg_dealloc (128-130)
tilelang/language/proxy.py (1)
  • make_tensor (278-279)
tilelang/tileop/gemm/gemm_base.py (1)
  • mbar (127-128)
tilelang/contrib/cutedsl/reduce.py (1)
  • bar_sync_ptx (97-108)
src/target/codegen_cutedsl.cc (3)
tilelang/contrib/cutedsl/cpasync.py (1)
  • tma_load (60-107)
tilelang/language/math_intrinsics.py (1)
  • ieee_add (149-172)
tilelang/language/reduce.py (3)
  • warp_reduce_sum (332-345)
  • warp_reduce_max (348-361)
  • warp_reduce_min (364-377)
🪛 Ruff (0.14.8)
tilelang/contrib/cutedsl/__init__.py

7-7: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


8-8: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


9-9: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


11-11: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


14-14: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


15-15: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


20-20: from .mbar import * used; unable to detect undefined names

(F403)


21-21: from .cpasync import * used; unable to detect undefined names

(F403)


22-22: from .gemm_V1 import * used; unable to detect undefined names

(F403)


23-23: from .reduce import * used; unable to detect undefined names

(F403)


24-24: from .ldsm import * used; unable to detect undefined names

(F403)


25-25: from .math import * used; unable to detect undefined names

(F403)


26-26: from .threadblock_swizzle import * used; unable to detect undefined names

(F403)


62-62: bar_sync_ptx may be undefined, or defined from star imports

(F405)


127-127: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: Build SDist
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
  • GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.8
🔇 Additional comments (7)
src/target/codegen_cutedsl.cc (6)

25-39: Loop unroll threshold and cutlass.range_constexpr usage look good

Defining LOOP_UNROLL_THRESHOLD = 64 in the anonymous namespace and using it to decide between cutlass.range_constexpr vs cutlass.range(..., unroll_full=True) is a sensible trade‑off: it avoids the CuTe DSL compile‑time explosion for long static loops while still fully unrolling shorter ones.

The ForNode handling correctly:

  • Simplifies min/extent with arith::Analyzer,
  • Respects explicit pragma_unroll_factor when present,
  • Falls back to the Python codegen for non‑unrolled loops.

No issues here from a correctness or maintainability standpoint.

Also applies to: 894-931


81-147: DType translation and float literal emission are robust

  • DTypeToString covers the supported CuTeDSL scalar dtypes (floats, the various float8/float6/float4 encodings, ints/uints, bool) and fails fast with LOG(FATAL) for anything unsupported, which is appropriate at codegen time.
  • PrintType enforces scalar‑only printing for CuTeDSL types, catching accidental vector types early.
  • VisitExpr_(FloatImmNode*) now emits infinities/NaNs via float('inf') / float('nan') and finite values via float.fromhex(...) with std::hexfloat, which gives full‑precision round‑trip behavior and matches the earlier review feedback.

This section looks correct and matches CuTeDSL expectations.

Also applies to: 150-156, 163-199


227-247: Integer div / fast‑math mapping is acceptable under current assumptions

The combination of:

  • CanonicalizeFastmathFunctionName_ mapping common C math names to tl.* intrinsics, and
  • VisitExpr_(DivNode*) using:
    • // for integer types, and
    • tl.divf(..., fastmath=True) vs tl.divf(...) depending on enable_fastmath_

is consistent with the rest of the backend and PassContext‑controlled fast‑math policy.

Given earlier clarification that integer division here is only used for non‑negative indices, Python // semantics are fine; if that ever changes, we’d need a truncating‑toward‑zero path. For now, this looks good.

Also applies to: 52-67


280-491: CallNode CuTeDSL lowering is comprehensive; tl_shuffle_elect wiring matches intent

The CallNode visitor:

  • Cleanly handles the cp.async/mbarrier/TMA family with argument validation (e.g., eviction policy checks, tma_store IntImm guards) and clear LOG(FATAL) paths for unsupported variants.
  • Routes tl.tl_gemm through PrintCallExtern_ and adds the correct tl.gemm_* keyword arguments.
  • Maps tl.pack_b16 to tl.pack_half2 and uses tl.shuffle_elect in conjunction with IfThenElse’s with cute.arch.elect_one(): to mimic the C++ tl_shuffle_elect semantics (one elected lane per warp group).

The number of LOG(FATAL) branches is expected at this stage of CuTeDSL support and will make missing intrinsics obvious at codegen time. No functional issues stand out in this block.

Also applies to: 611-612, 933-947


791-839: Dynamic/shared allocation and buffer pointer/ref helpers look consistent

  • The AllocateNode handler’s shared.dyn branch emits a “fake” 1‑element tensor via tl.make_tensor(tl.get_dyn_smem(dtype, alignment=1024), (1,)), which matches the intent documented in the comment (“no bound check; just get its iterator”).
  • GetBufferPtr_ and GetBufferRef_ consistently use .iterator and, when handle types don’t match, wrap with tl.recast_ptr(..., dtype=...), before building tl.make_tensor_at_offset(...) views. The special‑case for scalar element types in local/shared/shared.dyn/shared.barrier scopes falling back to vid[index] matches CuTeDSL’s layout expectations.

Given the earlier clarification that tl.make_rmem_tensor and related APIs accept both integer and tuple shapes, the generated (buffer_size,) and (1,) shapes are valid. No issues here.

Also applies to: 1209-1297


1314-1352: PrintStorageSync_ improvements are good; only string‑literal assumption remains

PrintStorageSync_ now:

  • Distinguishes "warp" vs "shared"/"shared.dyn" vs "global",
  • Validates barrier_id / thread_count as IntImmNode and surfaces clear error messages, and
  • Fails fast on unknown scopes with LOG(FATAL).

The remaining assumption is that args[0] is always a StringImmNode; if TIR ever relaxes that, you’d need a defensive check there as well, but as long as tvm_storage_sync is only emitted with literal scopes this is fine.

src/target/codegen_cutedsl.h (1)

21-97: Header/API surface for CodeGenTileLangCuTeDSL matches the implementation

The header cleanly exposes:

  • The CuTeDSL‑specific overrides for expr/stmt visitors and type printing,
  • The helper hooks (PrintVec*, PrintBinary*, PrintCallExtern_, GetBufferPtr_/GetBufferRef_, BindThreadIndex_, PrintStorageSync_), and
  • Internal state (mbarrier name, unroll factors, eviction policy table, fast‑math flag).

Signatures and visibility match the .cc definitions, and marking the class final is appropriate given it’s not meant for further extension. No changes needed here.

@lucifer1004
Copy link
Collaborator Author

Ready for merge now.

@LeiWang1999
Copy link
Member

@codex review

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
tilelang/jit/adapter/cutedsl/adapter.py (2)

148-153: Consider assigning pass_configs in from_database.

The from_database classmethod receives pass_configs but doesn't assign it to the adapter instance (unlike the constructor at line 74). While the library generator receives it via assign_pass_configs (line 149), storing it on the adapter would maintain consistency with other adapters and support future use cases.

Consider adding:

adapter.pass_configs = pass_configs
 adapter.lib_generator = CuTeDSLLibraryGenerator(adapter.target, adapter.verbose)
+adapter.pass_configs = pass_configs
 adapter.lib_generator.assign_compile_flags(compile_flags)

206-214: Document why kernel_only parameter is unused.

The kernel_only parameter exists for interface consistency with other adapters but isn't used in the CuTeDSL implementation since there's only device source (no separate host source in the same sense as other backends).

Consider adding a brief comment or docstring note explaining this behavior:

def get_kernel_source(self, kernel_only: bool = True) -> str | None:
    """Get the CUDA kernel source code.
    
    Note: kernel_only parameter is accepted for interface consistency
    but ignored since CuTeDSL only provides device source.
    """
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f996239 and 32fe9b4.

📒 Files selected for processing (9)
  • .github/workflows/ci.yml (1 hunks)
  • CMakeLists.txt (1 hunks)
  • tilelang/cache/kernel_cache.py (10 hunks)
  • tilelang/engine/lower.py (2 hunks)
  • tilelang/jit/__init__.py (7 hunks)
  • tilelang/jit/adapter/cutedsl/adapter.py (1 hunks)
  • tilelang/jit/adapter/nvrtc/adapter.py (1 hunks)
  • tilelang/jit/execution_backend.py (4 hunks)
  • tilelang/jit/kernel.py (7 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • tilelang/engine/lower.py
  • CMakeLists.txt
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.

Applied to files:

  • .github/workflows/ci.yml
🧬 Code graph analysis (3)
tilelang/jit/kernel.py (2)
tilelang/jit/adapter/utils.py (1)
  • is_cutedsl_target (114-115)
tilelang/jit/adapter/cutedsl/adapter.py (2)
  • from_database (102-156)
  • get_kernel_source (206-214)
tilelang/jit/adapter/nvrtc/adapter.py (1)
tilelang/jit/adapter/nvrtc/wrapper.py (2)
  • host_func (260-264)
  • host_func (267-269)
tilelang/jit/execution_backend.py (3)
tilelang/jit/adapter/utils.py (1)
  • is_cutedsl_target (114-115)
tilelang/env.py (1)
  • use_gemm_v1 (284-290)
tilelang/jit/adapter/cutedsl/checks.py (1)
  • check_cutedsl_available (33-79)
🪛 Ruff (0.14.8)
tilelang/jit/execution_backend.py

81-83: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/jit/adapter/cutedsl/adapter.py

48-48: Avoid specifying long messages outside the exception class

(TRY003)


112-112: Unused class method argument: pass_configs

(ARG003)


124-124: Avoid specifying long messages outside the exception class

(TRY003)


206-206: Unused method argument: kernel_only

(ARG002)


262-262: Do not catch blind exception: Exception

(BLE001)


282-284: Avoid specifying long messages outside the exception class

(TRY003)


297-297: Avoid specifying long messages outside the exception class

(TRY003)


312-312: Avoid specifying long messages outside the exception class

(TRY003)


319-319: Avoid specifying long messages outside the exception class

(TRY003)


333-333: Avoid specifying long messages outside the exception class

(TRY003)


339-339: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
  • GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
  • GitHub Check: Build SDist
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.8
🔇 Additional comments (17)
.github/workflows/ci.yml (2)

373-373: LGTM! Proper test isolation.

Excluding the CuTeDSL test file from the standard CUDA test suite is the right approach to prevent the GEMM v1 requirement from affecting other tests.


376-392: Excellent test isolation strategy!

The dedicated CuTeDSL test step with GEMM v1 environment variable and single-process execution is well-designed:

  • Setting TILELANG_USE_GEMM_V1=1 at the step level ensures proper isolation from GEMM v2 tests
  • Using --numprocesses=1 appropriately avoids GPU contention for these tests
  • Clear comments explain the rationale for both decisions

This approach cleanly separates CuTeDSL tests while maintaining consistency with the existing CI patterns.

tilelang/cache/kernel_cache.py (3)

298-344: CuTeDSL-specific save logic looks correct.

The branching for CuTeDSL launcher artifacts (lines 327-343) correctly handles the optional C++ launcher library and source. The conditional logic appropriately separates CuTeDSL from other backends.

As noted in the existing comments, consolidating backend-specific cache logic into the adapters themselves would improve maintainability in future iterations.


421-461: CuTeDSL load logic handles empty sources correctly.

The code correctly skips loading device/host kernel sources for CuTeDSL (lines 443-461) and initializes them as empty strings instead. This aligns with CuTeDSL's approach of reading kernel content directly from the Python module path.

The validation at line 472 properly accounts for CuTeDSL's alternative contract (empty sources are acceptable).


225-230: Cache path propagation enables post-compilation artifact saving.

Setting _cache_path on the CuTeDSL adapter (lines 228-229) allows the adapter to save the cubin after first execution—a necessary deferred step since CuTeDSL generates the cubin during the first run rather than at compile time.

This is a clean integration point between the cache and adapter layers.

tilelang/jit/adapter/nvrtc/adapter.py (1)

79-81: Wrapper API updated to dictionary-based return.

The change from direct tuple unpacking to dictionary access aligns with the expanded wrapper contract that now returns additional keys (e.g., tma_cpp_init_code, launcher_cpp_code) for CuTeDSL support. The NVRTC adapter correctly extracts only the keys it needs while ignoring the extras.

This is a clean, backward-compatible refactoring.

tilelang/jit/execution_backend.py (3)

35-36: CuTeDSL backend properly restricted to CuTeDSL targets.

The short-circuit at lines 35-36 ensures that "cutedsl" is only available when the target explicitly includes the cutedsl key. Plain CUDA targets (without the key) will fall through to line 37 and receive the standard backend list without "cutedsl".

This addresses the past review concern about preventing users from selecting "cutedsl" on non-CuTeDSL targets—they'll now get a clear validation error at line 112 instead of an assertion failure later.


79-93: GEMM v1 validation provides actionable errors.

The _require_gemm_v1_for_cutedsl() helper cleanly validates prerequisites:

  1. Checks TILELANG_USE_GEMM_V1 environment variable (lines 80-83)
  2. Lazily imports and runs check_cutedsl_available() (lines 85-92)
  3. Converts ImportError to ValueError while preserving the actionable message

This gives users clear feedback about missing dependencies or misconfiguration before compilation begins.


96-98: Auto and explicit backend resolution both enforce CuTeDSL prerequisites.

Both the auto-resolution path (lines 96-98) and the explicit "cutedsl" request path (lines 126-127) call _require_gemm_v1_for_cutedsl() to validate prerequisites before proceeding. This ensures consistent enforcement regardless of how the user selects the backend.

Also applies to: 126-127

tilelang/jit/kernel.py (3)

327-340: CuTeDSL adapter creation follows established pattern.

The new branch for execution_backend == "cutedsl" (lines 327-340) correctly:

  • Validates the target has the cutedsl key (line 328)
  • Passes all required parameters matching the adapter signature
  • Follows the same structure as other backends (tvm_ffi, ctypes, nvrtc, etc.)

This maintains consistency across the codebase.


412-423: Database restoration path includes CuTeDSL.

The from_database path (lines 412-423) mirrors the construction path, ensuring cached CuTeDSL kernels can be properly restored. This is essential for the kernel cache integration demonstrated in kernel_cache.py.


474-474: Source access methods recognize CuTeDSL adapter.

Including "cutedsl" in the backend checks at lines 474 and 482 ensures get_kernel_source() and get_host_source() will delegate to the CuTeDSL adapter's methods rather than falling back to artifact sources. This is the correct behavior for adapters that manage their own source representations.

Also applies to: 482-482

tilelang/jit/__init__.py (1)

52-52: Execution backend type annotations consistently updated.

All public API entry points (compile, par_compile, JITImpl, jit, lazy_jit) and the ExecutionBackend type alias now include "cutedsl" in their type signatures. This provides proper type checking and IDE support for the new backend option.

The updates are consistent and complete across the module.

Also applies to: 67-67, 121-121, 138-138, 259-259, 427-427, 476-476

tilelang/jit/adapter/cutedsl/adapter.py (4)

85-96: Library generation and source reading correct for CuTeDSL.

The sequence at lines 85-96 properly:

  1. Generates the CuTeDSL Python module via lib_generator.compile_lib() and load_lib()
  2. Reads the generated Python module content into device_kernel_source (line 96)

For CuTeDSL, the Python module IS the kernel representation, so storing it in device_kernel_source maintains consistency with the adapter interface while reflecting CuTeDSL's Python-based approach.


158-204: Dynamic symbol processing correctly handles shapes and strides.

The _process_dynamic_symbolic method (lines 158-204) properly:

  • Follows CuTeDSL's ordering semantics (shapes first, then strides)
  • Uses tuple encoding (id, param_index, dim_index) where id=0 for shapes, id=1 for strides
  • Handles None strides safely (lines 198-199)
  • Maintains insertion order via unique_push_back

This implementation aligns with the wrapper's get_dynamic_symbolic_set() and ensures correct argument passing during execution.


265-353: Input marshalling correctly handles scalars and dynamic shapes.

The _wrap_forward_from_prebuild_lib method properly:

  • Materializes all parameters in PrimFunc order (lines 286-293)
  • Allocates output tensors using the first tensor's device (lines 295-297, 322)
  • Resolves dynamic shapes by indexing param_values (not ins) to avoid mismatched indices (lines 309-319)
  • Appends dynamic symbolic values in the correct order (lines 329-339)

This fixes the critical issue from past reviews where indexing into ins would fail with scalars or outputs present.


225-263: Post-execution cubin caching is a clever deferred-save pattern.

The _save_cubin_to_cache_if_needed method handles CuTeDSL's unique requirement: the cubin is generated during first execution rather than at compile time. Using the _cache_path set by kernel_cache.py (line 236), it copies the cubin to the cache directory after the first run.

The one-time flag (line 231) and existence checks (lines 249, 255) prevent redundant copies. The exc_info=True logging (line 263) aids debugging if the copy fails.

@LeiWang1999 LeiWang1999 mentioned this pull request Dec 17, 2025
7 tasks
@LeiWang1999
Copy link
Member

@codex review

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

kernel_lib_path = os.path.join(cache_path, kernel_lib_path)

# Save an extra Python file for NVRTC
if self.execution_backend == "nvrtc":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels a bit hacky now that we have so many backends. Could we refactor this part to make it more maintainable? If you don’t have time right now, we can spin it off into a separate thread/issue—otherwise, LGTM.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants