Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 186 additions & 2 deletions aiter/jit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,14 +483,142 @@ def convert(d_ops: dict):
)


MANUAL_SCHEMA_OPS = [
"register_graph_buffers",
"module_moe_ck2stages",
"mha_fwd",
"fmha_v3_fwd",
"mha_varlen_fwd",
"mha_bwd",
"fmha_v3_bwd",
"mha_varlen_bwd",
"fmha_v3_varlen_bwd",
"mha_batch_prefill",
"hipb_findallsols",
"rocb_findallsols",
"_ActivationType",
"_QuantType",
"init_custom_ar",
]

NONE_WRAPPED_OP = [
"hipb_create_extension",
"hipb_destroy_extension",
"getHipblasltKernelName",
"rocb_create_extension",
"rocb_destroy_extension",
"get_meta_buffer_ipc_handle",
"get_graph_buffer_ipc_meta",
"_ActivationType",
"_QuantType",
"allocate_meta_buffer",
"dispose",
"meta_size",
"get_padded_m",
]


def generate_schema(func) -> str:
import inspect
import torch
from typing import Optional, Union, List, get_origin, get_args

sig = inspect.signature(func)
parameters = []

for idx, (name, param) in enumerate(sig.parameters.items()):
param_type = param.annotation
flag = True

if param_type is torch.Tensor:
type_str = f"Tensor(a{idx}!)"
elif param_type == Optional[torch.Tensor]:
type_str = f"Tensor(a{idx}!)?"
elif get_origin(param_type) is Union and torch.Tensor in get_args(param_type):
type_str = f"Tensor(a{idx}!)?"
elif param_type in (torch.SymInt, int):
type_str = "SymInt"
elif param_type in (float, bool, str):
type_str = param_type.__name__
elif param_type == Optional[torch.Generator]:
type_str = "Generator?"
elif (
get_origin(param_type) in (list, List)
and get_args(param_type)[0] is torch.Tensor
):
type_str = f"Tensor(a{idx}!)[]"
elif get_origin(param_type) in (list, List) and get_args(param_type)[0] is int:
type_str = "int[]"
else:
type_str = "*"
flag = False
if flag:
param_str = f"{type_str} {name}"

if param.default != inspect.Parameter.empty:
if param.default is None:
param_str += "=None"
else:
param_str += f"={param.default}"
else:
param_str = f"{type_str} "

parameters.append(param_str)
return_annotation = sig.return_annotation
return_type = ""
if return_annotation is type(None) or return_annotation is None:
return_type = "()"
elif return_annotation is torch.Tensor:
return_type = "Tensor"
elif (
get_origin(return_annotation) is list and get_args(return_annotation)[0] is int
):
return_type = "int[]"
elif return_annotation is int:
return_type = "int"
elif return_annotation is float:
return_type = "float"
elif return_annotation is bool:
return_type = "bool"
elif (
get_origin(return_annotation) is list
and get_args(return_annotation)[0] is torch.Tensor
):
return_type = "Tensor[]"

schema = f"({', '.join(parameters)}) -> {return_type}"

return schema


def compile_ops(
_md_name: str,
fc_name: Optional[str] = None,
gen_func: Optional[Callable[..., dict[str, Any]]] = None,
gen_fake: Optional[Callable[..., Any]] = None,
):

def decorator(func):
import torch
from csrc.cpp_itfs.torch_utils import aiter_lib
import torch.library
import inspect

func.arg_checked = False

schema = ""
if func.__name__ in MANUAL_SCHEMA_OPS:
schema = generate_schema(func)
else:
sig = inspect.signature(func)
mutates_args = []
for name, param in sig.parameters.items():
if param.annotation is torch.Tensor:
mutates_args.append(name)
sig = torch.library.infer_schema(func, mutates_args="unknown")
schema = f"{sig}"
loadName = func.__name__

@functools.wraps(func)
def wrapper(*args, custom_build_args={}, **kwargs):
loadName = fc_name
Expand Down Expand Up @@ -565,6 +693,16 @@ def wrapper(*args, custom_build_args={}, **kwargs):
op = getattr(module, loadName)
else:
return None
activation_index = 0
quant_index = 0
activation_list = [
"fmoe_g1u1",
"fmoe_int8_g1u0",
"fmoe_g1u1_tkw1",
"fmoe_fp8_blockscale_g1u1",
"moe_stage1_g1u1",
]
quant_list = ["moe_stage1_g1u1"]

def check_args():
get_asm_dir()
Expand All @@ -587,7 +725,10 @@ def check_args():
func.__signature__ = sig
ann = {k: v.annotation for k, v in sig.parameters.items()}
ann["return"] = sig.return_annotation

if loadName in activation_list:
return True
if loadName in quant_list:
return True
callargs = inspect.getcallargs(func, *args, **kwargs)
for el, arg in callargs.items():
expected_type = ann[el]
Expand Down Expand Up @@ -632,8 +773,51 @@ def check_args():

log_args(func, *args, **kwargs)

sig = inspect.signature(func)
params = list(sig.parameters.keys())
if loadName in activation_list:
activation_index = params.index("activation")
args_list = list(args)
from aiter import ActivationType, QuantType

if len(args_list) > activation_index and isinstance(
args_list[activation_index], int
):
args_list[activation_index] = ActivationType(
args_list[activation_index]
)
args = tuple(args_list)

if loadName in quant_list:
quant_index = params.index("quant_type")
args_list = list(args)
from aiter import ActivationType, QuantType

if len(args_list) > quant_index and isinstance(
args_list[quant_index], int
):
args_list[quant_index] = QuantType(args_list[quant_index])
args = tuple(args_list)
return op(*args, **kwargs)

return wrapper
def abstract_impl(*args, custom_build_args={}, **kwargs):
if gen_fake is not None:
return gen_fake(*args, **kwargs)
return func(*args, **kwargs)

if loadName in NONE_WRAPPED_OP:
return wrapper

if not hasattr(torch.ops.aiter, f"wrapper_{loadName}"):
op_schema = f"aiter::wrapper_{loadName}" + schema
aiter_lib.define(op_schema)
aiter_lib.impl(f"wrapper_{loadName}", wrapper, "CUDA")
aiter_lib.impl(f"wrapper_{loadName}", wrapper, "CPU")
aiter_lib._register_fake(f"wrapper_{loadName}", abstract_impl)

def wrapper_return(*args, **kwargs):
return getattr(torch.ops.aiter, f"wrapper_{loadName}")(*args, **kwargs)

return wrapper_return

return decorator
8 changes: 4 additions & 4 deletions aiter/ops/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@


@compile_ops("module_activation")
def silu_and_mul(out: Tensor, input: Tensor): ...
def silu_and_mul(out: Tensor, input: Tensor) -> None: ...


@compile_ops("module_activation")
def scaled_silu_and_mul(out: Tensor, input: Tensor, scale: Tensor): ...
def scaled_silu_and_mul(out: Tensor, input: Tensor, scale: Tensor) -> None: ...


@compile_ops("module_activation")
def gelu_and_mul(out: Tensor, input: Tensor): ...
def gelu_and_mul(out: Tensor, input: Tensor) -> None: ...


@compile_ops("module_activation")
def gelu_tanh_and_mul(out: Tensor, input: Tensor): ...
def gelu_tanh_and_mul(out: Tensor, input: Tensor) -> None: ...
73 changes: 63 additions & 10 deletions aiter/ops/aiter_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ..jit.core import compile_ops, AITER_CSRC_DIR
from functools import partial
from typing import Any
import torch

MD_NAME = "module_aiter_operator"

Expand All @@ -20,47 +21,99 @@ def cmdGenFunc(op_name: str, input: Tensor, other: Tensor) -> dict[str, Any]:
}


def binary_fake_shape(input: Tensor, other: Tensor) -> Tensor:
shape1 = list(input.shape)
shape2 = list(other.shape)

max_dim = max(len(shape1), len(shape2))
shape1 = [1] * (max_dim - len(shape1)) + shape1
shape2 = [1] * (max_dim - len(shape2)) + shape2

result_shape = []
for dim1, dim2 in zip(shape1, shape2):
if dim1 == 1:
result_shape.append(dim2)
elif dim2 == 1:
result_shape.append(dim1)
elif dim1 == dim2:
result_shape.append(dim1)
else:
raise RuntimeError(
f"Incompatible shapes for binary operator: {input.shape} and {other.shape}"
)

return torch.empty(
size=result_shape,
dtype=input.dtype,
device=input.device,
)


def sigmoid_fake_shape(input: torch.Tensor) -> torch.Tensor:
return torch.empty(
size=input.shape,
dtype=input.dtype,
device=input.device,
)


binary_add_build_args = partial(cmdGenFunc, "add")
binary_sub_build_args = partial(cmdGenFunc, "sub")
binary_mul_build_args = partial(cmdGenFunc, "mul")
binary_div_build_args = partial(cmdGenFunc, "div")


@compile_ops("module_aiter_operator", gen_func=binary_add_build_args)
@compile_ops(
"module_aiter_operator", gen_func=binary_add_build_args, gen_fake=binary_fake_shape
)
def add(input: Tensor, other: Tensor) -> Tensor: ...


@compile_ops("module_aiter_operator", gen_func=binary_sub_build_args)
@compile_ops(
"module_aiter_operator", gen_func=binary_sub_build_args, gen_fake=binary_fake_shape
)
def sub(input: Tensor, other: Tensor) -> Tensor: ...


@compile_ops("module_aiter_operator", gen_func=binary_mul_build_args)
@compile_ops(
"module_aiter_operator", gen_func=binary_mul_build_args, gen_fake=binary_fake_shape
)
def mul(input: Tensor, other: Tensor) -> Tensor: ...


@compile_ops("module_aiter_operator", gen_func=binary_div_build_args)
@compile_ops(
"module_aiter_operator", gen_func=binary_div_build_args, gen_fake=binary_fake_shape
)
def div(input: Tensor, other: Tensor) -> Tensor: ...


@compile_ops("module_aiter_operator", gen_func=binary_add_build_args)
@compile_ops(
"module_aiter_operator", gen_func=binary_add_build_args, gen_fake=binary_fake_shape
)
def add_(input: Tensor, other: Tensor) -> Tensor: ...


@compile_ops("module_aiter_operator", gen_func=binary_sub_build_args)
@compile_ops(
"module_aiter_operator", gen_func=binary_sub_build_args, gen_fake=binary_fake_shape
)
def sub_(input: Tensor, other: Tensor) -> Tensor: ...


@compile_ops("module_aiter_operator", gen_func=binary_mul_build_args)
@compile_ops(
"module_aiter_operator", gen_func=binary_mul_build_args, gen_fake=binary_fake_shape
)
def mul_(input: Tensor, other: Tensor) -> Tensor: ...


@compile_ops("module_aiter_operator", gen_func=binary_div_build_args)
@compile_ops(
"module_aiter_operator", gen_func=binary_div_build_args, gen_fake=binary_fake_shape
)
def div_(input: Tensor, other: Tensor) -> Tensor: ...


@compile_ops("module_aiter_unary")
@compile_ops("module_aiter_unary", gen_fake=sigmoid_fake_shape)
def sigmoid(input: Tensor) -> Tensor: ...


@compile_ops("module_aiter_unary")
@compile_ops("module_aiter_unary", gen_fake=sigmoid_fake_shape)
def tanh(input: Tensor) -> Tensor: ...
Loading
Loading