Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/composable_kernel
Submodule composable_kernel updated 210 files
1 change: 1 addition & 0 deletions aiter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def getLogger():
from .ops.gemm_op_a4w4 import *
from .ops.batched_gemm_op_a8w8 import *
from .ops.batched_gemm_op_bf16 import *
from .ops.deepgemm import *
from .ops.aiter_operator import *
from .ops.activation import *
from .ops.attention import *
Expand Down
17 changes: 17 additions & 0 deletions aiter/jit/optCompilerConfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,23 @@
"is_standalone": "False",
"blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_gemm_a8w8_bpreshuffle/gen_instances.py --working_path {{}} --tune_file {AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_FILE}'"
},
"module_deepgemm": {
"srcs": [
"f'{AITER_CSRC_DIR}/pybind/deepgemm_pybind.cu'",
"f'{AITER_CSRC_DIR}/ck_deepgemm/deepgemm.cu'"

],
"flags_extra_cc": [],
"flags_extra_hip": [],
"md_name": "'module_deepgemm'",
"extra_ldflags": "None",
"extra_include": ["f'{CK_DIR}/example/ck_tile/18_flatmm'", "f'{AITER_CSRC_DIR}/ck_deepgemm/include'"],
"verbose": "False",
"is_python_module": "True",
"is_standalone": "False",
"hip_clang_path": "os.environ.get('FLATMM_HIP_CLANG_PATH')",
"blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_deepgemm/gen_instances.py --working_path {{}}'"
},
"module_gemm_a8w8_asm": {
"srcs": [
"f'{AITER_CSRC_DIR}/pybind/gemm_a8w8_asm_pybind.cu'",
Expand Down
30 changes: 30 additions & 0 deletions aiter/ops/deepgemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.

from torch import Tensor
from typing import Optional
from ..jit.core import (
compile_ops,
)


@compile_ops("module_deepgemm", fc_name="deepgemm")
def deepgemm_ck(
XQ: Tensor,
WQ: Tensor,
Y: Tensor,
group_layout: Tensor,
x_scale: Optional[Tensor] = None,
w_scale: Optional[Tensor] = None,
) -> Tensor: ...


def deepgemm(
XQ: Tensor,
WQ: Tensor,
Y: Tensor,
group_layout: Tensor,
x_scale: Optional[Tensor] = None,
w_scale: Optional[Tensor] = None,
):
return deepgemm_ck(XQ, WQ, Y, group_layout, x_scale, w_scale)
149 changes: 149 additions & 0 deletions csrc/ck_deepgemm/deepgemm.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
// SPDX-License-Identifier: MIT
// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

#include "deepgemm_common.cuh"
#include "deepgemm_lookup.h"
#include "deepgemm_manifest.h"
#include <cmath>
#include "py_itfs_common.h"

using RowwiseKernel = std::function<
torch::Tensor(torch::Tensor &, torch::Tensor &,
torch::Tensor &, torch::Tensor &,
std::optional<torch::Tensor>, std::optional<torch::Tensor>)>;

// Define a custom hash function for std::tuple<int, int, int>
struct IntTupleHash
{
size_t operator()(const std::tuple<int, int, int> &t) const
{
auto hash1 = std::hash<int>{}(std::get<0>(t));
auto hash2 = std::hash<int>{}(std::get<1>(t));
auto hash3 = std::hash<int>{}(std::get<2>(t));
return hash1 ^ hash2 ^ hash3;
}
};

// For certain high priority shapes, we directly use the best kernel rather
// than use heuristics.
using RowwiseKernelMap = std::unordered_map<
std::tuple<int, int, int>,
RowwiseKernel,
IntTupleHash>;

template <typename ABDataType, typename AccDataType, typename CDataType>
RowwiseKernel rowwise_heuristic_dispatch(int M, int N, int K)
{
// Apply shape heuristics to find a suitable kernel implementation.
if (M < 128)
{
return deepgemm_256x32x64x256_16x16x64_1x4<ABDataType, AccDataType, CDataType>;
}
else
{
return deepgemm_256x128x128x128_16x16x64_1x4<ABDataType, AccDataType, CDataType>;
}
}

// Helper function to return the next largest power of 2
static constexpr int nextPow2(unsigned int num)
{
if (num <= 1)
return 1;
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
}

template <typename ABDataType, typename AccDataType, typename CDataType>
RowwiseKernel rowwise_dispatch(int M, int N, int K)
{
// TODO: add tuner @lalala-sh
// For a given shape, either find the best kernel via lookup or heuristic.
// For many small M shapes, we bucket them to the next largest kernel.
// This is fine since kernels are padded anyway.

// static const auto lookup = [&]
// {
// return RowwiseKernelMap{GENERATE_LOOKUP_TABLE(ABDataType, AccDataType, CDataType)};
// }();

// // First check if this shape(M,N,K) is available in the direct lookup.
// auto it = lookup.find({M, N, K});
// // If we found an optimal kernel, use it.
// if (it != lookup.end())
// {
// return it->second;
// }

// int padded_m = M;
// if (M > 1 && M <= 16)
// {
// padded_m = 16;
// }
// else if (M <= 16384)
// {
// padded_m = nextPow2(M);
// }
// else if (M <= 20480)
// {
// padded_m = 20480;
// }
// // Second check if this shape(padded_m,N,K) is available in the direct lookup.
// it = lookup.find({padded_m, N, K});
// // If we found an optimal kernel, use it.
// if (it != lookup.end())
// {
// return it->second;
// }
// Otherwise, use heuristics.
return rowwise_heuristic_dispatch<ABDataType, AccDataType, CDataType>(M, N, K);
}

torch::Tensor deepgemm(
torch::Tensor &XQ,
torch::Tensor &WQ,
torch::Tensor &Y,
torch::Tensor &grouped_layout,
std::optional<torch::Tensor> x_scale,
std::optional<torch::Tensor> w_scale)
{
TORCH_CHECK(XQ.dtype() == WQ.dtype(),
"Weights and activations should both be int8/fp8!");
if (x_scale != std::nullopt && w_scale != std::nullopt)
TORCH_CHECK(x_scale.value().dtype() == w_scale.value().dtype(),
"Scales should have the same dtype!");

int M = XQ.size(0);
int N = WQ.size(0);
int K = XQ.size(1);
int KBatch = 1;



if (XQ.dtype() == at::ScalarType::BFloat16 || XQ.dtype() == at::ScalarType::Half)
{
if (XQ.dtype() == at::ScalarType::Half)
{
rowwise_dispatch<fp16, float, fp16>(M, N, K)(XQ, WQ, Y, grouped_layout, x_scale, w_scale);
}
else
{
rowwise_dispatch<bf16, float, bf16>(M, N, K)(XQ, WQ, Y, grouped_layout, x_scale, w_scale);
}
}
else if (XQ.dtype() == torch_fp8)
{
if (Y.dtype() == at::ScalarType::Half)
{
rowwise_dispatch<fp8, float, fp16>(M, N, K)(XQ, WQ, Y, grouped_layout, x_scale, w_scale);
}
else if (Y.dtype() == at::ScalarType::BFloat16)
{
rowwise_dispatch<fp8, float, bf16>(M, N, K)(XQ, WQ, Y, grouped_layout, x_scale, w_scale);
}
}
else
{
TORCH_CHECK(false, "Unsupported scales/output dtype!");
}
return Y;
}
64 changes: 64 additions & 0 deletions csrc/ck_deepgemm/deepgemm_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
from dataclasses import dataclass


@dataclass
class kernelInstance:
BLOCK_SIZE: int
# GroupCount: int
MPerBLOCK: int
NPerBLOCK: int
KPerBLOCK: int
WAVE_TILE_M: int
WAVE_TILE_N: int
WAVE_TILE_K: int
WAVE_MAP_M: int
WAVE_MAP_N: int

@property
def name(self) -> str:
return ("_").join(
[
"deepgemm",
("x").join(
map(
lambda x: str(x),
[
self.BLOCK_SIZE,
self.MPerBLOCK,
self.NPerBLOCK,
self.KPerBLOCK,
],
)
),
("x").join(
map(
lambda x: str(x),
[self.WAVE_TILE_M, self.WAVE_TILE_N, self.WAVE_TILE_K],
)
),
("x").join(map(lambda x: str(x), [self.WAVE_MAP_M, self.WAVE_MAP_N])),
]
)


# fmt: off
kernels_list = {
# ( M, N, K): kernel: BLOCK_SIZE| MPerBLOCK| NPerBLOCK| KPerBLOCK| WAVE_TILE_M| WAVE_TILE_N| WAVE_TILE_K| WAVE_MAP_M| WAVE_MAP_N| LOOP_SCHED|PIPELINE_VERSION
1: kernelInstance( 256, 128, 128, 128, 16, 16, 64, 1, 4),
2: kernelInstance( 256, 128, 128, 128, 16, 16, 32, 1, 4),
3: kernelInstance( 256, 32, 64, 256, 16, 16, 64, 1, 4),
4: kernelInstance( 256, 32, 64, 256, 16, 16, 32, 1, 4),
}


default_kernels_dict = {
# ( M, N, K): kernel: BLOCK_SIZE| MPerBLOCK| NPerBLOCK| KPerBLOCK| WAVE_TILE_M| WAVE_TILE_N| WAVE_MAP_M| WAVE_MAP_N| ABLOCK_TRANSFER| BBLOCK_TRANSFER| CBLOCK_TRANSFER| CBLOCK_SPV| CSHUFFLE_MX| CSHUFFLE_NX| LOOP_SCHED|PIPELINE_VERSION
(-1): kernelInstance( 256, 128, 128, 128, 16, 16, 64, 1, 4),
(-2): kernelInstance( 256, 128, 128, 128, 16, 16, 32, 1, 4),
(-3): kernelInstance( 256, 32, 64, 256, 16, 16, 64, 1, 4),
(-4): kernelInstance( 256, 32, 64, 256, 16, 16, 32, 1, 4),

}
# fmt: on
Loading