Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions ggml/src/ggml-hexagon/ggml-hexagon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2152,6 +2152,44 @@ static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess
return true;
}

static bool ggml_hexagon_supported_ssm_conv(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
const struct ggml_tensor * src0 = op->src[0];
const struct ggml_tensor * src1 = op->src[1];
const struct ggml_tensor * dst = op;

// Only support FP32 for now
if (src0->type != GGML_TYPE_F32 || src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {
return false;
}

// Check IO tensor shapes and dims
if (src0->ne[3] != 1 || src1->ne[2] != 1 || src1->ne[3] != 1 || dst->ne[3] != 1) {
return false; // src0 should be effectively 3D
}

const int d_conv = src1->ne[0];
const int d_inner = src0->ne[1];
const int n_t = dst->ne[1];
const int n_s = dst->ne[2];

if (src0->ne[0] != d_conv - 1 + n_t || src0->ne[1] != d_inner || src0->ne[2] != n_s) {
return false;
}
if (src1->ne[0] != d_conv || src1->ne[1] != d_inner) {
return false;
}
if (dst->ne[0] != d_inner || dst->ne[1] != n_t || dst->ne[2] != n_s) {
return false;
}

// TODO: add support for non-contiguous tensors
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
return false;
}

return true;
}

enum dspqbuf_type {
DSPQBUF_TYPE_DSP_WRITE_CPU_READ = 0,
DSPQBUF_TYPE_CPU_WRITE_DSP_READ,
Expand Down Expand Up @@ -2468,6 +2506,17 @@ static inline size_t init_flash_attn_ext_req(htp_general_req * req, dspqueue_buf
return n_bufs;
}

static inline size_t init_ssm_conv_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
req->op = HTP_OP_SSM_CONV;

size_t n_bufs = 0;
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CONSTANT);
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);

return n_bufs;
}

static const char * ggml_backend_hexagon_name(ggml_backend_t backend) {
auto sess = static_cast<ggml_hexagon_session *>(backend->context);
return sess->name.c_str();
Expand Down Expand Up @@ -2606,6 +2655,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
ggml_hexagon_dispatch_op<init_argsort_req>(sess, node, flags);
break;

case GGML_OP_SSM_CONV:
ggml_hexagon_dispatch_op<init_ssm_conv_req>(sess, node, flags);
break;

default:
GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node));
}
Expand Down Expand Up @@ -3024,6 +3077,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
supp = ggml_hexagon_supported_argsort(sess, op);
break;

case GGML_OP_SSM_CONV:
supp = ggml_hexagon_supported_ssm_conv(sess, op);
break;

default:
break;
}
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-hexagon/htp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ add_library(${HTP_LIB} SHARED
get-rows-ops.c
cpy-ops.c
argsort-ops.c
ssm-conv.c
)

target_compile_definitions(${HTP_LIB} PRIVATE
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-hexagon/htp/htp-msg.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ enum htp_op {
HTP_OP_SQR,
HTP_OP_SQRT,
HTP_OP_SUM_ROWS,
HTP_OP_SSM_CONV,
INVALID
};

Expand Down
4 changes: 1 addition & 3 deletions ggml/src/ggml-hexagon/htp/htp-ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@ struct htp_ops_context {
worker_pool_context_t * wpool; // worker pool
uint32_t n_threads; // num threads

uint32_t src0_nrows_per_thread;
uint32_t src1_nrows_per_thread;

uint32_t flags;
};

Expand All @@ -61,5 +58,6 @@ int op_set_rows(struct htp_ops_context * octx);
int op_get_rows(struct htp_ops_context * octx);
int op_cpy(struct htp_ops_context * octx);
int op_argsort(struct htp_ops_context * octx);
int op_ssm_conv(struct htp_ops_context * octx);

#endif /* HTP_OPS_H */
8 changes: 8 additions & 0 deletions ggml/src/ggml-hexagon/htp/hvx-utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,12 @@
#include "hvx-div.h"
#include "hvx-base.h"

#ifndef GATHER_TYPE
# if defined(__hexagon__)
# define GATHER_TYPE(_a) (intptr_t) _a
# else
# define GATHER_TYPE(_a) (HVX_Vector *) _a
# endif
#endif

#endif /* HVX_UTILS_H */
49 changes: 49 additions & 0 deletions ggml/src/ggml-hexagon/htp/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,47 @@ static void proc_sum_rows_req(struct htp_context * ctx, struct htp_general_req *
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}

static void proc_ssm_conv_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];

// We've written to the output buffer, we'd also need to flush it
rsp_bufs[0].fd = bufs[2].fd;
rsp_bufs[0].ptr = bufs[2].ptr;
rsp_bufs[0].offset = bufs[2].offset;
rsp_bufs[0].size = bufs[2].size;
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU

// Setup OP context
struct htp_ops_context octx = { 0 };
octx.ctx = ctx;
octx.src0 = req->src0;
octx.src1 = req->src1;
octx.dst = req->dst;
octx.flags = req->flags;
octx.op = req->op;

memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));

// Update data pointers
octx.src0.data = (uint32_t) bufs[0].ptr;
octx.src1.data = (uint32_t) bufs[1].ptr;
octx.dst.data = (uint32_t) bufs[2].ptr;
octx.n_threads = ctx->n_threads;

struct profile_data prof;
profile_start(&prof);

uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
rsp_status = op_ssm_conv(&octx);
vtcm_release(ctx);
}

profile_stop(&prof);
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}

static void proc_activations_req(struct htp_context * ctx,
struct htp_general_req * req,
struct dspqueue_buffer * bufs,
Expand Down Expand Up @@ -1142,6 +1183,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
proc_argsort_req(ctx, &req, bufs);
break;

case HTP_OP_SSM_CONV:
if (n_bufs != 3) {
FARF(ERROR, "Bad ssm-conv-req buffer list");
continue;
}
proc_ssm_conv_req(ctx, &req, bufs);
break;

default:
FARF(ERROR, "Unknown Op %u", req.op);
break;
Expand Down
Loading
Loading