From b41ea16320a3feb773e7867fc6442b4e4eca44ee Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Mon, 21 Jul 2025 19:12:23 +0800 Subject: [PATCH 1/7] CUDA: add fused rms norm --- ggml/src/ggml-cuda/ggml-cuda.cu | 42 +++++++++++++++++ ggml/src/ggml-cuda/norm.cu | 81 +++++++++++++++++++++++++++++++-- ggml/src/ggml-cuda/norm.cuh | 2 + 3 files changed, 120 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 548bc31ce21..999a74314ae 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -55,6 +55,7 @@ #include #include #include +#include #include #include #include @@ -2765,6 +2766,40 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { } #endif +static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list ops) { + if(!ggml_can_fuse(cgraph, node_idx, ops)) { + return false; + } + + if(ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) { + const ggml_tensor *rms_norm = cgraph->nodes[node_idx]; + const ggml_tensor *mul = cgraph->nodes[node_idx+1]; + + GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(rms_norm->type == GGML_TYPE_F32); + + //rms norm only supports F32 + if (mul->src[0]->type != GGML_TYPE_F32 || + mul->src[1]->type != GGML_TYPE_F32 || + mul->type != GGML_TYPE_F32) { + return false; + } + + //if rms norm is the B operand, then we don't handle broadcast + if (rms_norm == mul->src[1] && + mul->src[0]->ne[1] != rms_norm->src[1]->ne[1]) { + return false; + } + + //rms_norm kernel assumes contigous rows + if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) { + return false; + } + } + + return true; +} + static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) { // flag used to determine whether it is an integrated_gpu @@ -2774,6 +2809,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph. // With the use of CUDA graphs, the execution will be performed by the graph launch. if (!use_cuda_graph || cuda_graph_update_required) { + for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; @@ -2781,6 +2817,12 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx continue; } + static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); + if (!disable_fusion && ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]); + i++; + continue; + } #ifndef NDEBUG assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device)); for (int j = 0; j < GGML_MAX_SRC; j++) { diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index 0020dbcec5f..6d3992c99b1 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -104,10 +104,11 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr } } -template +template static __global__ void rms_norm_f32( const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, - const int64_t stride_sample, const float eps) { + const int64_t stride_sample, const float eps, const float * mul = nullptr, const int64_t mul_stride_row = 0, + const int64_t mul_stride_channel = 0, const int64_t mul_stride_sample = 0) { const int nrows = gridDim.x; const int nchannels = gridDim.y; @@ -119,6 +120,13 @@ static __global__ void rms_norm_f32( x += sample*stride_sample + channel*stride_channel + row*stride_row; dst += ((sample*nchannels + channel)*nrows + row)*ncols; + const float * mul_ptr = nullptr; + if constexpr (do_multiply) { + if (mul != nullptr) { + mul_ptr = mul + sample*mul_stride_sample + channel*mul_stride_channel + row*mul_stride_row; + } + } + float tmp = 0.0f; // partial sum for thread in warp for (int col = tid; col < ncols; col += block_size) { @@ -145,7 +153,11 @@ static __global__ void rms_norm_f32( const float scale = rsqrtf(mean + eps); for (int col = tid; col < ncols; col += block_size) { - dst[col] = scale * x[col]; + if constexpr (do_multiply) { + dst[col] = scale * x[col] * (mul_ptr ? mul_ptr[col] : 1.0f); + } else { + dst[col] = scale * x[col]; + } } } @@ -310,10 +322,25 @@ static void rms_norm_f32_cuda( const dim3 blocks_num(nrows, nchannels, nsamples); if (ncols < 1024) { const dim3 block_dims(WARP_SIZE, 1, 1); - rms_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + rms_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + } else { + const dim3 block_dims(1024, 1, 1); + rms_norm_f32<1024, false><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + } +} + +static void rms_norm_mul_f32_cuda( + const float * x, const float * mul, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples, + const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, + const int64_t mul_stride_row, const int64_t mul_stride_channel, const int64_t mul_stride_sample, + const float eps, cudaStream_t stream) { + const dim3 blocks_num(nrows, nchannels, nsamples); + if (ncols < 1024) { + const dim3 block_dims(WARP_SIZE, 1, 1); + rms_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample); } else { const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + rms_norm_f32<1024, true><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample); } } @@ -407,6 +434,50 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { rms_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream); } +void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor) { + const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0]; + float eps = 0.0f; + + memcpy(&eps, dst->op_params, sizeof(float)); + + const float * src0_d = (const float *) rms_norm_src->data; + const float * mul_d = nullptr; + + if(mul_tensor->src[0] == dst) { + mul_d = (float *) mul_tensor->src[1]->data; + } else if(mul_tensor->src[1] == dst) { + mul_d = (float *) mul_tensor->src[0]->data; + } else { + GGML_ASSERT(false); + } + + float * dst_d = (float *) mul_tensor->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32); + GGML_ASSERT(eps >= 0.0f); + + const int64_t ne00 = rms_norm_src->ne[0]; + const int64_t ne01 = rms_norm_src->ne[1]; + const int64_t ne02 = rms_norm_src->ne[2]; + const int64_t ne03 = rms_norm_src->ne[3]; + + const size_t ts0 = ggml_type_size(rms_norm_src->type); + GGML_ASSERT(rms_norm_src->nb[0] == ts0); + const int64_t s01 = rms_norm_src->nb[1] / ts0; + const int64_t s02 = rms_norm_src->nb[2] / ts0; + const int64_t s03 = rms_norm_src->nb[3] / ts0; + + const size_t ts_mul = ggml_type_size(mul_tensor->type); + const int64_t mul_s01 = mul_tensor->nb[1] / ts_mul; + const int64_t mul_s02 = mul_tensor->nb[2] / ts_mul; + const int64_t mul_s03 = mul_tensor->nb[3] / ts_mul; + + rms_norm_mul_f32_cuda(src0_d, mul_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, mul_s01, mul_s02, mul_s03, eps, stream); +} + void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * grad = dst->src[0]; // gradients const ggml_tensor * src0f = dst->src[1]; // src0 from forward pass diff --git a/ggml/src/ggml-cuda/norm.cuh b/ggml/src/ggml-cuda/norm.cuh index 706a5660a68..7ea7bd4df3c 100644 --- a/ggml/src/ggml-cuda/norm.cuh +++ b/ggml/src/ggml-cuda/norm.cuh @@ -6,6 +6,8 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor); + void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); From a8b1b872dfdb7635f99d3b929717690bf2f51289 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Tue, 22 Jul 2025 10:41:59 +0800 Subject: [PATCH 2/7] assume mul_ptr is not null when calling fused ops, formatting changes --- ggml/src/ggml-cuda/ggml-cuda.cu | 4 ++-- ggml/src/ggml-cuda/norm.cu | 12 +++++++----- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 999a74314ae..de1773864bf 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2767,11 +2767,11 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { #endif static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list ops) { - if(!ggml_can_fuse(cgraph, node_idx, ops)) { + if (!ggml_can_fuse(cgraph, node_idx, ops)) { return false; } - if(ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) { + if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) { const ggml_tensor *rms_norm = cgraph->nodes[node_idx]; const ggml_tensor *mul = cgraph->nodes[node_idx+1]; diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index 6d3992c99b1..b74eaf5e91f 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -122,9 +122,7 @@ static __global__ void rms_norm_f32( const float * mul_ptr = nullptr; if constexpr (do_multiply) { - if (mul != nullptr) { - mul_ptr = mul + sample*mul_stride_sample + channel*mul_stride_channel + row*mul_stride_row; - } + mul_ptr = mul + sample*mul_stride_sample + channel*mul_stride_channel + row*mul_stride_row; } float tmp = 0.0f; // partial sum for thread in warp @@ -154,7 +152,7 @@ static __global__ void rms_norm_f32( for (int col = tid; col < ncols; col += block_size) { if constexpr (do_multiply) { - dst[col] = scale * x[col] * (mul_ptr ? mul_ptr[col] : 1.0f); + dst[col] = scale * x[col] * mul_ptr[col]; } else { dst[col] = scale * x[col]; } @@ -335,6 +333,10 @@ static void rms_norm_mul_f32_cuda( const int64_t mul_stride_row, const int64_t mul_stride_channel, const int64_t mul_stride_sample, const float eps, cudaStream_t stream) { const dim3 blocks_num(nrows, nchannels, nsamples); + if(mul == nullptr) { + rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream); + return; + } if (ncols < 1024) { const dim3 block_dims(WARP_SIZE, 1, 1); rms_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample); @@ -443,7 +445,7 @@ void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * const float * src0_d = (const float *) rms_norm_src->data; const float * mul_d = nullptr; - if(mul_tensor->src[0] == dst) { + if (mul_tensor->src[0] == dst) { mul_d = (float *) mul_tensor->src[1]->data; } else if(mul_tensor->src[1] == dst) { mul_d = (float *) mul_tensor->src[0]->data; From 0c6d097a6089f11d486f8f3eb644907e71c5e511 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Tue, 22 Jul 2025 17:06:46 +0800 Subject: [PATCH 3/7] Replace mul_ptr with mul --- ggml/src/ggml-cuda/norm.cu | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index b74eaf5e91f..51651b74b8e 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -120,9 +120,8 @@ static __global__ void rms_norm_f32( x += sample*stride_sample + channel*stride_channel + row*stride_row; dst += ((sample*nchannels + channel)*nrows + row)*ncols; - const float * mul_ptr = nullptr; if constexpr (do_multiply) { - mul_ptr = mul + sample*mul_stride_sample + channel*mul_stride_channel + row*mul_stride_row; + mul += sample*mul_stride_sample + channel*mul_stride_channel + row*mul_stride_row; } float tmp = 0.0f; // partial sum for thread in warp @@ -152,7 +151,7 @@ static __global__ void rms_norm_f32( for (int col = tid; col < ncols; col += block_size) { if constexpr (do_multiply) { - dst[col] = scale * x[col] * mul_ptr[col]; + dst[col] = scale * x[col] * mul[col]; } else { dst[col] = scale * x[col]; } @@ -333,7 +332,7 @@ static void rms_norm_mul_f32_cuda( const int64_t mul_stride_row, const int64_t mul_stride_channel, const int64_t mul_stride_sample, const float eps, cudaStream_t stream) { const dim3 blocks_num(nrows, nchannels, nsamples); - if(mul == nullptr) { + if (mul == nullptr) { rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream); return; } From db341d2a9a5f6956d5cdf19f23a5ff6b83c5d2e7 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Tue, 22 Jul 2025 18:34:18 +0800 Subject: [PATCH 4/7] Use mul tensor for broadcast --- ggml/src/ggml-cuda/norm.cu | 35 +++++++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index 51651b74b8e..bddcca51b7b 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -108,7 +108,8 @@ template static __global__ void rms_norm_f32( const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, const float * mul = nullptr, const int64_t mul_stride_row = 0, - const int64_t mul_stride_channel = 0, const int64_t mul_stride_sample = 0) { + const int64_t mul_stride_channel = 0, const int64_t mul_stride_sample = 0, const int mul_ncols = 0, + const int mul_nrows = 0, const int mul_nchannels = 0, const int mul_nsamples = 0) { const int nrows = gridDim.x; const int nchannels = gridDim.y; @@ -121,7 +122,10 @@ static __global__ void rms_norm_f32( dst += ((sample*nchannels + channel)*nrows + row)*ncols; if constexpr (do_multiply) { - mul += sample*mul_stride_sample + channel*mul_stride_channel + row*mul_stride_row; + const int mul_row = row % mul_nrows; + const int mul_channel = channel % mul_nchannels; + const int mul_sample = sample % mul_nsamples; + mul += mul_sample*mul_stride_sample + mul_channel*mul_stride_channel + mul_row*mul_stride_row; } float tmp = 0.0f; // partial sum for thread in warp @@ -151,7 +155,8 @@ static __global__ void rms_norm_f32( for (int col = tid; col < ncols; col += block_size) { if constexpr (do_multiply) { - dst[col] = scale * x[col] * mul[col]; + const int mul_col = col % mul_ncols; + dst[col] = scale * x[col] * mul[mul_col]; } else { dst[col] = scale * x[col]; } @@ -330,6 +335,7 @@ static void rms_norm_mul_f32_cuda( const float * x, const float * mul, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples, const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const int64_t mul_stride_row, const int64_t mul_stride_channel, const int64_t mul_stride_sample, + const int mul_ncols, const int mul_nrows, const int mul_nchannels, const int mul_nsamples, const float eps, cudaStream_t stream) { const dim3 blocks_num(nrows, nchannels, nsamples); if (mul == nullptr) { @@ -338,10 +344,10 @@ static void rms_norm_mul_f32_cuda( } if (ncols < 1024) { const dim3 block_dims(WARP_SIZE, 1, 1); - rms_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample); + rms_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples); } else { const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024, true><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample); + rms_norm_f32<1024, true><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples); } } @@ -443,11 +449,14 @@ void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * const float * src0_d = (const float *) rms_norm_src->data; const float * mul_d = nullptr; + const ggml_tensor * mul_src = nullptr; if (mul_tensor->src[0] == dst) { mul_d = (float *) mul_tensor->src[1]->data; + mul_src = mul_tensor->src[1]; } else if(mul_tensor->src[1] == dst) { mul_d = (float *) mul_tensor->src[0]->data; + mul_src = mul_tensor->src[0]; } else { GGML_ASSERT(false); } @@ -471,12 +480,18 @@ void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * const int64_t s02 = rms_norm_src->nb[2] / ts0; const int64_t s03 = rms_norm_src->nb[3] / ts0; - const size_t ts_mul = ggml_type_size(mul_tensor->type); - const int64_t mul_s01 = mul_tensor->nb[1] / ts_mul; - const int64_t mul_s02 = mul_tensor->nb[2] / ts_mul; - const int64_t mul_s03 = mul_tensor->nb[3] / ts_mul; + const size_t ts_mul = ggml_type_size(mul_src->type); + GGML_ASSERT(mul_src->nb[0] == ts_mul); + const int64_t mul_s01 = mul_src->nb[1] / ts_mul; + const int64_t mul_s02 = mul_src->nb[2] / ts_mul; + const int64_t mul_s03 = mul_src->nb[3] / ts_mul; - rms_norm_mul_f32_cuda(src0_d, mul_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, mul_s01, mul_s02, mul_s03, eps, stream); + const int mul_ncols = mul_src->ne[0]; + const int mul_nrows = mul_src->ne[1]; + const int mul_nchannels = mul_src->ne[2]; + const int mul_nsamples = mul_src->ne[3]; + + rms_norm_mul_f32_cuda(src0_d, mul_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, mul_s01, mul_s02, mul_s03, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, eps, stream); } void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { From f38c610cd9e9dba2f567983d1e0899b1015d838b Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Tue, 22 Jul 2025 19:52:15 +0800 Subject: [PATCH 5/7] Add testcase about the broadcast --- tests/test-backend-ops.cpp | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index a6d00542dd2..c424e248d94 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2561,7 +2561,7 @@ struct test_rms_norm : public test_case { const float eps; std::string vars() override { - return VARS_TO_STR4(type, ne, v, eps); + return VARS_TO_STR5(type, ne, v, eps, v); } test_rms_norm(ggml_type type = GGML_TYPE_F32, @@ -2641,6 +2641,7 @@ struct test_rms_norm_mul_add : public test_case { const ggml_type type; const std::array ne; const float eps; + const bool broadcast; std::string op_desc(ggml_tensor * t) override { GGML_UNUSED(t); @@ -2655,13 +2656,18 @@ struct test_rms_norm_mul_add : public test_case { test_rms_norm_mul_add(ggml_type type = GGML_TYPE_F32, std::array ne = {64, 5, 4, 3}, - float eps = 1e-6f) - : type(type), ne(ne), eps(eps) {} + float eps = 1e-6f, bool broadcast = false) + : type(type), ne(ne), eps(eps), broadcast(broadcast) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + std::array broadcast_dims = {ne[0]*2, ne[1]*3, ne[2]*3, ne[3]*4}; + + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, broadcast ? broadcast_dims.data() : ne.data()); ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_tensor * c = ggml_new_tensor(ctx, type, 4, ne.data()); + + + ggml_set_param(a); ggml_set_name(a, "a"); ggml_set_param(b); @@ -5354,6 +5360,7 @@ static std::vector> make_test_cases_eval() { } for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) { test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps)); + test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true)); } test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f)); From 2ebe86ac9ea531891daa532f791a54085ce3633b Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Tue, 22 Jul 2025 20:10:38 +0800 Subject: [PATCH 6/7] Fix test print --- tests/test-backend-ops.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index c424e248d94..4898094c918 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2561,7 +2561,7 @@ struct test_rms_norm : public test_case { const float eps; std::string vars() override { - return VARS_TO_STR5(type, ne, v, eps, v); + return VARS_TO_STR4(type, ne, v, eps); } test_rms_norm(ggml_type type = GGML_TYPE_F32, @@ -2651,7 +2651,7 @@ struct test_rms_norm_mul_add : public test_case { bool run_whole_graph() override { return true; } std::string vars() override { - return VARS_TO_STR3(type, ne, eps); + return VARS_TO_STR4(type, ne, eps, broadcast); } test_rms_norm_mul_add(ggml_type type = GGML_TYPE_F32, @@ -2666,8 +2666,6 @@ struct test_rms_norm_mul_add : public test_case { ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_tensor * c = ggml_new_tensor(ctx, type, 4, ne.data()); - - ggml_set_param(a); ggml_set_name(a, "a"); ggml_set_param(b); From ed9f84e2ea83cb40973691e8b75332fd754aa804 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Wed, 23 Jul 2025 00:36:40 +0800 Subject: [PATCH 7/7] Fix condition for broadcast --- ggml/src/ggml-cuda/ggml-cuda.cu | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index de1773864bf..03c380897cd 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2786,8 +2786,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, } //if rms norm is the B operand, then we don't handle broadcast - if (rms_norm == mul->src[1] && - mul->src[0]->ne[1] != rms_norm->src[1]->ne[1]) { + if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) { return false; }