-
Notifications
You must be signed in to change notification settings - Fork 166
fix: gemm operator's fake shall follow actual function flow #1757
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -73,10 +73,41 @@ def gemm_afp4wfp4_fake_tensor( | |||||||||||||||||||||||||||||||||||||||
| config: Optional[str] = None, | ||||||||||||||||||||||||||||||||||||||||
| skip_reduce: Optional[bool] = False, | ||||||||||||||||||||||||||||||||||||||||
| ) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||||||||
| if y is None: | ||||||||||||||||||||||||||||||||||||||||
| M, _ = x.shape | ||||||||||||||||||||||||||||||||||||||||
| N, _ = w.shape | ||||||||||||||||||||||||||||||||||||||||
| return torch.empty((M, N), dtype=dtype, device=x.device) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| M, K = x.shape | ||||||||||||||||||||||||||||||||||||||||
| N, _ = w.shape | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| config = deserialize_str(config) | ||||||||||||||||||||||||||||||||||||||||
| num_ksplit = config["NUM_KSPLIT"] | ||||||||||||||||||||||||||||||||||||||||
| block_size_k = config["BLOCK_SIZE_K"] | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| if num_ksplit > 1: | ||||||||||||||||||||||||||||||||||||||||
| _, block_size_k, num_ksplit = get_splitk( | ||||||||||||||||||||||||||||||||||||||||
| K, config["BLOCK_SIZE_K"], num_ksplit | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| if block_size_k >= 2 * K: | ||||||||||||||||||||||||||||||||||||||||
| num_ksplit= 1 | ||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+85
to
+90
|
||||||||||||||||||||||||||||||||||||||||
| _, block_size_k, num_ksplit = get_splitk( | |
| K, config["BLOCK_SIZE_K"], num_ksplit | |
| ) | |
| if block_size_k >= 2 * K: | |
| num_ksplit= 1 | |
| splitk_block_size, block_size_k, num_ksplit = get_splitk( | |
| K, config["BLOCK_SIZE_K"], num_ksplit | |
| ) | |
| config["SPLITK_BLOCK_SIZE"] = splitk_block_size | |
| config["BLOCK_SIZE_K"] = block_size_k | |
| config["NUM_KSPLIT"] = num_ksplit | |
| if block_size_k >= 2 * K: | |
| block_size_k = max(triton.next_power_of_2(2 * K), 128) | |
| num_ksplit = 1 | |
| config["BLOCK_SIZE_K"] = block_size_k | |
| config["SPLITK_BLOCK_SIZE"] = block_size_k | |
| config["NUM_KSPLIT"] = num_ksplit |
Copilot
AI
Dec 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When _USE_GEMM_SPLITK_BF16 is True, this accesses y.dtype but y can be None at this point (line 104 shows y may be None). This will cause an AttributeError. In the actual function (lines 177-179), y is guaranteed to be non-None at this point because it's either passed in or allocated earlier. Consider using dtype parameter instead of y.dtype to match the actual function behavior.
| (num_ksplit, M, N), dtype=y.dtype, device=x.device | |
| (num_ksplit, M, N), dtype=dtype, device=x.device |
Copilot
AI
Dec 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fake function does not set SPLITK_BLOCK_SIZE when num_ksplit == 1. In the actual function (line 185), when NUM_KSPLIT == 1, it sets config["SPLITK_BLOCK_SIZE"] = 2 * K. While this might not affect the fake function's tensor allocation behavior, it represents an inconsistency in following the actual function flow, which is the stated goal of this PR.
Copilot
AI
Dec 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable y_pp is not used.
Copilot
AI
Dec 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The config_hashable variable is set to None when config is None, but then it's unconditionally assigned on line 597, losing the None value. This means config_hashable will never be None when passed to gemm_afp4wfp4_. The logic should either check if config_hashable is None before line 597, or only set config_hashable once after potentially calling _get_config.
| config_hashable = None | |
| M, K = x.shape | |
| N, _ = w.shape | |
| config = _get_config(M, N, K) | |
| config_hashable = serialize_dict(config) | |
| M, K = x.shape | |
| N, _ = w.shape | |
| config = _get_config(M, N, K) | |
| config_hashable = serialize_dict(config) if config is not None else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The config parameter can be None according to the function signature, but deserialize_str is called on it without checking. This will cause an AttributeError or similar exception when config is None. The actual function gemm_afp4wfp4_ handles this case by checking if config is None and calling _get_config if needed (lines 155-158).