diff --git a/aiter/jit/utils/cpp_extension.py b/aiter/jit/utils/cpp_extension.py index 5799e47205..2bd4c14a67 100644 --- a/aiter/jit/utils/cpp_extension.py +++ b/aiter/jit/utils/cpp_extension.py @@ -1534,7 +1534,20 @@ def _write_ninja_file_to_build_library( extra_ldflags = [flag.strip() for flag in extra_ldflags] extra_include_paths = [flag.strip() for flag in extra_include_paths] # include_paths() gives us the location of torch/extension.h - system_includes = [] if torch_exclude else include_paths(with_cuda) + # system_includes = [] if torch_exclude else include_paths(with_cuda) + import torch + + _TORCH_PATH = os.path.dirname(torch.__file__) + TORCH_INCLUDE_ROOT = os.path.join(_TORCH_PATH, "include") + system_includes = [ + TORCH_INCLUDE_ROOT, + os.path.join(TORCH_INCLUDE_ROOT, "torch/csrc/api/include"), + os.path.join(TORCH_INCLUDE_ROOT, "TH"), + os.path.join(TORCH_INCLUDE_ROOT, "THC"), + ] + if not torch_exclude: + system_includes += include_paths(with_cuda) + system_includes = list(set(system_includes)) # FIXME: build python module excluded with torch, use `pybind11` # But we can't use this now because all aiter op based on torch diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index 925deb96ea..135e8ae03b 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -3,6 +3,7 @@ #pragma once #include + namespace py = pybind11; #define ACTIVATION_PYBIND \ @@ -658,6 +659,64 @@ namespace py = pybind11; py::arg("rng_state") = std::nullopt, \ py::arg("gen") = std::nullopt); +#define ROCSOLGEMM_PYBIND \ + m.def("rocb_create_extension", &rocb_create_extension, "create_extension"); \ + m.def("rocb_destroy_extension", &rocb_destroy_extension, "destroy_extension"); \ + m.def("rocb_mm", &RocSolIdxBlas, "mm"); \ + m.def("rocb_findallsols", &RocFindAllSolIdxBlas, "rocblas_find_all_sols"); + +#define HIPBSOLGEMM_PYBIND \ + m.def("hipb_create_extension", &hipb_create_extension, "create_extension"); \ + m.def("hipb_destroy_extension", &hipb_destroy_extension, "destroy_extension"); \ + m.def("hipb_mm", \ + &hipb_mm, \ + "hipb_mm", \ + py::arg("mat1"), \ + py::arg("mat2"), \ + py::arg("solution_index"), \ + py::arg("bias") = std::nullopt, \ + py::arg("out_dtype") = std::nullopt, \ + py::arg("scaleA") = std::nullopt, \ + py::arg("scaleB") = std::nullopt, \ + py::arg("scaleOut") = std::nullopt, \ + py::arg("bpreshuffle") = std::nullopt); \ + m.def("hipb_findallsols", \ + &hipb_findallsols, \ + "hipb_findallsols", \ + py::arg("mat1"), \ + py::arg("mat2"), \ + py::arg("bias") = std::nullopt, \ + py::arg("out_dtype") = std::nullopt, \ + py::arg("scaleA") = std::nullopt, \ + py::arg("scaleB") = std::nullopt, \ + py::arg("scaleC") = std::nullopt, \ + py::arg("bpreshuffle") = false); \ + m.def("getHipblasltKernelName", &getHipblasltKernelName); + +#define LIBMHA_BWD_PYBIND \ + m.def("libmha_bwd", \ + &aiter::torch_itfs::mha_bwd, \ + py::arg("dout"), \ + py::arg("q"), \ + py::arg("k"), \ + py::arg("v"), \ + py::arg("out"), \ + py::arg("softmax_lse"), \ + py::arg("dropout_p"), \ + py::arg("softmax_scale"), \ + py::arg("is_causal"), \ + py::arg("window_size_left"), \ + py::arg("window_size_right"), \ + py::arg("deterministic"), \ + py::arg("dq") = std::nullopt, \ + py::arg("dk") = std::nullopt, \ + py::arg("dv") = std::nullopt, \ + py::arg("dbias") = std::nullopt, \ + py::arg("bias") = std::nullopt, \ + py::arg("alibi_slopes") = std::nullopt, \ + py::arg("rng_state") = std::nullopt, \ + py::arg("gen") = std::nullopt); + #define MHA_VARLEN_BWD_ASM_PYBIND \ m.def("fmha_v3_varlen_bwd", \ &aiter::torch_itfs::fmha_v3_varlen_bwd, \ @@ -756,32 +815,56 @@ namespace py = pybind11; py::arg("v_descale") = std::nullopt, \ py::arg("gen") = std::nullopt); -#define MHA_VARLEN_FWD_ASM_PYBIND \ - m.def("fmha_v3_varlen_fwd", \ - &aiter::torch_itfs::fmha_v3_varlen_fwd, \ - py::arg("q"), \ - py::arg("k"), \ - py::arg("v"), \ - py::arg("cu_seqlens_q"), \ - py::arg("cu_seqlens_k"), \ - py::arg("max_seqlen_q"), \ - py::arg("max_seqlen_k"), \ - py::arg("min_seqlen_q"), \ - py::arg("dropout_p"), \ - py::arg("softmax_scale"), \ - py::arg("logits_soft_cap"), \ - py::arg("zero_tensors"), \ - py::arg("is_causal"), \ - py::arg("window_size_left"), \ - py::arg("window_size_right"), \ - py::arg("return_softmax_lse"), \ - py::arg("return_dropout_randval"), \ - py::arg("how_v3_bf16_cvt"), \ - py::arg("out") = std::nullopt, \ - py::arg("block_table") = std::nullopt, \ - py::arg("bias") = std::nullopt, \ - py::arg("alibi_slopes") = std::nullopt, \ - py::arg("gen") = std::nullopt, \ +#define LIBMHA_FWD_PYBIND \ + m.def("libmha_fwd", \ + &aiter::torch_itfs::mha_fwd, \ + py::arg("q"), \ + py::arg("k"), \ + py::arg("v"), \ + py::arg("dropout_p"), \ + py::arg("softmax_scale"), \ + py::arg("is_causal"), \ + py::arg("window_size_left"), \ + py::arg("window_size_right"), \ + py::arg("sink_size"), \ + py::arg("return_softmax_lse"), \ + py::arg("return_dropout_randval"), \ + py::arg("cu_seqlens_q") = std::nullopt, \ + py::arg("cu_seqlens_kv") = std::nullopt, \ + py::arg("out") = std::nullopt, \ + py::arg("bias") = std::nullopt, \ + py::arg("alibi_slopes") = std::nullopt, \ + py::arg("q_descale") = std::nullopt, \ + py::arg("k_descale") = std::nullopt, \ + py::arg("v_descale") = std::nullopt, \ + py::arg("gen") = std::nullopt); + +#define MHA_VARLEN_FWD_ASM_PYBIND \ + m.def("fmha_v3_varlen_fwd", \ + &aiter::torch_itfs::fmha_v3_varlen_fwd, \ + py::arg("q"), \ + py::arg("k"), \ + py::arg("v"), \ + py::arg("cu_seqlens_q"), \ + py::arg("cu_seqlens_k"), \ + py::arg("max_seqlen_q"), \ + py::arg("max_seqlen_k"), \ + py::arg("min_seqlen_q"), \ + py::arg("dropout_p"), \ + py::arg("softmax_scale"), \ + py::arg("logits_soft_cap"), \ + py::arg("zero_tensors"), \ + py::arg("is_causal"), \ + py::arg("window_size_left"), \ + py::arg("window_size_right"), \ + py::arg("return_softmax_lse"), \ + py::arg("return_dropout_randval"), \ + py::arg("how_v3_bf16_cvt"), \ + py::arg("out") = std::nullopt, \ + py::arg("block_table") = std::nullopt, \ + py::arg("bias") = std::nullopt, \ + py::arg("alibi_slopes") = std::nullopt, \ + py::arg("gen") = std::nullopt, \ py::arg("cu_seqlens_q_padded") = std::nullopt, \ py::arg("cu_seqlens_k_padded") = std::nullopt); diff --git a/setup.py b/setup.py index 9b1540db52..c704459431 100644 --- a/setup.py +++ b/setup.py @@ -4,6 +4,7 @@ import os import shutil import sys +import json from setuptools import Distribution, setup @@ -81,131 +82,55 @@ def is_develop_mode(): shutil.copytree("gradlib", "aiter_meta/gradlib") shutil.copytree("csrc", "aiter_meta/csrc") + def _load_modules_from_config(): + cfg_path = os.path.join(this_dir, "aiter", "jit", "optCompilerConfig.json") + try: + with open(cfg_path, "r", encoding="utf-8") as f: + data = json.load(f) + except Exception: + return [] + if isinstance(data, dict): + return list(data.keys()) + return [] + def get_exclude_ops(): - if PREBUILD_KERNELS == 1: - return [ - "libmha_fwd", - "libmha_bwd", - "module_fmha_v3_fwd", - "module_mha_fwd", - "module_mha_varlen_fwd", - "module_mha_batch_prefill", - "module_fmha_v3_bwd", - "module_fmha_v3_varlen_bwd", - "module_fmha_v3_varlen_fwd", - "module_mha_bwd", - "module_mha_varlen_bwd", - "module_batched_gemm_bf16_tune", - "module_batched_gemm_a8w8_tune", - "module_gemm_a8w8_tune", - "module_gemm_a8w8_blockscale_tune", - "module_gemm_a8w8_blockscale_bpreshuffle_tune", - "module_gemm_a4w4_blockscale_tune", - "module_gemm_a8w8_bpreshuffle_tune", - "module_gemm_a8w8_bpreshuffle_cktile_tune", - "module_gemm_mi350_a8w8_blockscale_asm", - ] - elif PREBUILD_KERNELS == 2: - return [ - "libmha_bwd", - "module_mha_batch_prefill", - "module_fmha_v3_bwd", - "module_fmha_v3_varlen_bwd", - "module_mha_bwd", - "module_mha_varlen_bwd", - "module_batched_gemm_bf16_tune", - "module_batched_gemm_a8w8_tune", - "module_gemm_a8w8_tune", - "module_gemm_a8w8_blockscale_tune", - "module_gemm_a8w8_blockscale_bpreshuffle_tune", - "module_gemm_a4w4_blockscale_tune", - "module_gemm_a8w8_bpreshuffle_tune", - "module_gemm_a8w8_bpreshuffle_cktile_tune", - "module_gemm_mi350_a8w8_blockscale_asm", - ] - elif PREBUILD_KERNELS == 3: - return [ - "module_activation", - "module_attention", - "module_pa_ragged", - "module_pa_v1", - "module_attention_asm", - "module_pa", - "module_mla_asm", - "module_cache", - "module_custom_all_reduce", - "module_quick_all_reduce", - "module_custom", - "module_gemm_common", - "module_batched_gemm_bf16", - "module_batched_gemm_a8w8", - "module_gemm_a8w8", - "module_gemm_a8w8_blockscale", - "module_gemm_a8w8_blockscale_bpreshuffle", - "module_gemm_a4w4_blockscale", - "module_gemm_a8w8_bpreshuffle", - "module_deepgemm", - "module_gemm_a8w8_bpreshuffle_cktile", - "module_gemm_a8w8_asm", - "module_gemm_a16w16_asm", - "module_gemm_a4w4_asm", - "module_gemm_a8w8_blockscale_asm", - "module_gemm_a8w8_blockscale_bpreshuffle_asm", - "module_gemm_mi350_a8w8_blockscale_asm", - "module_moe_asm", - "module_moe_ck2stages", - "module_moe_cktile2stages", - "module_moe_sorting", - "module_moe_topk", - "module_norm", - "module_pos_encoding", - "module_rmsnorm", - "module_smoothquant", - "module_batched_gemm_bf16_tune", - "module_batched_gemm_a8w8_tune", - "module_gemm_a8w8_tune", - "module_gemm_a8w8_blockscale_tune", - "module_gemm_a8w8_blockscale_bpreshuffle_tune", - "module_gemm_a4w4_blockscale_tune", - "module_gemm_a8w8_bpreshuffle_tune", - "module_gemm_a8w8_bpreshuffle_cktile_tune", - "module_aiter_operator", - "module_aiter_unary", - "module_quant", - "module_sample", - "module_rope_general_fwd", - "module_rope_general_bwd", - "module_rope_pos_fwd", - "module_fused_mrope_rms", - # "module_fmha_v3_fwd", - "module_mha_fwd", - "module_mha_varlen_fwd", - # "module_fmha_v3_bwd", - "module_fmha_v3_varlen_bwd", - "module_fmha_v3_varlen_fwd", - "module_mha_bwd", - "module_mha_varlen_bwd", - "libmha_fwd", - "libmha_bwd", - "module_rocsolgemm", - "module_hipbsolgemm", - "module_top_k_per_row", - "module_mla_metadata", - "module_mla_reduce", - "module_topk_plain", - ] - else: - return [ - "module_gemm_mi350_a8w8_blockscale_asm", - "module_batched_gemm_bf16_tune", - "module_batched_gemm_a8w8_tune", - "module_gemm_a8w8_tune", - "module_gemm_a8w8_blockscale_tune", - "module_gemm_a8w8_blockscale_bpreshuffle_tune", - "module_gemm_a4w4_blockscale_tune", - "module_gemm_a8w8_bpreshuffle_tune", - "module_gemm_a8w8_bpreshuffle_cktile_tune", - ] + all_modules = _load_modules_from_config() + exclude_ops = [] + + for module in all_modules: + if PREBUILD_KERNELS == 1: + # Exclude mha, _tune, and specific module + if ( + "mha" in module + or "_tune" in module + or module == "module_gemm_mi350_a8w8_blockscale_asm" + ): + exclude_ops.append(module) + elif PREBUILD_KERNELS == 2: + # Exclude _bwd, _tune, and specific module + if ( + "_bwd" in module + or "_tune" in module + or module == "module_gemm_mi350_a8w8_blockscale_asm" + ): + exclude_ops.append(module) + elif PREBUILD_KERNELS == 3: + # Keep only module_fmha_v3* and module_aiter_enum + if not ( + module.startswith("module_fmha_v3") + or module == "module_aiter_enum" + or module == "module_gemm_mi350_a8w8_blockscale_asm" + ): + exclude_ops.append(module) + else: + # Default behavior: exclude tunes and specific mi350 module + if ( + "_tune" in module + or module == "module_gemm_mi350_a8w8_blockscale_asm" + ): + exclude_ops.append(module) + + return exclude_ops exclude_ops = get_exclude_ops()