-
Notifications
You must be signed in to change notification settings - Fork 171
[TRITON]: Standardize GEMM weight shape to (N, K) and TN memory layout (by default) #597
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…and raises a cast error. Solution involves using tl.cast
|
There's a gnarly issue where Triton will implicitly promote small integers that it believes to be constants to batch_id = batch_id.to(tl.int64)
stride_ab = stride_ab.to(tl.int64)
^
AttributeError("'constexpr' object has no attribute 'to'")Replacing tldr: If you're attempting to cast strides to int64 to prevent integer overflow, use |
…t (by default) (#597) * Fix .assume() bug that causes M=1 cases to fail in some kernels * Add minimal test cases (M, N, K) = (1, 1, 1) * Fix bug where stride_.. becomes a tl.constant and raises a cast error * Add weight shape changes for a8w8, a8w8 blockscale, a16w16 * Add weight shape changes for the rest of the GEMMs. FP4 kernels not yet validated * Fix tensor ops for the afp4wfp4 GEMM * Fix tensor ops for the afp4wfp4 pre-quant GEMM * Add weight shape changes for gemm_afp4wfp4 kernels (atomic & standard) * Formatting changes * Fix bug where stride_.. becomes a tl.constant and raises a cast error * Add stride int64 casts back * Fix bug where stride_.. becomes implicitly promoted to a tl.constant and raises a cast error. Solution involves using tl.cast * Add cast debug comment * Temp change to use int64 strides
…t (by default) (#597) * Fix .assume() bug that causes M=1 cases to fail in some kernels * Add minimal test cases (M, N, K) = (1, 1, 1) * Fix bug where stride_.. becomes a tl.constant and raises a cast error * Add weight shape changes for a8w8, a8w8 blockscale, a16w16 * Add weight shape changes for the rest of the GEMMs. FP4 kernels not yet validated * Fix tensor ops for the afp4wfp4 GEMM * Fix tensor ops for the afp4wfp4 pre-quant GEMM * Add weight shape changes for gemm_afp4wfp4 kernels (atomic & standard) * Formatting changes * Fix bug where stride_.. becomes a tl.constant and raises a cast error * Add stride int64 casts back * Fix bug where stride_.. becomes implicitly promoted to a tl.constant and raises a cast error. Solution involves using tl.cast * Add cast debug comment * Temp change to use int64 strides
Changes: