-
Notifications
You must be signed in to change notification settings - Fork 370
feat(cutedsl): add CuTeDSL backend #1421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdds a CuTeDSL backend and integration across build, C++/Python codegens, runtime builder, JIT adapter/wrapper/lib generator, Cutlass‑DSL contrib helpers, kernel caching/launcher artifact support, target mapping and JIT routing for a "cutedsl" execution backend, plus tests and CI adjustments to isolate GEMM v1 tests. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant JIT as tilelang.jit
participant Lower as tilelang.engine.lower
participant CodeGen as CodeGenTileLangCuTeDSL
participant LibGen as CuTeDSLLibraryGenerator
participant Wrapper as TLCuTeDSLSourceWrapper
participant Adapter as CuTeDSLKernelAdapter
participant CUDA as CUDA Runtime
User->>JIT: compile(func, target="cutedsl")
JIT->>Lower: lower(IRModule, target)
Lower->>CodeGen: BuildTileLangCuTeDSLWithoutCompile(mod, target)
CodeGen-->>Lower: device source + metadata
Lower->>LibGen: request compile_lib (kernel.py, kernel.cubin, launcher)
LibGen->>LibGen: write kernel.py, optionally compile launcher.cpp → launcher_lib.so
LibGen-->>Wrapper: provide artifacts (py host func, cubin, launcher)
JIT->>Adapter: construct CuTeDSLKernelAdapter(..., artifacts)
User->>Adapter: call converted torch function
Adapter->>CUDA: launch via launcher_lib / cubin (stream, args)
CUDA-->>User: kernel results
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes
Possibly related PRs
Suggested labels
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: defaults Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🧰 Additional context used🧠 Learnings (1)📚 Learning: 2025-11-14T07:56:11.098ZApplied to files:
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
🔇 Additional comments (2)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
3f1def2 to
5617be1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 9
Note
Due to the large number of review comments, Critical, Major severity comments were prioritized as inline comments.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tilelang/jit/kernel.py (1)
160-160: Missing "cutedsl" infrom_databaseexecution_backend type hint.The
from_databaseclass method'sexecution_backendparameter type hint on line 160 doesn't include "cutedsl", but the__init__method on line 66 does. This inconsistency could cause type checking issues.- execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"], + execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"],
🟡 Minor comments (14)
tilelang/contrib/cutedsl/mbar.py-17-19 (1)
17-19: Unusedtimeout_nsparameter.The
timeout_nsparameter is declared with a default value but never passed toarch.mbarrier_wait. If the underlying implementation supports timeouts, this should be forwarded; otherwise, remove the parameter to avoid confusion.-def mbarrier_wait(mbar_ptr: Pointer, phase: Int, timeout_ns: Int = 10000000): +def mbarrier_wait(mbar_ptr: Pointer, phase: Int): """Waits on a mbarrier with a specified phase.""" arch.mbarrier_wait(mbar_ptr, phase)Alternatively, if
arch.mbarrier_waitaccepts a timeout parameter, forward it:def mbarrier_wait(mbar_ptr: Pointer, phase: Int, timeout_ns: Int = 10000000): """Waits on a mbarrier with a specified phase.""" - arch.mbarrier_wait(mbar_ptr, phase) + arch.mbarrier_wait(mbar_ptr, phase, timeout_ns)src/target/rt_mod_cutedsl.cc-41-60 (1)
41-60: Confirm: CUDAModuleCreate parameters are semantically inconsistent with the actual content type.In
BuildTileLangCuTeDSLWithoutCompile, the function generates Python source code viaCodeGenTileLangCuTeDSL, yet passes"ptx", "ptx"toCUDAModuleCreate. This mirrors the same pattern inBuildTileLangCUDAWithoutCompile(line 102 of rt_mod_cuda.cc), which also generates source code but uses"ptx"as a placeholder format.However, this design is problematic: "ptx" is a real GPU compilation format (Parallel Thread Execution), not an appropriate identifier for uncompiled source code. In the compiled path (
BuildTileLangCUDA), the first parameter is an actual binary and the second is the format identifier. In the uncompiled paths, using"ptx"as both a placeholder and format identifier for Python source is misleading and inconsistent with NVIDIA's standard format definitions (as used intilelang/contrib/nvrtc.pyandnvcc.py, which validate that formats must be "ptx", "cubin", or "fatbin").Consider either: (1) using distinct format identifiers for source code (e.g., "python" or "cuda_src"), (2) clarifying the design with a code comment explaining why "ptx" is used as a placeholder, or (3) investigating if this is unintended and the correct format should reflect the actual content type.
src/target/codegen_cutedsl.cc-1105-1105 (1)
1105-1105: Typo: "ommited" should be "omitted".Minor typo in comment.
tilelang/contrib/cutedsl/gemm_V1.py-179-218 (1)
179-218:gemm_rralso ignoresuse_wgmmaandwg_waitparameters.Same issue as
gemm_sr— the parameters are accepted but unused.tilelang/contrib/cutedsl/gemm_V1.py-232-247 (1)
232-247: Instance caching with__new__may cause issues with mutable state.The singleton pattern caches instances by args, but
__init__is still called on every__new__return. Thehasattr(self, "initialized")check prevents re-initialization, butinitializedis never actually set toTrue.if not hasattr(self, "initialized"): self.cta_tiler = (M, N, K) ... self.clear_accum = clear_accum + self.initialized = Truetilelang/jit/adapter/cutedsl/libgen.py-78-78 (1)
78-78: Unusedtimeoutparameter.The
timeoutparameter is declared but never used incompile_lib. Either implement timeout support for subprocess calls or remove the parameter to match the actual behavior.- def compile_lib(self, timeout: float = None): + def compile_lib(self, timeout: float | None = None): target = self.target if is_cutedsl_target(target): ... + # Pass timeout to subprocess.run calls if provided + subprocess.run(..., timeout=timeout)Committable suggestion skipped: line range outside the PR's diff.
tilelang/contrib/cutedsl/cpasync.py-170-197 (1)
170-197: Docstring parametersizevs actual parametercp_sizemismatch.The docstring documents a
sizeparameter (lines 181-182) but the actual parameter iscp_size. Also, the docstring forcp_size(lines 185-186) duplicates the description.""" Asynchronously copy data from global memory to shared memory. :param dst: Destination pointer in shared memory :type dst: Pointer :param src: Source pointer in global memory :type src: Pointer - :param size: Size of the copy in bytes - :type size: Int + :param cp_size: Size of the copy in bytes + :type cp_size: Int :param modifier: Cache modifier - :type modifier: Int - :param cp_size: Optional copy size override - :type cp_size: Int + :type modifier: nvvm.LoadCacheModifierKind + :param src_size: Optional source size override (defaults to cp_size) + :type src_size: Int """tilelang/contrib/cutedsl/gemm_V1.py-138-176 (1)
138-176:gemm_srignoresuse_wgmmaandwg_waitparameters.The function signature includes
use_wgmmaandwg_waitbut the comment states "wgmma doesn't support gemm_sr" and always uses SM80. These parameters should either be removed or a warning/error raised whenuse_wgmma=True.def gemm_sr( ... use_wgmma=None, wg_wait=0, ... ): """GEMM with A from shared memory and B from register/fragment""" - # wgmma doesn't support gemm_sr, only use SM80 + if use_wgmma: + raise ValueError("wgmma does not support gemm_sr (A from shared, B from register)") gemm = Gemm_SM80(...)tilelang/contrib/cutedsl/reduce.py-140-159 (1)
140-159: Potential infinite recursion ifscaleis never reached.The
runmethod recurses untiloffset == self.scale. Ifscaleis set incorrectly (e.g., larger than initialthreads/2), this could recurse indefinitely. Consider adding a guard.return ( x if offset == self.scale - else AllReduce(self.reducer, offset, self.scale, self.thread_offset, self.all_threads).run(x, red_buf) + else AllReduce(self.reducer, offset, self.scale, self.thread_offset, self.all_threads).run(x, red_buf) + if offset > self.scale + else x # Reached minimum offset )Alternatively, document the invariant that
scalemust be a power-of-2 divisor ofthreads/2.tilelang/contrib/cutedsl/gemm_V1.py-380-411 (1)
380-411: Same issues inGemm_SM90:_instancesannotation and missinginitialized = True.Applying the same fixes as for
Gemm_SM80.+from typing import ClassVar + class Gemm_SM90: - _instances = {} # cache instances for the same arguments + _instances: ClassVar[dict] = {} ... def __init__(self, ...): if not hasattr(self, "initialized"): ... self.clear_accum = clear_accum + self.initialized = Truesrc/target/codegen_py.h-162-162 (1)
162-162: Typo in comment.Minor typo: "statment" should be "statement".
- // statment + // statementtilelang/jit/adapter/cutedsl/adapter.py-93-93 (1)
93-93: File handle not properly closed.Using
open().read()without a context manager or explicit close can leak file handles.- self.device_kernel_source = open(self.libpath).read() + with open(self.libpath) as f: + self.device_kernel_source = f.read()tilelang/jit/adapter/cutedsl/wrapper.py-1224-1224 (1)
1224-1224: Incorrect type annotation.
list[str] = Nonesuggests the variable holds a list, but it's initialized toNone. UseOptional[list[str]]orlist[str] | None.- function_params: list[str] = None + function_params: list[str] | None = Nonetilelang/jit/adapter/cutedsl/adapter.py-145-145 (1)
145-145: Same file handle issue infrom_database.- adapter.kernel_global_source = open(kernel_lib_path).read() + with open(kernel_lib_path) as f: + adapter.kernel_global_source = f.read()
🧹 Nitpick comments (42)
src/tl_templates/cuda/nvrtc_std.h (1)
176-176: Good practice: documenting closing preprocessor directive.Adding a comment to the closing
#endifimproves readability, especially when the corresponding#ifdef __CUDACC_RTC__(line 20) is far above. This clearly documents intent and aids future maintenance.maint/scripts/run_local_ci_test.sh (1)
21-21: Consider error handling for cd command.While unlikely to fail in practice, adding error handling improves script robustness.
Apply this diff:
-cd testing/python +cd testing/python || exit 1Based on static analysis hints.
tilelang/jit/adapter/utils.py (1)
41-54: Unusedannotationparameter.The
annotationparameter is declared but never used in the function body. Unlikematch_declare_kernelandmatch_declare_kernel_cpuwhich use their annotation parameter, this function always matches@cute.kernel.Either remove the unused parameter or incorporate it into the pattern:
-def match_declare_kernel_cutedsl(source: str, annotation: str = "@cute.kernel") -> int: +def match_declare_kernel_cutedsl(source: str) -> int: # Match decorator followed by function definition across lines # \s+ allows any whitespace including newlines between decorator and def pattern = r"@cute\.kernel\s+def\s+(\w+)"tilelang/utils/target.py (2)
99-106: LGTM - CuTeDSL target creation logic.The approach of creating a CUDA target and augmenting its keys with "cutedsl" correctly establishes CuTeDSL as a CUDA-based backend. This enables downstream dispatch logic to identify and route to CuTeDSL-specific codegen paths.
Minor style improvement per Ruff suggestion - use unpacking syntax:
- target_dict["keys"] = list(target_dict["keys"]) + ["cutedsl"] + target_dict["keys"] = [*list(target_dict["keys"]), "cutedsl"]
127-132: Redundant Target instance check.Lines 127-128 early-return if
return_varis aTarget, but lines 130-131 perform the same check immediately after. The early return is only reachable whenreturn_varis aTarget(from line 106 or 110), making the second check at line 130 dead code in that path.Consider consolidating:
if isinstance(return_var, Target): return return_var - if return_object: - if isinstance(return_var, Target): - return return_var + if return_object: return Target(return_var) return return_vartilelang/contrib/cutedsl/mbar.py (1)
6-12: Remove unnecessarynoqadirectives.The
# noqa: F401comments are unnecessary since these imports are intentional re-exports for the module's public API, and the F401 rule isn't enabled in your configuration.-from cutlass.cute.typing import Pointer, Int, Boolean # noqa: F401 -from cutlass.cutlass_dsl import CuTeDSL, dsl_user_op # noqa: F401 +from cutlass.cute.typing import Pointer, Int, Boolean +from cutlass.cutlass_dsl import CuTeDSL, dsl_user_op -from cutlass.cute.arch import mbarrier_init, mbarrier_expect_tx, mbarrier_arrive # noqa: F401 -from cutlass.cute.arch import mbarrier_arrive_and_expect_tx as arrive_and_expect_tx # noqa: F401 -from cutlass.cute.arch import cp_async_mbarrier_arrive_noinc as mbarrier_cp_async_arrive_noinc # noqa: F401 +from cutlass.cute.arch import mbarrier_init, mbarrier_expect_tx, mbarrier_arrive +from cutlass.cute.arch import mbarrier_arrive_and_expect_tx as arrive_and_expect_tx +from cutlass.cute.arch import cp_async_mbarrier_arrive_noinc as mbarrier_cp_async_arrive_noinctilelang/contrib/cutedsl/__init__.py (4)
23-29: Consider defining__all__in submodules for better traceability.The star imports make it difficult to trace which symbols are exposed. While this pattern is common for re-export modules, consider adding explicit
__all__definitions in each submodule (.mbar,.cpasync, etc.) to document the intended public API and help static analysis tools.
53-60: Verify warp size assumption inshuffle_elect.The function assumes a warp size of 32 (hardcoded divisor). While this is correct for current NVIDIA GPUs, consider adding a comment or using a named constant for clarity:
def shuffle_elect(thread_extent): # thread_extent is the number of threads of a warpgroup + WARP_SIZE = 32 warp_idx = cute.arch.warp_idx() warp_idx = cute.arch.make_warp_uniform(warp_idx) if thread_extent == 0: return warp_idx == 0 else: - return (warp_idx % (thread_extent // 32)) == 0 + return (warp_idx % (thread_extent // WARP_SIZE)) == 0
88-98: Consider removinghas_side_effects=Truefor pure data packing.The
mov.b32instruction is a pure data movement with no side effects—it only packs two 16-bit values into a 32-bit register. Settinghas_side_effects=Truemay prevent beneficial compiler optimizations like dead code elimination or instruction reordering.packed_xy = llvm.inline_asm( Int32.mlir_type, [x_i16, y_i16], "mov.b32 $0, {$1, $2};", "=r,h,h", - has_side_effects=True, + has_side_effects=False, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, loc=loc, ip=ip, )
105-130: Consider expandingAtomicAddto support additional dtypes.The function currently supports only
Float32andInt32. Common use cases might also requireFloat64(double precision) orInt64support. Consider documenting this limitation or expanding support if needed.tilelang/contrib/cutedsl/math.py (1)
11-44: Add return type annotations for consistency.The
divffunction has a return type annotation, but the other math functions are missing them. For API consistency and better IDE support:-def exp(x: Union[TensorSSA, Numeric], fastmath: bool = True): +def exp(x: Union[TensorSSA, Numeric], fastmath: bool = True) -> Union[TensorSSA, Numeric]: return cute.math.exp(x, fastmath=fastmath)Apply similar changes to
exp2,log,log2,log10,tan,cos,sin, andsqrt.src/target/codegen_cutedsl.cc (4)
22-39: Duplicate utility function - consider extracting to shared header.
CheckOutermostParenthesesMatchis also implemented insrc/target/codegen_py.cc(per relevant snippets). Consider extracting to a shared utility header to avoid code duplication.
91-109: Consider explicit error messages for unsupported Float8 variants.Several Float8 variants (e3m4, e4m3b11fnuz, e4m3fnuz, e5m2fnuz) silently fall through to the
LOG(FATAL)at line 134, but the error message won't indicate which specific variant was attempted. Adding explicit error messages would improve debuggability.} else if (t.is_float8_e3m4()) { - // unsupported + LOG(FATAL) << "Float8 e3m4 is not supported in CuTeDSL"; } else if (t.is_float8_e4m3()) {
265-629: Consider refactoring the large CallNode visitor.This visitor function spans ~365 lines with a long if-else chain. While it follows the established TVM codegen pattern, consider extracting groups of related operations (e.g., TMA ops, mbarrier ops, math intrinsics) into separate helper methods for improved maintainability.
1072-1073: Variable shadowing:argsshadows outer parameter.The inner
argsvariable shadows the function parameterargsfrom line 1033. Consider renaming totemplate_argsfor clarity.auto pos_right = global_symbol_str.find('>', pos_left + 1); if (pos_right != std::string::npos) { - auto args = + auto template_str = global_symbol_str.substr(pos_left + 1, pos_right - pos_left - 1); - ReplaceAll(args, "true", "True"); - ReplaceAll(args, "false", "False"); - global_symbol_str.replace(pos_left, args.size() + 2, "(" + args + ")"); + ReplaceAll(template_str, "true", "True"); + ReplaceAll(template_str, "false", "False"); + global_symbol_str.replace(pos_left, template_str.size() + 2, "(" + template_str + ")"); }tilelang/contrib/cutedsl/ldsm.py (2)
11-12: Remove unusednoqadirectives.Per static analysis hints, the
# noqa: F401directives on lines 11-12 are unused. The importsirfrom line 11 andPointer, Int32from line 12 are actually used in the code (Pointer in type hints, Int32 implicitly via cute.Int32).-from cutlass._mlir import ir # noqa: F401 -from cutlass.cute.typing import Pointer, Int32 # noqa: F401 +from cutlass._mlir import ir +from cutlass.cute.typing import Pointer, Int32
25-33: Consider supportingnum=1in_ldmatrixhelper.The assertion
assert num in [2, 4]excludes the single-matrix case, requiring separate handling inptx_ldmatrix_x1. Consider extending the helper to supportnum=1to reduce code duplication and ensure consistency between transposed and non-transposed variants.def _ldmatrix(smem_ptr, local_ptr, num, transpose, loc=None, ip=None): """Internal helper for ldmatrix operations""" layout = nvvm.MMALayout.col if transpose else nvvm.MMALayout.row - assert num in [2, 4] - ret_type = llvm.StructType.get_literal([T.i32()] * num) - out_i32 = nvvm.ldmatrix(ret_type, smem_ptr.llvm_ptr, num=num, layout=layout, loc=loc, ip=ip) + assert num in [1, 2, 4], f"ldmatrix supports num=1, 2, or 4, got {num}" + if num == 1: + out_i32 = nvvm.ldmatrix(T.i32(), smem_ptr.llvm_ptr, num=1, layout=layout, loc=loc, ip=ip) + out = cute.make_tensor(cute.recast_ptr(local_ptr, dtype=cute.Int32), 1) + out[0] = cute.Int32(out_i32) + return + ret_type = llvm.StructType.get_literal([T.i32()] * num) + out_i32 = nvvm.ldmatrix(ret_type, smem_ptr.llvm_ptr, num=num, layout=layout, loc=loc, ip=ip) out = cute.make_tensor(cute.recast_ptr(local_ptr, dtype=cute.Int32), num) for i in range(num): out[i] = cute.Int32(llvm.extractvalue(T.i32(), out_i32, [i], loc=loc, ip=ip))tilelang/cache/kernel_cache.py (2)
319-339: Consider adding null check forlib_generatorbefore accessinglauncher_libpath.The code checks
lib_genis truthy, but iflib_genexists butlauncher_libpathis an empty string (falsy but not None), the condition passes but_load_binarywill fail. The current check is reasonable but could be more explicit.- lib_gen = getattr(kernel.adapter, "lib_generator", None) - if lib_gen and hasattr(lib_gen, "launcher_libpath") and lib_gen.launcher_libpath: + lib_gen = getattr(kernel.adapter, "lib_generator", None) + if lib_gen is not None and getattr(lib_gen, "launcher_libpath", None):
439-453: Static analysis: Consider usinglogging.exceptionfor better stack traces.Per static analysis hints, using
self.logger.exception()instead ofself.logger.error()in the except blocks would automatically include the stack trace, improving debuggability.except Exception as e: - self.logger.error(f"Error loading kernel source code from disk: {e}") + self.logger.exception(f"Error loading kernel source code from disk: {e}")tilelang/contrib/cutedsl/reduce.py (1)
161-184: Hopper path uses hardcoded barrier IDs 1 and 2.Using fixed barrier IDs (1, 2) in
run_hoppercould conflict with other barriers in the kernel. Consider making these configurable or documenting this constraint.tilelang/contrib/cutedsl/cpasync.py (3)
2-15: Remove unnecessarynoqadirectives.Per static analysis, the
# noqa: F401comments are unnecessary since F401 is not enabled. These can be safely removed for cleaner code.-from cutlass.cutlass_dsl import CuTeDSL, T, if_generate, dsl_user_op # noqa: F401 +from cutlass.cutlass_dsl import CuTeDSL, T, if_generate, dsl_user_op -from cutlass._mlir.dialects import nvvm, cute_nvgpu # noqa: F401 +from cutlass._mlir.dialects import nvvm, cute_nvgpu
21-36: UseTypeErrorfor invalid type errors.Per static analysis hints,
TypeErroris more appropriate thanValueErrorwhen the issue is an incorrect type.else: - raise ValueError(f"Invalid source type: {type(src)}") + raise TypeError(f"Expected cute.Tensor or cute.Pointer for src, got {type(src).__name__}")else: - raise ValueError(f"Invalid destination type: {type(dst)}") + raise TypeError(f"Expected cute.Tensor or cute.Pointer for dst, got {type(dst).__name__}")
72-106: Thetma_loadfunction overloads based oncrdtype — document this behavior.When
crdis not a tuple, the function reinterprets parameters differently (treatingtma_descassmem_ptr,mbarasgmem_ptr, etc.). This is confusing and error-prone. Consider splitting into two functions or adding clearer documentation.The comment on line 76 mentions this is a BUG related to API differences. Consider either:
- Creating a separate function for the non-tuple case
- Adding explicit documentation in the docstring about this dual behavior
tilelang/jit/adapter/cutedsl/libgen.py (3)
17-36: Module-level subprocess calls block import and swallow errors silently.Running
tvm-ffi-configat module import time can slow down imports and the empty fallback silently hides configuration issues. Consider lazy initialization or logging a warning when falling back.+import logging + +_logger = logging.getLogger(__name__) + try: tvm_cxxflags = ( subprocess.check_output( ["tvm-ffi-config", "--cxxflags"], text=True, ) .strip() .split() ) tvm_ldflags = ( subprocess.check_output( ["tvm-ffi-config", "--ldflags"], text=True, ) .strip() .split() ) except (subprocess.CalledProcessError, FileNotFoundError): + _logger.debug("tvm-ffi-config not available, C++ launcher compilation may fail") tvm_cxxflags = [] tvm_ldflags = []
91-112: Temporary files not cleaned up on failure.
tma_srcis created but never explicitly cleaned up if compilation fails. The file will persist until system cleanup.Consider using a context manager or explicit cleanup:
- tma_src = tempfile.NamedTemporaryFile(mode="w", suffix=".cpp", delete=False) if self.tma_cpp_init_code is not None: + tma_src = tempfile.NamedTemporaryFile(mode="w", suffix=".cpp", delete=False) + try: with open(tma_src.name, "w") as f: f.write(self.tma_cpp_init_code) ... + finally: + os.unlink(tma_src.name)
136-138: Include stdout in error message for better diagnostics.When compilation fails,
result.stdoutmay also contain useful information (warnings, notes). Consider including both.result = subprocess.run(compile_cmd, check=False, capture_output=True, text=True) if result.returncode != 0: - raise RuntimeError(f"Failed to compile C++ launcher: {result.stderr}") + raise RuntimeError( + f"Failed to compile C++ launcher:\nstdout: {result.stdout}\nstderr: {result.stderr}" + )tilelang/contrib/cutedsl/gemm_V1.py (2)
221-228: Mutable class attribute_instancesshould useClassVar.Per static analysis, mutable class attributes like
_instances = {}should be annotated withtyping.ClassVarto clarify they're shared across instances.+from typing import ClassVar + class Gemm_SM80: - _instances = {} # cache instances for the same arguments + _instances: ClassVar[dict] = {} # cache instances for the same arguments
424-424: Unused variabletma_tensorfrom tuple unpacking.The static analysis correctly identifies that
tma_tensoris never used. Prefix with underscore to indicate intentional discard.- tma_atom, tma_tensor = cute.nvgpu.cpasync.make_tiled_tma_atom( + tma_atom, _tma_tensor = cute.nvgpu.cpasync.make_tiled_tma_atom(testing/python/jit/test_tilelang_jit_cutedsl.py (4)
9-52: Consider consolidating duplicatedmatmulandmatmul_jit_kernelfunctions.These two functions are nearly identical, with the only difference being their names. This duplication increases maintenance burden. Consider using a single function or a factory pattern.
-def matmul_jit_kernel( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - accum_dtype, - num_stages, - threads, -): - A_shape = (K, M) if trans_A else (M, K) - B_shape = (N, K) if trans_B else (K, N) - A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) - B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - - @T.prim_func - def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype) - B_shared = T.alloc_shared(B_shared_shape, in_dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - T.clear(C_local) - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - if trans_A: - T.copy(A[k * block_K, by * block_M], A_shared) - else: - T.copy(A[by * block_M, k * block_K], A_shared) - if trans_B: - T.copy(B[bx * block_N, k * block_K], B_shared) - else: - T.copy(B[k * block_K, bx * block_N], B_shared) - T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) - T.copy(C_local, C[by * block_M, bx * block_N]) - - return main +# Reuse `matmul` for JIT kernel tests +matmul_jit_kernel = matmulAlso applies to: 55-98
137-143: Inefficient tensor creation pattern.Creating a tensor and then transposing it allocates memory twice. Consider creating the tensor with the correct shape directly when transposed.
- A = torch.randn(M, K, dtype=in_dtype).cuda() - B = torch.randn(K, N, dtype=in_dtype).cuda() - - if trans_A: - A = A.T - if trans_B: - B = B.T + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A = torch.randn(*A_shape, dtype=in_dtype).cuda() + B = torch.randn(*B_shape, dtype=in_dtype).cuda()
244-248: Multi-stream test lacks correctness verification.The test launches kernels on multiple streams but doesn't verify the results are correct. This only tests that the kernel doesn't crash, not that it produces correct results with concurrent execution.
num_streams = 4 + streams = [] + outputs = [] for _ in range(num_streams): stream = torch.cuda.Stream() + streams.append(stream) + tensor_c = torch.randn(M, N, dtype=out_dtype).cuda() + outputs.append(tensor_c) with torch.cuda.stream(stream): - matmul_kernel(tensor_a, tensor_b, tensor_c) + matmul_kernel(tensor_a, tensor_b, tensor_c) + + # Wait for all streams and verify results + for stream in streams: + stream.synchronize() + + tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype) + for tensor_c in outputs: + tilelang.testing.torch_assert_close(tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
310-315: Unusedcheck_hopperfunction.This utility function is defined but never called in the test file. If it's intended to skip tests on non-Hopper hardware, it should be used as a pytest skip decorator or guard.
+import pytest + def check_hopper(): if not torch.cuda.is_available(): return False props = torch.cuda.get_device_properties(0) compute_capability = props.major, props.minor return compute_capability == (9, 0) + +requires_hopper = pytest.mark.skipif( + not check_hopper(), + reason="Requires Hopper GPU (SM 9.0)" +)Then apply
@requires_hopperdecorator to tests that need Hopper hardware.tilelang/jit/adapter/cutedsl/adapter.py (4)
50-63: Duplicated parameter shape processing logic.The parameter shape conversion logic is duplicated between
__init__andfrom_database. Consider extracting to a helper method.+ @staticmethod + def _convert_param_shapes(params: list[KernelParam]) -> list[list]: + """Convert parameter shapes to native Python lists.""" + param_shapes = [] + for param in params: + native_shape = [] + for dim in param.shape: + if isinstance(dim, tir.IntImm): + native_shape.append(int(dim)) + elif isinstance(dim, tir.Var): + native_shape.append(dim) + else: + native_shape.append(dim) + param_shapes.append(native_shape) + return param_shapesThen use
self.param_shapes = self._convert_param_shapes(params)in both__init__andfrom_database.Also applies to: 124-136
250-254: Broad exception catch may mask bugs.Catching generic
Exceptioncan hide unexpected errors. Consider catching more specific exceptions likeOSErrororIOError.try: shutil.copy2(src_cubin_path, dst_cubin_path) logger.debug(f"Saved CuTeDSL cubin to cache: {dst_cubin_path}") - except Exception as e: + except OSError as e: logger.warning(f"Failed to save cubin to cache: {e}")
231-232: Move imports to module level.Importing
osandshutilinside a method adds overhead on each call. These are standard library modules and should be imported at module level.from __future__ import annotations import logging +import os +import shutil from typing import Any, CallableThen remove the local imports at lines 231-232.
21-22: Class attributepymodulemay be unintentionally shared.Mutable or stateful class attributes can be shared across instances. While
Noneitself isn't mutable, this pattern can cause confusion if the attribute is meant to be instance-specific.Consider initializing
pymodulein__init__orfrom_databaseonly:class CuTeDSLKernelAdapter(BaseKernelAdapter): - pymodule = None + # pymodule is initialized in __init__ and from_databasesrc/target/codegen_cutedsl.h (2)
25-25: Global constant in header may cause ODR violations.
const int64_t LOOP_UNROLL_THRESHOLDdefined in a header file could cause one-definition-rule (ODR) issues if this header is included in multiple translation units.-const int64_t LOOP_UNROLL_THRESHOLD = 64; +inline constexpr int64_t LOOP_UNROLL_THRESHOLD = 64;Or move the definition to the
.ccfile and declareextern consthere.
78-82: Documentation comment format issue.The Doxygen
\paramcomment format is incorrect - it should specify the parameter name without the type./*! * \brief Print expr representing the thread tag - * \param IterVar iv The thread index to be binded; + * \param iv The thread index to be bound */ virtual void BindThreadIndex_(const IterVar &iv); // NOLINT(*)tilelang/jit/adapter/cutedsl/wrapper.py (2)
483-529: Class-level mutable attributes should useClassVar.While these dictionaries are used as read-only lookups, annotating them with
ClassVardocuments the intent and satisfies static analysis.+from typing import Any, ClassVar + class TLCuTeDSLSourceWrapper(TLCUDASourceWrapper): - _TYPE_MAP = { + _TYPE_MAP: ClassVar[dict[str, str]] = { "float32": "cutlass.Float32", ... } - _CXX_TYPE_MAP = { + _CXX_TYPE_MAP: ClassVar[dict[str, str]] = { "float32": "float", ... } - _CTYPES_MAP = { + _CTYPES_MAP: ClassVar[dict[str, str]] = { "buffer": "ctypes.c_uint64", ... }
1067-1067: Prefix unused unpacked variables with underscore.The variables
dtypeandglobalAddressare unpacked but never used, which triggers linter warnings.- _, dtype, tensor_rank, globalAddress, *remaining_args = args[1:] + _, _dtype, tensor_rank, _globalAddress, *remaining_args = args[1:]src/target/codegen_py.cc (2)
471-487: Remove unused variablebuffer_var.The variable
buffer_varon line 479 is declared but never used.DataType value_dtype = op->value.dtype(); DataType element_dtype = op->buffer->dtype; PrimExpr index_expr = op->indices[0]; - Var buffer_var = op->buffer->data;
453-469: Remove unused variablebuffer_var.The variable
buffer_varon line 462 is declared but never used.DataType value_dtype = op->dtype; PrimExpr index = op->indices[0]; - Var buffer_var = op->buffer->data; DataType element_dtype = op->buffer->dtype;
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (37)
.github/workflows/ci.yml(1 hunks)3rdparty/tvm(1 hunks)CMakeLists.txt(1 hunks)examples/gemm_fp8/example_tilelang_gemm_fp8.py(1 hunks)maint/scripts/run_local_ci_test.sh(1 hunks)src/target/codegen_cutedsl.cc(1 hunks)src/target/codegen_cutedsl.h(1 hunks)src/target/codegen_py.cc(1 hunks)src/target/codegen_py.h(1 hunks)src/target/rt_mod_cutedsl.cc(1 hunks)src/tl_templates/cuda/nvrtc_std.h(1 hunks)testing/python/jit/test_tilelang_jit_cutedsl.py(1 hunks)tilelang/cache/kernel_cache.py(12 hunks)tilelang/contrib/cutedsl/.gitignore(1 hunks)tilelang/contrib/cutedsl/__init__.py(1 hunks)tilelang/contrib/cutedsl/cpasync.py(1 hunks)tilelang/contrib/cutedsl/gemm_V1.py(1 hunks)tilelang/contrib/cutedsl/ldsm.py(1 hunks)tilelang/contrib/cutedsl/math.py(1 hunks)tilelang/contrib/cutedsl/mbar.py(1 hunks)tilelang/contrib/cutedsl/reduce.py(1 hunks)tilelang/contrib/cutedsl/threadblock_swizzle.py(1 hunks)tilelang/engine/lower.py(2 hunks)tilelang/jit/__init__.py(7 hunks)tilelang/jit/adapter/__init__.py(1 hunks)tilelang/jit/adapter/cutedsl/__init__.py(1 hunks)tilelang/jit/adapter/cutedsl/adapter.py(1 hunks)tilelang/jit/adapter/cutedsl/checks.py(1 hunks)tilelang/jit/adapter/cutedsl/libgen.py(1 hunks)tilelang/jit/adapter/cutedsl/wrapper.py(1 hunks)tilelang/jit/adapter/nvrtc/adapter.py(1 hunks)tilelang/jit/adapter/nvrtc/wrapper.py(1 hunks)tilelang/jit/adapter/utils.py(4 hunks)tilelang/jit/adapter/wrapper.py(7 hunks)tilelang/jit/execution_backend.py(4 hunks)tilelang/jit/kernel.py(7 hunks)tilelang/utils/target.py(3 hunks)
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.
Applied to files:
examples/gemm_fp8/example_tilelang_gemm_fp8.pysrc/target/rt_mod_cutedsl.cc.github/workflows/ci.ymltesting/python/jit/test_tilelang_jit_cutedsl.py
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.
Applied to files:
tilelang/contrib/cutedsl/__init__.py
🧬 Code graph analysis (19)
examples/gemm_fp8/example_tilelang_gemm_fp8.py (2)
tilelang/jit/__init__.py (3)
jit(434-434)jit(438-448)jit(451-521)tilelang/jit/kernel.py (1)
out_idx(629-630)
tilelang/jit/adapter/utils.py (2)
src/target/codegen_cutedsl.h (1)
tvm(18-97)src/target/codegen_py.h (1)
tvm(25-247)
tilelang/contrib/cutedsl/threadblock_swizzle.py (4)
tilelang/carver/template/base.py (1)
arch(156-163)tilelang/jit/__init__.py (3)
jit(434-434)jit(438-448)jit(451-521)tilelang/carver/roller/rasterization.py (1)
panel_width(14-16)tilelang/carver/roller/hint.py (1)
stride(45-46)
tilelang/jit/adapter/nvrtc/wrapper.py (1)
tilelang/jit/adapter/utils.py (1)
pythonic_expr(156-286)
tilelang/jit/adapter/__init__.py (1)
tilelang/jit/adapter/cutedsl/adapter.py (1)
CuTeDSLKernelAdapter(21-341)
tilelang/jit/execution_backend.py (3)
tilelang/jit/adapter/utils.py (1)
is_cutedsl_target(114-115)tilelang/env.py (1)
use_gemm_v1(284-290)tilelang/jit/adapter/cutedsl/checks.py (1)
check_cutedsl_available(33-79)
src/target/codegen_cutedsl.cc (2)
src/target/codegen_py.cc (2)
CheckOutermostParenthesesMatch(22-38)CheckOutermostParenthesesMatch(22-22)src/target/codegen_py.h (1)
PrintExpr_(119-121)
tilelang/contrib/cutedsl/mbar.py (2)
tilelang/carver/template/base.py (1)
arch(156-163)tilelang/language/builtin.py (2)
mbarrier_expect_tx(280-289)mbarrier_arrive(262-277)
tilelang/contrib/cutedsl/ldsm.py (1)
tilelang/language/proxy.py (1)
make_tensor(278-279)
tilelang/jit/kernel.py (2)
tilelang/jit/adapter/utils.py (1)
is_cutedsl_target(114-115)tilelang/jit/adapter/cutedsl/adapter.py (3)
CuTeDSLKernelAdapter(21-341)from_database(99-149)get_kernel_source(197-205)
tilelang/utils/target.py (1)
tilelang/language/ast/ir.py (1)
target(1677-1707)
tilelang/contrib/cutedsl/reduce.py (3)
src/tl_templates/cuda/reduce.h (1)
T(178-250)src/tl_templates/cuda/nvrtc_std.h (1)
min(124-124)tilelang/language/proxy.py (1)
make_tensor(278-279)
tilelang/engine/lower.py (1)
tilelang/language/ast/ir.py (1)
target(1677-1707)
tilelang/contrib/cutedsl/cpasync.py (1)
src/tl_templates/cuda/copy.h (1)
cp_async_wait(20-26)
src/target/codegen_cutedsl.h (1)
src/target/codegen_py.h (1)
codegen(26-246)
tilelang/cache/kernel_cache.py (1)
tilelang/autotuner/param.py (4)
_save_kernel_to_disk(176-260)_safe_write_file(158-166)_load_binary(152-155)_safe_write_executable(169-174)
tilelang/jit/adapter/cutedsl/libgen.py (3)
tilelang/jit/adapter/libgen.py (1)
LibraryGenerator(23-172)tilelang/jit/adapter/utils.py (1)
is_cutedsl_target(114-115)tilelang/jit/adapter/cutedsl/wrapper.py (2)
host_func(550-554)host_func(557-559)
tilelang/jit/adapter/wrapper.py (1)
tilelang/jit/adapter/utils.py (4)
is_cutedsl_target(114-115)pythonic_expr(156-286)parse_tma_descriptor_args(389-494)is_cuda_target(98-99)
src/target/codegen_py.cc (1)
src/target/codegen_py.h (2)
PrintStmt_(113-113)PrintExpr_(119-121)
🪛 Cppcheck (2.18.0)
src/target/rt_mod_cutedsl.cc
[error] 62-62: syntax error
(syntaxError)
🪛 Ruff (0.14.8)
tilelang/jit/adapter/utils.py
41-41: Unused function argument: annotation
(ARG001)
54-54: Avoid specifying long messages outside the exception class
(TRY003)
85-85: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/contrib/cutedsl/__init__.py
9-9: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
10-10: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
11-11: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
14-14: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
17-17: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
18-18: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
23-23: from .mbar import * used; unable to detect undefined names
(F403)
24-24: from .cpasync import * used; unable to detect undefined names
(F403)
25-25: from .gemm_V1 import * used; unable to detect undefined names
(F403)
26-26: from .reduce import * used; unable to detect undefined names
(F403)
27-27: from .ldsm import * used; unable to detect undefined names
(F403)
28-28: from .math import * used; unable to detect undefined names
(F403)
29-29: from .threadblock_swizzle import * used; unable to detect undefined names
(F403)
64-64: bar_sync_ptx may be undefined, or defined from star imports
(F405)
129-129: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/jit/adapter/__init__.py
7-7: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
tilelang/jit/execution_backend.py
79-81: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/contrib/cutedsl/mbar.py
6-6: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
7-7: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
10-10: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
11-11: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
12-12: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
17-17: Unused function argument: timeout_ns
(ARG001)
tilelang/jit/adapter/cutedsl/__init__.py
6-11: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
13-13: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
14-14: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
15-15: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
16-16: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
tilelang/contrib/cutedsl/ldsm.py
11-11: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
12-12: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
tilelang/utils/target.py
104-104: Consider [*list(target_dict["keys"]), "cutedsl"] instead of concatenation
Replace with [*list(target_dict["keys"]), "cutedsl"]
(RUF005)
tilelang/contrib/cutedsl/cpasync.py
2-2: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
4-4: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
11-11: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
14-14: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
15-15: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
29-29: Prefer TypeError exception for invalid type
(TRY004)
29-29: Avoid specifying long messages outside the exception class
(TRY003)
35-35: Prefer TypeError exception for invalid type
(TRY004)
35-35: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/jit/adapter/cutedsl/adapter.py
109-109: Unused class method argument: pass_configs
(ARG003)
197-197: Unused method argument: kernel_only
(ARG002)
253-253: Do not catch blind exception: Exception
(BLE001)
273-275: Avoid specifying long messages outside the exception class
(TRY003)
294-294: Avoid specifying long messages outside the exception class
(TRY003)
312-312: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/jit/adapter/cutedsl/wrapper.py
483-499: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
503-514: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
516-529: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
594-594: Avoid specifying long messages outside the exception class
(TRY003)
893-893: Unused method argument: code
(ARG002)
978-978: Consider [*inner_args, "stream: CUstream"] instead of concatenation
Replace with [*inner_args, "stream: CUstream"]
(RUF005)
989-989: Unused method argument: function_name
(ARG002)
1041-1041: Unused method argument: device_index
(ARG002)
1067-1067: Unpacked variable dtype is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1067-1067: Unpacked variable globalAddress is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1235-1235: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/jit/adapter/cutedsl/checks.py
47-47: Do not catch blind exception: Exception
(BLE001)
55-57: Avoid specifying long messages outside the exception class
(TRY003)
65-65: Avoid specifying long messages outside the exception class
(TRY003)
79-79: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/contrib/cutedsl/gemm_V1.py
3-3: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
151-151: Unused function argument: use_wgmma
(ARG001)
152-152: Unused function argument: wg_wait
(ARG001)
192-192: Unused function argument: use_wgmma
(ARG001)
193-193: Unused function argument: wg_wait
(ARG001)
222-222: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
233-233: Unused method argument: stride_A
(ARG002)
233-233: Unused method argument: stride_B
(ARG002)
233-233: Unused method argument: offset_A
(ARG002)
233-233: Unused method argument: offset_B
(ARG002)
381-381: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
392-392: Unused method argument: stride_A
(ARG002)
392-392: Unused method argument: stride_B
(ARG002)
392-392: Unused method argument: offset_A
(ARG002)
392-392: Unused method argument: offset_B
(ARG002)
424-424: Unpacked variable tma_tensor is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
tilelang/cache/kernel_cache.py
301-301: Do not catch blind exception: Exception
(BLE001)
302-302: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
445-445: Do not catch blind exception: Exception
(BLE001)
446-446: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
452-452: Do not catch blind exception: Exception
(BLE001)
453-453: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
tilelang/jit/adapter/cutedsl/libgen.py
20-20: Starting a process with a partial executable path
(S607)
28-28: Starting a process with a partial executable path
(S607)
78-78: Unused method argument: timeout
(ARG002)
78-78: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
92-92: Unused noqa directive (non-enabled: SIM115)
Remove unused noqa directive
(RUF100)
99-99: subprocess call: check for execution of untrusted input
(S603)
100-108: Starting a process with a partial executable path
(S607)
115-115: Unused noqa directive (non-enabled: SIM115)
Remove unused noqa directive
(RUF100)
136-136: subprocess call: check for execution of untrusted input
(S603)
138-138: Avoid specifying long messages outside the exception class
(TRY003)
146-146: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/jit/adapter/wrapper.py
896-896: Avoid specifying long messages outside the exception class
(TRY003)
🪛 Shellcheck (0.11.0)
maint/scripts/run_local_ci_test.sh
[warning] 21-21: Use 'cd ... || exit' or 'cd ... || return' in case cd fails.
(SC2164)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.8
- GitHub Check: Build SDist
- GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
- GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 18
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tilelang/jit/kernel.py (1)
149-163: Missing"cutedsl"infrom_databasetype hint.The
execution_backendparameter type hint is missing"cutedsl", but lines 405-416 handle the cutedsl case in_create_adapter_from_database.def from_database( cls, func: PrimFunc, host_kernel_source: str, device_kernel_source: str, kernel_lib_path: str, params: list[KernelParam], target: str | Target, target_host: str | Target, out_idx: list[int] | int, - execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"], + execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"], pass_configs: dict[str, Any] | None = None, compile_flags: list[str] | None = None, ):
♻️ Duplicate comments (9)
tilelang/contrib/cutedsl/threadblock_swizzle.py (2)
25-38: Zero stride risk already flagged.The stride calculation on line 35 can result in zero when the final panel has fewer than
gridDim.xblocks, causing division by zero on line 36. This was flagged in a previous review.
41-54: Zero stride risk already flagged.Same issue applies to
rasterization2DColumnon line 51.tilelang/contrib/cutedsl/reduce.py (2)
16-27: Unsupported third parameter fornvvm.fminalready flagged.The NVVM
fminintrinsic only accepts two operands. Passingcas a third parameter will fail. This was flagged in a previous review.
30-41: Same issue applies tonvvm.fmax.The
maxfunction has the same unsupported third parameter issue.tilelang/contrib/cutedsl/ldsm.py (1)
69-72: Bug:ptx_ldmatrix_x1_transwill fail due to assertion in_ldmatrix.This issue was already identified in a previous review. The
_ldmatrixhelper assertsnum in [2, 4]at line 28, butptx_ldmatrix_x1_transcalls_ldmatrixwithnum=1. This will cause an AssertionError at runtime.The fix should mirror the inline handling used in
ptx_ldmatrix_x1(lines 51-54):@dsl_user_op def ptx_ldmatrix_x1_trans(smem_ptr: Pointer, local_ptr: Pointer, *, loc=None, ip=None) -> None: """Load 1 matrix (8x8) with transpose from shared memory""" - _ldmatrix(smem_ptr, local_ptr, 1, True, loc, ip) + out_i32 = nvvm.ldmatrix(T.i32(), smem_ptr.llvm_ptr, num=1, layout=nvvm.MMALayout.col, loc=loc, ip=ip) + out = cute.make_tensor(cute.recast_ptr(local_ptr, dtype=cute.Int32), 1) + out[0] = cute.Int32(out_i32)src/target/codegen_py.h (1)
28-28: Avoidusing namespacein header files.This was flagged in a previous review.
using namespace tir;in a header file pollutes the namespace for all translation units that include this header, potentially causing name collisions.Instead, use explicit
tir::qualification in declarations or move theusingdirective inside method implementations in the.ccfile.src/target/codegen_py.cc (3)
81-102: Add collision check forgvar->name_hintbefore reserving it (duplicate).
This matches the earlier review: thename_hintpath can silently generate duplicate Python defs.} else { - func_name_supply_->ReserveName(gvar->name_hint); - return gvar->name_hint; + auto name = gvar->name_hint; + ICHECK(!func_name_supply_->ContainsName(name)) + << "Function " << gvar << " must use name hint " << name + << ", but this name has already been used."; + func_name_supply_->ReserveName(name); + return name; }
188-217:PrintTypeswitch fallthrough prints incorrect types (duplicate); also float16 path is inconsistent.
Missingbreakcausesintbool-style output. Past review already flagged this.void CodeGenTileLangPY::PrintType(DataType type, std::ostream &os) { // NOLINT(*) if (type.is_float()) { - if (type.bits() == 32 || type.bits() == 64) { + if (type.bits() == 16 || type.bits() == 32 || type.bits() == 64) { os << "float"; + } else { + LOG(FATAL) << "Unsupported float bit-width: " << type.bits(); } } else if (type.is_uint()) { switch (type.bits()) { case 8: case 16: case 32: case 64: { os << "int"; + break; } case 1: os << "bool"; + break; + default: + LOG(FATAL) << "Unsupported uint bit-width: " << type.bits(); } } else if (type.is_int()) { switch (type.bits()) { case 8: case 16: case 32: case 64: { os << "int"; + break; } + default: + LOG(FATAL) << "Unsupported int bit-width: " << type.bits(); } } else { LOG(FATAL) << "Cannot convert type " << type << " to Python type"; } }
522-536: For-loop bounds stream TVM IR directly instead of Python expr printing (duplicate).
Past review already flagged this; should usePrintExpr_.if (is_zero(op->min)) { PrintExpr_(op->extent, stream); } else { - stream << op->min << ", " - << arith::Analyzer().Simplify(op->extent + op->min); + PrintExpr_(op->min, stream); + stream << ", "; + PrimExpr upper = arith::Analyzer().Simplify(op->extent + op->min); + PrintExpr_(upper, stream); }
🧹 Nitpick comments (28)
tilelang/jit/adapter/utils.py (2)
41-54: Remove unusedannotationparameter or use it in the pattern.The
annotationparameter defaults to"@cute.kernel"but is never used in the function body. The regex pattern is hardcoded to@cute\.kernel. Either remove the parameter or incorporate it into the pattern for consistency with sibling functions likematch_declare_kernel.-def match_declare_kernel_cutedsl(source: str, annotation: str = "@cute.kernel") -> int: +def match_declare_kernel_cutedsl(source: str) -> int: # Match decorator followed by function definition across lines # \s+ allows any whitespace including newlines between decorator and def pattern = r"@cute\.kernel\s+def\s+(\w+)"Or, to use the parameter:
- pattern = r"@cute\.kernel\s+def\s+(\w+)" + pattern = rf"{re.escape(annotation)}\s+def\s+(\w+)"
78-85: Regex may fail on type hints with nested parentheses.The pattern
\([^)]*\)matches until the first), which breaks for signatures likedef kernel(callback: Callable[[int], int]). If such signatures are possible, consider a more robust approach (e.g., iterative parenthesis matching or a simpler heuristic).For current use cases with CuTeDSL kernels, this is likely acceptable.
tilelang/contrib/cutedsl/math.py (1)
11-44: Consider adding return type annotations for consistency.The
divffunction has a return type annotation but the other functions (exp,exp2,log, etc.) do not. Adding consistent type hints improves maintainability.-def exp(x: Union[TensorSSA, Numeric], fastmath: bool = True): +def exp(x: Union[TensorSSA, Numeric], fastmath: bool = True) -> Union[TensorSSA, Numeric]: return cute.math.exp(x, fastmath=fastmath)tilelang/contrib/cutedsl/reduce.py (1)
155-159: Recursive AllReduce may hit stack limits for large thread counts.The recursive call pattern
AllReduce(...).run(...)builds up stack frames. For very largethreadsvalues (e.g., 1024), this results in ~10 recursive calls which is fine, but the pattern could be refactored to iterative if needed for deeper recursion.Also applies to: 180-184
tilelang/contrib/cutedsl/gemm_V1.py (3)
222-222: Annotate mutable class attribute withClassVar.The
_instancesdict is a mutable class attribute shared across all instances. Annotate it withtyping.ClassVarfor clarity and type checker compatibility.+from typing import ClassVar class Gemm_SM80: - _instances = {} # cache instances for the same arguments + _instances: ClassVar[dict] = {} # cache instances for the same argumentsApply the same change to
Gemm_SM90._instances.Also applies to: 381-381
424-424: Prefix unused variable with underscore.The unpacked
tma_tensoris never used. Prefix it with_to indicate it's intentionally unused.- tma_atom, tma_tensor = cute.nvgpu.cpasync.make_tiled_tma_atom( + tma_atom, _tma_tensor = cute.nvgpu.cpasync.make_tiled_tma_atom(
3-3: Remove unusednoqadirective.The
# noqa: F401comment on line 3 is flagged as unused by Ruff.-import cutlass.utils as utils # noqa: F401 +import cutlass.utils as utilssrc/target/codegen_cutedsl.h (1)
86-94: Consider documenting the eviction policy enum mapping.The
eviction_policy_names_vector maps indices to policy strings. A brief comment explaining the index-to-policy correspondence would improve maintainability.+ // Maps eviction policy indices to their string representations + // Index 0: EVICT_NORMAL, Index 1: EVICT_FIRST, Index 2: EVICT_LAST std::vector<std::string> eviction_policy_names_ = { "EVICT_NORMAL", "EVICT_FIRST", "EVICT_LAST"};tilelang/jit/adapter/cutedsl/__init__.py (1)
6-16: Consider aligning__all__order with import order for consistency.The
__all__list placescheck_cutedsl_availablelast, while imports have it first. Consider either sorting__all__alphabetically or matching the import order for easier maintenance.__all__ = [ + "check_cutedsl_available", "CuTeDSLKernelAdapter", - "TLCuTeDSLSourceWrapper", "CuTeDSLLibraryGenerator", - "check_cutedsl_available", + "TLCuTeDSLSourceWrapper", ]tilelang/utils/target.py (2)
99-106: Consider using unpacking for cleaner list construction.The logic correctly transforms
cutedsl*target strings to CUDA-based Targets with thecutedslkey appended. Per static analysis suggestion, consider using unpacking for the list concatenation.- target_dict["keys"] = list(target_dict["keys"]) + ["cutedsl"] + target_dict["keys"] = [*target_dict["keys"], "cutedsl"]
127-132: Dead code: lines 130-131 are now unreachable.The early return at lines 127-128 ensures that if
return_varis aTarget, the function returns immediately. This makes theisinstance(return_var, Target)check at line 130 unreachable.Either remove the dead code or consolidate the logic:
if isinstance(return_var, Target): return return_var if return_object: - if isinstance(return_var, Target): - return return_var return Target(return_var) return return_vartilelang/jit/adapter/cutedsl/checks.py (1)
43-49: Consider narrowing the exception type.The bare
Exceptioncatch is documented as "best-effort" which is reasonable, but consider catching a more specific base like(ValueError, TypeError, AttributeError)to avoid accidentally swallowing unexpected errors (e.g.,KeyboardInterruptis not caught, but other programming errors might be).except _importlib_metadata.PackageNotFoundError: dist_version = None - except Exception: + except (ValueError, TypeError, AttributeError, OSError): # Metadata is best-effort; don't block internal/nonstandard installs here. dist_version = Nonetilelang/contrib/cutedsl/__init__.py (3)
20-26: Consider using explicit imports instead of star imports.Star imports make it harder to track what symbols are available and can lead to namespace pollution. Consider defining an
__all__in each submodule and/or using explicit imports in this init file.That said, for a facade/convenience module like this, star imports are acceptable if the submodules are well-maintained.
85-95: Consider settinghas_side_effects=Falsefor the pack operation.The
mov.b32instruction is a pure bitwise packing operation with no side effects. Settinghas_side_effects=Truemay prevent useful compiler optimizations like dead code elimination or instruction reordering.packed_xy = llvm.inline_asm( Int32.mlir_type, [x_i16, y_i16], "mov.b32 $0, {$1, $2};", "=r,h,h", - has_side_effects=True, + has_side_effects=False, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, loc=loc, ip=ip, )
102-127: Consider extending AtomicAdd to support more data types.The current implementation only supports
Float32andInt32. For broader utility, consider adding support forFloat16(atomicAdd on fp16 is available on SM_60+) andInt64.Would you like me to help extend this to support additional data types?
src/target/codegen_cutedsl.cc (2)
22-47: Consider extracting duplicated utility functions.
CheckOutermostParenthesesMatchandRemoveOutermostParenthesesare duplicated fromsrc/target/codegen_py.cc(lines 21-37 in the external context). Consider extracting these into a shared utility header to reduce code duplication.// Could create a shared header like codegen_utils.h // and move these functions there for reuse across codegen backends
91-109: Silent fallthrough for unsupported float8 variants.Several float8 variants (e3m4, e4m3b11fnuz, e4m3fnuz, e5m2fnuz) have empty branches that silently fall through to the error at line 134. While this works, adding explicit
LOG(WARNING)or comments would clarify intent.} else if (t.is_float8_e3m4()) { - // unsupported + // Float8E3M4 is unsupported in CuTeDSL } else if (t.is_float8_e4m3()) {tilelang/jit/adapter/cutedsl/libgen.py (2)
78-78: Unusedtimeoutparameter.The
timeoutparameter is declared but never used in the method body. Either implement timeout handling for the subprocess calls or remove the parameter.- def compile_lib(self, timeout: float = None): + def compile_lib(self, timeout: float | None = None): + # Note: timeout is reserved for future use with subprocess timeoutsOr if timeout support is needed:
- result = subprocess.run(compile_cmd, check=False, capture_output=True, text=True) + result = subprocess.run(compile_cmd, check=False, capture_output=True, text=True, timeout=timeout)
92-112: Temporary files not cleaned up on success.The
NamedTemporaryFileobjects created withdelete=Falseare never explicitly cleaned up after successful compilation. While this may be acceptable for debugging, consider cleaning up temp files in production paths.tma_src = tempfile.NamedTemporaryFile(mode="w", suffix=".cpp", delete=False) if self.tma_cpp_init_code is not None: with open(tma_src.name, "w") as f: f.write(self.tma_cpp_init_code) # ... compilation code ... + + # Clean up temp source file + os.unlink(tma_src.name)tilelang/contrib/cutedsl/ldsm.py (1)
11-12: Remove unusednoqadirectives.Static analysis indicates the
# noqa: F401comments on lines 11-12 are unnecessary since the rule is not enabled. These can be safely removed.-from cutlass._mlir import ir # noqa: F401 -from cutlass.cute.typing import Pointer, Int32 # noqa: F401 +from cutlass._mlir import ir +from cutlass.cute.typing import Pointer, Int32tilelang/cache/kernel_cache.py (2)
423-431: Required files validation includes launcher for CuTeDSL.Good validation that ensures the launcher library exists before attempting to load a cached CuTeDSL kernel. This prevents confusing errors later in the loading process.
Minor: Consider using a generator expression instead of list comprehension in the
all()call on line 431:- if not all([os.path.exists(file) for file in required_files]): + if not all(os.path.exists(file) for file in required_files):
301-302: Consider usinglogging.exceptionfor better stack traces.Per static analysis hints,
logging.exceptionprovides better debugging context thanlogging.errorwhen catching exceptions, as it automatically includes the stack trace.except Exception as e: - self.logger.error(f"Error saving kernel source code to disk: {e}") + self.logger.exception("Error saving kernel source code to disk")This pattern applies to similar locations at lines 314, 370, 380, 446, 453, and 462.
tilelang/contrib/cutedsl/cpasync.py (3)
1-19: Remove ineffectivenoqa: F401directives (RUF100) and consider dropping unused constants.
Ruff reports thesenoqacomments are unused, so they add noise. AlsoBYTES_PER_TENSORMAP/BYTES_PER_POINTERaren’t referenced in this module.
21-37: Don’t useassertfor user input validation; raiseValueErrorand preferTypeErrorfor bad operand types.
assert size in [16, 8, 4]can be optimized away and produces weaker errors than explicit exceptions; also the invalid-type branches should beTypeError(per Ruff TRY004).def cp_async_gs(size, dst, dst_offset, src, src_offset): - assert size in [16, 8, 4] + if size not in (4, 8, 16): + raise ValueError("cp_async_gs size must be one of {4, 8, 16} bytes") mode = nvvm.LoadCacheModifierKind.CG if size == 16 else nvvm.LoadCacheModifierKind.CA if isinstance(src, cute.Tensor): src_ptr = src.iterator elif isinstance(src, cute.Pointer): src_ptr = src else: - raise ValueError(f"Invalid source type: {type(src)}") + raise TypeError("cp_async_gs src must be cute.Tensor or cute.Pointer") if isinstance(dst, cute.Tensor): dst_ptr = dst.iterator elif isinstance(dst, cute.Pointer): dst_ptr = dst else: - raise ValueError(f"Invalid destination type: {type(dst)}") + raise TypeError("cp_async_gs dst must be cute.Tensor or cute.Pointer") cp_async_shared_global(dst_ptr + dst_offset, src_ptr + src_offset, size, mode)
170-198: Ensuresizeis a plain int before buildingIntegerAttr(defensive).
Ifcp_size/src_sizecome in as non-int “Int” wrappers,ir.IntegerAttr.get(..., size)can misbehave.- size = src_size if src_size else cp_size + size = int(src_size) if src_size is not None else int(cp_size) nvvm.cp_async_shared_global( dst=dst.llvm_ptr, src=src.llvm_ptr, size=ir.IntegerAttr.get(ir.IntegerType.get_signless(32), size),tilelang/jit/adapter/cutedsl/adapter.py (2)
197-206:kernel_onlyis unused (ARG002); either honor it or drop it.
216-255: Create cache dir and avoid catchingExceptionbroadly.
If_cache_pathexists but directory doesn’t, copy will fail; also catchingExceptioncan hide real programmer errors.# Destination cubin path (in cache directory) dst_cubin_path = os.path.join(cache_path, "kernel.cubin") + os.makedirs(cache_path, exist_ok=True) @@ - except Exception as e: + except (OSError, shutil.Error) as e: logger.warning(f"Failed to save cubin to cache: {e}")tilelang/jit/adapter/cutedsl/wrapper.py (1)
483-529: Mark mapping dicts asClassVar(RUF012) to avoid “mutable default on class” pitfalls.
Even if you never mutate them, annotating helps tooling and avoids future foot-guns.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (37)
.github/workflows/ci.yml(1 hunks)3rdparty/tvm(1 hunks)CMakeLists.txt(1 hunks)examples/gemm_fp8/example_tilelang_gemm_fp8.py(1 hunks)maint/scripts/run_local_ci_test.sh(1 hunks)src/target/codegen_cutedsl.cc(1 hunks)src/target/codegen_cutedsl.h(1 hunks)src/target/codegen_py.cc(1 hunks)src/target/codegen_py.h(1 hunks)src/target/rt_mod_cutedsl.cc(1 hunks)src/tl_templates/cuda/nvrtc_std.h(1 hunks)testing/python/jit/test_tilelang_jit_cutedsl.py(1 hunks)tilelang/cache/kernel_cache.py(12 hunks)tilelang/contrib/cutedsl/.gitignore(1 hunks)tilelang/contrib/cutedsl/__init__.py(1 hunks)tilelang/contrib/cutedsl/cpasync.py(1 hunks)tilelang/contrib/cutedsl/gemm_V1.py(1 hunks)tilelang/contrib/cutedsl/ldsm.py(1 hunks)tilelang/contrib/cutedsl/math.py(1 hunks)tilelang/contrib/cutedsl/mbar.py(1 hunks)tilelang/contrib/cutedsl/reduce.py(1 hunks)tilelang/contrib/cutedsl/threadblock_swizzle.py(1 hunks)tilelang/engine/lower.py(2 hunks)tilelang/jit/__init__.py(7 hunks)tilelang/jit/adapter/__init__.py(1 hunks)tilelang/jit/adapter/cutedsl/__init__.py(1 hunks)tilelang/jit/adapter/cutedsl/adapter.py(1 hunks)tilelang/jit/adapter/cutedsl/checks.py(1 hunks)tilelang/jit/adapter/cutedsl/libgen.py(1 hunks)tilelang/jit/adapter/cutedsl/wrapper.py(1 hunks)tilelang/jit/adapter/nvrtc/adapter.py(1 hunks)tilelang/jit/adapter/nvrtc/wrapper.py(1 hunks)tilelang/jit/adapter/utils.py(4 hunks)tilelang/jit/adapter/wrapper.py(7 hunks)tilelang/jit/execution_backend.py(4 hunks)tilelang/jit/kernel.py(7 hunks)tilelang/utils/target.py(3 hunks)
✅ Files skipped from review due to trivial changes (2)
- src/tl_templates/cuda/nvrtc_std.h
- tilelang/contrib/cutedsl/.gitignore
🚧 Files skipped from review as they are similar to previous changes (6)
- 3rdparty/tvm
- .github/workflows/ci.yml
- tilelang/jit/adapter/nvrtc/wrapper.py
- tilelang/engine/lower.py
- testing/python/jit/test_tilelang_jit_cutedsl.py
- examples/gemm_fp8/example_tilelang_gemm_fp8.py
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.
Applied to files:
tilelang/contrib/cutedsl/__init__.py
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.
Applied to files:
src/target/rt_mod_cutedsl.cc
🧬 Code graph analysis (16)
tilelang/jit/adapter/cutedsl/checks.py (1)
tilelang/jit/__init__.py (2)
compile(49-118)compile(350-376)
src/target/codegen_cutedsl.cc (4)
src/target/codegen_py.cc (2)
CheckOutermostParenthesesMatch(22-38)CheckOutermostParenthesesMatch(22-22)tilelang/language/builtin.py (6)
get_mbarrier(92-101)no_set_max_nreg(202-204)warpgroup_arrive(292-298)warpgroup_commit_batch(301-307)warpgroup_fence_operand(462-618)set_max_nreg(164-179)tilelang/contrib/cutedsl/cpasync.py (1)
tma_load(59-106)tilelang/language/math_intrinsics.py (5)
ieee_add(149-172)ieee_sub(175-196)ieee_mul(199-220)ieee_fmaf(223-247)ieee_fdiv(309-330)
tilelang/jit/adapter/wrapper.py (3)
tilelang/jit/adapter/utils.py (3)
is_metal_target(110-111)is_cutedsl_target(114-115)pythonic_expr(156-286)tilelang/jit/adapter/cutedsl/wrapper.py (1)
_pythonic_expr(565-567)tilelang/jit/adapter/nvrtc/wrapper.py (1)
_pythonic_expr(271-276)
tilelang/contrib/cutedsl/mbar.py (2)
tilelang/carver/template/base.py (1)
arch(156-163)tilelang/language/builtin.py (2)
mbarrier_expect_tx(280-289)mbarrier_arrive(262-277)
tilelang/contrib/cutedsl/cpasync.py (1)
src/tl_templates/cuda/copy.h (1)
cp_async_wait(20-26)
src/target/codegen_py.h (1)
src/target/codegen_py.cc (34)
AddFunction(49-64)AddFunction(49-49)Finish(66-71)Finish(66-66)GetFunctionName_(73-79)GetFunctionName_(73-73)RegisterFunction_(81-102)RegisterFunction_(81-82)InitFuncState_(104-109)InitFuncState_(104-104)PrintFunctionSignature_(111-132)PrintFunctionSignature_(111-113)ReserveKeywordsAsUnique_(134-181)ReserveKeywordsAsUnique_(134-134)PrintSSAAssign(183-186)PrintSSAAssign(183-184)PrintType(188-217)PrintType(188-189)VisitStmt_(471-487)VisitStmt_(471-471)VisitStmt_(489-491)VisitStmt_(489-489)CastFromTo_(592-600)CastFromTo_(592-593)PrintBinaryExpr_(602-620)PrintBinaryExpr_(602-605)PrintCallExtern_(634-647)PrintCallExtern_(634-638)GetBufferRef_(650-662)GetBufferRef_(650-652)RegisterHandleType_(664-672)RegisterHandleType_(664-665)HandleTypeMatch_(674-680)HandleTypeMatch_(674-675)
tilelang/utils/target.py (1)
tilelang/language/ast/ir.py (1)
target(1677-1707)
tilelang/contrib/cutedsl/ldsm.py (1)
tilelang/language/proxy.py (1)
make_tensor(278-279)
tilelang/jit/adapter/cutedsl/adapter.py (8)
tilelang/engine/param.py (1)
KernelParam(13-105)tilelang/jit/adapter/wrapper.py (4)
wrap(144-145)wrap(858-878)wrap(885-912)host_func(557-567)tilelang/jit/adapter/cutedsl/checks.py (1)
check_cutedsl_available(33-79)tilelang/jit/adapter/cutedsl/libgen.py (1)
CuTeDSLLibraryGenerator(39-146)tilelang/utils/target.py (1)
determine_target(63-133)tilelang/jit/adapter/base.py (2)
BaseKernelAdapter(11-96)_post_init(95-96)tilelang/jit/adapter/cutedsl/wrapper.py (2)
host_func(550-554)host_func(557-559)tilelang/jit/adapter/nvrtc/wrapper.py (2)
host_func(260-264)host_func(267-269)
src/target/codegen_cutedsl.h (2)
src/target/codegen_py.h (4)
tvm(25-247)codegen(26-246)void(86-86)void(92-92)src/target/codegen_cutedsl.cc (42)
PrintFuncDecorator_(60-63)PrintFuncDecorator_(60-61)PreFunctionBody_(65-70)PreFunctionBody_(65-65)PrintType(141-146)PrintType(141-142)VisitExpr_(148-152)VisitExpr_(148-149)VisitExpr_(154-189)VisitExpr_(154-155)VisitExpr_(191-215)VisitExpr_(191-192)VisitExpr_(217-224)VisitExpr_(217-218)VisitExpr_(225-228)VisitExpr_(225-226)VisitExpr_(229-232)VisitExpr_(229-230)VisitExpr_(265-629)VisitExpr_(265-266)VisitExpr_(631-691)VisitExpr_(631-632)VisitStmt_(693-746)VisitStmt_(693-693)VisitStmt_(748-801)VisitStmt_(748-748)VisitStmt_(803-850)VisitStmt_(803-803)VisitStmt_(852-889)VisitStmt_(852-852)VisitStmt_(891-915)VisitStmt_(891-891)VisitStmt_(917-938)VisitStmt_(917-917)PrintBinaryExpr_(1010-1019)PrintBinaryExpr_(1010-1013)PrintBinaryIntrinsic_(1021-1029)PrintBinaryIntrinsic_(1021-1023)PrintCallExtern_(1031-1157)PrintCallExtern_(1031-1035)GetBufferRef_(1188-1247)GetBufferRef_(1188-1190)
tilelang/jit/adapter/cutedsl/wrapper.py (4)
src/target/codegen_py.h (1)
tvm(25-247)tilelang/jit/adapter/utils.py (3)
extract_python_func_declaration(57-85)pythonic_expr(156-286)parse_tma_descriptor_args(389-494)tilelang/jit/adapter/nvrtc/wrapper.py (4)
host_func(260-264)host_func(267-269)_pythonic_expr(271-276)update_lib_code(519-573)src/tl_templates/cuda/nvrtc_std.h (1)
max(126-126)
tilelang/jit/adapter/utils.py (1)
src/target/codegen_py.h (1)
tvm(25-247)
src/target/codegen_py.cc (2)
src/target/codegen_cutedsl.cc (2)
CheckOutermostParenthesesMatch(23-39)CheckOutermostParenthesesMatch(23-23)src/target/codegen_py.h (2)
PrintStmt_(113-113)PrintExpr_(119-121)
tilelang/jit/adapter/__init__.py (1)
tilelang/jit/adapter/cutedsl/adapter.py (1)
CuTeDSLKernelAdapter(21-341)
tilelang/jit/adapter/cutedsl/__init__.py (4)
tilelang/jit/adapter/cutedsl/checks.py (1)
check_cutedsl_available(33-79)tilelang/jit/adapter/cutedsl/adapter.py (1)
CuTeDSLKernelAdapter(21-341)tilelang/jit/adapter/cutedsl/wrapper.py (1)
TLCuTeDSLSourceWrapper(473-1252)tilelang/jit/adapter/cutedsl/libgen.py (1)
CuTeDSLLibraryGenerator(39-146)
tilelang/jit/execution_backend.py (3)
tilelang/jit/adapter/utils.py (1)
is_cutedsl_target(114-115)tilelang/env.py (1)
use_gemm_v1(284-290)tilelang/jit/adapter/cutedsl/checks.py (1)
check_cutedsl_available(33-79)
🪛 Cppcheck (2.18.0)
src/target/rt_mod_cutedsl.cc
[error] 62-62: syntax error
(syntaxError)
🪛 Ruff (0.14.8)
tilelang/jit/adapter/cutedsl/checks.py
47-47: Do not catch blind exception: Exception
(BLE001)
55-57: Avoid specifying long messages outside the exception class
(TRY003)
65-65: Avoid specifying long messages outside the exception class
(TRY003)
79-79: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/jit/adapter/wrapper.py
896-896: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/cache/kernel_cache.py
301-301: Do not catch blind exception: Exception
(BLE001)
302-302: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
445-445: Do not catch blind exception: Exception
(BLE001)
446-446: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
452-452: Do not catch blind exception: Exception
(BLE001)
453-453: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
tilelang/contrib/cutedsl/__init__.py
7-7: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
8-8: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
9-9: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
11-11: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
14-14: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
15-15: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
20-20: from .mbar import * used; unable to detect undefined names
(F403)
21-21: from .cpasync import * used; unable to detect undefined names
(F403)
22-22: from .gemm_V1 import * used; unable to detect undefined names
(F403)
23-23: from .reduce import * used; unable to detect undefined names
(F403)
24-24: from .ldsm import * used; unable to detect undefined names
(F403)
25-25: from .math import * used; unable to detect undefined names
(F403)
26-26: from .threadblock_swizzle import * used; unable to detect undefined names
(F403)
61-61: bar_sync_ptx may be undefined, or defined from star imports
(F405)
126-126: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/contrib/cutedsl/mbar.py
6-6: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
7-7: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
10-10: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
11-11: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
12-12: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
17-17: Unused function argument: timeout_ns
(ARG001)
tilelang/contrib/cutedsl/cpasync.py
2-2: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
4-4: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
11-11: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
14-14: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
15-15: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
29-29: Prefer TypeError exception for invalid type
(TRY004)
29-29: Avoid specifying long messages outside the exception class
(TRY003)
35-35: Prefer TypeError exception for invalid type
(TRY004)
35-35: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/utils/target.py
104-104: Consider [*list(target_dict["keys"]), "cutedsl"] instead of concatenation
Replace with [*list(target_dict["keys"]), "cutedsl"]
(RUF005)
tilelang/contrib/cutedsl/ldsm.py
11-11: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
12-12: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
tilelang/contrib/cutedsl/gemm_V1.py
3-3: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
151-151: Unused function argument: use_wgmma
(ARG001)
152-152: Unused function argument: wg_wait
(ARG001)
192-192: Unused function argument: use_wgmma
(ARG001)
193-193: Unused function argument: wg_wait
(ARG001)
222-222: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
233-233: Unused method argument: stride_A
(ARG002)
233-233: Unused method argument: stride_B
(ARG002)
233-233: Unused method argument: offset_A
(ARG002)
233-233: Unused method argument: offset_B
(ARG002)
381-381: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
392-392: Unused method argument: stride_A
(ARG002)
392-392: Unused method argument: stride_B
(ARG002)
392-392: Unused method argument: offset_A
(ARG002)
392-392: Unused method argument: offset_B
(ARG002)
424-424: Unpacked variable tma_tensor is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
tilelang/jit/adapter/cutedsl/adapter.py
109-109: Unused class method argument: pass_configs
(ARG003)
197-197: Unused method argument: kernel_only
(ARG002)
253-253: Do not catch blind exception: Exception
(BLE001)
273-275: Avoid specifying long messages outside the exception class
(TRY003)
294-294: Avoid specifying long messages outside the exception class
(TRY003)
312-312: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/jit/adapter/cutedsl/wrapper.py
483-499: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
503-514: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
516-529: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
594-594: Avoid specifying long messages outside the exception class
(TRY003)
893-893: Unused method argument: code
(ARG002)
980-980: Consider [*inner_args, "stream: CUstream"] instead of concatenation
Replace with [*inner_args, "stream: CUstream"]
(RUF005)
991-991: Unused method argument: function_name
(ARG002)
1043-1043: Unused method argument: device_index
(ARG002)
1069-1069: Unpacked variable dtype is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1069-1069: Unpacked variable globalAddress is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1237-1237: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/jit/adapter/utils.py
41-41: Unused function argument: annotation
(ARG001)
54-54: Avoid specifying long messages outside the exception class
(TRY003)
85-85: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/jit/adapter/__init__.py
7-7: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
tilelang/jit/adapter/cutedsl/__init__.py
6-11: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
13-13: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
14-14: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
15-15: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
16-16: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
tilelang/jit/adapter/cutedsl/libgen.py
20-20: Starting a process with a partial executable path
(S607)
28-28: Starting a process with a partial executable path
(S607)
78-78: Unused method argument: timeout
(ARG002)
78-78: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
92-92: Unused noqa directive (non-enabled: SIM115)
Remove unused noqa directive
(RUF100)
99-99: subprocess call: check for execution of untrusted input
(S603)
100-108: Starting a process with a partial executable path
(S607)
115-115: Unused noqa directive (non-enabled: SIM115)
Remove unused noqa directive
(RUF100)
136-136: subprocess call: check for execution of untrusted input
(S603)
138-138: Avoid specifying long messages outside the exception class
(TRY003)
146-146: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/jit/execution_backend.py
79-81: Avoid specifying long messages outside the exception class
(TRY003)
🪛 Shellcheck (0.11.0)
maint/scripts/run_local_ci_test.sh
[warning] 21-21: Use 'cd ... || exit' or 'cd ... || return' in case cd fails.
(SC2164)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (6)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.8
- GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
- GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
- GitHub Check: Build SDist
🔇 Additional comments (35)
maint/scripts/run_local_ci_test.sh (1)
17-20: Two-phase pytest approach is well designed.The separation of parallel tests (excluding CuTeDSL) from isolated GEMM v1 tests is sound and well documented. The environment variable is correctly scoped to only the CuTeDSL invocation, preventing unintended side effects on other tests. The single-worker flag appropriately prevents xdist contention on the GPU.
Also applies to: 22-22, 24-25
CMakeLists.txt (1)
217-220: LGTM!The new source files for CuTeDSL codegen and runtime are correctly added to the CUDA build configuration. The placement within the
USE_CUDAblock is appropriate.tilelang/jit/adapter/utils.py (2)
114-116: LGTM!Clean implementation for detecting CuTeDSL targets by checking both the target kind and the presence of the "cutedsl" key.
156-158: LGTM!Good addition of the
floor_div_opparameter. The default"/"preserves backward compatibility for C/C++ codegen, while allowing Python codegen to pass"//"for integer division semantics.Also applies to: 166-169, 240-240
tilelang/contrib/cutedsl/gemm_V1.py (1)
138-177: LGTM ongemm_srdesign decision.The comment "wgmma doesn't support gemm_sr, only use SM80" is helpful. The unused
use_wgmmaandwg_waitparameters maintain API consistency with other gemm_* functions.src/target/codegen_cutedsl.h (2)
21-25: LGTM on LOOP_UNROLL_THRESHOLD.The threshold constant is well-documented with the rationale for its value. Being
constgives it internal linkage, avoiding ODR violations across translation units.
27-94: LGTM on class structure.The
CodeGenTileLangCuTeDSLclass is well-organized with clear separation of concerns:
- Protected override methods for CuTeDSL-specific expression/statement handling
- Virtual helper methods for specialized operations (vec load/store, buffer access)
- Private state for mbarrier naming and unroll factors
The inheritance from
CodeGenTileLangPYis appropriate for the Python codegen path.tilelang/contrib/cutedsl/math.py (1)
7-8: No action required oncute.math._math_opusage.There is no public alternative for
divfin CUTLASS CuTe. The_math_opwrapper is the necessary mechanism for integrating MLIR operations (likearith.divf) into the cute.math interface, making this usage appropriate and intentional rather than a dependency on an unstable internal API.tilelang/jit/adapter/__init__.py (1)
7-7: LGTM!The import follows the established pattern for re-exporting kernel adapters from this package's
__init__.py. The# noqa: F401directive is consistent with other imports in this file.tilelang/jit/adapter/nvrtc/adapter.py (1)
78-80: LGTM!The change from tuple unpacking to dictionary-based access aligns with the updated
TLPyWrapper.wrap()return signature. The NVRTC adapter correctly extracts only the fields it needs (host_funcandfunction_names), maintaining backward compatibility while enabling the broader dictionary-based output required for CuTeDSL integration.tilelang/jit/execution_backend.py (3)
77-90: LGTM! Validation logic is well-structured.The helper function encapsulates the GEMM v1 and CuTeDSL availability checks cleanly. Converting
ImportErrortoValueErrorpreserves the function's error contract while keeping the actionable message.
93-107: LGTM!The auto-resolution logic correctly prioritizes CuTeDSL when the target has the
cutedslkey, validating prerequisites before returning. The early return prevents falling through to the generic CUDA default.
123-126: LGTM!Explicit
cutedslbackend selection also validates GEMM v1 and CuTeDSL availability, ensuring consistent prerequisite checking regardless of how the backend is selected.tilelang/contrib/cutedsl/mbar.py (2)
22-30: LGTM!The
mbarrier_cp_async_arriveimplementation correctly uses the@dsl_user_opdecorator and properly forwards thelocandipparameters to the underlying NVVM intrinsic.
33-38: LGTM!The fence functions correctly delegate to the cutlass arch implementations with appropriate parameters.
src/target/rt_mod_cutedsl.cc (3)
9-39: LGTM!The
ExtractFuncInfofunction correctly extracts function metadata including parameter types, grid constant handling, and launch parameter tags.
62-66: LGTM!The FFI registration follows TVM conventions. The Cppcheck syntax error is a false positive due to the TVM macro expansion.
41-60: Confirmed: This follows the established pattern for "without compile" code paths.The
("ptx", "ptx")parameters are intentional. The same format is used identically inBuildTileLangCUDAWithoutCompileinrt_mod_cuda.cc(line 102), confirming this is the expected behavior for the "without compile" code generation path. While the generated code is CuTeDSL Python rather than actual PTX, the format strings appear to be placeholders or conventionally used for the intermediate representation handled by post-processing callbacks.tilelang/jit/adapter/cutedsl/checks.py (1)
67-79: LGTM!The filesystem-based check for the "cute" submodule is a smart optimization to avoid expensive imports while still verifying CuTe support is available. The check correctly handles both package (
cute/) and module (cute.py) cases.tilelang/jit/kernel.py (3)
320-333: LGTM!The CuTeDSL adapter instantiation follows the established pattern of other backends, with proper target validation via
is_cutedsl_target.
405-416: LGTM!The database loading path for CuTeDSL correctly delegates to
CuTeDSLKernelAdapter.from_databasewith the expected parameters.
467-476: LGTM!The source access methods correctly include
"cutedsl"in the set of backends that delegate to the adapter for source retrieval.tilelang/jit/adapter/wrapper.py (3)
885-912: LGTM! Clean dispatch pattern for CuTeDSL target.The conditional routing based on
is_cutedsl_targetand the dictionary-based return contract provide a clean interface for the new backend. The structure withhost_func,function_names, and launcher-related fields aligns well with the CuTeDSL adapter requirements.One minor note: the error message on line 896 could be moved to a custom exception class per TRY003, but this is acceptable for now given it's a configuration error that should be rare.
203-205: Good fix for C/CUDA integer division operator.The comment correctly explains that C/C++ uses
/for integer division, not//. This ensures the generated CUDA source compiles correctly.
333-376: Good typo fix:tma_descripter_init→tma_descriptor_init.Renaming the variable from
tma_descripter_inittotma_descriptor_initimproves code clarity.src/target/codegen_cutedsl.cc (2)
265-629: Comprehensive intrinsic handling with clear unsupported-op fallbacks.The extensive handling of TileLang intrinsics with explicit
LOG(FATAL)for unsupported operations provides clear error messages. The pattern of delegating toCodeGenTileLangPY::VisitExpr_for unhandled cases is appropriate for the inheritance hierarchy.
852-889: Well-structured unroll loop handling.The
ForNodevisitor correctly handles:
- Non-unrolled loops by delegating to parent
- Unroll factor from attributes
range_constexprfor small constant loops- Full unroll flag when appropriate
This aligns with CuTeDSL's Python-based loop constructs.
tilelang/jit/__init__.py (1)
430-430: LGTM! Consistent ExecutionBackend type extension.The
cutedslbackend is consistently added across all relevant type annotations and function signatures. The type aliasExecutionBackendcentralizes the definition, making future updates easier.src/target/codegen_py.h (1)
32-48: Well-documented public API for CodeGenTileLangPY.The public interface with
AddFunctionandFinishmethods is clean and well-documented. The Doxygen comments clearly explain the purpose of each method.tilelang/cache/kernel_cache.py (2)
31-34: LGTM! Clear constant definitions for CuTeDSL artifacts.The new constants (
LAUNCHER_LIB_PATH,LAUNCHER_CPP_PATH,CUTEDSL_CUBIN_PATH) are well-named and align with the CuTeDSL artifact structure.
221-225: Good: Cache path propagation for CuTeDSL cubin saving.Setting
kernel.adapter._cache_pathenables the CuTeDSL adapter to save the cubin after first execution. The conditional check forexecution_backend == "cutedsl"is appropriate.tilelang/contrib/cutedsl/cpasync.py (3)
39-43: Conditional cp.async wrapper looks fine.
[cute.jit] wrapper is minimal and keeps the “conditional” behavior localized.
149-168: TMA store arrive/wait wrappers look consistent with the intended PTX mapping.
200-212: Prefetch helper looks OK; consider arch gating only if NVVM op is arch-specific.tilelang/jit/adapter/cutedsl/adapter.py (1)
338-341:prim_funcaccessor is clean and keeps adapter logic decoupled from module structure.
|
Awesome Contributions! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
♻️ Duplicate comments (3)
src/target/codegen_py.cc (3)
81-102: Collision check needed for name_hint fallback path.As noted in previous reviews, the name_hint fallback at line 97 needs the same collision validation as the global_symbol path. This prevents duplicate function names when multiple functions share the same name_hint.
188-217: Missing break statements and type coverage.As noted in previous reviews, the switch cases lack break statements causing fallthrough, and float16/default cases are missing. These must be fixed to generate correct Python type annotations.
522-536: Inconsistent expression printing in for-loop bounds.As noted in previous reviews, lines 529-530 stream expressions directly instead of using
PrintExpr_, which can produce TVM IR representation instead of valid Python syntax.
🧹 Nitpick comments (2)
src/target/codegen_cutedsl.cc (2)
1067-1080: Add validation for unsupported nested template syntax.The comment states "Nested '<' is not supported," but there's no validation to detect or reject this case. If nested templates are encountered, the code will silently produce incorrect output.
Add a check to detect and fail fast on unsupported input:
// Replace "<...>" with "(...)". Nested "<" is not supported { auto pos_left = global_symbol_str.find('<'); while (pos_left != std::string::npos) { auto pos_right = global_symbol_str.find('>', pos_left + 1); if (pos_right != std::string::npos) { auto args = global_symbol_str.substr(pos_left + 1, pos_right - pos_left - 1); + // Check for nested '<' which is not supported + if (args.find('<') != std::string::npos) { + LOG(FATAL) << "Nested template arguments are not supported: " << global_symbol_str; + } ReplaceAll(args, "true", "True"); ReplaceAll(args, "false", "False");
778-778: UseLOG(FATAL)for consistency with other error paths.Line 778 uses
ICHECK(false)to report an unsupported scope, while lines 758, 760, and 795 useLOG(FATAL)for the same purpose.For consistency with the rest of this file, consider:
- ICHECK(false) << "Unsupported scope: " << scope; + LOG(FATAL) << "Unsupported scope: " << scope;
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/target/codegen_cutedsl.cc(1 hunks)src/target/codegen_py.cc(1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.
Applied to files:
src/target/codegen_py.cc
🧬 Code graph analysis (2)
src/target/codegen_cutedsl.cc (1)
src/target/codegen_py.cc (4)
CheckOutermostParenthesesMatch(22-38)CheckOutermostParenthesesMatch(22-22)RemoveOutermostParentheses(40-46)RemoveOutermostParentheses(40-40)
src/target/codegen_py.cc (1)
src/target/codegen_py.h (2)
PrintStmt_(113-113)PrintExpr_(119-121)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
- GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.8
- GitHub Check: Build SDist
- GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
🔇 Additional comments (10)
src/target/codegen_py.cc (10)
22-47: LGTM!The parentheses matching and removal logic is correct. The algorithm properly tracks nesting depth and validates that the outermost parentheses enclose the entire expression.
49-79: LGTM!The function registration and code generation flow is well-structured. State management and scope handling are appropriate for a code generator.
104-181: LGTM!The function state initialization, signature generation, and keyword reservation are well-implemented. The keyword list is comprehensive and includes both Python keywords and type names that could cause conflicts.
183-186: LGTM!The SSA assignment printing correctly removes unnecessary outermost parentheses for cleaner output.
219-255: LGTM!The expression visitors for variables, integers, floats, and casts are correctly implemented. Note that the float16 handling at lines 247-250 depends on the fix for
PrintTypeflagged in previous reviews.
262-366: LGTM!The cast, arithmetic, logical, comparison, select, and ramp expression visitors are correctly implemented and generate appropriate Python syntax.
368-451: LGTM!The call expression visitor comprehensively handles built-in operations, external calls, and function references. The fallthrough to
LOG(FATAL)ensures unhandled operations are caught during development.
453-520: LGTM!Buffer load/store, let statements, and allocation are correctly implemented. The Python list allocation
[None] * sizeis appropriate for creating mutable storage.
538-579: LGTM!The while, if/else, sequence, and evaluate statement visitors correctly generate Python control flow structures.
592-680: LGTM!The utility methods for type casting, binary expressions, external calls, buffer references, and handle type tracking are correctly implemented and provide appropriate validation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
requirements-test-cuda.txt (1)
10-11: Verify nvidia-cutclass-dsl version and consider clarifying GEMM v1 requirement.The nvidia-cutlass-dsl version 4.3.1 is compatible with CUDA 12.x, so the dependency constraint is appropriate. However, per the PR objectives, CuTeDSL currently requires GEMM v1, and users must set
TILELANG_USE_GEMM_V1=1before importing tilelang. Consider updating the comment to document this prerequisite for clarity.Consider revising the comment to include the GEMM v1 requirement:
-# CuTeDSL (CUTLASS Python DSL with CuTe support) +# CuTeDSL (CUTLASS Python DSL with CuTe support; requires GEMM v1 - set TILELANG_USE_GEMM_V1=1)This would help users understand the prerequisite when reviewing test requirements.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
requirements-test-cuda.txt(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (6)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
- GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.8
- GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
- GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
- GitHub Check: Build SDist
kurisu6912
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @lucifer1004 for your great contribution to tilelang!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
src/target/codegen_cutedsl.cc (1)
843-843: Define the missing constantLOOP_UNROLL_THRESHOLD.The constant
LOOP_UNROLL_THRESHOLDis referenced but not defined in this file or in any included headers. This will cause a compilation error.Add the constant definition at the top of the file or include the appropriate header:
namespace tvm { namespace codegen { namespace { +constexpr int LOOP_UNROLL_THRESHOLD = 16; // Adjust value as appropriate void ReplaceAll(std::string &str, const std::string &from,
🧹 Nitpick comments (1)
src/target/codegen_cutedsl.cc (1)
25-32: Consider extractingReplaceAllto shared utilities.The
ReplaceAllfunction is a general-purpose string utility that could be shared across codegen modules.If this pattern is used in multiple codegen files, consider adding it to
codegen_utils.h/ccalongside the parentheses utilities for consistency.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
CMakeLists.txt(1 hunks)src/target/codegen_cutedsl.cc(1 hunks)src/target/codegen_py.cc(1 hunks)src/target/codegen_utils.cc(1 hunks)src/target/codegen_utils.h(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- CMakeLists.txt
- src/target/codegen_py.cc
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-12T09:47:46.474Z
Learnt from: kurisu6912
Repo: tile-ai/tilelang PR: 794
File: tilelang/transform/add_bufstore_wrapper.py:30-33
Timestamp: 2025-09-12T09:47:46.474Z
Learning: In TVM's PyStmtExprMutator, visit_block_ methods typically call super().visit_block_(op) to process child nodes and update internal state, but return the original op when the block itself doesn't need transformation. The pattern `return op` is correct for blocks that serve as containers where mutations happen at deeper levels.
Applied to files:
src/target/codegen_cutedsl.cc
🧬 Code graph analysis (1)
src/target/codegen_utils.h (1)
src/target/codegen_utils.cc (4)
CheckOutermostParenthesesMatch(11-27)CheckOutermostParenthesesMatch(11-11)RemoveOutermostParentheses(29-35)RemoveOutermostParentheses(29-29)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
- GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
- GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.8
- GitHub Check: Build SDist
🔇 Additional comments (7)
src/target/codegen_cutedsl.cc (4)
555-555: Verify fastmath configuration from pass config.The
fastmath=Trueparameter is hardcoded. Consider checking the pass configuration to allow users to control this behavior, as suggested in previous reviews.Based on past review comments, this should respect the pass configuration for fastmath settings.
622-624: Clarify the logic for checking buffer reference suffix.The code checks if
ref.back() == ')'to determine whether to append.load(). Based on the comment block at lines 1155-1163, this is intentional: only form (4) ending with ')' represents a full tensor that needs.load(), while other forms represent scalar accesses that don't.This logic correctly distinguishes between tensor forms (requiring
.load()) and scalar accesses. The comment block at lines 1155-1163 documents this design clearly.
156-156: Scientific notation is appropriate for Python codegen.The use of
std::scientificinstead of hexfloat is intentional for CuTeDSL, which generates Python code. Python doesn't support C-style hexfloat literals, so decimal scientific notation is the correct choice here.The comment at line 154 explicitly states "For CuTeDSL, use Python decimal format instead of C hexfloat", which aligns with the implementation.
6-6: Good: Utility duplication resolved.The inclusion of
codegen_utils.haddresses the previous concern about duplicated utility functions (CheckOutermostParenthesesMatchandRemoveOutermostParentheses).src/target/codegen_utils.h (1)
1-34: LGTM: Well-structured utility header.The header file follows best practices with proper include guards, clear documentation, and a clean API. The functions address the code duplication concern raised in previous reviews.
src/target/codegen_utils.cc (2)
11-27: LGTM: Correct parentheses matching logic.The implementation correctly validates that the first and last characters are matching parentheses that enclose the entire string. The balance counter approach ensures that intermediate closing parentheses don't prematurely match.
29-35: LGTM: Safe parentheses removal.The function safely removes outermost parentheses only when they match, preventing incorrect transformations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 8
♻️ Duplicate comments (4)
src/target/codegen_py.cc (3)
47-68: Prevent duplicate Pythondefnames in thename_hintpath (add collision check).
global_symbolpath checksContainsName, but thename_hintfallback reserves blindly (Line 63-65). If twoGlobalVars share the samename_hint, you’ll emit duplicate function definitions.auto function_name = [&]() -> ffi::String { if (auto global_symbol = func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol)) { auto name = global_symbol.value(); ICHECK(!func_name_supply_->ContainsName(name)) << "Function " << gvar << " must use global symbol " << name << ", but this name has already been used."; func_name_supply_->ReserveName(name); return name; } else { - func_name_supply_->ReserveName(gvar->name_hint); - return gvar->name_hint; + auto name = gvar->name_hint; + ICHECK(!func_name_supply_->ContainsName(name)) + << "Function " << gvar << " must use name_hint " << name + << ", but this name has already been used."; + func_name_supply_->ReserveName(name); + return name; } }();
235-238: Escape Python string literals inStringImmemission.Raw
op->valuecan contain",\, newlines, etc., generating invalid Python.void CodeGenTileLangPY::VisitExpr_(const StringImmNode *op, std::ostream &os) { // NOLINT(*) - os << "\"" << op->value << "\""; + os << "\""; + for (char c : std::string(op->value)) { + switch (c) { + case '\"': os << "\\\""; break; + case '\\': os << "\\\\"; break; + case '\n': os << "\\n"; break; + case '\r': os << "\\r"; break; + case '\t': os << "\\t"; break; + default: os << c; break; + } + } + os << "\""; }
561-570: Escape assert message strings (reuse the same escaping asStringImm).Same escaping bug as
VisitExpr_(StringImmNode).void CodeGenTileLangPY::VisitStmt_(const AssertStmtNode *op) { std::string cond = PrintExpr_(op->condition); PrintIndent(); if (const auto *str = op->message.as<StringImmNode>()) { - stream << "assert " << cond << ", \"" << str->value << "\"\n"; + stream << "assert " << cond << ", \""; + for (char c : std::string(str->value)) { + switch (c) { + case '\"': stream << "\\\""; break; + case '\\': stream << "\\\\"; break; + case '\n': stream << "\\n"; break; + case '\r': stream << "\\r"; break; + case '\t': stream << "\\t"; break; + default: stream << c; break; + } + } + stream << "\"\n"; } else { stream << "assert " << cond << "\n"; } PrintStmt_(op->body); }src/target/codegen_cutedsl.cc (1)
241-605: Keep an eye onVisitExpr_(CallNode*)size/complexity (already flagged previously).
This is still a large monolithic dispatcher; any further additions will get risky to maintain and test.
🧹 Nitpick comments (1)
src/target/codegen_cutedsl.cc (1)
47-113:DTypeToString:t.is_scalar()vst.is_void()is inconsistent—verify TVMDataTypesemantics.
YouICHECK(t.is_scalar())and then handlet.is_void(). Ifvoidis not considered scalar, this is unreachable; if it is, OK. Consider either removing the scalar check forvoid, or rewriting to make intent explicit.Also applies to: 116-121
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
src/target/codegen_cutedsl.cc(1 hunks)src/target/codegen_py.cc(1 hunks)src/target/codegen_utils.cc(1 hunks)src/target/codegen_utils.h(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- src/target/codegen_utils.h
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-12T09:47:46.474Z
Learnt from: kurisu6912
Repo: tile-ai/tilelang PR: 794
File: tilelang/transform/add_bufstore_wrapper.py:30-33
Timestamp: 2025-09-12T09:47:46.474Z
Learning: In TVM's PyStmtExprMutator, visit_block_ methods typically call super().visit_block_(op) to process child nodes and update internal state, but return the original op when the block itself doesn't need transformation. The pattern `return op` is correct for blocks that serve as containers where mutations happen at deeper levels.
Applied to files:
src/target/codegen_cutedsl.cc
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.
Applied to files:
src/target/codegen_cutedsl.cc
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Build SDist
- GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
- GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.8
- GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
♻️ Duplicate comments (9)
tilelang/jit/adapter/cutedsl/adapter.py (3)
146-153:from_database()must setadapter.libpath(and should align the “source” field it reads) to avoid runtime AttributeError.
Right now_save_cubin_to_cache_if_needed()readsself.libpath(Line 241) butfrom_database()never sets it; alsoopen(kernel_lib_path).read()is stored inkernel_global_source(Line 151) whileget_kernel_source()returnsdevice_kernel_source(Line 211), so this looks inconsistent.@@ adapter.lib_generator.assign_compile_flags(compile_flags) adapter.lib_generator.load_lib(lib_path=kernel_lib_path) - adapter.kernel_global_source = open(kernel_lib_path).read() + adapter.libpath = kernel_lib_path + adapter.device_kernel_source = open(kernel_lib_path).read() adapter.pymodule = adapter.lib_generator.pymodule
192-200: Guardbuffer.strides is Nonein_process_dynamic_symbolic().
TVM can represent compact buffers withstrides=None; iterating it will crash.@@ for i, param in enumerate(params): if param not in buffer_map: continue buffer = buffer_map[param] + if buffer.strides is None: + continue for j, stride in enumerate(buffer.strides): if isinstance(stride, tir.Var): unique_push_back(stride, (1, i, j))
262-333: Argument marshalling is unsafe: scalar params + PrimFunc param-index mapping can index the wrongins[](or go OOB).
dynamic_symbolic_mapstores PrimFunc param indices (Line 165-166), but_wrap_forward_from_prebuild_lib()indexes intoins[...](inputs-only) at Lines 293-317; this breaks as soon as you have outputs, scalars, or non-tensor params.A safer pattern is: materialize
param_valuesin PrimFunc-param order, allocate outputs into it, and resolve dynamic vars againstparam_values[...](ensuring the referenced value is a tensor).@@ - def _wrap_forward_from_prebuild_lib(self, *ins: list[torch.Tensor], stream: int | None = None): + def _wrap_forward_from_prebuild_lib(self, *ins: Any, stream: int | None = None): @@ - ins_idx = 0 - args = [] + ins_idx = 0 + # Map provided inputs into PrimFunc-param order (excluding allocated outputs). + param_values: list[Any] = [None] * len(self.params) + for i in range(len(self.params)): + if i in self.result_idx: + continue + param_values[i] = ins[ins_idx] + ins_idx += 1 + + first_tensor = next((v for v in param_values if isinstance(v, torch.Tensor)), None) + if first_tensor is None: + raise ValueError("Expected at least one torch.Tensor argument to infer device") + + args: list[Any] = [] @@ if i in self.result_idx: dtype = self.param_dtypes[i] shape = [] # Now working with native Python list, no FFI calls needed for s in self.param_shapes[i]: if isinstance(s, tir.Var): - ref_id, ref_tensor_idx, ref_dim_idx = self.dynamic_symbolic_map[s] + ref_id, ref_param_idx, ref_dim_idx = self.dynamic_symbolic_map[s] + ref_val = param_values[ref_param_idx] + if not isinstance(ref_val, torch.Tensor): + raise TypeError( + f"Dynamic var {s} refers to non-tensor param at index {ref_param_idx}" + ) if ref_id == 0: - shape.append(ins[ref_tensor_idx].shape[ref_dim_idx]) + shape.append(ref_val.shape[ref_dim_idx]) elif ref_id == 1: # Stride vars are not expected in output shapes, but handle defensively. - shape.append(ins[ref_tensor_idx].stride()[ref_dim_idx]) + shape.append(ref_val.stride()[ref_dim_idx]) else: raise ValueError(f"Unknown dynamic symbol ref id: {ref_id}") else: # Already converted to Python int during initialization shape.append(s) - device = ins[0].device if len(ins) > 0 else torch.cuda.current_device() - tensor = torch.empty(*shape, dtype=dtype, device=device) + tensor = torch.empty(*shape, dtype=dtype, device=first_tensor.device) + param_values[i] = tensor else: - tensor = ins[ins_idx] - ins_idx += 1 + tensor = param_values[i] args.append(tensor) @@ for sym in self.dynamic_symbolic_order: - ref_id, buffer_idx, dim_idx = self.dynamic_symbolic_map[sym] + ref_id, ref_param_idx, dim_idx = self.dynamic_symbolic_map[sym] + ref_val = param_values[ref_param_idx] + if not isinstance(ref_val, torch.Tensor): + raise TypeError(f"Dynamic symbolic var {sym} refers to non-tensor param at index {ref_param_idx}") if ref_id == 0: - args.append(ins[buffer_idx].shape[dim_idx]) + args.append(ref_val.shape[dim_idx]) elif ref_id == 1: - args.append(ins[buffer_idx].stride()[dim_idx]) + args.append(ref_val.stride()[dim_idx]) else: raise ValueError(f"Unknown dynamic symbol ref id: {ref_id}")src/target/codegen_cutedsl.cc (6)
24-31: GuardReplaceAll()against emptyfrom(infinite loop).void ReplaceAll(std::string &str, const std::string &from, const std::string &to) { + ICHECK(!from.empty()) << "ReplaceAll(): `from` must be non-empty"; auto pos = str.find(from); while (pos != std::string::npos) { str.replace(pos, from.size(), to); pos = str.find(from, pos + to.size()); } }
193-200: Integer division emission uses//(Python floor), which can diverge from TIR semantics for negatives.If negative operands are possible, either (a) emit a trunc-toward-zero form, or (b) enforce/ICHECK non-negativity for both operands before using
//.
375-426:eviction_policy_names_indexing andneed_reduceassumeIntImm+ in-range; add validation to avoid OOB/null deref crashes.At minimum, validate
op->args.back()/op->args[...-2]areIntImmNodeand bounds-check the eviction policy index before subscripting.
607-667: Vector load/store: guard lane math + don’t assumeRampNodein non-contiguous path.Please add:
ICHECK_GE(value_lanes, element_dtype.lanes())ICHECK_EQ(value_lanes % element_dtype.lanes(), 0)- Only attempt
arith::ramp(..., value_lanes / element_dtype.lanes())when the quotient is > 0- In non-contiguous path,
ICHECK(index.as<RampNode>()) << "...";once before the loop (not per-iteration)Also applies to: 669-722
737-742:shared.dynhardcodesalignment/div_by=1024; make dtype-aware to avoid incorrect alignment assumptions.Consider using
op->dtype.bytes() * op->dtype.lanes()(like thesharedpath) for bothalignment=anddiv_by=.
1213-1222:GetBufferRef_:buffer_sizesilently truncates if bitwidths aren’t divisible; fail fast.} else { - int buffer_size = - t.bits() * t.lanes() / - (buffer_element_dtype.bits() * buffer_element_dtype.lanes()); + const int num = t.bits() * t.lanes(); + const int den = + buffer_element_dtype.bits() * buffer_element_dtype.lanes(); + ICHECK_GT(den, 0); + ICHECK_EQ(num % den, 0) + << "Cannot form view: requested dtype " << t + << " is not divisible by buffer dtype " << buffer_element_dtype; + const int buffer_size = num / den; std::ostringstream os; os << "tl.make_tensor_at_offset(" << ptr_str << ", " << index_str << ", (" << buffer_size << ",), div_by=" << buffer_size << ")"; return os.str(); }
🧹 Nitpick comments (1)
tilelang/jit/adapter/cutedsl/adapter.py (1)
203-212: Consider keepingget_kernel_source()behavior consistent withBaseKernelAdapter.get_kernel_source().
Base adapter supportskernel_only=False; herekernel_onlyis unused and you always returndevice_kernel_source. If intentional, consider dropping the param or mirroring the base behavior (e.g., returning host+device when available).
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/target/codegen_cutedsl.cc(1 hunks)tilelang/jit/adapter/cutedsl/adapter.py(1 hunks)
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2025-09-12T09:47:46.474Z
Learnt from: kurisu6912
Repo: tile-ai/tilelang PR: 794
File: tilelang/transform/add_bufstore_wrapper.py:30-33
Timestamp: 2025-09-12T09:47:46.474Z
Learning: In TVM's PyStmtExprMutator, visit_block_ methods typically call super().visit_block_(op) to process child nodes and update internal state, but return the original op when the block itself doesn't need transformation. The pattern `return op` is correct for blocks that serve as containers where mutations happen at deeper levels.
Applied to files:
src/target/codegen_cutedsl.cc
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.
Applied to files:
src/target/codegen_cutedsl.cc
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.
Applied to files:
src/target/codegen_cutedsl.cc
🧬 Code graph analysis (2)
src/target/codegen_cutedsl.cc (2)
src/target/codegen_py.h (2)
PrintExpr_(119-121)PrintStmt_(113-113)src/target/codegen_utils.cc (2)
RemoveOutermostParentheses(29-35)RemoveOutermostParentheses(29-29)
tilelang/jit/adapter/cutedsl/adapter.py (3)
tilelang/utils/target.py (1)
determine_target(63-133)tilelang/jit/adapter/base.py (1)
_post_init(95-96)tilelang/jit/adapter/nvrtc/adapter.py (4)
_process_dynamic_symbolic(155-175)prim_func(265-267)get_kernel_source(177-188)_convert_torch_func(254-262)
🪛 Ruff (0.14.8)
tilelang/jit/adapter/cutedsl/adapter.py
48-48: Avoid specifying long messages outside the exception class
(TRY003)
112-112: Unused class method argument: pass_configs
(ARG003)
124-124: Avoid specifying long messages outside the exception class
(TRY003)
203-203: Unused method argument: kernel_only
(ARG002)
259-259: Do not catch blind exception: Exception
(BLE001)
279-281: Avoid specifying long messages outside the exception class
(TRY003)
300-300: Avoid specifying long messages outside the exception class
(TRY003)
318-318: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.8
- GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
- GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
- GitHub Check: Build SDist
This comment was marked as outdated.
This comment was marked as outdated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
src/target/codegen_cutedsl.cc (1)
671-734: Vector BufferStore path still lacks the lane‑count safety checks added for loadsThe vector load path defends against invalid lane relationships:
if (value_lanes == element_dtype.lanes()) { ... } else { ICHECK_GE(value_lanes, element_dtype.lanes()) << "Unsupported load/store: value lanes < buffer element lanes"; ... if (arith::ramp(base, 1, value_lanes / element_dtype.lanes()).Match(index)) { ... } else { const RampNode* ramp = index.as<RampNode>(); ICHECK(ramp) << "Expected Ramp index for vectorized non-contiguous access"; ... } }The store path, however, still does:
int value_lanes = value_dtype.lanes(); ... } else { bool is_contiguous = false; arith::PVar<PrimExpr> base; if (arith::ramp(base, 1, value_lanes / element_dtype.lanes()) .Match(index_expr)) { is_contiguous = true; } ... for (int i = 0; i < value_lanes; ++i) { const RampNode *ramp = index_expr.as<RampNode>(); ICHECK(ramp); ... } }If malformed IR ever produces
value_lanes < element_dtype.lanes(),value_lanes / element_dtype.lanes()becomes 0 and the pattern match degenerates; the non‑contiguous branch also assumesindex_expris aRampNodewithout checking, which can crash.Mirroring the load‑side checks would make this more robust:
} else { - bool is_contiguous = false; - arith::PVar<PrimExpr> base; - if (arith::ramp(base, 1, value_lanes / element_dtype.lanes()) - .Match(index_expr)) { - is_contiguous = true; - } + ICHECK_GE(value_lanes, element_dtype.lanes()) + << "Unsupported store: value lanes < buffer element lanes"; + bool is_contiguous = false; + arith::PVar<PrimExpr> base; + const int stride_count = value_lanes / element_dtype.lanes(); + if (stride_count > 0 && + arith::ramp(base, 1, stride_count).Match(index_expr)) { + is_contiguous = true; + } @@ - } else { + } else { ICHECK(element_dtype.is_scalar()) << "buffer element type for non-contiguous store must be scalar " "currently"; @@ - for (int i = 0; i < value_lanes; ++i) { - const RampNode *ramp = index_expr.as<RampNode>(); - ICHECK(ramp); + const RampNode *ramp = index_expr.as<RampNode>(); + ICHECK(ramp) + << "Expected Ramp index for vectorized non-contiguous store"; + for (int i = 0; i < value_lanes; ++i) { auto idx_expr = arith::Analyzer().Simplify(ramp->base + ramp->stride * i);This keeps load/store behavior symmetric and avoids silent misbehavior on bad IR.
Also applies to: 736-789
🧹 Nitpick comments (4)
tilelang/contrib/cutedsl/__init__.py (4)
7-15: Star imports andnoqa: F401pragmas: fine but noisy for toolingThe
from ... import *plus# noqa: F401pattern works here as an aggregator, but Ruff is flagging unusednoqadirectives and can’t see thatbar_sync_ptxcomes fromreduce.py. If you want cleaner static-analysis output, consider:
- Dropping the
# noqa: F401comments now that F401 isn’t enabled, and/or- Replacing
import *with explicit re‑exports and an__all__in this package.Given this is a convenience facade, it’s reasonable to leave as‑is if you’re okay with the Ruff noise.
Also applies to: 20-26, 61-63
45-49: Tightenmake_tensor_at_offsetarounddiv_byandas_numericusageThis helper is central to all
tl.make_tensor_at_offset(...)codegen, so thediv_bycontract matters:
- You import
as_numericfromcutlass.base_dsl.typingbut callcutlass.as_numeric(offset)instead. If these are aliases, using the importedas_numericmakes the dependency clearer and avoids relying on a top‑level re‑export.- It may be worth asserting
div_by >= 1(and maybe documenting that it’s in “elements” units) to catch accidental zero/negative values before they hitcute.assume(..., divby=div_by).Both changes are small but make the semantics of the
div_bypath easier to reason about, especially since there are open questions arounddiv_byusage elsewhere in this PR.
51-59: Consider guardingshuffle_electagainst non‑warp‑multiplethread_extentThe implementation assumes
thread_extentis 0 or a multiple of 32 (thread_extent // 32is used as the number of warps per group). If someone accidentally passes a non‑multiple (e.g. 48), this will computethread_extent // 32 == 1but the semantics become unclear.A light defensive check such as:
def shuffle_elect(thread_extent): - # thread_extent is the number of threads of a warpgroup + # thread_extent is the number of threads of a warpgroup + assert thread_extent == 0 or thread_extent % 32 == 0, \ + "shuffle_elect expects thread_extent to be 0 or a multiple of 32"would make misuse much easier to diagnose, and it matches the intended warp‑group semantics from the C++ intrinsics.
66-100:pack_half2/AtomicAddimplementation looks correct; only minor polish possible
pack_half2correctly bitcasts fp16 values to i16 and uses a singlemov.b32to pack them. Since it’s a pure bit‑manipulation op, you could sethas_side_effects=Falseon the inline asm to give NVVM more freedom to optimize, but functionally this is fine.AtomicAddcorrectly routes tonvvm.atomicrmwforFloat32/Int32with relaxed GPU scope and returns the updated value wrapped in the pointer’s dtype. The explicitValueErrorfor unsupported dtypes is good fail‑fast behavior; if you expect bf16 or other integer widths in the near future, this is the obvious extension point.No blocking issues here; these helpers look ready for use as‑is.
Also applies to: 103-127
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/target/codegen_cutedsl.cc(1 hunks)src/target/codegen_cutedsl.h(1 hunks)tilelang/contrib/cutedsl/__init__.py(1 hunks)
🧰 Additional context used
🧠 Learnings (7)
📚 Learning: 2025-12-15T07:23:46.620Z
Learnt from: cherichy
Repo: tile-ai/tilelang PR: 1421
File: tilelang/contrib/cutedsl/cpasync.py:45-55
Timestamp: 2025-12-15T07:23:46.620Z
Learning: In tilelang/contrib/cutedsl/cpasync.py, using AddressSpace.generic for TMA descriptor pointers (tensormap_ptr) in the extract_tensormap_ptr function is correct. When creating ptr_type with _cute_ir.PtrType.get for TMA descriptors in CuTeDSL, AddressSpace.generic should be used, not a device-specific or constant address space.
Applied to files:
tilelang/contrib/cutedsl/__init__.pysrc/target/codegen_cutedsl.cc
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.
Applied to files:
tilelang/contrib/cutedsl/__init__.pysrc/target/codegen_cutedsl.cc
📚 Learning: 2025-12-15T07:48:36.835Z
Learnt from: cherichy
Repo: tile-ai/tilelang PR: 1421
File: src/target/codegen_cutedsl.cc:789-793
Timestamp: 2025-12-15T07:48:36.835Z
Learning: In tilelang/contrib/cutedsl, the `tl.make_rmem_tensor` function accepts both an Integer and a Tuple of Integer for its shape parameter. Therefore, both `tl.make_rmem_tensor(N, ...)` and `tl.make_rmem_tensor((N,), ...)` are valid syntaxes in CuteDSL-generated code.
Applied to files:
tilelang/contrib/cutedsl/__init__.pysrc/target/codegen_cutedsl.cc
📚 Learning: 2025-09-12T09:47:46.474Z
Learnt from: kurisu6912
Repo: tile-ai/tilelang PR: 794
File: tilelang/transform/add_bufstore_wrapper.py:30-33
Timestamp: 2025-09-12T09:47:46.474Z
Learning: In TVM's PyStmtExprMutator, visit_block_ methods typically call super().visit_block_(op) to process child nodes and update internal state, but return the original op when the block itself doesn't need transformation. The pattern `return op` is correct for blocks that serve as containers where mutations happen at deeper levels.
Applied to files:
src/target/codegen_cutedsl.cc
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.
Applied to files:
src/target/codegen_cutedsl.cc
📚 Learning: 2025-11-03T06:24:11.411Z
Learnt from: Rachmanino
Repo: tile-ai/tilelang PR: 1175
File: src/op/math.cc:44-52
Timestamp: 2025-11-03T06:24:11.411Z
Learning: In tilelang's `src/op/math.cc`, the `infinity_op` function uses `std::numeric_limits<float>::infinity()` as a placeholder for all float types (including float64 and bfloat16). The codegen layer (PrintConst:Inf) handles the correct infinity value based on the dtype field of the FloatImm node, so the specific C++ template argument doesn't matter.
Applied to files:
src/target/codegen_cutedsl.cc
📚 Learning: 2025-12-15T08:56:18.649Z
Learnt from: cherichy
Repo: tile-ai/tilelang PR: 1421
File: tilelang/contrib/cutedsl/reduce.py:161-184
Timestamp: 2025-12-15T08:56:18.649Z
Learning: In Tilelang's CUDA backend and CuTeDSL backend, barrier IDs 1 and 2 are reserved for internal use (such as in AllReduce operations). User-defined barriers should use IDs starting from 3 to avoid synchronization conflicts.
Applied to files:
src/target/codegen_cutedsl.cc
🧬 Code graph analysis (2)
tilelang/contrib/cutedsl/__init__.py (5)
tilelang/language/builtin.py (1)
sync_threads(702-709)src/tl_templates/cuda/intrin.h (2)
warpgroup_reg_alloc(124-126)warpgroup_reg_dealloc(128-130)tilelang/language/proxy.py (1)
make_tensor(278-279)tilelang/tileop/gemm/gemm_base.py (1)
mbar(127-128)tilelang/contrib/cutedsl/reduce.py (1)
bar_sync_ptx(97-108)
src/target/codegen_cutedsl.cc (3)
tilelang/contrib/cutedsl/cpasync.py (1)
tma_load(60-107)tilelang/language/math_intrinsics.py (1)
ieee_add(149-172)tilelang/language/reduce.py (3)
warp_reduce_sum(332-345)warp_reduce_max(348-361)warp_reduce_min(364-377)
🪛 Ruff (0.14.8)
tilelang/contrib/cutedsl/__init__.py
7-7: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
8-8: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
9-9: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
11-11: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
14-14: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
15-15: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
20-20: from .mbar import * used; unable to detect undefined names
(F403)
21-21: from .cpasync import * used; unable to detect undefined names
(F403)
22-22: from .gemm_V1 import * used; unable to detect undefined names
(F403)
23-23: from .reduce import * used; unable to detect undefined names
(F403)
24-24: from .ldsm import * used; unable to detect undefined names
(F403)
25-25: from .math import * used; unable to detect undefined names
(F403)
26-26: from .threadblock_swizzle import * used; unable to detect undefined names
(F403)
62-62: bar_sync_ptx may be undefined, or defined from star imports
(F405)
127-127: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: Build SDist
- GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
- GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
- GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.8
🔇 Additional comments (7)
src/target/codegen_cutedsl.cc (6)
25-39: Loop unroll threshold andcutlass.range_constexprusage look goodDefining
LOOP_UNROLL_THRESHOLD = 64in the anonymous namespace and using it to decide betweencutlass.range_constexprvscutlass.range(..., unroll_full=True)is a sensible trade‑off: it avoids the CuTe DSL compile‑time explosion for long static loops while still fully unrolling shorter ones.The
ForNodehandling correctly:
- Simplifies
min/extentwitharith::Analyzer,- Respects explicit
pragma_unroll_factorwhen present,- Falls back to the Python codegen for non‑unrolled loops.
No issues here from a correctness or maintainability standpoint.
Also applies to: 894-931
81-147: DType translation and float literal emission are robust
DTypeToStringcovers the supported CuTeDSL scalar dtypes (floats, the various float8/float6/float4 encodings, ints/uints, bool) and fails fast withLOG(FATAL)for anything unsupported, which is appropriate at codegen time.PrintTypeenforces scalar‑only printing for CuTeDSL types, catching accidental vector types early.VisitExpr_(FloatImmNode*)now emits infinities/NaNs viafloat('inf')/float('nan')and finite values viafloat.fromhex(...)withstd::hexfloat, which gives full‑precision round‑trip behavior and matches the earlier review feedback.This section looks correct and matches CuTeDSL expectations.
Also applies to: 150-156, 163-199
227-247: Integer div / fast‑math mapping is acceptable under current assumptionsThe combination of:
CanonicalizeFastmathFunctionName_mapping common C math names totl.*intrinsics, andVisitExpr_(DivNode*)using:
//for integer types, andtl.divf(..., fastmath=True)vstl.divf(...)depending onenable_fastmath_is consistent with the rest of the backend and PassContext‑controlled fast‑math policy.
Given earlier clarification that integer division here is only used for non‑negative indices, Python
//semantics are fine; if that ever changes, we’d need a truncating‑toward‑zero path. For now, this looks good.Also applies to: 52-67
280-491: CallNode CuTeDSL lowering is comprehensive;tl_shuffle_electwiring matches intentThe CallNode visitor:
- Cleanly handles the cp.async/mbarrier/TMA family with argument validation (e.g., eviction policy checks,
tma_storeIntImmguards) and clearLOG(FATAL)paths for unsupported variants.- Routes
tl.tl_gemmthroughPrintCallExtern_and adds the correcttl.gemm_*keyword arguments.- Maps
tl.pack_b16totl.pack_half2and usestl.shuffle_electin conjunction withIfThenElse’swith cute.arch.elect_one():to mimic the C++tl_shuffle_electsemantics (one elected lane per warp group).The number of
LOG(FATAL)branches is expected at this stage of CuTeDSL support and will make missing intrinsics obvious at codegen time. No functional issues stand out in this block.Also applies to: 611-612, 933-947
791-839: Dynamic/shared allocation and buffer pointer/ref helpers look consistent
- The
AllocateNodehandler’sshared.dynbranch emits a “fake” 1‑element tensor viatl.make_tensor(tl.get_dyn_smem(dtype, alignment=1024), (1,)), which matches the intent documented in the comment (“no bound check; just get its iterator”).GetBufferPtr_andGetBufferRef_consistently use.iteratorand, when handle types don’t match, wrap withtl.recast_ptr(..., dtype=...), before buildingtl.make_tensor_at_offset(...)views. The special‑case for scalar element types inlocal/shared/shared.dyn/shared.barrierscopes falling back tovid[index]matches CuTeDSL’s layout expectations.Given the earlier clarification that
tl.make_rmem_tensorand related APIs accept both integer and tuple shapes, the generated(buffer_size,)and(1,)shapes are valid. No issues here.Also applies to: 1209-1297
1314-1352:PrintStorageSync_improvements are good; only string‑literal assumption remains
PrintStorageSync_now:
- Distinguishes
"warp"vs"shared"/"shared.dyn"vs"global",- Validates
barrier_id/thread_countasIntImmNodeand surfaces clear error messages, and- Fails fast on unknown scopes with
LOG(FATAL).The remaining assumption is that
args[0]is always aStringImmNode; if TIR ever relaxes that, you’d need a defensive check there as well, but as long astvm_storage_syncis only emitted with literal scopes this is fine.src/target/codegen_cutedsl.h (1)
21-97: Header/API surface forCodeGenTileLangCuTeDSLmatches the implementationThe header cleanly exposes:
- The CuTeDSL‑specific overrides for expr/stmt visitors and type printing,
- The helper hooks (
PrintVec*,PrintBinary*,PrintCallExtern_,GetBufferPtr_/GetBufferRef_,BindThreadIndex_,PrintStorageSync_), and- Internal state (mbarrier name, unroll factors, eviction policy table, fast‑math flag).
Signatures and visibility match the
.ccdefinitions, and marking the classfinalis appropriate given it’s not meant for further extension. No changes needed here.
|
Ready for merge now. |
|
@codex review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
tilelang/jit/adapter/cutedsl/adapter.py (2)
148-153: Consider assigningpass_configsinfrom_database.The
from_databaseclassmethod receivespass_configsbut doesn't assign it to the adapter instance (unlike the constructor at line 74). While the library generator receives it viaassign_pass_configs(line 149), storing it on the adapter would maintain consistency with other adapters and support future use cases.Consider adding:
adapter.pass_configs = pass_configsadapter.lib_generator = CuTeDSLLibraryGenerator(adapter.target, adapter.verbose) +adapter.pass_configs = pass_configs adapter.lib_generator.assign_compile_flags(compile_flags)
206-214: Document whykernel_onlyparameter is unused.The
kernel_onlyparameter exists for interface consistency with other adapters but isn't used in the CuTeDSL implementation since there's only device source (no separate host source in the same sense as other backends).Consider adding a brief comment or docstring note explaining this behavior:
def get_kernel_source(self, kernel_only: bool = True) -> str | None: """Get the CUDA kernel source code. Note: kernel_only parameter is accepted for interface consistency but ignored since CuTeDSL only provides device source. """
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (9)
.github/workflows/ci.yml(1 hunks)CMakeLists.txt(1 hunks)tilelang/cache/kernel_cache.py(10 hunks)tilelang/engine/lower.py(2 hunks)tilelang/jit/__init__.py(7 hunks)tilelang/jit/adapter/cutedsl/adapter.py(1 hunks)tilelang/jit/adapter/nvrtc/adapter.py(1 hunks)tilelang/jit/execution_backend.py(4 hunks)tilelang/jit/kernel.py(7 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- tilelang/engine/lower.py
- CMakeLists.txt
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.
Applied to files:
.github/workflows/ci.yml
🧬 Code graph analysis (3)
tilelang/jit/kernel.py (2)
tilelang/jit/adapter/utils.py (1)
is_cutedsl_target(114-115)tilelang/jit/adapter/cutedsl/adapter.py (2)
from_database(102-156)get_kernel_source(206-214)
tilelang/jit/adapter/nvrtc/adapter.py (1)
tilelang/jit/adapter/nvrtc/wrapper.py (2)
host_func(260-264)host_func(267-269)
tilelang/jit/execution_backend.py (3)
tilelang/jit/adapter/utils.py (1)
is_cutedsl_target(114-115)tilelang/env.py (1)
use_gemm_v1(284-290)tilelang/jit/adapter/cutedsl/checks.py (1)
check_cutedsl_available(33-79)
🪛 Ruff (0.14.8)
tilelang/jit/execution_backend.py
81-83: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/jit/adapter/cutedsl/adapter.py
48-48: Avoid specifying long messages outside the exception class
(TRY003)
112-112: Unused class method argument: pass_configs
(ARG003)
124-124: Avoid specifying long messages outside the exception class
(TRY003)
206-206: Unused method argument: kernel_only
(ARG002)
262-262: Do not catch blind exception: Exception
(BLE001)
282-284: Avoid specifying long messages outside the exception class
(TRY003)
297-297: Avoid specifying long messages outside the exception class
(TRY003)
312-312: Avoid specifying long messages outside the exception class
(TRY003)
319-319: Avoid specifying long messages outside the exception class
(TRY003)
333-333: Avoid specifying long messages outside the exception class
(TRY003)
339-339: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
- GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
- GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
- GitHub Check: Build SDist
- GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.8
🔇 Additional comments (17)
.github/workflows/ci.yml (2)
373-373: LGTM! Proper test isolation.Excluding the CuTeDSL test file from the standard CUDA test suite is the right approach to prevent the GEMM v1 requirement from affecting other tests.
376-392: Excellent test isolation strategy!The dedicated CuTeDSL test step with GEMM v1 environment variable and single-process execution is well-designed:
- Setting
TILELANG_USE_GEMM_V1=1at the step level ensures proper isolation from GEMM v2 tests- Using
--numprocesses=1appropriately avoids GPU contention for these tests- Clear comments explain the rationale for both decisions
This approach cleanly separates CuTeDSL tests while maintaining consistency with the existing CI patterns.
tilelang/cache/kernel_cache.py (3)
298-344: CuTeDSL-specific save logic looks correct.The branching for CuTeDSL launcher artifacts (lines 327-343) correctly handles the optional C++ launcher library and source. The conditional logic appropriately separates CuTeDSL from other backends.
As noted in the existing comments, consolidating backend-specific cache logic into the adapters themselves would improve maintainability in future iterations.
421-461: CuTeDSL load logic handles empty sources correctly.The code correctly skips loading device/host kernel sources for CuTeDSL (lines 443-461) and initializes them as empty strings instead. This aligns with CuTeDSL's approach of reading kernel content directly from the Python module path.
The validation at line 472 properly accounts for CuTeDSL's alternative contract (empty sources are acceptable).
225-230: Cache path propagation enables post-compilation artifact saving.Setting
_cache_pathon the CuTeDSL adapter (lines 228-229) allows the adapter to save the cubin after first execution—a necessary deferred step since CuTeDSL generates the cubin during the first run rather than at compile time.This is a clean integration point between the cache and adapter layers.
tilelang/jit/adapter/nvrtc/adapter.py (1)
79-81: Wrapper API updated to dictionary-based return.The change from direct tuple unpacking to dictionary access aligns with the expanded wrapper contract that now returns additional keys (e.g.,
tma_cpp_init_code,launcher_cpp_code) for CuTeDSL support. The NVRTC adapter correctly extracts only the keys it needs while ignoring the extras.This is a clean, backward-compatible refactoring.
tilelang/jit/execution_backend.py (3)
35-36: CuTeDSL backend properly restricted to CuTeDSL targets.The short-circuit at lines 35-36 ensures that
"cutedsl"is only available when the target explicitly includes thecutedslkey. Plain CUDA targets (without the key) will fall through to line 37 and receive the standard backend list without"cutedsl".This addresses the past review concern about preventing users from selecting
"cutedsl"on non-CuTeDSL targets—they'll now get a clear validation error at line 112 instead of an assertion failure later.
79-93: GEMM v1 validation provides actionable errors.The
_require_gemm_v1_for_cutedsl()helper cleanly validates prerequisites:
- Checks
TILELANG_USE_GEMM_V1environment variable (lines 80-83)- Lazily imports and runs
check_cutedsl_available()(lines 85-92)- Converts
ImportErrortoValueErrorwhile preserving the actionable messageThis gives users clear feedback about missing dependencies or misconfiguration before compilation begins.
96-98: Auto and explicit backend resolution both enforce CuTeDSL prerequisites.Both the auto-resolution path (lines 96-98) and the explicit
"cutedsl"request path (lines 126-127) call_require_gemm_v1_for_cutedsl()to validate prerequisites before proceeding. This ensures consistent enforcement regardless of how the user selects the backend.Also applies to: 126-127
tilelang/jit/kernel.py (3)
327-340: CuTeDSL adapter creation follows established pattern.The new branch for
execution_backend == "cutedsl"(lines 327-340) correctly:
- Validates the target has the
cutedslkey (line 328)- Passes all required parameters matching the adapter signature
- Follows the same structure as other backends (tvm_ffi, ctypes, nvrtc, etc.)
This maintains consistency across the codebase.
412-423: Database restoration path includes CuTeDSL.The
from_databasepath (lines 412-423) mirrors the construction path, ensuring cached CuTeDSL kernels can be properly restored. This is essential for the kernel cache integration demonstrated inkernel_cache.py.
474-474: Source access methods recognize CuTeDSL adapter.Including
"cutedsl"in the backend checks at lines 474 and 482 ensuresget_kernel_source()andget_host_source()will delegate to the CuTeDSL adapter's methods rather than falling back to artifact sources. This is the correct behavior for adapters that manage their own source representations.Also applies to: 482-482
tilelang/jit/__init__.py (1)
52-52: Execution backend type annotations consistently updated.All public API entry points (
compile,par_compile,JITImpl,jit,lazy_jit) and theExecutionBackendtype alias now include"cutedsl"in their type signatures. This provides proper type checking and IDE support for the new backend option.The updates are consistent and complete across the module.
Also applies to: 67-67, 121-121, 138-138, 259-259, 427-427, 476-476
tilelang/jit/adapter/cutedsl/adapter.py (4)
85-96: Library generation and source reading correct for CuTeDSL.The sequence at lines 85-96 properly:
- Generates the CuTeDSL Python module via
lib_generator.compile_lib()andload_lib()- Reads the generated Python module content into
device_kernel_source(line 96)For CuTeDSL, the Python module IS the kernel representation, so storing it in
device_kernel_sourcemaintains consistency with the adapter interface while reflecting CuTeDSL's Python-based approach.
158-204: Dynamic symbol processing correctly handles shapes and strides.The
_process_dynamic_symbolicmethod (lines 158-204) properly:
- Follows CuTeDSL's ordering semantics (shapes first, then strides)
- Uses tuple encoding
(id, param_index, dim_index)where id=0 for shapes, id=1 for strides- Handles None strides safely (lines 198-199)
- Maintains insertion order via
unique_push_backThis implementation aligns with the wrapper's
get_dynamic_symbolic_set()and ensures correct argument passing during execution.
265-353: Input marshalling correctly handles scalars and dynamic shapes.The
_wrap_forward_from_prebuild_libmethod properly:
- Materializes all parameters in PrimFunc order (lines 286-293)
- Allocates output tensors using the first tensor's device (lines 295-297, 322)
- Resolves dynamic shapes by indexing
param_values(notins) to avoid mismatched indices (lines 309-319)- Appends dynamic symbolic values in the correct order (lines 329-339)
This fixes the critical issue from past reviews where indexing into
inswould fail with scalars or outputs present.
225-263: Post-execution cubin caching is a clever deferred-save pattern.The
_save_cubin_to_cache_if_neededmethod handles CuTeDSL's unique requirement: the cubin is generated during first execution rather than at compile time. Using the_cache_pathset bykernel_cache.py(line 236), it copies the cubin to the cache directory after the first run.The one-time flag (line 231) and existence checks (lines 249, 255) prevent redundant copies. The
exc_info=Truelogging (line 263) aids debugging if the copy fails.
|
@codex review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| kernel_lib_path = os.path.join(cache_path, kernel_lib_path) | ||
|
|
||
| # Save an extra Python file for NVRTC | ||
| if self.execution_backend == "nvrtc": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This feels a bit hacky now that we have so many backends. Could we refactor this part to make it more maintainable? If you don’t have time right now, we can spin it off into a separate thread/issue—otherwise, LGTM.
PR Description
Summary
This PR introduces CuTeDSL as a new TileLang target, providing an end-to-end path to lower → generate CuTeDSL Python → build/load runtime module → run via JIT, plus a dedicated test suite and CI wiring.
User-facing changes
target="cutedsl"(implemented as a CUDA target with an extracutedslkey).execution_backend="auto"now resolves tocutedslwhen the target is CuTeDSL; existing CUDA targets keep their previous defaults.What’s included
tilelang.utils.target.determine_target()recognizescutedsl*targets and injects thecutedslkey.tilelang.engine.lower.device_codegen*()dispatches totarget.build.tilelang_cutedsl*when the target has thecutedslkey.src/target/codegen_cutedsl.*) built on top of the Python codegen infrastructure (src/target/codegen_py.*).src/target/rt_mod_cutedsl.cc).tilelang/jit/adapter/cutedsl/and integration intotilelang.jit.testing/python/jit/test_tilelang_jit_cutedsl.py.TILELANG_USE_GEMM_V1=1(kept isolated to avoid changing default GEMM selection for other CUDA tests).Requirements / constraints
TILELANG_USE_GEMM_V1=1before importing tilelang.Backward compatibility
target="cutedsl"(or requestexecution_backend="cutedsl").Testing
testing/python/jit/test_tilelang_jit_cutedsl.py(correctness vs PyTorch matmul, profiler do_bench, multi-stream execution, dynamic shapes)TILELANG_USE_GEMM_V1=1and-n 1to avoid xdist contention on a single GPU.Notes for reviewers
tilelang/utils/target.py(target parsing + key injection)tilelang/engine/lower.py(codegen dispatch)tilelang/jit/execution_backend.py(auto backend selection + GEMM v1 requirement)tilelang/cache/kernel_cache.py(CuTeDSL artifact persistence)Follow-ups (not in this PR)
Summary by CodeRabbit
New Features
Kernel / Packaging
JIT / Adapters
Codegen
Tests
Chores
✏️ Tip: You can customize this high-level summary in your review settings.