Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ggml/src/ggml-cpu/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
#include "ggml-impl.h"
#include "simd-mappings.h"

#define GGML_FA_TILE_Q 32
#define GGML_FA_TILE_KV 16
#define GGML_FA_TILE_Q 64
#define GGML_FA_TILE_KV 64
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did this not need to be adjust with GEMM_RM/GEMM_RN size?

but I don't know if it is related to GEMM_RM or GEMM_RN (or both?)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be adjusted acc to ISA available, I just tuned on Zen 2 (AMD Rome) since that's the only one I have available. I think scratch space should be close to L1 cache size so maybe that's one factor to tune this

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, but I think something else: it is bloc of bloc so

GGML_FA_TILE_Q   = 16*GEMM_RM;  // or 16*GEMM_RN 
GGML_FA_TILE_KV = 16*GEMM_RM;  // or 16*GEMM_RN

so we use at most the most efficient uGEMM blick size ?

Note: I did not take time to look it GGML_FA_TILE_Q and GGML_FA_TILE_KV is related to GEMM_RM or GEMM_RN:

And yes for L1/L2/L3 cache size, but to be best we need block on K not only NxM

Copy link
Copy Markdown
Contributor Author

@am17an am17an Feb 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I guess for AVX2 it works out because GEMM_RM=4, 4*16 = 64. Let me try to tune it according to this and see the difference

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what bench cmd did you use ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-m llama_8b_q4_0.gguf,gpt-oss-20b.mxfp4 -fa 1 -d 0,8196,16348 -t 16,32,64 -n 0 -p 512

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no -ctk f32 -ctv f32 is needed ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No because it converts for f32 from f16


#ifdef __cplusplus

Expand Down
4 changes: 2 additions & 2 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -2874,8 +2874,8 @@ struct ggml_cplan ggml_graph_plan(
const int64_t DV = node->src[2]->ne[0];

// Tiled flash attention scratch (tile sizes defined in common.h)
// Per-thread: Q_q + KQ + mask + VKQ32 + V32 + padding
size_t prefill = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV)*n_tasks;
// Per-thread: Q_q + KQ + mask + VKQ32 + V32 + K_f32 + padding
size_t prefill = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV + GGML_FA_TILE_KV*DK)*n_tasks;

// Decode path: n_kv_chunks = n_tasks (one chunk per thread)
// Per-thread: VKQ accmulator (DV), partial M, partial S + intra-thread scratch for V, Q and VKQ
Expand Down
121 changes: 68 additions & 53 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "ggml-cpu.h"
#include "ggml-impl.h"
#include "binary-ops.h"
#include "simd-gemm.h"
#include "ggml.h"
#include "unary-ops.h"
#include "vec.h"
Expand Down Expand Up @@ -8389,10 +8390,6 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
GGML_ASSERT(k->type == v->type);
const ggml_type kv_type = k->type;

const auto * kv_type_traits_cpu = ggml_get_type_traits_cpu(kv_type);
const ggml_from_float_t kv_from_float = kv_type_traits_cpu->from_float;
const ggml_vec_dot_t kv_vec_dot = kv_type_traits_cpu->vec_dot;
const size_t kv_type_size = ggml_type_size(kv_type);

// broadcast factors
const int64_t rk2 = neq2/nek2;
Expand Down Expand Up @@ -8424,8 +8421,6 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q;
static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV;

GGML_ASSERT(nek1 % KV_TILE_SZ == 0 && "KV sequence length must be divisible by KV_TILE_SZ");

int ir = ir0;
while (ir < ir1) {
// q indices for the start of this tile
Expand All @@ -8452,18 +8447,20 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
}

// Per-thread scratch layout:
// Q_q: Q_TILE_SZ * DK (converted Q tile in KV type)
// Q_q: Q_TILE_SZ * DK (converted Q tile — F32 for GEMM, KV type for scalar)
// KQ: Q_TILE_SZ * KV_TILE_SZ (attention scores in float)
// mask: Q_TILE_SZ * KV_TILE_SZ (mask in float)
// VKQ32: Q_TILE_SZ * DV (FP32 output accumulator)
// V32: KV_TILE_SZ * DV (F32 buffer for V tile - used for f166 conversion)
float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + CACHE_LINE_SIZE_F32);
// V32: KV_TILE_SZ * DV (F32 buffer for V tile)
// K_f32: KV_TILE_SZ * DK (F32 buffer for K tile — GEMM path)
float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + KV_TILE_SZ*DK + CACHE_LINE_SIZE_F32);

void * Q_q = base;
float * KQ = (float *)((char *)base + Q_TILE_SZ * DK * sizeof(float));
float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ;
float * VKQ32 = mask32 + Q_TILE_SZ * KV_TILE_SZ;
float * V32 = VKQ32 + Q_TILE_SZ * DV; // F32 buffer for V tile
float * V32 = VKQ32 + Q_TILE_SZ * DV;
float * K_f32 = V32 + KV_TILE_SZ * DV;

memset(VKQ32, 0, Q_TILE_SZ * DV * sizeof(float));
memset(mask32, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
Expand All @@ -8476,42 +8473,71 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
const int iv3 = iq3 / rv3;
const int iv2 = iq2 / rv2;

for (int tq = 0; tq < tile_rows; tq++) {
const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
kv_from_float(pq, (char *)Q_q + tq * DK * kv_type_size, DK);
}
// Zero-pad remaining rows
for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
memset((char *)Q_q + tq * DK * kv_type_size, 0, DK * kv_type_size);
{
float * Q_f32 = (float *)Q_q;
for (int tq = 0; tq < tile_rows; tq++) {
const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
memcpy(Q_f32 + tq * DK, pq, DK * sizeof(float));
}
for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
memset(Q_f32 + tq * DK, 0, DK * sizeof(float));
}
}

memset(K_f32, 0, DK * KV_TILE_SZ * sizeof(float));
memset(V32, 0, KV_TILE_SZ * DV * sizeof(float));

for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) {
const int kv_tile = (int)std::min((int64_t)KV_TILE_SZ, nek1 - ic);

// skip the tile entirely if all the masks are -inf
if (mask) {
bool can_skip = true;
for (int tq = 0; tq < tile_rows; tq++) {
const ggml_fp16_t * mp_row = (const ggml_fp16_t *)((const char *) mask->data + (iq1 + tq)*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]);
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
for (int tk = 0; tk < kv_tile; tk++) {
mask32[tq * KV_TILE_SZ + tk] = slope * GGML_CPU_FP16_TO_FP32(mp_row[ic + tk]);
if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) {
can_skip = false;
}
}
// Pad remaining mask entries with -inf
for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {
mask32[tq * KV_TILE_SZ + tk] = -INFINITY;
}
}

if (can_skip) {
continue;
}
}

for (int tq = 0; tq < Q_TILE_SZ; tq++) {
const void * q_row = (const char *)Q_q + tq * DK * kv_type_size;
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
const void * k_row = (const char *) k->data + ((ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3);
float s;
kv_vec_dot(DK, &s, 0, k_row, 0, q_row, 0, 1);
KQ[tq * KV_TILE_SZ + tk] = s * scale;
// Pack K tile transposed: K_f32[dk][kv] so KV_TILE is contiguous (SIMD dim)
// Zero-pad the last tile so the GEMM always operates on KV_TILE_SZ columns
for (int tk = 0; tk < kv_tile; tk++) {
const char * k_data = (const char *)k->data + (ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3;
if (kv_type == GGML_TYPE_F16) {
const ggml_fp16_t * k_f16 = (const ggml_fp16_t *)k_data;
for (int64_t dk = 0; dk < DK; dk++) {
K_f32[dk * KV_TILE_SZ + tk] = GGML_CPU_FP16_TO_FP32(k_f16[dk]);
}
} else {
const float * k_f32_src = (const float *)k_data;
for (int64_t dk = 0; dk < DK; dk++) {
K_f32[dk * KV_TILE_SZ + tk] = k_f32_src[dk];
}
}
}
memset(KQ, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
simd_gemm(KQ, (const float *)Q_q, K_f32, Q_TILE_SZ, DK, KV_TILE_SZ);
ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, scale);

// Set padded KQ entries to -inf so softmax gives them zero weight
if (kv_tile < KV_TILE_SZ) {
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {
KQ[tq * KV_TILE_SZ + tk] = -INFINITY;
}
}
}

Expand Down Expand Up @@ -8551,33 +8577,22 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
S[tq] += ggml_vec_soft_max_f32(KV_TILE_SZ, kq_row, kq_row, Mnew);
}

// Convert V tile to F32 first (if F16), then do MAD
// On x86, ggml_vec_mad_f16 internall converts F16<->F32 on every load/store, so pre-converting is faster.
// TODO: on ARM, native f16 should be faster
if (kv_type == GGML_TYPE_F16) {
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
const ggml_fp16_t * v_row = (const ggml_fp16_t *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
ggml_fp16_to_fp32_row(v_row, V32 + tk * DV, DV);
}
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
if (skip[tq]) continue;
float * vkq_row = VKQ32 + tq * DV;
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
const float p = KQ[tq * KV_TILE_SZ + tk];
ggml_vec_mad_f32(DV, vkq_row, V32 + tk * DV, p);
}
// V accumulation: VKQ32 += softmax(KQ) * V
// Pack V tile to contiguous F32, zero-padded
for (int tk = 0; tk < kv_tile; tk++) {
const char * v_data = (const char *)v->data + (ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3;
if (kv_type == GGML_TYPE_F16) {
ggml_fp16_to_fp32_row((const ggml_fp16_t *)v_data, V32 + tk * DV, DV);
} else {
memcpy(V32 + tk * DV, v_data, DV * sizeof(float));
}
} else {
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
if (skip[tq]) continue;
float * vkq_row = VKQ32 + tq * DV;
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
const float p = KQ[tq * KV_TILE_SZ + tk];
const float * v_row = (const float *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
ggml_vec_mad_f32(DV, vkq_row, v_row, p);
}
}
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
if (skip[tq]) {
memset(KQ + tq * KV_TILE_SZ, 0, KV_TILE_SZ * sizeof(float));
}
}
simd_gemm(VKQ32, KQ, V32, Q_TILE_SZ, KV_TILE_SZ, DV);
}

// sinks (apply only to valid rows in the tile)
Expand Down Expand Up @@ -8794,15 +8809,15 @@ static void ggml_compute_forward_flash_attn_ext_f16(

const int64_t dr = (nr + nchunk - 1) / nchunk;

static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV;
static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
const bool use_tiled = !use_ref &&
bool use_tiled = !use_ref &&
(q->type == GGML_TYPE_F32 &&
kv_is_f32_or_f16 &&
k->type == v->type &&
nek1 % KV_TILE_SZ == 0 &&
neq1 >= Q_TILE_SZ);

#ifdef GGML_SIMD
use_tiled &= (DV % GGML_F32_EPR == 0);
#endif
int current_chunk = ith;

while (current_chunk < nchunk) {
Expand Down
139 changes: 139 additions & 0 deletions ggml/src/ggml-cpu/simd-gemm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
#pragma once

// Computes C[M x N] += A[M x K] * B[K x N]

#include "ggml-cpu-impl.h"
#include "vec.h"
#include "common.h"
Comment thread
am17an marked this conversation as resolved.
#include "simd-mappings.h"

// TODO: add support for sizeless vector types
#if defined(GGML_SIMD) && !defined(__ARM_FEATURE_SVE) && !defined(__riscv_v_intrinsic)

// TODO: untested on avx512
// These are in units of GGML_F32_EPR
#if defined(__AVX512F__) || defined (__ARM_NEON__)
static constexpr int GEMM_RM = 4;
static constexpr int GEMM_RN = 4; // 16+4+1 = 25/32
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can try

static constexpr int GEMM_RM = 6;
static constexpr int GEMM_RN = 4;   // 24 + 4 + 1 = 29 ...

that is what is used by AMD.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was this before #19422 (comment), but changed for ARM_NEON, will create a separate branch for AVX512

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#if defined(__AVX512F__)
    static constexpr int GEMM_RM = 4;
    static constexpr int GEMM_RN = 6; // 24+4+2 = 30/32  (+2 for pre-load)
#elif defined (__ARM_NEON__)
    static constexpr int GEMM_RM = 4;
    static constexpr int GEMM_RN = 4; // 16+4+1 = 25/32

#elif defined(__AVX2__) || defined(__AVX__)
Comment on lines +15 to +18
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On M2 Ultra, I get better results using:

# GGML_F32_EPR = 4

GEMM_RM = 4
GEMM_RN = 4

Here is before and after:

  • 6x4
Model Test t/s master t/s pr/19422 Speedup
gpt-oss 20B MXFP4 MoE pp512@d1024 166.40 177.12 1.06
gpt-oss 20B MXFP4 MoE pp512@d4096 134.16 155.12 1.16
qwen2 1.5B F16 pp512@d1024 298.68 307.97 1.03
qwen2 1.5B F16 pp512@d4096 226.77 247.86 1.09
qwen2 3B F16 pp512@d1024 143.64 151.70 1.06
qwen2 3B F16 pp512@d4096 112.81 131.39 1.16
  • 4x4
Model Test t/s master t/s pr/19422 Speedup
gpt-oss 20B MXFP4 MoE pp512@d1024 167.38 178.00 1.06
gpt-oss 20B MXFP4 MoE pp512@d4096 134.28 160.09 1.19
qwen2 1.5B F16 pp512@d1024 299.04 313.51 1.05
qwen2 1.5B F16 pp512@d4096 226.24 259.23 1.15
qwen2 3B F16 pp512@d1024 144.24 152.88 1.06
qwen2 3B F16 pp512@d4096 113.59 135.72 1.19

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for testing, can you try one with a very high depth like 16384, that's where it would be really clear

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also for me, after this PR FA=1 is always faster than FA=0 at least for PP. For TG the results are better if there are more threads. I guess our GEMV implementation can be improved

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is comparison of 6x4 (current) vs 4x4 (new):

Model Test t/s 8debab3 t/s dce1b0911 Speedup
gpt-oss 20B MXFP4 MoE pp512@d1024 176.48 178.33 1.01
gpt-oss 20B MXFP4 MoE pp512@d4096 156.05 159.91 1.02
gpt-oss 20B MXFP4 MoE pp512@d16384 106.38 113.16 1.06
qwen2 1.5B F16 pp512@d1024 307.25 313.73 1.02
qwen2 1.5B F16 pp512@d4096 247.65 258.67 1.04
qwen2 1.5B F16 pp512@d16384 136.64 153.64 1.12
qwen2 3B F16 pp512@d1024 151.11 152.89 1.01
qwen2 3B F16 pp512@d4096 130.56 133.45 1.02
qwen2 3B F16 pp512@d16384 85.33 91.88 1.08

I.e. it's better also for 16k context.

static constexpr int GEMM_RM = 6;
static constexpr int GEMM_RN = 2; // 12+2+1 = 15/16
#else
static constexpr int GEMM_RM = 2;
static constexpr int GEMM_RN = 2;
#endif

#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Waggressive-loop-optimizations"
#endif
Comment on lines +26 to +29
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed the compile warnings in the CI. Are we confident these are false-positives and safe to ignore?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it was complaining about overflowing iteration when ii is 2^58


template <int RM, int RN>
static inline void simd_gemm_ukernel(
float * GGML_RESTRICT C,
const float * GGML_RESTRICT A,
const float * GGML_RESTRICT B,
int64_t K, int64_t N,
int64_t ii, int64_t jj)
{
static constexpr int KN = GGML_F32_EPR;

GGML_F32_VEC acc[RM][RN];
for (int i = 0; i < RM; i++) {
for (int r = 0; r < RN; r++) {
acc[i][r] = GGML_F32_VEC_LOAD(C + (ii + i) * N + jj + r * KN);
}
}

for (int64_t kk = 0; kk < K; kk++) {
GGML_F32_VEC Bv[RN];
for (int r = 0; r < RN; r++) {
Bv[r] = GGML_F32_VEC_LOAD(B + kk * N + jj + r * KN);
}
for (int i = 0; i < RM; i++) {
GGML_F32_VEC p = GGML_F32_VEC_SET1(A[(ii + i) * K + kk]);
Copy link
Copy Markdown
Contributor

@Djip007 Djip007 Feb 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is nice for x86 to do it like that.

but on ARM (neon) I remember that there is some OP:
regC += regB * regA[i] for FMA
it is possible to load a full register for A.
for FP16 : with 32 register on neon you can have:

  • RN = 1
  • RM = 16 // => 1 vector load

for fp32 it may be
1x8 / 2x8 / 1x16

[edit] but may need some transpose on A for that.

for (int r = 0; r < RN; r++) {
acc[i][r] = GGML_F32_VEC_FMA(acc[i][r], Bv[r], p);
}
}
}

for (int i = 0; i < RM; i++) {
for (int r = 0; r < RN; r++) {
GGML_F32_VEC_STORE(C + (ii + i) * N + jj + r * KN, acc[i][r]);
}
}
}

// C[M x N] += A[M x K] * B[K x N]
static void simd_gemm(
float * GGML_RESTRICT C,
const float * GGML_RESTRICT A,
const float * GGML_RESTRICT B,
int64_t M, int64_t K, int64_t N)
{
static constexpr int KN = GGML_F32_EPR;

int64_t ii = 0;
for (; ii + GEMM_RM <= M; ii += GEMM_RM) {
int64_t jj = 0;
for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) {
simd_gemm_ukernel<GEMM_RM, GEMM_RN>(C, A, B, K, N, ii, jj);
}
for (; jj + KN <= N; jj += KN) {
simd_gemm_ukernel<GEMM_RM, 1>(C, A, B, K, N, ii, jj);
}
for (; jj < N; jj++) {
for (int i = 0; i < GEMM_RM; i++) {
float a = C[(ii + i) * N + jj];
for (int64_t kk = 0; kk < K; kk++) {
a += A[(ii + i) * K + kk] * B[kk * N + jj];
}
C[(ii + i) * N + jj] = a;
}
}
}

// Tail rows: one at a time
for (; ii < M; ii++) {
int64_t jj = 0;
for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) {
simd_gemm_ukernel<1, GEMM_RN>(C, A, B, K, N, ii, jj);
}
for (; jj + KN <= N; jj += KN) {
simd_gemm_ukernel<1, 1>(C, A, B, K, N, ii, jj);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for some gain (or not...)
tinyblas use some encoding 0xMN to have a switch on "all" MxN possible
so we can use some <2,GEMM_RN> ...

int64_t mc, nc, mp, np;

but yes it need more dev / time / ...

for (; jj < N; jj++) {
float a = C[ii * N + jj];
for (int64_t kk = 0; kk < K; kk++) {
a += A[ii * K + kk] * B[kk * N + jj];
}
C[ii * N + jj] = a;
}
}
}

#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic pop
#endif

#else // scalar path

static void simd_gemm(
float * GGML_RESTRICT C,
const float * GGML_RESTRICT A,
const float * GGML_RESTRICT B,
int64_t M, int64_t K, int64_t N)
{
for (int64_t i = 0; i < M; i++) {
for (int64_t j = 0; j < N; j++) {
float sum = C[i * N + j];
for (int64_t kk = 0; kk < K; kk++) {
sum += A[i * K + kk] * B[kk * N + j];
}
C[i * N + j] = sum;
}
}
}

#endif // GGML_SIMD
Loading