Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 36 additions & 33 deletions aiter/jit/utils/torch_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,23 +78,19 @@ def _is_torch_equal_or_newer(torch_version: str, target: str) -> bool:
"qr_get_handle",
]

# We default all args are inplace, you can define inplace args for specific op
SPECIAL_OPS_MUTATES_ARGS = {}


def generate_schema(func) -> str:
def generate_schema(func, mutates_args: Union[list[str], str] = "unknown") -> str:
import inspect

import torch

sig = inspect.signature(func)
parameters = []
mutates_args = SPECIAL_OPS_MUTATES_ARGS.get(func.__name__, [])
for idx, (name, param) in enumerate(sig.parameters.items()):
param_type = param.annotation
flag = True
is_mutates = True
if len(mutates_args) > 0 and name not in mutates_args:
if mutates_args != "unknown" and name not in mutates_args:
is_mutates = False

if param_type is torch.Tensor:
Expand Down Expand Up @@ -188,7 +184,7 @@ def generate_schema(func) -> str:


def torch_compile_guard(
mutates_args: list[str] = [],
mutates_args: Union[list[str], str] = "unknown",
device: str = "cpu",
calling_func_: Optional[Callable[..., Any]] = None,
gen_fake: Optional[Callable[..., Any]] = None,
Expand Down Expand Up @@ -224,11 +220,8 @@ def wrapper_register(calling_func):
schema = generate_schema(calling_func)
else:
sig = inspect.signature(calling_func)
mutates_args = SPECIAL_OPS_MUTATES_ARGS.get(
calling_func.__name__, "unknown"
)
if hasattr(torch.library, "infer_schema"):
sig = torch.library.infer_schema(
schema = torch.library.infer_schema(
calling_func, mutates_args=mutates_args
)
else:
Expand All @@ -237,14 +230,15 @@ def wrapper_register(calling_func):

# torch 2.4 not support mutates "unknown" for inplace all param
if mutates_args == "unknown":
mutates_args = []
mutates_args_custom = []

for param_name, param in sig.parameters.items():
if param.annotation == torch.Tensor:
mutates_args.append(param_name)
mutates_args_custom.append(param_name)

sig = torch._custom_op.impl.infer_schema(calling_func, mutates_args)
schema = f"{sig}"
schema = torch._custom_op.impl.infer_schema(
calling_func, mutates_args_custom
)
return schema

schema = wrapper_register(calling_func)
Expand Down Expand Up @@ -280,11 +274,27 @@ def wrapper_register(calling_func):

loadName = calling_func.__name__

def abstract_impl(*args, custom_build_args={}, **kwargs):
if return_non_tensor:
return torch.empty(1, device=device), 1
def wrapper_custom(*args, **kwargs):
result = (
getattr(torch.ops.aiter, f"{loadName}")(*args, **kwargs)
if input_is_tensor
else getattr(torch.ops.aiter, f"{loadName}")(
torch.empty(1, device=device), *args, **kwargs
)
)
return result[1] if return_non_tensor else result

if hasattr(torch.ops.aiter, loadName):
return wrapper_custom

def abstract_impl(*args, **kwargs):
if gen_fake is not None:
return gen_fake(*args, **kwargs)
if return_non_tensor:
return torch.empty(1, device=device), gen_fake(*args, **kwargs)
else:
return gen_fake(*args, **kwargs)
if return_non_tensor:
return torch.empty(1, device=device), calling_func(*args, **kwargs)
return calling_func(*args, **kwargs)

def outer_wrapper(*args, **kwargs):
Expand All @@ -294,11 +304,14 @@ def outer_wrapper(*args, **kwargs):
else (torch.empty(1, device=device), wrapper(*args, **kwargs))
)

def abstract_impl_dummy(dummy, *args, custom_build_args={}, **kwargs):
if return_non_tensor:
return torch.empty(1, device=device), 1
def abstract_impl_dummy(dummy, *args, **kwargs):
if gen_fake is not None:
return gen_fake(*args, **kwargs)
if return_non_tensor:
return torch.empty(1, device=device), gen_fake(*args, **kwargs)
else:
return gen_fake(*args, **kwargs)
if return_non_tensor:
return torch.empty(1, device=device), calling_func(*args, **kwargs)
return calling_func(*args, **kwargs)

def outer_wrapper_dummy(dummy, *args, **kwargs):
Expand All @@ -325,16 +338,6 @@ def outer_wrapper_dummy(dummy, *args, **kwargs):
aiter_lib.impl(f"aiter::{loadName}", custom_func, dispatch_key="CPU")
aiter_lib._register_fake(f"{loadName}", fake_func)

def wrapper_custom(*args, custom_build_args={}, **kwargs):
result = (
getattr(torch.ops.aiter, f"{loadName}")(*args, **kwargs)
if input_is_tensor
else getattr(torch.ops.aiter, f"{loadName}")(
torch.empty(1, device=device), *args, **kwargs
)
)
return result[1] if return_non_tensor else result

return wrapper_custom

return decorator
18 changes: 16 additions & 2 deletions aiter/ops/gemm_op_a4w4.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

import functools
import os
from typing import Optional

from aiter.jit.utils.torch_guard import torch_compile_guard
import pandas as pd
import torch
from torch import Tensor
Expand All @@ -14,7 +14,6 @@
from ..jit.core import (
AITER_CONFIG_GEMM_A4W4_FILE,
AITER_LOG_TUNED_CONFIG,
AITER_ROOT_DIR,
compile_ops,
)
from ..jit.utils.chip_info import get_cu_num, get_gfx
Expand Down Expand Up @@ -60,6 +59,21 @@ def get_GEMM_config(M: int, N: int, K: int):
return config


def gemm_a4w4_fake(
A: Tensor, # A:[M, K/2] f4x2
B: Tensor, # B:[N, K/2] f4x2
A_scale: Tensor, # A_scale:[M, K/32] e8m0 paded
B_scale: Tensor, # B_scale:[N, K/32] e8m0 paded
out: Tensor, # Out:[M, N] bf16
bias: Optional[Tensor] = None, # bias:[1, N] f32
alpha: Optional[float] = 1.0,
beta: Optional[float] = 0.0,
bpreshuffle: Optional[bool] = True,
) -> torch.Tensor:
return out


@torch_compile_guard(gen_fake=gemm_a4w4_fake)
def gemm_a4w4(
A: Tensor, # A:[M, K/2] f4x2
B: Tensor, # B:[N, K/2] f4x2
Expand Down
72 changes: 53 additions & 19 deletions aiter/ops/gemm_op_a8w8.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ def compute_gemm_SplitK(M: int, N: int, K: int, tile_m: int, tile_n: int, tile_k
_CKGEMM_CONFIG_CACHE = None


@torch_compile_guard()
def get_CKGEMM_config_(tuned_file: str = None) -> None:
@functools.lru_cache(maxsize=1024)
def get_CKGEMM_config(M: int, N: int, K: int, tuned_file="a8w8_tuned_gemm.csv"):
if tuned_file is None:
tuned_file = "a8w8_tuned_gemm.csv"
global _CKGEMM_CONFIG_CACHE
Expand All @@ -221,13 +221,6 @@ def get_CKGEMM_config_(tuned_file: str = None) -> None:
["cu_num", "M", "N", "K"]
).to_dict("index")

return None


@functools.lru_cache(maxsize=1024)
def get_CKGEMM_config(M: int, N: int, K: int, tuned_file="a8w8_tuned_gemm.csv"):
get_CKGEMM_config_(tuned_file)

cu_num = get_cu_num()

padded_M = M
Expand Down Expand Up @@ -277,15 +270,28 @@ def get_bpreshuffle_GEMM_config(
return config


def gemm_a8w8_fake(
XQ: Tensor,
WQ: Tensor,
x_scale: Tensor,
w_scale: Tensor,
bias: Optional[Tensor] = None,
dtype: torch.dtype = dtypes.bf16,
splitK: Optional[int] = None,
) -> Tensor:
return torch.empty(XQ.shape[0], WQ.shape[0], dtype=dtype, device=XQ.device)


@torch_compile_guard(gen_fake=gemm_a8w8_fake)
def gemm_a8w8(
XQ: Tensor,
WQ: Tensor,
x_scale: Tensor,
w_scale: Tensor,
bias: Optional[Tensor] = None,
dtype=dtypes.bf16,
dtype: torch.dtype = dtypes.bf16,
splitK: Optional[int] = None,
):
) -> Tensor:
# assert dtype in [
# dtypes.bf16,
# dtypes.fp16,
Expand Down Expand Up @@ -350,9 +356,9 @@ def gemm_a8w8_CK(
x_scale: Tensor,
w_scale: Tensor,
bias: Optional[Tensor] = None,
dtype=dtypes.bf16,
dtype: torch.dtype = dtypes.bf16,
splitK: Optional[int] = None,
):
) -> Tensor:
# assert dtype in [
# dtypes.bf16,
# dtypes.fp16,
Expand All @@ -370,15 +376,28 @@ def gemm_a8w8_CK(
return gemm_a8w8_ck(XQ, WQ, x_scale, w_scale, Y, bias, splitK)


def gemm_a8w8_bpreshuffle_fake(
XQ: Tensor,
WQ: Tensor,
x_scale: Tensor,
w_scale: Tensor,
bias: Optional[Tensor] = None,
dtype: torch.dtype = dtypes.bf16,
check: bool = False,
) -> Tensor:
return torch.empty(XQ.shape[0], WQ.shape[0], dtype=dtype, device=XQ.device)


@torch_compile_guard(gen_fake=gemm_a8w8_bpreshuffle_fake)
def gemm_a8w8_bpreshuffle(
XQ: Tensor,
WQ: Tensor,
x_scale: Tensor,
w_scale: Tensor,
bias: Optional[Tensor] = None,
dtype=torch.float16,
check=False,
):
dtype: torch.dtype = dtypes.bf16,
check: bool = False,
) -> Tensor:
assert dtype in [
torch.bfloat16,
torch.float16,
Expand Down Expand Up @@ -410,7 +429,7 @@ def gemm_a8w8_blockscale_fake(
WQ: Tensor,
x_scale: Tensor,
w_scale: Tensor,
dtype=dtypes.bf16,
dtype: torch.dtype = dtypes.bf16,
isBpreshuffled=False,
) -> torch.Tensor:
m = XQ.shape[0]
Expand Down Expand Up @@ -465,9 +484,24 @@ def flatmm_a8w8_blockscale_ASM(
return flatmm_a8w8_blockscale_asm(XQ, WQ, x_scale, w_scale, Y)


def gemm_a8w8_blockscale_bpreshuffle_fake(
XQ: Tensor,
WQ: Tensor,
x_scale: Tensor,
w_scale: Tensor,
dtype: torch.dtype = dtypes.bf16,
) -> Tensor:
return torch.empty(XQ.shape[0], WQ.shape[0], dtype=dtype, device=XQ.device)


@torch_compile_guard(gen_fake=gemm_a8w8_blockscale_bpreshuffle_fake)
def gemm_a8w8_blockscale_bpreshuffle(
XQ: Tensor, WQ: Tensor, x_scale: Tensor, w_scale: Tensor, dtype=dtypes.bf16
):
XQ: Tensor,
WQ: Tensor,
x_scale: Tensor,
w_scale: Tensor,
dtype: torch.dtype = dtypes.bf16,
) -> Tensor:
assert dtype in [
dtypes.bf16,
dtypes.fp16,
Expand Down
38 changes: 37 additions & 1 deletion aiter/ops/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -1257,7 +1257,7 @@ def _validate_cu(name: str, x: Optional[torch.Tensor]):
return out, softmax_lse, S_dmask, rng_state


@torch_compile_guard()
# @torch_compile_guard(mutates_args=[])
def can_impl_fmha_v3_bwd(
dout: torch.Tensor,
q: torch.Tensor,
Expand Down Expand Up @@ -1436,6 +1436,42 @@ def psskddv():
return ret


def _flash_attn_backward_fake(
dout: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: torch.Tensor,
softmax_lse: torch.Tensor,
dq: Optional[torch.Tensor],
dk: Optional[torch.Tensor],
dv: Optional[torch.Tensor],
dbias: Optional[torch.Tensor],
dropout_p: float,
softmax_scale: float,
causal: bool,
window_size_left: int,
window_size_right: int,
bias: Optional[torch.Tensor],
alibi_slopes: Optional[torch.Tensor],
deterministic: bool,
rng_state: Optional[torch.Tensor] = None,
is_v3_atomic_fp32: Optional[bool] = True,
how_v3_bf16_cvt: Optional[int] = 1,
) -> torch.Tensor:
batch_size = q.size(0)
seqlen_q = q.size(1)
num_heads = q.size(2)

softmax_d = torch.empty(
(batch_size, num_heads, seqlen_q), # {batch_size, num_heads, seqlen_q}
dtype=torch.float32,
device=q.device,
)
return softmax_d


@torch_compile_guard(gen_fake=_flash_attn_backward_fake)
def _flash_attn_backward(
dout: torch.Tensor,
q: torch.Tensor,
Expand Down