Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 51 additions & 33 deletions aiter/jit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

AITER_REBUILD = int(os.environ.get("AITER_REBUILD", "0"))

aiter_lib = None


def mp_lock(
lockPath: str,
Expand Down Expand Up @@ -514,6 +516,8 @@ def convert(d_ops: dict):
"dispose",
"meta_size",
"get_padded_m",
"compile_mha_fwd",
"compile_mha_bwd",
]


Expand Down Expand Up @@ -598,32 +602,8 @@ def compile_ops(
):

def decorator(func):
import torch
from csrc.cpp_itfs.torch_utils import aiter_lib
import torch.library
import inspect

func.arg_checked = False

schema = ""
if func.__name__ in MANUAL_SCHEMA_OPS:
schema = generate_schema(func)
else:
sig = inspect.signature(func)
mutates_args = []
for name, param in sig.parameters.items():
if param.annotation is torch.Tensor:
mutates_args.append(name)
if hasattr(torch.library, "infer_schema"):
sig = torch.library.infer_schema(func, mutates_args="unknown")
else:
# for pytorch 2.4
import torch._custom_op.impl

sig = torch._custom_op.impl.infer_schema(func, mutates_args)
schema = f"{sig}"
loadName = func.__name__

@functools.wraps(func)
def wrapper(*args, custom_build_args={}, **kwargs):
loadName = fc_name
Expand Down Expand Up @@ -777,6 +757,8 @@ def check_args():

log_args(func, *args, **kwargs)

import inspect

sig = inspect.signature(func)
params = list(sig.parameters.keys())
if loadName in activation_list:
Expand Down Expand Up @@ -809,19 +791,55 @@ def abstract_impl(*args, custom_build_args={}, **kwargs):
return gen_fake(*args, **kwargs)
return func(*args, **kwargs)

if loadName in NONE_WRAPPED_OP:
if func.__name__ in NONE_WRAPPED_OP:
return wrapper

if not hasattr(torch.ops.aiter, f"wrapper_{loadName}"):
op_schema = f"aiter::wrapper_{loadName}" + schema
aiter_lib.define(op_schema)
aiter_lib.impl(f"wrapper_{loadName}", wrapper, "CUDA")
aiter_lib.impl(f"wrapper_{loadName}", wrapper, "CPU")
aiter_lib._register_fake(f"wrapper_{loadName}", abstract_impl)
def wrapper_register(func):
import torch
import torch.library
import inspect
from torch.library import Library

global aiter_lib
aiter_lib = Library("aiter", "FRAGMENT") if aiter_lib is None else aiter_lib
schema = ""
if func.__name__ in MANUAL_SCHEMA_OPS:
schema = generate_schema(func)
else:
sig = inspect.signature(func)
mutates_args = []
for name, param in sig.parameters.items():
if param.annotation is torch.Tensor:
mutates_args.append(name)
if hasattr(torch.library, "infer_schema"):
sig = torch.library.infer_schema(func, mutates_args="unknown")
else:
# for pytorch 2.4
import torch._custom_op.impl

sig = torch._custom_op.impl.infer_schema(func, mutates_args)
schema = f"{sig}"
return schema

schema = wrapper_register(func)

def wrapper_custom(*args, custom_build_args={}, **kwargs):
import torch

loadName = func.__name__
if not hasattr(torch.ops.aiter, f"wrapper_{loadName}"):
op_schema = f"aiter::wrapper_{loadName}" + schema
aiter_lib.define(op_schema, tags=())
aiter_lib.impl(
f"aiter::wrapper_{loadName}", wrapper, dispatch_key="CUDA"
)
aiter_lib.impl(
f"aiter::wrapper_{loadName}", wrapper, dispatch_key="CPU"
)
aiter_lib._register_fake(f"wrapper_{loadName}", abstract_impl)

def wrapper_return(*args, **kwargs):
return getattr(torch.ops.aiter, f"wrapper_{loadName}")(*args, **kwargs)

return wrapper_return
return wrapper_custom

return decorator
Binary file added op_tests/cpp/mha/benchmark_mha_fwd
Binary file not shown.