diff --git a/aiter/jit/core.py b/aiter/jit/core.py index 712feea0d3..5ca666eb21 100644 --- a/aiter/jit/core.py +++ b/aiter/jit/core.py @@ -24,6 +24,8 @@ AITER_REBUILD = int(os.environ.get("AITER_REBUILD", "0")) +aiter_lib = None + def mp_lock( lockPath: str, @@ -514,6 +516,8 @@ def convert(d_ops: dict): "dispose", "meta_size", "get_padded_m", + "compile_mha_fwd", + "compile_mha_bwd", ] @@ -598,32 +602,8 @@ def compile_ops( ): def decorator(func): - import torch - from csrc.cpp_itfs.torch_utils import aiter_lib - import torch.library - import inspect - func.arg_checked = False - schema = "" - if func.__name__ in MANUAL_SCHEMA_OPS: - schema = generate_schema(func) - else: - sig = inspect.signature(func) - mutates_args = [] - for name, param in sig.parameters.items(): - if param.annotation is torch.Tensor: - mutates_args.append(name) - if hasattr(torch.library, "infer_schema"): - sig = torch.library.infer_schema(func, mutates_args="unknown") - else: - # for pytorch 2.4 - import torch._custom_op.impl - - sig = torch._custom_op.impl.infer_schema(func, mutates_args) - schema = f"{sig}" - loadName = func.__name__ - @functools.wraps(func) def wrapper(*args, custom_build_args={}, **kwargs): loadName = fc_name @@ -777,6 +757,8 @@ def check_args(): log_args(func, *args, **kwargs) + import inspect + sig = inspect.signature(func) params = list(sig.parameters.keys()) if loadName in activation_list: @@ -809,19 +791,55 @@ def abstract_impl(*args, custom_build_args={}, **kwargs): return gen_fake(*args, **kwargs) return func(*args, **kwargs) - if loadName in NONE_WRAPPED_OP: + if func.__name__ in NONE_WRAPPED_OP: return wrapper - if not hasattr(torch.ops.aiter, f"wrapper_{loadName}"): - op_schema = f"aiter::wrapper_{loadName}" + schema - aiter_lib.define(op_schema) - aiter_lib.impl(f"wrapper_{loadName}", wrapper, "CUDA") - aiter_lib.impl(f"wrapper_{loadName}", wrapper, "CPU") - aiter_lib._register_fake(f"wrapper_{loadName}", abstract_impl) + def wrapper_register(func): + import torch + import torch.library + import inspect + from torch.library import Library + + global aiter_lib + aiter_lib = Library("aiter", "FRAGMENT") if aiter_lib is None else aiter_lib + schema = "" + if func.__name__ in MANUAL_SCHEMA_OPS: + schema = generate_schema(func) + else: + sig = inspect.signature(func) + mutates_args = [] + for name, param in sig.parameters.items(): + if param.annotation is torch.Tensor: + mutates_args.append(name) + if hasattr(torch.library, "infer_schema"): + sig = torch.library.infer_schema(func, mutates_args="unknown") + else: + # for pytorch 2.4 + import torch._custom_op.impl + + sig = torch._custom_op.impl.infer_schema(func, mutates_args) + schema = f"{sig}" + return schema + + schema = wrapper_register(func) + + def wrapper_custom(*args, custom_build_args={}, **kwargs): + import torch + + loadName = func.__name__ + if not hasattr(torch.ops.aiter, f"wrapper_{loadName}"): + op_schema = f"aiter::wrapper_{loadName}" + schema + aiter_lib.define(op_schema, tags=()) + aiter_lib.impl( + f"aiter::wrapper_{loadName}", wrapper, dispatch_key="CUDA" + ) + aiter_lib.impl( + f"aiter::wrapper_{loadName}", wrapper, dispatch_key="CPU" + ) + aiter_lib._register_fake(f"wrapper_{loadName}", abstract_impl) - def wrapper_return(*args, **kwargs): return getattr(torch.ops.aiter, f"wrapper_{loadName}")(*args, **kwargs) - return wrapper_return + return wrapper_custom return decorator diff --git a/op_tests/cpp/mha/benchmark_mha_fwd b/op_tests/cpp/mha/benchmark_mha_fwd new file mode 100755 index 0000000000..6b79b72762 Binary files /dev/null and b/op_tests/cpp/mha/benchmark_mha_fwd differ