Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion aiter/jit/utils/cpp_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -1534,7 +1534,20 @@ def _write_ninja_file_to_build_library(
extra_ldflags = [flag.strip() for flag in extra_ldflags]
extra_include_paths = [flag.strip() for flag in extra_include_paths]
# include_paths() gives us the location of torch/extension.h
system_includes = [] if torch_exclude else include_paths(with_cuda)
# system_includes = [] if torch_exclude else include_paths(with_cuda)
import torch

_TORCH_PATH = os.path.dirname(torch.__file__)
TORCH_INCLUDE_ROOT = os.path.join(_TORCH_PATH, "include")
system_includes = [
TORCH_INCLUDE_ROOT,
os.path.join(TORCH_INCLUDE_ROOT, "torch/csrc/api/include"),
os.path.join(TORCH_INCLUDE_ROOT, "TH"),
os.path.join(TORCH_INCLUDE_ROOT, "THC"),
]
if not torch_exclude:
system_includes += include_paths(with_cuda)
system_includes = list(set(system_includes))

# FIXME: build python module excluded with torch, use `pybind11`
# But we can't use this now because all aiter op based on torch
Expand Down
135 changes: 109 additions & 26 deletions csrc/include/rocm_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#pragma once

#include <pybind11/pybind11.h>

namespace py = pybind11;

#define ACTIVATION_PYBIND \
Expand Down Expand Up @@ -658,6 +659,64 @@ namespace py = pybind11;
py::arg("rng_state") = std::nullopt, \
py::arg("gen") = std::nullopt);

#define ROCSOLGEMM_PYBIND \
m.def("rocb_create_extension", &rocb_create_extension, "create_extension"); \
m.def("rocb_destroy_extension", &rocb_destroy_extension, "destroy_extension"); \
m.def("rocb_mm", &RocSolIdxBlas, "mm"); \
m.def("rocb_findallsols", &RocFindAllSolIdxBlas, "rocblas_find_all_sols");

#define HIPBSOLGEMM_PYBIND \
m.def("hipb_create_extension", &hipb_create_extension, "create_extension"); \
m.def("hipb_destroy_extension", &hipb_destroy_extension, "destroy_extension"); \
m.def("hipb_mm", \
&hipb_mm, \
"hipb_mm", \
py::arg("mat1"), \
py::arg("mat2"), \
py::arg("solution_index"), \
py::arg("bias") = std::nullopt, \
py::arg("out_dtype") = std::nullopt, \
py::arg("scaleA") = std::nullopt, \
py::arg("scaleB") = std::nullopt, \
py::arg("scaleOut") = std::nullopt, \
py::arg("bpreshuffle") = std::nullopt); \
m.def("hipb_findallsols", \
&hipb_findallsols, \
"hipb_findallsols", \
py::arg("mat1"), \
py::arg("mat2"), \
py::arg("bias") = std::nullopt, \
py::arg("out_dtype") = std::nullopt, \
py::arg("scaleA") = std::nullopt, \
py::arg("scaleB") = std::nullopt, \
py::arg("scaleC") = std::nullopt, \
py::arg("bpreshuffle") = false); \
m.def("getHipblasltKernelName", &getHipblasltKernelName);

#define LIBMHA_BWD_PYBIND \
m.def("libmha_bwd", \
&aiter::torch_itfs::mha_bwd, \
py::arg("dout"), \
py::arg("q"), \
py::arg("k"), \
py::arg("v"), \
py::arg("out"), \
py::arg("softmax_lse"), \
py::arg("dropout_p"), \
py::arg("softmax_scale"), \
py::arg("is_causal"), \
py::arg("window_size_left"), \
py::arg("window_size_right"), \
py::arg("deterministic"), \
py::arg("dq") = std::nullopt, \
py::arg("dk") = std::nullopt, \
py::arg("dv") = std::nullopt, \
py::arg("dbias") = std::nullopt, \
py::arg("bias") = std::nullopt, \
py::arg("alibi_slopes") = std::nullopt, \
py::arg("rng_state") = std::nullopt, \
py::arg("gen") = std::nullopt);

#define MHA_VARLEN_BWD_ASM_PYBIND \
m.def("fmha_v3_varlen_bwd", \
&aiter::torch_itfs::fmha_v3_varlen_bwd, \
Expand Down Expand Up @@ -756,32 +815,56 @@ namespace py = pybind11;
py::arg("v_descale") = std::nullopt, \
py::arg("gen") = std::nullopt);

#define MHA_VARLEN_FWD_ASM_PYBIND \
m.def("fmha_v3_varlen_fwd", \
&aiter::torch_itfs::fmha_v3_varlen_fwd, \
py::arg("q"), \
py::arg("k"), \
py::arg("v"), \
py::arg("cu_seqlens_q"), \
py::arg("cu_seqlens_k"), \
py::arg("max_seqlen_q"), \
py::arg("max_seqlen_k"), \
py::arg("min_seqlen_q"), \
py::arg("dropout_p"), \
py::arg("softmax_scale"), \
py::arg("logits_soft_cap"), \
py::arg("zero_tensors"), \
py::arg("is_causal"), \
py::arg("window_size_left"), \
py::arg("window_size_right"), \
py::arg("return_softmax_lse"), \
py::arg("return_dropout_randval"), \
py::arg("how_v3_bf16_cvt"), \
py::arg("out") = std::nullopt, \
py::arg("block_table") = std::nullopt, \
py::arg("bias") = std::nullopt, \
py::arg("alibi_slopes") = std::nullopt, \
py::arg("gen") = std::nullopt, \
#define LIBMHA_FWD_PYBIND \
m.def("libmha_fwd", \
&aiter::torch_itfs::mha_fwd, \
py::arg("q"), \
py::arg("k"), \
py::arg("v"), \
py::arg("dropout_p"), \
py::arg("softmax_scale"), \
py::arg("is_causal"), \
py::arg("window_size_left"), \
py::arg("window_size_right"), \
py::arg("sink_size"), \
py::arg("return_softmax_lse"), \
py::arg("return_dropout_randval"), \
py::arg("cu_seqlens_q") = std::nullopt, \
py::arg("cu_seqlens_kv") = std::nullopt, \
py::arg("out") = std::nullopt, \
py::arg("bias") = std::nullopt, \
py::arg("alibi_slopes") = std::nullopt, \
py::arg("q_descale") = std::nullopt, \
py::arg("k_descale") = std::nullopt, \
py::arg("v_descale") = std::nullopt, \
py::arg("gen") = std::nullopt);

#define MHA_VARLEN_FWD_ASM_PYBIND \
m.def("fmha_v3_varlen_fwd", \
&aiter::torch_itfs::fmha_v3_varlen_fwd, \
py::arg("q"), \
py::arg("k"), \
py::arg("v"), \
py::arg("cu_seqlens_q"), \
py::arg("cu_seqlens_k"), \
py::arg("max_seqlen_q"), \
py::arg("max_seqlen_k"), \
py::arg("min_seqlen_q"), \
py::arg("dropout_p"), \
py::arg("softmax_scale"), \
py::arg("logits_soft_cap"), \
py::arg("zero_tensors"), \
py::arg("is_causal"), \
py::arg("window_size_left"), \
py::arg("window_size_right"), \
py::arg("return_softmax_lse"), \
py::arg("return_dropout_randval"), \
py::arg("how_v3_bf16_cvt"), \
py::arg("out") = std::nullopt, \
py::arg("block_table") = std::nullopt, \
py::arg("bias") = std::nullopt, \
py::arg("alibi_slopes") = std::nullopt, \
py::arg("gen") = std::nullopt, \
py::arg("cu_seqlens_q_padded") = std::nullopt, \
py::arg("cu_seqlens_k_padded") = std::nullopt);

Expand Down
173 changes: 49 additions & 124 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import shutil
import sys
import json

from setuptools import Distribution, setup

Expand Down Expand Up @@ -81,131 +82,55 @@ def is_develop_mode():
shutil.copytree("gradlib", "aiter_meta/gradlib")
shutil.copytree("csrc", "aiter_meta/csrc")

def _load_modules_from_config():
cfg_path = os.path.join(this_dir, "aiter", "jit", "optCompilerConfig.json")
try:
with open(cfg_path, "r", encoding="utf-8") as f:
data = json.load(f)
except Exception:
return []
if isinstance(data, dict):
return list(data.keys())
return []

def get_exclude_ops():
if PREBUILD_KERNELS == 1:
return [
"libmha_fwd",
"libmha_bwd",
"module_fmha_v3_fwd",
"module_mha_fwd",
"module_mha_varlen_fwd",
"module_mha_batch_prefill",
"module_fmha_v3_bwd",
"module_fmha_v3_varlen_bwd",
"module_fmha_v3_varlen_fwd",
"module_mha_bwd",
"module_mha_varlen_bwd",
"module_batched_gemm_bf16_tune",
"module_batched_gemm_a8w8_tune",
"module_gemm_a8w8_tune",
"module_gemm_a8w8_blockscale_tune",
"module_gemm_a8w8_blockscale_bpreshuffle_tune",
"module_gemm_a4w4_blockscale_tune",
"module_gemm_a8w8_bpreshuffle_tune",
"module_gemm_a8w8_bpreshuffle_cktile_tune",
"module_gemm_mi350_a8w8_blockscale_asm",
]
elif PREBUILD_KERNELS == 2:
return [
"libmha_bwd",
"module_mha_batch_prefill",
"module_fmha_v3_bwd",
"module_fmha_v3_varlen_bwd",
"module_mha_bwd",
"module_mha_varlen_bwd",
"module_batched_gemm_bf16_tune",
"module_batched_gemm_a8w8_tune",
"module_gemm_a8w8_tune",
"module_gemm_a8w8_blockscale_tune",
"module_gemm_a8w8_blockscale_bpreshuffle_tune",
"module_gemm_a4w4_blockscale_tune",
"module_gemm_a8w8_bpreshuffle_tune",
"module_gemm_a8w8_bpreshuffle_cktile_tune",
"module_gemm_mi350_a8w8_blockscale_asm",
]
elif PREBUILD_KERNELS == 3:
return [
"module_activation",
"module_attention",
"module_pa_ragged",
"module_pa_v1",
"module_attention_asm",
"module_pa",
"module_mla_asm",
"module_cache",
"module_custom_all_reduce",
"module_quick_all_reduce",
"module_custom",
"module_gemm_common",
"module_batched_gemm_bf16",
"module_batched_gemm_a8w8",
"module_gemm_a8w8",
"module_gemm_a8w8_blockscale",
"module_gemm_a8w8_blockscale_bpreshuffle",
"module_gemm_a4w4_blockscale",
"module_gemm_a8w8_bpreshuffle",
"module_deepgemm",
"module_gemm_a8w8_bpreshuffle_cktile",
"module_gemm_a8w8_asm",
"module_gemm_a16w16_asm",
"module_gemm_a4w4_asm",
"module_gemm_a8w8_blockscale_asm",
"module_gemm_a8w8_blockscale_bpreshuffle_asm",
"module_gemm_mi350_a8w8_blockscale_asm",
"module_moe_asm",
"module_moe_ck2stages",
"module_moe_cktile2stages",
"module_moe_sorting",
"module_moe_topk",
"module_norm",
"module_pos_encoding",
"module_rmsnorm",
"module_smoothquant",
"module_batched_gemm_bf16_tune",
"module_batched_gemm_a8w8_tune",
"module_gemm_a8w8_tune",
"module_gemm_a8w8_blockscale_tune",
"module_gemm_a8w8_blockscale_bpreshuffle_tune",
"module_gemm_a4w4_blockscale_tune",
"module_gemm_a8w8_bpreshuffle_tune",
"module_gemm_a8w8_bpreshuffle_cktile_tune",
"module_aiter_operator",
"module_aiter_unary",
"module_quant",
"module_sample",
"module_rope_general_fwd",
"module_rope_general_bwd",
"module_rope_pos_fwd",
"module_fused_mrope_rms",
# "module_fmha_v3_fwd",
"module_mha_fwd",
"module_mha_varlen_fwd",
# "module_fmha_v3_bwd",
"module_fmha_v3_varlen_bwd",
"module_fmha_v3_varlen_fwd",
"module_mha_bwd",
"module_mha_varlen_bwd",
"libmha_fwd",
"libmha_bwd",
"module_rocsolgemm",
"module_hipbsolgemm",
"module_top_k_per_row",
"module_mla_metadata",
"module_mla_reduce",
"module_topk_plain",
]
else:
return [
"module_gemm_mi350_a8w8_blockscale_asm",
"module_batched_gemm_bf16_tune",
"module_batched_gemm_a8w8_tune",
"module_gemm_a8w8_tune",
"module_gemm_a8w8_blockscale_tune",
"module_gemm_a8w8_blockscale_bpreshuffle_tune",
"module_gemm_a4w4_blockscale_tune",
"module_gemm_a8w8_bpreshuffle_tune",
"module_gemm_a8w8_bpreshuffle_cktile_tune",
]
all_modules = _load_modules_from_config()
exclude_ops = []

for module in all_modules:
if PREBUILD_KERNELS == 1:
# Exclude mha, _tune, and specific module
if (
"mha" in module
or "_tune" in module
or module == "module_gemm_mi350_a8w8_blockscale_asm"
):
exclude_ops.append(module)
elif PREBUILD_KERNELS == 2:
# Exclude _bwd, _tune, and specific module
if (
"_bwd" in module
or "_tune" in module
or module == "module_gemm_mi350_a8w8_blockscale_asm"
):
exclude_ops.append(module)
elif PREBUILD_KERNELS == 3:
# Keep only module_fmha_v3* and module_aiter_enum
if not (
module.startswith("module_fmha_v3")
or module == "module_aiter_enum"
or module == "module_gemm_mi350_a8w8_blockscale_asm"
):
exclude_ops.append(module)
else:
# Default behavior: exclude tunes and specific mi350 module
if (
"_tune" in module
or module == "module_gemm_mi350_a8w8_blockscale_asm"
):
exclude_ops.append(module)

return exclude_ops

exclude_ops = get_exclude_ops()

Expand Down