From 288ce1e57413c0afa3a8c2873e011d4a1444fa48 Mon Sep 17 00:00:00 2001 From: chenghuaWang Date: Tue, 7 Oct 2025 07:05:13 +0000 Subject: [PATCH 1/8] refactor(cpu): rename paged_attn_x to radix_attn and update namespace references - Rename directory and files from `paged_attn_x` to `radix_attn` - Update namespace from `mllm::cpu::paged_attn_x` to `mllm::cpu::radix_attn` - Remove unused includes and context-related code in fwd_bshd.hpp - Add new ops: RadixAttnOp and Scatter2ShardsOp with CPU implementations - Introduce tensor creation utilities: fromVector and refVectorData - Add kRadixAttn and kScatter2Shards to OpTypes enum and optype2Str - Remove kPrefixCache from PagedAttnImplType - Add necessary headers and fix tensor ptr attribute to [[nodiscard]] --- .../{paged_attn_x => radix_attn}/README.md | 0 .../{paged_attn_x => radix_attn}/arch.hpp | 4 +- .../{paged_attn_x => radix_attn}/fwd_bshd.hpp | 28 +- .../impl-any-simd.hpp | 0 .../{paged_attn_x => radix_attn}/impl-any.hpp | 0 .../{paged_attn_x => radix_attn}/impl-arm.hpp | 6 +- mllm/backends/cpu/ops/RadixAttnOp.cpp | 14 + mllm/backends/cpu/ops/RadixAttnOp.hpp | 25 ++ mllm/backends/cpu/ops/Scatter2ShardsOp.cpp | 45 +++ mllm/backends/cpu/ops/Scatter2ShardsOp.hpp | 25 ++ mllm/core/OpTypes.hpp | 3 + mllm/core/Tensor.hpp | 44 ++- mllm/core/aops/PagedAttnOp.hpp | 1 - mllm/core/aops/RadixAttnOp.cpp | 33 ++ mllm/core/aops/RadixAttnOp.hpp | 31 ++ mllm/core/aops/Scatter2ShardsOp.cpp | 35 ++ mllm/core/aops/Scatter2ShardsOp.hpp | 33 ++ mllm/engine/service/Service.hpp | 1 + mllm/engine/service/Session.cpp | 3 + mllm/engine/service/Session.hpp | 2 + mllm/models/qwen3/modeling_qwen3_service.hpp | 358 +++++++++++++++--- mllm/nn/Functional.cpp | 6 + mllm/nn/Functional.hpp | 2 + mllm/nn/layers/PagedAttn.cpp | 11 - mllm/nn/layers/PagedAttn.hpp | 4 - 25 files changed, 619 insertions(+), 95 deletions(-) rename mllm/backends/cpu/kernels/common/{paged_attn_x => radix_attn}/README.md (100%) rename mllm/backends/cpu/kernels/common/{paged_attn_x => radix_attn}/arch.hpp (91%) rename mllm/backends/cpu/kernels/common/{paged_attn_x => radix_attn}/fwd_bshd.hpp (76%) rename mllm/backends/cpu/kernels/common/{paged_attn_x => radix_attn}/impl-any-simd.hpp (100%) rename mllm/backends/cpu/kernels/common/{paged_attn_x => radix_attn}/impl-any.hpp (100%) rename mllm/backends/cpu/kernels/common/{paged_attn_x => radix_attn}/impl-arm.hpp (95%) create mode 100644 mllm/backends/cpu/ops/RadixAttnOp.cpp create mode 100644 mllm/backends/cpu/ops/RadixAttnOp.hpp create mode 100644 mllm/backends/cpu/ops/Scatter2ShardsOp.cpp create mode 100644 mllm/backends/cpu/ops/Scatter2ShardsOp.hpp create mode 100644 mllm/core/aops/RadixAttnOp.cpp create mode 100644 mllm/core/aops/RadixAttnOp.hpp create mode 100644 mllm/core/aops/Scatter2ShardsOp.cpp create mode 100644 mllm/core/aops/Scatter2ShardsOp.hpp diff --git a/mllm/backends/cpu/kernels/common/paged_attn_x/README.md b/mllm/backends/cpu/kernels/common/radix_attn/README.md similarity index 100% rename from mllm/backends/cpu/kernels/common/paged_attn_x/README.md rename to mllm/backends/cpu/kernels/common/radix_attn/README.md diff --git a/mllm/backends/cpu/kernels/common/paged_attn_x/arch.hpp b/mllm/backends/cpu/kernels/common/radix_attn/arch.hpp similarity index 91% rename from mllm/backends/cpu/kernels/common/paged_attn_x/arch.hpp rename to mllm/backends/cpu/kernels/common/radix_attn/arch.hpp index e88f4be54..5fcd32f8f 100644 --- a/mllm/backends/cpu/kernels/common/paged_attn_x/arch.hpp +++ b/mllm/backends/cpu/kernels/common/radix_attn/arch.hpp @@ -5,7 +5,7 @@ #include #include "mllm/utils/Common.hpp" -namespace mllm::cpu::paged_attn_x::details { +namespace mllm::cpu::radix_attn::details { struct __AnyArchTag {}; using any_arch_tag = __AnyArchTag; @@ -32,4 +32,4 @@ struct FMAConstArray { static MLLM_FORCE_INLINE void run(T* __restrict__ acc_o, const U acc_s, const V* __restrict__ v_token, size_t len) {} }; -} // namespace mllm::cpu::paged_attn_x::details +} // namespace mllm::cpu::radix_attn::details diff --git a/mllm/backends/cpu/kernels/common/paged_attn_x/fwd_bshd.hpp b/mllm/backends/cpu/kernels/common/radix_attn/fwd_bshd.hpp similarity index 76% rename from mllm/backends/cpu/kernels/common/paged_attn_x/fwd_bshd.hpp rename to mllm/backends/cpu/kernels/common/radix_attn/fwd_bshd.hpp index d34e7676d..1312146ec 100644 --- a/mllm/backends/cpu/kernels/common/paged_attn_x/fwd_bshd.hpp +++ b/mllm/backends/cpu/kernels/common/radix_attn/fwd_bshd.hpp @@ -9,17 +9,16 @@ #include #include "mllm/core/Parallel.hpp" #include "mllm/utils/CPUArchHelper.hpp" -#include "mllm/engine/prefix_cache/TLB.hpp" -#include "mllm/backends/cpu/kernels/common/paged_attn_x/arch.hpp" +#include "mllm/backends/cpu/kernels/common/radix_attn/arch.hpp" #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH) -#include "mllm/backends/cpu/kernels/common/paged_attn_x/impl-arm.hpp" +#include "mllm/backends/cpu/kernels/common/radix_attn/impl-arm.hpp" #else -#include "mllm/backends/cpu/kernels/common/paged_attn_x/impl-any-simd.hpp" +#include "mllm/backends/cpu/kernels/common/radix_attn/impl-any-simd.hpp" #endif -#include "mllm/backends/cpu/kernels/common/paged_attn_x/impl-any.hpp" +#include "mllm/backends/cpu/kernels/common/radix_attn/impl-any.hpp" -namespace mllm::cpu::paged_attn_x { +namespace mllm::cpu::radix_attn { // BHSD // K: [S_KV], address, not contiguous @@ -30,8 +29,7 @@ namespace mllm::cpu::paged_attn_x { template void fwd_bhsd(int32_t B, int32_t H_Q, int32_t H_KV, int32_t S_Q, int32_t S_KV, int32_t D, const __QDType* __restrict__ __q, - const mllm::prefix_cache::vp_addr_t* __k, const mllm::prefix_cache::vp_addr_t* __v, __ODType* __restrict__ __out, - void* ctx, int32_t thread_count) { + const __KDType** __k, const __VDType** __v, __ODType* __restrict__ __out, int32_t thread_count) { int32_t head_repeat_times = H_Q / H_KV; __AccDType scale = scale = std::sqrt(1.0 / D) * (__AccDType)std::numbers::log2e; @@ -47,7 +45,6 @@ void fwd_bhsd(int32_t B, int32_t H_Q, int32_t H_KV, int32_t S_Q, int32_t S_KV, i __QDType* q_token = __q + b_idx * H_Q * S_Q * D + h_q_idx * S_Q * D + s_q_idx * D; __ODType* acc_o = __out + b_idx * H_Q * S_Q * D + h_q_idx * S_Q * D + s_q_idx * D; - // FIXME: Boost with SIMD for (int d_idx = 0; d_idx < D; ++d_idx) { acc_o[d_idx] = 0; } __AccDType scores_max = std::numeric_limits<__AccDType>::lowest(); __AccDType scores_max_prev = std::numeric_limits<__AccDType>::lowest(); @@ -59,13 +56,8 @@ void fwd_bhsd(int32_t B, int32_t H_Q, int32_t H_KV, int32_t S_Q, int32_t S_KV, i int S_KV_BOUND = std::min(__delta + s_q_idx + 1, S_KV); for (int s_kv_idx = 0; s_kv_idx < S_KV_BOUND; ++s_kv_idx) { - // TODO, prefetch next - - // TODO using context. - // __KDType* k_token = (__KDType*)ctx->access(__k[s_kv_idx]); - // __VDType* v_token = (__VDType*)ctx->access(__v[s_kv_idx]); - __KDType* k_token = NULL; - __VDType* v_token = NULL; + __KDType* k_token = __k[s_kv_idx]; + __VDType* v_token = __v[s_kv_idx]; // Offset to one head. // k_token and v_token shape is [D] @@ -90,8 +82,6 @@ void fwd_bhsd(int32_t B, int32_t H_Q, int32_t H_KV, int32_t S_Q, int32_t S_KV, i // 4. MMA1. FMAConstArray<__ArchTag, __AccDType, __AccDType, __AccDType>(acc_o, acc_s, v_token, D); - - // TODO, drop this mmap in the future. } // 5. Final Rescale. @@ -101,4 +91,4 @@ void fwd_bhsd(int32_t B, int32_t H_Q, int32_t H_KV, int32_t S_Q, int32_t S_KV, i } } -} // namespace mllm::cpu::paged_attn_x +} // namespace mllm::cpu::radix_attn diff --git a/mllm/backends/cpu/kernels/common/paged_attn_x/impl-any-simd.hpp b/mllm/backends/cpu/kernels/common/radix_attn/impl-any-simd.hpp similarity index 100% rename from mllm/backends/cpu/kernels/common/paged_attn_x/impl-any-simd.hpp rename to mllm/backends/cpu/kernels/common/radix_attn/impl-any-simd.hpp diff --git a/mllm/backends/cpu/kernels/common/paged_attn_x/impl-any.hpp b/mllm/backends/cpu/kernels/common/radix_attn/impl-any.hpp similarity index 100% rename from mllm/backends/cpu/kernels/common/paged_attn_x/impl-any.hpp rename to mllm/backends/cpu/kernels/common/radix_attn/impl-any.hpp diff --git a/mllm/backends/cpu/kernels/common/paged_attn_x/impl-arm.hpp b/mllm/backends/cpu/kernels/common/radix_attn/impl-arm.hpp similarity index 95% rename from mllm/backends/cpu/kernels/common/paged_attn_x/impl-arm.hpp rename to mllm/backends/cpu/kernels/common/radix_attn/impl-arm.hpp index 6327b45d2..962916bbb 100644 --- a/mllm/backends/cpu/kernels/common/paged_attn_x/impl-arm.hpp +++ b/mllm/backends/cpu/kernels/common/radix_attn/impl-arm.hpp @@ -4,11 +4,11 @@ #pragma once #include "mllm/core/DataTypes.hpp" -#include "mllm/backends/cpu/kernels/common/paged_attn_x/arch.hpp" +#include "mllm/backends/cpu/kernels/common/radix_attn/arch.hpp" #include -namespace mllm::cpu::paged_attn_x::details { +namespace mllm::cpu::radix_attn::details { template<> struct VectorDotProduct<__ArmArchTag, mllm_fp32_t, mllm_fp32_t, mllm_fp32_t> { static MLLM_FORCE_INLINE void run(const mllm_fp32_t* __restrict__ __lhs, const mllm_fp32_t* __restrict__ __rhs, @@ -97,4 +97,4 @@ struct FMAConstArray<__ArmArchTag, mllm_fp32_t, mllm_fp32_t, mllm_fp32_t> { } }; -} // namespace mllm::cpu::paged_attn_x::details +} // namespace mllm::cpu::radix_attn::details diff --git a/mllm/backends/cpu/ops/RadixAttnOp.cpp b/mllm/backends/cpu/ops/RadixAttnOp.cpp new file mode 100644 index 000000000..413dc9068 --- /dev/null +++ b/mllm/backends/cpu/ops/RadixAttnOp.cpp @@ -0,0 +1,14 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/core/Tensor.hpp" +#include "mllm/backends/cpu/ops/RadixAttnOp.hpp" +#include "mllm/backends/cpu/kernels/common/radix_attn/fwd_bshd.hpp" + +namespace mllm::cpu { + +CPURadixAttnOp::CPURadixAttnOp(const aops::RadixAttnOpOptions& options) : aops::RadixAttnOp(options) {} + +void CPURadixAttnOp::forward(const std::vector& inputs, std::vector& outputs) {} + +} // namespace mllm::cpu diff --git a/mllm/backends/cpu/ops/RadixAttnOp.hpp b/mllm/backends/cpu/ops/RadixAttnOp.hpp new file mode 100644 index 000000000..bc7f44130 --- /dev/null +++ b/mllm/backends/cpu/ops/RadixAttnOp.hpp @@ -0,0 +1,25 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/aops/RadixAttnOp.hpp" + +namespace mllm::cpu { + +class CPURadixAttnOp final : public aops::RadixAttnOp { + public: + explicit CPURadixAttnOp(const aops::RadixAttnOpOptions& options); + + void forward(const std::vector& inputs, std::vector& outputs) override; +}; + +class CPURadixAttnOpFactory : public TypedOpFactory { + public: + std::shared_ptr createOpImpl(const aops::RadixAttnOpOptions& options) override { + return std::make_shared(options); + } +}; + +} // namespace mllm::cpu diff --git a/mllm/backends/cpu/ops/Scatter2ShardsOp.cpp b/mllm/backends/cpu/ops/Scatter2ShardsOp.cpp new file mode 100644 index 000000000..cde949727 --- /dev/null +++ b/mllm/backends/cpu/ops/Scatter2ShardsOp.cpp @@ -0,0 +1,45 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include +#include "mllm/utils/Common.hpp" +#include "mllm/core/DataTypes.hpp" +#include "mllm/backends/cpu/ops/Scatter2ShardsOp.hpp" + +namespace mllm::cpu { + +CPUScatter2ShardsOp::CPUScatter2ShardsOp(const aops::Scatter2ShardsOpOptions& options) : aops::Scatter2ShardsOp(options) {} + +void CPUScatter2ShardsOp::forward(const std::vector& inputs, std::vector& outputs) { + const auto& src = inputs[0]; + const auto& dst_ptrs = inputs[1]; + + // Validation + MLLM_RT_ASSERT(dst_ptrs.dtype() == kInt64 && dst_ptrs.rank() == 1 && dst_ptrs.shape()[0] == src.shape()[options_.dim]); + + // [B, H, S, D] + // floating_shards_point_left = B * H + // floating_shards_point_right = D + int32_t floating_shards_point_left = 1; + int32_t floating_shards_point_right = 1; + for (int i = 0; i < options_.dim; ++i) { floating_shards_point_left *= src.shape()[i]; } + for (int i = options_.dim + 1; i < src.rank(); ++i) { floating_shards_point_right *= src.shape()[i]; } + int32_t floating_shards_point_stride = 1; + for (int i = options_.dim; i < src.rank(); ++i) { floating_shards_point_stride *= src.shape()[i]; } + if (options_.dim == src.rank() - 1) { floating_shards_point_stride = 1; } + if (options_.dim == 0) { floating_shards_point_left = src.stride()[0]; } + + int32_t loop_times = src.shape()[options_.dim]; + for (int lo = 0; lo < loop_times; ++lo) { + for (int ho = 0; ho < floating_shards_point_left; ++ho) { + auto src_ptr = src.ptr() + + (ho * floating_shards_point_stride + lo * floating_shards_point_right) + * (bytesOfType(src.dtype()) / lanesOfType(src.dtype())); + auto dst_ptr = dst_ptrs.ptr()[lo] + + ho * floating_shards_point_right * (bytesOfType(src.dtype()) / lanesOfType(src.dtype())); + memcpy(dst_ptr, src_ptr, floating_shards_point_right * (bytesOfType(src.dtype()) / lanesOfType(src.dtype()))); + } + } +} + +} // namespace mllm::cpu diff --git a/mllm/backends/cpu/ops/Scatter2ShardsOp.hpp b/mllm/backends/cpu/ops/Scatter2ShardsOp.hpp new file mode 100644 index 000000000..5073cb63b --- /dev/null +++ b/mllm/backends/cpu/ops/Scatter2ShardsOp.hpp @@ -0,0 +1,25 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/aops/Scatter2ShardsOp.hpp" + +namespace mllm::cpu { + +class CPUScatter2ShardsOp final : public aops::Scatter2ShardsOp { + public: + explicit CPUScatter2ShardsOp(const aops::Scatter2ShardsOpOptions& options); + + void forward(const std::vector& inputs, std::vector& outputs) override; +}; + +class CPUScatter2ShardsOpFactory : public TypedOpFactory { + public: + std::shared_ptr createOpImpl(const aops::Scatter2ShardsOpOptions& options) override { + return std::make_shared(options); + } +}; + +} // namespace mllm::cpu diff --git a/mllm/core/OpTypes.hpp b/mllm/core/OpTypes.hpp index 58c59dbf2..4adc215e5 100644 --- a/mllm/core/OpTypes.hpp +++ b/mllm/core/OpTypes.hpp @@ -73,6 +73,8 @@ enum class OpTypes : int32_t { // High-level Op or Fused Op kPagedAttn = 57, + kRadixAttn = 58, + kScatter2Shards = 59, // Dynamic Op Start for user to register there own ops. kDynamicOp_Start = 4096, @@ -140,6 +142,7 @@ inline std::string optype2Str(OpTypes type) { case OpTypes::kGraphBegin: return "GraphBegin"; case OpTypes::kGraphEnd: return "GraphEnd"; case OpTypes::kPagedAttn: return "PagedAttn"; + case OpTypes::kScatter2Shards: return "Scatter2Shards"; case OpTypes::kOpType_End: return "OpType_End"; default: return "Unknown"; } diff --git a/mllm/core/Tensor.hpp b/mllm/core/Tensor.hpp index a7b0a25bb..0011b871f 100644 --- a/mllm/core/Tensor.hpp +++ b/mllm/core/Tensor.hpp @@ -3,6 +3,7 @@ #pragma once +#include #include #include #include @@ -93,6 +94,47 @@ class Tensor { return tensor; } + template + static inline Tensor fromVector(const std::span& vec, const shape_t& shape, DataTypes dtype = kFloat32, + DeviceTypes device = kCPU) { + Tensor tensor = Tensor::empty(shape, dtype, device).alloc(); + size_t tensor_size = tensor.numel(); + if (vec.size() != tensor_size) { + MLLM_ERROR_EXIT(ExitCode::kShapeError, "Tensor size mismatch with std::vector size"); + return Tensor::nil(); + } + std::copy(vec.begin(), vec.end(), tensor.ptr()); + return tensor; + } + + /** + * @brief Create a tensor from a std::vector, but reference the vector data, not copy it. + * + * @tparam T + * @param vec + * @param shape + * @param dtype + * @param device + * @return Tensor + */ + template + static inline Tensor refVectorData(const std::vector& vec, const shape_t& shape, DataTypes dtype = kFloat32, + DeviceTypes device = kCPU) { + size_t expected_size = 1; + for (auto dim : shape) { expected_size *= dim; } + + if (vec.size() != expected_size) { + MLLM_ERROR_EXIT(ExitCode::kShapeError, "Tensor shape mismatch with std::vector size"); + return Tensor::nil(); + } + + Tensor tensor = Tensor::empty(shape, dtype, device); + tensor.impl_->storage()->ptr_ = const_cast(vec.data()); + tensor.impl_->storage()->mem_type_ = kManual; + + return tensor; + } + template inline std::vector toVector() const { std::vector vec; @@ -522,7 +564,7 @@ class Tensor { * @return Typed base pointer. */ template - T* ptr() const { + [[nodiscard]] T* ptr() const { return impl_->ptr(); } diff --git a/mllm/core/aops/PagedAttnOp.hpp b/mllm/core/aops/PagedAttnOp.hpp index 43d97f5ff..675d6b41f 100644 --- a/mllm/core/aops/PagedAttnOp.hpp +++ b/mllm/core/aops/PagedAttnOp.hpp @@ -11,7 +11,6 @@ namespace mllm::aops { enum class PagedAttnImplType { kDefault = 0, kAllFp32 = 1, - kPrefixCache = 2, }; struct PagedAttnOpOptions : public BaseOpOptions { diff --git a/mllm/core/aops/RadixAttnOp.cpp b/mllm/core/aops/RadixAttnOp.cpp new file mode 100644 index 000000000..4d9f33378 --- /dev/null +++ b/mllm/core/aops/RadixAttnOp.cpp @@ -0,0 +1,33 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/core/aops/RadixAttnOp.hpp" +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/Tensor.hpp" +#include "mllm/utils/Common.hpp" + +namespace mllm::aops { + +RadixAttnOp::RadixAttnOp(const RadixAttnOpOptions& options) : BaseOp(OpTypes::kRadixAttn), options_(options) {} + +void RadixAttnOp::load(const ParameterFile::ptr_t& ploader) { MLLM_EMPTY_SCOPE; } + +void RadixAttnOp::trace(void* trace_context, const std::vector& inputs, std::vector& outputs) { + MLLM_WARN("RadixAttnOp::trace is not supported."); +} + +void RadixAttnOp::forward(const std::vector& inputs, std::vector& outputs) { + NYI("RadixAttnOp::forward not implemented in aops base."); +} + +void RadixAttnOp::reshape(const std::vector& inputs, std::vector& outputs) { + auto& query = inputs[0]; + auto& key = inputs[1]; + auto& value = inputs[2]; + + outputs.emplace_back(Tensor::empty(query.shape(), query.dtype(), query.device())); +} + +void RadixAttnOp::setup(const std::vector& inputs, std::vector& outputs) { BaseOp::setup(inputs, outputs); } + +} // namespace mllm::aops diff --git a/mllm/core/aops/RadixAttnOp.hpp b/mllm/core/aops/RadixAttnOp.hpp new file mode 100644 index 000000000..6b1e2428b --- /dev/null +++ b/mllm/core/aops/RadixAttnOp.hpp @@ -0,0 +1,31 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/ParameterFile.hpp" + +namespace mllm::aops { + +struct RadixAttnOpOptions : public BaseOpOptions {}; + +class RadixAttnOp : public BaseOp { + public: + explicit RadixAttnOp(const RadixAttnOpOptions& options); + + void load(const ParameterFile::ptr_t& ploader) override; + + void trace(void* trace_context, const std::vector& inputs, std::vector& outputs) override; + + void forward(const std::vector& inputs, std::vector& outputs) override; + + void reshape(const std::vector& inputs, std::vector& outputs) override; + + void setup(const std::vector& inputs, std::vector& outputs) override; + + protected: + RadixAttnOpOptions options_; +}; + +} // namespace mllm::aops diff --git a/mllm/core/aops/Scatter2ShardsOp.cpp b/mllm/core/aops/Scatter2ShardsOp.cpp new file mode 100644 index 000000000..bfcc8783b --- /dev/null +++ b/mllm/core/aops/Scatter2ShardsOp.cpp @@ -0,0 +1,35 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/core/aops/Scatter2ShardsOp.hpp" +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/Tensor.hpp" +#include "mllm/utils/Common.hpp" +#include "mllm/compile/ir/linalg/Op.hpp" + +namespace mllm::aops { + +Scatter2ShardsOp::Scatter2ShardsOp(const Scatter2ShardsOpOptions& options) + : BaseOp(OpTypes::kScatter2Shards), options_(options) {} + +void Scatter2ShardsOp::load(const ParameterFile::ptr_t& ploader) { MLLM_EMPTY_SCOPE; } + +void Scatter2ShardsOp::trace(void* trace_context, const std::vector& inputs, std::vector& outputs) { + MLLM_WARN("Scatter2ShardsOp::trace can't be traced in v2.0.0 right now. Pls send us a feature request issues on github if " + "you need this."); +} + +void Scatter2ShardsOp::forward(const std::vector& inputs, std::vector& outputs) { + NYI("Scatter2ShardsOp::forward not implemented in aops base."); +} + +void Scatter2ShardsOp::reshape(const std::vector& inputs, std::vector& outputs) { + // scatter op has no output. It only has inputs. + MLLM_EMPTY_SCOPE; +} + +void Scatter2ShardsOp::setup(const std::vector& inputs, std::vector& outputs) { + BaseOp::setup(inputs, outputs); +} + +} // namespace mllm::aops diff --git a/mllm/core/aops/Scatter2ShardsOp.hpp b/mllm/core/aops/Scatter2ShardsOp.hpp new file mode 100644 index 000000000..a45c0642f --- /dev/null +++ b/mllm/core/aops/Scatter2ShardsOp.hpp @@ -0,0 +1,33 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/ParameterFile.hpp" + +namespace mllm::aops { + +struct Scatter2ShardsOpOptions : public BaseOpOptions { + int dim = 0; +}; + +class Scatter2ShardsOp : public BaseOp { + public: + explicit Scatter2ShardsOp(const Scatter2ShardsOpOptions& options); + + void load(const ParameterFile::ptr_t& ploader) override; + + void trace(void* trace_context, const std::vector& inputs, std::vector& outputs) override; + + void forward(const std::vector& inputs, std::vector& outputs) override; + + void reshape(const std::vector& inputs, std::vector& outputs) override; + + void setup(const std::vector& inputs, std::vector& outputs) override; + + protected: + Scatter2ShardsOpOptions options_; +}; + +} // namespace mllm::aops diff --git a/mllm/engine/service/Service.hpp b/mllm/engine/service/Service.hpp index 5b87968e2..54321678f 100644 --- a/mllm/engine/service/Service.hpp +++ b/mllm/engine/service/Service.hpp @@ -4,6 +4,7 @@ #include #include +#include #include #include diff --git a/mllm/engine/service/Session.cpp b/mllm/engine/service/Session.cpp index 2089eb5a0..818fc4b07 100644 --- a/mllm/engine/service/Session.cpp +++ b/mllm/engine/service/Session.cpp @@ -5,12 +5,15 @@ #include #include +#include "mllm/utils/Common.hpp" #include "mllm/engine/service/Session.hpp" namespace mllm::service { Session::Session(std::shared_ptr model) : model_(std::move(model)) {} +void Session::fromPreTrain(const std::string& model_path) { MLLM_EMPTY_SCOPE; } + NoneSession::NoneSession() : Session(nullptr) {} void NoneSession::streamGenerate(const nlohmann::json& request, diff --git a/mllm/engine/service/Session.hpp b/mllm/engine/service/Session.hpp index 339d6cb60..1f1efe1bf 100644 --- a/mllm/engine/service/Session.hpp +++ b/mllm/engine/service/Session.hpp @@ -17,6 +17,8 @@ class Session { virtual void streamGenerate(const nlohmann::json& request, const std::function& callback) = 0; + virtual void fromPreTrain(const std::string& model_path); + private: std::shared_ptr model_; }; diff --git a/mllm/models/qwen3/modeling_qwen3_service.hpp b/mllm/models/qwen3/modeling_qwen3_service.hpp index 09b8cf8d2..da9a01ddc 100644 --- a/mllm/models/qwen3/modeling_qwen3_service.hpp +++ b/mllm/models/qwen3/modeling_qwen3_service.hpp @@ -1,15 +1,25 @@ // Copyright (c) MLLM Team. // Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include #include "mllm/mllm.hpp" #include "mllm/nn/Nn.hpp" #include "mllm/nn/Module.hpp" #include "mllm/nn/Functional.hpp" -#include "mllm/nn/lmcache/StaticCache.hpp" -#include "mllm/models/qwen3/configuration_qwen3.hpp" #include "mllm/utils/Enumerate.hpp" #include "mllm/models/ARGeneration.hpp" +#include "mllm/models/qwen3/tokenization_qwen3.hpp" +#include "mllm/models/qwen3/configuration_qwen3.hpp" + +// Service related. #include "mllm/engine/service/Session.hpp" +#include "mllm/engine/prefix_cache/Cache.hpp" namespace mllm::models::qwen3 { @@ -103,8 +113,6 @@ class Qwen3Attention final : public nn::Module { nn::RMSNorm rms_norm_k_; nn::RoPE q_rope_; nn::RoPE k_rope_; - nn::CausalMask mask_; - nn::Softmax softmax_; int hidden_size_; int head_dim_; @@ -136,16 +144,15 @@ class Qwen3Attention final : public nn::Module { q_rope_ = reg("q_rope", cfg.rope_theta, cfg.max_position_embeddings); k_rope_ = reg("k_rope", cfg.rope_theta, cfg.max_position_embeddings); - - mask_ = reg("mask"); - softmax_ = reg("softmax", -1); } std::vector forward(const std::vector& inputs, const std::vector& args) override { auto x = inputs[0]; auto llm_embedding_sin = inputs[1]; auto llm_embedding_cos = inputs[2]; - auto past_kv_cache = args[0].get(); + auto k_cache_addr = args[0].get>*>(); + auto v_cache_addr = args[1].get>*>(); + auto prefix_cache_context = args[2].get(); // [B, S, H * D] auto query_states = q_proj_(x); @@ -167,39 +174,54 @@ class Qwen3Attention final : public nn::Module { // [B, H, S, D] query_states = query_states.transpose(1, 2); key_states = key_states.transpose(1, 2); - value_states = value_states.transpose(1, 2); + value_states = value_states.transpose(1, 2).to(kFloat16); // [B, H, S, D] - query_states = q_rope_(query_states, llm_embedding_sin, llm_embedding_cos); - key_states = k_rope_(key_states, llm_embedding_sin, llm_embedding_cos); + query_states = q_rope_(query_states, llm_embedding_sin, llm_embedding_cos).to(kFloat16); + key_states = k_rope_(key_states, llm_embedding_sin, llm_embedding_cos).to(kFloat16); - // [B, H, S, D] - auto [key_states_new, value_states_new] = past_kv_cache->updateKVCache(layer_idx_, key_states, value_states); - key_states = key_states_new; - value_states = value_states_new; - - Tensor attn; - if (key_states.dtype() == kFloat32) { - // attention weight - // [B, H, S, S] - attn = nn::functional::matmul(query_states, key_states, false, true) * (1.f / sqrtf(head_dim_)); - attn = mask_(attn); - attn = softmax_(attn); - } else if (key_states.dtype() == kFloat16) { - attn = nn::functional::matmul(query_states.to(kFloat32), key_states.to(kFloat32), false, true) * (1.f / sqrtf(head_dim_)); - attn = mask_(attn); - attn = softmax_(attn); - attn = attn.to(kFloat16); + // FIXME: Think, if rope before cache is ok? + + // Acquire cache + std::vector k_addr_wait_for_promote; + std::vector v_addr_wait_for_promote; + for (int s_idx; s_idx < S; ++s_idx) { + k_addr_wait_for_promote.push_back(prefix_cache_context->alloc(DeviceTypes::kCPU)); + v_addr_wait_for_promote.push_back(prefix_cache_context->alloc(DeviceTypes::kCPU)); } - // attn output - // [B, H, S, S] @ [B, H, S, D] -> [B, H, S, D] - auto output = nn::functional::matmul(attn, value_states); - // [B, H, S, D] -> [B, S, H, D] -> [B, S, H * D] - output = output.transpose(1, 2).view({B, S, num_attention_heads_ * head_dim_}); - output = o_proj_(output); + // Prepare indicies cache. sizeof(char*) == 8 == sizeof(int64_t) + std::vector k_phy_addr_wait_for_promote; + std::vector v_phy_addr_wait_for_promote; + for (int s_idx; s_idx < S; ++s_idx) { + k_phy_addr_wait_for_promote.push_back(prefix_cache_context->physicalAddr(k_addr_wait_for_promote[s_idx])); + v_phy_addr_wait_for_promote.push_back(prefix_cache_context->physicalAddr(v_addr_wait_for_promote[s_idx])); + } + auto k_wait_for_promote = Tensor::refVectorData(k_phy_addr_wait_for_promote, {S}, kInt64, kCPU); + auto v_wait_for_promote = Tensor::refVectorData(v_phy_addr_wait_for_promote, {S}, kInt64, kCPU); + + // Copy key_states and value_states to cache + nn::functional::scatter2Shards(key_states, k_wait_for_promote, 2); + nn::functional::scatter2Shards(value_states, v_wait_for_promote, 2); + + // Gather all cache to indicies tensor + std::ranges::copy((*k_cache_addr)[layer_idx_], std::back_inserter(k_addr_wait_for_promote)); + std::ranges::copy((*v_cache_addr)[layer_idx_], std::back_inserter(v_addr_wait_for_promote)); + std::vector k_phy_cache_indicies; + std::vector v_phy_cache_indicies; + int32_t kv_cache_len = (*k_cache_addr)[layer_idx_].size(); + k_phy_cache_indicies.reserve(kv_cache_len); + v_phy_cache_indicies.reserve(kv_cache_len); + for (int i = 0; i < kv_cache_len; ++i) { + k_phy_cache_indicies.push_back(prefix_cache_context->physicalAddr((*k_cache_addr)[layer_idx_][i])); + v_phy_cache_indicies.push_back(prefix_cache_context->physicalAddr((*k_cache_addr)[layer_idx_][i])); + } + auto k_cache = Tensor::refVectorData(k_phy_cache_indicies, {kv_cache_len}, kInt64, kCPU); + auto v_cache = Tensor::refVectorData(v_phy_cache_indicies, {kv_cache_len}, kInt64, kCPU); - return {output}; + // TODO Do Radix Attention + + return {}; } int layer_idx_; @@ -224,10 +246,12 @@ class Qwen3Decoder final : public nn::Module { std::vector forward(const std::vector& inputs, const std::vector& args) override { auto llm_embedding_sin = inputs[1]; auto llm_embedding_cos = inputs[2]; - auto& kv_cache = args[0]; + auto& k_cache_addr = args[0]; + auto& v_cache_addr = args[1]; + auto& prefix_cache_context = args[2]; auto x = input_layer_norm_(inputs[0]); - x = self_attn_(x, llm_embedding_sin, llm_embedding_cos, kv_cache)[0]; + x = self_attn_(x, llm_embedding_sin, llm_embedding_cos, k_cache_addr, v_cache_addr, prefix_cache_context)[0]; auto tmp = x + inputs[0]; x = post_attention_layer_norm_(tmp); x = mlp_(x)[0]; @@ -259,9 +283,13 @@ class Qwen3Text final : public nn::Module { auto llm_embedding_sin = inputs[1]; auto llm_embedding_cos = inputs[2]; - auto& kv_cache = args[0]; + auto& k_cache_addr = args[0]; + auto& v_cache_addr = args[1]; + auto& prefix_cache_context = args[2]; - for (auto& block : blocks) { x = block(x, llm_embedding_sin, llm_embedding_cos, kv_cache)[0]; } + for (auto& block : blocks) { + x = block(x, llm_embedding_sin, llm_embedding_cos, k_cache_addr, v_cache_addr, prefix_cache_context)[0]; + } x = norm_(x); @@ -272,21 +300,13 @@ class Qwen3Text final : public nn::Module { class Qwen3ForCausalLM : public ARGeneration, public nn::Module { public: explicit Qwen3ForCausalLM(const Qwen3Config& cfg) : cfg(cfg) { - kv_cache_ = nn::StaticCache(cfg.max_cache_length, cfg.num_hidden_layers, - cfg.num_attention_heads, // q_heads - cfg.num_key_value_heads, // kv_heads - cfg.head_dim, // kv_dim - kFloat32, // k_dtype - kFloat32, // v_dtype - kCPU, // device_type - false // use_fa2 - ); eos_token_id_ = cfg.end_of_text_token_id; max_length_ = cfg.max_cache_length; tie_word_embeddings_ = cfg.tie_word_embeddings; llm = reg("model", cfg); + // Qwen3 0.6B's lm_head is tied with embed_tokens. But ModelScope's official weights separate them. if (cfg.tie_word_embeddings) { // NOTE: // model.lm_head.weight is quantization weights of model.embed_tokens.weight @@ -317,6 +337,9 @@ class Qwen3ForCausalLM : public ARGeneration, public nn::Module { *position_ids.offsettedPtr({0, 0}) = last_pos + 1; } } else { + // NOTE: Service Session should not go into this branch !!! + MLLM_RT_ASSERT(false); + // Generate position_ids for prefill phase position_ids = Tensor::empty({batch_size, seq_len}, kInt64, kCPU).alloc(); auto position_ids_ptr = position_ids.ptr(); @@ -328,7 +351,8 @@ class Qwen3ForCausalLM : public ARGeneration, public nn::Module { // Generate RoPE embeddings using the inv_freq buffer auto [llm_embedding_sin, llm_embedding_cos] = makeRotaryPosEmbedding(position_ids, getBuffer("inv_freq"), 1.0f); - sequence = llm(sequence, llm_embedding_sin, llm_embedding_cos, AnyValue(&kv_cache_))[0]; + sequence = llm(sequence, llm_embedding_sin, llm_embedding_cos, args.at("k_cache_addrs"), args.at("v_cache_addrs"), + args.at("prefix_cache_context"))[0]; // clip x to one seq length { @@ -348,12 +372,238 @@ class Qwen3ForCausalLM : public ARGeneration, public nn::Module { Qwen3Text llm; nn::Linear lm_head_; bool tie_word_embeddings_; - nn::StaticCache kv_cache_; }; -// class Qwen3Session final : public ::mllm::service::Session { -// public: -// private: -// }; +class Qwen3Session final : public ::mllm::service::Session { + public: + Qwen3Session(); + + void streamGenerate(const nlohmann::json& request, + const std::function& callback) override { + const auto& messages = request["messages"]; + auto inputs = applyChatTemplate(messages, nullptr, true, request.value("enable_thinking", false)); + + Dbg("prompt", inputs); + + auto full_seq_idx = tokenizer_->convert2Ids(tokenizer_->tokenize(inputs)).toVector(); + ARGenerationArgs args; + ARGenerationOutputPast input; + + // Search in the radix cache. Find the tokens we really need to compute. + auto prefix_cache_result = cache_->find(full_seq_idx); + std::span reduced_seq_idx(full_seq_idx.data() + prefix_cache_result.matched_length, + full_seq_idx.size() - prefix_cache_result.matched_length); + std::vector position_ids; + { + auto start = prefix_cache_result.matched_length; + auto end = full_seq_idx.size(); + position_ids.reserve(end - start); + std::ranges::copy(std::views::iota(static_cast(start), static_cast(end)), + std::back_inserter(position_ids)); + } + MLLM_RT_ASSERT_EQ(reduced_seq_idx.size(), position_ids.size()); + input["sequence"] = Tensor::fromVector(reduced_seq_idx, {(int32_t)reduced_seq_idx.size()}, kInt64, kCPU); + input["position_ids"] = Tensor::fromVector(position_ids, {(int32_t)position_ids.size()}, kInt64, kCPU); + + // Setup session context + k_cache_addrs_ = prefix_cache_result.k_cache_addresses; + v_cache_addrs_ = prefix_cache_result.v_cache_addresses; + args["k_cache_addrs"] = &k_cache_addrs_; + args["v_cache_addrs"] = &v_cache_addrs_; + args["prefix_cache_context"] = cache_.get(); + + // TODO other args for sampling should be read from request. + + // Iteration start + model_->streamGenerate(input, args, [](int64_t idx) { + // TODO Callback + }); + + // TODO: process full_seq_idx and k_cache_addrs_/v_cache_addrs_. Only none thinking budget should be insert in radix tree. + + // Insert generated tokens to the cache. + cache_->promote(full_seq_idx, k_cache_addrs_, v_cache_addrs_); + + // Cleanup session Context + k_cache_addrs_ = {}; + v_cache_addrs_ = {}; + } + + void fromPreTrain(const std::string& model_path) override { + namespace fs = std::filesystem; + fs::path root = fs::path(model_path).lexically_normal(); + fs::path config_file = root / "config.json"; + fs::path model_file = root / "model.mllm"; + fs::path tokenizer_file = root / "tokenizer.json"; + if (!fs::exists(config_file)) throw std::runtime_error("config.json not found"); + if (!fs::exists(model_file)) throw std::runtime_error("model.mllm not found"); + if (!fs::exists(tokenizer_file)) throw std::runtime_error("tokenizer.json not found"); + + auto cfg = Qwen3Config(config_file.string()); + model_ = std::make_shared(cfg); + model_->load(load(model_file.string())); + tokenizer_ = std::make_shared(tokenizer_file.string()); + + cache_ = std::make_shared(prefix_cache::CacheOptions{ + .radix_tree_options = {.enable_lru_eviction = false, + .eviction_threshold = 0.9f, + .enable_path_compression = false, + .min_compression_length = 2, + .transformer_blocks_num = cfg.num_hidden_layers}, + .allocator_options = {// Normal things. + .per_k_token_ele = static_cast(cfg.head_dim * cfg.num_key_value_heads), + .per_v_token_ele = static_cast(cfg.head_dim * cfg.num_key_value_heads), + .k_dtype = mllm::kFloat16, + .v_dtype = mllm::kFloat16, + + // CUDA things. + .enable_cuda = false, + .cuda_mem_base = 0x100000, + + // CPU things. + .enable_cpu_hierarchy_memory = true, + .zen_fs_options = { + .record = false, + .working_dir = ".", + .blob_bits_size = 20, + .page_bits = 7, + .lane_bits = 5, + .per_k_token_ele = static_cast(cfg.head_dim * cfg.num_key_value_heads), + .per_v_token_ele = static_cast(cfg.head_dim * cfg.num_key_value_heads), + .k_dtype = mllm::kFloat16, + .v_dtype = mllm::kFloat16, + .mmap_type = mllm::prefix_cache::ZenFSBlobMMapType::kAnonymous, + }}}); + } + + std::string trim(const std::string& s) { + auto beg = s.find_first_not_of(" \t\r\n"); + if (beg == std::string::npos) return ""; + auto end = s.find_last_not_of(" \t\r\n"); + return s.substr(beg, end - beg + 1); + } + + // [{"role": "user", "content": "Say this is a test!"}] + // [{"role": "assistant", "content": "This is a test!", "reasoning": "You are absolutely right!"}] + // + std::string applyChatTemplate(const json& messages, const json* tools = nullptr, bool add_generation_prompt = true, + bool enable_thinking = true, const std::string& bos_token = "", + const std::string& eos_token = "<|im_end|>") { + std::ostringstream out; + + if (tools && tools->is_array() && !tools->empty()) { + out << "<|im_start|>system\n"; + if (messages.is_array() && !messages.empty() && messages[0].contains("role") && messages[0]["role"] == "system") { + out << messages[0]["content"].get() << "\n\n"; + } + out << "# Tools\n\nYou may call one or more functions to assist with the user query.\n\n" + "You are provided with function signatures within XML tags:\n"; + for (auto& t : *tools) out << "\n" << t.dump(); + out << "\n\n\n" + "For each function call, return a json object with function name and arguments " + "within XML tags:\n\n" + "{\"name\": , \"arguments\": }\n" + "<|im_end|>\n"; + } else { + if (messages.is_array() && !messages.empty() && messages[0].contains("role") && messages[0]["role"] == "system") { + out << "<|im_start|>system\n" << messages[0]["content"].get() << "<|im_end|>\n"; + } + } + + size_t last_query_index = messages.size() - 1; + bool multi_step_tool = true; + if (messages.is_array()) { + for (int i = static_cast(messages.size()) - 1; i >= 0; --i) { + auto& m = messages[i]; + if (multi_step_tool && m.contains("role") && m["role"] == "user" && m.contains("content") && m["content"].is_string()) { + std::string c = m["content"]; + bool is_tr = + (c.starts_with("")) && (c.size() > 17 && c.compare(c.size() - 18, 18, "") == 0); + if (!is_tr) { + multi_step_tool = false; + last_query_index = i; + } + } + } + } + + if (!messages.is_array()) return out.str(); + for (size_t idx = 0; idx < messages.size(); ++idx) { + auto& m = messages[idx]; + std::string role, content; + if (m.contains("role")) role = m["role"]; + if (m.contains("content") && m["content"].is_string()) content = m["content"]; + + if (role == "user" || (role == "system" && idx != 0)) { + out << "<|im_start|>" << role << "\n" << content << eos_token << "\n"; + continue; + } + + if (role == "assistant") { + std::string reasoning_content; + if (m.contains("reasoning_content") && m["reasoning_content"].is_string()) { + reasoning_content = m["reasoning_content"]; + } else { + size_t pos_end = content.find(""); + if (pos_end != std::string::npos) { + size_t pos_beg = content.rfind("", pos_end); + if (pos_beg != std::string::npos) { + reasoning_content = trim(content.substr(pos_beg + 7, pos_end - pos_beg - 7)); + content.erase(pos_beg, pos_end + 8); + content = trim(content); + } + } + } + + bool in_last_turn = (idx > last_query_index); + bool need_think = in_last_turn && (idx + 1 == messages.size() || !reasoning_content.empty()); + + out << "<|im_start|>assistant\n"; + if (need_think) { out << "\n" << reasoning_content << "\n\n\n"; } + out << content; + + if (m.contains("tool_calls") && m["tool_calls"].is_array()) { + for (auto& tc : m["tool_calls"]) { + json fn = tc.contains("function") ? tc["function"] : tc; + std::string name = fn.value("name", ""); + std::string args = fn.value("arguments", ""); + if (fn["arguments"].is_object() || fn["arguments"].is_array()) args = fn["arguments"].dump(); + out << "\n\n" + << R"({"name": ")" << name << R"(", "arguments": )" << args << "}\n" + << ""; + } + } + out << eos_token << "\n"; + continue; + } + + if (role == "tool") { + bool first_tool = (idx == 0) || !messages[idx - 1].contains("role") || messages[idx - 1]["role"] != "tool"; + bool last_tool = + (idx + 1 == messages.size()) || !messages[idx + 1].contains("role") || messages[idx + 1]["role"] != "tool"; + if (first_tool) out << "<|im_start|>user"; + out << "\n\n" << content << "\n"; + if (last_tool) out << eos_token << "\n"; + } + } + + if (add_generation_prompt) { + out << "<|im_start|>assistant\n"; + if (!enable_thinking) { out << "\n\n\n\n"; } + } + + return out.str(); + } + + private: + // States + std::vector> k_cache_addrs_; + std::vector> v_cache_addrs_; + + // Owned data + std::shared_ptr model_; + std::shared_ptr tokenizer_; + std::shared_ptr cache_; +}; } // namespace mllm::models::qwen3 diff --git a/mllm/nn/Functional.cpp b/mllm/nn/Functional.cpp index e4697d7ee..c6803af9c 100644 --- a/mllm/nn/Functional.cpp +++ b/mllm/nn/Functional.cpp @@ -7,6 +7,7 @@ #include "mllm/core/aops/FlashAttention2Op.hpp" #include "mllm/core/aops/MatMulOp.hpp" #include "mllm/core/aops/ReduceOps.hpp" +#include "mllm/core/aops/Scatter2ShardsOp.hpp" #include "mllm/core/aops/SoftmaxOp.hpp" #include "mllm/core/aops/ElewiseOps.hpp" #include "mllm/core/aops/SplitOp.hpp" @@ -112,4 +113,9 @@ Tensor silu_(const Tensor& x) { return Context::instance().buildOpAndSubmitTask(OpTypes::kSiLU, opt, {x})[0]; } +void scatter2Shards(const Tensor& src, const Tensor& shards_pointer, int32_t dim) { + Context::instance().buildOpAndSubmitTask(OpTypes::kScatter2Shards, aops::Scatter2ShardsOpOptions{.dim = dim}, + {src, shards_pointer}); +} + } // namespace mllm::nn::functional diff --git a/mllm/nn/Functional.hpp b/mllm/nn/Functional.hpp index bd88ab836..9efad0ab2 100644 --- a/mllm/nn/Functional.hpp +++ b/mllm/nn/Functional.hpp @@ -127,4 +127,6 @@ Tensor mean(const Tensor& x, int32_t dim = std::numeric_limits::max(), Tensor silu(const Tensor& x); Tensor silu_(const Tensor& x); +void scatter2Shards(const Tensor& src, const Tensor& shards_pointer, int32_t dim); + } // namespace mllm::nn::functional diff --git a/mllm/nn/layers/PagedAttn.cpp b/mllm/nn/layers/PagedAttn.cpp index 84484588c..5aa664135 100644 --- a/mllm/nn/layers/PagedAttn.cpp +++ b/mllm/nn/layers/PagedAttn.cpp @@ -1,7 +1,6 @@ // Copyright (c) MLLM Team. // Licensed under the MIT License. #include "mllm/core/aops/PagedAttnOp.hpp" -#include "mllm/utils/Common.hpp" #include "mllm/nn/layers/PagedAttn.hpp" namespace mllm::nn { @@ -18,14 +17,4 @@ PagedAttn::PagedAttn(int32_t head_repeat_times, bool high_precision_exp, bool fu .need_attn_weights = need_attn_weights, .impl_type = impl_type}) {} -PagedAttn::PagedAttn(void* ctx, bool high_precision_exp, bool fuse_rope, aops::PagedAttnImplType impl_type) - : Layer(OpTypes::kPagedAttn, aops::PagedAttnOpOptions{.head_repeat_times = -1, - .high_precision_exp = high_precision_exp, - .fuse_rope = fuse_rope, - .need_attn_weights = false, - .impl_type = impl_type, - .prefix_cache_ctx = ctx}) { - if (ctx == nullptr) { MLLM_ERROR_EXIT(ExitCode::kCoreError, "prefix_cache_ctx is empty."); } -} - } // namespace mllm::nn diff --git a/mllm/nn/layers/PagedAttn.hpp b/mllm/nn/layers/PagedAttn.hpp index 4aab7396d..b4b1be4a8 100644 --- a/mllm/nn/layers/PagedAttn.hpp +++ b/mllm/nn/layers/PagedAttn.hpp @@ -17,10 +17,6 @@ class PagedAttn : public Layer { explicit PagedAttn(int32_t head_repeat_times, bool high_precision_exp = false, bool fuse_rope = false, bool need_attn_weights = false, aops::PagedAttnImplType impl_type = aops::PagedAttnImplType::kAllFp32); - // Prefixed Cache Paged Attn Constructor - explicit PagedAttn(void* ctx, bool high_precision_exp = false, bool fuse_rope = false, - aops::PagedAttnImplType impl_type = aops::PagedAttnImplType::kPrefixCache); - MLLM_LAYER_ANY_INPUTS_2_OUTPUTS_FORWARD }; From 60580ab8f72d1290981f9ab5c1da645e63b39812 Mon Sep 17 00:00:00 2001 From: chenghuaWang Date: Tue, 7 Oct 2025 07:28:12 +0000 Subject: [PATCH 2/8] feat(cpu): add RadixAttnOp and Scatter2ShardsOp support - Register `RadixAttnOp` and `Scatter2ShardsOp` in CPU backend - Implement forward logic for `RadixAttnOp` with architecture-specific kernels - Update `RadixAttnOpOptions` to include head count parameters - Fix const-correctness and template usage in radix attention kernel - Remove unnecessary `.to(kFloat16)` calls in Qwen3 attention module - Adjust function signatures for better type safety in radix attention kernel - Refactor kernel calls to use namespaced static dispatch --- mllm/backends/cpu/CPUBackend.cpp | 4 +- .../kernels/common/radix_attn/fwd_bshd.hpp | 10 ++--- mllm/backends/cpu/ops/RadixAttnOp.cpp | 38 ++++++++++++++++++- mllm/core/aops/RadixAttnOp.hpp | 5 ++- mllm/models/qwen3/modeling_qwen3_service.hpp | 6 +-- 5 files changed, 52 insertions(+), 11 deletions(-) diff --git a/mllm/backends/cpu/CPUBackend.cpp b/mllm/backends/cpu/CPUBackend.cpp index b2794df34..c94871b90 100644 --- a/mllm/backends/cpu/CPUBackend.cpp +++ b/mllm/backends/cpu/CPUBackend.cpp @@ -17,6 +17,7 @@ #include "mllm/backends/cpu/ops/FillOp.hpp" #include "mllm/backends/cpu/ops/FlashAttention2Op.hpp" #include "mllm/backends/cpu/ops/GELUOp.hpp" +#include "mllm/backends/cpu/ops/RadixAttnOp.hpp" #include "mllm/backends/cpu/ops/ReLUOp.hpp" #include "mllm/backends/cpu/ops/GraphOps.hpp" #include "mllm/backends/cpu/ops/ISTFTOp.hpp" @@ -35,6 +36,7 @@ #include "mllm/backends/cpu/ops/RepeatOp.hpp" #include "mllm/backends/cpu/ops/RoPEOp.hpp" #include "mllm/backends/cpu/ops/STFTOp.hpp" +#include "mllm/backends/cpu/ops/Scatter2ShardsOp.hpp" #include "mllm/backends/cpu/ops/SiLUOp.hpp" #include "mllm/backends/cpu/ops/SliceOp.hpp" #include "mllm/backends/cpu/ops/SoftmaxOp.hpp" @@ -58,7 +60,7 @@ CPUBackend::CPUBackend() : Backend(kCPU, createCPUAllocator()) { CPUFlashAttention2OpFactory, CPUSliceOpFactory, CPUVisionRoPEOpFactory, CPUParamOpFactory, CPUMultimodalRoPEOpFactory, CPURoPEOpFactory, CPUCausalMaskOpFactory, CPUConv1DOpFactory, CPUConv3DOpFactory, CPUSTFTOpFactory, CPUISTFTOpFactory, CPUIndexOpFactory, CPUTopKOpFactory, CPUClipOpFactory, CPUMeanOpFactory, - CPUKVCacheOpFactory, CPUPagedAttnOpFactory>(); + CPUKVCacheOpFactory, CPUPagedAttnOpFactory, CPUScatter2ShardsOpFactory, CPURadixAttnOpFactory>(); } std::shared_ptr createCPUBackend() { return std::make_shared(); } diff --git a/mllm/backends/cpu/kernels/common/radix_attn/fwd_bshd.hpp b/mllm/backends/cpu/kernels/common/radix_attn/fwd_bshd.hpp index 1312146ec..11c42157f 100644 --- a/mllm/backends/cpu/kernels/common/radix_attn/fwd_bshd.hpp +++ b/mllm/backends/cpu/kernels/common/radix_attn/fwd_bshd.hpp @@ -29,7 +29,7 @@ namespace mllm::cpu::radix_attn { template void fwd_bhsd(int32_t B, int32_t H_Q, int32_t H_KV, int32_t S_Q, int32_t S_KV, int32_t D, const __QDType* __restrict__ __q, - const __KDType** __k, const __VDType** __v, __ODType* __restrict__ __out, int32_t thread_count) { + __KDType** __k, __VDType** __v, __ODType* __restrict__ __out, int32_t thread_count) { int32_t head_repeat_times = H_Q / H_KV; __AccDType scale = scale = std::sqrt(1.0 / D) * (__AccDType)std::numbers::log2e; @@ -42,7 +42,7 @@ void fwd_bhsd(int32_t B, int32_t H_Q, int32_t H_KV, int32_t S_Q, int32_t S_KV, i // FA2's Loop for (int s_q_idx = 0; s_q_idx < S_Q; ++s_q_idx) { - __QDType* q_token = __q + b_idx * H_Q * S_Q * D + h_q_idx * S_Q * D + s_q_idx * D; + const __QDType* q_token = __q + b_idx * H_Q * S_Q * D + h_q_idx * S_Q * D + s_q_idx * D; __ODType* acc_o = __out + b_idx * H_Q * S_Q * D + h_q_idx * S_Q * D + s_q_idx * D; for (int d_idx = 0; d_idx < D; ++d_idx) { acc_o[d_idx] = 0; } @@ -78,14 +78,14 @@ void fwd_bhsd(int32_t B, int32_t H_Q, int32_t H_KV, int32_t S_Q, int32_t S_KV, i logsum = logsum * scores_scale + scores_sum; // 3. Scale - MulFromConst<__ArchTag, __AccDType, __AccDType>(acc_o, scores_scale, D); + details::MulFromConst<__ArchTag, __AccDType, __AccDType>::run(acc_o, scores_scale, D); // 4. MMA1. - FMAConstArray<__ArchTag, __AccDType, __AccDType, __AccDType>(acc_o, acc_s, v_token, D); + details::FMAConstArray<__ArchTag, __AccDType, __AccDType, __AccDType>::run(acc_o, acc_s, v_token, D); } // 5. Final Rescale. - MulFromConst<__ArchTag, __AccDType, __AccDType>(acc_o, (1.f / logsum), D); + details::MulFromConst<__ArchTag, __ODType, __AccDType>::run(acc_o, (1.f / logsum), D); } }); } diff --git a/mllm/backends/cpu/ops/RadixAttnOp.cpp b/mllm/backends/cpu/ops/RadixAttnOp.cpp index 413dc9068..5f7ff6c6f 100644 --- a/mllm/backends/cpu/ops/RadixAttnOp.cpp +++ b/mllm/backends/cpu/ops/RadixAttnOp.cpp @@ -3,12 +3,48 @@ #include "mllm/core/Tensor.hpp" #include "mllm/backends/cpu/ops/RadixAttnOp.hpp" +#include "mllm/backends/cpu/kernels/common/radix_attn/arch.hpp" #include "mllm/backends/cpu/kernels/common/radix_attn/fwd_bshd.hpp" +#include "mllm/utils/Common.hpp" namespace mllm::cpu { CPURadixAttnOp::CPURadixAttnOp(const aops::RadixAttnOpOptions& options) : aops::RadixAttnOp(options) {} -void CPURadixAttnOp::forward(const std::vector& inputs, std::vector& outputs) {} +void CPURadixAttnOp::forward(const std::vector& inputs, std::vector& outputs) { + const auto& Q = inputs[0]; + const auto& K_PTR = inputs[1]; + const auto& V_PTR = inputs[2]; + const auto& OUT = outputs[0]; + + MLLM_RT_ASSERT(K_PTR.dtype() == kInt64 && V_PTR.dtype() == kInt64 && K_PTR.rank() == 1 && V_PTR.rank() == 1); + auto B = Q.shape()[0]; + auto H_Q = Q.shape()[1]; + auto S_Q = Q.shape()[2]; + auto D = K_PTR.shape()[3]; + MLLM_RT_ASSERT_EQ(H_Q, options_.H_Q); + auto S_KV = K_PTR.shape()[0]; + MLLM_RT_ASSERT_EQ(S_KV, V_PTR.shape()[0]); + + switch (Q.dtype()) { + case kFloat32: { +#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH) + cpu::radix_attn::fwd_bhsd(B, options_.H_Q, options_.H_KV, S_Q, S_KV, D, Q.ptr(), + K_PTR.ptr(), V_PTR.ptr(), OUT.ptr(), + options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86) || defined(MLLM_HOST_ARCH_X86_64) + cpu::radix_attn::fwd_bhsd(B, options_.H_Q, options_.H_KV, S_Q, S_KV, D, Q.ptr(), + K_PTR.ptr(), V_PTR.ptr(), OUT.ptr(), + options_.getThreads()); +#endif + break; + } + default: { + NYI("RadixAttnOp not supported for this data type"); + } + } +} } // namespace mllm::cpu diff --git a/mllm/core/aops/RadixAttnOp.hpp b/mllm/core/aops/RadixAttnOp.hpp index 6b1e2428b..dbadffc59 100644 --- a/mllm/core/aops/RadixAttnOp.hpp +++ b/mllm/core/aops/RadixAttnOp.hpp @@ -8,7 +8,10 @@ namespace mllm::aops { -struct RadixAttnOpOptions : public BaseOpOptions {}; +struct RadixAttnOpOptions : public BaseOpOptions { + int32_t H_Q; + int32_t H_KV; +}; class RadixAttnOp : public BaseOp { public: diff --git a/mllm/models/qwen3/modeling_qwen3_service.hpp b/mllm/models/qwen3/modeling_qwen3_service.hpp index da9a01ddc..7fd138fb1 100644 --- a/mllm/models/qwen3/modeling_qwen3_service.hpp +++ b/mllm/models/qwen3/modeling_qwen3_service.hpp @@ -174,11 +174,11 @@ class Qwen3Attention final : public nn::Module { // [B, H, S, D] query_states = query_states.transpose(1, 2); key_states = key_states.transpose(1, 2); - value_states = value_states.transpose(1, 2).to(kFloat16); + value_states = value_states.transpose(1, 2); // [B, H, S, D] - query_states = q_rope_(query_states, llm_embedding_sin, llm_embedding_cos).to(kFloat16); - key_states = k_rope_(key_states, llm_embedding_sin, llm_embedding_cos).to(kFloat16); + query_states = q_rope_(query_states, llm_embedding_sin, llm_embedding_cos); + key_states = k_rope_(key_states, llm_embedding_sin, llm_embedding_cos); // FIXME: Think, if rope before cache is ok? From 7426d9cea1812afe51b59d563e0450642aa5916d Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Tue, 7 Oct 2025 22:09:58 +0800 Subject: [PATCH 3/8] fix(cpu): correct Q tensor layout in radix attention kernel The Q tensor layout was incorrectly specified as [B, H_Q, S_Q, D] when it should be [B, S_Q, H_Q, D]. This change updates the documentation and adjusts the indexing logic accordingly. A FIXME comment is added to indicate potential performance improvements. Also adds a TODO comment indicating that the kernel's layout needs further review. feat(cpu): add input layout support for RoPE operation Introduces support for different input layouts (BHSD and BSHD) in the RoPE operation. This change modifies the forward method to accept an input_layout_type parameter, allowing the RoPE operation to correctly process tensors with varying memory layouts. The implementation includes updates to both float32 and float16 data types, ensuring compatibility across different architectures (x86 and ARM). Vectorized processing is preserved for both layout types. Additionally, this commit: - Adds enum class RoPEOpOptionsInputType to specify layout types - Updates RoPEOpOptions to include input_type configuration - Modifies Qwen3Attention to use BSHD layout for RoPE operations - Adjusts tensor scattering indices in Qwen3Attention from 2 to 1 - Implements RadixAttn layer and registers it in Qwen3Attention - Updates Qwen3Session::applyChatTemplate to use nlohmann::json fix(include): update prefix cache include directive Updates the include directive in PrefixCache.hpp to properly export the cache header, ensuring correct usage throughout the codebase. --- .../kernels/common/radix_attn/fwd_bshd.hpp | 13 +- mllm/backends/cpu/ops/RoPEOp.cpp | 358 +++++++++++++----- mllm/backends/cpu/ops/RoPEOp.hpp | 5 +- mllm/core/aops/RoPEOp.hpp | 11 +- mllm/models/qwen3/modeling_qwen3_service.hpp | 51 +-- mllm/nn/Nn.hpp | 1 + mllm/nn/layers/RadixAttn.cpp | 16 + mllm/nn/layers/RadixAttn.hpp | 23 ++ mllm/nn/layers/RoPE.cpp | 8 +- mllm/nn/layers/RoPE.hpp | 6 +- mllm/nn/lmcache/PrefixCache.hpp | 2 +- tests/cpu/CMakeLists.txt | 1 - tests/cpu/KernelTest.cpp | 33 ++ tests/cpu/RadixAttnKernel.hpp | 156 ++++++++ 14 files changed, 539 insertions(+), 145 deletions(-) create mode 100644 mllm/nn/layers/RadixAttn.cpp create mode 100644 mllm/nn/layers/RadixAttn.hpp create mode 100644 tests/cpu/RadixAttnKernel.hpp diff --git a/mllm/backends/cpu/kernels/common/radix_attn/fwd_bshd.hpp b/mllm/backends/cpu/kernels/common/radix_attn/fwd_bshd.hpp index 11c42157f..2fc9f5cd3 100644 --- a/mllm/backends/cpu/kernels/common/radix_attn/fwd_bshd.hpp +++ b/mllm/backends/cpu/kernels/common/radix_attn/fwd_bshd.hpp @@ -23,7 +23,11 @@ namespace mllm::cpu::radix_attn { // BHSD // K: [S_KV], address, not contiguous // V: [S_KV], address, not contiguous -// Q: [B, H_Q, S_Q, D], contiguous +// Q: [B, S_Q, H_Q, D], contiguous +// +// After find KV Tokens, KV is [B, S_KV, H_KV, D] +// +// TODO This kernel's layout is error // // H_KV should <= H_Q template::lowest(); __AccDType scores_max_prev = std::numeric_limits<__AccDType>::lowest(); __AccDType logsum = 0; @@ -56,6 +62,7 @@ void fwd_bhsd(int32_t B, int32_t H_Q, int32_t H_KV, int32_t S_Q, int32_t S_KV, i int S_KV_BOUND = std::min(__delta + s_q_idx + 1, S_KV); for (int s_kv_idx = 0; s_kv_idx < S_KV_BOUND; ++s_kv_idx) { + // k_token and v_token shape is [B, 1, H, D] __KDType* k_token = __k[s_kv_idx]; __VDType* v_token = __v[s_kv_idx]; diff --git a/mllm/backends/cpu/ops/RoPEOp.cpp b/mllm/backends/cpu/ops/RoPEOp.cpp index 57a2b1542..326ed96f9 100644 --- a/mllm/backends/cpu/ops/RoPEOp.cpp +++ b/mllm/backends/cpu/ops/RoPEOp.cpp @@ -4,6 +4,7 @@ #include #include "mllm/backends/cpu/ops/RoPEOp.hpp" +#include "mllm/core/aops/RoPEOp.hpp" #include "mllm/utils/CPUArchHelper.hpp" #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) // Include AVX, SSE. @@ -13,7 +14,8 @@ namespace mllm::cpu { -void RoPEOpImpl::forward(const std::vector& inputs, std::vector& outputs, Tensor& sin, Tensor& cos) { +void RoPEOpImpl::forward(const std::vector& inputs, std::vector& outputs, Tensor& sin, Tensor& cos, + aops::RoPEOpOptionsInputType input_layout_type) { auto activation = inputs[0]; auto out = outputs[0]; @@ -30,67 +32,138 @@ void RoPEOpImpl::forward(const std::vector& inputs, std::vector& switch (activation.dtype()) { case kFloat32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - for (int n = 0; n < B; ++n) { - for (int h = 0; h < H; ++h) { - for (int s = 0; s < S; ++s) { - mllm_fp32_t* act_ptr = activation.offsettedPtr({n, h, s, 0}); - mllm_fp32_t* out_ptr = out.offsettedPtr({n, h, s, 0}); - const mllm_fp32_t* sin_ptr = sin.offsettedPtr({n, s, 0}); - const mllm_fp32_t* cos_ptr = cos.offsettedPtr({n, s, 0}); - - for (int d = 0; d < half; ++d) { - mllm_fp32_t in_val = act_ptr[d]; - mllm_fp32_t in_val2 = act_ptr[d + half]; - mllm_fp32_t sin_val = sin_ptr[d]; - mllm_fp32_t cos_val = cos_ptr[d]; - - out_ptr[d] = in_val * cos_val - in_val2 * sin_val; - out_ptr[d + half] = in_val * sin_val + in_val2 * cos_val; + switch (input_layout_type) { + case aops::RoPEOpOptionsInputType::kBHSD: { + for (int n = 0; n < B; ++n) { + for (int h = 0; h < H; ++h) { + for (int s = 0; s < S; ++s) { + mllm_fp32_t* act_ptr = activation.offsettedPtr({n, h, s, 0}); + mllm_fp32_t* out_ptr = out.offsettedPtr({n, h, s, 0}); + const mllm_fp32_t* sin_ptr = sin.offsettedPtr({n, s, 0}); + const mllm_fp32_t* cos_ptr = cos.offsettedPtr({n, s, 0}); + + for (int d = 0; d < half; ++d) { + mllm_fp32_t in_val = act_ptr[d]; + mllm_fp32_t in_val2 = act_ptr[d + half]; + mllm_fp32_t sin_val = sin_ptr[d]; + mllm_fp32_t cos_val = cos_ptr[d]; + + out_ptr[d] = in_val * cos_val - in_val2 * sin_val; + out_ptr[d + half] = in_val * sin_val + in_val2 * cos_val; + } + } + } + } + break; + } + case aops::RoPEOpOptionsInputType::kBSHD: { + for (int n = 0; n < B; ++n) { + for (int s = 0; s < S; ++s) { + const mllm_fp32_t* sin_ptr = sin.offsettedPtr({n, s, 0}); + const mllm_fp32_t* cos_ptr = cos.offsettedPtr({n, s, 0}); + for (int h = 0; h < H; ++h) { + mllm_fp32_t* act_ptr = activation.offsettedPtr({n, s, h, 0}); + mllm_fp32_t* out_ptr = out.offsettedPtr({n, s, h, 0}); + for (int d = 0; d < half; ++d) { + mllm_fp32_t in_val = act_ptr[d]; + mllm_fp32_t in_val2 = act_ptr[d + half]; + mllm_fp32_t sin_val = sin_ptr[d]; + mllm_fp32_t cos_val = cos_ptr[d]; + out_ptr[d] = in_val * cos_val - in_val2 * sin_val; + out_ptr[d + half] = in_val * sin_val + in_val2 * cos_val; + } + } } } + break; } } #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) - for (int n = 0; n < B; ++n) { - for (int h = 0; h < H; ++h) { - for (int s = 0; s < S; ++s) { - mllm_fp32_t* act_ptr = activation.offsettedPtr({n, h, s, 0}); - mllm_fp32_t* out_ptr = out.offsettedPtr({n, h, s, 0}); - const mllm_fp32_t* sin_ptr = sin.offsettedPtr({n, s, 0}); - const mllm_fp32_t* cos_ptr = cos.offsettedPtr({n, s, 0}); - - // Vectorized processing (4 elements per iteration) - int d = 0; - constexpr int step = 4; - for (; d <= half - step; d += step) { - // Load activation blocks - float32x4_t act_front = vld1q_f32(act_ptr + d); - float32x4_t act_back = vld1q_f32(act_ptr + d + half); - - // Load sin/cos values - float32x4_t sin_vec = vld1q_f32(sin_ptr + d); - float32x4_t cos_vec = vld1q_f32(cos_ptr + d); - - // Compute rotated values - float32x4_t out_front = vsubq_f32(vmulq_f32(act_front, cos_vec), vmulq_f32(act_back, sin_vec)); - float32x4_t out_back = vaddq_f32(vmulq_f32(act_front, sin_vec), vmulq_f32(act_back, cos_vec)); - - // Store results - vst1q_f32(out_ptr + d, out_front); - vst1q_f32(out_ptr + d + half, out_back); + switch (input_layout_type) { + case aops::RoPEOpOptionsInputType::kBHSD: { + for (int n = 0; n < B; ++n) { + for (int h = 0; h < H; ++h) { + for (int s = 0; s < S; ++s) { + mllm_fp32_t* act_ptr = activation.offsettedPtr({n, h, s, 0}); + mllm_fp32_t* out_ptr = out.offsettedPtr({n, h, s, 0}); + const mllm_fp32_t* sin_ptr = sin.offsettedPtr({n, s, 0}); + const mllm_fp32_t* cos_ptr = cos.offsettedPtr({n, s, 0}); + + // Vectorized processing (4 elements per iteration) + int d = 0; + constexpr int step = 4; + for (; d <= half - step; d += step) { + // Load activation blocks + float32x4_t act_front = vld1q_f32(act_ptr + d); + float32x4_t act_back = vld1q_f32(act_ptr + d + half); + + // Load sin/cos values + float32x4_t sin_vec = vld1q_f32(sin_ptr + d); + float32x4_t cos_vec = vld1q_f32(cos_ptr + d); + + // Compute rotated values + float32x4_t out_front = vsubq_f32(vmulq_f32(act_front, cos_vec), vmulq_f32(act_back, sin_vec)); + float32x4_t out_back = vaddq_f32(vmulq_f32(act_front, sin_vec), vmulq_f32(act_back, cos_vec)); + + // Store results + vst1q_f32(out_ptr + d, out_front); + vst1q_f32(out_ptr + d + half, out_back); + } + + // Process remaining elements + for (; d < half; ++d) { + mllm_fp32_t in_val = act_ptr[d]; + mllm_fp32_t in_val2 = act_ptr[d + half]; + mllm_fp32_t sin_val = sin_ptr[d]; + mllm_fp32_t cos_val = cos_ptr[d]; + + out_ptr[d] = in_val * cos_val - in_val2 * sin_val; + out_ptr[d + half] = in_val * sin_val + in_val2 * cos_val; + } + } } + } + break; + } + case aops::RoPEOpOptionsInputType::kBSHD: { + const int half = D / 2; + constexpr int step = 4; + + for (int b = 0; b < B; ++b) { + for (int s = 0; s < S; ++s) { + const mllm_fp32_t* sin_ptr = sin.offsettedPtr({b, s, 0}); + const mllm_fp32_t* cos_ptr = cos.offsettedPtr({b, s, 0}); - // Process remaining elements - for (; d < half; ++d) { - mllm_fp32_t in_val = act_ptr[d]; - mllm_fp32_t in_val2 = act_ptr[d + half]; - mllm_fp32_t sin_val = sin_ptr[d]; - mllm_fp32_t cos_val = cos_ptr[d]; + for (int h = 0; h < H; ++h) { + mllm_fp32_t* act_ptr = activation.offsettedPtr({b, s, h, 0}); + mllm_fp32_t* out_ptr = out.offsettedPtr({b, s, h, 0}); - out_ptr[d] = in_val * cos_val - in_val2 * sin_val; - out_ptr[d + half] = in_val * sin_val + in_val2 * cos_val; + int d = 0; + for (; d <= half - step; d += step) { + float32x4_t act_front = vld1q_f32(act_ptr + d); // [d, d+1, d+2, d+3] + float32x4_t act_back = vld1q_f32(act_ptr + d + half); // [d+half, ...] + float32x4_t sin_vec = vld1q_f32(sin_ptr + d); + float32x4_t cos_vec = vld1q_f32(cos_ptr + d); + + float32x4_t out_front = vsubq_f32(vmulq_f32(act_front, cos_vec), vmulq_f32(act_back, sin_vec)); + float32x4_t out_back = vaddq_f32(vmulq_f32(act_front, sin_vec), vmulq_f32(act_back, cos_vec)); + + vst1q_f32(out_ptr + d, out_front); + vst1q_f32(out_ptr + d + half, out_back); + } + for (; d < half; ++d) { + mllm_fp32_t in0 = act_ptr[d]; + mllm_fp32_t in1 = act_ptr[d + half]; + mllm_fp32_t s = sin_ptr[d]; + mllm_fp32_t c = cos_ptr[d]; + + out_ptr[d] = in0 * c - in1 * s; + out_ptr[d + half] = in0 * s + in1 * c; + } + } } } + break; } } #endif @@ -98,67 +171,144 @@ void RoPEOpImpl::forward(const std::vector& inputs, std::vector& } case kFloat16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - for (int n = 0; n < B; ++n) { - for (int h = 0; h < H; ++h) { - for (int s = 0; s < S; ++s) { - mllm_fp16_t* act_ptr = activation.offsettedPtr({n, h, s, 0}); - mllm_fp16_t* out_ptr = out.offsettedPtr({n, h, s, 0}); - const mllm_fp16_t* sin_ptr = sin.offsettedPtr({n, s, 0}); - const mllm_fp16_t* cos_ptr = cos.offsettedPtr({n, s, 0}); - - for (int d = 0; d < half; ++d) { - mllm_fp32_t in_val = static_cast(act_ptr[d]); - mllm_fp32_t in_val2 = static_cast(act_ptr[d + half]); - mllm_fp32_t sin_val = static_cast(sin_ptr[d]); - mllm_fp32_t cos_val = static_cast(cos_ptr[d]); - - out_ptr[d] = static_cast(in_val * cos_val - in_val2 * sin_val); - out_ptr[d + half] = static_cast(in_val * sin_val + in_val2 * cos_val); + switch (input_layout_type) { + case aops::RoPEOpOptionsInputType::kBHSD: { + for (int n = 0; n < B; ++n) { + for (int h = 0; h < H; ++h) { + for (int s = 0; s < S; ++s) { + mllm_fp16_t* act_ptr = activation.offsettedPtr({n, h, s, 0}); + mllm_fp16_t* out_ptr = out.offsettedPtr({n, h, s, 0}); + const mllm_fp16_t* sin_ptr = sin.offsettedPtr({n, s, 0}); + const mllm_fp16_t* cos_ptr = cos.offsettedPtr({n, s, 0}); + + for (int d = 0; d < half; ++d) { + mllm_fp32_t in_val = static_cast(act_ptr[d]); + mllm_fp32_t in_val2 = static_cast(act_ptr[d + half]); + mllm_fp32_t sin_val = static_cast(sin_ptr[d]); + mllm_fp32_t cos_val = static_cast(cos_ptr[d]); + + out_ptr[d] = static_cast(in_val * cos_val - in_val2 * sin_val); + out_ptr[d + half] = static_cast(in_val * sin_val + in_val2 * cos_val); + } + } + } + } + break; + } + case aops::RoPEOpOptionsInputType::kBSHD: { + for (int n = 0; n < B; ++n) { + for (int s = 0; s < S; ++s) { + const mllm_fp16_t* sin_ptr = sin.offsettedPtr({n, s, 0}); + const mllm_fp16_t* cos_ptr = cos.offsettedPtr({n, s, 0}); + for (int h = 0; h < H; ++h) { + mllm_fp16_t* act_ptr = activation.offsettedPtr({n, s, h, 0}); + mllm_fp16_t* out_ptr = out.offsettedPtr({n, s, h, 0}); + for (int d = 0; d < half; ++d) { + mllm_fp32_t in_val = static_cast(act_ptr[d]); + mllm_fp32_t in_val2 = static_cast(act_ptr[d + half]); + mllm_fp32_t sin_val = static_cast(sin_ptr[d]); + mllm_fp32_t cos_val = static_cast(cos_ptr[d]); + + out_ptr[d] = static_cast(in_val * cos_val - in_val2 * sin_val); + out_ptr[d + half] = static_cast(in_val * sin_val + in_val2 * cos_val); + } + } } } + break; } } #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) - for (int n = 0; n < B; ++n) { - for (int h = 0; h < H; ++h) { - for (int s = 0; s < S; ++s) { - mllm_fp16_t* act_ptr = activation.offsettedPtr({n, h, s, 0}); - mllm_fp16_t* out_ptr = out.offsettedPtr({n, h, s, 0}); - const mllm_fp16_t* sin_ptr = sin.offsettedPtr({n, s, 0}); - const mllm_fp16_t* cos_ptr = cos.offsettedPtr({n, s, 0}); - - // Vectorized processing (8 elements per iteration) - int d = 0; - constexpr int step = 8; - for (; d <= half - step; d += step) { - // Load activation blocks - float16x8_t act_front = vld1q_f16(act_ptr + d); - float16x8_t act_back = vld1q_f16(act_ptr + d + half); - - // Load sin/cos values - float16x8_t sin_vec = vld1q_f16(sin_ptr + d); - float16x8_t cos_vec = vld1q_f16(cos_ptr + d); - - // Compute rotated values - float16x8_t out_front = vsubq_f16(vmulq_f16(act_front, cos_vec), vmulq_f16(act_back, sin_vec)); - float16x8_t out_back = vaddq_f16(vmulq_f16(act_front, sin_vec), vmulq_f16(act_back, cos_vec)); - - // Store results - vst1q_f16(out_ptr + d, out_front); - vst1q_f16(out_ptr + d + half, out_back); + switch (input_layout_type) { + case aops::RoPEOpOptionsInputType::kBHSD: { + for (int n = 0; n < B; ++n) { + for (int h = 0; h < H; ++h) { + for (int s = 0; s < S; ++s) { + mllm_fp16_t* act_ptr = activation.offsettedPtr({n, h, s, 0}); + mllm_fp16_t* out_ptr = out.offsettedPtr({n, h, s, 0}); + const mllm_fp16_t* sin_ptr = sin.offsettedPtr({n, s, 0}); + const mllm_fp16_t* cos_ptr = cos.offsettedPtr({n, s, 0}); + + // Vectorized processing (8 elements per iteration) + int d = 0; + constexpr int step = 8; + for (; d <= half - step; d += step) { + // Load activation blocks + float16x8_t act_front = vld1q_f16(act_ptr + d); + float16x8_t act_back = vld1q_f16(act_ptr + d + half); + + // Load sin/cos values + float16x8_t sin_vec = vld1q_f16(sin_ptr + d); + float16x8_t cos_vec = vld1q_f16(cos_ptr + d); + + // Compute rotated values + float16x8_t out_front = vsubq_f16(vmulq_f16(act_front, cos_vec), vmulq_f16(act_back, sin_vec)); + float16x8_t out_back = vaddq_f16(vmulq_f16(act_front, sin_vec), vmulq_f16(act_back, cos_vec)); + + // Store results + vst1q_f16(out_ptr + d, out_front); + vst1q_f16(out_ptr + d + half, out_back); + } + + // Process remaining elements + for (; d < half; ++d) { + mllm_fp32_t in_val = static_cast(act_ptr[d]); + mllm_fp32_t in_val2 = static_cast(act_ptr[d + half]); + mllm_fp32_t sin_val = static_cast(sin_ptr[d]); + mllm_fp32_t cos_val = static_cast(cos_ptr[d]); + + out_ptr[d] = static_cast(in_val * cos_val - in_val2 * sin_val); + out_ptr[d + half] = static_cast(in_val * sin_val + in_val2 * cos_val); + } + } } + } + break; + } + case aops::RoPEOpOptionsInputType::kBSHD: { + for (int n = 0; n < B; ++n) { + for (int s = 0; s < S; ++s) { + const mllm_fp16_t* sin_ptr = sin.offsettedPtr({n, s, 0}); + const mllm_fp16_t* cos_ptr = cos.offsettedPtr({n, s, 0}); + + for (int h = 0; h < H; ++h) { + mllm_fp16_t* act_ptr = activation.offsettedPtr({n, s, h, 0}); + mllm_fp16_t* out_ptr = out.offsettedPtr({n, s, h, 0}); + // Vectorized processing (8 elements per iteration) + int d = 0; + constexpr int step = 8; + for (; d <= half - step; d += step) { + // Load activation blocks + float16x8_t act_front = vld1q_f16(act_ptr + d); + float16x8_t act_back = vld1q_f16(act_ptr + d + half); + + // Load sin/cos values + float16x8_t sin_vec = vld1q_f16(sin_ptr + d); + float16x8_t cos_vec = vld1q_f16(cos_ptr + d); + + // Compute rotated values + float16x8_t out_front = vsubq_f16(vmulq_f16(act_front, cos_vec), vmulq_f16(act_back, sin_vec)); + float16x8_t out_back = vaddq_f16(vmulq_f16(act_front, sin_vec), vmulq_f16(act_back, cos_vec)); + + // Store results + vst1q_f16(out_ptr + d, out_front); + vst1q_f16(out_ptr + d + half, out_back); + } - // Process remaining elements - for (; d < half; ++d) { - mllm_fp32_t in_val = static_cast(act_ptr[d]); - mllm_fp32_t in_val2 = static_cast(act_ptr[d + half]); - mllm_fp32_t sin_val = static_cast(sin_ptr[d]); - mllm_fp32_t cos_val = static_cast(cos_ptr[d]); + // Process remaining elements + for (; d < half; ++d) { + mllm_fp32_t in_val = static_cast(act_ptr[d]); + mllm_fp32_t in_val2 = static_cast(act_ptr[d + half]); + mllm_fp32_t sin_val = static_cast(sin_ptr[d]); + mllm_fp32_t cos_val = static_cast(cos_ptr[d]); - out_ptr[d] = static_cast(in_val * cos_val - in_val2 * sin_val); - out_ptr[d + half] = static_cast(in_val * sin_val + in_val2 * cos_val); + out_ptr[d] = static_cast(in_val * cos_val - in_val2 * sin_val); + out_ptr[d + half] = static_cast(in_val * sin_val + in_val2 * cos_val); + } + } } } + break; } } #endif @@ -189,7 +339,7 @@ void CPURoPEOp::forward(const std::vector& inputs, std::vector& auto out = outputs[0]; auto impl = RoPEOpImpl(); - impl.forward(inputs, outputs, sin, cos); + impl.forward(inputs, outputs, sin, cos, options_.input_type); } } // namespace mllm::cpu diff --git a/mllm/backends/cpu/ops/RoPEOp.hpp b/mllm/backends/cpu/ops/RoPEOp.hpp index 5c7b10c78..efeb9ebda 100644 --- a/mllm/backends/cpu/ops/RoPEOp.hpp +++ b/mllm/backends/cpu/ops/RoPEOp.hpp @@ -9,7 +9,8 @@ namespace mllm::cpu { struct RoPEOpImpl { - void forward(const std::vector& inputs, std::vector& outputs, Tensor& sin, Tensor& cos); + void forward(const std::vector& inputs, std::vector& outputs, Tensor& sin, Tensor& cos, + aops::RoPEOpOptionsInputType input_layout_type); }; class CPURoPEOp final : public aops::RoPEOp { @@ -26,4 +27,4 @@ class CPURoPEOpFactory : public TypedOpFactory { float rope_theta = 10000.0F; int32_t max_position_embeddings = 16384; - - RoPEOpOptions() = default; - explicit RoPEOpOptions(float theta, int32_t max_pos_embed) : rope_theta(theta), max_position_embeddings(max_pos_embed) {} + RoPEOpOptionsInputType input_type = RoPEOpOptionsInputType::kBHSD; }; class RoPEOp : public BaseOp { @@ -36,4 +39,4 @@ class RoPEOp : public BaseOp { RoPEOpOptions options_; }; -} // namespace mllm::aops \ No newline at end of file +} // namespace mllm::aops diff --git a/mllm/models/qwen3/modeling_qwen3_service.hpp b/mllm/models/qwen3/modeling_qwen3_service.hpp index 7fd138fb1..ba6cbb549 100644 --- a/mllm/models/qwen3/modeling_qwen3_service.hpp +++ b/mllm/models/qwen3/modeling_qwen3_service.hpp @@ -113,6 +113,7 @@ class Qwen3Attention final : public nn::Module { nn::RMSNorm rms_norm_k_; nn::RoPE q_rope_; nn::RoPE k_rope_; + nn::RadixAttn attn_; int hidden_size_; int head_dim_; @@ -130,20 +131,20 @@ class Qwen3Attention final : public nn::Module { head_dim_ = cfg.head_dim; num_key_value_groups_ = num_attention_heads_ / num_key_value_heads_; - q_proj_ = - reg("q_proj", hidden_size_, head_dim_ * num_attention_heads_, cfg.attention_bias, cfg.linear_impl_type); - k_proj_ = - reg("k_proj", hidden_size_, head_dim_ * num_key_value_heads_, cfg.attention_bias, cfg.linear_impl_type); - v_proj_ = - reg("v_proj", hidden_size_, head_dim_ * num_key_value_heads_, cfg.attention_bias, cfg.linear_impl_type); - o_proj_ = - reg("o_proj", head_dim_ * num_attention_heads_, hidden_size_, cfg.attention_bias, cfg.linear_impl_type); + // clang-format off + q_proj_ = reg("q_proj", hidden_size_, head_dim_ * num_attention_heads_, cfg.attention_bias, cfg.linear_impl_type); + k_proj_ = reg("k_proj", hidden_size_, head_dim_ * num_key_value_heads_, cfg.attention_bias, cfg.linear_impl_type); + v_proj_ = reg("v_proj", hidden_size_, head_dim_ * num_key_value_heads_, cfg.attention_bias, cfg.linear_impl_type); + o_proj_ = reg("o_proj", head_dim_ * num_attention_heads_, hidden_size_, cfg.attention_bias, cfg.linear_impl_type); + // clang-format on rms_norm_q_ = reg("q_norm", cfg.rms_norm_eps); rms_norm_k_ = reg("k_norm", cfg.rms_norm_eps); - q_rope_ = reg("q_rope", cfg.rope_theta, cfg.max_position_embeddings); - k_rope_ = reg("k_rope", cfg.rope_theta, cfg.max_position_embeddings); + q_rope_ = reg("q_rope", cfg.rope_theta, cfg.max_position_embeddings, aops::RoPEOpOptionsInputType::kBSHD); + k_rope_ = reg("k_rope", cfg.rope_theta, cfg.max_position_embeddings, aops::RoPEOpOptionsInputType::kBSHD); + + attn_ = reg("attn", num_attention_heads_, num_key_value_heads_); } std::vector forward(const std::vector& inputs, const std::vector& args) override { @@ -171,12 +172,8 @@ class Qwen3Attention final : public nn::Module { query_states = rms_norm_q_(query_states); key_states = rms_norm_k_(key_states); - // [B, H, S, D] - query_states = query_states.transpose(1, 2); - key_states = key_states.transpose(1, 2); - value_states = value_states.transpose(1, 2); - - // [B, H, S, D] + // Different from original [B, H, S, D] rope. + // [B, S, H, D] query_states = q_rope_(query_states, llm_embedding_sin, llm_embedding_cos); key_states = k_rope_(key_states, llm_embedding_sin, llm_embedding_cos); @@ -186,8 +183,8 @@ class Qwen3Attention final : public nn::Module { std::vector k_addr_wait_for_promote; std::vector v_addr_wait_for_promote; for (int s_idx; s_idx < S; ++s_idx) { - k_addr_wait_for_promote.push_back(prefix_cache_context->alloc(DeviceTypes::kCPU)); - v_addr_wait_for_promote.push_back(prefix_cache_context->alloc(DeviceTypes::kCPU)); + k_addr_wait_for_promote.push_back(prefix_cache_context->alloc(kCPU)); + v_addr_wait_for_promote.push_back(prefix_cache_context->alloc(kCPU)); } // Prepare indicies cache. sizeof(char*) == 8 == sizeof(int64_t) @@ -201,8 +198,8 @@ class Qwen3Attention final : public nn::Module { auto v_wait_for_promote = Tensor::refVectorData(v_phy_addr_wait_for_promote, {S}, kInt64, kCPU); // Copy key_states and value_states to cache - nn::functional::scatter2Shards(key_states, k_wait_for_promote, 2); - nn::functional::scatter2Shards(value_states, v_wait_for_promote, 2); + nn::functional::scatter2Shards(key_states, k_wait_for_promote, 1); + nn::functional::scatter2Shards(value_states, v_wait_for_promote, 1); // Gather all cache to indicies tensor std::ranges::copy((*k_cache_addr)[layer_idx_], std::back_inserter(k_addr_wait_for_promote)); @@ -219,9 +216,13 @@ class Qwen3Attention final : public nn::Module { auto k_cache = Tensor::refVectorData(k_phy_cache_indicies, {kv_cache_len}, kInt64, kCPU); auto v_cache = Tensor::refVectorData(v_phy_cache_indicies, {kv_cache_len}, kInt64, kCPU); - // TODO Do Radix Attention + // Do Radix Attention + // output is [B, S, H, D] + auto output = attn_(query_states, k_cache, v_cache); + output = output.view({B, S, num_attention_heads_ * head_dim_}); + output = o_proj_(output); - return {}; + return {output}; } int layer_idx_; @@ -486,9 +487,9 @@ class Qwen3Session final : public ::mllm::service::Session { // [{"role": "user", "content": "Say this is a test!"}] // [{"role": "assistant", "content": "This is a test!", "reasoning": "You are absolutely right!"}] // - std::string applyChatTemplate(const json& messages, const json* tools = nullptr, bool add_generation_prompt = true, - bool enable_thinking = true, const std::string& bos_token = "", - const std::string& eos_token = "<|im_end|>") { + std::string applyChatTemplate(const nlohmann::json& messages, const nlohmann::json* tools = nullptr, + bool add_generation_prompt = true, bool enable_thinking = true, + const std::string& bos_token = "", const std::string& eos_token = "<|im_end|>") { std::ostringstream out; if (tools && tools->is_array() && !tools->empty()) { diff --git a/mllm/nn/Nn.hpp b/mllm/nn/Nn.hpp index fa9074355..7bab5ebaa 100644 --- a/mllm/nn/Nn.hpp +++ b/mllm/nn/Nn.hpp @@ -26,3 +26,4 @@ #include "mllm/nn/layers/Conv1D.hpp" // IWYU pragma: export #include "mllm/nn/layers/STFT.hpp" // IWYU pragma: export #include "mllm/nn/layers/PagedAttn.hpp" // IWYU pragma: export +#include "mllm/nn/layers/RadixAttn.hpp" // IWYU pragma: export diff --git a/mllm/nn/layers/RadixAttn.cpp b/mllm/nn/layers/RadixAttn.cpp new file mode 100644 index 000000000..cf0c891b8 --- /dev/null +++ b/mllm/nn/layers/RadixAttn.cpp @@ -0,0 +1,16 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/core/aops/RadixAttnOp.hpp" +#include "mllm/nn/layers/RadixAttn.hpp" + +namespace mllm::nn { + +RadixAttn::RadixAttn() : Layer(OpTypes::kRadixAttn, aops::RadixAttnOpOptions{}) {} + +RadixAttn::RadixAttn(const aops::RadixAttnOpOptions& options) : Layer(OpTypes::kRadixAttn, options) {} + +RadixAttn::RadixAttn(int32_t H_Q, int32_t H_KV) + : Layer(OpTypes::kRadixAttn, aops::RadixAttnOpOptions{.H_Q = H_Q, .H_KV = H_KV}) {} + +} // namespace mllm::nn diff --git a/mllm/nn/layers/RadixAttn.hpp b/mllm/nn/layers/RadixAttn.hpp new file mode 100644 index 000000000..a641ccb7d --- /dev/null +++ b/mllm/nn/layers/RadixAttn.hpp @@ -0,0 +1,23 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/nn/Layer.hpp" +#include "mllm/core/aops/RadixAttnOp.hpp" + +namespace mllm::nn { + +class RadixAttn : public Layer { + public: + RadixAttn(); + + explicit RadixAttn(const aops::RadixAttnOpOptions& options); + + RadixAttn(int32_t H_Q, int32_t H_KV); + + // Q, K, V in and one output out + MLLM_LAYER_ANY_INPUTS_1_OUTPUTS_FORWARD +}; + +} // namespace mllm::nn diff --git a/mllm/nn/layers/RoPE.cpp b/mllm/nn/layers/RoPE.cpp index 8659aa017..3b44c529b 100644 --- a/mllm/nn/layers/RoPE.cpp +++ b/mllm/nn/layers/RoPE.cpp @@ -8,7 +8,9 @@ namespace mllm::nn { RoPE::RoPE() : Layer(OpTypes::kRoPE, aops::RoPEOpOptions{}) {} -RoPE::RoPE(float theta, int32_t max_position_embeddings) - : Layer(OpTypes::kRoPE, aops::RoPEOpOptions(theta, max_position_embeddings)) {} +RoPE::RoPE(float theta, int32_t max_position_embeddings, aops::RoPEOpOptionsInputType input_type) + : Layer(OpTypes::kRoPE, + aops::RoPEOpOptions{ + .rope_theta = theta, .max_position_embeddings = max_position_embeddings, .input_type = input_type}) {} -} // namespace mllm::nn \ No newline at end of file +} // namespace mllm::nn diff --git a/mllm/nn/layers/RoPE.hpp b/mllm/nn/layers/RoPE.hpp index abc38c680..2ca512a41 100644 --- a/mllm/nn/layers/RoPE.hpp +++ b/mllm/nn/layers/RoPE.hpp @@ -11,7 +11,9 @@ namespace mllm::nn { class RoPE : public Layer { public: RoPE(); - explicit RoPE(float theta, int32_t max_position_embeddings); + + RoPE(float theta, int32_t max_position_embeddings, + aops::RoPEOpOptionsInputType input_type = aops::RoPEOpOptionsInputType::kBHSD); MLLM_LAYER_ANY_INPUTS_1_OUTPUTS_FORWARD @@ -20,4 +22,4 @@ class RoPE : public Layer { int32_t max_position_embeddings_ = 128; }; -} // namespace mllm::nn \ No newline at end of file +} // namespace mllm::nn diff --git a/mllm/nn/lmcache/PrefixCache.hpp b/mllm/nn/lmcache/PrefixCache.hpp index b32f7234d..051654626 100644 --- a/mllm/nn/lmcache/PrefixCache.hpp +++ b/mllm/nn/lmcache/PrefixCache.hpp @@ -3,4 +3,4 @@ #pragma once -#include "mllm/engine/prefix_cache/Cache.hpp" +#include "mllm/engine/prefix_cache/Cache.hpp" // IWYU pragma: export diff --git a/tests/cpu/CMakeLists.txt b/tests/cpu/CMakeLists.txt index 940544be4..478f867d9 100644 --- a/tests/cpu/CMakeLists.txt +++ b/tests/cpu/CMakeLists.txt @@ -1,4 +1,3 @@ - add_executable(Mllm-Test-CPUKernel KernelTest.cpp) target_link_libraries(Mllm-Test-CPUKernel PRIVATE gtest_main MllmRT MllmCPUBackend) target_include_directories(Mllm-Test-CPUKernel PRIVATE ${MLLM_INCLUDE_DIR}) diff --git a/tests/cpu/KernelTest.cpp b/tests/cpu/KernelTest.cpp index 06ddf432c..1d8e1853a 100644 --- a/tests/cpu/KernelTest.cpp +++ b/tests/cpu/KernelTest.cpp @@ -792,6 +792,7 @@ TEST_F(ReduceKernelTest, SumFloat32) { //===----------------------------------------------------------------------===// // Paged Attn //===----------------------------------------------------------------------===// +#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) #include "PagedAttnTest.hpp" TEST_F(PagedAttnTest, fwd_bshd) { EXPECT_EQ(manyCases({ @@ -803,6 +804,38 @@ TEST_F(PagedAttnTest, fwd_bshd) { }), true); } +#endif + +//===----------------------------------------------------------------------===// +// Radix Attn +//===----------------------------------------------------------------------===// +#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) +#include "RadixAttnKernel.hpp" +TEST_F(RadixAttnKernelTest, fwd_bshd) { + EXPECT_EQ(testRadixAttn({{ + {"H_Q", 28}, + {"H_KV", 2}, + {"S_Q", 10}, + {"S_KV", 10}, + {"D", 128}, + }, + { + {"H_Q", 28}, + {"H_KV", 2}, + {"S_Q", 10}, + {"S_KV", 20}, + {"D", 128}, + }, + { + {"H_Q", 28}, + {"H_KV", 2}, + {"S_Q", 1}, + {"S_KV", 20}, + {"D", 128}, + }}), + true); +} +#endif int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); diff --git a/tests/cpu/RadixAttnKernel.hpp b/tests/cpu/RadixAttnKernel.hpp new file mode 100644 index 000000000..2bfe63854 --- /dev/null +++ b/tests/cpu/RadixAttnKernel.hpp @@ -0,0 +1,156 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include + +#include "mllm/mllm.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/nn/lmcache/PrefixCache.hpp" + +#include "KernelTestHelper.hpp" + +using namespace mllm; // NOLINT + +class RadixAttnModule : public nn::Module { + nn::RadixAttn attn_; + + public: + RadixAttnModule() = default; + + RadixAttnModule(int H_Q, int H_KV) : nn::Module() { attn_ = reg("attn", H_Q, H_KV); } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + // inputs is Q, K_indices, V_indices + return {attn_(inputs[0], inputs[1], inputs[2])}; + } +}; + +class EagerModule : public nn::Module { + nn::CausalMask mask_; + + public: + EagerModule() : nn::Module() { mask_ = reg("mask"); } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + // inputs is Q, K_indices, V_indices + // Q, K, V is [B, S, H, D] + auto Q = inputs[0]; + auto K = inputs[1]; + auto V = inputs[2]; + + auto h_q = Q.shape()[2]; + auto h_kv = K.shape()[2]; + auto head_dim = Q.shape()[3]; + + // Q, K, V is [B, H, S, D] + Q = Q.transpose(1, 2); + K = K.transpose(1, 2).repeat(h_q / h_kv, 1); + V = V.transpose(1, 2).repeat(h_q / h_kv, 1); + + // Attention Weight + // [B, H, S, S] + auto attn = nn::functional::matmul(Q, K, false, true) * (1.f / sqrtf(head_dim)); + attn = mask_(attn); + attn = nn::functional::softmax(attn, -1); + // [B, H, S, D] + auto output = nn::functional::matmul(attn, V); + // [B, S, H, D] + output = output.transpose(1, 2); + + return {output}; + } +}; + +class RadixAttnKernelTest : public KernelTest { + public: + RadixAttnKernelTest() = default; + ~RadixAttnKernelTest() override = default; + + bool testRadixAttnOnce(const std::unordered_map& cfg) { + int B = 1; + int H_Q = cfg.at("H_Q"); + int H_KV = cfg.at("H_KV"); + int S_Q = cfg.at("S_Q"); + int S_KV = cfg.at("S_KV"); + int D = cfg.at("D"); + + mllm::prefix_cache::CacheOptions opt{ + .radix_tree_options = {.enable_lru_eviction = false, + .eviction_threshold = 0.9f, + .enable_path_compression = false, + .min_compression_length = 2, + .transformer_blocks_num = 1}, + .allocator_options = {// Normal things. + .per_k_token_ele = static_cast(H_KV * D), + .per_v_token_ele = static_cast(H_KV * D), + .k_dtype = mllm::kFloat32, + .v_dtype = mllm::kFloat32, + + // CUDA things. + .enable_cuda = false, + .cuda_mem_base = 0x100000, + + // CPU things. + .enable_cpu_hierarchy_memory = true, + .zen_fs_options = { + .record = false, + .working_dir = ".", + .blob_bits_size = 20, + .page_bits = 7, + .lane_bits = 5, + .per_k_token_ele = static_cast(H_KV * D), + .per_v_token_ele = static_cast(H_KV * D), + .k_dtype = mllm::kFloat32, + .v_dtype = mllm::kFloat32, + .mmap_type = mllm::prefix_cache::ZenFSBlobMMapType::kAnonymous, + }}}; + EagerModule eager_attn; + RadixAttnModule radix_attn; + prefix_cache::Cache cache(opt); + + // Create Q, K, V + auto Q = Tensor::random({B, S_Q, H_Q, D}, -10.f, 10.f); + auto K = Tensor::random({B, S_KV, H_KV, D}, -10.f, 10.f); + auto V = Tensor::random({B, S_KV, H_KV, D}, -10.f, 10.f); + + // Insert K and V into Cache and Radix Tree + std::vector k_cache_addrs; + std::vector v_cache_addrs; + for (int i = 0; i < S_KV; i++) { + k_cache_addrs.push_back(cache.alloc(kCPU)); + v_cache_addrs.push_back(cache.alloc(kCPU)); + } + std::vector k_cache_ptrs; + std::vector v_cache_ptrs; + for (int i = 0; i < S_KV; i++) { + k_cache_ptrs.push_back(cache.physicalAddr(k_cache_addrs[i])); + v_cache_ptrs.push_back(cache.physicalAddr(v_cache_addrs[i])); + } + auto k_cache_indices = Tensor::refVectorData(k_cache_ptrs, {S_KV}, kInt64); + auto v_cache_indices = Tensor::refVectorData(v_cache_ptrs, {S_KV}, kInt64); + + nn::functional::scatter2Shards(K, k_cache_indices, 1); + nn::functional::scatter2Shards(K, v_cache_indices, 1); + + // Compute eager + Tensor gt = eager_attn(Q, K, V)[0]; + Tensor predict = radix_attn(Q, k_cache_indices, v_cache_indices)[0]; + Dbg(); + + // Compare + auto result = test::allClose(gt, predict); + if (!result) { + print(result); + return false; + } + return true; + } + + bool testRadixAttn(const std::vector>& cfgs) { + for (auto& cfg : cfgs) { + if (!testRadixAttnOnce(cfg)) { return false; } + } + return true; + } +}; From c3273b4b6d6cec8b5e0dc03545aac9803529e24b Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Wed, 8 Oct 2025 17:20:50 +0800 Subject: [PATCH 4/8] feat(cpu): implement FilledWithConst for radix attention kernels - Add FilledWithConst struct template in arch.hpp - Implement FilledWithConst specializations for __AnyArchTag and __ArmArchTag - Replace manual loop initialization with FilledWithConst in fwd_bshd.hpp - Update Q tensor shape indexing in RadixAttnOp.cpp - Improve VectorDotProduct, MulFromConst and FMAConstArray ARM implementations - Fix softmax computation logic in radix attention forward pass - Add proper mask handling in RadixAttnKernel test - Add Scatter2ShardsKernelTest for shard scattering validation - Fix random state management in Context class - Add RandomStatesTest for verifying random seed behavior --- .../cpu/kernels/common/radix_attn/arch.hpp | 5 + .../kernels/common/radix_attn/fwd_bshd.hpp | 17 +- .../kernels/common/radix_attn/impl-any.hpp | 44 +++++ .../kernels/common/radix_attn/impl-arm.hpp | 159 ++++++++++++------ mllm/backends/cpu/ops/RadixAttnOp.cpp | 6 +- mllm/core/Tensor.cpp | 2 +- mllm/engine/Context.cpp | 30 ++-- mllm/engine/Context.hpp | 3 + mllm/mllm.cpp | 2 + mllm/mllm.hpp | 2 + tests/cpu/KernelTest.cpp | 42 +---- tests/cpu/RadixAttnKernel.hpp | 48 +++++- tests/cpu/Scatter2ShardsKernelTest.hpp | 28 +++ tests/engine/CMakeLists.txt | 4 + tests/engine/RandomStatesTest.cpp | 12 ++ 15 files changed, 279 insertions(+), 125 deletions(-) create mode 100644 tests/cpu/Scatter2ShardsKernelTest.hpp create mode 100644 tests/engine/RandomStatesTest.cpp diff --git a/mllm/backends/cpu/kernels/common/radix_attn/arch.hpp b/mllm/backends/cpu/kernels/common/radix_attn/arch.hpp index 5fcd32f8f..dae6e8e76 100644 --- a/mllm/backends/cpu/kernels/common/radix_attn/arch.hpp +++ b/mllm/backends/cpu/kernels/common/radix_attn/arch.hpp @@ -32,4 +32,9 @@ struct FMAConstArray { static MLLM_FORCE_INLINE void run(T* __restrict__ acc_o, const U acc_s, const V* __restrict__ v_token, size_t len) {} }; +template +struct FilledWithConst { + static MLLM_FORCE_INLINE void run(T* __restrict__ a, const T v, size_t len) {} +}; + } // namespace mllm::cpu::radix_attn::details diff --git a/mllm/backends/cpu/kernels/common/radix_attn/fwd_bshd.hpp b/mllm/backends/cpu/kernels/common/radix_attn/fwd_bshd.hpp index 2fc9f5cd3..d562b1a6a 100644 --- a/mllm/backends/cpu/kernels/common/radix_attn/fwd_bshd.hpp +++ b/mllm/backends/cpu/kernels/common/radix_attn/fwd_bshd.hpp @@ -25,9 +25,7 @@ namespace mllm::cpu::radix_attn { // V: [S_KV], address, not contiguous // Q: [B, S_Q, H_Q, D], contiguous // -// After find KV Tokens, KV is [B, S_KV, H_KV, D] -// -// TODO This kernel's layout is error +// After find KV Tokens, KV is [B, 1, H_KV, D] // // H_KV should <= H_Q template::run(acc_o, 0, D); - __AccDType scores_max = std::numeric_limits<__AccDType>::lowest(); - __AccDType scores_max_prev = std::numeric_limits<__AccDType>::lowest(); + __AccDType scores_max = -std::numeric_limits<__AccDType>::infinity(); + __AccDType scores_max_prev = -std::numeric_limits<__AccDType>::infinity(); __AccDType logsum = 0; - __AccDType scores_scale = 0; __AccDType scores_sum = 0; + __AccDType scores_scale = 0; int __delta = S_KV - S_Q; int S_KV_BOUND = std::min(__delta + s_q_idx + 1, S_KV); @@ -77,11 +75,10 @@ void fwd_bhsd(int32_t B, int32_t H_Q, int32_t H_KV, int32_t S_Q, int32_t S_KV, i // 2. Do softmax stuff. scores_max_prev = scores_max; - scores_max = std::numeric_limits<__AccDType>::lowest(); - scores_max = std::max(scores_max, acc_s); + scores_max = std::max(scores_max_prev, acc_s); scores_scale = std::exp2(scores_max_prev * scale - scores_max * scale); acc_s = std::exp2(acc_s * scale - scores_max * scale); - scores_sum += acc_s; // TODO This line may be error. + scores_sum = acc_s; logsum = logsum * scores_scale + scores_sum; // 3. Scale diff --git a/mllm/backends/cpu/kernels/common/radix_attn/impl-any.hpp b/mllm/backends/cpu/kernels/common/radix_attn/impl-any.hpp index e69de29bb..a5905c3b8 100644 --- a/mllm/backends/cpu/kernels/common/radix_attn/impl-any.hpp +++ b/mllm/backends/cpu/kernels/common/radix_attn/impl-any.hpp @@ -0,0 +1,44 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/DataTypes.hpp" +#include "mllm/backends/cpu/kernels/common/radix_attn/arch.hpp" + +#include + +namespace mllm::cpu::radix_attn::details { +template<> +struct VectorDotProduct<__AnyArchTag, mllm_fp32_t, mllm_fp32_t, mllm_fp32_t> { + static MLLM_FORCE_INLINE void run(const mllm_fp32_t* __restrict__ __lhs, const mllm_fp32_t* __restrict__ __rhs, + mllm_fp32_t* __out, size_t len) { + mllm_fp32_t ret = 0; + for (size_t i = 0; i < len; ++i) { ret += __lhs[i] * __rhs[i]; } + *__out = ret; + } +}; + +template<> +struct MulFromConst<__AnyArchTag, mllm_fp32_t, mllm_fp32_t> { + static MLLM_FORCE_INLINE void run(mllm_fp32_t* __restrict__ __from, const mllm_fp32_t const_v, size_t len) { + for (int i = 0; i < len; ++i) { __from[i] *= const_v; } + } +}; + +template<> +struct FMAConstArray<__AnyArchTag, mllm_fp32_t, mllm_fp32_t, mllm_fp32_t> { + static MLLM_FORCE_INLINE void run(mllm_fp32_t* __restrict__ acc_o, const mllm_fp32_t acc_s, + const mllm_fp32_t* __restrict__ v_token, size_t len) { + for (int i = 0; i < len; ++i) { acc_o[i] += acc_s * v_token[i]; } + } +}; + +template<> +struct FilledWithConst<__AnyArchTag, mllm_fp32_t> { + static MLLM_FORCE_INLINE void run(mllm_fp32_t* __restrict__ a, const mllm_fp32_t v, size_t len) { + for (int i = 0; i < len; ++i) { a[i] = v; } + } +}; + +} // namespace mllm::cpu::radix_attn::details diff --git a/mllm/backends/cpu/kernels/common/radix_attn/impl-arm.hpp b/mllm/backends/cpu/kernels/common/radix_attn/impl-arm.hpp index 962916bbb..859688da4 100644 --- a/mllm/backends/cpu/kernels/common/radix_attn/impl-arm.hpp +++ b/mllm/backends/cpu/kernels/common/radix_attn/impl-arm.hpp @@ -13,44 +13,37 @@ template<> struct VectorDotProduct<__ArmArchTag, mllm_fp32_t, mllm_fp32_t, mllm_fp32_t> { static MLLM_FORCE_INLINE void run(const mllm_fp32_t* __restrict__ __lhs, const mllm_fp32_t* __restrict__ __rhs, mllm_fp32_t* __out, size_t len) { - float32x4_t sum_vec0 = vdupq_n_f32(0.0f); - float32x4_t sum_vec1 = vdupq_n_f32(0.0f); - float32x4_t sum_vec2 = vdupq_n_f32(0.0f); - float32x4_t sum_vec3 = vdupq_n_f32(0.0f); + float32x4_t sum_vec = vdupq_n_f32(0.0f); size_t i = 0; - const size_t main_loop_bound = len - 15; - for (; i < main_loop_bound; i += 16) { - const float32x4_t lhs0 = vld1q_f32(__lhs + i); - const float32x4_t rhs0 = vld1q_f32(__rhs + i); - sum_vec0 = vfmaq_f32(sum_vec0, lhs0, rhs0); - - const float32x4_t lhs1 = vld1q_f32(__lhs + i + 4); - const float32x4_t rhs1 = vld1q_f32(__rhs + i + 4); - sum_vec1 = vfmaq_f32(sum_vec1, lhs1, rhs1); - - const float32x4_t lhs2 = vld1q_f32(__lhs + i + 8); - const float32x4_t rhs2 = vld1q_f32(__rhs + i + 8); - sum_vec2 = vfmaq_f32(sum_vec2, lhs2, rhs2); - - const float32x4_t lhs3 = vld1q_f32(__lhs + i + 12); - const float32x4_t rhs3 = vld1q_f32(__rhs + i + 12); - sum_vec3 = vfmaq_f32(sum_vec3, lhs3, rhs3); + const size_t block_size = 16; + const size_t len_aligned = len & ~(block_size - 1); + + for (; i < len_aligned; i += block_size) { + float32x4_t lhs_vec0 = vld1q_f32(__lhs + i); + float32x4_t lhs_vec1 = vld1q_f32(__lhs + i + 4); + float32x4_t lhs_vec2 = vld1q_f32(__lhs + i + 8); + float32x4_t lhs_vec3 = vld1q_f32(__lhs + i + 12); + + float32x4_t rhs_vec0 = vld1q_f32(__rhs + i); + float32x4_t rhs_vec1 = vld1q_f32(__rhs + i + 4); + float32x4_t rhs_vec2 = vld1q_f32(__rhs + i + 8); + float32x4_t rhs_vec3 = vld1q_f32(__rhs + i + 12); + + sum_vec = vfmaq_f32(sum_vec, lhs_vec0, rhs_vec0); + sum_vec = vfmaq_f32(sum_vec, lhs_vec1, rhs_vec1); + sum_vec = vfmaq_f32(sum_vec, lhs_vec2, rhs_vec2); + sum_vec = vfmaq_f32(sum_vec, lhs_vec3, rhs_vec3); } - const size_t unroll4_bound = len - 3; - for (; i < unroll4_bound; i += 4) { - const float32x4_t lhs_vec = vld1q_f32(__lhs + i); - const float32x4_t rhs_vec = vld1q_f32(__rhs + i); - sum_vec0 = vfmaq_f32(sum_vec0, lhs_vec, rhs_vec); + for (; i + 3 < len; i += 4) { + float32x4_t lhs_vec = vld1q_f32(__lhs + i); + float32x4_t rhs_vec = vld1q_f32(__rhs + i); + sum_vec = vfmaq_f32(sum_vec, lhs_vec, rhs_vec); } - sum_vec0 = vaddq_f32(sum_vec0, sum_vec1); - sum_vec2 = vaddq_f32(sum_vec2, sum_vec3); - sum_vec0 = vaddq_f32(sum_vec0, sum_vec2); + float result = vaddvq_f32(sum_vec); - // Reduce - float result = vaddvq_f32(sum_vec0); for (; i < len; ++i) { result += __lhs[i] * __rhs[i]; } *__out = result; @@ -60,17 +53,36 @@ struct VectorDotProduct<__ArmArchTag, mllm_fp32_t, mllm_fp32_t, mllm_fp32_t> { template<> struct MulFromConst<__ArmArchTag, mllm_fp32_t, mllm_fp32_t> { static MLLM_FORCE_INLINE void run(mllm_fp32_t* __restrict__ __from, const mllm_fp32_t const_v, size_t len) { + float32x4_t const_vec = vdupq_n_f32(const_v); + size_t i = 0; - const size_t simd_width = 4; - if (len >= simd_width) { - float32x4_t const_vec = vdupq_n_f32(const_v); - size_t simd_len = (len / simd_width) * simd_width; - for (; i < simd_len; i += simd_width) { - float32x4_t data_vec = vld1q_f32(&__from[i]); - data_vec = vmulq_f32(data_vec, const_vec); - vst1q_f32(&__from[i], data_vec); - } + const size_t block_size = 16; + const size_t len_aligned = len & ~(block_size - 1); + + for (; i < len_aligned; i += block_size) { + float32x4_t vec0 = vld1q_f32(__from + i); + float32x4_t vec1 = vld1q_f32(__from + i + 4); + float32x4_t vec2 = vld1q_f32(__from + i + 8); + float32x4_t vec3 = vld1q_f32(__from + i + 12); + + // FIXME: FMA may be muster than MUL + vec0 = vmulq_f32(vec0, const_vec); + vec1 = vmulq_f32(vec1, const_vec); + vec2 = vmulq_f32(vec2, const_vec); + vec3 = vmulq_f32(vec3, const_vec); + + vst1q_f32(__from + i, vec0); + vst1q_f32(__from + i + 4, vec1); + vst1q_f32(__from + i + 8, vec2); + vst1q_f32(__from + i + 12, vec3); + } + + for (; i + 3 < len; i += 4) { + float32x4_t vec = vld1q_f32(__from + i); + vec = vmulq_f32(vec, const_vec); + vst1q_f32(__from + i, vec); } + for (; i < len; ++i) { __from[i] *= const_v; } } }; @@ -79,22 +91,65 @@ template<> struct FMAConstArray<__ArmArchTag, mllm_fp32_t, mllm_fp32_t, mllm_fp32_t> { static MLLM_FORCE_INLINE void run(mllm_fp32_t* __restrict__ acc_o, const mllm_fp32_t acc_s, const mllm_fp32_t* __restrict__ v_token, size_t len) { + float32x4_t acc_vec = vdupq_n_f32(acc_s); + size_t i = 0; - const size_t simd_width = 4; - - if (len >= simd_width) { - float32x4_t acc_s_vec = vdupq_n_f32(acc_s); - - size_t simd_len = (len / simd_width) * simd_width; - for (; i < simd_len; i += simd_width) { - float32x4_t v_token_vec = vld1q_f32(&v_token[i]); - float32x4_t acc_o_vec = vld1q_f32(&acc_o[i]); - acc_o_vec = vfmaq_f32(acc_o_vec, acc_s_vec, v_token_vec); - vst1q_f32(&acc_o[i], acc_o_vec); - } + const size_t block_size = 16; + const size_t len_aligned = len & ~(block_size - 1); + + for (; i < len_aligned; i += block_size) { + float32x4_t acc0 = vld1q_f32(acc_o + i); + float32x4_t token0 = vld1q_f32(v_token + i); + acc0 = vfmaq_f32(acc0, token0, acc_vec); + vst1q_f32(acc_o + i, acc0); + + float32x4_t acc1 = vld1q_f32(acc_o + i + 4); + float32x4_t token1 = vld1q_f32(v_token + i + 4); + acc1 = vfmaq_f32(acc1, token1, acc_vec); + vst1q_f32(acc_o + i + 4, acc1); + + float32x4_t acc2 = vld1q_f32(acc_o + i + 8); + float32x4_t token2 = vld1q_f32(v_token + i + 8); + acc2 = vfmaq_f32(acc2, token2, acc_vec); + vst1q_f32(acc_o + i + 8, acc2); + + float32x4_t acc3 = vld1q_f32(acc_o + i + 12); + float32x4_t token3 = vld1q_f32(v_token + i + 12); + acc3 = vfmaq_f32(acc3, token3, acc_vec); + vst1q_f32(acc_o + i + 12, acc3); } + + for (; i + 3 < len; i += 4) { + float32x4_t acc = vld1q_f32(acc_o + i); + float32x4_t token = vld1q_f32(v_token + i); + acc = vfmaq_f32(acc, token, acc_vec); + vst1q_f32(acc_o + i, acc); + } + for (; i < len; ++i) { acc_o[i] += acc_s * v_token[i]; } } }; +template<> +struct FilledWithConst<__ArmArchTag, mllm_fp32_t> { + static MLLM_FORCE_INLINE void run(mllm_fp32_t* __restrict__ a, const mllm_fp32_t v, size_t len) { + float32x4_t const_vec = vdupq_n_f32(v); + + size_t i = 0; + const size_t block_size = 16; + const size_t len_aligned = len & ~(block_size - 1); + + for (; i < len_aligned; i += block_size) { + vst1q_f32(a + i, const_vec); + vst1q_f32(a + i + 4, const_vec); + vst1q_f32(a + i + 8, const_vec); + vst1q_f32(a + i + 12, const_vec); + } + + for (; i + 3 < len; i += 4) { vst1q_f32(a + i, const_vec); } + + for (; i < len; ++i) { a[i] = v; } + } +}; + } // namespace mllm::cpu::radix_attn::details diff --git a/mllm/backends/cpu/ops/RadixAttnOp.cpp b/mllm/backends/cpu/ops/RadixAttnOp.cpp index 5f7ff6c6f..cca4ace3e 100644 --- a/mllm/backends/cpu/ops/RadixAttnOp.cpp +++ b/mllm/backends/cpu/ops/RadixAttnOp.cpp @@ -19,9 +19,9 @@ void CPURadixAttnOp::forward(const std::vector& inputs, std::vector& shape, float start, float end, return Context::instance().buildOpAndSubmitTask( OpTypes::kFill, aops::FillOpOptions{ - .type = aops::FillOpTypes::kRandom, .start = start, .end = end, .seed = Context::instance().getRandomSeed()}, + .type = aops::FillOpTypes::kRandom, .start = start, .end = end, .seed = Context::instance().getRandomState()}, {i})[0]; } diff --git a/mllm/engine/Context.cpp b/mllm/engine/Context.cpp index 17363564c..5c9f0ffb6 100644 --- a/mllm/engine/Context.cpp +++ b/mllm/engine/Context.cpp @@ -94,10 +94,20 @@ SessionTCB::ptr_t Context::thisThread() { SessionTCB::ptr_t Context::mainThread() { return main_thread_; } -void Context::setRandomSeed(uint64_t seed) { random_seed_ = seed; } +void Context::setRandomSeed(uint64_t seed) { + random_seed_ = seed; + random_state_ = seed; +} uint64_t Context::getRandomSeed() { return random_seed_; } +uint64_t Context::getRandomState() { + auto ret = random_state_; + std::mt19937 gen(random_state_); + random_state_ = gen(); + return ret; +} + uint64_t Context::curTime() { auto now = std::chrono::high_resolution_clock::now(); auto duration = now.time_since_epoch(); @@ -106,20 +116,12 @@ uint64_t Context::curTime() { std::unordered_map Context::refSessionThreads() { return session_threads_; } -void Context::setPrintPrecision(int precision) { - print_precision_ = precision; -} +void Context::setPrintPrecision(int precision) { print_precision_ = precision; } -int Context::getPrintPrecision() const { - return print_precision_; -} +int Context::getPrintPrecision() const { return print_precision_; } -void Context::setPrintMaxElementsPerDim(int max_elements) { - print_max_elements_per_dim_ = max_elements; -} +void Context::setPrintMaxElementsPerDim(int max_elements) { print_max_elements_per_dim_ = max_elements; } -int Context::getPrintMaxElementsPerDim() const { - return print_max_elements_per_dim_; -} +int Context::getPrintMaxElementsPerDim() const { return print_max_elements_per_dim_; } -} // namespace mllm \ No newline at end of file +} // namespace mllm diff --git a/mllm/engine/Context.hpp b/mllm/engine/Context.hpp index 93fb1bea1..fe0b0bc84 100644 --- a/mllm/engine/Context.hpp +++ b/mllm/engine/Context.hpp @@ -46,6 +46,8 @@ class Context { uint64_t getRandomSeed(); + uint64_t getRandomState(); + uint64_t curTime(); std::unordered_map refSessionThreads(); @@ -64,6 +66,7 @@ class Context { Context(); uint64_t random_seed_ = 42; + uint64_t random_state_ = 42; SessionTCB::ptr_t main_thread_; std::unordered_map session_threads_; diff --git a/mllm/mllm.cpp b/mllm/mllm.cpp index b93f17cbf..b6851c425 100644 --- a/mllm/mllm.cpp +++ b/mllm/mllm.cpp @@ -23,6 +23,8 @@ void setLogLevel(const LogLevel& level) { ::mllm::Logger::level() = level; } void setRandomSeed(uint64_t seed) { Context::instance().setRandomSeed(seed); } +int64_t getRandomState() { return Context::instance().getRandomState(); } + void setMaximumNumThreads(uint32_t num_threads) { // TODO } diff --git a/mllm/mllm.hpp b/mllm/mllm.hpp index c8026955a..d1268b7c4 100644 --- a/mllm/mllm.hpp +++ b/mllm/mllm.hpp @@ -156,6 +156,8 @@ void setLogLevel(const LogLevel& level); void setRandomSeed(uint64_t seed); +int64_t getRandomState(); + void setMaximumNumThreads(uint32_t num_threads); void setPrintPrecision(int precision); diff --git a/tests/cpu/KernelTest.cpp b/tests/cpu/KernelTest.cpp index 1d8e1853a..9a2d19803 100644 --- a/tests/cpu/KernelTest.cpp +++ b/tests/cpu/KernelTest.cpp @@ -253,42 +253,6 @@ TEST_F(ElementwiseKernelTest, DivFloat16) { true); } -TEST_F(ElementwiseKernelTest, DivInt8) { - EXPECT_EQ(DivInt8Test({ - {42}, - {5, 5}, - {16, 16}, - {16, 18}, - {32, 32}, - {128, 128, 128}, - }), - true); -} - -TEST_F(ElementwiseKernelTest, DivInt16) { - EXPECT_EQ(DivInt16Test({ - {42}, - {5, 5}, - {16, 16}, - {16, 18}, - {32, 32}, - {128, 128, 128}, - }), - true); -} - -TEST_F(ElementwiseKernelTest, DivInt32) { - EXPECT_EQ(DivInt32Test({ - {42}, - {5, 5}, - {16, 16}, - {16, 18}, - {32, 32}, - {128, 128, 128}, - }), - true); -} - //===----------------------------------------------------------------------===// // Element wise ADD Scalar. // @@ -789,6 +753,12 @@ TEST_F(ReduceKernelTest, SumFloat32) { // } #endif +//===----------------------------------------------------------------------===// +// Scatter 2 Shards Attn +//===----------------------------------------------------------------------===// +#include "Scatter2ShardsKernelTest.hpp" +TEST_F(Scatter2ShardsKernelTest, one) { EXPECT_EQ(testScatter2Shards(), true); } + //===----------------------------------------------------------------------===// // Paged Attn //===----------------------------------------------------------------------===// diff --git a/tests/cpu/RadixAttnKernel.hpp b/tests/cpu/RadixAttnKernel.hpp index 2bfe63854..d547d43ff 100644 --- a/tests/cpu/RadixAttnKernel.hpp +++ b/tests/cpu/RadixAttnKernel.hpp @@ -1,6 +1,7 @@ // Copyright (c) MLLM Team. // Licensed under the MIT License. +#include #include #include "mllm/mllm.hpp" @@ -27,10 +28,8 @@ class RadixAttnModule : public nn::Module { }; class EagerModule : public nn::Module { - nn::CausalMask mask_; - public: - EagerModule() : nn::Module() { mask_ = reg("mask"); } + EagerModule() : nn::Module() {} std::vector forward(const std::vector& inputs, const std::vector& args) override { // inputs is Q, K_indices, V_indices @@ -51,8 +50,23 @@ class EagerModule : public nn::Module { // Attention Weight // [B, H, S, S] auto attn = nn::functional::matmul(Q, K, false, true) * (1.f / sqrtf(head_dim)); - attn = mask_(attn); - attn = nn::functional::softmax(attn, -1); + + // Make mask + auto S_Q = Q.shape()[2]; + auto S_KV = K.shape()[2]; + auto mask = Tensor::zeros({1, 1, S_Q, S_KV}); + { + auto ptr = mask.ptr(); + int __delta = S_KV - S_Q; + for (int s_q_idx = 0; s_q_idx < S_Q; s_q_idx++) { + int S_KV_BOUND = std::min(__delta + s_q_idx + 1, S_KV); + for (int s_kv_idx = S_KV_BOUND; s_kv_idx < S_KV; s_kv_idx++) { + ptr[s_q_idx * S_KV + s_kv_idx] = -std::numeric_limits::infinity(); + } + } + } + + attn = nn::functional::softmax(attn + mask, -1); // [B, H, S, D] auto output = nn::functional::matmul(attn, V); // [B, S, H, D] @@ -106,7 +120,7 @@ class RadixAttnKernelTest : public KernelTest { .mmap_type = mllm::prefix_cache::ZenFSBlobMMapType::kAnonymous, }}}; EagerModule eager_attn; - RadixAttnModule radix_attn; + RadixAttnModule radix_attn(H_Q, H_KV); prefix_cache::Cache cache(opt); // Create Q, K, V @@ -131,17 +145,33 @@ class RadixAttnKernelTest : public KernelTest { auto v_cache_indices = Tensor::refVectorData(v_cache_ptrs, {S_KV}, kInt64); nn::functional::scatter2Shards(K, k_cache_indices, 1); - nn::functional::scatter2Shards(K, v_cache_indices, 1); + nn::functional::scatter2Shards(V, v_cache_indices, 1); + + // loop check if shards is correct + for (int i = 0; i < S_KV; ++i) { + auto prd_ptr = (float*)k_cache_ptrs[i]; + auto gt_ptr = K.offsettedPtr({0, i, 0, 0}); + for (int j = 0; j < H_KV * D; ++j) { + if (prd_ptr[j] != gt_ptr[j]) { + print("Error at: ", i, j); + print("prd: ", prd_ptr[j], " gt: ", gt_ptr[j]); + return false; + } + } + } // Compute eager Tensor gt = eager_attn(Q, K, V)[0]; Tensor predict = radix_attn(Q, k_cache_indices, v_cache_indices)[0]; - Dbg(); // Compare - auto result = test::allClose(gt, predict); + // rtol and atol set to 1e-2f is because: + // 1. The eager softmax is approximate, but radix is not. + auto result = test::allClose(gt, predict, 1e-2f, 1e-2f); if (!result) { print(result); + print("S_Q and S_KV is", S_Q, S_KV); + print(predict); return false; } return true; diff --git a/tests/cpu/Scatter2ShardsKernelTest.hpp b/tests/cpu/Scatter2ShardsKernelTest.hpp new file mode 100644 index 000000000..d575dd529 --- /dev/null +++ b/tests/cpu/Scatter2ShardsKernelTest.hpp @@ -0,0 +1,28 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/mllm.hpp" + +#include "KernelTestHelper.hpp" + +class Scatter2ShardsKernelTest : public KernelTest { + public: + Scatter2ShardsKernelTest() = default; + ~Scatter2ShardsKernelTest() override = default; + + bool testScatter2Shards() { + // B, S, H, D + auto Q = mllm::Tensor::random({1, 100, 8, 64}, -10, 10); + auto P = mllm::Tensor::zeros({1, 100, 8, 64}); + + std::vector indices; + indices.reserve(100); + for (int i = 0; i < 100; ++i) { indices.push_back((char*)P.offsettedPtr({0, i, 0, 0})); } + + auto tensor_indices = mllm::Tensor::refVectorData(indices, {100}, mllm::kInt64); + + mllm::nn::functional::scatter2Shards(Q, tensor_indices, 1); + + return mllm::test::allClose(Q, P).is_close; + } +}; diff --git a/tests/engine/CMakeLists.txt b/tests/engine/CMakeLists.txt index e0d673314..9d9b48161 100644 --- a/tests/engine/CMakeLists.txt +++ b/tests/engine/CMakeLists.txt @@ -17,3 +17,7 @@ target_include_directories(Mllm-Test-Engine-PrefixCache PRIVATE ${MLLM_INCLUDE_D add_executable(Mllm-Test-Engine-Service ServiceTest.cpp) target_link_libraries(Mllm-Test-Engine-Service PRIVATE MllmRT MllmCPUBackend) target_include_directories(Mllm-Test-Engine-Service PRIVATE ${MLLM_INCLUDE_DIR}) + +add_executable(Mllm-Test-Engine-RandomStates RandomStatesTest.cpp) +target_link_libraries(Mllm-Test-Engine-RandomStates PRIVATE MllmRT MllmCPUBackend) +target_include_directories(Mllm-Test-Engine-RandomStates PRIVATE ${MLLM_INCLUDE_DIR}) diff --git a/tests/engine/RandomStatesTest.cpp b/tests/engine/RandomStatesTest.cpp new file mode 100644 index 000000000..446432eed --- /dev/null +++ b/tests/engine/RandomStatesTest.cpp @@ -0,0 +1,12 @@ +#include + +#include "mllm/mllm.hpp" + +using namespace mllm; // NOLINT + +int main() { + mllm::initializeContext(); + mllm::setRandomSeed(42); + for (int i = 0; i < 10; i++) { print(mllm::getRandomState()); } + mllm::memoryReport(); +} From 785498240763e9436c3c96aff5892f8f4e8da25c Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Thu, 9 Oct 2025 21:17:39 +0800 Subject: [PATCH 5/8] feat(service): add qwen3 service example and enhance session management - Add new `qwen3_service` example with CMake build support - Introduce `startService`, `stopService`, and `insertSession` APIs for simplified service control - Refactor `Session` class to remove dependency on `ARGeneration` and improve flexibility - Enhance `Qwen3Session` with thinking token handling and improved cache management - Improve error messages in model loading and extend rotary positional embedding support - Update radix attention kernel comments and parallelization hints - Remove obsolete PyPI README content The changes enable better service deployment and model session handling, especially for Qwen3-based models with thinking token capabilities. --- .github/workflows/pymllm-nightly-macos.yml | 0 .github/workflows/pymllm-publish-macos.yml | 0 .github/workflows/pymllm-publish-x86.yml | 0 examples/CMakeLists.txt | 1 + examples/qwen3_service/CMakeLists.txt | 3 + examples/qwen3_service/main.cpp | 35 +++++++++ .../kernels/common/radix_attn/fwd_bshd.hpp | 6 +- mllm/engine/service/Service.cpp | 8 ++ mllm/engine/service/Service.hpp | 8 +- mllm/engine/service/Session.cpp | 4 +- mllm/engine/service/Session.hpp | 6 +- mllm/models/qwen3/configuration_qwen3.hpp | 2 + mllm/models/qwen3/modeling_qwen3_service.hpp | 74 ++++++++++++++++--- pymllm/README.md | 23 ------ tox.ini | 0 15 files changed, 126 insertions(+), 44 deletions(-) create mode 100644 .github/workflows/pymllm-nightly-macos.yml create mode 100644 .github/workflows/pymllm-publish-macos.yml create mode 100644 .github/workflows/pymllm-publish-x86.yml create mode 100644 tox.ini diff --git a/.github/workflows/pymllm-nightly-macos.yml b/.github/workflows/pymllm-nightly-macos.yml new file mode 100644 index 000000000..e69de29bb diff --git a/.github/workflows/pymllm-publish-macos.yml b/.github/workflows/pymllm-publish-macos.yml new file mode 100644 index 000000000..e69de29bb diff --git a/.github/workflows/pymllm-publish-x86.yml b/.github/workflows/pymllm-publish-x86.yml new file mode 100644 index 000000000..e69de29bb diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 4dd546641..f193f34c6 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -5,3 +5,4 @@ add_subdirectory(qwen2_5vl_tracer) add_subdirectory(llama) add_subdirectory(minicpm_o) add_subdirectory(qwen3) +add_subdirectory(qwen3_service) diff --git a/examples/qwen3_service/CMakeLists.txt b/examples/qwen3_service/CMakeLists.txt index e69de29bb..31faa395e 100644 --- a/examples/qwen3_service/CMakeLists.txt +++ b/examples/qwen3_service/CMakeLists.txt @@ -0,0 +1,3 @@ +add_executable(mllm-qwen3-service main.cpp) +target_link_libraries(mllm-qwen3-service PRIVATE MllmRT MllmCPUBackend) +target_include_directories(mllm-qwen3-service PRIVATE ${MLLM_INCLUDE_DIR}) diff --git a/examples/qwen3_service/main.cpp b/examples/qwen3_service/main.cpp index e69de29bb..d9bed079c 100644 --- a/examples/qwen3_service/main.cpp +++ b/examples/qwen3_service/main.cpp @@ -0,0 +1,35 @@ +#include +#include +#include +#include + +MLLM_MAIN({ + auto& model_path = mllm::Argparse::add("-m|--model_path").help("Model path").required(true); + mllm::Argparse::parse(argc, argv); + + auto qwen3_session = std::make_shared(); + qwen3_session->fromPreTrain(model_path.get()); + mllm::service::insertSession("mllmTeam/Qwen3-0.6B-w4a32kai", qwen3_session); + mllm::service::startService(); + + // Build Request + mllm::service::RequestPayload req; + req["model"] = "mllmTeam/Qwen3-0.6B-w4a32kai"; + mllm::service::RequestPayload one_msg; + one_msg["role"] = "user"; + one_msg["content"] = "Say Hello in Chinese, English, French and German."; + req["messages"] = json::array({one_msg}); + req["id"] = "chat-01"; + req["enable_thinking"] = true; + mllm::service::sendRequest(req.dump()); + + // getResponse will block until the response is ready, while(true) is ok in this case. + while (true) { + std::string resp = mllm::service::getResponse("chat-01"); + auto j = nlohmann::json::parse(resp); + std::cout << j["data"].get() << std::flush; + if (j["finished"]) break; + } + + mllm::service::stopService(); +}) diff --git a/mllm/backends/cpu/kernels/common/radix_attn/fwd_bshd.hpp b/mllm/backends/cpu/kernels/common/radix_attn/fwd_bshd.hpp index d562b1a6a..1834618b0 100644 --- a/mllm/backends/cpu/kernels/common/radix_attn/fwd_bshd.hpp +++ b/mllm/backends/cpu/kernels/common/radix_attn/fwd_bshd.hpp @@ -38,7 +38,10 @@ void fwd_bhsd(int32_t B, int32_t H_Q, int32_t H_KV, int32_t S_Q, int32_t S_KV, i // Loop on batch size. for (int b_idx = 0; b_idx < B; ++b_idx) { - // Loop on head dim, should be made parallel + // FIXME: Loop on SEQUENCE may faster? + // seq_q [ head_q [ seq_kv ] ] + + // Loop on HEAD dim, should be made parallel MLLM_CONDITIONAL_PARALLEL_FOR(thread_count > 1, thread_count, h_q_idx, 0, H_Q, 1, { int h_kv_id = h_q_idx / head_repeat_times; @@ -47,7 +50,6 @@ void fwd_bhsd(int32_t B, int32_t H_Q, int32_t H_KV, int32_t S_Q, int32_t S_KV, i const __QDType* q_token = __q + b_idx * H_Q * S_Q * D + s_q_idx * H_Q * D + h_q_idx * D; __ODType* acc_o = __out + b_idx * H_Q * S_Q * D + s_q_idx * H_Q * D + h_q_idx * D; - // FIXME: Maybe we can make this loop faster. details::FilledWithConst<__ArchTag, __ODType>::run(acc_o, 0, D); __AccDType scores_max = -std::numeric_limits<__AccDType>::infinity(); diff --git a/mllm/engine/service/Service.cpp b/mllm/engine/service/Service.cpp index 4c9635b6f..ca6efab45 100644 --- a/mllm/engine/service/Service.cpp +++ b/mllm/engine/service/Service.cpp @@ -121,6 +121,14 @@ void Service::workerLoop() { } } +void startService(size_t worker_threads) { Service::instance().start(worker_threads); } + +void stopService() { Service::instance().stop(); } + +void insertSession(const std::string& session_id, const std::shared_ptr& session) { + Service::instance().sessionPool().registerSession(session_id, session); +} + int sendRequest(const std::string& json_str) { if (json_str.empty()) return -1; try { diff --git a/mllm/engine/service/Service.hpp b/mllm/engine/service/Service.hpp index 54321678f..5726b7311 100644 --- a/mllm/engine/service/Service.hpp +++ b/mllm/engine/service/Service.hpp @@ -107,8 +107,14 @@ class Service { std::vector workers_; }; -// Golang, Python SDK should bind those two function and provide high level Network API. +// Golang, Python SDK should bind those 5 function and provide high level Network API. // We just focus on the server side. +void startService(size_t worker_threads = 1); + +void stopService(); + +void insertSession(const std::string& session_id, const std::shared_ptr& session); + int sendRequest(const std::string& json_str); Response getResponse(const std::string& id); diff --git a/mllm/engine/service/Session.cpp b/mllm/engine/service/Session.cpp index 818fc4b07..9fb649e1e 100644 --- a/mllm/engine/service/Session.cpp +++ b/mllm/engine/service/Session.cpp @@ -10,11 +10,9 @@ namespace mllm::service { -Session::Session(std::shared_ptr model) : model_(std::move(model)) {} - void Session::fromPreTrain(const std::string& model_path) { MLLM_EMPTY_SCOPE; } -NoneSession::NoneSession() : Session(nullptr) {} +NoneSession::NoneSession() : Session() {} void NoneSession::streamGenerate(const nlohmann::json& request, const std::function& callback) { diff --git a/mllm/engine/service/Session.hpp b/mllm/engine/service/Session.hpp index 1f1efe1bf..f8e41084e 100644 --- a/mllm/engine/service/Session.hpp +++ b/mllm/engine/service/Session.hpp @@ -4,7 +4,6 @@ #include #include -#include "mllm/models/ARGeneration.hpp" namespace mllm::service { @@ -12,15 +11,12 @@ class Session { public: using ptr_t = std::shared_ptr; - explicit Session(std::shared_ptr model); + Session() = default; virtual void streamGenerate(const nlohmann::json& request, const std::function& callback) = 0; virtual void fromPreTrain(const std::string& model_path); - - private: - std::shared_ptr model_; }; class NoneSession final : public Session { diff --git a/mllm/models/qwen3/configuration_qwen3.hpp b/mllm/models/qwen3/configuration_qwen3.hpp index 63c7d0cdb..7df5137ca 100644 --- a/mllm/models/qwen3/configuration_qwen3.hpp +++ b/mllm/models/qwen3/configuration_qwen3.hpp @@ -51,6 +51,8 @@ struct Qwen3Config : protected ConfigFile { bool tie_word_embeddings = true; int32_t max_cache_length = 2048; int32_t end_of_text_token_id = 151645; + int32_t thinking_start_token_id = 151667; + int32_t thinking_end_token_id = 151668; aops::LinearImplTypes linear_impl_type = aops::LinearImplTypes::kDefault; }; diff --git a/mllm/models/qwen3/modeling_qwen3_service.hpp b/mllm/models/qwen3/modeling_qwen3_service.hpp index ba6cbb549..c55b8e03b 100644 --- a/mllm/models/qwen3/modeling_qwen3_service.hpp +++ b/mllm/models/qwen3/modeling_qwen3_service.hpp @@ -14,6 +14,7 @@ #include "mllm/nn/Functional.hpp" #include "mllm/utils/Enumerate.hpp" #include "mllm/models/ARGeneration.hpp" +#include "mllm/preprocessor/tokenizers/Unicode.hpp" #include "mllm/models/qwen3/tokenization_qwen3.hpp" #include "mllm/models/qwen3/configuration_qwen3.hpp" @@ -368,8 +369,9 @@ class Qwen3ForCausalLM : public ARGeneration, public nn::Module { }; } - private: const Qwen3Config& cfg; + + private: Qwen3Text llm; nn::Linear lm_head_; bool tie_word_embeddings_; @@ -377,7 +379,12 @@ class Qwen3ForCausalLM : public ARGeneration, public nn::Module { class Qwen3Session final : public ::mllm::service::Session { public: - Qwen3Session(); + Qwen3Session() = default; + + std::size_t findThinkStartToken(const std::vector& output_ids) { + auto it = std::find(output_ids.begin(), output_ids.end(), model_->cfg.thinking_start_token_id); + return std::distance(output_ids.begin(), it); + } void streamGenerate(const nlohmann::json& request, const std::function& callback) override { @@ -413,14 +420,61 @@ class Qwen3Session final : public ::mllm::service::Session { args["v_cache_addrs"] = &v_cache_addrs_; args["prefix_cache_context"] = cache_.get(); - // TODO other args for sampling should be read from request. + // Has temperature, top_k, top_p, max_length, do_sample. + args["temperature"] = request.value("temperature", 1.0f); + args["top_k"] = request.value("top_k", 0); + args["top_p"] = request.value("top_p", 1.0f); + args["max_length"] = request.value("max_length", 1024); + args["do_sample"] = request.value("do_sample", true); // Iteration start - model_->streamGenerate(input, args, [](int64_t idx) { - // TODO Callback + int64_t package_cnt = 0; + model_->streamGenerate(input, args, [this, &full_seq_idx, &package_cnt, &callback](int64_t idx) { + bool finished = false; + std::string ret_token; + if (idx == model_->cfg.end_of_text_token_id) { + finished = true; + ret_token = ""; + } else { + finished = false; + std::string t = preprocessor::wideString2Utf8String(tokenizer_->detokenize(idx)); + + // Update full_seq_idx to include the new token for Radix Tree to use. + full_seq_idx.push_back(idx); + } + + // Callback will send this json to the response pool for user to consume. + callback(nlohmann::json{}, finished); + + package_cnt++; }); - // TODO: process full_seq_idx and k_cache_addrs_/v_cache_addrs_. Only none thinking budget should be insert in radix tree. + // Post process full_seq_idx and k_cache_addrs_/v_cache_addrs_. Only none thinking budget should be insert in radix tree. + // + // NOTE: We will drop everything after the thinking_start_token_idx(include it). + // Suppose: Only one token in the sequence. + // + // e.g.: + // <|im_start|>user + // hello<|im_end|> + // <|im_start|>assistant + // + // + // + // hello! + // <|endoftext|> + // + // In radic tree, we will only save: + // <|im_start|>user + // hello<|im_end|> + // <|im_start|>assistant + // + // Explain: That because Qwen3 and other CoT model will remove thinking budget, which means the answer "hello"'s rope is + // changed in 2ed turn. + auto thinking_end_token_idx = findThinkStartToken(full_seq_idx); + full_seq_idx.resize(thinking_end_token_idx); + for (auto& k_vec : k_cache_addrs_) k_vec.resize(thinking_end_token_idx); + for (auto& v_vec : v_cache_addrs_) v_vec.resize(thinking_end_token_idx); // Insert generated tokens to the cache. cache_->promote(full_seq_idx, k_cache_addrs_, v_cache_addrs_); @@ -436,13 +490,13 @@ class Qwen3Session final : public ::mllm::service::Session { fs::path config_file = root / "config.json"; fs::path model_file = root / "model.mllm"; fs::path tokenizer_file = root / "tokenizer.json"; - if (!fs::exists(config_file)) throw std::runtime_error("config.json not found"); - if (!fs::exists(model_file)) throw std::runtime_error("model.mllm not found"); - if (!fs::exists(tokenizer_file)) throw std::runtime_error("tokenizer.json not found"); + if (!fs::exists(config_file)) throw std::runtime_error(config_file.string() + " not found"); + if (!fs::exists(model_file)) throw std::runtime_error(model_file.string() + " not found"); + if (!fs::exists(tokenizer_file)) throw std::runtime_error(tokenizer_file.string() + " not found"); auto cfg = Qwen3Config(config_file.string()); model_ = std::make_shared(cfg); - model_->load(load(model_file.string())); + model_->load(load(model_file.string(), ModelFileVersion::kV2)); tokenizer_ = std::make_shared(tokenizer_file.string()); cache_ = std::make_shared(prefix_cache::CacheOptions{ diff --git a/pymllm/README.md b/pymllm/README.md index 5240a5fa4..e69de29bb 100644 --- a/pymllm/README.md +++ b/pymllm/README.md @@ -1,23 +0,0 @@ -This file is used to request a PyPI organization. Should not be packaged in production. - -**PyPI Staff** - -Hello ubios! Upon reviewing this organization application request, we were unable to determine your affiliation with the provided URL (https://ubiquitouslearning.github.io/mllm/). Please let us know if there is some other way for us to verify this, otherwise we will decline this request. - -**chenghuaWang** - -Hello, -Thank you for your response. - -I am the maintainer of the project hosted at https://ubiquitouslearning.github.io/mllm/ and a contributor/administrator of this GitHub repository. I am currently developing the next major version (v2) of mllm; you can see my commits at https://github.com/UbiquitousLearning/mllm/commits/v2/ under the username chenghuaWang. The “mllm” organization is tied to this project, and I would like to create it on PyPI to publish and manage the related packages. - -For verification, please find below: - -1. My GitHub profile: https://github.com/chenghuaWang -2. Repository link: https://github.com/ubiquitouslearning/mllm -3. I have added a temporary note in the repository to confirm this request: https://github.com/UbiquitousLearning/mllm/blob/v2/pymllm/README.md (our conversation is quoted there). - -Let me know if any further proof is needed. - -Best regards, -chenghua.Wang diff --git a/tox.ini b/tox.ini new file mode 100644 index 000000000..e69de29bb From b112dbb89faffff69d01bcc3061aad5dd31452c7 Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Thu, 9 Oct 2025 22:05:03 +0800 Subject: [PATCH 6/8] fix(prefix_cache): initialize k/v cache addresses with correct size Ensure that `k_cache_addresses` and `v_cache_addresses` are properly resized to match the number of transformer blocks during radix tree search failure. fix(qwen3): correct cache address gathering and physical address mapping Replace `std::ranges::copy` with `insert` for efficient appending of cache addresses. Fix incorrect use of `k_cache_addr` where `v_cache_addr` should be used when mapping physical addresses for value cache. --- mllm/engine/prefix_cache/RadixTree.cpp | 4 ++-- mllm/models/qwen3/modeling_qwen3_service.hpp | 12 +++++++++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/mllm/engine/prefix_cache/RadixTree.cpp b/mllm/engine/prefix_cache/RadixTree.cpp index 770170282..d9049c521 100644 --- a/mllm/engine/prefix_cache/RadixTree.cpp +++ b/mllm/engine/prefix_cache/RadixTree.cpp @@ -157,8 +157,8 @@ RadixSearchResult RadixTree::search(const RadixTreeNodeKey& key) { result.success = false; result.path = path; result.matched_length = 0; - result.k_cache_addresses = {}; - result.v_cache_addresses = {}; + result.k_cache_addresses.resize(options_.transformer_blocks_num); + result.v_cache_addresses.resize(options_.transformer_blocks_num); } return result; diff --git a/mllm/models/qwen3/modeling_qwen3_service.hpp b/mllm/models/qwen3/modeling_qwen3_service.hpp index c55b8e03b..799f69820 100644 --- a/mllm/models/qwen3/modeling_qwen3_service.hpp +++ b/mllm/models/qwen3/modeling_qwen3_service.hpp @@ -203,8 +203,14 @@ class Qwen3Attention final : public nn::Module { nn::functional::scatter2Shards(value_states, v_wait_for_promote, 1); // Gather all cache to indicies tensor - std::ranges::copy((*k_cache_addr)[layer_idx_], std::back_inserter(k_addr_wait_for_promote)); - std::ranges::copy((*v_cache_addr)[layer_idx_], std::back_inserter(v_addr_wait_for_promote)); + { + auto& dst = (*k_cache_addr)[layer_idx_]; + dst.insert(dst.end(), k_addr_wait_for_promote.begin(), k_addr_wait_for_promote.end()); + } + { + auto& dst = (*v_cache_addr)[layer_idx_]; + dst.insert(dst.end(), v_addr_wait_for_promote.begin(), v_addr_wait_for_promote.end()); + } std::vector k_phy_cache_indicies; std::vector v_phy_cache_indicies; int32_t kv_cache_len = (*k_cache_addr)[layer_idx_].size(); @@ -212,7 +218,7 @@ class Qwen3Attention final : public nn::Module { v_phy_cache_indicies.reserve(kv_cache_len); for (int i = 0; i < kv_cache_len; ++i) { k_phy_cache_indicies.push_back(prefix_cache_context->physicalAddr((*k_cache_addr)[layer_idx_][i])); - v_phy_cache_indicies.push_back(prefix_cache_context->physicalAddr((*k_cache_addr)[layer_idx_][i])); + v_phy_cache_indicies.push_back(prefix_cache_context->physicalAddr((*v_cache_addr)[layer_idx_][i])); } auto k_cache = Tensor::refVectorData(k_phy_cache_indicies, {kv_cache_len}, kInt64, kCPU); auto v_cache = Tensor::refVectorData(v_phy_cache_indicies, {kv_cache_len}, kInt64, kCPU); From 33b901e6dd744469fefba712c7a26dd53b992739 Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Fri, 10 Oct 2025 15:51:08 +0800 Subject: [PATCH 7/8] feat(service): implement interactive chat loop in qwen3 service - Replace hardcoded request with interactive user input loop - Add support for multi-turn conversations with history tracking - Integrate thinking state visualization using fmt library - Handle graceful exit with /exit or /quit commands - Improve response formatting with proper JSON structure fix(cpu): support BSHD layout in RoPE operation - Extend RoPE operator to handle both BHSD and BSHD tensor layouts - Add layout type checking and dimension mapping logic - Fix assertion to allow multiple input layout types refactor(service): improve request pool shutdown handling - Change RequestPool::pop() to return std::optional - Update service worker loop to handle optional requests - Ensure worker threads join properly during shutdown - Add timestamp to response payloads - Modify response format to follow chat completion chunk structure perf(cache): optimize ZenFS blob size calculation - Simplify bit calculation using literal 64 instead of sizeof(uint64_t) - Correct blob size computation by using consistent dtype lanes fix(qwen3): correct loop initialization and template processing - Initialize loop counters properly in attention mechanisms - Refine chat template processing for better tool and thinking support - Adjust tensor shapes for sequence and position IDs - Update default sampling parameters and EOS token ID - Change cache data types from float16 to float32 for better precision style(qwen3): reformat function signature for readability - Break long function declaration into multiple lines - Improve code formatting and parameter alignment --- examples/qwen3_service/main.cpp | 90 ++++++-- mllm/backends/cpu/ops/RoPEOp.cpp | 27 ++- mllm/engine/prefix_cache/ZenFS.cpp | 4 +- mllm/engine/service/Service.cpp | 61 ++++-- mllm/engine/service/Service.hpp | 2 +- mllm/models/qwen3/modeling_qwen3_service.hpp | 214 ++++++++++--------- tools/mllm-llm-benchmark/README.md | 0 tools/mllm-vlm-benchmark/README.md | 0 8 files changed, 251 insertions(+), 147 deletions(-) create mode 100644 tools/mllm-llm-benchmark/README.md create mode 100644 tools/mllm-vlm-benchmark/README.md diff --git a/examples/qwen3_service/main.cpp b/examples/qwen3_service/main.cpp index d9bed079c..26239064d 100644 --- a/examples/qwen3_service/main.cpp +++ b/examples/qwen3_service/main.cpp @@ -1,9 +1,15 @@ -#include +#include +#include + +#include +#include + #include #include #include MLLM_MAIN({ + mllm::setLogLevel(mllm::LogLevel::kError); auto& model_path = mllm::Argparse::add("-m|--model_path").help("Model path").required(true); mllm::Argparse::parse(argc, argv); @@ -12,24 +18,74 @@ MLLM_MAIN({ mllm::service::insertSession("mllmTeam/Qwen3-0.6B-w4a32kai", qwen3_session); mllm::service::startService(); - // Build Request - mllm::service::RequestPayload req; - req["model"] = "mllmTeam/Qwen3-0.6B-w4a32kai"; - mllm::service::RequestPayload one_msg; - one_msg["role"] = "user"; - one_msg["content"] = "Say Hello in Chinese, English, French and German."; - req["messages"] = json::array({one_msg}); - req["id"] = "chat-01"; - req["enable_thinking"] = true; - mllm::service::sendRequest(req.dump()); - - // getResponse will block until the response is ready, while(true) is ok in this case. + std::vector history; + const std::string model_name = "mllmTeam/Qwen3-0.6B-w4a32kai"; + + std::cout << "Enter /exit or /quit to exit this program\n"; + while (true) { - std::string resp = mllm::service::getResponse("chat-01"); - auto j = nlohmann::json::parse(resp); - std::cout << j["data"].get() << std::flush; - if (j["finished"]) break; + std::cout << "\nUser: "; + std::string user_input; + std::getline(std::cin, user_input); + if (user_input == "/exit" || user_input == "/quit") break; + + nlohmann::json user_msg; + user_msg["role"] = "user"; + user_msg["content"] = user_input; + history.push_back(user_msg); + + nlohmann::json req; + req["model"] = model_name; + req["messages"] = history; + req["id"] = "chat-multi"; + req["enable_thinking"] = true; + mllm::service::sendRequest(req.dump()); + std::string assistant_content; + + bool thinking_states = false; + + while (true) { + std::string resp = mllm::service::getResponse("chat-multi"); + auto j = nlohmann::json::parse(resp); + + if (j.contains("choices") && j["choices"].size() > 0 && j["choices"][0].contains("delta") + && j["choices"][0]["delta"].contains("content")) { + std::string delta = j["choices"][0]["delta"]["content"].get(); + + if (delta == "") { + thinking_states = true; + fmt::print(fmt::fg(fmt::color::gray) | fmt::emphasis::bold | fmt::emphasis::underline, "Thinking...:"); + continue; + } + if (delta == "") { + thinking_states = false; + fmt::print("\n"); + continue; + } + + if (thinking_states) { + fmt::print(fmt::fg(fmt::color::gray), "{}", delta); + std::fflush(stdout); + } else { + fmt::print("{}", delta); + std::fflush(stdout); + } + + assistant_content += delta; + } + + if (j.contains("choices") && j["choices"].size() > 0 && j["choices"][0].contains("finish_reason") + && j["choices"][0]["finish_reason"].is_string() && j["choices"][0]["finish_reason"].get() == "stop") { + break; + } + } + + nlohmann::json assistant_msg; + assistant_msg["role"] = "assistant"; + assistant_msg["content"] = assistant_content; + history.push_back(assistant_msg); } mllm::service::stopService(); + return 0; }) diff --git a/mllm/backends/cpu/ops/RoPEOp.cpp b/mllm/backends/cpu/ops/RoPEOp.cpp index 326ed96f9..cb9f98131 100644 --- a/mllm/backends/cpu/ops/RoPEOp.cpp +++ b/mllm/backends/cpu/ops/RoPEOp.cpp @@ -19,13 +19,30 @@ void RoPEOpImpl::forward(const std::vector& inputs, std::vector& auto activation = inputs[0]; auto out = outputs[0]; - // Activation must in BHSD layout + // Activation must in BHSD or BSHD layout MLLM_RT_ASSERT_EQ(activation.shape().size(), 4); - auto B = activation.shape()[0]; - auto H = activation.shape()[1]; - auto S = activation.shape()[2]; - auto D = activation.shape()[3]; + auto B = 0; + auto H = 0; + auto S = 0; + auto D = 0; + + switch (input_layout_type) { + case aops::RoPEOpOptionsInputType::kBHSD: { + B = activation.shape()[0]; + H = activation.shape()[1]; + S = activation.shape()[2]; + D = activation.shape()[3]; + break; + } + case aops::RoPEOpOptionsInputType::kBSHD: { + B = activation.shape()[0]; + S = activation.shape()[1]; + H = activation.shape()[2]; + D = activation.shape()[3]; + break; + } + } int32_t half = D / 2; diff --git a/mllm/engine/prefix_cache/ZenFS.cpp b/mllm/engine/prefix_cache/ZenFS.cpp index 8a5409129..604be3f8f 100644 --- a/mllm/engine/prefix_cache/ZenFS.cpp +++ b/mllm/engine/prefix_cache/ZenFS.cpp @@ -672,13 +672,13 @@ void ZenFileSystem::_createBlobOnDisk() { void ZenFileSystem::_createBlobOnAnonymousFile() { // Calculate blob size. size_t total_bits = (1 << (options_.page_bits + options_.lane_bits)); - size_t uint64_count = (total_bits + sizeof(uint64_t) - 1) / sizeof(uint64_t); + size_t uint64_count = (total_bits + 64 - 1) / 64; // Blob size is not K and V. // Is 1 << (options_.page_bits + options_.lane_bits) * elements * sizeof(dtype). // K and V shared one blob. size_t blob_size = (1 << (options_.page_bits + options_.lane_bits)) * options_.per_k_token_ele * bytesOfType(options_.k_dtype) - / lanesOfType(options_.v_dtype); + / lanesOfType(options_.k_dtype); std::error_code ec; auto mmap_file = ZenFSBlobMMAPFile::create(blob_size, ZenFSMMAPMode::kAnonymous, "", ec); diff --git a/mllm/engine/service/Service.cpp b/mllm/engine/service/Service.cpp index ca6efab45..5400bcdb2 100644 --- a/mllm/engine/service/Service.cpp +++ b/mllm/engine/service/Service.cpp @@ -2,6 +2,9 @@ // Licensed under the MIT License. #include +#include +#include + #include "mllm/engine/service/Session.hpp" #include "mllm/engine/service/Service.hpp" @@ -18,10 +21,12 @@ void RequestPool::push(RequestItem item) { cv_.notify_one(); } -RequestItem RequestPool::pop() { +std::optional RequestPool::pop() { std::unique_lock lk(mtx_); cv_.wait(lk, [this] { return !queue_.empty() || stop_; }); - if (stop_) { throw std::runtime_error("RequestPool stopped"); } + + if (stop_ && queue_.empty()) { return std::nullopt; } + auto item = std::move(queue_.front()); queue_.pop(); return item; @@ -92,29 +97,52 @@ void Service::start(size_t worker_threads) { } void Service::stop() { - running_ = false; req_pool_.shutdown(); + running_ = false; + + for (auto& t : workers_) { + if (t.joinable()) { t.join(); } + } + resp_pool_.shutdown(); - for (auto& t : workers_) t.join(); } RequestPool& Service::requestPool() { return req_pool_; } + ResponsePool& Service::responsePool() { return resp_pool_; } + SessionPool& Service::sessionPool() { return sess_pool_; } void Service::workerLoop() { while (running_) { try { - RequestItem req = req_pool_.pop(); - auto session = sess_pool_.get(req.payload.value("model", "")); - - session->streamGenerate(req.payload, [this, req_id = req.id](const std::string& token, bool finished) { - ResponseItem item; - item.id = req_id; - item.finished = finished; - item.raw = token; - resp_pool_.push(req_id, std::move(item)); - }); + if (auto req_opt = req_pool_.pop(); req_opt) { + RequestItem& req = *req_opt; + auto session = sess_pool_.get(req.payload.value("model", "")); + + session->streamGenerate( + req.payload, [this, req_payload = req.payload, req_id = req.id](const std::string& token, bool finished) { + ResponseItem item; + item.id = req_id; + item.finished = finished; + item.raw = token; + + std::time_t now = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now()); + + // Make payload + item.payload = { + {"model", req_payload["model"]}, + {"created", now}, + {"choices", nlohmann::json::array({{{"index", 0}, + {"delta", {{"content", token}}}, + {"finish_reason", finished ? "stop" : nlohmann::json(nullptr)}}})}}; + + resp_pool_.push(req_id, std::move(item)); + }); + } else { + // pop return empty, which means service is stopped and queue is empty. + break; + } } catch (...) { // TODO } @@ -160,8 +188,9 @@ Response getResponse(const std::string& id) { j["finished"] = false; return j.dump(); } - nlohmann::json j; - j["data"] = opt->raw; + nlohmann::json j = opt->payload; + j["id"] = opt->id; + j["object"] = "chat.completion.chunk"; j["finished"] = opt->finished; std::string s = j.dump(); return s; diff --git a/mllm/engine/service/Service.hpp b/mllm/engine/service/Service.hpp index 5726b7311..2ae6e8a12 100644 --- a/mllm/engine/service/Service.hpp +++ b/mllm/engine/service/Service.hpp @@ -38,7 +38,7 @@ class RequestPool { public: void push(RequestItem item); - RequestItem pop(); + std::optional pop(); void shutdown(); diff --git a/mllm/models/qwen3/modeling_qwen3_service.hpp b/mllm/models/qwen3/modeling_qwen3_service.hpp index 799f69820..372095ca4 100644 --- a/mllm/models/qwen3/modeling_qwen3_service.hpp +++ b/mllm/models/qwen3/modeling_qwen3_service.hpp @@ -183,7 +183,7 @@ class Qwen3Attention final : public nn::Module { // Acquire cache std::vector k_addr_wait_for_promote; std::vector v_addr_wait_for_promote; - for (int s_idx; s_idx < S; ++s_idx) { + for (int s_idx = 0; s_idx < S; ++s_idx) { k_addr_wait_for_promote.push_back(prefix_cache_context->alloc(kCPU)); v_addr_wait_for_promote.push_back(prefix_cache_context->alloc(kCPU)); } @@ -191,7 +191,7 @@ class Qwen3Attention final : public nn::Module { // Prepare indicies cache. sizeof(char*) == 8 == sizeof(int64_t) std::vector k_phy_addr_wait_for_promote; std::vector v_phy_addr_wait_for_promote; - for (int s_idx; s_idx < S; ++s_idx) { + for (int s_idx = 0; s_idx < S; ++s_idx) { k_phy_addr_wait_for_promote.push_back(prefix_cache_context->physicalAddr(k_addr_wait_for_promote[s_idx])); v_phy_addr_wait_for_promote.push_back(prefix_cache_context->physicalAddr(v_addr_wait_for_promote[s_idx])); } @@ -395,9 +395,7 @@ class Qwen3Session final : public ::mllm::service::Session { void streamGenerate(const nlohmann::json& request, const std::function& callback) override { const auto& messages = request["messages"]; - auto inputs = applyChatTemplate(messages, nullptr, true, request.value("enable_thinking", false)); - - Dbg("prompt", inputs); + auto inputs = applyChatTemplate(messages, {}, true, request.value("enable_thinking", false)); auto full_seq_idx = tokenizer_->convert2Ids(tokenizer_->tokenize(inputs)).toVector(); ARGenerationArgs args; @@ -416,8 +414,8 @@ class Qwen3Session final : public ::mllm::service::Session { std::back_inserter(position_ids)); } MLLM_RT_ASSERT_EQ(reduced_seq_idx.size(), position_ids.size()); - input["sequence"] = Tensor::fromVector(reduced_seq_idx, {(int32_t)reduced_seq_idx.size()}, kInt64, kCPU); - input["position_ids"] = Tensor::fromVector(position_ids, {(int32_t)position_ids.size()}, kInt64, kCPU); + input["sequence"] = Tensor::fromVector(reduced_seq_idx, {1, (int32_t)reduced_seq_idx.size()}, kInt64, kCPU); + input["position_ids"] = Tensor::fromVector(position_ids, {1, (int32_t)position_ids.size()}, kInt64, kCPU); // Setup session context k_cache_addrs_ = prefix_cache_result.k_cache_addresses; @@ -429,31 +427,33 @@ class Qwen3Session final : public ::mllm::service::Session { // Has temperature, top_k, top_p, max_length, do_sample. args["temperature"] = request.value("temperature", 1.0f); args["top_k"] = request.value("top_k", 0); - args["top_p"] = request.value("top_p", 1.0f); + args["top_p"] = request.value("top_p", 0.0f); args["max_length"] = request.value("max_length", 1024); - args["do_sample"] = request.value("do_sample", true); + args["do_sample"] = request.value("do_sample", false); // Iteration start int64_t package_cnt = 0; - model_->streamGenerate(input, args, [this, &full_seq_idx, &package_cnt, &callback](int64_t idx) { + model_->streamGenerate(input, args, [this, &request, &full_seq_idx, &package_cnt, &callback](int64_t idx) { bool finished = false; std::string ret_token; - if (idx == model_->cfg.end_of_text_token_id) { + if (idx == model_->cfg.eos_token_id) { finished = true; ret_token = ""; } else { finished = false; - std::string t = preprocessor::wideString2Utf8String(tokenizer_->detokenize(idx)); + ret_token = preprocessor::wideString2Utf8String(tokenizer_->detokenize(idx)); // Update full_seq_idx to include the new token for Radix Tree to use. full_seq_idx.push_back(idx); } // Callback will send this json to the response pool for user to consume. - callback(nlohmann::json{}, finished); + callback(ret_token, finished); package_cnt++; }); + // Callback a finish token + callback("", true); // Post process full_seq_idx and k_cache_addrs_/v_cache_addrs_. Only none thinking budget should be insert in radix tree. // @@ -514,8 +514,8 @@ class Qwen3Session final : public ::mllm::service::Session { .allocator_options = {// Normal things. .per_k_token_ele = static_cast(cfg.head_dim * cfg.num_key_value_heads), .per_v_token_ele = static_cast(cfg.head_dim * cfg.num_key_value_heads), - .k_dtype = mllm::kFloat16, - .v_dtype = mllm::kFloat16, + .k_dtype = mllm::kFloat32, + .v_dtype = mllm::kFloat32, // CUDA things. .enable_cuda = false, @@ -531,129 +531,131 @@ class Qwen3Session final : public ::mllm::service::Session { .lane_bits = 5, .per_k_token_ele = static_cast(cfg.head_dim * cfg.num_key_value_heads), .per_v_token_ele = static_cast(cfg.head_dim * cfg.num_key_value_heads), - .k_dtype = mllm::kFloat16, - .v_dtype = mllm::kFloat16, + .k_dtype = mllm::kFloat32, + .v_dtype = mllm::kFloat32, .mmap_type = mllm::prefix_cache::ZenFSBlobMMapType::kAnonymous, }}}); } - std::string trim(const std::string& s) { - auto beg = s.find_first_not_of(" \t\r\n"); - if (beg == std::string::npos) return ""; - auto end = s.find_last_not_of(" \t\r\n"); - return s.substr(beg, end - beg + 1); + std::string ltrim(const std::string& s) { + size_t start = s.find_first_not_of(" \n\r\t\f\v"); + return (start == std::string::npos) ? "" : s.substr(start); } - // [{"role": "user", "content": "Say this is a test!"}] - // [{"role": "assistant", "content": "This is a test!", "reasoning": "You are absolutely right!"}] - // - std::string applyChatTemplate(const nlohmann::json& messages, const nlohmann::json* tools = nullptr, - bool add_generation_prompt = true, bool enable_thinking = true, - const std::string& bos_token = "", const std::string& eos_token = "<|im_end|>") { - std::ostringstream out; - - if (tools && tools->is_array() && !tools->empty()) { - out << "<|im_start|>system\n"; - if (messages.is_array() && !messages.empty() && messages[0].contains("role") && messages[0]["role"] == "system") { - out << messages[0]["content"].get() << "\n\n"; - } - out << "# Tools\n\nYou may call one or more functions to assist with the user query.\n\n" - "You are provided with function signatures within XML tags:\n"; - for (auto& t : *tools) out << "\n" << t.dump(); - out << "\n\n\n" - "For each function call, return a json object with function name and arguments " - "within XML tags:\n\n" - "{\"name\": , \"arguments\": }\n" - "<|im_end|>\n"; + std::string rtrim(const std::string& s) { + size_t end = s.find_last_not_of(" \n\r\t\f\v"); + return (end == std::string::npos) ? "" : s.substr(0, end + 1); + } + + std::string trim(const std::string& s) { return rtrim(ltrim(s)); } + + std::string applyChatTemplate(const json& messages, const std::vector& tools = {}, bool add_generation_prompt = true, + bool enable_thinking = true, const std::string& bos_token = "", + const std::string& eos_token = "<|im_end|>") { + std::ostringstream oss; + + if (!tools.empty()) { + oss << "<|im_start|>system\n"; + if (!messages.empty() && messages[0].value("role", "") == "system") { oss << messages[0].value("content", "") << "\n\n"; } + oss << "# Tools\n\nYou may call one or more functions to assist with the user query.\n"; + oss << "You are provided with function signatures within XML tags:\n"; + for (const auto& tool : tools) { oss << "\n" << tool.dump(); } + oss << "\n\n\nFor each function call, return a json object with function name and arguments within " + " XML tags:\n\n{\"name\": , \"arguments\": " + "}\n<|im_end|>\n"; } else { - if (messages.is_array() && !messages.empty() && messages[0].contains("role") && messages[0]["role"] == "system") { - out << "<|im_start|>system\n" << messages[0]["content"].get() << "<|im_end|>\n"; + if (!messages.empty() && messages[0].value("role", "") == "system") { + oss << "<|im_start|>system\n" << messages[0].value("content", "") << "<|im_end|>\n"; } } - size_t last_query_index = messages.size() - 1; - bool multi_step_tool = true; - if (messages.is_array()) { - for (int i = static_cast(messages.size()) - 1; i >= 0; --i) { - auto& m = messages[i]; - if (multi_step_tool && m.contains("role") && m["role"] == "user" && m.contains("content") && m["content"].is_string()) { - std::string c = m["content"]; - bool is_tr = - (c.starts_with("")) && (c.size() > 17 && c.compare(c.size() - 18, 18, "") == 0); - if (!is_tr) { - multi_step_tool = false; + size_t last_query_index = messages.empty() ? 0 : messages.size() - 1; + bool found_last_query = false; + if (!messages.empty()) { + for (int i = messages.size() - 1; i >= 0; --i) { + const auto& msg = messages[i]; + if (msg.value("role", "") == "user" && msg.contains("content") && msg["content"].is_string()) { + std::string content_str = msg["content"].get(); + if (!(content_str.starts_with("") + && content_str.find("") == content_str.length() - std::string("").length())) { last_query_index = i; + found_last_query = true; + break; } } } } + if (messages.empty()) { found_last_query = false; } - if (!messages.is_array()) return out.str(); - for (size_t idx = 0; idx < messages.size(); ++idx) { - auto& m = messages[idx]; - std::string role, content; - if (m.contains("role")) role = m["role"]; - if (m.contains("content") && m["content"].is_string()) content = m["content"]; - - if (role == "user" || (role == "system" && idx != 0)) { - out << "<|im_start|>" << role << "\n" << content << eos_token << "\n"; - continue; - } + for (size_t i = 0; i < messages.size(); ++i) { + const auto& message = messages[i]; + std::string role = message.value("role", ""); + std::string content; + if (message.contains("content") && message["content"].is_string()) { content = message["content"].get(); } - if (role == "assistant") { + if (role == "user" || (role == "system" && i > 0)) { + oss << "<|im_start|>" << role << "\n" << content << "<|im_end|>\n"; + } else if (role == "assistant") { std::string reasoning_content; - if (m.contains("reasoning_content") && m["reasoning_content"].is_string()) { - reasoning_content = m["reasoning_content"]; + if (message.contains("reasoning_content") && message["reasoning_content"].is_string()) { + reasoning_content = message["reasoning_content"].get(); } else { - size_t pos_end = content.find(""); - if (pos_end != std::string::npos) { - size_t pos_beg = content.rfind("", pos_end); - if (pos_beg != std::string::npos) { - reasoning_content = trim(content.substr(pos_beg + 7, pos_end - pos_beg - 7)); - content.erase(pos_beg, pos_end + 8); - content = trim(content); + auto think_end_pos = content.find(""); + if (think_end_pos != std::string::npos) { + auto think_start_pos = content.rfind("", think_end_pos); + if (think_start_pos != std::string::npos) { + reasoning_content = content.substr(think_start_pos + 7, think_end_pos - (think_start_pos + 7)); + content = content.substr(think_end_pos + 8); } } } - bool in_last_turn = (idx > last_query_index); - bool need_think = in_last_turn && (idx + 1 == messages.size() || !reasoning_content.empty()); - - out << "<|im_start|>assistant\n"; - if (need_think) { out << "\n" << reasoning_content << "\n\n\n"; } - out << content; - - if (m.contains("tool_calls") && m["tool_calls"].is_array()) { - for (auto& tc : m["tool_calls"]) { - json fn = tc.contains("function") ? tc["function"] : tc; - std::string name = fn.value("name", ""); - std::string args = fn.value("arguments", ""); - if (fn["arguments"].is_object() || fn["arguments"].is_array()) args = fn["arguments"].dump(); - out << "\n\n" - << R"({"name": ")" << name << R"(", "arguments": )" << args << "}\n" - << ""; + oss << "<|im_start|>" << role << "\n"; + if (found_last_query && i > last_query_index) { + if ((i == messages.size() - 1) || !reasoning_content.empty()) { + oss << "\n" << trim(reasoning_content) << "\n\n\n" << ltrim(content); + } else { + oss << content; } + } else { + oss << content; } - out << eos_token << "\n"; - continue; - } - if (role == "tool") { - bool first_tool = (idx == 0) || !messages[idx - 1].contains("role") || messages[idx - 1]["role"] != "tool"; - bool last_tool = - (idx + 1 == messages.size()) || !messages[idx + 1].contains("role") || messages[idx + 1]["role"] != "tool"; - if (first_tool) out << "<|im_start|>user"; - out << "\n\n" << content << "\n"; - if (last_tool) out << eos_token << "\n"; + if (message.contains("tool_calls")) { + bool is_first_tool = true; + for (const auto& tool_call_item : message["tool_calls"]) { + if ((is_first_tool && !content.empty()) || !is_first_tool) { oss << "\n"; } + is_first_tool = false; + + const json* tool_call_ptr = &tool_call_item; + if (tool_call_item.contains("function")) { tool_call_ptr = &tool_call_item["function"]; } + const json& tool_call = *tool_call_ptr; + + oss << "\n{\"name\": \"" << tool_call.value("name", "") << R"(", "arguments": )"; + const auto& args = tool_call["arguments"]; + if (args.is_string()) { + oss << args.get(); + } else { + oss << args.dump(); + } + oss << "}\n"; + } + } + oss << "<|im_end|>\n"; + + } else if (role == "tool") { + if (i == 0 || messages[i - 1].value("role", "") != "tool") { oss << "<|im_start|>user"; } + oss << "\n\n" << content << "\n"; + if (i == messages.size() - 1 || messages[i + 1].value("role", "") != "tool") { oss << "<|im_end|>\n"; } } } if (add_generation_prompt) { - out << "<|im_start|>assistant\n"; - if (!enable_thinking) { out << "\n\n\n\n"; } + oss << "<|im_start|>assistant\n"; + if (!enable_thinking) { oss << "\n\n\n\n"; } } - return out.str(); + return oss.str(); } private: diff --git a/tools/mllm-llm-benchmark/README.md b/tools/mllm-llm-benchmark/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/tools/mllm-vlm-benchmark/README.md b/tools/mllm-vlm-benchmark/README.md new file mode 100644 index 000000000..e69de29bb From f1e00ed414cbdd8ada7e70863ccfb5352c4314a0 Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Fri, 10 Oct 2025 15:52:13 +0800 Subject: [PATCH 8/8] ci(workflows): remove unused macos and x86 publish workflows Delete outdated GitHub Actions workflows for macOS and x86 builds that are no longer needed in the current CI/CD pipeline configuration. --- .github/workflows/pymllm-nightly-macos.yml | 0 .github/workflows/pymllm-publish-macos.yml | 0 .github/workflows/pymllm-publish-x86.yml | 0 3 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 .github/workflows/pymllm-nightly-macos.yml delete mode 100644 .github/workflows/pymllm-publish-macos.yml delete mode 100644 .github/workflows/pymllm-publish-x86.yml diff --git a/.github/workflows/pymllm-nightly-macos.yml b/.github/workflows/pymllm-nightly-macos.yml deleted file mode 100644 index e69de29bb..000000000 diff --git a/.github/workflows/pymllm-publish-macos.yml b/.github/workflows/pymllm-publish-macos.yml deleted file mode 100644 index e69de29bb..000000000 diff --git a/.github/workflows/pymllm-publish-x86.yml b/.github/workflows/pymllm-publish-x86.yml deleted file mode 100644 index e69de29bb..000000000