Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 40 additions & 14 deletions ggml/src/ggml-cpu/amx/amx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,22 +146,48 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1;
};

if (op->op == GGML_OP_MUL_MAT && is_contiguous_2d(op->src[0]) && // src0 must be contiguous
is_contiguous_2d(op->src[1]) && // src1 must be contiguous
op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_amx_buffer_type() &&
op->src[0]->ne[0] % (TILE_K * 2 * 32) == 0 && // TODO: not sure if correct (https://github.com/ggml-org/llama.cpp/pull/16315)
op->ne[0] % (TILE_N * 2) == 0 && // out_features is 32x
(qtype_has_amx_kernels(op->src[0]->type) || (op->src[0]->type == GGML_TYPE_F16))) {
// src1 must be host buffer
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
if (op->op != GGML_OP_MUL_MAT) {
return false;
}
if (!is_contiguous_2d(op->src[0]) || !is_contiguous_2d(op->src[1])) {
return false;
}
if (!op->src[0]->buffer || op->src[0]->buffer->buft != ggml_backend_amx_buffer_type()) {
return false;
}
if (op->ne[0] % (TILE_N * 2) != 0) {
return false;
}

int alignment;
switch (op->src[0]->type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q8_0:
alignment = TILE_K;
break;
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ4_XS:
alignment = 256; // QK_K
break;
case GGML_TYPE_F16:
alignment = 16;
break;
default:
return false;
}
// src1 must be float32
if (op->src[1]->type == GGML_TYPE_F32) {
return true;
}
}
return false;
if (op->src[0]->ne[0] % alignment != 0) {
return false;
}
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
return false;
}
if (op->src[1]->type != GGML_TYPE_F32) {
return false;
}
return true;
}

ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-cpu/amx/mmq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2079,7 +2079,7 @@ void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const v
_tile_stored(TMM5, Tile5(C_pre), TILE_N * sizeof(int32_t));

if (need_unpack) {
unpack_B<TB>(Tile1, B_blk0);
unpack_B<TB>(Tile1, B_blk1);
_tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);
} else {
_tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK);
Expand Down
Loading