From a5a22b30d3ee428efec6f1561ab3d02fe63a00ce Mon Sep 17 00:00:00 2001 From: Todor Boinovski Date: Fri, 16 Jan 2026 14:02:43 -0800 Subject: [PATCH 1/6] hexagon: add ssm_conv op --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 64 ++++++ ggml/src/ggml-hexagon/htp/CMakeLists.txt | 1 + ggml/src/ggml-hexagon/htp/htp-msg.h | 1 + ggml/src/ggml-hexagon/htp/htp-ops.h | 1 + ggml/src/ggml-hexagon/htp/main.c | 49 +++++ ggml/src/ggml-hexagon/htp/ssm-conv.c | 260 +++++++++++++++++++++++ 6 files changed, 376 insertions(+) create mode 100644 ggml/src/ggml-hexagon/htp/ssm-conv.c diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index b70da8f3b28..ff851b2711b 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2152,6 +2152,51 @@ static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess return true; } +static bool ggml_hexagon_supported_ssm_conv(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * src1 = op->src[1]; + const struct ggml_tensor * dst = op; + + // Only support FP32 for now + if (src0->type != GGML_TYPE_F32 || src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + return false; + } + + // Check IO tensor shapes + if (src0->ne[3] != 1 || src1->ne[2] != 1 || src1->ne[3] != 1 || dst->ne[3] != 1) { + return false; // src0 should be effectively 3D + } + + const int d_conv = src1->ne[0]; + const int d_inner = src0->ne[1]; + const int n_t = dst->ne[1]; + const int n_s = dst->ne[2]; + + if (src0->ne[0] != d_conv - 1 + n_t || src0->ne[1] != d_inner || src0->ne[2] != n_s) { + return false; + } + if (src1->ne[0] != d_conv || src1->ne[1] != d_inner) { + return false; + } + if (dst->ne[0] != d_inner || dst->ne[1] != n_t || dst->ne[2] != n_s) { + return false; + } + + if (src0->nb[0] != sizeof(float) || src0->nb[1] != src0->ne[0] * sizeof(float)) { + return false; + } + if (src1->nb[0] != sizeof(float)) { + return false; + } + + // TODO: add support for non-contiguous tensors + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) { + return false; + } + + return true; +} + enum dspqbuf_type { DSPQBUF_TYPE_DSP_WRITE_CPU_READ = 0, DSPQBUF_TYPE_CPU_WRITE_DSP_READ, @@ -2468,6 +2513,17 @@ static inline size_t init_flash_attn_ext_req(htp_general_req * req, dspqueue_buf return n_bufs; } +static inline size_t init_ssm_conv_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { + req->op = HTP_OP_SSM_CONV; + + size_t n_bufs = 0; + n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CONSTANT); + n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + + return n_bufs; +} + static const char * ggml_backend_hexagon_name(ggml_backend_t backend) { auto sess = static_cast(backend->context); return sess->name.c_str(); @@ -2606,6 +2662,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg ggml_hexagon_dispatch_op(sess, node, flags); break; + case GGML_OP_SSM_CONV: + ggml_hexagon_dispatch_op(sess, node, flags); + break; + default: GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node)); } @@ -3024,6 +3084,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_argsort(sess, op); break; + case GGML_OP_SSM_CONV: + supp = ggml_hexagon_supported_ssm_conv(sess, op); + break; + default: break; } diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 2c23b60da3d..02d07a503d5 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -31,6 +31,7 @@ add_library(${HTP_LIB} SHARED get-rows-ops.c cpy-ops.c argsort-ops.c + ssm-conv.c ) target_compile_definitions(${HTP_LIB} PRIVATE diff --git a/ggml/src/ggml-hexagon/htp/htp-msg.h b/ggml/src/ggml-hexagon/htp/htp-msg.h index 25403bb1126..52dcc36d8f7 100644 --- a/ggml/src/ggml-hexagon/htp/htp-msg.h +++ b/ggml/src/ggml-hexagon/htp/htp-msg.h @@ -68,6 +68,7 @@ enum htp_op { HTP_OP_SQR, HTP_OP_SQRT, HTP_OP_SUM_ROWS, + HTP_OP_SSM_CONV, INVALID }; diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 127ab1d6659..da65d5b54d7 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -61,5 +61,6 @@ int op_set_rows(struct htp_ops_context * octx); int op_get_rows(struct htp_ops_context * octx); int op_cpy(struct htp_ops_context * octx); int op_argsort(struct htp_ops_context * octx); +int op_ssm_conv(struct htp_ops_context * octx); #endif /* HTP_OPS_H */ diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 92a1422896c..3f99dbb32c4 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -757,6 +757,47 @@ static void proc_sum_rows_req(struct htp_context * ctx, struct htp_general_req * send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); } +static void proc_ssm_conv_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { + struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; + + // We've written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[2].fd; + rsp_bufs[0].ptr = bufs[2].ptr; + rsp_bufs[0].offset = bufs[2].offset; + rsp_bufs[0].size = bufs[2].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup OP context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.src1 = req->src1; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.src1.data = (uint32_t) bufs[1].ptr; + octx.dst.data = (uint32_t) bufs[2].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_ssm_conv(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + static void proc_activations_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs, @@ -1142,6 +1183,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { proc_argsort_req(ctx, &req, bufs); break; + case HTP_OP_SSM_CONV: + if (n_bufs != 3) { + FARF(ERROR, "Bad ssm-conv-req buffer list"); + continue; + } + proc_ssm_conv_req(ctx, &req, bufs); + break; + default: FARF(ERROR, "Unknown Op %u", req.op); break; diff --git a/ggml/src/ggml-hexagon/htp/ssm-conv.c b/ggml/src/ggml-hexagon/htp/ssm-conv.c new file mode 100644 index 00000000000..2ad436bc724 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/ssm-conv.c @@ -0,0 +1,260 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "hex-dma.h" +#include "htp-msg.h" +#include "htp-ops.h" +#include "hvx-utils.h" + +// Scalar FP32 SSM_CONV implementation +static void ssm_conv_thread_f32_f32(struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { + const struct htp_tensor * src0 = &octx->src0; // conv_x input -> {d_conv - 1 + n_t, d_inner, n_seqs} + const struct htp_tensor * src1 = &octx->src1; // conv1d weights -> {d_conv, d_inner} + struct htp_tensor * dst = &octx->dst; // output -> {d_inner, n_t, n_seqs} + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint32_t d_conv = src1->ne[0]; + const uint32_t d_inner = src0->ne[1]; + const uint32_t n_t = dst->ne[1]; + const uint32_t n_s = dst->ne[2]; + + const uint32_t src0_stride_inner = src0->nb[1] / sizeof(float); // stride for inner dimension + const uint32_t src0_stride_seq = src0->nb[2] / sizeof(float); // stride for sequence dimension + const uint32_t src1_stride_inner = src1->nb[1] / sizeof(float); // stride for inner dimension + const uint32_t dst_stride_token = dst->nb[1] / sizeof(float); // stride for token dimension + const uint32_t dst_stride_seq = dst->nb[2] / sizeof(float); // stride for sequence dimension + + const float * src0_data = (const float *) src0->data; + const float * src1_data = (const float *) src1->data; + float * dst_data = (float *) dst->data; + + // Calculate row range for this thread + const uint32_t d_inner_per_thread = (d_inner + nth - 1) / nth; + const uint32_t d_inner_start = d_inner_per_thread * ith; + const uint32_t d_inner_end = MIN(d_inner_start + d_inner_per_thread, d_inner); + + // No work for this thread + if (d_inner_start >= d_inner_end) { + return; + } + + for (uint32_t i3 = 0; i3 < n_s; ++i3) { + for (uint32_t i2 = 0; i2 < n_t; ++i2) { + for (uint32_t i1 = d_inner_start; i1 < d_inner_end; ++i1) { + float sumf = 0.0f; + + for (uint32_t i0 = 0; i0 < d_conv; ++i0) { + // src0: window starting at position i2, element at window offset i0 + // src0 layout: {d_conv - 1 + n_t, d_inner, n_seqs} + const uint32_t src0_idx = (i2 + i0) + i1 * src0_stride_inner + i3 * src0_stride_seq; + // src1: conv weight at position i0, inner dim i1 + // src1 layout: {d_conv, d_inner} + const uint32_t src1_idx = i0 + i1 * src1_stride_inner; + + sumf += src0_data[src0_idx] * src1_data[src1_idx]; + } + + // dst: inner dim i1, token i2, sequence i3 + // dst layout: {d_inner, n_t, n_seqs} + const uint32_t dst_idx = i1 + i2 * dst_stride_token + i3 * dst_stride_seq; + dst_data[dst_idx] = sumf; + } + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "ssm-conv-f32 %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], d_inner_start, d_inner_end, + src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], + dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// HVX FP32 SSM_CONV implementation +// Vectorizes across d_inner dimension, processing 32 inner dims at once +static void ssm_conv_thread_f32_f32_hvx(struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { + const struct htp_tensor * src0 = &octx->src0; + const struct htp_tensor * src1 = &octx->src1; + struct htp_tensor * dst = &octx->dst; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint32_t d_conv = src1->ne[0]; + const uint32_t d_inner = src0->ne[1]; + const uint32_t n_t = dst->ne[1]; + const uint32_t n_s = dst->ne[2]; + + const uint32_t src0_stride_inner = src0->nb[1] / sizeof(float); // stride for inner dimension + const uint32_t src0_stride_seq = src0->nb[2] / sizeof(float); // stride for sequence dimension + const uint32_t src1_stride_inner = src1->nb[1] / sizeof(float); // stride for inner dimension + const uint32_t dst_stride_token = dst->nb[1] / sizeof(float); // stride for token dimension + const uint32_t dst_stride_seq = dst->nb[2] / sizeof(float); // stride for sequence dimension + + const float * src0_data = (const float *) src0->data; + const float * src1_data = (const float *) src1->data; + float * dst_data = (float *) dst->data; + + // Calculate row range for this thread + const uint32_t d_inner_per_thread = (d_inner + nth - 1) / nth; + const uint32_t d_inner_start = d_inner_per_thread * ith; + const uint32_t d_inner_end = MIN(d_inner_start + d_inner_per_thread, d_inner); + + if (d_inner_start >= d_inner_end) { + return; // No work for this thread + } + + // Align start to VLEN_FP32 boundary + const uint32_t d_inner_vec_start = (d_inner_start + VLEN_FP32 - 1) & ~(VLEN_FP32 - 1); + const uint32_t d_inner_vec_end = d_inner_end & ~(VLEN_FP32 - 1); + + // Per sequence + for (uint32_t i3 = 0; i3 < n_s; ++i3) { + // Per token + for (uint32_t i2 = 0; i2 < n_t; ++i2) { + // Handle scalar remainder at the beginning (when start is not aligned) + for (uint32_t i1 = d_inner_start; i1 < MIN(d_inner_vec_start, d_inner_end); ++i1) { + float sumf = 0.0f; + for (uint32_t i0 = 0; i0 < d_conv; ++i0) { + const uint32_t src0_idx = (i2 + i0) + i1 * src0_stride_inner + i3 * src0_stride_seq; + const uint32_t src1_idx = i0 + i1 * src1_stride_inner; + sumf += src0_data[src0_idx] * src1_data[src1_idx]; + } + const uint32_t dst_idx = i1 + i2 * dst_stride_token + i3 * dst_stride_seq; + dst_data[dst_idx] = sumf; + } + + for (uint32_t i1_vec = d_inner_vec_start; i1_vec < d_inner_vec_end; i1_vec += VLEN_FP32) { + HVX_Vector acc_vec = Q6_V_vzero(); + + // Per kernel element + for (uint32_t i0 = 0; i0 < d_conv; ++i0) { + // Load 32 elements from src0: window at position (i2+i0), inner dims [i1_vec, i1_vec+32) + const float * src0_ptr = src0_data + (i2 + i0) + i1_vec * src0_stride_inner + i3 * src0_stride_seq; + HVX_Vector src0_vec = *(const HVX_Vector *) src0_ptr; + + // Load 32 elements from src1: kernel at position i0, inner dims [i1_vec, i1_vec+32) + const float * src1_ptr = src1_data + i0 + i1_vec * src1_stride_inner; + HVX_Vector src1_vec = *(const HVX_Vector *) src1_ptr; + + HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(src0_vec, src1_vec); + acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod); + } + + HVX_Vector result_vec = Q6_Vsf_equals_Vqf32(acc_vec); + float * dst_ptr = dst_data + i1_vec + i2 * dst_stride_token + i3 * dst_stride_seq; + *(HVX_Vector *) dst_ptr = result_vec; + } + + // Handle scalar remainder at the end (if end is not aligned) + for (uint32_t i1 = d_inner_vec_end; i1 < d_inner_end; ++i1) { + float sumf = 0.0f; + for (uint32_t i0 = 0; i0 < d_conv; ++i0) { + const uint32_t src0_idx = (i2 + i0) + i1 * src0_stride_inner + i3 * src0_stride_seq; + const uint32_t src1_idx = i0 + i1 * src1_stride_inner; + sumf += src0_data[src0_idx] * src1_data[src1_idx]; + } + const uint32_t dst_idx = i1 + i2 * dst_stride_token + i3 * dst_stride_seq; + dst_data[dst_idx] = sumf; + } + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "ssm-conv-f32-hvx %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], d_inner_start, d_inner_end, + src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], + dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +static void ssm_conv_work_f32_f32(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = (struct htp_ops_context *) data; + ssm_conv_thread_f32_f32(octx, n, i); +} + +static void ssm_conv_work_f32_f32_hvx(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = (struct htp_ops_context *) data; + ssm_conv_thread_f32_f32_hvx(octx, n, i); +} + +int op_ssm_conv_f32(struct htp_ops_context * octx) { + const struct htp_tensor * src0 = &octx->src0; + const struct htp_tensor * src1 = &octx->src1; + struct htp_tensor * dst = &octx->dst; + + if (src0->type != HTP_TYPE_F32 || src1->type != HTP_TYPE_F32 || dst->type != HTP_TYPE_F32) { + FARF(ERROR, "ssm_conv: only (F32 x F32 -> F32) OPs supported"); + return HTP_STATUS_NO_SUPPORT; + } + + const uint32_t nc = src1->ne[0]; // d_conv + const uint32_t ncs = src0->ne[0]; // d_conv - 1 + n_t + const int nr = src0->ne[1]; // d_inner + const int n_t = dst->ne[1]; // tokens per sequence + const int n_s = dst->ne[2]; // number of sequences in the batch + + if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { + const uint32_t n_jobs = MIN(octx->n_threads, nr); + + const uint32_t src0_stride_inner = src0->nb[1] / sizeof(float); + const uint32_t src1_stride_inner = src1->nb[1] / sizeof(float); + + int use_hvx = 0; + if (nr >= VLEN_FP32) { + int is_aligned = hex_is_aligned((void *) src0->data, VLEN) && + hex_is_aligned((void *) src1->data, VLEN) && + hex_is_aligned((void *) dst->data, VLEN); + + int strides_aligned = !(src0_stride_inner & (VLEN_FP32 - 1)) && !(src1_stride_inner & (VLEN_FP32 - 1)); + + if (is_aligned && strides_aligned) { + use_hvx = 1; + } + + FARF(HIGH, "ssm-conv-f32: (%ux%ux%ux%u) x (%ux%ux%ux%u) -> (%ux%ux%ux%u) : use_hvx %d is_aligned %d strides_aligned %d\n", + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], use_hvx, is_aligned, strides_aligned); + } + + if (use_hvx) { + worker_pool_run_func(octx->ctx->worker_pool, ssm_conv_work_f32_f32_hvx, octx, n_jobs); + } else { + worker_pool_run_func(octx->ctx->worker_pool, ssm_conv_work_f32_f32, octx, n_jobs); + } + } + + return HTP_STATUS_OK; +} + +int op_ssm_conv(struct htp_ops_context * octx) { + int err = HTP_STATUS_OK; + struct htp_tensor * dst = &octx->dst; + + switch (dst->type) { + case HTP_TYPE_F32: + err = op_ssm_conv_f32(octx); + break; + default: + err = HTP_STATUS_NO_SUPPORT; + break; + } + + return err; +} From e7b715efefdd60e01ffbe4dffbec6fc7a04e69e0 Mon Sep 17 00:00:00 2001 From: Todor Boinovski Date: Wed, 18 Feb 2026 11:50:38 -0800 Subject: [PATCH 2/6] hexagon: hvx kernel is functional --- ggml/src/ggml-hexagon/htp/hvx-utils.h | 14 ++ ggml/src/ggml-hexagon/htp/ssm-conv.c | 214 ++++++++++++++++---------- 2 files changed, 146 insertions(+), 82 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.h b/ggml/src/ggml-hexagon/htp/hvx-utils.h index a518ad37331..31d1e139c1c 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-utils.h +++ b/ggml/src/ggml-hexagon/htp/hvx-utils.h @@ -15,4 +15,18 @@ #include "hvx-div.h" #include "hvx-base.h" +#ifndef GATHER_TYPE +# ifdef __hexagon__ +# define GATHER_TYPE(_a) (uint32_t) (_a) +# else +# define GATHER_TYPE(_a) (_a) +# endif +#endif + +#if defined(__hexagon__) +# define SCATTER_TYPE(_a) (intptr_t) _a +#else +# define SCATTER_TYPE(_a) (HVX_Vector *) _a +#endif + #endif /* HVX_UTILS_H */ diff --git a/ggml/src/ggml-hexagon/htp/ssm-conv.c b/ggml/src/ggml-hexagon/htp/ssm-conv.c index 2ad436bc724..06bc83acf14 100644 --- a/ggml/src/ggml-hexagon/htp/ssm-conv.c +++ b/ggml/src/ggml-hexagon/htp/ssm-conv.c @@ -20,6 +20,51 @@ #include "htp-ops.h" #include "hvx-utils.h" +#include "hvx-dump.h" + +#define htp_ssm_conv_tensors_preamble \ + struct htp_tensor * restrict src0 = &octx->src0; \ + struct htp_tensor * restrict src1 = &octx->src1; \ + struct htp_tensor * restrict dst = &octx->dst; \ + struct htp_spad * restrict src0_spad = &octx->src0_spad; \ + struct htp_spad * restrict src1_spad = &octx->src1_spad; \ + struct htp_spad * restrict dst_spad = &octx->dst_spad; \ + \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t ne10 = src1->ne[0]; \ + const uint32_t ne11 = src1->ne[1]; \ + const uint32_t ne12 = src1->ne[2]; \ + const uint32_t ne13 = src1->ne[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t nb10 = src1->nb[0]; \ + const uint32_t nb11 = src1->nb[1]; \ + const uint32_t nb12 = src1->nb[2]; \ + const uint32_t nb13 = src1->nb[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; + +#define htp_ssm_conv_preamble \ + htp_ssm_conv_tensors_preamble; \ + dma_queue *dma_queue = octx->ctx->dma[ith]; \ + uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread; + // Scalar FP32 SSM_CONV implementation static void ssm_conv_thread_f32_f32(struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { const struct htp_tensor * src0 = &octx->src0; // conv_x input -> {d_conv - 1 + n_t, d_inner, n_seqs} @@ -60,24 +105,19 @@ static void ssm_conv_thread_f32_f32(struct htp_ops_context * octx, uint32_t nth, float sumf = 0.0f; for (uint32_t i0 = 0; i0 < d_conv; ++i0) { - // src0: window starting at position i2, element at window offset i0 - // src0 layout: {d_conv - 1 + n_t, d_inner, n_seqs} const uint32_t src0_idx = (i2 + i0) + i1 * src0_stride_inner + i3 * src0_stride_seq; - // src1: conv weight at position i0, inner dim i1 - // src1 layout: {d_conv, d_inner} const uint32_t src1_idx = i0 + i1 * src1_stride_inner; sumf += src0_data[src0_idx] * src1_data[src1_idx]; } - // dst: inner dim i1, token i2, sequence i3 - // dst layout: {d_inner, n_t, n_seqs} const uint32_t dst_idx = i1 + i2 * dst_stride_token + i3 * dst_stride_seq; dst_data[dst_idx] = sumf; } } } + t2 = HAP_perf_get_qtimer_count(); FARF(HIGH, "ssm-conv-f32 %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", @@ -89,152 +129,162 @@ static void ssm_conv_thread_f32_f32(struct htp_ops_context * octx, uint32_t nth, // HVX FP32 SSM_CONV implementation // Vectorizes across d_inner dimension, processing 32 inner dims at once static void ssm_conv_thread_f32_f32_hvx(struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - struct htp_tensor * dst = &octx->dst; + htp_ssm_conv_preamble; uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); + const int nc = src1->ne[0]; // d_conv + const int ncs = src0->ne[0]; // d_conv - 1 + n_t + const uint32_t d_conv = src1->ne[0]; const uint32_t d_inner = src0->ne[1]; const uint32_t n_t = dst->ne[1]; const uint32_t n_s = dst->ne[2]; - const uint32_t src0_stride_inner = src0->nb[1] / sizeof(float); // stride for inner dimension - const uint32_t src0_stride_seq = src0->nb[2] / sizeof(float); // stride for sequence dimension - const uint32_t src1_stride_inner = src1->nb[1] / sizeof(float); // stride for inner dimension - const uint32_t dst_stride_token = dst->nb[1] / sizeof(float); // stride for token dimension - const uint32_t dst_stride_seq = dst->nb[2] / sizeof(float); // stride for sequence dimension - const float * src0_data = (const float *) src0->data; const float * src1_data = (const float *) src1->data; float * dst_data = (float *) dst->data; // Calculate row range for this thread - const uint32_t d_inner_per_thread = (d_inner + nth - 1) / nth; - const uint32_t d_inner_start = d_inner_per_thread * ith; - const uint32_t d_inner_end = MIN(d_inner_start + d_inner_per_thread, d_inner); + const int dr = (d_inner + nth - 1) / nth; + const uint32_t ir0 = dr * ith; + const uint32_t ir1 = MIN(ir0 + dr, d_inner); + const int ir = ir1 - ir0; - if (d_inner_start >= d_inner_end) { + if (ir0 >= ir1) { return; // No work for this thread } - // Align start to VLEN_FP32 boundary - const uint32_t d_inner_vec_start = (d_inner_start + VLEN_FP32 - 1) & ~(VLEN_FP32 - 1); - const uint32_t d_inner_vec_end = d_inner_end & ~(VLEN_FP32 - 1); + // gather op src0 offsets + uint32_t src0_offsets[VLEN_FP32] = { 0 }; + for (uint32_t i = 0; i < VLEN_FP32; ++i) { + src0_offsets[i] = i * (ncs) * sizeof(float); + } + + // gather op src1 offsets + uint32_t src1_offsets[VLEN_FP32] = { 0 }; + for (uint32_t i = 0; i < VLEN_FP32; ++i) { + src1_offsets[i] = i * (d_conv) * sizeof(float); + } + + uint32_t src0_gather_len = (src0->ne[0] * src0->ne[1] * src0->ne[2] * src0->ne[3]) * sizeof(float); + uint32_t src1_gather_len = (src1->ne[0] * dr * src1->ne[2] * src1->ne[3]) * sizeof(float); + + HVX_Vector * src0_vec = (HVX_Vector *) (octx->ctx->vtcm_base + ith * VLEN); + HVX_Vector * src1_vec = (HVX_Vector *) (octx->ctx->vtcm_base + 1024 + ith * VLEN); - // Per sequence for (uint32_t i3 = 0; i3 < n_s; ++i3) { - // Per token for (uint32_t i2 = 0; i2 < n_t; ++i2) { - // Handle scalar remainder at the beginning (when start is not aligned) - for (uint32_t i1 = d_inner_start; i1 < MIN(d_inner_vec_start, d_inner_end); ++i1) { - float sumf = 0.0f; - for (uint32_t i0 = 0; i0 < d_conv; ++i0) { - const uint32_t src0_idx = (i2 + i0) + i1 * src0_stride_inner + i3 * src0_stride_seq; - const uint32_t src1_idx = i0 + i1 * src1_stride_inner; - sumf += src0_data[src0_idx] * src1_data[src1_idx]; - } - const uint32_t dst_idx = i1 + i2 * dst_stride_token + i3 * dst_stride_seq; - dst_data[dst_idx] = sumf; - } - - for (uint32_t i1_vec = d_inner_vec_start; i1_vec < d_inner_vec_end; i1_vec += VLEN_FP32) { + for (uint32_t i1 = 0; i1 < ir; i1 += VLEN_FP32) { HVX_Vector acc_vec = Q6_V_vzero(); - // Per kernel element for (uint32_t i0 = 0; i0 < d_conv; ++i0) { - // Load 32 elements from src0: window at position (i2+i0), inner dims [i1_vec, i1_vec+32) - const float * src0_ptr = src0_data + (i2 + i0) + i1_vec * src0_stride_inner + i3 * src0_stride_seq; - HVX_Vector src0_vec = *(const HVX_Vector *) src0_ptr; + // src0 -> {d_conv, d_inner, n_s} + const float * src0_ptr = (const float *) ((const char *) octx->src0_spad.data + (i0 + i1*ncs) * sizeof(float) + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); + // src1 -> {d_conv, d_inner} + const float * src1_ptr = (const float *) ((const char *) octx->src1_spad.data + (i0 + i1*nc) * sizeof(float) + ir0*(src1->nb[1])); - // Load 32 elements from src1: kernel at position i0, inner dims [i1_vec, i1_vec+32) - const float * src1_ptr = src1_data + i0 + i1_vec * src1_stride_inner; - HVX_Vector src1_vec = *(const HVX_Vector *) src1_ptr; + Q6_vgather_ARMVw(src0_vec, SCATTER_TYPE(src0_ptr), src0_gather_len, (*(const HVX_Vector *) src0_offsets)); + Q6_vgather_ARMVw(src1_vec, SCATTER_TYPE(src1_ptr), src1_gather_len, (*(const HVX_Vector *) src1_offsets)); - HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(src0_vec, src1_vec); + HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec); acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod); } + // dst -> {d_inner, n_t, n_s} + float * dst_ptr = (float *) ((char *) octx->dst_spad.data + i1*sizeof(float) + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); HVX_Vector result_vec = Q6_Vsf_equals_Vqf32(acc_vec); - float * dst_ptr = dst_data + i1_vec + i2 * dst_stride_token + i3 * dst_stride_seq; *(HVX_Vector *) dst_ptr = result_vec; } - - // Handle scalar remainder at the end (if end is not aligned) - for (uint32_t i1 = d_inner_vec_end; i1 < d_inner_end; ++i1) { - float sumf = 0.0f; - for (uint32_t i0 = 0; i0 < d_conv; ++i0) { - const uint32_t src0_idx = (i2 + i0) + i1 * src0_stride_inner + i3 * src0_stride_seq; - const uint32_t src1_idx = i0 + i1 * src1_stride_inner; - sumf += src0_data[src0_idx] * src1_data[src1_idx]; - } - const uint32_t dst_idx = i1 + i2 * dst_stride_token + i3 * dst_stride_seq; - dst_data[dst_idx] = sumf; - } } } t2 = HAP_perf_get_qtimer_count(); FARF(HIGH, "ssm-conv-f32-hvx %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", - ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], d_inner_start, d_inner_end, + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0, ir1, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -static void ssm_conv_work_f32_f32(unsigned int n, unsigned int i, void * data) { +static void ssm_conv_work_f32_f32(unsigned int nth, unsigned int ith, void * data) { struct htp_ops_context * octx = (struct htp_ops_context *) data; - ssm_conv_thread_f32_f32(octx, n, i); + ssm_conv_thread_f32_f32(octx, nth, ith); } -static void ssm_conv_work_f32_f32_hvx(unsigned int n, unsigned int i, void * data) { +static void ssm_conv_work_f32_f32_hvx(unsigned int nth, unsigned int ith, void * data) { struct htp_ops_context * octx = (struct htp_ops_context *) data; - ssm_conv_thread_f32_f32_hvx(octx, n, i); + ssm_conv_thread_f32_f32_hvx(octx, nth, ith); } +float tmp_dst[4096] = { 0.0f }; + int op_ssm_conv_f32(struct htp_ops_context * octx) { - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - struct htp_tensor * dst = &octx->dst; + htp_ssm_conv_tensors_preamble; + + assert(sizeof(float) == SIZEOF_FP32); if (src0->type != HTP_TYPE_F32 || src1->type != HTP_TYPE_F32 || dst->type != HTP_TYPE_F32) { FARF(ERROR, "ssm_conv: only (F32 x F32 -> F32) OPs supported"); return HTP_STATUS_NO_SUPPORT; } - const uint32_t nc = src1->ne[0]; // d_conv - const uint32_t ncs = src0->ne[0]; // d_conv - 1 + n_t - const int nr = src0->ne[1]; // d_inner - const int n_t = dst->ne[1]; // tokens per sequence - const int n_s = dst->ne[2]; // number of sequences in the batch + const uint32_t d_conv = src1->ne[0]; + const uint32_t d_inner = src0->ne[1]; + const uint32_t n_t = dst->ne[1]; // tokens per sequence + const uint32_t n_s = dst->ne[2]; // number of sequences in the batch if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - const uint32_t n_jobs = MIN(octx->n_threads, nr); - - const uint32_t src0_stride_inner = src0->nb[1] / sizeof(float); - const uint32_t src1_stride_inner = src1->nb[1] / sizeof(float); + const uint32_t n_jobs = MIN(octx->n_threads, d_inner); int use_hvx = 0; - if (nr >= VLEN_FP32) { + if (d_inner >= VLEN_FP32 && d_inner % VLEN_FP32 == 0) { int is_aligned = hex_is_aligned((void *) src0->data, VLEN) && hex_is_aligned((void *) src1->data, VLEN) && hex_is_aligned((void *) dst->data, VLEN); - int strides_aligned = !(src0_stride_inner & (VLEN_FP32 - 1)) && !(src1_stride_inner & (VLEN_FP32 - 1)); - - if (is_aligned && strides_aligned) { + if (is_aligned && n_t > 8) { use_hvx = 1; } - FARF(HIGH, "ssm-conv-f32: (%ux%ux%ux%u) x (%ux%ux%ux%u) -> (%ux%ux%ux%u) : use_hvx %d is_aligned %d strides_aligned %d\n", - src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], - dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], use_hvx, is_aligned, strides_aligned); + FARF(HIGH, "ssm-conv-f32: (%ux%ux%ux%u) x (%ux%ux%ux%u) -> (%ux%ux%ux%u) : use_hvx %d is_aligned %d\n", + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], use_hvx, is_aligned); } + // rows per thread + const int dr = (src0->ne[1] + n_jobs - 1) / n_jobs; + + octx->dst_spad.size_per_thread = hex_round_up(dr * ne1 * ne2 * ne3 * sizeof(float), 256); + octx->src0_spad.size_per_thread = hex_round_up(ne00 * dr * ne02 * ne03 * sizeof(float), 256); + octx->src1_spad.size_per_thread = hex_round_up(ne10 * dr * ne12 * ne13 * sizeof(float), 256); + + octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + + octx->src0_spad.data = octx->ctx->vtcm_base + 2048; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; + + FARF(ERROR, + "ssm_conv: dr: %u spad-per-thread:(%u:%u:%u) spad-sizes:(%u:%u:%u) spad-per-thread-data:(%p:%p:%p)\n", dr, + octx->src0_spad.size_per_thread, octx->src1_spad.size_per_thread, octx->dst_spad.size_per_thread, + octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size, octx->src0_spad.data, + octx->src1_spad.data, octx->dst_spad.data); + if (use_hvx) { + //// Remove me + //worker_pool_run_func(octx->ctx->worker_pool, ssm_conv_work_f32_f32, octx, n_jobs); + //memcpy((uint8_t *) tmp_dst, (const float *) dst->data, 4096); + + memcpy(octx->src0_spad.data, (const uint8_t *) src0->data, octx->src0_spad.size); + memcpy(octx->src1_spad.data, (const uint8_t *) src1->data, octx->src1_spad.size); + worker_pool_run_func(octx->ctx->worker_pool, ssm_conv_work_f32_f32_hvx, octx, n_jobs); + + memcpy((uint8_t *) dst->data, octx->dst_spad.data, octx->dst_spad.size); } else { worker_pool_run_func(octx->ctx->worker_pool, ssm_conv_work_f32_f32, octx, n_jobs); } From 9aba15bb271ce649666096e361cbd65f306e0d87 Mon Sep 17 00:00:00 2001 From: Todor Boinovski Date: Thu, 19 Feb 2026 18:52:03 -0800 Subject: [PATCH 3/6] hexagon: improvements to ssm-conv hvx kernel --- ggml/src/ggml-hexagon/htp/ssm-conv.c | 77 +++++++++++++--------------- 1 file changed, 36 insertions(+), 41 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/ssm-conv.c b/ggml/src/ggml-hexagon/htp/ssm-conv.c index 06bc83acf14..eda44b237d0 100644 --- a/ggml/src/ggml-hexagon/htp/ssm-conv.c +++ b/ggml/src/ggml-hexagon/htp/ssm-conv.c @@ -117,7 +117,6 @@ static void ssm_conv_thread_f32_f32(struct htp_ops_context * octx, uint32_t nth, } } - t2 = HAP_perf_get_qtimer_count(); FARF(HIGH, "ssm-conv-f32 %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", @@ -126,8 +125,7 @@ static void ssm_conv_thread_f32_f32(struct htp_ops_context * octx, uint32_t nth, dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -// HVX FP32 SSM_CONV implementation -// Vectorizes across d_inner dimension, processing 32 inner dims at once +// HVX FP32 SSM_CONV implementation - vectorizes across d_inner dimension static void ssm_conv_thread_f32_f32_hvx(struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { htp_ssm_conv_preamble; @@ -156,44 +154,53 @@ static void ssm_conv_thread_f32_f32_hvx(struct htp_ops_context * octx, uint32_t return; // No work for this thread } - // gather op src0 offsets + // src0 gather offsets uint32_t src0_offsets[VLEN_FP32] = { 0 }; for (uint32_t i = 0; i < VLEN_FP32; ++i) { src0_offsets[i] = i * (ncs) * sizeof(float); } + uint32_t src0_gather_len = VLEN * ncs; - // gather op src1 offsets + // src1 gather offsets uint32_t src1_offsets[VLEN_FP32] = { 0 }; for (uint32_t i = 0; i < VLEN_FP32; ++i) { src1_offsets[i] = i * (d_conv) * sizeof(float); } - - uint32_t src0_gather_len = (src0->ne[0] * src0->ne[1] * src0->ne[2] * src0->ne[3]) * sizeof(float); - uint32_t src1_gather_len = (src1->ne[0] * dr * src1->ne[2] * src1->ne[3]) * sizeof(float); + uint32_t src1_gather_len = VLEN * d_conv; HVX_Vector * src0_vec = (HVX_Vector *) (octx->ctx->vtcm_base + ith * VLEN); HVX_Vector * src1_vec = (HVX_Vector *) (octx->ctx->vtcm_base + 1024 + ith * VLEN); + float * data_src0 = (float *) ((char *) src0->data + ir0*(src0->nb[1])); + float * data_src1 = (float *) ((char *) src1->data + ir0*(src1->nb[1])); + + uint8_t * spad_src0 = octx->src0_spad.data + ith * octx->src0_spad.size_per_thread; + uint8_t * spad_src1 = octx->src1_spad.data + ith * octx->src1_spad.size_per_thread; + + memcpy(spad_src1, data_src1, octx->src1_spad.size_per_thread); + for (uint32_t i3 = 0; i3 < n_s; ++i3) { + float * src0_data_ptr = (float *) ((char *) data_src0 + i3 * (src0->nb[2])); + + memcpy(spad_src0, src0_data_ptr, octx->src0_spad.size_per_thread); + for (uint32_t i2 = 0; i2 < n_t; ++i2) { for (uint32_t i1 = 0; i1 < ir; i1 += VLEN_FP32) { HVX_Vector acc_vec = Q6_V_vzero(); for (uint32_t i0 = 0; i0 < d_conv; ++i0) { - // src0 -> {d_conv, d_inner, n_s} - const float * src0_ptr = (const float *) ((const char *) octx->src0_spad.data + (i0 + i1*ncs) * sizeof(float) + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); - // src1 -> {d_conv, d_inner} - const float * src1_ptr = (const float *) ((const char *) octx->src1_spad.data + (i0 + i1*nc) * sizeof(float) + ir0*(src1->nb[1])); - - Q6_vgather_ARMVw(src0_vec, SCATTER_TYPE(src0_ptr), src0_gather_len, (*(const HVX_Vector *) src0_offsets)); - Q6_vgather_ARMVw(src1_vec, SCATTER_TYPE(src1_ptr), src1_gather_len, (*(const HVX_Vector *) src1_offsets)); + Q6_vgather_ARMVw(src0_vec, + SCATTER_TYPE(spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0])), + src0_gather_len, (*(const HVX_Vector *) src0_offsets)); + Q6_vgather_ARMVw(src1_vec, + SCATTER_TYPE(spad_src1 + (i0 + i1 * nc) * sizeof(float)), + src1_gather_len, (*(const HVX_Vector *) src1_offsets)); HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec); acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod); } - // dst -> {d_inner, n_t, n_s} - float * dst_ptr = (float *) ((char *) octx->dst_spad.data + i1*sizeof(float) + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); + float * dst_ptr = (float *) ((char *) dst->data + i1*sizeof(float) + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); HVX_Vector result_vec = Q6_Vsf_equals_Vqf32(acc_vec); *(HVX_Vector *) dst_ptr = result_vec; } @@ -218,8 +225,6 @@ static void ssm_conv_work_f32_f32_hvx(unsigned int nth, unsigned int ith, void * ssm_conv_thread_f32_f32_hvx(octx, nth, ith); } -float tmp_dst[4096] = { 0.0f }; - int op_ssm_conv_f32(struct htp_ops_context * octx) { htp_ssm_conv_tensors_preamble; @@ -232,8 +237,8 @@ int op_ssm_conv_f32(struct htp_ops_context * octx) { const uint32_t d_conv = src1->ne[0]; const uint32_t d_inner = src0->ne[1]; - const uint32_t n_t = dst->ne[1]; // tokens per sequence - const uint32_t n_s = dst->ne[2]; // number of sequences in the batch + const uint32_t n_t = dst->ne[1]; // tokens per sequence + const uint32_t n_s = dst->ne[2]; // number of sequences in the batch if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { const uint32_t n_jobs = MIN(octx->n_threads, d_inner); @@ -244,21 +249,21 @@ int op_ssm_conv_f32(struct htp_ops_context * octx) { hex_is_aligned((void *) src1->data, VLEN) && hex_is_aligned((void *) dst->data, VLEN); - if (is_aligned && n_t > 8) { + if (is_aligned && n_t > 3) { use_hvx = 1; } - - FARF(HIGH, "ssm-conv-f32: (%ux%ux%ux%u) x (%ux%ux%ux%u) -> (%ux%ux%ux%u) : use_hvx %d is_aligned %d\n", - src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], - dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], use_hvx, is_aligned); } - // rows per thread - const int dr = (src0->ne[1] + n_jobs - 1) / n_jobs; + FARF(HIGH, "ssm-conv-f32: (%ux%ux%ux%u) x (%ux%ux%ux%u) -> (%ux%ux%ux%u) : use_hvx %d\n", + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], use_hvx); + + // chunks per thread + const int dr = (d_inner + n_jobs - 1) / n_jobs; octx->dst_spad.size_per_thread = hex_round_up(dr * ne1 * ne2 * ne3 * sizeof(float), 256); - octx->src0_spad.size_per_thread = hex_round_up(ne00 * dr * ne02 * ne03 * sizeof(float), 256); - octx->src1_spad.size_per_thread = hex_round_up(ne10 * dr * ne12 * ne13 * sizeof(float), 256); + octx->src0_spad.size_per_thread = hex_round_up(ne00 * dr * sizeof(float), 256); + octx->src1_spad.size_per_thread = hex_round_up(ne10 * dr * sizeof(float), 256); octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; @@ -268,23 +273,13 @@ int op_ssm_conv_f32(struct htp_ops_context * octx) { octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; - FARF(ERROR, - "ssm_conv: dr: %u spad-per-thread:(%u:%u:%u) spad-sizes:(%u:%u:%u) spad-per-thread-data:(%p:%p:%p)\n", dr, + FARF(HIGH, "ssm_conv-f32: spad-per-thread:(%u:%u:%u) spad-sizes:(%u:%u:%u) spad-data:(%p:%p:%p)\n", octx->src0_spad.size_per_thread, octx->src1_spad.size_per_thread, octx->dst_spad.size_per_thread, octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size, octx->src0_spad.data, octx->src1_spad.data, octx->dst_spad.data); if (use_hvx) { - //// Remove me - //worker_pool_run_func(octx->ctx->worker_pool, ssm_conv_work_f32_f32, octx, n_jobs); - //memcpy((uint8_t *) tmp_dst, (const float *) dst->data, 4096); - - memcpy(octx->src0_spad.data, (const uint8_t *) src0->data, octx->src0_spad.size); - memcpy(octx->src1_spad.data, (const uint8_t *) src1->data, octx->src1_spad.size); - worker_pool_run_func(octx->ctx->worker_pool, ssm_conv_work_f32_f32_hvx, octx, n_jobs); - - memcpy((uint8_t *) dst->data, octx->dst_spad.data, octx->dst_spad.size); } else { worker_pool_run_func(octx->ctx->worker_pool, ssm_conv_work_f32_f32, octx, n_jobs); } From e0e10d7908645f81a9748185a8f5090a1d8d8004 Mon Sep 17 00:00:00 2001 From: Todor Boinovski Date: Fri, 20 Feb 2026 12:29:00 -0800 Subject: [PATCH 4/6] hexagon: added dma to ssm-conv hvx kernel --- ggml/src/ggml-hexagon/htp/ssm-conv.c | 134 ++++++++++++++------------- 1 file changed, 72 insertions(+), 62 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/ssm-conv.c b/ggml/src/ggml-hexagon/htp/ssm-conv.c index eda44b237d0..00caac93808 100644 --- a/ggml/src/ggml-hexagon/htp/ssm-conv.c +++ b/ggml/src/ggml-hexagon/htp/ssm-conv.c @@ -20,9 +20,7 @@ #include "htp-ops.h" #include "hvx-utils.h" -#include "hvx-dump.h" - -#define htp_ssm_conv_tensors_preamble \ +#define htp_ssm_conv_tensors_preamble \ struct htp_tensor * restrict src0 = &octx->src0; \ struct htp_tensor * restrict src1 = &octx->src1; \ struct htp_tensor * restrict dst = &octx->dst; \ @@ -30,46 +28,46 @@ struct htp_spad * restrict src1_spad = &octx->src1_spad; \ struct htp_spad * restrict dst_spad = &octx->dst_spad; \ \ - const uint32_t ne00 = src0->ne[0]; \ - const uint32_t ne01 = src0->ne[1]; \ - const uint32_t ne02 = src0->ne[2]; \ - const uint32_t ne03 = src0->ne[3]; \ - \ - const uint32_t ne10 = src1->ne[0]; \ - const uint32_t ne11 = src1->ne[1]; \ - const uint32_t ne12 = src1->ne[2]; \ - const uint32_t ne13 = src1->ne[3]; \ - \ - const uint32_t ne0 = dst->ne[0]; \ - const uint32_t ne1 = dst->ne[1]; \ - const uint32_t ne2 = dst->ne[2]; \ - const uint32_t ne3 = dst->ne[3]; \ - \ - const uint32_t nb00 = src0->nb[0]; \ - const uint32_t nb01 = src0->nb[1]; \ - const uint32_t nb02 = src0->nb[2]; \ - const uint32_t nb03 = src0->nb[3]; \ - \ - const uint32_t nb10 = src1->nb[0]; \ - const uint32_t nb11 = src1->nb[1]; \ - const uint32_t nb12 = src1->nb[2]; \ - const uint32_t nb13 = src1->nb[3]; \ - \ - const uint32_t nb0 = dst->nb[0]; \ - const uint32_t nb1 = dst->nb[1]; \ - const uint32_t nb2 = dst->nb[2]; \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t ne10 = src1->ne[0]; \ + const uint32_t ne11 = src1->ne[1]; \ + const uint32_t ne12 = src1->ne[2]; \ + const uint32_t ne13 = src1->ne[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t nb10 = src1->nb[0]; \ + const uint32_t nb11 = src1->nb[1]; \ + const uint32_t nb12 = src1->nb[2]; \ + const uint32_t nb13 = src1->nb[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; -#define htp_ssm_conv_preamble \ - htp_ssm_conv_tensors_preamble; \ - dma_queue *dma_queue = octx->ctx->dma[ith]; \ - uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread; +#define htp_ssm_conv_preamble \ + htp_ssm_conv_tensors_preamble; \ + dma_queue * dma_queue = octx->ctx->dma[ith]; \ + uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread; + +#define SSM_CONV_GATHER_SPAD_SIZE 2048 // Scalar FP32 SSM_CONV implementation static void ssm_conv_thread_f32_f32(struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { - const struct htp_tensor * src0 = &octx->src0; // conv_x input -> {d_conv - 1 + n_t, d_inner, n_seqs} - const struct htp_tensor * src1 = &octx->src1; // conv1d weights -> {d_conv, d_inner} - struct htp_tensor * dst = &octx->dst; // output -> {d_inner, n_t, n_seqs} + htp_ssm_conv_tensors_preamble; uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); @@ -168,6 +166,7 @@ static void ssm_conv_thread_f32_f32_hvx(struct htp_ops_context * octx, uint32_t } uint32_t src1_gather_len = VLEN * d_conv; + // gather scratchpads HVX_Vector * src0_vec = (HVX_Vector *) (octx->ctx->vtcm_base + ith * VLEN); HVX_Vector * src1_vec = (HVX_Vector *) (octx->ctx->vtcm_base + 1024 + ith * VLEN); @@ -177,14 +176,20 @@ static void ssm_conv_thread_f32_f32_hvx(struct htp_ops_context * octx, uint32_t uint8_t * spad_src0 = octx->src0_spad.data + ith * octx->src0_spad.size_per_thread; uint8_t * spad_src1 = octx->src1_spad.data + ith * octx->src1_spad.size_per_thread; - memcpy(spad_src1, data_src1, octx->src1_spad.size_per_thread); + // copy src1 workload to VTCM + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src1, data_src1), nb11, nb11, ir); for (uint32_t i3 = 0; i3 < n_s; ++i3) { float * src0_data_ptr = (float *) ((char *) data_src0 + i3 * (src0->nb[2])); - memcpy(spad_src0, src0_data_ptr, octx->src0_spad.size_per_thread); + // copy src0 workload to VTCM + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0, src0_data_ptr), nb01, nb01, ir); + dma_queue_flush(dma_queue); for (uint32_t i2 = 0; i2 < n_t; ++i2) { + float * dst_ptr = + (float *) ((char *) dst->data + ir0 * (dst->nb[0]) + i2 * (dst->nb[1]) + i3 * (dst->nb[2])); + for (uint32_t i1 = 0; i1 < ir; i1 += VLEN_FP32) { HVX_Vector acc_vec = Q6_V_vzero(); @@ -192,17 +197,15 @@ static void ssm_conv_thread_f32_f32_hvx(struct htp_ops_context * octx, uint32_t Q6_vgather_ARMVw(src0_vec, SCATTER_TYPE(spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0])), src0_gather_len, (*(const HVX_Vector *) src0_offsets)); - Q6_vgather_ARMVw(src1_vec, - SCATTER_TYPE(spad_src1 + (i0 + i1 * nc) * sizeof(float)), + Q6_vgather_ARMVw(src1_vec, SCATTER_TYPE(spad_src1 + (i0 + i1 * nc) * sizeof(float)), src1_gather_len, (*(const HVX_Vector *) src1_offsets)); HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec); - acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod); + acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod); } - float * dst_ptr = (float *) ((char *) dst->data + i1*sizeof(float) + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); - HVX_Vector result_vec = Q6_Vsf_equals_Vqf32(acc_vec); - *(HVX_Vector *) dst_ptr = result_vec; + HVX_Vector result_vec = Q6_Vsf_equals_Vqf32(acc_vec); + *(HVX_Vector *) (dst_ptr + i1) = result_vec; } } } @@ -228,8 +231,6 @@ static void ssm_conv_work_f32_f32_hvx(unsigned int nth, unsigned int ith, void * int op_ssm_conv_f32(struct htp_ops_context * octx) { htp_ssm_conv_tensors_preamble; - assert(sizeof(float) == SIZEOF_FP32); - if (src0->type != HTP_TYPE_F32 || src1->type != HTP_TYPE_F32 || dst->type != HTP_TYPE_F32) { FARF(ERROR, "ssm_conv: only (F32 x F32 -> F32) OPs supported"); return HTP_STATUS_NO_SUPPORT; @@ -243,40 +244,49 @@ int op_ssm_conv_f32(struct htp_ops_context * octx) { if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { const uint32_t n_jobs = MIN(octx->n_threads, d_inner); - int use_hvx = 0; + uint32_t use_hvx = 0; if (d_inner >= VLEN_FP32 && d_inner % VLEN_FP32 == 0) { int is_aligned = hex_is_aligned((void *) src0->data, VLEN) && hex_is_aligned((void *) src1->data, VLEN) && hex_is_aligned((void *) dst->data, VLEN); - if (is_aligned && n_t > 3) { + if (is_aligned) { use_hvx = 1; } } - FARF(HIGH, "ssm-conv-f32: (%ux%ux%ux%u) x (%ux%ux%ux%u) -> (%ux%ux%ux%u) : use_hvx %d\n", - src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], - dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], use_hvx); - - // chunks per thread + // d_inner chunks per thread const int dr = (d_inner + n_jobs - 1) / n_jobs; - octx->dst_spad.size_per_thread = hex_round_up(dr * ne1 * ne2 * ne3 * sizeof(float), 256); - octx->src0_spad.size_per_thread = hex_round_up(ne00 * dr * sizeof(float), 256); - octx->src1_spad.size_per_thread = hex_round_up(ne10 * dr * sizeof(float), 256); + octx->src0_spad.size_per_thread = hex_round_up(dr * nb01, 256); + octx->src1_spad.size_per_thread = hex_round_up(dr * nb11, 256); + octx->dst_spad.size_per_thread = hex_round_up(dr * sizeof(float), 256); octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - octx->src0_spad.data = octx->ctx->vtcm_base + 2048; + octx->src0_spad.data = octx->ctx->vtcm_base + SSM_CONV_GATHER_SPAD_SIZE; octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; FARF(HIGH, "ssm_conv-f32: spad-per-thread:(%u:%u:%u) spad-sizes:(%u:%u:%u) spad-data:(%p:%p:%p)\n", - octx->src0_spad.size_per_thread, octx->src1_spad.size_per_thread, octx->dst_spad.size_per_thread, - octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size, octx->src0_spad.data, - octx->src1_spad.data, octx->dst_spad.data); + octx->src0_spad.size_per_thread, octx->src1_spad.size_per_thread, octx->dst_spad.size_per_thread, + octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size, octx->src0_spad.data, + octx->src1_spad.data, octx->dst_spad.data); + + const size_t total_spad_size = + SSM_CONV_GATHER_SPAD_SIZE + octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size; + + if (total_spad_size > octx->ctx->vtcm_size) { + FARF(HIGH, "ssm_conv-f32: HVX scratchpad size %zu exceeds VTCM size %zu", + total_spad_size, octx->ctx->vtcm_size); + use_hvx = 0; + } + + FARF(HIGH, "ssm-conv-f32: (%ux%ux%ux%u) x (%ux%ux%ux%u) -> (%ux%ux%ux%u) : use_hvx %d\n", src0->ne[0], + src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], + dst->ne[1], dst->ne[2], dst->ne[3], use_hvx); if (use_hvx) { worker_pool_run_func(octx->ctx->worker_pool, ssm_conv_work_f32_f32_hvx, octx, n_jobs); From 1d6cc3b47e4cf16f6a56b3261806d32317d57978 Mon Sep 17 00:00:00 2001 From: Todor Boinovski Date: Thu, 26 Feb 2026 18:11:27 -0800 Subject: [PATCH 5/6] hexagon: ssm-conv dynamically compute gather scratchpad --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 9 +---- ggml/src/ggml-hexagon/htp/hvx-utils.h | 12 ++---- ggml/src/ggml-hexagon/htp/ssm-conv.c | 54 +++++++++++++------------- 3 files changed, 32 insertions(+), 43 deletions(-) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index ff851b2711b..d6e9776b878 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2162,7 +2162,7 @@ static bool ggml_hexagon_supported_ssm_conv(const struct ggml_hexagon_session * return false; } - // Check IO tensor shapes + // Check IO tensor shapes and dims if (src0->ne[3] != 1 || src1->ne[2] != 1 || src1->ne[3] != 1 || dst->ne[3] != 1) { return false; // src0 should be effectively 3D } @@ -2182,13 +2182,6 @@ static bool ggml_hexagon_supported_ssm_conv(const struct ggml_hexagon_session * return false; } - if (src0->nb[0] != sizeof(float) || src0->nb[1] != src0->ne[0] * sizeof(float)) { - return false; - } - if (src1->nb[0] != sizeof(float)) { - return false; - } - // TODO: add support for non-contiguous tensors if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) { return false; diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.h b/ggml/src/ggml-hexagon/htp/hvx-utils.h index 31d1e139c1c..08343798794 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-utils.h +++ b/ggml/src/ggml-hexagon/htp/hvx-utils.h @@ -16,17 +16,11 @@ #include "hvx-base.h" #ifndef GATHER_TYPE -# ifdef __hexagon__ -# define GATHER_TYPE(_a) (uint32_t) (_a) +# if defined(__hexagon__) +# define GATHER_TYPE(_a) (intptr_t) _a # else -# define GATHER_TYPE(_a) (_a) +# define GATHER_TYPE(_a) (HVX_Vector *) _a # endif #endif -#if defined(__hexagon__) -# define SCATTER_TYPE(_a) (intptr_t) _a -#else -# define SCATTER_TYPE(_a) (HVX_Vector *) _a -#endif - #endif /* HVX_UTILS_H */ diff --git a/ggml/src/ggml-hexagon/htp/ssm-conv.c b/ggml/src/ggml-hexagon/htp/ssm-conv.c index 00caac93808..9716b0a79a6 100644 --- a/ggml/src/ggml-hexagon/htp/ssm-conv.c +++ b/ggml/src/ggml-hexagon/htp/ssm-conv.c @@ -63,8 +63,6 @@ dma_queue * dma_queue = octx->ctx->dma[ith]; \ uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread; -#define SSM_CONV_GATHER_SPAD_SIZE 2048 - // Scalar FP32 SSM_CONV implementation static void ssm_conv_thread_f32_f32(struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { htp_ssm_conv_tensors_preamble; @@ -194,10 +192,9 @@ static void ssm_conv_thread_f32_f32_hvx(struct htp_ops_context * octx, uint32_t HVX_Vector acc_vec = Q6_V_vzero(); for (uint32_t i0 = 0; i0 < d_conv; ++i0) { - Q6_vgather_ARMVw(src0_vec, - SCATTER_TYPE(spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0])), + Q6_vgather_ARMVw(src0_vec, GATHER_TYPE(spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0])), src0_gather_len, (*(const HVX_Vector *) src0_offsets)); - Q6_vgather_ARMVw(src1_vec, SCATTER_TYPE(spad_src1 + (i0 + i1 * nc) * sizeof(float)), + Q6_vgather_ARMVw(src1_vec, GATHER_TYPE(spad_src1 + (i0 + i1 * nc) * sizeof(float)), src1_gather_len, (*(const HVX_Vector *) src1_offsets)); HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec); @@ -255,33 +252,38 @@ int op_ssm_conv_f32(struct htp_ops_context * octx) { } } - // d_inner chunks per thread - const int dr = (d_inner + n_jobs - 1) / n_jobs; + if (use_hvx) { + // d_inner chunks per thread + const int dr = (d_inner + n_jobs - 1) / n_jobs; + + octx->src0_spad.size_per_thread = hex_round_up(dr * nb01, 256); + octx->src1_spad.size_per_thread = hex_round_up(dr * nb11, 256); + octx->dst_spad.size_per_thread = hex_round_up(dr * sizeof(float), 256); - octx->src0_spad.size_per_thread = hex_round_up(dr * nb01, 256); - octx->src1_spad.size_per_thread = hex_round_up(dr * nb11, 256); - octx->dst_spad.size_per_thread = hex_round_up(dr * sizeof(float), 256); + octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + // Compute gather scratchpad size for src0 and src1 + const size_t gather_spad_size = n_jobs * VLEN * 2; - octx->src0_spad.data = octx->ctx->vtcm_base + SSM_CONV_GATHER_SPAD_SIZE; - octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; - octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; + octx->src0_spad.data = octx->ctx->vtcm_base + gather_spad_size; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; - FARF(HIGH, "ssm_conv-f32: spad-per-thread:(%u:%u:%u) spad-sizes:(%u:%u:%u) spad-data:(%p:%p:%p)\n", - octx->src0_spad.size_per_thread, octx->src1_spad.size_per_thread, octx->dst_spad.size_per_thread, - octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size, octx->src0_spad.data, - octx->src1_spad.data, octx->dst_spad.data); + FARF(HIGH, "ssm_conv-f32: gather-spad:%zu spad-per-thread:(%u:%u:%u) spad-sizes:(%u:%u:%u) spad-data:(%p:%p:%p)\n", + gather_spad_size, octx->src0_spad.size_per_thread, octx->src1_spad.size_per_thread, + octx->dst_spad.size_per_thread, octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size, + octx->src0_spad.data, octx->src1_spad.data, octx->dst_spad.data); - const size_t total_spad_size = - SSM_CONV_GATHER_SPAD_SIZE + octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size; + const size_t total_spad_size = + gather_spad_size + octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size; - if (total_spad_size > octx->ctx->vtcm_size) { - FARF(HIGH, "ssm_conv-f32: HVX scratchpad size %zu exceeds VTCM size %zu", - total_spad_size, octx->ctx->vtcm_size); - use_hvx = 0; + if (total_spad_size > octx->ctx->vtcm_size) { + FARF(HIGH, "ssm_conv-f32: HVX scratchpad size %zu exceeds VTCM size %zu", total_spad_size, + octx->ctx->vtcm_size); + use_hvx = 0; + } } FARF(HIGH, "ssm-conv-f32: (%ux%ux%ux%u) x (%ux%ux%ux%u) -> (%ux%ux%ux%u) : use_hvx %d\n", src0->ne[0], From daec76815478da17b96d8ed21bdfb215df8395ef Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Thu, 5 Mar 2026 20:02:09 -0800 Subject: [PATCH 6/6] hex-ssm-conv: add local context and fix various issues (spad indexing, etc) --- ggml/src/ggml-hexagon/htp/htp-ops.h | 3 - ggml/src/ggml-hexagon/htp/ssm-conv.c | 130 ++++++++++++++++----------- 2 files changed, 76 insertions(+), 57 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index da65d5b54d7..2ef20936f1b 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -41,9 +41,6 @@ struct htp_ops_context { worker_pool_context_t * wpool; // worker pool uint32_t n_threads; // num threads - uint32_t src0_nrows_per_thread; - uint32_t src1_nrows_per_thread; - uint32_t flags; }; diff --git a/ggml/src/ggml-hexagon/htp/ssm-conv.c b/ggml/src/ggml-hexagon/htp/ssm-conv.c index 9716b0a79a6..b3c1ef9572e 100644 --- a/ggml/src/ggml-hexagon/htp/ssm-conv.c +++ b/ggml/src/ggml-hexagon/htp/ssm-conv.c @@ -58,14 +58,21 @@ const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; -#define htp_ssm_conv_preamble \ - htp_ssm_conv_tensors_preamble; \ - dma_queue * dma_queue = octx->ctx->dma[ith]; \ - uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread; +struct htp_ssm_conv_context { + struct htp_ops_context * octx; + uint32_t nrows_per_thread; + uint64_t t_start; +}; + +#define htp_ssm_conv_preamble \ + struct htp_ssm_conv_context * scctx = (struct htp_ssm_conv_context *) data; \ + struct htp_ops_context * octx = scctx->octx; \ + htp_ssm_conv_tensors_preamble; \ + dma_queue * dma_queue = octx->ctx->dma[ith]; // Scalar FP32 SSM_CONV implementation -static void ssm_conv_thread_f32_f32(struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { - htp_ssm_conv_tensors_preamble; +static void ssm_conv_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) { + htp_ssm_conv_preamble; uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); @@ -76,17 +83,17 @@ static void ssm_conv_thread_f32_f32(struct htp_ops_context * octx, uint32_t nth, const uint32_t n_s = dst->ne[2]; const uint32_t src0_stride_inner = src0->nb[1] / sizeof(float); // stride for inner dimension - const uint32_t src0_stride_seq = src0->nb[2] / sizeof(float); // stride for sequence dimension + const uint32_t src0_stride_seq = src0->nb[2] / sizeof(float); // stride for sequence dimension const uint32_t src1_stride_inner = src1->nb[1] / sizeof(float); // stride for inner dimension - const uint32_t dst_stride_token = dst->nb[1] / sizeof(float); // stride for token dimension - const uint32_t dst_stride_seq = dst->nb[2] / sizeof(float); // stride for sequence dimension + const uint32_t dst_stride_token = dst->nb[1] / sizeof(float); // stride for token dimension + const uint32_t dst_stride_seq = dst->nb[2] / sizeof(float); // stride for sequence dimension const float * src0_data = (const float *) src0->data; const float * src1_data = (const float *) src1->data; float * dst_data = (float *) dst->data; // Calculate row range for this thread - const uint32_t d_inner_per_thread = (d_inner + nth - 1) / nth; + const uint32_t d_inner_per_thread = scctx->nrows_per_thread; const uint32_t d_inner_start = d_inner_per_thread * ith; const uint32_t d_inner_end = MIN(d_inner_start + d_inner_per_thread, d_inner); @@ -122,7 +129,7 @@ static void ssm_conv_thread_f32_f32(struct htp_ops_context * octx, uint32_t nth, } // HVX FP32 SSM_CONV implementation - vectorizes across d_inner dimension -static void ssm_conv_thread_f32_f32_hvx(struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { +static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void *data) { htp_ssm_conv_preamble; uint64_t t1, t2; @@ -141,7 +148,7 @@ static void ssm_conv_thread_f32_f32_hvx(struct htp_ops_context * octx, uint32_t float * dst_data = (float *) dst->data; // Calculate row range for this thread - const int dr = (d_inner + nth - 1) / nth; + const int dr = scctx->nrows_per_thread; const uint32_t ir0 = dr * ith; const uint32_t ir1 = MIN(ir0 + dr, d_inner); const int ir = ir1 - ir0; @@ -150,26 +157,24 @@ static void ssm_conv_thread_f32_f32_hvx(struct htp_ops_context * octx, uint32_t return; // No work for this thread } - // src0 gather offsets - uint32_t src0_offsets[VLEN_FP32] = { 0 }; - for (uint32_t i = 0; i < VLEN_FP32; ++i) { - src0_offsets[i] = i * (ncs) * sizeof(float); - } - uint32_t src0_gather_len = VLEN * ncs; + // src0 and src1 gather offsets + uint32_t __attribute__((aligned(VLEN))) src0_offsets[VLEN_FP32] = { 0 }; + uint32_t __attribute__((aligned(VLEN))) src1_offsets[VLEN_FP32] = { 0 }; - // src1 gather offsets - uint32_t src1_offsets[VLEN_FP32] = { 0 }; for (uint32_t i = 0; i < VLEN_FP32; ++i) { + src0_offsets[i] = i * (ncs) * sizeof(float); src1_offsets[i] = i * (d_conv) * sizeof(float); } - uint32_t src1_gather_len = VLEN * d_conv; + + const uint32_t src0_gather_len = VLEN * ncs; + const uint32_t src1_gather_len = VLEN * d_conv; // gather scratchpads - HVX_Vector * src0_vec = (HVX_Vector *) (octx->ctx->vtcm_base + ith * VLEN); - HVX_Vector * src1_vec = (HVX_Vector *) (octx->ctx->vtcm_base + 1024 + ith * VLEN); + HVX_Vector * src0_vec = (HVX_Vector *) (octx->ctx->vtcm_base + ith * VLEN*2 + 0); + HVX_Vector * src1_vec = (HVX_Vector *) (octx->ctx->vtcm_base + ith * VLEN*2 + VLEN); - float * data_src0 = (float *) ((char *) src0->data + ir0*(src0->nb[1])); - float * data_src1 = (float *) ((char *) src1->data + ir0*(src1->nb[1])); + float * data_src0 = (float *) ((char *) src0->data + ir0 * src0->nb[1]); + float * data_src1 = (float *) ((char *) src1->data + ir0 * src1->nb[1]); uint8_t * spad_src0 = octx->src0_spad.data + ith * octx->src0_spad.size_per_thread; uint8_t * spad_src1 = octx->src1_spad.data + ith * octx->src1_spad.size_per_thread; @@ -177,19 +182,27 @@ static void ssm_conv_thread_f32_f32_hvx(struct htp_ops_context * octx, uint32_t // copy src1 workload to VTCM dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src1, data_src1), nb11, nb11, ir); + // FARF(HIGH, "ssm-conv-src1-fetch %d: ir0 %u size %u\n", ith, ir0, nb11 * ir); + for (uint32_t i3 = 0; i3 < n_s; ++i3) { float * src0_data_ptr = (float *) ((char *) data_src0 + i3 * (src0->nb[2])); // copy src0 workload to VTCM dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0, src0_data_ptr), nb01, nb01, ir); + + // FARF(HIGH, "ssm-conv-src0-fetch %d: ir0 %u i3 %u size %u\n", ith, ir0, i3, nb01 * ir); + dma_queue_flush(dma_queue); for (uint32_t i2 = 0; i2 < n_t; ++i2) { - float * dst_ptr = - (float *) ((char *) dst->data + ir0 * (dst->nb[0]) + i2 * (dst->nb[1]) + i3 * (dst->nb[2])); + float * dst_ptr = (float *) ((char *) dst->data + ir0 * (dst->nb[0]) + i2 * (dst->nb[1]) + i3 * (dst->nb[2])); - for (uint32_t i1 = 0; i1 < ir; i1 += VLEN_FP32) { - HVX_Vector acc_vec = Q6_V_vzero(); + const uint32_t nvec = ir / VLEN_FP32; + const uint32_t nloe = ir % VLEN_FP32; + uint32_t i1 = 0; + + for (uint32_t vi1 = 0; vi1 < nvec; vi1++) { + HVX_Vector acc_vec = Q6_V_vsplat_R(0); for (uint32_t i0 = 0; i0 < d_conv; ++i0) { Q6_vgather_ARMVw(src0_vec, GATHER_TYPE(spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0])), @@ -201,8 +214,24 @@ static void ssm_conv_thread_f32_f32_hvx(struct htp_ops_context * octx, uint32_t acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod); } - HVX_Vector result_vec = Q6_Vsf_equals_Vqf32(acc_vec); - *(HVX_Vector *) (dst_ptr + i1) = result_vec; + *(HVX_UVector *) (dst_ptr + i1) = Q6_Vsf_equals_Vqf32(acc_vec); + i1 += VLEN_FP32; + } + + if (nloe) { + HVX_Vector acc_vec = Q6_V_vsplat_R(0); + + for (uint32_t i0 = 0; i0 < d_conv; ++i0) { + Q6_vgather_ARMVw(src0_vec, GATHER_TYPE(spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0])), + src0_gather_len, (*(const HVX_Vector *) src0_offsets)); + Q6_vgather_ARMVw(src1_vec, GATHER_TYPE(spad_src1 + (i0 + i1 * nc) * sizeof(float)), + src1_gather_len, (*(const HVX_Vector *) src1_offsets)); + + HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec); + acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod); + } + + hvx_vec_store_u(dst_ptr + i1, (ir - i1) * 4, Q6_Vsf_equals_Vqf32(acc_vec)); } } } @@ -215,16 +244,6 @@ static void ssm_conv_thread_f32_f32_hvx(struct htp_ops_context * octx, uint32_t dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -static void ssm_conv_work_f32_f32(unsigned int nth, unsigned int ith, void * data) { - struct htp_ops_context * octx = (struct htp_ops_context *) data; - ssm_conv_thread_f32_f32(octx, nth, ith); -} - -static void ssm_conv_work_f32_f32_hvx(unsigned int nth, unsigned int ith, void * data) { - struct htp_ops_context * octx = (struct htp_ops_context *) data; - ssm_conv_thread_f32_f32_hvx(octx, nth, ith); -} - int op_ssm_conv_f32(struct htp_ops_context * octx) { htp_ssm_conv_tensors_preamble; @@ -233,14 +252,17 @@ int op_ssm_conv_f32(struct htp_ops_context * octx) { return HTP_STATUS_NO_SUPPORT; } + struct htp_ssm_conv_context scctx = { 0 }; + scctx.octx = octx; + const uint32_t d_conv = src1->ne[0]; const uint32_t d_inner = src0->ne[1]; const uint32_t n_t = dst->ne[1]; // tokens per sequence const uint32_t n_s = dst->ne[2]; // number of sequences in the batch - if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - const uint32_t n_jobs = MIN(octx->n_threads, d_inner); + const uint32_t n_threads = MIN(octx->n_threads, d_inner); + if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { uint32_t use_hvx = 0; if (d_inner >= VLEN_FP32 && d_inner % VLEN_FP32 == 0) { int is_aligned = hex_is_aligned((void *) src0->data, VLEN) && @@ -253,19 +275,19 @@ int op_ssm_conv_f32(struct htp_ops_context * octx) { } if (use_hvx) { - // d_inner chunks per thread - const int dr = (d_inner + n_jobs - 1) / n_jobs; + scctx.nrows_per_thread = (d_inner + n_threads - 1) / n_threads; // d_inner chunks per thread + scctx.nrows_per_thread += (scctx.nrows_per_thread & 1); // round up to even - octx->src0_spad.size_per_thread = hex_round_up(dr * nb01, 256); - octx->src1_spad.size_per_thread = hex_round_up(dr * nb11, 256); - octx->dst_spad.size_per_thread = hex_round_up(dr * sizeof(float), 256); + octx->src0_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * nb01, 256); + octx->src1_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * nb11, 256); + octx->dst_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * sizeof(float), 256); - octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + octx->src0_spad.size = octx->src0_spad.size_per_thread * n_threads; + octx->src1_spad.size = octx->src1_spad.size_per_thread * n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * n_threads; // Compute gather scratchpad size for src0 and src1 - const size_t gather_spad_size = n_jobs * VLEN * 2; + const size_t gather_spad_size = n_threads * VLEN * 2; octx->src0_spad.data = octx->ctx->vtcm_base + gather_spad_size; octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; @@ -291,9 +313,9 @@ int op_ssm_conv_f32(struct htp_ops_context * octx) { dst->ne[1], dst->ne[2], dst->ne[3], use_hvx); if (use_hvx) { - worker_pool_run_func(octx->ctx->worker_pool, ssm_conv_work_f32_f32_hvx, octx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, ssm_conv_thread_f32_f32_hvx, &scctx, n_threads); } else { - worker_pool_run_func(octx->ctx->worker_pool, ssm_conv_work_f32_f32, octx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, ssm_conv_thread_f32_f32, &scctx, n_threads); } }