diff --git a/ggml/src/ggml-cpu/amx/amx.cpp b/ggml/src/ggml-cpu/amx/amx.cpp index 895a5713753..0e7f22ba4df 100644 --- a/ggml/src/ggml-cpu/amx/amx.cpp +++ b/ggml/src/ggml-cpu/amx/amx.cpp @@ -146,22 +146,48 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1; }; - if (op->op == GGML_OP_MUL_MAT && is_contiguous_2d(op->src[0]) && // src0 must be contiguous - is_contiguous_2d(op->src[1]) && // src1 must be contiguous - op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_amx_buffer_type() && - op->src[0]->ne[0] % (TILE_K * 2 * 32) == 0 && // TODO: not sure if correct (https://github.com/ggml-org/llama.cpp/pull/16315) - op->ne[0] % (TILE_N * 2) == 0 && // out_features is 32x - (qtype_has_amx_kernels(op->src[0]->type) || (op->src[0]->type == GGML_TYPE_F16))) { - // src1 must be host buffer - if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { + if (op->op != GGML_OP_MUL_MAT) { + return false; + } + if (!is_contiguous_2d(op->src[0]) || !is_contiguous_2d(op->src[1])) { + return false; + } + if (!op->src[0]->buffer || op->src[0]->buffer->buft != ggml_backend_amx_buffer_type()) { + return false; + } + if (op->ne[0] % (TILE_N * 2) != 0) { + return false; + } + + int alignment; + switch (op->src[0]->type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: + alignment = TILE_K; + break; + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ4_XS: + alignment = 256; // QK_K + break; + case GGML_TYPE_F16: + alignment = 16; + break; + default: return false; - } - // src1 must be float32 - if (op->src[1]->type == GGML_TYPE_F32) { - return true; - } } - return false; + if (op->src[0]->ne[0] % alignment != 0) { + return false; + } + if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { + return false; + } + if (op->src[1]->type != GGML_TYPE_F32) { + return false; + } + return true; } ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override { diff --git a/ggml/src/ggml-cpu/amx/mmq.cpp b/ggml/src/ggml-cpu/amx/mmq.cpp index 47c61b88164..b8e4781401c 100644 --- a/ggml/src/ggml-cpu/amx/mmq.cpp +++ b/ggml/src/ggml-cpu/amx/mmq.cpp @@ -2079,7 +2079,7 @@ void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const v _tile_stored(TMM5, Tile5(C_pre), TILE_N * sizeof(int32_t)); if (need_unpack) { - unpack_B(Tile1, B_blk0); + unpack_B(Tile1, B_blk1); _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK); } else { _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK);