Skip to content

Conversation

@mqhc2020
Copy link
Contributor

Motivation

GEMM fake should follow the same code flow as actual function.

Technical Details

Test Plan

Test Result

Submission Checklist

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the gemm_afp4wfp4 operator's fake tensor function to follow the same code flow as the actual implementation, ensuring consistency in how configurations are handled and tensors are allocated.

Key Changes

  • Refactored gemm_afp4wfp4_fake_tensor to mirror the actual function's logic for config handling, splitk computation, and tensor allocation
  • Added skip_reduce parameter to gemm_afp4wfp4 wrapper function to support split-k reduction control
  • Updated config handling in the wrapper to compute default config when not provided

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

N, _ = w.shape
config = _get_config(M, N, K)
config_hashable = serialize_dict(config)
return gemm_afp4wfp4_(x, w, x_scales, w_scales, dtype, y, config_hashable)
Copy link

Copilot AI Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The skip_reduce parameter is added to the function signature but is not passed to the inner gemm_afp4wfp4_ function. This means the skip_reduce functionality will not work as intended when calling this wrapper function. The call should include skip_reduce as an argument.

Suggested change
return gemm_afp4wfp4_(x, w, x_scales, w_scales, dtype, y, config_hashable)
return gemm_afp4wfp4_(x, w, x_scales, w_scales, dtype, y, config_hashable, skip_reduce=skip_reduce)

Copilot uses AI. Check for mistakes.
Comment on lines +85 to +90
_, block_size_k, num_ksplit = get_splitk(
K, config["BLOCK_SIZE_K"], num_ksplit
)

if block_size_k >= 2 * K:
num_ksplit= 1
Copy link

Copilot AI Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fake function uses block_size_k (lowercase variable) but does not update the config dictionary after the check. In the actual function (lines 169-173), when BLOCK_SIZE_K >= 2K, it updates config["BLOCK_SIZE_K"], config["SPLITK_BLOCK_SIZE"], and config["NUM_KSPLIT"], and also applies max(config["BLOCK_SIZE_K"], 128). The fake function only sets the local num_ksplit variable but doesn't update block_size_k with triton.next_power_of_2(2K) or apply the 128 minimum. This inconsistency could lead to different behavior between fake and actual functions.

Suggested change
_, block_size_k, num_ksplit = get_splitk(
K, config["BLOCK_SIZE_K"], num_ksplit
)
if block_size_k >= 2 * K:
num_ksplit= 1
splitk_block_size, block_size_k, num_ksplit = get_splitk(
K, config["BLOCK_SIZE_K"], num_ksplit
)
config["SPLITK_BLOCK_SIZE"] = splitk_block_size
config["BLOCK_SIZE_K"] = block_size_k
config["NUM_KSPLIT"] = num_ksplit
if block_size_k >= 2 * K:
block_size_k = max(triton.next_power_of_2(2 * K), 128)
num_ksplit = 1
config["BLOCK_SIZE_K"] = block_size_k
config["SPLITK_BLOCK_SIZE"] = block_size_k
config["NUM_KSPLIT"] = num_ksplit

Copilot uses AI. Check for mistakes.
Comment on lines +101 to +102
else:
y_pp = None
Copy link

Copilot AI Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fake function does not set SPLITK_BLOCK_SIZE when num_ksplit == 1. In the actual function (line 185), when NUM_KSPLIT == 1, it sets config["SPLITK_BLOCK_SIZE"] = 2 * K. While this might not affect the fake function's tensor allocation behavior, it represents an inconsistency in following the actual function flow, which is the stated goal of this PR.

Copilot uses AI. Check for mistakes.

M, K = x.shape
N, _ = w.shape

Copy link

Copilot AI Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The config parameter can be None according to the function signature, but deserialize_str is called on it without checking. This will cause an AttributeError or similar exception when config is None. The actual function gemm_afp4wfp4_ handles this case by checking if config is None and calling _get_config if needed (lines 155-158).

Suggested change
if config is None:
raise ValueError("gemm_afp4wfp4_fake_tensor requires a non-None serialized config.")

Copilot uses AI. Check for mistakes.
if num_ksplit > 1:
if _USE_GEMM_SPLITK_BF16:
y_pp = torch.empty(
(num_ksplit, M, N), dtype=y.dtype, device=x.device
Copy link

Copilot AI Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When _USE_GEMM_SPLITK_BF16 is True, this accesses y.dtype but y can be None at this point (line 104 shows y may be None). This will cause an AttributeError. In the actual function (lines 177-179), y is guaranteed to be non-None at this point because it's either passed in or allocated earlier. Consider using dtype parameter instead of y.dtype to match the actual function behavior.

Suggested change
(num_ksplit, M, N), dtype=y.dtype, device=x.device
(num_ksplit, M, N), dtype=dtype, device=x.device

Copilot uses AI. Check for mistakes.
Comment on lines +593 to +597
config_hashable = None
M, K = x.shape
N, _ = w.shape
config = _get_config(M, N, K)
config_hashable = serialize_dict(config)
Copy link

Copilot AI Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The config_hashable variable is set to None when config is None, but then it's unconditionally assigned on line 597, losing the None value. This means config_hashable will never be None when passed to gemm_afp4wfp4_. The logic should either check if config_hashable is None before line 597, or only set config_hashable once after potentially calling _get_config.

Suggested change
config_hashable = None
M, K = x.shape
N, _ = w.shape
config = _get_config(M, N, K)
config_hashable = serialize_dict(config)
M, K = x.shape
N, _ = w.shape
config = _get_config(M, N, K)
config_hashable = serialize_dict(config) if config is not None else None

Copilot uses AI. Check for mistakes.
(num_ksplit, M, N), dtype=torch.float32, device=x.device
)
else:
y_pp = None
Copy link

Copilot AI Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable y_pp is not used.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant