Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 43 additions & 6 deletions aiter/ops/triton/gemm_afp4wfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,41 @@ def gemm_afp4wfp4_fake_tensor(
config: Optional[str] = None,
skip_reduce: Optional[bool] = False,
) -> torch.Tensor:
if y is None:
M, _ = x.shape
N, _ = w.shape
return torch.empty((M, N), dtype=dtype, device=x.device)

M, K = x.shape
N, _ = w.shape

Copy link

Copilot AI Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The config parameter can be None according to the function signature, but deserialize_str is called on it without checking. This will cause an AttributeError or similar exception when config is None. The actual function gemm_afp4wfp4_ handles this case by checking if config is None and calling _get_config if needed (lines 155-158).

Suggested change
if config is None:
raise ValueError("gemm_afp4wfp4_fake_tensor requires a non-None serialized config.")

Copilot uses AI. Check for mistakes.
config = deserialize_str(config)
num_ksplit = config["NUM_KSPLIT"]
block_size_k = config["BLOCK_SIZE_K"]

if num_ksplit > 1:
_, block_size_k, num_ksplit = get_splitk(
K, config["BLOCK_SIZE_K"], num_ksplit
)

if block_size_k >= 2 * K:
num_ksplit= 1
Comment on lines +85 to +90
Copy link

Copilot AI Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fake function uses block_size_k (lowercase variable) but does not update the config dictionary after the check. In the actual function (lines 169-173), when BLOCK_SIZE_K >= 2K, it updates config["BLOCK_SIZE_K"], config["SPLITK_BLOCK_SIZE"], and config["NUM_KSPLIT"], and also applies max(config["BLOCK_SIZE_K"], 128). The fake function only sets the local num_ksplit variable but doesn't update block_size_k with triton.next_power_of_2(2K) or apply the 128 minimum. This inconsistency could lead to different behavior between fake and actual functions.

Suggested change
_, block_size_k, num_ksplit = get_splitk(
K, config["BLOCK_SIZE_K"], num_ksplit
)
if block_size_k >= 2 * K:
num_ksplit= 1
splitk_block_size, block_size_k, num_ksplit = get_splitk(
K, config["BLOCK_SIZE_K"], num_ksplit
)
config["SPLITK_BLOCK_SIZE"] = splitk_block_size
config["BLOCK_SIZE_K"] = block_size_k
config["NUM_KSPLIT"] = num_ksplit
if block_size_k >= 2 * K:
block_size_k = max(triton.next_power_of_2(2 * K), 128)
num_ksplit = 1
config["BLOCK_SIZE_K"] = block_size_k
config["SPLITK_BLOCK_SIZE"] = block_size_k
config["NUM_KSPLIT"] = num_ksplit

Copilot uses AI. Check for mistakes.

if num_ksplit > 1:
if _USE_GEMM_SPLITK_BF16:
y_pp = torch.empty(
(num_ksplit, M, N), dtype=y.dtype, device=x.device
Copy link

Copilot AI Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When _USE_GEMM_SPLITK_BF16 is True, this accesses y.dtype but y can be None at this point (line 104 shows y may be None). This will cause an AttributeError. In the actual function (lines 177-179), y is guaranteed to be non-None at this point because it's either passed in or allocated earlier. Consider using dtype parameter instead of y.dtype to match the actual function behavior.

Suggested change
(num_ksplit, M, N), dtype=y.dtype, device=x.device
(num_ksplit, M, N), dtype=dtype, device=x.device

Copilot uses AI. Check for mistakes.
)
else:
y_pp = torch.empty(
(num_ksplit, M, N), dtype=torch.float32, device=x.device
)
else:
y_pp = None
Comment on lines +101 to +102
Copy link

Copilot AI Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fake function does not set SPLITK_BLOCK_SIZE when num_ksplit == 1. In the actual function (line 185), when NUM_KSPLIT == 1, it sets config["SPLITK_BLOCK_SIZE"] = 2 * K. While this might not affect the fake function's tensor allocation behavior, it represents an inconsistency in following the actual function flow, which is the stated goal of this PR.

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable y_pp is not used.

Copilot uses AI. Check for mistakes.

if y is None and (num_ksplit == 1 or not skip_reduce):
y = torch.empty((M, N), dtype=dtype, device=x.device)

if num_ksplit > 1:
if skip_reduce:
return y_pp

return y


Expand Down Expand Up @@ -556,6 +587,12 @@ def gemm_afp4wfp4(
dtype: Optional[torch.dtype] = torch.bfloat16,
y: Optional[torch.Tensor] = None,
config: Optional[dict] = None,
skip_reduce: Optional[bool] = False,
) -> torch.Tensor:
config_hashable = serialize_dict(config) if config else None
return gemm_afp4wfp4_(x, w, x_scales, w_scales, dtype, y, config_hashable)
if config is None:
config_hashable = None
M, K = x.shape
N, _ = w.shape
config = _get_config(M, N, K)
config_hashable = serialize_dict(config)
Comment on lines +593 to +597
Copy link

Copilot AI Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The config_hashable variable is set to None when config is None, but then it's unconditionally assigned on line 597, losing the None value. This means config_hashable will never be None when passed to gemm_afp4wfp4_. The logic should either check if config_hashable is None before line 597, or only set config_hashable once after potentially calling _get_config.

Suggested change
config_hashable = None
M, K = x.shape
N, _ = w.shape
config = _get_config(M, N, K)
config_hashable = serialize_dict(config)
M, K = x.shape
N, _ = w.shape
config = _get_config(M, N, K)
config_hashable = serialize_dict(config) if config is not None else None

Copilot uses AI. Check for mistakes.
return gemm_afp4wfp4_(x, w, x_scales, w_scales, dtype, y, config_hashable, skip_reduce)
Loading