Skip to content

Conversation

@willzhou-amd
Copy link
Contributor

@willzhou-amd willzhou-amd commented Jul 1, 2025

Changes:

  • Fix .assume() bug that causes M=1 cases to fail in some kernels
  • Add test cases for M = 1 to all tests
  • Standardize weight shapes to (N, K)

@willzhou-amd willzhou-amd self-assigned this Jul 1, 2025
@willzhou-amd willzhou-amd requested a review from scxiao July 2, 2025 16:49
@rahulbatra85 rahulbatra85 changed the title Fix .assume() bug that causes M=1 cases to fail in some kernels [TRITON]: Fix .assume() bug that causes M=1 cases to fail in some kernels Jul 2, 2025
@willzhou-amd willzhou-amd changed the title [TRITON]: Fix .assume() bug that causes M=1 cases to fail in some kernels [TRITON]: Standardize weight shapes to (N, K) and TN memory layout (by default) Jul 2, 2025
@willzhou-amd willzhou-amd changed the title [TRITON]: Standardize weight shapes to (N, K) and TN memory layout (by default) [TRITON]: Standardize GEMM weight shape to (N, K) and TN memory layout (by default) Jul 2, 2025
@willzhou-amd
Copy link
Contributor Author

There's a gnarly issue where Triton will implicitly promote small integers that it believes to be constants to tl.constant(). In some cases where the dimensions are small, typecasting arguments to your kernel will yield a cast error:

    batch_id = batch_id.to(tl.int64)
    stride_ab = stride_ab.to(tl.int64)
                ^
AttributeError("'constexpr' object has no attribute 'to'")

Replacing .to() with tl.cast() prevents this implicit promotion. I'm uncertain whether this happens on other versions of Triton, but this is reliably replicable on Triton 3.3.1.

tldr: If you're attempting to cast strides to int64 to prevent integer overflow, use tl.cast.

@rahulbatra85 rahulbatra85 self-requested a review July 8, 2025 18:52
rahulbatra85
rahulbatra85 previously approved these changes Jul 8, 2025
@rahulbatra85 rahulbatra85 merged commit e7570ed into main Jul 10, 2025
13 checks passed
@rahulbatra85 rahulbatra85 deleted the willz/weight-shape-debug branch July 10, 2025 01:35
fsx950223 pushed a commit that referenced this pull request Jul 11, 2025
…t (by default) (#597)

* Fix .assume() bug that causes M=1 cases to fail in some kernels

* Add minimal test cases (M, N, K) = (1, 1, 1)

* Fix bug where stride_.. becomes a tl.constant and raises a cast error

* Add weight shape changes for a8w8, a8w8 blockscale, a16w16

* Add weight shape changes for the rest of the GEMMs. FP4 kernels not yet validated

* Fix tensor ops for the afp4wfp4 GEMM

* Fix tensor ops for the afp4wfp4 pre-quant GEMM

* Add weight shape changes for gemm_afp4wfp4 kernels (atomic & standard)

* Formatting changes

* Fix bug where stride_.. becomes a tl.constant and raises a cast error

* Add stride int64 casts back

* Fix bug where stride_.. becomes implicitly promoted to a tl.constant and raises a cast error. Solution involves using tl.cast

* Add cast debug comment

* Temp change to use int64 strides
cagrikymk pushed a commit that referenced this pull request Jul 30, 2025
…t (by default) (#597)

* Fix .assume() bug that causes M=1 cases to fail in some kernels

* Add minimal test cases (M, N, K) = (1, 1, 1)

* Fix bug where stride_.. becomes a tl.constant and raises a cast error

* Add weight shape changes for a8w8, a8w8 blockscale, a16w16

* Add weight shape changes for the rest of the GEMMs. FP4 kernels not yet validated

* Fix tensor ops for the afp4wfp4 GEMM

* Fix tensor ops for the afp4wfp4 pre-quant GEMM

* Add weight shape changes for gemm_afp4wfp4 kernels (atomic & standard)

* Formatting changes

* Fix bug where stride_.. becomes a tl.constant and raises a cast error

* Add stride int64 casts back

* Fix bug where stride_.. becomes implicitly promoted to a tl.constant and raises a cast error. Solution involves using tl.cast

* Add cast debug comment

* Temp change to use int64 strides
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants