-
Notifications
You must be signed in to change notification settings - Fork 166
fix: gemm operator's fake shall follow actual function flow #1757
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR updates the gemm_afp4wfp4 operator's fake tensor function to follow the same code flow as the actual implementation, ensuring consistency in how configurations are handled and tensors are allocated.
Key Changes
- Refactored
gemm_afp4wfp4_fake_tensorto mirror the actual function's logic for config handling, splitk computation, and tensor allocation - Added
skip_reduceparameter togemm_afp4wfp4wrapper function to support split-k reduction control - Updated config handling in the wrapper to compute default config when not provided
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
aiter/ops/triton/gemm_afp4wfp4.py
Outdated
| N, _ = w.shape | ||
| config = _get_config(M, N, K) | ||
| config_hashable = serialize_dict(config) | ||
| return gemm_afp4wfp4_(x, w, x_scales, w_scales, dtype, y, config_hashable) |
Copilot
AI
Dec 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The skip_reduce parameter is added to the function signature but is not passed to the inner gemm_afp4wfp4_ function. This means the skip_reduce functionality will not work as intended when calling this wrapper function. The call should include skip_reduce as an argument.
| return gemm_afp4wfp4_(x, w, x_scales, w_scales, dtype, y, config_hashable) | |
| return gemm_afp4wfp4_(x, w, x_scales, w_scales, dtype, y, config_hashable, skip_reduce=skip_reduce) |
| _, block_size_k, num_ksplit = get_splitk( | ||
| K, config["BLOCK_SIZE_K"], num_ksplit | ||
| ) | ||
|
|
||
| if block_size_k >= 2 * K: | ||
| num_ksplit= 1 |
Copilot
AI
Dec 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fake function uses block_size_k (lowercase variable) but does not update the config dictionary after the check. In the actual function (lines 169-173), when BLOCK_SIZE_K >= 2K, it updates config["BLOCK_SIZE_K"], config["SPLITK_BLOCK_SIZE"], and config["NUM_KSPLIT"], and also applies max(config["BLOCK_SIZE_K"], 128). The fake function only sets the local num_ksplit variable but doesn't update block_size_k with triton.next_power_of_2(2K) or apply the 128 minimum. This inconsistency could lead to different behavior between fake and actual functions.
| _, block_size_k, num_ksplit = get_splitk( | |
| K, config["BLOCK_SIZE_K"], num_ksplit | |
| ) | |
| if block_size_k >= 2 * K: | |
| num_ksplit= 1 | |
| splitk_block_size, block_size_k, num_ksplit = get_splitk( | |
| K, config["BLOCK_SIZE_K"], num_ksplit | |
| ) | |
| config["SPLITK_BLOCK_SIZE"] = splitk_block_size | |
| config["BLOCK_SIZE_K"] = block_size_k | |
| config["NUM_KSPLIT"] = num_ksplit | |
| if block_size_k >= 2 * K: | |
| block_size_k = max(triton.next_power_of_2(2 * K), 128) | |
| num_ksplit = 1 | |
| config["BLOCK_SIZE_K"] = block_size_k | |
| config["SPLITK_BLOCK_SIZE"] = block_size_k | |
| config["NUM_KSPLIT"] = num_ksplit |
| else: | ||
| y_pp = None |
Copilot
AI
Dec 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fake function does not set SPLITK_BLOCK_SIZE when num_ksplit == 1. In the actual function (line 185), when NUM_KSPLIT == 1, it sets config["SPLITK_BLOCK_SIZE"] = 2 * K. While this might not affect the fake function's tensor allocation behavior, it represents an inconsistency in following the actual function flow, which is the stated goal of this PR.
|
|
||
| M, K = x.shape | ||
| N, _ = w.shape | ||
|
|
Copilot
AI
Dec 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The config parameter can be None according to the function signature, but deserialize_str is called on it without checking. This will cause an AttributeError or similar exception when config is None. The actual function gemm_afp4wfp4_ handles this case by checking if config is None and calling _get_config if needed (lines 155-158).
| if config is None: | |
| raise ValueError("gemm_afp4wfp4_fake_tensor requires a non-None serialized config.") |
| if num_ksplit > 1: | ||
| if _USE_GEMM_SPLITK_BF16: | ||
| y_pp = torch.empty( | ||
| (num_ksplit, M, N), dtype=y.dtype, device=x.device |
Copilot
AI
Dec 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When _USE_GEMM_SPLITK_BF16 is True, this accesses y.dtype but y can be None at this point (line 104 shows y may be None). This will cause an AttributeError. In the actual function (lines 177-179), y is guaranteed to be non-None at this point because it's either passed in or allocated earlier. Consider using dtype parameter instead of y.dtype to match the actual function behavior.
| (num_ksplit, M, N), dtype=y.dtype, device=x.device | |
| (num_ksplit, M, N), dtype=dtype, device=x.device |
| config_hashable = None | ||
| M, K = x.shape | ||
| N, _ = w.shape | ||
| config = _get_config(M, N, K) | ||
| config_hashable = serialize_dict(config) |
Copilot
AI
Dec 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The config_hashable variable is set to None when config is None, but then it's unconditionally assigned on line 597, losing the None value. This means config_hashable will never be None when passed to gemm_afp4wfp4_. The logic should either check if config_hashable is None before line 597, or only set config_hashable once after potentially calling _get_config.
| config_hashable = None | |
| M, K = x.shape | |
| N, _ = w.shape | |
| config = _get_config(M, N, K) | |
| config_hashable = serialize_dict(config) | |
| M, K = x.shape | |
| N, _ = w.shape | |
| config = _get_config(M, N, K) | |
| config_hashable = serialize_dict(config) if config is not None else None |
| (num_ksplit, M, N), dtype=torch.float32, device=x.device | ||
| ) | ||
| else: | ||
| y_pp = None |
Copilot
AI
Dec 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable y_pp is not used.
Motivation
GEMM fake should follow the same code flow as actual function.
Technical Details
Test Plan
Test Result
Submission Checklist