diff --git a/aiter/jit/core.py b/aiter/jit/core.py index 8f8a73fd01..09104e7aa5 100644 --- a/aiter/jit/core.py +++ b/aiter/jit/core.py @@ -483,14 +483,142 @@ def convert(d_ops: dict): ) +MANUAL_SCHEMA_OPS = [ + "register_graph_buffers", + "module_moe_ck2stages", + "mha_fwd", + "fmha_v3_fwd", + "mha_varlen_fwd", + "mha_bwd", + "fmha_v3_bwd", + "mha_varlen_bwd", + "fmha_v3_varlen_bwd", + "mha_batch_prefill", + "hipb_findallsols", + "rocb_findallsols", + "_ActivationType", + "_QuantType", + "init_custom_ar", +] + +NONE_WRAPPED_OP = [ + "hipb_create_extension", + "hipb_destroy_extension", + "getHipblasltKernelName", + "rocb_create_extension", + "rocb_destroy_extension", + "get_meta_buffer_ipc_handle", + "get_graph_buffer_ipc_meta", + "_ActivationType", + "_QuantType", + "allocate_meta_buffer", + "dispose", + "meta_size", + "get_padded_m", +] + + +def generate_schema(func) -> str: + import inspect + import torch + from typing import Optional, Union, List, get_origin, get_args + + sig = inspect.signature(func) + parameters = [] + + for idx, (name, param) in enumerate(sig.parameters.items()): + param_type = param.annotation + flag = True + + if param_type is torch.Tensor: + type_str = f"Tensor(a{idx}!)" + elif param_type == Optional[torch.Tensor]: + type_str = f"Tensor(a{idx}!)?" + elif get_origin(param_type) is Union and torch.Tensor in get_args(param_type): + type_str = f"Tensor(a{idx}!)?" + elif param_type in (torch.SymInt, int): + type_str = "SymInt" + elif param_type in (float, bool, str): + type_str = param_type.__name__ + elif param_type == Optional[torch.Generator]: + type_str = "Generator?" + elif ( + get_origin(param_type) in (list, List) + and get_args(param_type)[0] is torch.Tensor + ): + type_str = f"Tensor(a{idx}!)[]" + elif get_origin(param_type) in (list, List) and get_args(param_type)[0] is int: + type_str = "int[]" + else: + type_str = "*" + flag = False + if flag: + param_str = f"{type_str} {name}" + + if param.default != inspect.Parameter.empty: + if param.default is None: + param_str += "=None" + else: + param_str += f"={param.default}" + else: + param_str = f"{type_str} " + + parameters.append(param_str) + return_annotation = sig.return_annotation + return_type = "" + if return_annotation is type(None) or return_annotation is None: + return_type = "()" + elif return_annotation is torch.Tensor: + return_type = "Tensor" + elif ( + get_origin(return_annotation) is list and get_args(return_annotation)[0] is int + ): + return_type = "int[]" + elif return_annotation is int: + return_type = "int" + elif return_annotation is float: + return_type = "float" + elif return_annotation is bool: + return_type = "bool" + elif ( + get_origin(return_annotation) is list + and get_args(return_annotation)[0] is torch.Tensor + ): + return_type = "Tensor[]" + + schema = f"({', '.join(parameters)}) -> {return_type}" + + return schema + + def compile_ops( _md_name: str, fc_name: Optional[str] = None, gen_func: Optional[Callable[..., dict[str, Any]]] = None, + gen_fake: Optional[Callable[..., Any]] = None, ): + def decorator(func): + import torch + from csrc.cpp_itfs.torch_utils import aiter_lib + import torch.library + import inspect + func.arg_checked = False + schema = "" + if func.__name__ in MANUAL_SCHEMA_OPS: + schema = generate_schema(func) + else: + sig = inspect.signature(func) + mutates_args = [] + for name, param in sig.parameters.items(): + if param.annotation is torch.Tensor: + mutates_args.append(name) + sig = torch.library.infer_schema(func, mutates_args="unknown") + schema = f"{sig}" + loadName = func.__name__ + @functools.wraps(func) def wrapper(*args, custom_build_args={}, **kwargs): loadName = fc_name @@ -565,6 +693,16 @@ def wrapper(*args, custom_build_args={}, **kwargs): op = getattr(module, loadName) else: return None + activation_index = 0 + quant_index = 0 + activation_list = [ + "fmoe_g1u1", + "fmoe_int8_g1u0", + "fmoe_g1u1_tkw1", + "fmoe_fp8_blockscale_g1u1", + "moe_stage1_g1u1", + ] + quant_list = ["moe_stage1_g1u1"] def check_args(): get_asm_dir() @@ -587,7 +725,10 @@ def check_args(): func.__signature__ = sig ann = {k: v.annotation for k, v in sig.parameters.items()} ann["return"] = sig.return_annotation - + if loadName in activation_list: + return True + if loadName in quant_list: + return True callargs = inspect.getcallargs(func, *args, **kwargs) for el, arg in callargs.items(): expected_type = ann[el] @@ -632,8 +773,51 @@ def check_args(): log_args(func, *args, **kwargs) + sig = inspect.signature(func) + params = list(sig.parameters.keys()) + if loadName in activation_list: + activation_index = params.index("activation") + args_list = list(args) + from aiter import ActivationType, QuantType + + if len(args_list) > activation_index and isinstance( + args_list[activation_index], int + ): + args_list[activation_index] = ActivationType( + args_list[activation_index] + ) + args = tuple(args_list) + + if loadName in quant_list: + quant_index = params.index("quant_type") + args_list = list(args) + from aiter import ActivationType, QuantType + + if len(args_list) > quant_index and isinstance( + args_list[quant_index], int + ): + args_list[quant_index] = QuantType(args_list[quant_index]) + args = tuple(args_list) return op(*args, **kwargs) - return wrapper + def abstract_impl(*args, custom_build_args={}, **kwargs): + if gen_fake is not None: + return gen_fake(*args, **kwargs) + return func(*args, **kwargs) + + if loadName in NONE_WRAPPED_OP: + return wrapper + + if not hasattr(torch.ops.aiter, f"wrapper_{loadName}"): + op_schema = f"aiter::wrapper_{loadName}" + schema + aiter_lib.define(op_schema) + aiter_lib.impl(f"wrapper_{loadName}", wrapper, "CUDA") + aiter_lib.impl(f"wrapper_{loadName}", wrapper, "CPU") + aiter_lib._register_fake(f"wrapper_{loadName}", abstract_impl) + + def wrapper_return(*args, **kwargs): + return getattr(torch.ops.aiter, f"wrapper_{loadName}")(*args, **kwargs) + + return wrapper_return return decorator diff --git a/aiter/ops/activation.py b/aiter/ops/activation.py index 227fce3fdd..fc70b0aa9b 100644 --- a/aiter/ops/activation.py +++ b/aiter/ops/activation.py @@ -9,16 +9,16 @@ @compile_ops("module_activation") -def silu_and_mul(out: Tensor, input: Tensor): ... +def silu_and_mul(out: Tensor, input: Tensor) -> None: ... @compile_ops("module_activation") -def scaled_silu_and_mul(out: Tensor, input: Tensor, scale: Tensor): ... +def scaled_silu_and_mul(out: Tensor, input: Tensor, scale: Tensor) -> None: ... @compile_ops("module_activation") -def gelu_and_mul(out: Tensor, input: Tensor): ... +def gelu_and_mul(out: Tensor, input: Tensor) -> None: ... @compile_ops("module_activation") -def gelu_tanh_and_mul(out: Tensor, input: Tensor): ... +def gelu_tanh_and_mul(out: Tensor, input: Tensor) -> None: ... diff --git a/aiter/ops/aiter_operator.py b/aiter/ops/aiter_operator.py index ab4d0b4f99..aab26fbaac 100644 --- a/aiter/ops/aiter_operator.py +++ b/aiter/ops/aiter_operator.py @@ -5,6 +5,7 @@ from ..jit.core import compile_ops, AITER_CSRC_DIR from functools import partial from typing import Any +import torch MD_NAME = "module_aiter_operator" @@ -20,47 +21,99 @@ def cmdGenFunc(op_name: str, input: Tensor, other: Tensor) -> dict[str, Any]: } +def binary_fake_shape(input: Tensor, other: Tensor) -> Tensor: + shape1 = list(input.shape) + shape2 = list(other.shape) + + max_dim = max(len(shape1), len(shape2)) + shape1 = [1] * (max_dim - len(shape1)) + shape1 + shape2 = [1] * (max_dim - len(shape2)) + shape2 + + result_shape = [] + for dim1, dim2 in zip(shape1, shape2): + if dim1 == 1: + result_shape.append(dim2) + elif dim2 == 1: + result_shape.append(dim1) + elif dim1 == dim2: + result_shape.append(dim1) + else: + raise RuntimeError( + f"Incompatible shapes for binary operator: {input.shape} and {other.shape}" + ) + + return torch.empty( + size=result_shape, + dtype=input.dtype, + device=input.device, + ) + + +def sigmoid_fake_shape(input: torch.Tensor) -> torch.Tensor: + return torch.empty( + size=input.shape, + dtype=input.dtype, + device=input.device, + ) + + binary_add_build_args = partial(cmdGenFunc, "add") binary_sub_build_args = partial(cmdGenFunc, "sub") binary_mul_build_args = partial(cmdGenFunc, "mul") binary_div_build_args = partial(cmdGenFunc, "div") -@compile_ops("module_aiter_operator", gen_func=binary_add_build_args) +@compile_ops( + "module_aiter_operator", gen_func=binary_add_build_args, gen_fake=binary_fake_shape +) def add(input: Tensor, other: Tensor) -> Tensor: ... -@compile_ops("module_aiter_operator", gen_func=binary_sub_build_args) +@compile_ops( + "module_aiter_operator", gen_func=binary_sub_build_args, gen_fake=binary_fake_shape +) def sub(input: Tensor, other: Tensor) -> Tensor: ... -@compile_ops("module_aiter_operator", gen_func=binary_mul_build_args) +@compile_ops( + "module_aiter_operator", gen_func=binary_mul_build_args, gen_fake=binary_fake_shape +) def mul(input: Tensor, other: Tensor) -> Tensor: ... -@compile_ops("module_aiter_operator", gen_func=binary_div_build_args) +@compile_ops( + "module_aiter_operator", gen_func=binary_div_build_args, gen_fake=binary_fake_shape +) def div(input: Tensor, other: Tensor) -> Tensor: ... -@compile_ops("module_aiter_operator", gen_func=binary_add_build_args) +@compile_ops( + "module_aiter_operator", gen_func=binary_add_build_args, gen_fake=binary_fake_shape +) def add_(input: Tensor, other: Tensor) -> Tensor: ... -@compile_ops("module_aiter_operator", gen_func=binary_sub_build_args) +@compile_ops( + "module_aiter_operator", gen_func=binary_sub_build_args, gen_fake=binary_fake_shape +) def sub_(input: Tensor, other: Tensor) -> Tensor: ... -@compile_ops("module_aiter_operator", gen_func=binary_mul_build_args) +@compile_ops( + "module_aiter_operator", gen_func=binary_mul_build_args, gen_fake=binary_fake_shape +) def mul_(input: Tensor, other: Tensor) -> Tensor: ... -@compile_ops("module_aiter_operator", gen_func=binary_div_build_args) +@compile_ops( + "module_aiter_operator", gen_func=binary_div_build_args, gen_fake=binary_fake_shape +) def div_(input: Tensor, other: Tensor) -> Tensor: ... -@compile_ops("module_aiter_unary") +@compile_ops("module_aiter_unary", gen_fake=sigmoid_fake_shape) def sigmoid(input: Tensor) -> Tensor: ... -@compile_ops("module_aiter_unary") +@compile_ops("module_aiter_unary", gen_fake=sigmoid_fake_shape) def tanh(input: Tensor) -> Tensor: ... diff --git a/aiter/ops/attention.py b/aiter/ops/attention.py index febfe127d9..b9fa38a65f 100644 --- a/aiter/ops/attention.py +++ b/aiter/ops/attention.py @@ -16,7 +16,58 @@ MD_NAME = "module_attention" -@compile_ops("module_attention") +def gen_pa_fwd_native_fake( + # [num_seqs, num_heads, head_size] + query: torch.Tensor, + # [num_blocks, num_kv_heads, head_size/x, block_size, x] + key_cache: torch.Tensor, + # [num_blocks, num_kv_heads, head_size, block_size] + value_cache: torch.Tensor, + # [num_seqs, max_num_blocks_per_seq] + block_tables: torch.Tensor, + # [num_seqs] + context_lens: torch.Tensor, + k_dequant_scales: torch.Tensor, + v_dequant_scales: torch.Tensor, + max_seq_len: int, + num_kv_heads: int, + scale_s: float, + scale_k: float, + scale_v: float, + block_size: int, + quant_algo: int, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + if out is not None: + return out + else: + return torch.empty_like(query) + + +def gen_pa_fwd_asm( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + max_num_blocks: int, + max_qlen: int = 1, + K_QScale: Optional[torch.Tensor] = None, + V_QScale: Optional[torch.Tensor] = None, + out_: Optional[torch.Tensor] = None, + qo_indptr: Optional[torch.Tensor] = None, + high_precision: Optional[ + int + ] = 1, # [0, 1, 2] 2 is the highest precision, this is only for fp8 kvcache + kernelName: str = "", +): + if out_ is not None: + return out_ + else: + return torch.empty_like(query) + + +@compile_ops("module_attention", gen_fake=gen_pa_fwd_native_fake) def pa_fwd_naive( # [num_seqs, num_heads, head_size] query: torch.Tensor, @@ -41,7 +92,7 @@ def pa_fwd_naive( ) -> torch.Tensor: ... -@compile_ops("module_attention_asm") +@compile_ops("module_attention_asm", gen_fake=gen_pa_fwd_asm) def pa_fwd_asm( query: torch.Tensor, key_cache: torch.Tensor, @@ -246,7 +297,7 @@ def mla_decode_stage1_asm_fwd( splitData: torch.Tensor, # [batch_size, num_kv_splits, num_heads, 1] splitLse: torch.Tensor, -): ... +) -> None: ... @compile_ops(MD_NAME) @@ -269,4 +320,4 @@ def mla_prefill_asm_fwd( splitData: torch.Tensor, # [batch_size, num_kv_splits, num_heads, 1] splitLse: torch.Tensor, -): ... +) -> None: ... diff --git a/aiter/ops/batched_gemm_op_a8w8.py b/aiter/ops/batched_gemm_op_a8w8.py index 76062ea6c6..6f8c37cbe9 100644 --- a/aiter/ops/batched_gemm_op_a8w8.py +++ b/aiter/ops/batched_gemm_op_a8w8.py @@ -14,7 +14,23 @@ from ..jit.utils.chip_info import get_cu_num -@compile_ops("module_batched_gemm_a8w8", fc_name="batched_gemm_a8w8") +def gen_batched_gemm_a8w8_fake_tensors( + XQ: Tensor, + WQ: Tensor, + x_scale: Tensor, + w_scale: Tensor, + out: Tensor, + bias: Optional[Tensor] = None, + splitK: int = 0, +) -> Tensor: + return out + + +@compile_ops( + "module_batched_gemm_a8w8", + fc_name="batched_gemm_a8w8", + gen_fake=gen_batched_gemm_a8w8_fake_tensors, +) def batched_gemm_a8w8( XQ: Tensor, WQ: Tensor, @@ -22,8 +38,8 @@ def batched_gemm_a8w8( w_scale: Tensor, out: Tensor, bias: Optional[Tensor] = None, - splitK=0, -): ... + splitK: int = 0, +) -> Tensor: ... @functools.lru_cache(maxsize=1024) @@ -90,7 +106,23 @@ def batched_gemm_a8w8_CK( return batched_gemm_a8w8(XQ, WQ, x_scale, w_scale, Y, bias, splitK) -@compile_ops("module_batched_gemm_a8w8_tune", fc_name="batched_gemm_a8w8_tune") +def gen_batched_gemm_a8w8_tune_fake_tensors( + XQ: Tensor, + WQ: Tensor, + x_scale: Tensor, + w_scale: Tensor, + out: Tensor, + kernelId: int, + splitK: int = 0, +) -> Tensor: + return out + + +@compile_ops( + "module_batched_gemm_a8w8_tune", + fc_name="batched_gemm_a8w8_tune", + gen_fake=gen_batched_gemm_a8w8_tune_fake_tensors, +) def batched_gemm_a8w8_tune( XQ: Tensor, WQ: Tensor, @@ -98,5 +130,5 @@ def batched_gemm_a8w8_tune( w_scale: Tensor, out: Tensor, kernelId: int, - splitK=0, -): ... + splitK: int = 0, +) -> Tensor: ... diff --git a/aiter/ops/batched_gemm_op_bf16.py b/aiter/ops/batched_gemm_op_bf16.py index cb312da212..e82905684d 100644 --- a/aiter/ops/batched_gemm_op_bf16.py +++ b/aiter/ops/batched_gemm_op_bf16.py @@ -16,8 +16,8 @@ @compile_ops("module_batched_gemm_bf16", fc_name="batched_gemm_bf16") def batched_gemm_bf16( - XQ: Tensor, WQ: Tensor, out: Tensor, bias: Optional[Tensor] = None, splitK=0 -): ... + XQ: Tensor, WQ: Tensor, out: Tensor, bias: Optional[Tensor] = None, splitK: int = 0 +) -> None: ... @functools.lru_cache(maxsize=1024) @@ -80,10 +80,11 @@ def batched_gemm_bf16_CK( else: splitK = 0 Y = torch.empty(b, m, n, dtype=dtype, device=XQ.device) - return batched_gemm_bf16(XQ, WQ, Y, bias, splitK) + batched_gemm_bf16(XQ, WQ, Y, bias, splitK) + return Y @compile_ops("module_batched_gemm_bf16_tune", fc_name="batched_gemm_bf16_tune") def batched_gemm_bf16_tune( - XQ: Tensor, WQ: Tensor, out: Tensor, kernelId: int, splitK=0 -): ... + XQ: Tensor, WQ: Tensor, out: Tensor, kernelId: int, splitK: int = 0 +) -> None: ... diff --git a/aiter/ops/cache.py b/aiter/ops/cache.py index fb041abfbc..bd6f881bf1 100644 --- a/aiter/ops/cache.py +++ b/aiter/ops/cache.py @@ -10,11 +10,13 @@ @compile_ops("module_cache") -def swap_blocks(src: Tensor, dst: Tensor, block_mapping: Tensor): ... +def swap_blocks(src: Tensor, dst: Tensor, block_mapping: Tensor) -> None: ... @compile_ops("module_cache") -def copy_blocks(key_caches: Tensor, value_caches: Tensor, block_mapping: Tensor): ... +def copy_blocks( + key_caches: Tensor, value_caches: Tensor, block_mapping: Tensor +) -> None: ... @compile_ops("module_cache") @@ -41,7 +43,7 @@ def reshape_and_cache_flash( kv_cache_dtype: str, k_scale: Tensor, v_scale: Tensor, -): ... +) -> None: ... @compile_ops("module_cache") @@ -54,7 +56,7 @@ def reshape_and_cache_with_pertoken_quant( v_dequant_scales: Tensor, slot_mapping: Tensor, asm_layout: bool, -): ... +) -> None: ... @compile_ops("module_cache") @@ -67,7 +69,7 @@ def reshape_and_cache_with_block_quant( v_dequant_scales: Tensor, slot_mapping: Tensor, asm_layout: bool, -): ... +) -> None: ... @compile_ops("module_cache") @@ -81,4 +83,4 @@ def reshape_and_cache_with_block_quant_for_asm_pa( slot_mapping: Tensor, asm_layout: bool, ori_block_size: int = 128, # [128/256] -): ... +) -> None: ... diff --git a/aiter/ops/communication.py b/aiter/ops/communication.py index 329ae05a72..c884cacfba 100644 --- a/aiter/ops/communication.py +++ b/aiter/ops/communication.py @@ -51,7 +51,6 @@ def destroy_dist_env(): def all_reduce_asm(inp: torch.Tensor): tp_grp = get_tp_group() ca = tp_grp.ca_comm - if ca._IS_CAPTURING: if torch.cuda.is_current_stream_capturing(): return aiter.all_reduce_asm_( diff --git a/aiter/ops/custom.py b/aiter/ops/custom.py index 5453c94b00..b5e0b47a5e 100644 --- a/aiter/ops/custom.py +++ b/aiter/ops/custom.py @@ -8,17 +8,19 @@ @compile_ops("module_custom") -def wvSpltK(in_a: Tensor, in_b: Tensor, out_c: Tensor, N_in: int, CuCount: int): ... +def wvSpltK( + in_a: Tensor, in_b: Tensor, out_c: Tensor, N_in: int, CuCount: int +) -> None: ... @compile_ops("module_custom") def wv_splitk_small_fp16_bf16( in_a: Tensor, in_b: Tensor, out_c: Tensor, N_in: int, CuCount: int -): ... +) -> None: ... @compile_ops("module_custom") -def LLMM1(in_a: Tensor, in_b: Tensor, out_c: Tensor, rows_per_block: int): ... +def LLMM1(in_a: Tensor, in_b: Tensor, out_c: Tensor, rows_per_block: int) -> None: ... @compile_ops("module_custom") @@ -29,4 +31,4 @@ def wvSplitKQ( scale_a: Tensor, scale_b: Tensor, CuCount: int, -): ... +) -> None: ... diff --git a/aiter/ops/custom_all_reduce.py b/aiter/ops/custom_all_reduce.py index 6817a93c23..64c41aca0c 100644 --- a/aiter/ops/custom_all_reduce.py +++ b/aiter/ops/custom_all_reduce.py @@ -14,8 +14,8 @@ def init_custom_ar( meta: torch.Tensor, rank_data: torch.Tensor, - handles: list[torch.Tensor], - offsets: list[int], + handles: List[torch.Tensor], + offsets: List[int], rank: int, full_nvlink: bool, ) -> int: ... @@ -24,16 +24,31 @@ def init_custom_ar( @compile_ops("module_custom_all_reduce") def all_reduce_reg( _fa: int, inp: torch.Tensor, out: torch.Tensor, open_fp8_quant: bool -): ... +) -> None: ... @compile_ops("module_custom_all_reduce") def all_reduce_unreg( _fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor -): ... +) -> None: ... -@compile_ops("module_custom_all_reduce") +def all_reduce_asm_fake_tensor( + inp: torch.Tensor, + ca: int, + reg_sig: torch.Tensor, + reg_buffer: torch.Tensor, + isGraph: bool, +) -> torch.Tensor: + + return torch.empty_like( + inp, + dtype=inp.dtype, + device=inp.device, + ) + + +@compile_ops("module_custom_all_reduce", gen_fake=all_reduce_asm_fake_tensor) def all_reduce_asm_( inp: torch.Tensor, ca: int, @@ -43,7 +58,30 @@ def all_reduce_asm_( ) -> torch.Tensor: ... -@compile_ops("module_custom_all_reduce") +def all_reduce_rmsnorm_fake_tensors( + input: torch.Tensor, + residual_in: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + epsilon: float, + ca: int, + reg_sig: torch.Tensor, + reg_buffer: torch.Tensor, + isGraph: bool, +) -> List[torch.Tensor]: + + output = torch.empty_like( + input, dtype=input.dtype, device=input.device, requires_grad=input.requires_grad + ) + + residual_out = torch.empty_like( + input, dtype=input.dtype, device=input.device, requires_grad=input.requires_grad + ) + + return [output, residual_out] + + +@compile_ops("module_custom_all_reduce", gen_fake=all_reduce_rmsnorm_fake_tensors) def all_reduce_rmsnorm_( input: torch.Tensor, residual_in: torch.Tensor, @@ -57,7 +95,36 @@ def all_reduce_rmsnorm_( ) -> List[torch.Tensor]: ... -@compile_ops("module_custom_all_reduce") +def all_reduce_rmsnorm_quant_fake_tensors( + input: torch.Tensor, + residual_in: torch.Tensor, + weight: torch.Tensor, + xscale: torch.Tensor, + bias: torch.Tensor, + epsilon: float, + ca: int, + reg_sig: torch.Tensor, + reg_buffer: torch.Tensor, + isGraph: bool, +) -> List[torch.Tensor]: + + N = input.size(-1) + M = input.numel() // N + + output = torch.empty_like( + input, dtype=input.dtype, device=input.device, requires_grad=input.requires_grad + ) + + residual_out = torch.empty_like( + input, dtype=input.dtype, device=input.device, requires_grad=input.requires_grad + ) + + y_scale = torch.empty((M, 1), dtype=torch.float32, device=input.device) + + return [output, residual_out, y_scale] + + +@compile_ops("module_custom_all_reduce", gen_fake=all_reduce_rmsnorm_quant_fake_tensors) def all_reduce_rmsnorm_quant_( input: torch.Tensor, residual_in: torch.Tensor, @@ -73,7 +140,7 @@ def all_reduce_rmsnorm_quant_( @compile_ops("module_custom_all_reduce") -def dispose(_fa: int): ... +def dispose(_fa: int) -> None: ... @compile_ops("module_custom_all_reduce") @@ -82,8 +149,19 @@ def meta_size() -> int: ... @compile_ops("module_custom_all_reduce") def register_buffer( - _fa: int, t: torch.Tensor, handles: list[torch.Tensor], offsets: list[int] -): ... + _fa: int, t: torch.Tensor, handles: List[torch.Tensor], offsets: List[int] +) -> None: ... + + +# def gen_get_graph_buffer_ipc_meta_fake_tensors(_fa: int) -> List[torch.Tensor]: + +# handle_sz = 64 # sizeof(cudaIpcMemHandle_t) is 64 byte +# num_buffers = 4 # ??? +# handles = torch.empty((handle_sz * num_buffers,), dtype=torch.uint8, device="cuda") + +# offset_tensor = torch.empty((num_buffers,), dtype=torch.int64, device="cuda") + +# return [handles, offset_tensor] @compile_ops("module_custom_all_reduce") @@ -93,12 +171,20 @@ def get_graph_buffer_ipc_meta(_fa: int) -> tuple[torch.Tensor, torch.Tensor]: .. @compile_ops("module_custom_all_reduce") def register_graph_buffers( _fa: int, handles: list[torch.Tensor], offsets: list[torch.Tensor] -): ... +) -> None: ... @compile_ops("module_custom_all_reduce") def allocate_meta_buffer(size: int) -> torch.Tensor: ... +# def get_meta_buffer_ipc_handle_fake(inp: torch.Tensor) -> torch.Tensor: +# handle_size = 64 +# if not inp.is_cuda: +# raise RuntimeError("Input tensor must be on CUDA device") + +# return torch.empty(handle_size, dtype=torch.uint8, device=inp.device) + + @compile_ops("module_custom_all_reduce") def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor: ... diff --git a/aiter/ops/enum.py b/aiter/ops/enum.py index 9268fa31d4..edc1ddd671 100644 --- a/aiter/ops/enum.py +++ b/aiter/ops/enum.py @@ -1,5 +1,7 @@ from ..jit.core import compile_ops -from enum import Enum as Enum + +# from enum import Enum as Enum +Enum = int @compile_ops("module_aiter_enum", "ActivationType") diff --git a/aiter/ops/gemm_op_a4w4.py b/aiter/ops/gemm_op_a4w4.py index 5b8ab45443..54cac550d6 100644 --- a/aiter/ops/gemm_op_a4w4.py +++ b/aiter/ops/gemm_op_a4w4.py @@ -122,7 +122,7 @@ def gemm_a4w4_asm( beta: Optional[float] = 0.0, bpreshuffle: Optional[bool] = True, log2_k_split: Optional[int] = None, -) -> torch.Tensor: ... +) -> None: ... @compile_ops("module_gemm_a4w4_blockscale") @@ -133,7 +133,7 @@ def gemm_a4w4_blockscale( w_scale: torch.Tensor, Out: torch.Tensor, splitK: int = 0, -) -> torch.Tensor: ... +) -> None: ... @compile_ops("module_gemm_a4w4_blockscale_tune", fc_name="gemm_a4w4_blockscale_tune") @@ -145,4 +145,4 @@ def gemm_a4w4_blockscale_tune( Out: torch.Tensor, kernelId: int, splitK: int = 0, -) -> torch.Tensor: ... +) -> None: ... diff --git a/aiter/ops/gemm_op_a8w8.py b/aiter/ops/gemm_op_a8w8.py index 1fe6f2b3c9..2570a047c4 100644 --- a/aiter/ops/gemm_op_a8w8.py +++ b/aiter/ops/gemm_op_a8w8.py @@ -13,10 +13,27 @@ ) from ..utility import dtypes from ..jit.utils.chip_info import get_cu_num +from torch.library import Library + +aiter_lib = Library("aiter", "FRAGMENT") from ..ops.gemm_op_common import get_padded_m -@compile_ops("module_gemm_a8w8", fc_name="gemm_a8w8") +def gen_gemm_a8w8_ck_fake_tensors( + XQ: torch.Tensor, + WQ: torch.Tensor, + x_scale: torch.Tensor, + w_scale: torch.Tensor, + Out: torch.Tensor, + bias: Optional[torch.Tensor] = None, + splitK: int = 0, +) -> torch.Tensor: + return Out + + +@compile_ops( + "module_gemm_a8w8", fc_name="gemm_a8w8", gen_fake=gen_gemm_a8w8_ck_fake_tensors +) def gemm_a8w8_ck( XQ: torch.Tensor, WQ: torch.Tensor, @@ -28,7 +45,21 @@ def gemm_a8w8_ck( ) -> torch.Tensor: ... -@compile_ops("module_gemm_a8w8_bpreshuffle", fc_name="gemm_a8w8_bpreshuffle") +def gen_gemm_a8w8_bpreshuffle_ck_fake_tensors( + XQ: torch.Tensor, + WQ: torch.Tensor, + x_scale: torch.Tensor, + w_scale: torch.Tensor, + Out: torch.Tensor, +) -> torch.Tensor: + return Out + + +@compile_ops( + "module_gemm_a8w8_bpreshuffle", + fc_name="gemm_a8w8_bpreshuffle", + gen_fake=gen_gemm_a8w8_bpreshuffle_ck_fake_tensors, +) def gemm_a8w8_bpreshuffle_ck( XQ: torch.Tensor, WQ: torch.Tensor, @@ -38,7 +69,28 @@ def gemm_a8w8_bpreshuffle_ck( ) -> torch.Tensor: ... -@compile_ops("module_gemm_a8w8_asm", fc_name="gemm_a8w8_asm") +def gen_gemm_a8w8_asm_fake_tensors( + XQ: Tensor, # A:[M, K] i8 + WQ: Tensor, # B:[N, K] i8 -> shuffle layout(32,16) + x_scale: Tensor, # A_scale:[M, 1] f32 + w_scale: Tensor, # B_scale:[1, N] f32 + Out: Tensor, # Out:[M, N] bf16 + bias: Tensor, # bias:[1, N] f32 + sub_m: Optional[int] = 128, + sub_n: Optional[int] = 128, + pad_a: Optional[int] = 0, + pad_b: Optional[int] = 0, + pad_c: Optional[int] = 0, + splitK: Optional[int] = 0, +) -> torch.Tensor: + return Out + + +@compile_ops( + "module_gemm_a8w8_asm", + fc_name="gemm_a8w8_asm", + gen_fake=gen_gemm_a8w8_asm_fake_tensors, +) def gemm_a8w8_asm( XQ: Tensor, # A:[M, K] i8 WQ: Tensor, # B:[N, K] i8 -> shuffle layout(32,16) @@ -55,7 +107,21 @@ def gemm_a8w8_asm( ) -> torch.Tensor: ... -@compile_ops("module_gemm_a8w8_blockscale", fc_name="gemm_a8w8_blockscale") +def gen_gemm_a8w8_blockscale_ck_fake_tensors( + XQ: torch.Tensor, + WQ: torch.Tensor, + x_scale: torch.Tensor, + w_scale: torch.Tensor, + Out: torch.Tensor, +) -> Tensor: + return Out + + +@compile_ops( + "module_gemm_a8w8_blockscale", + fc_name="gemm_a8w8_blockscale", + gen_fake=gen_gemm_a8w8_blockscale_ck_fake_tensors, +) def gemm_a8w8_blockscale_ck( XQ: torch.Tensor, WQ: torch.Tensor, @@ -65,14 +131,28 @@ def gemm_a8w8_blockscale_ck( ) -> torch.Tensor: ... -@compile_ops("module_gemm_a8w8_blockscale_asm", fc_name="flatmm_a8w8_blockscale_asm") +def gen_flatmm_a8w8_blockscale_asm_fake_tensors( + XQ: Tensor, + WQ: Tensor, + x_scale: Tensor, + w_scale: Tensor, + out: Tensor, +) -> Tensor: + return out + + +@compile_ops( + "module_gemm_a8w8_blockscale_asm", + fc_name="flatmm_a8w8_blockscale_asm", + gen_fake=gen_flatmm_a8w8_blockscale_asm_fake_tensors, +) def flatmm_a8w8_blockscale_asm( XQ: Tensor, WQ: Tensor, x_scale: Tensor, w_scale: Tensor, out: Tensor, -): ... +) -> Tensor: ... @functools.lru_cache(maxsize=1024) @@ -86,26 +166,50 @@ def compute_gemm_SplitK(M: int, N: int, K: int, tile_m: int, tile_n: int, tile_k return splitK -@functools.lru_cache(maxsize=1024) -def get_CKGEMM_config(M: int, N: int, K: int, tuned_file="a8w8_tuned_gemm.csv"): - if not hasattr(get_CKGEMM_config, "ckgemm_dict"): - get_CKGEMM_config.ckgemm_dict = {} - if tuned_file not in get_CKGEMM_config.ckgemm_dict: +_CKGEMM_CONFIG_CACHE = None + + +def get_CKGEMM_config_(X: Tensor, tuned_file: str = "a8w8_tuned_gemm.csv") -> None: + global _CKGEMM_CONFIG_CACHE + + if _CKGEMM_CONFIG_CACHE is None: + _CKGEMM_CONFIG_CACHE = {} + if tuned_file not in _CKGEMM_CONFIG_CACHE: ckgemm_dict = pd.read_csv( f"{AITER_ROOT_DIR}/aiter/configs/{tuned_file}" ).drop_duplicates() - get_CKGEMM_config.ckgemm_dict[tuned_file] = ckgemm_dict.set_index( + _CKGEMM_CONFIG_CACHE[tuned_file] = ckgemm_dict.set_index( ["cu_num", "M", "N", "K"] ).to_dict("index") + + return None + + +def get_CKGEMM_config_fake( + X: Tensor, +) -> None: + return None + + +op_name = "aiter::get_CKGEMM_config_" + +schema_str = torch.library.infer_schema(get_CKGEMM_config_, mutates_args=()) +torch.library.define(op_name, schema_str, lib=aiter_lib) +torch.library.impl(op_name, "cuda", get_CKGEMM_config_, lib=aiter_lib) +torch.library.register_fake(op_name, get_CKGEMM_config_fake, lib=aiter_lib) + + +@functools.lru_cache(maxsize=1024) +def get_CKGEMM_config(M: int, N: int, K: int, tuned_file="a8w8_tuned_gemm.csv"): + torch.ops.aiter.get_CKGEMM_config_(torch.empty(1, device="cuda"), tuned_file) + cu_num = get_cu_num() padded_M = M config = None for gl in [None, 0, 1]: padded_M = M if gl is None else get_padded_m(M, N, K, gl) - config = get_CKGEMM_config.ckgemm_dict[tuned_file].get( - (cu_num, padded_M, N, K), None - ) + config = _CKGEMM_CONFIG_CACHE[tuned_file].get((cu_num, padded_M, N, K), None) if config is not None: logger.info( f"shape is M:{M}, N:{N}, K:{K}, found padded_M: {padded_M}, N:{N}, K:{K} is tuned on cu_num = {cu_num} in CKGEMM , kernel name is {config['kernelName']}!" @@ -287,7 +391,23 @@ def flatmm_a8w8_blockscale_ASM( return flatmm_a8w8_blockscale_asm(XQ, WQ, x_scale, w_scale, Y) -@compile_ops("module_gemm_a8w8_tune", fc_name="gemm_a8w8_tune") +def gen_gemm_a8w8_tune_fake_tensors( + XQ: torch.Tensor, + WQ: torch.Tensor, + x_scale: torch.Tensor, + w_scale: torch.Tensor, + Out: torch.Tensor, + kernelId: int = 0, + splitK: int = 0, +) -> torch.Tensor: + return Out + + +@compile_ops( + "module_gemm_a8w8_tune", + fc_name="gemm_a8w8_tune", + gen_fake=gen_gemm_a8w8_tune_fake_tensors, +) def gemm_a8w8_tune( XQ: torch.Tensor, WQ: torch.Tensor, @@ -299,7 +419,23 @@ def gemm_a8w8_tune( ) -> torch.Tensor: ... -@compile_ops("module_gemm_a8w8_blockscale_tune", fc_name="gemm_a8w8_blockscale_tune") +def gen_gemm_a8w8_blockscale_tune_fake_tensors( + XQ: torch.Tensor, + WQ: torch.Tensor, + x_scale: torch.Tensor, + w_scale: torch.Tensor, + Out: torch.Tensor, + kernelId: int = 0, + splitK: int = 0, +) -> torch.Tensor: + return Out + + +@compile_ops( + "module_gemm_a8w8_blockscale_tune", + fc_name="gemm_a8w8_blockscale_tune", + gen_fake=gen_gemm_a8w8_blockscale_tune_fake_tensors, +) def gemm_a8w8_blockscale_tune( XQ: torch.Tensor, WQ: torch.Tensor, @@ -318,4 +454,4 @@ def gemm_a8w8_bpreshuffle_tune( Out: torch.Tensor, kernelId: int = 0, splitK: int = 0, -) -> torch.Tensor: ... +) -> None: ... diff --git a/aiter/ops/gradlib.py b/aiter/ops/gradlib.py index 2fcdd6b04e..1c19662bd9 100644 --- a/aiter/ops/gradlib.py +++ b/aiter/ops/gradlib.py @@ -7,20 +7,41 @@ @compile_ops("module_hipbsolgemm") -def hipb_create_extension(): ... +def hipb_create_extension() -> None: ... @compile_ops("module_hipbsolgemm") -def hipb_destroy_extension(): ... +def hipb_destroy_extension() -> None: ... -@compile_ops("module_hipbsolgemm") +def gen_hipb_mm_fake_tensor( + mat1: torch.Tensor, + mat2: torch.Tensor, + solution_index: int, + bias: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + scaleA: Optional[torch.Tensor] = None, + scaleB: Optional[torch.Tensor] = None, + scaleOut: Optional[torch.Tensor] = None, +): + mat1_sizes = mat1.size() + mat2_sizes = mat2.size() + in_dtype = mat1.dtype + out_dtype = out_dtype if out_dtype is not None else in_dtype + result = torch.empty( + (mat1_sizes[0], mat2_sizes[1]), dtype=out_dtype, device=mat1.device + ) + + return result + + +@compile_ops("module_hipbsolgemm", gen_fake=gen_hipb_mm_fake_tensor) def hipb_mm( mat1: torch.Tensor, mat2: torch.Tensor, solution_index: int, bias: Optional[torch.Tensor] = None, - out_dtype: Optional[object] = None, + out_dtype: Optional[torch.dtype] = None, scaleA: Optional[torch.Tensor] = None, scaleB: Optional[torch.Tensor] = None, scaleOut: Optional[torch.Tensor] = None, @@ -32,7 +53,7 @@ def hipb_findallsols( mat1: torch.Tensor, mat2: torch.Tensor, bias: Optional[torch.Tensor] = None, - out_dtype: Optional[object] = None, + out_dtype: Optional[torch.dtype] = None, scaleA: Optional[torch.Tensor] = None, scaleB: Optional[torch.Tensor] = None, scaleC: Optional[torch.Tensor] = None, @@ -40,18 +61,31 @@ def hipb_findallsols( @compile_ops("module_hipbsolgemm") -def getHipblasltKernelName(): ... +def getHipblasltKernelName() -> None: ... @compile_ops("module_rocsolgemm") -def rocb_create_extension(): ... +def rocb_create_extension() -> None: ... @compile_ops("module_rocsolgemm") -def rocb_destroy_extension(): ... +def rocb_destroy_extension() -> None: ... -@compile_ops("module_rocsolgemm") +def gen_rocb_mm_fake_tensor( + arg0: torch.Tensor, arg1: torch.Tensor, arg2: int +) -> torch.Tensor: + mat1_sizes = arg0.size() + mat2_sizes = arg0.size() + in_dtype = arg0.dtype + result = torch.empty( + (mat1_sizes[0], mat2_sizes[1]), dtype=in_dtype, device=arg0.device + ) + + return result + + +@compile_ops("module_rocsolgemm", gen_fake=gen_rocb_mm_fake_tensor) def rocb_mm(arg0: torch.Tensor, arg1: torch.Tensor, arg2: int) -> torch.Tensor: ... diff --git a/aiter/ops/mha.py b/aiter/ops/mha.py index 44bbf60cd6..470812962e 100644 --- a/aiter/ops/mha.py +++ b/aiter/ops/mha.py @@ -2,14 +2,159 @@ # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. from torch import Tensor, Generator -from typing import Optional, Tuple +from typing import List, Optional, Tuple, Any from ..jit.core import compile_ops, CK_DIR, AITER_CSRC_DIR, logger from ..jit.utils.chip_info import get_gfx, get_cu_num from ..utility import dtypes import torch -@compile_ops("module_mha_fwd", fc_name="mha_fwd") +def cmdGenFunc_mha_fwd( + q: Tensor, + k: Tensor, + v: Tensor, + dropout_p: float, + softmax_scale: float, + is_causal: bool, + window_size_left: int, + window_size_right: int, + return_softmax_lse: bool, + return_dropout_randval: bool, + out: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + alibi_slopes: Optional[Tensor] = None, + gen: Optional[Generator] = None, +): + (_, seqlen_q, _, _) = q.shape + # causal=true is the same as causal=false in this case + causal = is_causal + if seqlen_q == 1 and alibi_slopes is None: + causal = False + + md_name = "mha_fwd" + filter = "*" + if q.dtype == dtypes.fp16: + md_name += "_fp16" + filter += "fp16*" + elif q.dtype == dtypes.bf16: + md_name += "_bf16" + filter += "bf16*" + if bias is not None: + md_name += "_bias" + filter += "_bias*" + elif alibi_slopes is not None: + md_name += "_alibi" + filter += "_alibi*" + else: + md_name += "_nbias" + filter += "_nbias*" + if not causal and window_size_left == -1 and window_size_right == -1: + md_name += "_nmask" + filter += "_nmask*" + else: + md_name += "_mask" + filter += "_mask*" + if return_softmax_lse: + md_name += "_lse" + filter += "_lse*" + else: + md_name += "_nlse" + filter += "_nlse*" + if dropout_p == 0: + md_name += "_ndropout" + filter += "_ndropout*" + else: + md_name += "_dropout" + filter += "_dropout*" + + blob_gen_cmd = [ + f"{CK_DIR}/example/ck_tile/01_fmha/generate.py -d fwd " + "--receipt 100 --filter {} --output_dir {{}}".format(filter), + f"{AITER_CSRC_DIR}/cpp_itfs/mha_fwd_generate.py --receipt 2 --output_dir {{}}", + ] + return { + "md_name": md_name, + "blob_gen_cmd": blob_gen_cmd, + } + + +def common_mha_fwd_fake_tensors( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + dropout_p: float, + return_softmax_lse: bool, + return_dropout_randval: bool, + out: Optional[torch.Tensor] = None, +): + batch_size = q.size(0) + seqlen_q = q.size(1) + num_heads = q.size(2) + head_size_v = v.size(3) + seqlen_k = k.size(1) + + if out is not None: + assert out.dtype == q.dtype, "Output must have the same dtype as inputs" + assert out.device == q.device, "Output must be on the same device as inputs" + assert out.stride(-1) == 1, "Output tensor must have contiguous last dimension" + assert out.shape == ( + batch_size, + seqlen_q, + num_heads, + head_size_v, + ), "Output tensor has incorrect shape" + else: + out = torch.empty( + (batch_size, seqlen_q, num_heads, head_size_v), + dtype=q.dtype, + device=q.device, + requires_grad=q.requires_grad, + ) + + if return_softmax_lse: + softmax_lse = torch.empty( + (batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device + ) + else: + softmax_lse = torch.empty((0,), dtype=torch.float32, device=q.device) + + if return_dropout_randval: + assert dropout_p > 0, "return_dropout_randval requires p_dropout > 0" + p = torch.empty( + (batch_size, num_heads, seqlen_q, seqlen_k), + dtype=torch.uint8, + device=q.device, + ) + else: + p = torch.empty((0,), device=q.device) + + rng_state = torch.empty((2,), dtype=torch.int64, device=q.device) + + return out, softmax_lse, p, rng_state + + +def gen_mha_fwd_fake_tensors( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + dropout_p: float, + softmax_scale: float, + is_causal: bool, + window_size_left: int, + window_size_right: int, + return_softmax_lse: bool, + return_dropout_randval: bool, + out: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + alibi_slopes: Optional[torch.Tensor] = None, + gen: Optional[torch.Generator] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + return common_mha_fwd_fake_tensors( + q, k, v, dropout_p, return_softmax_lse, return_dropout_randval, out + ) + + +@compile_ops("module_mha_fwd", fc_name="mha_fwd", gen_fake=gen_mha_fwd_fake_tensors) def mha_fwd( q: Tensor, k: Tensor, @@ -25,10 +170,33 @@ def mha_fwd( bias: Optional[Tensor] = None, alibi_slopes: Optional[Tensor] = None, gen: Optional[Generator] = None, -): ... +) -> List[Tensor]: ... -@compile_ops("module_fmha_v3_fwd", fc_name="fmha_v3_fwd") +def gen_fmha_v3_fwd_fake_tensors( + q: Tensor, + k: Tensor, + v: Tensor, + dropout_p: float, + softmax_scale: float, + is_causal: bool, + window_size_left: int, + window_size_right: int, + return_softmax_lse: bool, + return_dropout_randval: bool, + out: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + alibi_slopes: Optional[Tensor] = None, + gen: Optional[Generator] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + return common_mha_fwd_fake_tensors( + q, k, v, dropout_p, return_softmax_lse, return_dropout_randval, out + ) + + +@compile_ops( + "module_fmha_v3_fwd", fc_name="fmha_v3_fwd", gen_fake=gen_fmha_v3_fwd_fake_tensors +) def fmha_v3_fwd( q: Tensor, k: Tensor, @@ -44,10 +212,191 @@ def fmha_v3_fwd( bias: Optional[Tensor] = None, alibi_slopes: Optional[Tensor] = None, gen: Optional[Generator] = None, -): ... +) -> List[Tensor]: ... + + +def cmdGenFunc_mha_varlen_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: Optional[torch.Tensor], + max_seqlen_q: int, + max_seqlen_k: int, + min_seqlen_q: int, + dropout_p: float, + softmax_scale: float, + logits_soft_cap: float, + zero_tensors: bool, + is_causal: bool, + window_size_left: int, + window_size_right: int, + return_softmax_lse: bool, + return_dropout_randval: bool, + out: Optional[torch.Tensor] = None, + block_table: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + alibi_slopes: Optional[torch.Tensor] = None, + gen: Optional[torch.Generator] = None, +): + # causal=true is the same as causal=false in this case + causal = is_causal + if max_seqlen_q == 1 and alibi_slopes is None: + causal = False + md_name = "mha_varlen_fwd" + if block_table is None: + filter_fwd = "*" # get_fwd_blobs() + if q.dtype == dtypes.fp16: + md_name += "_fp16" + filter_fwd += "fp16*" + elif q.dtype == dtypes.bf16: + md_name += "_bf16" + filter_fwd += "bf16*" + if 0.0 < logits_soft_cap: + md_name += "_logits" + filter_fwd += "_logits*" + else: + md_name += "_nlogits" + filter_fwd += "_nlogits*" + if bias is not None: + md_name += "_bias" + filter_fwd += "_bias*" + elif alibi_slopes is not None: + md_name += "_alibi" + filter_fwd += "_alibi*" + else: + md_name += "_nbias" + filter_fwd += "_nbias*" + if not causal and window_size_left == -1 and window_size_right == -1: + md_name += "_nmask" + filter_fwd += "_nmask*" + else: + md_name += "_mask" + filter_fwd += "_mask*" + if return_softmax_lse: + md_name += "_lse" + filter_fwd += "_lse*" + else: + md_name += "_nlse" + filter_fwd += "_nlse*" + if dropout_p == 0: + md_name += "_ndropout" + filter_fwd += "_ndropout*" + else: + md_name += "_dropout" + filter_fwd += "_dropout*" + if min_seqlen_q == 0: + md_name += "_nskip" + filter_fwd += "_nskip*" + else: + md_name += "_skip" + filter_fwd += "_skip*" + blob_gen_cmd = [ + f"{CK_DIR}/example/ck_tile/01_fmha/generate.py -d fwd " + "--receipt 200 --filter {} --output_dir {{}}".format(filter_fwd) + ] + blob_gen_cmd.append( + f"{CK_DIR}/example/ck_tile/01_fmha/generate.py -d fwd_splitkv " + "--receipt 200 --filter {} --output_dir {{}}".format('" @ "') + ) + blob_gen_cmd.append( + f"{AITER_CSRC_DIR}/cpp_itfs/mha_fwd_generate.py --receipt 3 --output_dir {{}}" + ) + else: + filter_fwd_splitkv1 = "*" # get_fwd_splitkv_combine_blobs() + filter_fwd_splitkv2 = "*" # get_fwd_splitkv_blobs() + if q.dtype == dtypes.fp16: + md_name += "_fp16" + filter_fwd_splitkv1 += "fp16*" + filter_fwd_splitkv2 += "fp16*" + elif q.dtype == dtypes.bf16: + md_name += "_bf16" + filter_fwd_splitkv1 += "bf16*" + filter_fwd_splitkv2 += "bf16*" + if 0.0 < logits_soft_cap: + md_name += "_logits" + filter_fwd += "_logits*" + else: + md_name += "_nlogits" + filter_fwd += "_nlogits*" + if bias is not None: + md_name += "_bias" + filter_fwd_splitkv2 += "_bias*" + elif alibi_slopes is not None: + md_name += "_alibi" + filter_fwd_splitkv2 += "_alibi*" + else: + md_name += "_nbias" + filter_fwd_splitkv2 += "_nbias*" + if not is_causal and window_size_left == -1 and window_size_right == -1: + md_name += "_nmask" + filter_fwd_splitkv2 += "_nmask*" + else: + md_name += "_mask" + filter_fwd_splitkv2 += "_mask*" + if return_softmax_lse: + md_name += "_lse" + filter_fwd_splitkv1 += "_lse*" + filter_fwd_splitkv2 += "_lse*" + else: + md_name += "_nlse" + filter_fwd_splitkv1 += "_nlse*" + filter_fwd_splitkv2 += "_nlse*" + md_name += "_pagedkv" + filter_fwd_splitkv2 += "_pagedkv*" + filter_fwd_splitkv = f"{filter_fwd_splitkv1}@{filter_fwd_splitkv2}" + blob_gen_cmd = [ + f"{CK_DIR}/example/ck_tile/01_fmha/generate.py -d fwd " + "--receipt 200 --filter {} --output_dir {{}}".format('" "') + ] + blob_gen_cmd.append( + f"{CK_DIR}/example/ck_tile/01_fmha/generate.py -d fwd_splitkv " + "--receipt 200 --filter {} --output_dir {{}}".format(filter_fwd_splitkv) + ) + blob_gen_cmd.append( + f"{AITER_CSRC_DIR}/cpp_itfs/mha_fwd_generate.py --receipt 3 --output_dir {{}}" + ) + return { + "md_name": md_name, + "blob_gen_cmd": blob_gen_cmd, + } + + +def gen_mha_varlen_fwd_fake_tensor( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: Optional[torch.Tensor], + max_seqlen_q: int, + max_seqlen_k: int, + min_seqlen_q: int, + dropout_p: float, + softmax_scale: float, + logits_soft_cap: float, + zero_tensors: bool, + is_causal: bool, + window_size_left: int, + window_size_right: int, + return_softmax_lse: bool, + return_dropout_randval: bool, + out: Optional[torch.Tensor] = None, + block_table: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + alibi_slopes: Optional[torch.Tensor] = None, + gen: Optional[torch.Generator] = None, +) -> List[torch.Tensor]: + return common_mha_fwd_fake_tensors( + q, k, v, dropout_p, return_softmax_lse, return_dropout_randval, out + ) -@compile_ops("module_mha_varlen_fwd", fc_name="mha_varlen_fwd") +@compile_ops( + "module_mha_varlen_fwd", + fc_name="mha_varlen_fwd", + gen_func=cmdGenFunc_mha_varlen_fwd, + gen_fake=gen_mha_fwd_fake_tensors, +) def mha_varlen_fwd( q: torch.Tensor, k: torch.Tensor, @@ -71,19 +420,279 @@ def mha_varlen_fwd( bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, gen: Optional[torch.Generator] = None, -) -> list[torch.Tensor]: ... +) -> List[torch.Tensor]: ... + + +def cmdGenFunc_mha_bwd( + dout: Tensor, + q: Tensor, + k: Tensor, + v: Tensor, + out: Tensor, + softmax_lse: Tensor, + dropout_p: float, + softmax_scale: float, + is_causal: bool, + window_size_left: int, + window_size_right: int, + deterministic: bool, + dq: Optional[Tensor] = None, + dk: Optional[Tensor] = None, + dv: Optional[Tensor] = None, + dbias: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + alibi_slopes: Optional[Tensor] = None, + rng_state: Optional[Tensor] = None, + gen: Optional[Generator] = None, +): + md_name = "mha_bwd" + filter1 = "*" # get_bwd_dot_do_o_blobs() + filter2 = "*" # get_bwd_convert_dq_blobs() + filter3 = "*" # get_bwd_dq_dk_dv_blobs() + if q.dtype == dtypes.fp16: + md_name += "_fp16" + filter1 += "fp16*" + filter2 += "fp16*" + filter3 += "fp16*" + elif q.dtype == dtypes.bf16: + md_name += "_bf16" + filter1 += "bf16*" + filter2 += "bf16*" + filter3 += "bf16*" + if bias is not None: + md_name += "_bias" + filter3 += "_bias*" + elif alibi_slopes is not None: + md_name += "_alibi" + filter3 += "_alibi*" + else: + md_name += "_nbias" + filter3 += "_nbias*" + if dbias is not None: + md_name += "_dbias" + filter3 += "_dbias*" + else: + md_name += "_ndbias" + filter3 += "_ndbias*" + if not is_causal and window_size_left == -1 and window_size_right == -1: + md_name += "_nmask" + filter3 += "_nmask*" + else: + md_name += "_mask" + filter3 += "_mask*" + if dropout_p == 0: + md_name += "_ndropout" + filter3 += "_ndropout*" + else: + md_name += "_dropout" + filter3 += "_dropout*" + if deterministic: + md_name += "_deterministic" + filter2 += "_deterministic*" + filter3 += "_deterministic*" + else: + md_name += "_ndeterministic" + filter2 += "_ndeterministic*" + filter3 += "_ndeterministic*" + + filter = f"{filter1}@{filter2}@{filter3}" + + blob_gen_cmd = [ + f"{CK_DIR}/example/ck_tile/01_fmha/generate.py -d bwd " + "--receipt 300 --filter {} --output_dir {{}}".format(filter), + f"{AITER_CSRC_DIR}/cpp_itfs/mha_bwd_generate.py --receipt 1 --output_dir {{}}", + ] + return { + "md_name": md_name, + "blob_gen_cmd": blob_gen_cmd, + } + + +def common_mha_bwd_fake_tensors( + q: Tensor, + k: Tensor, + v: Tensor, + dq: Optional[Tensor] = None, + dk: Optional[Tensor] = None, + dv: Optional[Tensor] = None, +): + batch_size = q.size(0) + seqlen_q = q.size(1) + num_heads = q.size(2) + head_size_q = q.size(3) + head_size_v = v.size(3) + seqlen_k = k.size(1) + num_heads_k = k.size(2) + + if dq is None: + dq = torch.empty_like(q) # (batch_size, seqlen_q, num_heads, head_size_q) + else: + assert dq.dtype == q.dtype, "dq must have the same dtype as q" + assert dq.device == q.device, "dq must be on the same device as q" + assert dq.stride(-1) == 1, "dq must have contiguous last dimension" + assert dq.shape == ( + batch_size, + seqlen_q, + num_heads, + head_size_q, + ), "dq has incorrect shape" + + if dk is None: + dk = torch.empty_like(k) # (batch_size, seqlen_k, num_heads_k, head_size_q) + else: + assert dk.dtype == q.dtype, "dk must have the same dtype as q" + assert dk.device == q.device, "dk must be on the same device as q" + assert dk.stride(-1) == 1, "dk must have contiguous last dimension" + assert dk.shape == ( + batch_size, + seqlen_k, + num_heads_k, + head_size_q, + ), "dk has incorrect shape" + + if dv is None: + dv = torch.empty_like(v) # (batch_size, seqlen_k, num_heads_k, head_size_v) + else: + assert dv.dtype == q.dtype, "dv must have the same dtype as q" + assert dv.device == q.device, "dv must be on the same device as q" + assert dv.stride(-1) == 1, "dv must have contiguous last dimension" + assert dv.shape == ( + batch_size, + seqlen_k, + num_heads_k, + head_size_v, + ), "dv has incorrect shape" + + softmax_d = torch.empty( + (batch_size, num_heads, seqlen_q), # {batch_size, num_heads, seqlen_q} + dtype=torch.float32, + device=q.device, + ) + + return [dq, dk, dv, softmax_d] + + +def gen_mha_bwd_fake_tensors( + dout: Tensor, + q: Tensor, + k: Tensor, + v: Tensor, + out: Tensor, + softmax_lse: Tensor, + dropout_p: float, + softmax_scale: float, + is_causal: bool, + window_size_left: int, + window_size_right: int, + deterministic: bool, + dq: Optional[Tensor] = None, + dk: Optional[Tensor] = None, + dv: Optional[Tensor] = None, + dbias: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + alibi_slopes: Optional[Tensor] = None, + rng_state: Optional[Tensor] = None, + gen: Optional[Generator] = None, +) -> List[Tensor]: + return common_mha_bwd_fake_tensors(q, k, v, dq, dk, dv) + + +@compile_ops( + "module_mha_bwd", + fc_name="mha_bwd", + gen_func=cmdGenFunc_mha_bwd, + gen_fake=gen_mha_bwd_fake_tensors, +) +def mha_bwd( + dout: Tensor, + q: Tensor, + k: Tensor, + v: Tensor, + out: Tensor, + softmax_lse: Tensor, + dropout_p: float, + softmax_scale: float, + is_causal: bool, + window_size_left: int, + window_size_right: int, + deterministic: bool, + dq: Optional[Tensor] = None, + dk: Optional[Tensor] = None, + dv: Optional[Tensor] = None, + dbias: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + alibi_slopes: Optional[Tensor] = None, + rng_state: Optional[Tensor] = None, + gen: Optional[Generator] = None, +) -> List[Tensor]: ... + + +def gen_fmha_v3_bwd_fake_tensors( + dout: Tensor, + q: Tensor, + k: Tensor, + v: Tensor, + out: Tensor, + softmax_lse: Tensor, + dropout_p: float, + softmax_scale: float, + is_causal: bool, + window_size_left: int, + window_size_right: int, + deterministic: bool, + is_v3_atomic_fp32: bool, + how_v3_bf16_cvt: int, + dq: Optional[Tensor] = None, + dk: Optional[Tensor] = None, + dv: Optional[Tensor] = None, + alibi_slopes: Optional[Tensor] = None, + rng_state: Optional[Tensor] = None, + gen: Optional[Generator] = None, +) -> List[Tensor]: + return common_mha_bwd_fake_tensors(q, k, v, dq, dk, dv) + + +@compile_ops( + "module_fmha_v3_bwd", fc_name="fmha_v3_bwd", gen_fake=gen_fmha_v3_bwd_fake_tensors +) +def fmha_v3_bwd( + dout: Tensor, + q: Tensor, + k: Tensor, + v: Tensor, + out: Tensor, + softmax_lse: Tensor, + dropout_p: float, + softmax_scale: float, + is_causal: bool, + window_size_left: int, + window_size_right: int, + deterministic: bool, + is_v3_atomic_fp32: bool, + how_v3_bf16_cvt: int, + dq: Optional[Tensor] = None, + dk: Optional[Tensor] = None, + dv: Optional[Tensor] = None, + alibi_slopes: Optional[Tensor] = None, + rng_state: Optional[Tensor] = None, + gen: Optional[Generator] = None, +) -> List[Tensor]: ... -@compile_ops("module_mha_bwd", fc_name="mha_bwd") -def mha_bwd( +def cmdGenFunc_mha_varlen_bwd( dout: Tensor, q: Tensor, k: Tensor, v: Tensor, out: Tensor, softmax_lse: Tensor, + cu_seqlens_q: Tensor, + cu_seqlens_k: Tensor, + max_seqlen_q: int, + max_seqlen_k: int, dropout_p: float, softmax_scale: float, + zero_tensors: bool, is_causal: bool, window_size_left: int, window_size_right: int, @@ -91,41 +700,174 @@ def mha_bwd( dq: Optional[Tensor] = None, dk: Optional[Tensor] = None, dv: Optional[Tensor] = None, - dbias: Optional[Tensor] = None, - bias: Optional[Tensor] = None, alibi_slopes: Optional[Tensor] = None, rng_state: Optional[Tensor] = None, gen: Optional[Generator] = None, -): ... +) -> dict[str, Any]: + md_name = "mha_varlen_bwd" + filter1 = "*" # get_bwd_dot_do_o_blobs() + filter2 = "*" # get_bwd_convert_dq_blobs() + filter3 = "*" # get_bwd_dq_dk_dv_blobs() + if q.dtype == dtypes.fp16: + md_name += "_fp16" + filter1 += "fp16*" + filter2 += "fp16*" + filter3 += "fp16*" + elif q.dtype == dtypes.bf16: + md_name += "_bf16" + filter1 += "bf16*" + filter2 += "bf16*" + filter3 += "bf16*" + if alibi_slopes is None: + md_name += "_nbias" + filter3 += "_nbias*" + else: + md_name += "_alibi" + filter3 += "_alibi*" + if not is_causal and window_size_left == -1 and window_size_right == -1: + md_name += "_nmask" + filter3 += "_nmask*" + else: + md_name += "_mask" + filter3 += "_mask*" + if dropout_p == 0: + md_name += "_ndropout" + filter3 += "_ndropout*" + else: + md_name += "_dropout" + filter3 += "_dropout*" + if deterministic: + md_name += "_deterministic" + filter2 += "_deterministic*" + filter3 += "_deterministic*" + else: + md_name += "_ndeterministic" + filter2 += "_ndeterministic*" + filter3 += "_ndeterministic*" + filter = f"{filter1}@{filter2}@{filter3}" + + blob_gen_cmd = [ + f"{CK_DIR}/example/ck_tile/01_fmha/generate.py -d bwd " + "--receipt 400 --filter {} --output_dir {{}}".format(filter), + f"{AITER_CSRC_DIR}/cpp_itfs/mha_bwd_generate.py --receipt 1 --output_dir {{}}", + ] + return { + "md_name": md_name, + "blob_gen_cmd": blob_gen_cmd, + } -@compile_ops("module_fmha_v3_bwd", fc_name="fmha_v3_bwd") -def fmha_v3_bwd( +def cmdGenFunc_mha_batch_prefill( + q: Tensor, + k: Tensor, + v: Tensor, + cu_seqlens_q: Tensor, + kv_indptr: Tensor, + kv_page_indices: Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + dropout_p: float, + softmax_scale: float, + logits_soft_cap: float, + zero_tensors: bool, + is_causal: bool, + window_size_left: int, + window_size_right: int, + return_softmax_lse: bool, + return_dropout_randval: bool, + out: Optional[Tensor] = None, + alibi_slopes: Optional[Tensor] = None, + gen: Optional[Generator] = None, +): + # causal=true is the same as causal=false in this case + causal = is_causal + if max_seqlen_q == 1 and alibi_slopes is None: + causal = False + md_name = "mha_batch_prefill" + filter_fwd = "*" # get_fwd_blobs() + if q.dtype == torch.float16: + md_name += "_fp16" + filter_fwd += "fp16*" + elif q.dtype == torch.bfloat16: + md_name += "_bf16" + filter_fwd += "bf16*" + if 0.0 < logits_soft_cap: + md_name += "_logits" + filter_fwd += "_logits*" + else: + md_name += "_nlogits" + filter_fwd += "_nlogits*" + if alibi_slopes is None: + md_name += "_nbias" + filter_fwd += "_nbias*" + else: + md_name += "_alibi" + filter_fwd += "_alibi*" + if not causal and window_size_left == -1 and window_size_right == -1: + md_name += "_nmask" + filter_fwd += "_nmask*" + else: + md_name += "_mask" + filter_fwd += "_mask*" + if return_softmax_lse: + md_name += "_lse" + filter_fwd += "_lse*" + else: + md_name += "_nlse" + filter_fwd += "_nlse*" + if dropout_p == 0: + md_name += "_ndropout" + filter_fwd += "_ndropout*" + else: + md_name += "_dropout" + filter_fwd += "_dropout*" + blob_gen_cmd = [ + f"{CK_DIR}/example/ck_tile/01_fmha/generate.py -d batch_prefill " + "--receipt 200 --filter {} --output_dir {{}}".format(filter_fwd) + ] + blob_gen_cmd.append( + f"{AITER_CSRC_DIR}/cpp_itfs/mha_fwd_generate.py --receipt 4 --output_dir {{}}" + ) + return { + "md_name": md_name, + "blob_gen_cmd": blob_gen_cmd, + } + + +@compile_ops( + "module_mha_varlen_bwd", + fc_name="mha_varlen_bwd", + gen_func=cmdGenFunc_mha_varlen_bwd, + gen_fake=gen_mha_bwd_fake_tensors, +) +def mha_varlen_bwd( dout: Tensor, q: Tensor, k: Tensor, v: Tensor, out: Tensor, softmax_lse: Tensor, + cu_seqlens_q: Tensor, + cu_seqlens_k: Tensor, + max_seqlen_q: int, + max_seqlen_k: int, dropout_p: float, softmax_scale: float, + zero_tensors: bool, is_causal: bool, window_size_left: int, window_size_right: int, deterministic: bool, - is_v3_atomic_fp32: bool, - how_v3_bf16_cvt: int, dq: Optional[Tensor] = None, dk: Optional[Tensor] = None, dv: Optional[Tensor] = None, alibi_slopes: Optional[Tensor] = None, rng_state: Optional[Tensor] = None, gen: Optional[Generator] = None, -): ... +) -> List[Tensor]: ... -@compile_ops("module_mha_varlen_bwd", fc_name="mha_varlen_bwd") -def mha_varlen_bwd( +def gen_fmha_v3_varlen_bwd_fake_tensor( dout: Tensor, q: Tensor, k: Tensor, @@ -143,17 +885,23 @@ def mha_varlen_bwd( window_size_left: int, window_size_right: int, deterministic: bool, + is_v3_atomic_fp32: bool, + how_v3_bf16_cvt: int, dq: Optional[Tensor] = None, dk: Optional[Tensor] = None, dv: Optional[Tensor] = None, alibi_slopes: Optional[Tensor] = None, rng_state: Optional[Tensor] = None, gen: Optional[Generator] = None, - custom_build_args: Optional[dict] = None, -): ... +): + return common_mha_bwd_fake_tensors(q, k, v, dq, dk, dv) -@compile_ops("module_fmha_v3_varlen_bwd", fc_name="fmha_v3_varlen_bwd") +@compile_ops( + "module_fmha_v3_varlen_bwd", + fc_name="fmha_v3_varlen_bwd", + gen_fake=gen_fmha_v3_varlen_bwd_fake_tensor, +) def fmha_v3_varlen_bwd( dout: Tensor, q: Tensor, @@ -180,7 +928,7 @@ def fmha_v3_varlen_bwd( alibi_slopes: Optional[Tensor] = None, rng_state: Optional[Tensor] = None, gen: Optional[Generator] = None, -): ... +) -> None: ... def maybe_contiguous(x): @@ -201,52 +949,6 @@ def _flash_attn_forward( return_lse: bool, return_softmax: bool, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - (_, seqlen_q, _, _) = q.shape - # causal=true is the same as causal=false in this case - if seqlen_q == 1 and alibi_slopes is None: - causal = False - - md_name = "mha_fwd" - filter = "*" - if q.dtype == dtypes.fp16: - md_name += "_fp16" - filter += "fp16*" - elif q.dtype == dtypes.bf16: - md_name += "_bf16" - filter += "bf16*" - if bias is not None: - md_name += "_bias" - filter += "_bias*" - elif alibi_slopes is not None: - md_name += "_alibi" - filter += "_alibi*" - else: - md_name += "_nbias" - filter += "_nbias*" - if not causal and window_size_left == -1 and window_size_right == -1: - md_name += "_nmask" - filter += "_nmask*" - else: - md_name += "_mask" - filter += "_mask*" - if return_lse: - md_name += "_lse" - filter += "_lse*" - else: - md_name += "_nlse" - filter += "_nlse*" - if dropout_p == 0: - md_name += "_ndropout" - filter += "_ndropout*" - else: - md_name += "_dropout" - filter += "_dropout*" - - blob_gen_cmd = [ - f"{CK_DIR}/example/ck_tile/01_fmha/generate.py -d fwd " - "--receipt 100 --filter {} --output_dir {{}}".format(filter), - f"{AITER_CSRC_DIR}/cpp_itfs/mha_fwd_generate.py --receipt 2 --output_dir {{}}", - ] (_, seqlen_q, nhead_q, hdim_q) = q.shape (_, seqlen_k, nhead_k, hdim_v) = v.shape @@ -308,7 +1010,7 @@ def can_impl_fmha_v3_fwd(): bias, alibi_slopes, None, - custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd}, + # custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd}, ) return out, softmax_lse, S_dmask, rng_state @@ -327,76 +1029,19 @@ def _flash_attn_backward( dropout_p: float, softmax_scale: float, causal: bool, - window_size_left: int, - window_size_right: int, - bias: Optional[torch.Tensor], - alibi_slopes: Optional[torch.Tensor], - deterministic: bool, - rng_state: Optional[torch.Tensor] = None, - is_v3_atomic_fp32: Optional[bool] = True, - how_v3_bf16_cvt: Optional[int] = 1, -) -> torch.Tensor: - if get_gfx() == "gfx950" and how_v3_bf16_cvt != 0: - logger.warning( - "Rounding mode RTNA & RTZ are deprecated in gfx950, ignore option `how_v3_bf16_cvt`" - ) - md_name = "mha_bwd" - filter1 = "*" # get_bwd_dot_do_o_blobs() - filter2 = "*" # get_bwd_convert_dq_blobs() - filter3 = "*" # get_bwd_dq_dk_dv_blobs() - if q.dtype == dtypes.fp16: - md_name += "_fp16" - filter1 += "fp16*" - filter2 += "fp16*" - filter3 += "fp16*" - elif q.dtype == dtypes.bf16: - md_name += "_bf16" - filter1 += "bf16*" - filter2 += "bf16*" - filter3 += "bf16*" - if bias is not None: - md_name += "_bias" - filter3 += "_bias*" - elif alibi_slopes is not None: - md_name += "_alibi" - filter3 += "_alibi*" - else: - md_name += "_nbias" - filter3 += "_nbias*" - if dbias is not None: - md_name += "_dbias" - filter3 += "_dbias*" - else: - md_name += "_ndbias" - filter3 += "_ndbias*" - if not causal and window_size_left == -1 and window_size_right == -1: - md_name += "_nmask" - filter3 += "_nmask*" - else: - md_name += "_mask" - filter3 += "_mask*" - if dropout_p == 0: - md_name += "_ndropout" - filter3 += "_ndropout*" - else: - md_name += "_dropout" - filter3 += "_dropout*" - if deterministic: - md_name += "_deterministic" - filter2 += "_deterministic*" - filter3 += "_deterministic*" - else: - md_name += "_ndeterministic" - filter2 += "_ndeterministic*" - filter3 += "_ndeterministic*" - - filter = f"{filter1}@{filter2}@{filter3}" - - blob_gen_cmd = [ - f"{CK_DIR}/example/ck_tile/01_fmha/generate.py -d bwd " - "--receipt 300 --filter {} --output_dir {{}}".format(filter), - f"{AITER_CSRC_DIR}/cpp_itfs/mha_bwd_generate.py --receipt 1 --output_dir {{}}", - ] + window_size_left: int, + window_size_right: int, + bias: Optional[torch.Tensor], + alibi_slopes: Optional[torch.Tensor], + deterministic: bool, + rng_state: Optional[torch.Tensor] = None, + is_v3_atomic_fp32: Optional[bool] = True, + how_v3_bf16_cvt: Optional[int] = 1, +) -> torch.Tensor: + if get_gfx() == "gfx950" and how_v3_bf16_cvt != 0: + logger.warning( + "Rounding mode RTNA & RTZ are deprecated in gfx950, ignore option `how_v3_bf16_cvt`" + ) (_, seqlen_q, nhead_q, hdim_q) = q.shape (_, seqlen_k, nhead_k, hdim_v) = v.shape @@ -624,7 +1269,7 @@ def can_impl_fmha_v3_bwd(): alibi_slopes, rng_state, None, - custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd}, + # custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd}, ) return softmax_d @@ -838,123 +1483,6 @@ def _flash_attn_varlen_forward( out: Optional[torch.Tensor] = None, zero_tensors: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - # causal=true is the same as causal=false in this case - if max_seqlen_q == 1 and alibi_slopes is None: - causal = False - - md_name = "mha_varlen_fwd" - if block_table is None: - filter_fwd = "*" # get_fwd_blobs() - if q.dtype == dtypes.fp16: - md_name += "_fp16" - filter_fwd += "fp16*" - elif q.dtype == dtypes.bf16: - md_name += "_bf16" - filter_fwd += "bf16*" - if 0.0 < logits_soft_cap: - md_name += "_logits" - filter_fwd += "_logits*" - else: - md_name += "_nlogits" - filter_fwd += "_nlogits*" - if bias is not None: - md_name += "_bias" - filter_fwd += "_bias*" - elif alibi_slopes is not None: - md_name += "_alibi" - filter_fwd += "_alibi*" - else: - md_name += "_nbias" - filter_fwd += "_nbias*" - if not causal and window_size_left == -1 and window_size_right == -1: - md_name += "_nmask" - filter_fwd += "_nmask*" - else: - md_name += "_mask" - filter_fwd += "_mask*" - if return_lse: - md_name += "_lse" - filter_fwd += "_lse*" - else: - md_name += "_nlse" - filter_fwd += "_nlse*" - if dropout_p == 0: - md_name += "_ndropout" - filter_fwd += "_ndropout*" - else: - md_name += "_dropout" - filter_fwd += "_dropout*" - if min_seqlen_q == 0: - md_name += "_nskip" - filter_fwd += "_nskip*" - else: - md_name += "_skip" - filter_fwd += "_skip*" - blob_gen_cmd = [ - f"{CK_DIR}/example/ck_tile/01_fmha/generate.py -d fwd " - "--receipt 200 --filter {} --output_dir {{}}".format(filter_fwd) - ] - blob_gen_cmd.append( - f"{CK_DIR}/example/ck_tile/01_fmha/generate.py -d fwd_splitkv " - "--receipt 200 --filter {} --output_dir {{}}".format('" @ "') - ) - blob_gen_cmd.append( - f"{AITER_CSRC_DIR}/cpp_itfs/mha_fwd_generate.py --receipt 3 --output_dir {{}}" - ) - else: - filter_fwd_splitkv1 = "*" # get_fwd_splitkv_combine_blobs() - filter_fwd_splitkv2 = "*" # get_fwd_splitkv_blobs() - if q.dtype == dtypes.fp16: - md_name += "_fp16" - filter_fwd_splitkv1 += "fp16*" - filter_fwd_splitkv2 += "fp16*" - elif q.dtype == dtypes.bf16: - md_name += "_bf16" - filter_fwd_splitkv1 += "bf16*" - filter_fwd_splitkv2 += "bf16*" - if 0.0 < logits_soft_cap: - md_name += "_logits" - filter_fwd += "_logits*" - else: - md_name += "_nlogits" - filter_fwd += "_nlogits*" - if bias is not None: - md_name += "_bias" - filter_fwd_splitkv2 += "_bias*" - elif alibi_slopes is not None: - md_name += "_alibi" - filter_fwd_splitkv2 += "_alibi*" - else: - md_name += "_nbias" - filter_fwd_splitkv2 += "_nbias*" - if not causal and window_size_left == -1 and window_size_right == -1: - md_name += "_nmask" - filter_fwd_splitkv2 += "_nmask*" - else: - md_name += "_mask" - filter_fwd_splitkv2 += "_mask*" - if return_lse: - md_name += "_lse" - filter_fwd_splitkv1 += "_lse*" - filter_fwd_splitkv2 += "_lse*" - else: - md_name += "_nlse" - filter_fwd_splitkv1 += "_nlse*" - filter_fwd_splitkv2 += "_nlse*" - md_name += "_pagedkv" - filter_fwd_splitkv2 += "_pagedkv*" - filter_fwd_splitkv = f"{filter_fwd_splitkv1}@{filter_fwd_splitkv2}" - blob_gen_cmd = [ - f"{CK_DIR}/example/ck_tile/01_fmha/generate.py -d fwd " - "--receipt 200 --filter {} --output_dir {{}}".format('" "') - ] - blob_gen_cmd.append( - f"{CK_DIR}/example/ck_tile/01_fmha/generate.py -d fwd_splitkv " - "--receipt 200 --filter {} --output_dir {{}}".format(filter_fwd_splitkv) - ) - blob_gen_cmd.append( - f"{AITER_CSRC_DIR}/cpp_itfs/mha_fwd_generate.py --receipt 3 --output_dir {{}}" - ) q, k, v = [maybe_contiguous(x) for x in (q, k, v)] out, softmax_lse, S_dmask, rng_state = mha_varlen_fwd( @@ -980,7 +1508,7 @@ def _flash_attn_varlen_forward( bias, alibi_slopes, None, - custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd}, + # custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd}, ) return out, softmax_lse, S_dmask, rng_state @@ -1011,53 +1539,6 @@ def _flash_attn_varlen_backward( how_v3_bf16_cvt: Optional[int] = 1, zero_tensors: bool = False, ) -> torch.Tensor: - md_name = "mha_varlen_bwd" - filter1 = "*" # get_bwd_dot_do_o_blobs() - filter2 = "*" # get_bwd_convert_dq_blobs() - filter3 = "*" # get_bwd_dq_dk_dv_blobs() - if q.dtype == dtypes.fp16: - md_name += "_fp16" - filter1 += "fp16*" - filter2 += "fp16*" - filter3 += "fp16*" - elif q.dtype == dtypes.bf16: - md_name += "_bf16" - filter1 += "bf16*" - filter2 += "bf16*" - filter3 += "bf16*" - if alibi_slopes is None: - md_name += "_nbias" - filter3 += "_nbias*" - else: - md_name += "_alibi" - filter3 += "_alibi*" - if not causal and window_size_left == -1 and window_size_right == -1: - md_name += "_nmask" - filter3 += "_nmask*" - else: - md_name += "_mask" - filter3 += "_mask*" - if dropout_p == 0: - md_name += "_ndropout" - filter3 += "_ndropout*" - else: - md_name += "_dropout" - filter3 += "_dropout*" - if deterministic: - md_name += "_deterministic" - filter2 += "_deterministic*" - filter3 += "_deterministic*" - else: - md_name += "_ndeterministic" - filter2 += "_ndeterministic*" - filter3 += "_ndeterministic*" - filter = f"{filter1}@{filter2}@{filter3}" - - blob_gen_cmd = [ - f"{CK_DIR}/example/ck_tile/01_fmha/generate.py -d bwd " - "--receipt 400 --filter {} --output_dir {{}}".format(filter), - f"{AITER_CSRC_DIR}/cpp_itfs/mha_bwd_generate.py --receipt 1 --output_dir {{}}", - ] (_, nhead_q, hdim_q) = q.shape @@ -1198,7 +1679,7 @@ def can_impl_fmha_v3_bwd(): alibi_slopes, rng_state, None, - custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd}, + # custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd}, ) return softmax_d @@ -1467,7 +1948,71 @@ def flash_attn_varlen_func( ) -@compile_ops("module_mha_batch_prefill", fc_name="mha_batch_prefill") +def mha_batch_prefill_fake_tensors( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + kv_indptr: torch.Tensor, + kv_page_indices: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + dropout_p: float, + softmax_scale: float, + logits_soft_cap: float, + zero_tensors: bool, + is_causal: bool, + window_size_left: int, + window_size_right: int, + return_softmax_lse: bool, + return_dropout_randval: bool, + out: Optional[torch.Tensor] = None, + alibi_slopes: Optional[torch.Tensor] = None, + gen: Optional[Generator] = None, +) -> List[Tensor]: + # ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + num_heads = q.size(1) # num_heads = q.sizes()[1] + head_size_v = v.size(2) # head_size_v = v.size(2) + total_q = q.size(0) # total_q = q.size(0) + + if out is None: + out = torch.empty( + (total_q, num_heads, head_size_v), # {total_q, num_heads, head_size_v} + dtype=q.dtype, + device=q.device, + requires_grad=q.requires_grad, + ) + + if return_softmax_lse: + softmax_lse = torch.empty( + (num_heads, total_q), # {num_heads, total_q} + dtype=torch.float32, + device=q.device, + ) + else: + softmax_lse = torch.empty((0,), dtype=torch.float32, device=q.device) + + if return_dropout_randval: + assert dropout_p > 0, "return_dropout_randval requires p_dropout > 0" + p = torch.empty( + (num_heads, total_q, max_seqlen_k), # {num_heads, total_q, max_seqlen_k} + dtype=torch.uint8, + device=q.device, + ) + else: + p = torch.empty((0,), device=q.device) + + rng_state = torch.empty((2,), dtype=torch.int64, device=q.device) + + return (out, softmax_lse, p, rng_state) + + +@compile_ops( + "module_mha_batch_prefill", + fc_name="mha_batch_prefill", + gen_func=cmdGenFunc_mha_batch_prefill, + gen_fake=mha_batch_prefill_fake_tensors, +) def mha_batch_prefill( q: Tensor, k: Tensor, @@ -1489,7 +2034,7 @@ def mha_batch_prefill( out: Optional[Tensor] = None, alibi_slopes: Optional[Tensor] = None, gen: Optional[Generator] = None, -): ... +) -> List[Tensor]: ... def _mha_batch_prefill( @@ -1513,55 +2058,6 @@ def _mha_batch_prefill( zero_tensors: bool = False, out: torch.Tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - # causal=true is the same as causal=false in this case - if max_seqlen_q == 1 and alibi_slopes is None: - causal = False - - md_name = "mha_batch_prefill" - filter_fwd = "*" # get_fwd_blobs() - if q.dtype == torch.float16: - md_name += "_fp16" - filter_fwd += "fp16*" - elif q.dtype == torch.bfloat16: - md_name += "_bf16" - filter_fwd += "bf16*" - if 0.0 < logits_soft_cap: - md_name += "_logits" - filter_fwd += "_logits*" - else: - md_name += "_nlogits" - filter_fwd += "_nlogits*" - if alibi_slopes is None: - md_name += "_nbias" - filter_fwd += "_nbias*" - else: - md_name += "_alibi" - filter_fwd += "_alibi*" - if not causal and window_size_left == -1 and window_size_right == -1: - md_name += "_nmask" - filter_fwd += "_nmask*" - else: - md_name += "_mask" - filter_fwd += "_mask*" - if return_lse: - md_name += "_lse" - filter_fwd += "_lse*" - else: - md_name += "_nlse" - filter_fwd += "_nlse*" - if dropout_p == 0: - md_name += "_ndropout" - filter_fwd += "_ndropout*" - else: - md_name += "_dropout" - filter_fwd += "_dropout*" - blob_gen_cmd = [ - f"{CK_DIR}/example/ck_tile/01_fmha/generate.py -d batch_prefill " - "--receipt 200 --filter {} --output_dir {{}}".format(filter_fwd) - ] - blob_gen_cmd.append( - f"{AITER_CSRC_DIR}/cpp_itfs/mha_fwd_generate.py --receipt 4 --output_dir {{}}" - ) q, k, v = [maybe_contiguous(x) for x in (q, k, v)] out, softmax_lse, S_dmask, rng_state = mha_batch_prefill( @@ -1585,7 +2081,7 @@ def _mha_batch_prefill( out, alibi_slopes, None, - custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd}, + # custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd}, ) return out, softmax_lse, S_dmask, rng_state diff --git a/aiter/ops/moe_op.py b/aiter/ops/moe_op.py index 1bba436fd7..391676760d 100755 --- a/aiter/ops/moe_op.py +++ b/aiter/ops/moe_op.py @@ -23,7 +23,7 @@ def topk_softmax( @compile_ops("module_moe_asm") -def moe_sum(input: Tensor, output: Tensor): ... +def moe_sum(input: Tensor, output: Tensor) -> None: ... @compile_ops("module_moe_asm") @@ -35,7 +35,7 @@ def moe_align_block_size( experts_ids: Tensor, token_nums: Tensor, num_tokens_post_pad: Tensor, -): ... +) -> None: ... @compile_ops("module_moe_asm") @@ -49,7 +49,7 @@ def fmoe( sorted_expert_ids: Tensor, num_valid_ids: Tensor, topk: int, -): ... +) -> None: ... @compile_ops("module_moe_asm") @@ -67,8 +67,8 @@ def fmoe_int8_g1u0( fc1_scale: Tensor, fc2_scale: Tensor, fc2_smooth_scale: Tensor, - activation: Optional[Enum] = ActivationType.Silu, -): ... + activation: Optional[Enum] = ActivationType.Silu.value, +) -> None: ... @compile_ops("module_moe_asm") @@ -86,8 +86,8 @@ def fmoe_g1u1( fc1_scale: Tensor, fc2_scale: Tensor, fc2_smooth_scale: Optional[Tensor] = None, - activation: Optional[Enum] = ActivationType.Silu, -): ... + activation: Optional[Enum] = ActivationType.Silu.value, +) -> None: ... @compile_ops("module_moe_asm") @@ -105,8 +105,8 @@ def fmoe_g1u1_tkw1( fc1_scale: Tensor, fc2_scale: Tensor, fc2_smooth_scale: Optional[Tensor] = None, - activation: Optional[Enum] = ActivationType.Silu, -): ... + activation: Optional[Enum] = ActivationType.Silu.value, +) -> None: ... @compile_ops("module_moe_asm") @@ -124,7 +124,7 @@ def fmoe_int8_g1u0_a16( fc2_scale: Tensor, fc1_smooth_scale: Tensor, fc2_smooth_scale: Tensor, -): ... +) -> None: ... @compile_ops("module_moe_asm") @@ -142,8 +142,8 @@ def fmoe_g1u1_a16( fc2_scale: Tensor, fc1_smooth_scale: Tensor, fc2_smooth_scale: Tensor, - activation: ActivationType = ActivationType.Silu, -): ... + activation: Optional[Enum] = ActivationType.Silu.value, +) -> None: ... @compile_ops("module_moe_asm") @@ -163,8 +163,8 @@ def fmoe_fp8_blockscale_g1u1( fc_scale_blkn: int = 128, fc_scale_blkk: int = 128, fc2_smooth_scale: Optional[Tensor] = None, - activation: ActivationType = ActivationType.Silu, -): ... + activation: Optional[Enum] = ActivationType.Silu.value, +) -> None: ... @compile_ops("module_moe_asm") @@ -180,15 +180,81 @@ def moe_stage1_g1u1( kernelName: str, block_m: int, ksplit: int = 0, - activation: ActivationType = ActivationType.Silu, - quant_type: QuantType = QuantType.No, + activation: Optional[Enum] = ActivationType.Silu.value, + quant_type: Optional[Enum] = QuantType.No.value, a1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, sorted_weights: Optional[torch.Tensor] = None, ) -> None: ... -@compile_ops("module_moe_ck2stages") +def cmdGenFunc_ck_moe_stage( + hidden_states: Tensor, + w1: Tensor, + w2: Tensor, + sorted_token_ids: Tensor, + sorted_expert_ids: Tensor, + num_valid_ids: Tensor, + out: Tensor, + topk: int, + kernelName: str = "", + w1_scale: Optional[Tensor] = None, + a1_scale: Optional[Tensor] = None, + block_m: Optional[int] = 32, + sorted_weights: Optional[Tensor] = None, + quant_type: int = 0, + activation: int = 0, +): + + mul_routed_weight_stage = 2 if sorted_weights is None else 1 + md_name, blob_gen_cmd = get_moe_stage_module( + hidden_states.dtype, + w1.dtype, + out.dtype, + activation, + quant_type, + mul_routed_weight_stage, + ) + return { + "md_name": md_name, + "blob_gen_cmd": blob_gen_cmd, + } + + +def cmdGenFunc_ck_moe_stage2( + hidden_states: Tensor, + w1: Tensor, + w2: Tensor, + sorted_token_ids: Tensor, + sorted_expert_ids: Tensor, + num_valid_ids: Tensor, + out: Tensor, + topk: int, + kernelName: str = "", + w1_scale: Optional[Tensor] = None, + a1_scale: Optional[Tensor] = None, + block_m: Optional[int] = 32, + sorted_weights: Optional[Tensor] = None, + quant_type: int = 0, + activation: int = 0, +): + + mul_routed_weight_stage = 1 if sorted_weights is None else 2 + md_name, blob_gen_cmd = get_moe_stage_module( + hidden_states.dtype, + w1.dtype, + out.dtype, + activation, + quant_type, + mul_routed_weight_stage, + ) + return { + "md_name": md_name, + "blob_gen_cmd": blob_gen_cmd, + } + + +@compile_ops("module_moe_ck2stages", gen_func=cmdGenFunc_ck_moe_stage) def ck_moe_stage1( hidden_states: Tensor, w1: Tensor, @@ -203,10 +269,12 @@ def ck_moe_stage1( a1_scale: Optional[Tensor] = None, block_m: Optional[int] = 32, sorted_weights: Optional[Tensor] = None, -): ... + quant_type: int = 0, + activation: int = 0, +) -> None: ... -@compile_ops("module_moe_ck2stages") +@compile_ops("module_moe_ck2stages", gen_func=cmdGenFunc_ck_moe_stage2) def ck_moe_stage2( inter_states: Tensor, w1: Tensor, @@ -221,7 +289,9 @@ def ck_moe_stage2( a2_scale: Optional[Tensor] = None, block_m: Optional[int] = 32, sorted_weights: Optional[Tensor] = None, -): ... + quant_type: int = 0, + activation: int = 0, +) -> None: ... @compile_ops("module_moe_cktile2stages", fc_name="cktile_moe_gemm1") @@ -316,6 +386,11 @@ def get_moe_stage_module( quant_type, mul_routed_weight_stage, ): + if isinstance(activation, int): + activation = ActivationType(activation) + if isinstance(quant_type, int): + quant_type = QuantType(quant_type) + Adtype = dtype2str_dict[input_dtype] Bdtype = dtype2str_dict[weight_dtype] Cdtype = dtype2str_dict[output_dtype] @@ -362,16 +437,6 @@ def ck_moe_stage1_fwd( quant_type: QuantType = QuantType.No, activation: ActivationType = ActivationType.Silu, ): - mul_routed_weight_stage = 2 if sorted_weights is None else 1 - md_name, blob_gen_cmd = get_moe_stage_module( - hidden_states.dtype, - w1.dtype, - out.dtype, - activation, - quant_type, - mul_routed_weight_stage, - ) - ck_moe_stage1( hidden_states, w1, @@ -386,7 +451,8 @@ def ck_moe_stage1_fwd( a1_scale, block_m, sorted_weights, - custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd}, + quant_type.value, + activation.value, ) return out @@ -408,16 +474,6 @@ def ck_moe_stage2_fwd( quant_type: QuantType = QuantType.No, activation: ActivationType = ActivationType.Silu, ): - mul_routed_weight_stage = 1 if sorted_weights is None else 2 - - md_name, blob_gen_cmd = get_moe_stage_module( - inter_states.dtype, - w1.dtype, - out.dtype, - activation, - quant_type, - mul_routed_weight_stage, - ) ck_moe_stage2( inter_states, @@ -433,6 +489,7 @@ def ck_moe_stage2_fwd( a2_scale, block_m, sorted_weights, - custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd}, + quant_type.value, + activation.value, ) return out diff --git a/aiter/ops/moe_sorting.py b/aiter/ops/moe_sorting.py index 8ef3d40153..466a8fd94e 100644 --- a/aiter/ops/moe_sorting.py +++ b/aiter/ops/moe_sorting.py @@ -22,4 +22,4 @@ def moe_sorting_fwd( local_expert_mask: Optional[torch.Tensor] = None, num_local_tokens: Optional[torch.Tensor] = None, dispatch_policy: int = 0, -): ... +) -> None: ... diff --git a/aiter/ops/norm.py b/aiter/ops/norm.py index 713bfd227d..2a47f35146 100644 --- a/aiter/ops/norm.py +++ b/aiter/ops/norm.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: MIT # Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +import torch from torch import Tensor from typing import Optional from ..jit.core import compile_ops @@ -8,24 +9,43 @@ MD_NAME = "module_norm" -@compile_ops("module_norm", fc_name="layernorm2d_fwd") -def layer_norm( +def gen_layer_norm_fake_tensors( input: Tensor, # normalized_shape: List[int], weight: Optional[Tensor] = None, bias: Optional[Tensor] = None, eps: float = 1e-5, x_bias: Optional[Tensor] = None, +) -> Tensor: + return torch.empty_like( + input, + dtype=input.dtype, + device=input.device, + ) + + +@compile_ops( + "module_norm", fc_name="layernorm2d_fwd", gen_fake=gen_layer_norm_fake_tensors +) +def layer_norm( + input: Tensor, + # normalized_shape: List[int], + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + epsilon: float = 1e-5, + x_bias: Optional[Tensor] = None, ) -> Tensor: ... -@compile_ops("module_norm", fc_name="layernorm2d_fwd") +@compile_ops( + "module_norm", fc_name="layernorm2d_fwd", gen_fake=gen_layer_norm_fake_tensors +) def layernorm2d_fwd( input: Tensor, # normalized_shape: List[int], weight: Optional[Tensor] = None, bias: Optional[Tensor] = None, - eps: float = 1e-5, + epsilon: float = 1e-5, x_bias: Optional[Tensor] = None, ) -> Tensor: ... @@ -40,7 +60,7 @@ def layernorm2d_fwd_with_add( bias: Tensor, epsilon: float, x_bias: Optional[Tensor] = None, -): ... +) -> None: ... @compile_ops("module_norm") @@ -53,7 +73,7 @@ def layernorm2d_fwd_with_smoothquant( bias: Tensor, epsilon: float, x_bias: Optional[Tensor] = None, -): ... +) -> None: ... @compile_ops("module_norm") @@ -68,7 +88,7 @@ def layernorm2d_fwd_with_add_smoothquant( bias: Tensor, epsilon: float, x_bias: Optional[Tensor] = None, -): ... +) -> None: ... # @compile_ops("module_norm") @@ -103,7 +123,9 @@ def layernorm2d_with_add_asm( bias: Tensor, epsilon: float, x_bias: Optional[Tensor] = None, -): ... +) -> None: ... + + @compile_ops("module_norm") def layernorm2d_with_add_smoothquant_asm( out: Tensor, @@ -116,4 +138,4 @@ def layernorm2d_with_add_smoothquant_asm( bias: Tensor, epsilon: float, x_bias: Optional[Tensor] = None, -): ... +) -> None: ... diff --git a/aiter/ops/pos_encoding.py b/aiter/ops/pos_encoding.py index a049aaabcc..8a0c5f3a54 100644 --- a/aiter/ops/pos_encoding.py +++ b/aiter/ops/pos_encoding.py @@ -17,7 +17,7 @@ def rotary_embedding_fwd( sin_cache: Tensor, is_neox: bool, is_nope_first: bool, -): ... +) -> None: ... @compile_ops("module_pos_encoding") @@ -32,4 +32,4 @@ def batched_rotary_embedding( is_nope_first: bool, rot_dim: int, cos_sin_cache_offsets: Tensor, -): ... +) -> None: ... diff --git a/aiter/ops/quant.py b/aiter/ops/quant.py index 7049a19047..51e89edf5d 100644 --- a/aiter/ops/quant.py +++ b/aiter/ops/quant.py @@ -13,13 +13,15 @@ @compile_ops("module_smoothquant") -def smoothquant_fwd(input: Tensor, out: Tensor, x_scale: Tensor, y_scale: Tensor): ... +def smoothquant_fwd( + input: Tensor, out: Tensor, x_scale: Tensor, y_scale: Tensor +) -> None: ... @compile_ops("module_smoothquant") def moe_smoothquant_fwd( input: Tensor, out: Tensor, x_scale: Tensor, topk_ids: Tensor, y_scale: Tensor -): ... +) -> None: ... # following are pure torch implement @@ -357,11 +359,11 @@ def get_torch_act(aType): @compile_ops("module_quant") -def static_per_tensor_quant(out: Tensor, input: Tensor, scale: Tensor): ... +def static_per_tensor_quant(out: Tensor, input: Tensor, scale: Tensor) -> None: ... @compile_ops("module_quant") -def dynamic_per_tensor_quant(out: Tensor, input: Tensor, scale: Tensor): ... +def dynamic_per_tensor_quant(out: Tensor, input: Tensor, scale: Tensor) -> None: ... @compile_ops("module_quant") @@ -382,10 +384,10 @@ def dynamic_per_group_scaled_quant_fp4( input: Tensor, scales: Tensor, group_size: Optional[int] = 32, - shuffle_scale=True, + shuffle_scale: bool = True, num_rows: Optional[Tensor] = None, num_rows_factor: int = 1, -): +) -> None: """ Only support group_size in [32, 64, 128] """ @@ -397,4 +399,4 @@ def partial_transpose( out: Tensor, input: Tensor, num_rows: Tensor, -): ... +) -> None: ... diff --git a/aiter/ops/rmsnorm.py b/aiter/ops/rmsnorm.py index 988d8c2be6..63064d8e6f 100644 --- a/aiter/ops/rmsnorm.py +++ b/aiter/ops/rmsnorm.py @@ -14,7 +14,7 @@ def rms_norm_cu( input: Tensor, weight: Tensor, epsilon: float, -): +) -> None: """ Cuda version of rmsnorm """ @@ -27,33 +27,44 @@ def fused_add_rms_norm_cu( residual_in: Tensor, # residual_in/out weight: Tensor, epsilon: float, -): +) -> None: """ Cuda version of rmsnorm fused add """ ... -@compile_ops("module_rmsnorm", fc_name="rmsnorm2d_fwd") +def gen_rms_norm_fake_tensor( + input: Tensor, + weight: Tensor, + epsilon: float, + use_model_sensitive_rmsnorm: int = 0, +) -> Tensor: + return torch.empty_like(input, dtype=input.dtype, device=input.device) + + +@compile_ops( + "module_rmsnorm", fc_name="rmsnorm2d_fwd", gen_fake=gen_rms_norm_fake_tensor +) def rms_norm( input: Tensor, weight: Tensor, epsilon: float, - use_model_sensitive_rmsnorm: int, -): + use_model_sensitive_rmsnorm: int = 0, +) -> Tensor: """ CK version of rmsnorm """ ... -@compile_ops("module_rmsnorm") +@compile_ops("module_rmsnorm", gen_fake=gen_rms_norm_fake_tensor) def rmsnorm2d_fwd( input: torch.Tensor, weight: torch.Tensor, epsilon: float, - use_model_sensitive_rmsnorm: int, -) -> torch.Tensor: ... + use_model_sensitive_rmsnorm: int = 0, +) -> Tensor: ... @compile_ops("module_rmsnorm") @@ -64,8 +75,8 @@ def rmsnorm2d_fwd_with_add( residual_out: Tensor, weight: Tensor, epsilon: float, - use_model_sensitive_rmsnorm: int, -): ... + use_model_sensitive_rmsnorm: int = 0, +) -> None: ... @compile_ops("module_rmsnorm") @@ -76,8 +87,8 @@ def rmsnorm2d_fwd_with_smoothquant( yscale: Tensor, weight: Tensor, epsilon: float, - use_model_sensitive_rmsnorm: int, -): ... + use_model_sensitive_rmsnorm: int = 0, +) -> None: ... @compile_ops("module_rmsnorm") @@ -90,8 +101,8 @@ def rmsnorm2d_fwd_with_add_smoothquant( yscale: Tensor, weight: Tensor, epsilon: float, - use_model_sensitive_rmsnorm: int, -): ... + use_model_sensitive_rmsnorm: int = 0, +) -> None: ... @compile_ops("module_rmsnorm") @@ -101,8 +112,8 @@ def rmsnorm2d_fwd_with_dynamicquant( yscale: Tensor, weight: Tensor, epsilon: float, - use_model_sensitive_rmsnorm: int, -): ... + use_model_sensitive_rmsnorm: int = 0, +) -> None: ... @compile_ops("module_rmsnorm") @@ -114,5 +125,5 @@ def rmsnorm2d_fwd_with_add_dynamicquant( yscale: Tensor, weight: Tensor, epsilon: float, - use_model_sensitive_rmsnorm: int, -): ... + use_model_sensitive_rmsnorm: int = 0, +) -> None: ... diff --git a/aiter/ops/rope.py b/aiter/ops/rope.py index 39688d0fda..298beaee4f 100644 --- a/aiter/ops/rope.py +++ b/aiter/ops/rope.py @@ -17,7 +17,7 @@ def rope_fwd_impl( rotate_style: int, reuse_freqs_front_part: bool, nope_first: bool, -): +) -> None: """ Forward propagation of traditional RoPE (Rotary Position Embedding). Input and output should be in "sbhd" format and freqs should be in shape of [s, 1, 1, d // 2] @@ -37,7 +37,7 @@ def rope_bwd_impl( rotate_style: int, reuse_freqs_front_part: bool, nope_first: bool, -): +) -> None: """ Backward propagation of traditional RoPE (Rotary Position Embedding). Input and output should be in "sbhd" format and freqs should be in shape of [s, 1, 1, d // 2] @@ -59,7 +59,7 @@ def rope_2c_fwd_impl( rotate_style: int, reuse_freqs_front_part: bool, nope_first: bool, -): +) -> None: """ Forward propagation of traditional RoPE (Rotary Position Embedding) on two channels. Input and output should be in "sbhd" format and freqs should be in shape of [s, 1, 1, d // 2] @@ -81,7 +81,7 @@ def rope_2c_bwd_impl( rotate_style: int, reuse_freqs_front_part: bool, nope_first: bool, -): +) -> None: """ Backward propagation of traditional RoPE (Rotary Position Embedding) on two channels. Input and output should be in "sbhd" format and freqs should be in shape of [s, 1, 1, d // 2] @@ -102,7 +102,7 @@ def rope_cached_fwd_impl( rotate_style: int, reuse_freqs_front_part: bool, nope_first: bool, -): +) -> None: """ Forward propagation of RoPE (Rotary Position Embedding) with cached cos and sin. Input and output should be in "sbhd" format, and cos and sin should be in shape of [s, 1, 1, d // 2] @@ -123,7 +123,7 @@ def rope_cached_bwd_impl( rotate_style: int, reuse_freqs_front_part: bool, nope_first: bool, -): +) -> None: """ Backward propagation of RoPE (Rotary Position Embedding) with cached cos and sin. Input and output should be in "sbhd" format, and cos and sin should be in shape of [s, 1, 1, d // 2] @@ -146,7 +146,7 @@ def rope_cached_2c_fwd_impl( rotate_style: int, reuse_freqs_front_part: bool, nope_first: bool, -): +) -> None: """ Forward propagation of RoPE (Rotary Position Embedding) with cached cos and sin on two channels. Input and output should be in "sbhd" format, and cos and sin should be in shape of [s, 1, 1, d // 2] @@ -169,7 +169,7 @@ def rope_cached_2c_bwd_impl( rotate_style: int, reuse_freqs_front_part: bool, nope_first: bool, -): +) -> None: """ Backward propagation of RoPE (Rotary Position Embedding) with cached cos and sin on two channels. Input and output should be in "sbhd" format, and cos and sin should be in shape of [s, 1, 1, d // 2] @@ -191,7 +191,7 @@ def rope_cached_positions_fwd_impl( rotate_style: int, reuse_freqs_front_part: bool, nope_first: bool, -): +) -> None: """ Forward propagation of RoPE (Rotary Position Embedding) with cached cos and sin with positions and offsets on one channel. Offsets here is optional. Both positions and offsets should be in [s, b]. @@ -216,7 +216,7 @@ def rope_cached_positions_2c_fwd_impl( rotate_style: int, reuse_freqs_front_part: bool, nope_first: bool, -): +) -> None: """ Forward propagation of RoPE (Rotary Position Embedding) with cached cos and sin with positions and offsets on two channels. Offsets here is optional. Both positions and offsets should be in [s, b]. @@ -240,7 +240,7 @@ def rope_cached_positions_offsets_fwd_impl( rotate_style: int, reuse_freqs_front_part: bool, nope_first: bool, -): +) -> None: """ Forward propagation of RoPE (Rotary Position Embedding) with cached cos and sin with positions and offsets on one channel. Offsets here is optional. Both positions and offsets should be in [s, b]. @@ -266,7 +266,7 @@ def rope_cached_positions_offsets_2c_fwd_impl( rotate_style: int, reuse_freqs_front_part: bool, nope_first: bool, -): +) -> None: """ Forward propagation of RoPE (Rotary Position Embedding) with cached cos and sin with positions and offsets on two channels. Offsets here is optional. Both positions and offsets should be in [s, b]. @@ -288,7 +288,7 @@ def rope_thd_fwd_impl( rotate_style: int, reuse_freqs_front_part: bool, nope_first: bool, -): +) -> None: """ Forward propagation of RoPE (Rotary Position Embedding) with input sizes: (t, h, d). where t is cumulative sum of sequence lengths. @@ -310,7 +310,7 @@ def rope_thd_bwd_impl( rotate_style: int, reuse_freqs_front_part: bool, nope_first: bool, -): +) -> None: """ Backward propagation of RoPE (Rotary Position Embedding) with input sizes: (t, h, d). where t is cumulative sum of sequence lengths. @@ -336,7 +336,7 @@ def rope_2d_fwd_impl( rotate_style: int, reuse_freqs_front_part: bool, nope_first: bool, -): +) -> None: """ Forward propagation of RoPE (Rotary Position Embedding) with 2D image as input. Input and output should be in (b, s, h, d) where s = H * W. @@ -364,7 +364,7 @@ def rope_2d_bwd_impl( rotate_style: int, reuse_freqs_front_part: bool, nope_first: bool, -): +) -> None: """ Backward propagation of RoPE (Rotary Position Embedding) with 2D image as input. output_grads and input_grads should be in (b, s, h, d) where s = H * W. diff --git a/aiter/ops/topk.py b/aiter/ops/topk.py index f3992a26f1..0fd3f51690 100644 --- a/aiter/ops/topk.py +++ b/aiter/ops/topk.py @@ -3,6 +3,7 @@ # user interface +from typing import List import torch from ..jit.core import ( compile_ops, @@ -34,10 +35,30 @@ def grouped_topk( need_renorm: bool, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, -): ... +) -> None: ... -@compile_ops("module_moe_asm") +def gen_moe_fused_gate_fake_tensor( + input: torch.Tensor, + bias: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + topk: int, + n_share_experts_fusion: int, + routed_scaling_factor: float = 1.0, +) -> List[torch.Tensor]: + output = torch.empty_like( + topk_weights, dtype=topk_weights.dtype, device=topk_weights.device + ) + + indices = torch.empty_like(topk_ids, dtype=topk_ids.dtype, device=topk_ids.device) + + return [output, indices] + + +@compile_ops("module_moe_asm", gen_fake=gen_moe_fused_gate_fake_tensor) def moe_fused_gate( input: torch.Tensor, bias: torch.Tensor, @@ -48,7 +69,7 @@ def moe_fused_gate( topk: int, n_share_experts_fusion: int, routed_scaling_factor: float = 1.0, -) -> list[torch.Tensor]: ... +) -> List[torch.Tensor]: ... def biased_grouped_topk( diff --git a/csrc/ck_batched_gemm_bf16/batched_gemm_bf16.cu b/csrc/ck_batched_gemm_bf16/batched_gemm_bf16.cu index 2947b070ef..1a2afbfb62 100644 --- a/csrc/ck_batched_gemm_bf16/batched_gemm_bf16.cu +++ b/csrc/ck_batched_gemm_bf16/batched_gemm_bf16.cu @@ -137,7 +137,7 @@ BatchedKernel batched_dispatch(int B, int M, int N, int K) return batched_heuristic_dispatch(B, M, N, K); } -torch::Tensor batched_gemm_bf16( +void batched_gemm_bf16( torch::Tensor &XQ, torch::Tensor &WQ, torch::Tensor &Y, @@ -158,5 +158,5 @@ torch::Tensor batched_gemm_bf16( batched_dispatch(B, M, N, K)(XQ, WQ, Y, bias, KBatch); - return Y; + // return Y; } diff --git a/csrc/ck_batched_gemm_bf16/batched_gemm_bf16_tune.cu b/csrc/ck_batched_gemm_bf16/batched_gemm_bf16_tune.cu index 09d9960ee0..9c7b5f9612 100644 --- a/csrc/ck_batched_gemm_bf16/batched_gemm_bf16_tune.cu +++ b/csrc/ck_batched_gemm_bf16/batched_gemm_bf16_tune.cu @@ -50,7 +50,7 @@ BatchedKernel batched_dispatch(int id) } -torch::Tensor batched_gemm_bf16_tune( +void batched_gemm_bf16_tune( torch::Tensor &XQ, torch::Tensor &WQ, torch::Tensor &Y, @@ -75,5 +75,5 @@ torch::Tensor batched_gemm_bf16_tune( { TORCH_CHECK(false, "Unsupported output dtype!"); } - return Y; + // return Y; } diff --git a/csrc/ck_batched_gemm_bf16/include/batched_gemm_bf16.h b/csrc/ck_batched_gemm_bf16/include/batched_gemm_bf16.h index 87804eb59f..4a81221181 100644 --- a/csrc/ck_batched_gemm_bf16/include/batched_gemm_bf16.h +++ b/csrc/ck_batched_gemm_bf16/include/batched_gemm_bf16.h @@ -3,14 +3,14 @@ // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #include #include -torch::Tensor batched_gemm_bf16( +void batched_gemm_bf16( torch::Tensor &XQ, torch::Tensor &WQ, torch::Tensor &Y, std::optional bias, int splitK); -torch::Tensor batched_gemm_bf16_tune( +void batched_gemm_bf16_tune( torch::Tensor &XQ, torch::Tensor &WQ, torch::Tensor &Y, diff --git a/csrc/ck_gemm_a4w4_blockscale/gemm_a4w4_blockscale.cu b/csrc/ck_gemm_a4w4_blockscale/gemm_a4w4_blockscale.cu index 3fc3385871..79fba67193 100755 --- a/csrc/ck_gemm_a4w4_blockscale/gemm_a4w4_blockscale.cu +++ b/csrc/ck_gemm_a4w4_blockscale/gemm_a4w4_blockscale.cu @@ -76,7 +76,7 @@ BlockwiseKernel blockscale_dispatch(int M, int N, int K) return a4w4_blockscale_256x64x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8_2x2_intrawave_v3; } -torch::Tensor gemm_a4w4_blockscale( +void gemm_a4w4_blockscale( torch::Tensor& XQ, torch::Tensor& WQ, torch::Tensor& x_scale, @@ -104,5 +104,5 @@ torch::Tensor gemm_a4w4_blockscale( { TORCH_CHECK(false, "Unsupported scales/output dtype!"); } - return Y; + // return Y; } diff --git a/csrc/ck_gemm_a4w4_blockscale/gemm_a4w4_blockscale_tune.py b/csrc/ck_gemm_a4w4_blockscale/gemm_a4w4_blockscale_tune.py index 4267503093..14453f6c33 100755 --- a/csrc/ck_gemm_a4w4_blockscale/gemm_a4w4_blockscale_tune.py +++ b/csrc/ck_gemm_a4w4_blockscale/gemm_a4w4_blockscale_tune.py @@ -79,10 +79,8 @@ def kernel_instance_test(x, weight, x_scale, w_scale, out, kernel_id, splitK=0): def run_gemm_a4w4_blockscale(x, weight, x_scale, w_scale, out, kernel_id, splitK): m, k = x.shape n, k = weight.shape - res = aiter.gemm_a4w4_blockscale_tune( - x, weight, x_scale, w_scale, out, kernel_id, splitK - ) - return res[:m] + aiter.gemm_a4w4_blockscale_tune(x, weight, x_scale, w_scale, out, kernel_id, splitK) + return out[:m] def run_gemm_a4w4_blockscale_asm( @@ -101,7 +99,7 @@ def run_gemm_a4w4_blockscale_asm( if splitK is not None and splitK > 0: out_reset = torch.zeros(out.shape[0], out.shape[1], dtype=dtype) out = out_reset - res = aiter.gemm_a4w4_asm( + aiter.gemm_a4w4_asm( x, weight_shuffle, x_scale, @@ -112,7 +110,7 @@ def run_gemm_a4w4_blockscale_asm( bpreshuffle=bpreshuffle, log2_k_split=splitK, ) - return res[:m] + return out[:m] def generate_data(m, n, k, useSplitK=False, dtype=dtypes.bf16): diff --git a/csrc/ck_gemm_a4w4_blockscale/include/gemm_a4w4_blockscale.h b/csrc/ck_gemm_a4w4_blockscale/include/gemm_a4w4_blockscale.h index c9a53fefbf..be0d24a93b 100644 --- a/csrc/ck_gemm_a4w4_blockscale/include/gemm_a4w4_blockscale.h +++ b/csrc/ck_gemm_a4w4_blockscale/include/gemm_a4w4_blockscale.h @@ -3,14 +3,14 @@ // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include -torch::Tensor gemm_a4w4_blockscale(torch::Tensor& A, +void gemm_a4w4_blockscale(torch::Tensor& A, torch::Tensor& B, torch::Tensor& a_scale, torch::Tensor& b_scale, torch::Tensor& C, int splitK); -torch::Tensor gemm_a4w4_blockscale_tune(torch::Tensor& XQ, +void gemm_a4w4_blockscale_tune(torch::Tensor& XQ, torch::Tensor& WQ, torch::Tensor& x_scale, torch::Tensor& w_scale, diff --git a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.cu b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.cu index 25fcbf3827..1cea6a9497 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.cu +++ b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.cu @@ -6,6 +6,7 @@ #include "gemm_moe_ck2stages_lookup.h" #include "gemm_moe_ck2stages.h" #include "gemm_moe_ck2stages_heuristic_dispatch.hpp" +#include "moe_ck.h" #include using MoeKernelMap = std::unordered_map; @@ -52,7 +53,9 @@ void ck_moe_stage1(torch::Tensor &hidden_states, // [m, k], input token std::optional w1_scale = std::nullopt, // [e, 1, n], gate(up) scale std::optional a1_scale = std::nullopt, // [m, 1], token scale std::optional block_m = 32, - std::optional sorted_weights = std::nullopt) + std::optional sorted_weights = std::nullopt, + int quant_type = 0, + int activation = 0) { const at::cuda::OptionalCUDAGuard device_guard(device_of(out)); at::cuda::getCurrentCUDAStream().stream(); @@ -107,7 +110,9 @@ void ck_moe_stage2(torch::Tensor &inter_states, // [m, k], input token std::optional w2_scale = std::nullopt, // [e, 1, n], gate(up) scale std::optional a2_scale = std::nullopt, // [m, 1], token scale std::optional block_m = 32, - std::optional sorted_weights = std::nullopt) + std::optional sorted_weights = std::nullopt, + int quant_type = 0, + int activation = 0) { TORCH_CHECK(out.dtype() == at::ScalarType::BFloat16 || out.dtype() == at::ScalarType::Half, "Out dtype only support BFloat16/Float16!") diff --git a/csrc/include/asm_gemm_a4w4.h b/csrc/include/asm_gemm_a4w4.h index 54d23eb8d2..4371226cad 100644 --- a/csrc/include/asm_gemm_a4w4.h +++ b/csrc/include/asm_gemm_a4w4.h @@ -3,7 +3,7 @@ // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include -torch::Tensor gemm_a4w4_asm(torch::Tensor& A, // A:[M, K/2] f4x2 +void gemm_a4w4_asm(torch::Tensor& A, // A:[M, K/2] f4x2 torch::Tensor& B, // B:[N, K/2] f4x2 torch::Tensor& A_scale, // A_scale:[M, K/32] e8m0 paded torch::Tensor& B_scale, // B_scale:[N, K/32] e8m0 paded diff --git a/csrc/include/custom_all_reduce.h b/csrc/include/custom_all_reduce.h index b634c8cea1..8dbe8df9fb 100644 --- a/csrc/include/custom_all_reduce.h +++ b/csrc/include/custom_all_reduce.h @@ -35,7 +35,7 @@ int64_t meta_size(); void register_buffer(fptr_t _fa, torch::Tensor &t, const std::vector &handles, const std::vector &offsets); -std::tuple get_graph_buffer_ipc_meta( +std::vector get_graph_buffer_ipc_meta( fptr_t _fa); void register_graph_buffers(fptr_t _fa, const std::vector &handles, const std::vector &offsets); diff --git a/csrc/include/moe_ck.h b/csrc/include/moe_ck.h index 39667f7bd4..c3e023bb96 100644 --- a/csrc/include/moe_ck.h +++ b/csrc/include/moe_ck.h @@ -16,7 +16,9 @@ void ck_moe_stage1(torch::Tensor& hidden_states, // [m, k], input token std::optional w1_scale, // [e, 1, n], gate(up) scale std::optional a1_scale, // [m, 1], token scale std::optional block_m, - std::optional sorted_weights); + std::optional sorted_weights, + int quant_type, + int activation); void ck_moe_stage2(torch::Tensor& inter_states, // [m, k], input token torch::Tensor& w1, // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) @@ -30,4 +32,6 @@ void ck_moe_stage2(torch::Tensor& inter_states, // [m, k], input token std::optional w2_scale, // [e, 1, n], gate(up) scale std::optional a2_scale, // [m, 1], token scale std::optional block_m, - std::optional sorted_weights); // [max_num_tokens_padded]); + std::optional sorted_weights, // [max_num_tokens_padded]); + int quant_type, + int activation); \ No newline at end of file diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index 30c5599945..ded3f6c313 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -558,11 +558,14 @@ py::arg("num_valid_ids"), \ py::arg("out"), \ py::arg("topk"), \ - py::arg("kernelName"), \ + py::arg("kernelName") = "", \ py::arg("w1_scale") = std::nullopt, \ py::arg("a1_scale") = std::nullopt, \ py::arg("block_m") = 32, \ - py::arg("sorted_weights") = std::nullopt); \ + py::arg("sorted_weights") = std::nullopt, \ + py::arg("quant_type") = 0, \ + py::arg("activation") = 0); \ + \ \ m.def("ck_moe_stage2", \ &ck_moe_stage2, \ @@ -574,11 +577,14 @@ py::arg("num_valid_ids"), \ py::arg("out"), \ py::arg("topk"), \ - py::arg("kernelName"), \ + py::arg("kernelName") = "", \ py::arg("w2_scale") = std::nullopt, \ py::arg("a2_scale") = std::nullopt, \ py::arg("block_m") = 32, \ - py::arg("sorted_weights") = std::nullopt); + py::arg("sorted_weights") = std::nullopt, \ + py::arg("quant_type") = 0, \ + py::arg("activation") = 0); \ + #define MOE_CKTILE_2STAGES_PYBIND \ @@ -836,7 +842,7 @@ py::arg("input"), \ py::arg("weight"), \ py::arg("bias"), \ - py::arg("epsilon"), \ + py::arg("epsilon") = 1e-5f, \ py::arg("x_bias") = std::nullopt); \ m.def("layernorm2d_fwd_with_add", \ &layernorm2d_with_add, \ diff --git a/csrc/kernels/custom_all_reduce.cu b/csrc/kernels/custom_all_reduce.cu index 1f9305ba44..58f436a787 100644 --- a/csrc/kernels/custom_all_reduce.cu +++ b/csrc/kernels/custom_all_reduce.cu @@ -165,7 +165,7 @@ void register_buffer(fptr_t _fa, torch::Tensor &t, fa->register_buffer(handles, offsets, t.data_ptr()); } -std::tuple get_graph_buffer_ipc_meta( +std::vector get_graph_buffer_ipc_meta( fptr_t _fa) { auto fa = reinterpret_cast(_fa); diff --git a/csrc/py_itfs_cu/asm_gemm_a4w4.cu b/csrc/py_itfs_cu/asm_gemm_a4w4.cu index 0086bd323e..cd22a6b2cb 100644 --- a/csrc/py_itfs_cu/asm_gemm_a4w4.cu +++ b/csrc/py_itfs_cu/asm_gemm_a4w4.cu @@ -155,7 +155,7 @@ std::tuple get_heuristic_kernel(int M, // A4W4 asm gemm kernel // D=A*B*alpha+beta*C -torch::Tensor gemm_a4w4_asm(torch::Tensor& A, // A:[M, K/2] f4x2 +void gemm_a4w4_asm(torch::Tensor& A, // A:[M, K/2] f4x2 torch::Tensor& B, // B:[N, K/2] f4x2 torch::Tensor& A_scale, // A_scale:[M, K/32] e8m0 paded torch::Tensor& B_scale, // B_scale:[N, K/32] e8m0 paded @@ -288,5 +288,5 @@ torch::Tensor gemm_a4w4_asm(torch::Tensor& A, // A:[M, K/2] f4x2 1, // bdy 1, // bdz stream}); - return out; + // return out; } diff --git a/gradlib/csrc/hipbsolgemm.cu b/gradlib/csrc/hipbsolgemm.cu index edecacbe7c..37a390de82 100644 --- a/gradlib/csrc/hipbsolgemm.cu +++ b/gradlib/csrc/hipbsolgemm.cu @@ -327,7 +327,7 @@ hipblasStatus_t hipblasLtMatmul_sol_wrapper( torch::Tensor hipb_mm(const torch::Tensor &mat1, const torch::Tensor &mat2, const int solution_index, std::optional bias, - std::optional out_dtype, + std::optional out_dtype, std::optional scaleA, std::optional scaleB, std::optional scaleOut) @@ -347,7 +347,7 @@ torch::Tensor hipb_mm(const torch::Tensor &mat1, const torch::Tensor &mat2, auto inDtype{mat1.options().dtype().toScalarType()}; auto outDtype{ out_dtype.has_value() - ? torch::python::detail::py_object_to_dtype(out_dtype.value()) + ? out_dtype.value() : inDtype}; auto options{at::TensorOptions().dtype(outDtype).device(at::kCUDA)}; auto result{torch::empty({mat1_sizes[0], mat2_sizes[1]}, options)}; @@ -446,7 +446,7 @@ torch::Tensor hipb_mm(const torch::Tensor &mat1, const torch::Tensor &mat2, std::vector hipb_findallsols( const torch::Tensor &mat1, const torch::Tensor &mat2, std::optional bias, - std::optional out_dtype, + std::optional out_dtype, std::optional scaleA, std::optional scaleB, std::optional scaleC) @@ -465,7 +465,7 @@ std::vector hipb_findallsols( auto inType{mat1.options().dtype().toScalarType()}; auto outType{ out_dtype.has_value() - ? torch::python::detail::py_object_to_dtype(out_dtype.value()) + ? out_dtype.value() : inType}; auto options{at::TensorOptions().dtype(outType).device(at::kCUDA)}; diff --git a/gradlib/include/hipbsolgemm.cuh b/gradlib/include/hipbsolgemm.cuh index 90fd9c9b25..f457228e9d 100644 --- a/gradlib/include/hipbsolgemm.cuh +++ b/gradlib/include/hipbsolgemm.cuh @@ -41,7 +41,7 @@ void hipb_destroy_extension(); torch::Tensor hipb_mm(const torch::Tensor &mat1, const torch::Tensor &mat2, const int solution_index, std::optional bias = std::nullopt, - std::optional out_dtype = std::nullopt, + std::optional out_dtype = std::nullopt, std::optional scaleA = std::nullopt, std::optional scaleB = std::nullopt, std::optional scaleOut = std::nullopt); @@ -49,7 +49,7 @@ torch::Tensor hipb_mm(const torch::Tensor &mat1, const torch::Tensor &mat2, std::vector hipb_findallsols( const torch::Tensor &mat1, const torch::Tensor &mat2, std::optional bias = std::nullopt, - std::optional out_dtype = std::nullopt, + std::optional out_dtype = std::nullopt, std::optional scaleA = std::nullopt, std::optional scaleB = std::nullopt, std::optional scaleC = std::nullopt); diff --git a/op_tests/test_aiter_add.py b/op_tests/test_aiter_add.py index b958d139d8..32d0961d2b 100644 --- a/op_tests/test_aiter_add.py +++ b/op_tests/test_aiter_add.py @@ -4,6 +4,7 @@ import torch import aiter from torch.profiler import profile, ProfilerActivity +from aiter.test_common import checkAllclose from aiter import dtypes input_shapes = [ @@ -160,8 +161,9 @@ # cache_flush1 = torch.randn(10000, 10000, requires_grad=True, device="cuda", dtype=dtypes.fp32).to(dtypes.i32) # output = torch.empty_like(tensor1) output = aiter.add(tensor0, tensor1) - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + checkAllclose(result, output, msg="add") print(torch.equal(result, output)) # print("result:", result) # print("output:", output) diff --git a/op_tests/test_aiter_sigmoid.py b/op_tests/test_aiter_sigmoid.py index 24cf7fcacf..62beb90a98 100644 --- a/op_tests/test_aiter_sigmoid.py +++ b/op_tests/test_aiter_sigmoid.py @@ -4,6 +4,7 @@ import torch import aiter from aiter import dtypes +from aiter.test_common import checkAllclose # from ater.test_common import checkAllclose, perftest from torch.profiler import profile, ProfilerActivity @@ -53,5 +54,6 @@ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) print(torch.equal(result, output)) +checkAllclose(result, output, msg="sigmoid") print("result:", result) print("output:", output) diff --git a/op_tests/test_gemm_a4w4.py b/op_tests/test_gemm_a4w4.py index ba76f6c793..17e1da3fe2 100644 --- a/op_tests/test_gemm_a4w4.py +++ b/op_tests/test_gemm_a4w4.py @@ -39,7 +39,8 @@ def run_torch(x, w, x_scales, w_scales, dtype): @perftest() def run_gemm_ck(x, weight, x_scale, w_scale, out): - return aiter.gemm_a4w4_blockscale(x, weight, x_scale, w_scale, out) + aiter.gemm_a4w4_blockscale(x, weight, x_scale, w_scale, out) + return out @perftest() @@ -69,7 +70,7 @@ def run_gemm_asm( # ) # out = out_reset - return aiter.gemm_a4w4_asm( + aiter.gemm_a4w4_asm( x, weightshuffle, x_scale, @@ -80,6 +81,7 @@ def run_gemm_asm( bpreshuffle=bpreshuffle, log2_k_split=log2_k_split, ) + return out @benchmark()