diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 4dd546641..f193f34c6 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -5,3 +5,4 @@ add_subdirectory(qwen2_5vl_tracer) add_subdirectory(llama) add_subdirectory(minicpm_o) add_subdirectory(qwen3) +add_subdirectory(qwen3_service) diff --git a/examples/qwen3_service/CMakeLists.txt b/examples/qwen3_service/CMakeLists.txt index e69de29bb..31faa395e 100644 --- a/examples/qwen3_service/CMakeLists.txt +++ b/examples/qwen3_service/CMakeLists.txt @@ -0,0 +1,3 @@ +add_executable(mllm-qwen3-service main.cpp) +target_link_libraries(mllm-qwen3-service PRIVATE MllmRT MllmCPUBackend) +target_include_directories(mllm-qwen3-service PRIVATE ${MLLM_INCLUDE_DIR}) diff --git a/examples/qwen3_service/main.cpp b/examples/qwen3_service/main.cpp index e69de29bb..26239064d 100644 --- a/examples/qwen3_service/main.cpp +++ b/examples/qwen3_service/main.cpp @@ -0,0 +1,91 @@ +#include +#include + +#include +#include + +#include +#include +#include + +MLLM_MAIN({ + mllm::setLogLevel(mllm::LogLevel::kError); + auto& model_path = mllm::Argparse::add("-m|--model_path").help("Model path").required(true); + mllm::Argparse::parse(argc, argv); + + auto qwen3_session = std::make_shared(); + qwen3_session->fromPreTrain(model_path.get()); + mllm::service::insertSession("mllmTeam/Qwen3-0.6B-w4a32kai", qwen3_session); + mllm::service::startService(); + + std::vector history; + const std::string model_name = "mllmTeam/Qwen3-0.6B-w4a32kai"; + + std::cout << "Enter /exit or /quit to exit this program\n"; + + while (true) { + std::cout << "\nUser: "; + std::string user_input; + std::getline(std::cin, user_input); + if (user_input == "/exit" || user_input == "/quit") break; + + nlohmann::json user_msg; + user_msg["role"] = "user"; + user_msg["content"] = user_input; + history.push_back(user_msg); + + nlohmann::json req; + req["model"] = model_name; + req["messages"] = history; + req["id"] = "chat-multi"; + req["enable_thinking"] = true; + mllm::service::sendRequest(req.dump()); + std::string assistant_content; + + bool thinking_states = false; + + while (true) { + std::string resp = mllm::service::getResponse("chat-multi"); + auto j = nlohmann::json::parse(resp); + + if (j.contains("choices") && j["choices"].size() > 0 && j["choices"][0].contains("delta") + && j["choices"][0]["delta"].contains("content")) { + std::string delta = j["choices"][0]["delta"]["content"].get(); + + if (delta == "") { + thinking_states = true; + fmt::print(fmt::fg(fmt::color::gray) | fmt::emphasis::bold | fmt::emphasis::underline, "Thinking...:"); + continue; + } + if (delta == "") { + thinking_states = false; + fmt::print("\n"); + continue; + } + + if (thinking_states) { + fmt::print(fmt::fg(fmt::color::gray), "{}", delta); + std::fflush(stdout); + } else { + fmt::print("{}", delta); + std::fflush(stdout); + } + + assistant_content += delta; + } + + if (j.contains("choices") && j["choices"].size() > 0 && j["choices"][0].contains("finish_reason") + && j["choices"][0]["finish_reason"].is_string() && j["choices"][0]["finish_reason"].get() == "stop") { + break; + } + } + + nlohmann::json assistant_msg; + assistant_msg["role"] = "assistant"; + assistant_msg["content"] = assistant_content; + history.push_back(assistant_msg); + } + + mllm::service::stopService(); + return 0; +}) diff --git a/mllm/backends/cpu/CPUBackend.cpp b/mllm/backends/cpu/CPUBackend.cpp index b2794df34..c94871b90 100644 --- a/mllm/backends/cpu/CPUBackend.cpp +++ b/mllm/backends/cpu/CPUBackend.cpp @@ -17,6 +17,7 @@ #include "mllm/backends/cpu/ops/FillOp.hpp" #include "mllm/backends/cpu/ops/FlashAttention2Op.hpp" #include "mllm/backends/cpu/ops/GELUOp.hpp" +#include "mllm/backends/cpu/ops/RadixAttnOp.hpp" #include "mllm/backends/cpu/ops/ReLUOp.hpp" #include "mllm/backends/cpu/ops/GraphOps.hpp" #include "mllm/backends/cpu/ops/ISTFTOp.hpp" @@ -35,6 +36,7 @@ #include "mllm/backends/cpu/ops/RepeatOp.hpp" #include "mllm/backends/cpu/ops/RoPEOp.hpp" #include "mllm/backends/cpu/ops/STFTOp.hpp" +#include "mllm/backends/cpu/ops/Scatter2ShardsOp.hpp" #include "mllm/backends/cpu/ops/SiLUOp.hpp" #include "mllm/backends/cpu/ops/SliceOp.hpp" #include "mllm/backends/cpu/ops/SoftmaxOp.hpp" @@ -58,7 +60,7 @@ CPUBackend::CPUBackend() : Backend(kCPU, createCPUAllocator()) { CPUFlashAttention2OpFactory, CPUSliceOpFactory, CPUVisionRoPEOpFactory, CPUParamOpFactory, CPUMultimodalRoPEOpFactory, CPURoPEOpFactory, CPUCausalMaskOpFactory, CPUConv1DOpFactory, CPUConv3DOpFactory, CPUSTFTOpFactory, CPUISTFTOpFactory, CPUIndexOpFactory, CPUTopKOpFactory, CPUClipOpFactory, CPUMeanOpFactory, - CPUKVCacheOpFactory, CPUPagedAttnOpFactory>(); + CPUKVCacheOpFactory, CPUPagedAttnOpFactory, CPUScatter2ShardsOpFactory, CPURadixAttnOpFactory>(); } std::shared_ptr createCPUBackend() { return std::make_shared(); } diff --git a/mllm/backends/cpu/kernels/common/paged_attn_x/impl-arm.hpp b/mllm/backends/cpu/kernels/common/paged_attn_x/impl-arm.hpp deleted file mode 100644 index 6327b45d2..000000000 --- a/mllm/backends/cpu/kernels/common/paged_attn_x/impl-arm.hpp +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright (c) MLLM Team. -// Licensed under the MIT License. - -#pragma once - -#include "mllm/core/DataTypes.hpp" -#include "mllm/backends/cpu/kernels/common/paged_attn_x/arch.hpp" - -#include - -namespace mllm::cpu::paged_attn_x::details { -template<> -struct VectorDotProduct<__ArmArchTag, mllm_fp32_t, mllm_fp32_t, mllm_fp32_t> { - static MLLM_FORCE_INLINE void run(const mllm_fp32_t* __restrict__ __lhs, const mllm_fp32_t* __restrict__ __rhs, - mllm_fp32_t* __out, size_t len) { - float32x4_t sum_vec0 = vdupq_n_f32(0.0f); - float32x4_t sum_vec1 = vdupq_n_f32(0.0f); - float32x4_t sum_vec2 = vdupq_n_f32(0.0f); - float32x4_t sum_vec3 = vdupq_n_f32(0.0f); - - size_t i = 0; - const size_t main_loop_bound = len - 15; - for (; i < main_loop_bound; i += 16) { - const float32x4_t lhs0 = vld1q_f32(__lhs + i); - const float32x4_t rhs0 = vld1q_f32(__rhs + i); - sum_vec0 = vfmaq_f32(sum_vec0, lhs0, rhs0); - - const float32x4_t lhs1 = vld1q_f32(__lhs + i + 4); - const float32x4_t rhs1 = vld1q_f32(__rhs + i + 4); - sum_vec1 = vfmaq_f32(sum_vec1, lhs1, rhs1); - - const float32x4_t lhs2 = vld1q_f32(__lhs + i + 8); - const float32x4_t rhs2 = vld1q_f32(__rhs + i + 8); - sum_vec2 = vfmaq_f32(sum_vec2, lhs2, rhs2); - - const float32x4_t lhs3 = vld1q_f32(__lhs + i + 12); - const float32x4_t rhs3 = vld1q_f32(__rhs + i + 12); - sum_vec3 = vfmaq_f32(sum_vec3, lhs3, rhs3); - } - - const size_t unroll4_bound = len - 3; - for (; i < unroll4_bound; i += 4) { - const float32x4_t lhs_vec = vld1q_f32(__lhs + i); - const float32x4_t rhs_vec = vld1q_f32(__rhs + i); - sum_vec0 = vfmaq_f32(sum_vec0, lhs_vec, rhs_vec); - } - - sum_vec0 = vaddq_f32(sum_vec0, sum_vec1); - sum_vec2 = vaddq_f32(sum_vec2, sum_vec3); - sum_vec0 = vaddq_f32(sum_vec0, sum_vec2); - - // Reduce - float result = vaddvq_f32(sum_vec0); - for (; i < len; ++i) { result += __lhs[i] * __rhs[i]; } - - *__out = result; - } -}; - -template<> -struct MulFromConst<__ArmArchTag, mllm_fp32_t, mllm_fp32_t> { - static MLLM_FORCE_INLINE void run(mllm_fp32_t* __restrict__ __from, const mllm_fp32_t const_v, size_t len) { - size_t i = 0; - const size_t simd_width = 4; - if (len >= simd_width) { - float32x4_t const_vec = vdupq_n_f32(const_v); - size_t simd_len = (len / simd_width) * simd_width; - for (; i < simd_len; i += simd_width) { - float32x4_t data_vec = vld1q_f32(&__from[i]); - data_vec = vmulq_f32(data_vec, const_vec); - vst1q_f32(&__from[i], data_vec); - } - } - for (; i < len; ++i) { __from[i] *= const_v; } - } -}; - -template<> -struct FMAConstArray<__ArmArchTag, mllm_fp32_t, mllm_fp32_t, mllm_fp32_t> { - static MLLM_FORCE_INLINE void run(mllm_fp32_t* __restrict__ acc_o, const mllm_fp32_t acc_s, - const mllm_fp32_t* __restrict__ v_token, size_t len) { - size_t i = 0; - const size_t simd_width = 4; - - if (len >= simd_width) { - float32x4_t acc_s_vec = vdupq_n_f32(acc_s); - - size_t simd_len = (len / simd_width) * simd_width; - for (; i < simd_len; i += simd_width) { - float32x4_t v_token_vec = vld1q_f32(&v_token[i]); - float32x4_t acc_o_vec = vld1q_f32(&acc_o[i]); - acc_o_vec = vfmaq_f32(acc_o_vec, acc_s_vec, v_token_vec); - vst1q_f32(&acc_o[i], acc_o_vec); - } - } - for (; i < len; ++i) { acc_o[i] += acc_s * v_token[i]; } - } -}; - -} // namespace mllm::cpu::paged_attn_x::details diff --git a/mllm/backends/cpu/kernels/common/paged_attn_x/README.md b/mllm/backends/cpu/kernels/common/radix_attn/README.md similarity index 100% rename from mllm/backends/cpu/kernels/common/paged_attn_x/README.md rename to mllm/backends/cpu/kernels/common/radix_attn/README.md diff --git a/mllm/backends/cpu/kernels/common/paged_attn_x/arch.hpp b/mllm/backends/cpu/kernels/common/radix_attn/arch.hpp similarity index 81% rename from mllm/backends/cpu/kernels/common/paged_attn_x/arch.hpp rename to mllm/backends/cpu/kernels/common/radix_attn/arch.hpp index e88f4be54..dae6e8e76 100644 --- a/mllm/backends/cpu/kernels/common/paged_attn_x/arch.hpp +++ b/mllm/backends/cpu/kernels/common/radix_attn/arch.hpp @@ -5,7 +5,7 @@ #include #include "mllm/utils/Common.hpp" -namespace mllm::cpu::paged_attn_x::details { +namespace mllm::cpu::radix_attn::details { struct __AnyArchTag {}; using any_arch_tag = __AnyArchTag; @@ -32,4 +32,9 @@ struct FMAConstArray { static MLLM_FORCE_INLINE void run(T* __restrict__ acc_o, const U acc_s, const V* __restrict__ v_token, size_t len) {} }; -} // namespace mllm::cpu::paged_attn_x::details +template +struct FilledWithConst { + static MLLM_FORCE_INLINE void run(T* __restrict__ a, const T v, size_t len) {} +}; + +} // namespace mllm::cpu::radix_attn::details diff --git a/mllm/backends/cpu/kernels/common/paged_attn_x/fwd_bshd.hpp b/mllm/backends/cpu/kernels/common/radix_attn/fwd_bshd.hpp similarity index 53% rename from mllm/backends/cpu/kernels/common/paged_attn_x/fwd_bshd.hpp rename to mllm/backends/cpu/kernels/common/radix_attn/fwd_bshd.hpp index d34e7676d..1834618b0 100644 --- a/mllm/backends/cpu/kernels/common/paged_attn_x/fwd_bshd.hpp +++ b/mllm/backends/cpu/kernels/common/radix_attn/fwd_bshd.hpp @@ -9,63 +9,62 @@ #include #include "mllm/core/Parallel.hpp" #include "mllm/utils/CPUArchHelper.hpp" -#include "mllm/engine/prefix_cache/TLB.hpp" -#include "mllm/backends/cpu/kernels/common/paged_attn_x/arch.hpp" +#include "mllm/backends/cpu/kernels/common/radix_attn/arch.hpp" #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH) -#include "mllm/backends/cpu/kernels/common/paged_attn_x/impl-arm.hpp" +#include "mllm/backends/cpu/kernels/common/radix_attn/impl-arm.hpp" #else -#include "mllm/backends/cpu/kernels/common/paged_attn_x/impl-any-simd.hpp" +#include "mllm/backends/cpu/kernels/common/radix_attn/impl-any-simd.hpp" #endif -#include "mllm/backends/cpu/kernels/common/paged_attn_x/impl-any.hpp" +#include "mllm/backends/cpu/kernels/common/radix_attn/impl-any.hpp" -namespace mllm::cpu::paged_attn_x { +namespace mllm::cpu::radix_attn { // BHSD // K: [S_KV], address, not contiguous // V: [S_KV], address, not contiguous -// Q: [B, H_Q, S_Q, D], contiguous +// Q: [B, S_Q, H_Q, D], contiguous +// +// After find KV Tokens, KV is [B, 1, H_KV, D] // // H_KV should <= H_Q template void fwd_bhsd(int32_t B, int32_t H_Q, int32_t H_KV, int32_t S_Q, int32_t S_KV, int32_t D, const __QDType* __restrict__ __q, - const mllm::prefix_cache::vp_addr_t* __k, const mllm::prefix_cache::vp_addr_t* __v, __ODType* __restrict__ __out, - void* ctx, int32_t thread_count) { + __KDType** __k, __VDType** __v, __ODType* __restrict__ __out, int32_t thread_count) { int32_t head_repeat_times = H_Q / H_KV; __AccDType scale = scale = std::sqrt(1.0 / D) * (__AccDType)std::numbers::log2e; // Loop on batch size. for (int b_idx = 0; b_idx < B; ++b_idx) { - // Loop on head dim, should be made parallel + // FIXME: Loop on SEQUENCE may faster? + // seq_q [ head_q [ seq_kv ] ] + + // Loop on HEAD dim, should be made parallel MLLM_CONDITIONAL_PARALLEL_FOR(thread_count > 1, thread_count, h_q_idx, 0, H_Q, 1, { int h_kv_id = h_q_idx / head_repeat_times; // FA2's Loop for (int s_q_idx = 0; s_q_idx < S_Q; ++s_q_idx) { - __QDType* q_token = __q + b_idx * H_Q * S_Q * D + h_q_idx * S_Q * D + s_q_idx * D; - __ODType* acc_o = __out + b_idx * H_Q * S_Q * D + h_q_idx * S_Q * D + s_q_idx * D; + const __QDType* q_token = __q + b_idx * H_Q * S_Q * D + s_q_idx * H_Q * D + h_q_idx * D; + __ODType* acc_o = __out + b_idx * H_Q * S_Q * D + s_q_idx * H_Q * D + h_q_idx * D; - // FIXME: Boost with SIMD - for (int d_idx = 0; d_idx < D; ++d_idx) { acc_o[d_idx] = 0; } - __AccDType scores_max = std::numeric_limits<__AccDType>::lowest(); - __AccDType scores_max_prev = std::numeric_limits<__AccDType>::lowest(); + details::FilledWithConst<__ArchTag, __ODType>::run(acc_o, 0, D); + + __AccDType scores_max = -std::numeric_limits<__AccDType>::infinity(); + __AccDType scores_max_prev = -std::numeric_limits<__AccDType>::infinity(); __AccDType logsum = 0; - __AccDType scores_scale = 0; __AccDType scores_sum = 0; + __AccDType scores_scale = 0; int __delta = S_KV - S_Q; int S_KV_BOUND = std::min(__delta + s_q_idx + 1, S_KV); for (int s_kv_idx = 0; s_kv_idx < S_KV_BOUND; ++s_kv_idx) { - // TODO, prefetch next - - // TODO using context. - // __KDType* k_token = (__KDType*)ctx->access(__k[s_kv_idx]); - // __VDType* v_token = (__VDType*)ctx->access(__v[s_kv_idx]); - __KDType* k_token = NULL; - __VDType* v_token = NULL; + // k_token and v_token shape is [B, 1, H, D] + __KDType* k_token = __k[s_kv_idx]; + __VDType* v_token = __v[s_kv_idx]; // Offset to one head. // k_token and v_token shape is [D] @@ -78,27 +77,24 @@ void fwd_bhsd(int32_t B, int32_t H_Q, int32_t H_KV, int32_t S_Q, int32_t S_KV, i // 2. Do softmax stuff. scores_max_prev = scores_max; - scores_max = std::numeric_limits<__AccDType>::lowest(); - scores_max = std::max(scores_max, acc_s); + scores_max = std::max(scores_max_prev, acc_s); scores_scale = std::exp2(scores_max_prev * scale - scores_max * scale); acc_s = std::exp2(acc_s * scale - scores_max * scale); - scores_sum += acc_s; // TODO This line may be error. + scores_sum = acc_s; logsum = logsum * scores_scale + scores_sum; // 3. Scale - MulFromConst<__ArchTag, __AccDType, __AccDType>(acc_o, scores_scale, D); + details::MulFromConst<__ArchTag, __AccDType, __AccDType>::run(acc_o, scores_scale, D); // 4. MMA1. - FMAConstArray<__ArchTag, __AccDType, __AccDType, __AccDType>(acc_o, acc_s, v_token, D); - - // TODO, drop this mmap in the future. + details::FMAConstArray<__ArchTag, __AccDType, __AccDType, __AccDType>::run(acc_o, acc_s, v_token, D); } // 5. Final Rescale. - MulFromConst<__ArchTag, __AccDType, __AccDType>(acc_o, (1.f / logsum), D); + details::MulFromConst<__ArchTag, __ODType, __AccDType>::run(acc_o, (1.f / logsum), D); } }); } } -} // namespace mllm::cpu::paged_attn_x +} // namespace mllm::cpu::radix_attn diff --git a/mllm/backends/cpu/kernels/common/paged_attn_x/impl-any-simd.hpp b/mllm/backends/cpu/kernels/common/radix_attn/impl-any-simd.hpp similarity index 100% rename from mllm/backends/cpu/kernels/common/paged_attn_x/impl-any-simd.hpp rename to mllm/backends/cpu/kernels/common/radix_attn/impl-any-simd.hpp diff --git a/mllm/backends/cpu/kernels/common/radix_attn/impl-any.hpp b/mllm/backends/cpu/kernels/common/radix_attn/impl-any.hpp new file mode 100644 index 000000000..a5905c3b8 --- /dev/null +++ b/mllm/backends/cpu/kernels/common/radix_attn/impl-any.hpp @@ -0,0 +1,44 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/DataTypes.hpp" +#include "mllm/backends/cpu/kernels/common/radix_attn/arch.hpp" + +#include + +namespace mllm::cpu::radix_attn::details { +template<> +struct VectorDotProduct<__AnyArchTag, mllm_fp32_t, mllm_fp32_t, mllm_fp32_t> { + static MLLM_FORCE_INLINE void run(const mllm_fp32_t* __restrict__ __lhs, const mllm_fp32_t* __restrict__ __rhs, + mllm_fp32_t* __out, size_t len) { + mllm_fp32_t ret = 0; + for (size_t i = 0; i < len; ++i) { ret += __lhs[i] * __rhs[i]; } + *__out = ret; + } +}; + +template<> +struct MulFromConst<__AnyArchTag, mllm_fp32_t, mllm_fp32_t> { + static MLLM_FORCE_INLINE void run(mllm_fp32_t* __restrict__ __from, const mllm_fp32_t const_v, size_t len) { + for (int i = 0; i < len; ++i) { __from[i] *= const_v; } + } +}; + +template<> +struct FMAConstArray<__AnyArchTag, mllm_fp32_t, mllm_fp32_t, mllm_fp32_t> { + static MLLM_FORCE_INLINE void run(mllm_fp32_t* __restrict__ acc_o, const mllm_fp32_t acc_s, + const mllm_fp32_t* __restrict__ v_token, size_t len) { + for (int i = 0; i < len; ++i) { acc_o[i] += acc_s * v_token[i]; } + } +}; + +template<> +struct FilledWithConst<__AnyArchTag, mllm_fp32_t> { + static MLLM_FORCE_INLINE void run(mllm_fp32_t* __restrict__ a, const mllm_fp32_t v, size_t len) { + for (int i = 0; i < len; ++i) { a[i] = v; } + } +}; + +} // namespace mllm::cpu::radix_attn::details diff --git a/mllm/backends/cpu/kernels/common/radix_attn/impl-arm.hpp b/mllm/backends/cpu/kernels/common/radix_attn/impl-arm.hpp new file mode 100644 index 000000000..859688da4 --- /dev/null +++ b/mllm/backends/cpu/kernels/common/radix_attn/impl-arm.hpp @@ -0,0 +1,155 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/DataTypes.hpp" +#include "mllm/backends/cpu/kernels/common/radix_attn/arch.hpp" + +#include + +namespace mllm::cpu::radix_attn::details { +template<> +struct VectorDotProduct<__ArmArchTag, mllm_fp32_t, mllm_fp32_t, mllm_fp32_t> { + static MLLM_FORCE_INLINE void run(const mllm_fp32_t* __restrict__ __lhs, const mllm_fp32_t* __restrict__ __rhs, + mllm_fp32_t* __out, size_t len) { + float32x4_t sum_vec = vdupq_n_f32(0.0f); + + size_t i = 0; + const size_t block_size = 16; + const size_t len_aligned = len & ~(block_size - 1); + + for (; i < len_aligned; i += block_size) { + float32x4_t lhs_vec0 = vld1q_f32(__lhs + i); + float32x4_t lhs_vec1 = vld1q_f32(__lhs + i + 4); + float32x4_t lhs_vec2 = vld1q_f32(__lhs + i + 8); + float32x4_t lhs_vec3 = vld1q_f32(__lhs + i + 12); + + float32x4_t rhs_vec0 = vld1q_f32(__rhs + i); + float32x4_t rhs_vec1 = vld1q_f32(__rhs + i + 4); + float32x4_t rhs_vec2 = vld1q_f32(__rhs + i + 8); + float32x4_t rhs_vec3 = vld1q_f32(__rhs + i + 12); + + sum_vec = vfmaq_f32(sum_vec, lhs_vec0, rhs_vec0); + sum_vec = vfmaq_f32(sum_vec, lhs_vec1, rhs_vec1); + sum_vec = vfmaq_f32(sum_vec, lhs_vec2, rhs_vec2); + sum_vec = vfmaq_f32(sum_vec, lhs_vec3, rhs_vec3); + } + + for (; i + 3 < len; i += 4) { + float32x4_t lhs_vec = vld1q_f32(__lhs + i); + float32x4_t rhs_vec = vld1q_f32(__rhs + i); + sum_vec = vfmaq_f32(sum_vec, lhs_vec, rhs_vec); + } + + float result = vaddvq_f32(sum_vec); + + for (; i < len; ++i) { result += __lhs[i] * __rhs[i]; } + + *__out = result; + } +}; + +template<> +struct MulFromConst<__ArmArchTag, mllm_fp32_t, mllm_fp32_t> { + static MLLM_FORCE_INLINE void run(mllm_fp32_t* __restrict__ __from, const mllm_fp32_t const_v, size_t len) { + float32x4_t const_vec = vdupq_n_f32(const_v); + + size_t i = 0; + const size_t block_size = 16; + const size_t len_aligned = len & ~(block_size - 1); + + for (; i < len_aligned; i += block_size) { + float32x4_t vec0 = vld1q_f32(__from + i); + float32x4_t vec1 = vld1q_f32(__from + i + 4); + float32x4_t vec2 = vld1q_f32(__from + i + 8); + float32x4_t vec3 = vld1q_f32(__from + i + 12); + + // FIXME: FMA may be muster than MUL + vec0 = vmulq_f32(vec0, const_vec); + vec1 = vmulq_f32(vec1, const_vec); + vec2 = vmulq_f32(vec2, const_vec); + vec3 = vmulq_f32(vec3, const_vec); + + vst1q_f32(__from + i, vec0); + vst1q_f32(__from + i + 4, vec1); + vst1q_f32(__from + i + 8, vec2); + vst1q_f32(__from + i + 12, vec3); + } + + for (; i + 3 < len; i += 4) { + float32x4_t vec = vld1q_f32(__from + i); + vec = vmulq_f32(vec, const_vec); + vst1q_f32(__from + i, vec); + } + + for (; i < len; ++i) { __from[i] *= const_v; } + } +}; + +template<> +struct FMAConstArray<__ArmArchTag, mllm_fp32_t, mllm_fp32_t, mllm_fp32_t> { + static MLLM_FORCE_INLINE void run(mllm_fp32_t* __restrict__ acc_o, const mllm_fp32_t acc_s, + const mllm_fp32_t* __restrict__ v_token, size_t len) { + float32x4_t acc_vec = vdupq_n_f32(acc_s); + + size_t i = 0; + const size_t block_size = 16; + const size_t len_aligned = len & ~(block_size - 1); + + for (; i < len_aligned; i += block_size) { + float32x4_t acc0 = vld1q_f32(acc_o + i); + float32x4_t token0 = vld1q_f32(v_token + i); + acc0 = vfmaq_f32(acc0, token0, acc_vec); + vst1q_f32(acc_o + i, acc0); + + float32x4_t acc1 = vld1q_f32(acc_o + i + 4); + float32x4_t token1 = vld1q_f32(v_token + i + 4); + acc1 = vfmaq_f32(acc1, token1, acc_vec); + vst1q_f32(acc_o + i + 4, acc1); + + float32x4_t acc2 = vld1q_f32(acc_o + i + 8); + float32x4_t token2 = vld1q_f32(v_token + i + 8); + acc2 = vfmaq_f32(acc2, token2, acc_vec); + vst1q_f32(acc_o + i + 8, acc2); + + float32x4_t acc3 = vld1q_f32(acc_o + i + 12); + float32x4_t token3 = vld1q_f32(v_token + i + 12); + acc3 = vfmaq_f32(acc3, token3, acc_vec); + vst1q_f32(acc_o + i + 12, acc3); + } + + for (; i + 3 < len; i += 4) { + float32x4_t acc = vld1q_f32(acc_o + i); + float32x4_t token = vld1q_f32(v_token + i); + acc = vfmaq_f32(acc, token, acc_vec); + vst1q_f32(acc_o + i, acc); + } + + for (; i < len; ++i) { acc_o[i] += acc_s * v_token[i]; } + } +}; + +template<> +struct FilledWithConst<__ArmArchTag, mllm_fp32_t> { + static MLLM_FORCE_INLINE void run(mllm_fp32_t* __restrict__ a, const mllm_fp32_t v, size_t len) { + float32x4_t const_vec = vdupq_n_f32(v); + + size_t i = 0; + const size_t block_size = 16; + const size_t len_aligned = len & ~(block_size - 1); + + for (; i < len_aligned; i += block_size) { + vst1q_f32(a + i, const_vec); + vst1q_f32(a + i + 4, const_vec); + vst1q_f32(a + i + 8, const_vec); + vst1q_f32(a + i + 12, const_vec); + } + + for (; i + 3 < len; i += 4) { vst1q_f32(a + i, const_vec); } + + for (; i < len; ++i) { a[i] = v; } + } +}; + +} // namespace mllm::cpu::radix_attn::details diff --git a/mllm/backends/cpu/ops/RadixAttnOp.cpp b/mllm/backends/cpu/ops/RadixAttnOp.cpp new file mode 100644 index 000000000..cca4ace3e --- /dev/null +++ b/mllm/backends/cpu/ops/RadixAttnOp.cpp @@ -0,0 +1,50 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/core/Tensor.hpp" +#include "mllm/backends/cpu/ops/RadixAttnOp.hpp" +#include "mllm/backends/cpu/kernels/common/radix_attn/arch.hpp" +#include "mllm/backends/cpu/kernels/common/radix_attn/fwd_bshd.hpp" +#include "mllm/utils/Common.hpp" + +namespace mllm::cpu { + +CPURadixAttnOp::CPURadixAttnOp(const aops::RadixAttnOpOptions& options) : aops::RadixAttnOp(options) {} + +void CPURadixAttnOp::forward(const std::vector& inputs, std::vector& outputs) { + const auto& Q = inputs[0]; + const auto& K_PTR = inputs[1]; + const auto& V_PTR = inputs[2]; + const auto& OUT = outputs[0]; + + MLLM_RT_ASSERT(K_PTR.dtype() == kInt64 && V_PTR.dtype() == kInt64 && K_PTR.rank() == 1 && V_PTR.rank() == 1); + auto B = Q.shape()[0]; + auto S_Q = Q.shape()[1]; + auto H_Q = Q.shape()[2]; + auto D = Q.shape()[3]; + MLLM_RT_ASSERT_EQ(H_Q, options_.H_Q); + auto S_KV = K_PTR.shape()[0]; + MLLM_RT_ASSERT_EQ(S_KV, V_PTR.shape()[0]); + + switch (Q.dtype()) { + case kFloat32: { +#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH) + cpu::radix_attn::fwd_bhsd(B, options_.H_Q, options_.H_KV, S_Q, S_KV, D, Q.ptr(), + K_PTR.ptr(), V_PTR.ptr(), OUT.ptr(), + options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86) || defined(MLLM_HOST_ARCH_X86_64) + cpu::radix_attn::fwd_bhsd(B, options_.H_Q, options_.H_KV, S_Q, S_KV, D, Q.ptr(), + K_PTR.ptr(), V_PTR.ptr(), OUT.ptr(), + options_.getThreads()); +#endif + break; + } + default: { + NYI("RadixAttnOp not supported for this data type"); + } + } +} + +} // namespace mllm::cpu diff --git a/mllm/backends/cpu/ops/RadixAttnOp.hpp b/mllm/backends/cpu/ops/RadixAttnOp.hpp new file mode 100644 index 000000000..bc7f44130 --- /dev/null +++ b/mllm/backends/cpu/ops/RadixAttnOp.hpp @@ -0,0 +1,25 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/aops/RadixAttnOp.hpp" + +namespace mllm::cpu { + +class CPURadixAttnOp final : public aops::RadixAttnOp { + public: + explicit CPURadixAttnOp(const aops::RadixAttnOpOptions& options); + + void forward(const std::vector& inputs, std::vector& outputs) override; +}; + +class CPURadixAttnOpFactory : public TypedOpFactory { + public: + std::shared_ptr createOpImpl(const aops::RadixAttnOpOptions& options) override { + return std::make_shared(options); + } +}; + +} // namespace mllm::cpu diff --git a/mllm/backends/cpu/ops/RoPEOp.cpp b/mllm/backends/cpu/ops/RoPEOp.cpp index 57a2b1542..cb9f98131 100644 --- a/mllm/backends/cpu/ops/RoPEOp.cpp +++ b/mllm/backends/cpu/ops/RoPEOp.cpp @@ -4,6 +4,7 @@ #include #include "mllm/backends/cpu/ops/RoPEOp.hpp" +#include "mllm/core/aops/RoPEOp.hpp" #include "mllm/utils/CPUArchHelper.hpp" #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) // Include AVX, SSE. @@ -13,84 +14,173 @@ namespace mllm::cpu { -void RoPEOpImpl::forward(const std::vector& inputs, std::vector& outputs, Tensor& sin, Tensor& cos) { +void RoPEOpImpl::forward(const std::vector& inputs, std::vector& outputs, Tensor& sin, Tensor& cos, + aops::RoPEOpOptionsInputType input_layout_type) { auto activation = inputs[0]; auto out = outputs[0]; - // Activation must in BHSD layout + // Activation must in BHSD or BSHD layout MLLM_RT_ASSERT_EQ(activation.shape().size(), 4); - auto B = activation.shape()[0]; - auto H = activation.shape()[1]; - auto S = activation.shape()[2]; - auto D = activation.shape()[3]; + auto B = 0; + auto H = 0; + auto S = 0; + auto D = 0; + + switch (input_layout_type) { + case aops::RoPEOpOptionsInputType::kBHSD: { + B = activation.shape()[0]; + H = activation.shape()[1]; + S = activation.shape()[2]; + D = activation.shape()[3]; + break; + } + case aops::RoPEOpOptionsInputType::kBSHD: { + B = activation.shape()[0]; + S = activation.shape()[1]; + H = activation.shape()[2]; + D = activation.shape()[3]; + break; + } + } int32_t half = D / 2; switch (activation.dtype()) { case kFloat32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - for (int n = 0; n < B; ++n) { - for (int h = 0; h < H; ++h) { - for (int s = 0; s < S; ++s) { - mllm_fp32_t* act_ptr = activation.offsettedPtr({n, h, s, 0}); - mllm_fp32_t* out_ptr = out.offsettedPtr({n, h, s, 0}); - const mllm_fp32_t* sin_ptr = sin.offsettedPtr({n, s, 0}); - const mllm_fp32_t* cos_ptr = cos.offsettedPtr({n, s, 0}); - - for (int d = 0; d < half; ++d) { - mllm_fp32_t in_val = act_ptr[d]; - mllm_fp32_t in_val2 = act_ptr[d + half]; - mllm_fp32_t sin_val = sin_ptr[d]; - mllm_fp32_t cos_val = cos_ptr[d]; - - out_ptr[d] = in_val * cos_val - in_val2 * sin_val; - out_ptr[d + half] = in_val * sin_val + in_val2 * cos_val; + switch (input_layout_type) { + case aops::RoPEOpOptionsInputType::kBHSD: { + for (int n = 0; n < B; ++n) { + for (int h = 0; h < H; ++h) { + for (int s = 0; s < S; ++s) { + mllm_fp32_t* act_ptr = activation.offsettedPtr({n, h, s, 0}); + mllm_fp32_t* out_ptr = out.offsettedPtr({n, h, s, 0}); + const mllm_fp32_t* sin_ptr = sin.offsettedPtr({n, s, 0}); + const mllm_fp32_t* cos_ptr = cos.offsettedPtr({n, s, 0}); + + for (int d = 0; d < half; ++d) { + mllm_fp32_t in_val = act_ptr[d]; + mllm_fp32_t in_val2 = act_ptr[d + half]; + mllm_fp32_t sin_val = sin_ptr[d]; + mllm_fp32_t cos_val = cos_ptr[d]; + + out_ptr[d] = in_val * cos_val - in_val2 * sin_val; + out_ptr[d + half] = in_val * sin_val + in_val2 * cos_val; + } + } + } + } + break; + } + case aops::RoPEOpOptionsInputType::kBSHD: { + for (int n = 0; n < B; ++n) { + for (int s = 0; s < S; ++s) { + const mllm_fp32_t* sin_ptr = sin.offsettedPtr({n, s, 0}); + const mllm_fp32_t* cos_ptr = cos.offsettedPtr({n, s, 0}); + for (int h = 0; h < H; ++h) { + mllm_fp32_t* act_ptr = activation.offsettedPtr({n, s, h, 0}); + mllm_fp32_t* out_ptr = out.offsettedPtr({n, s, h, 0}); + for (int d = 0; d < half; ++d) { + mllm_fp32_t in_val = act_ptr[d]; + mllm_fp32_t in_val2 = act_ptr[d + half]; + mllm_fp32_t sin_val = sin_ptr[d]; + mllm_fp32_t cos_val = cos_ptr[d]; + out_ptr[d] = in_val * cos_val - in_val2 * sin_val; + out_ptr[d + half] = in_val * sin_val + in_val2 * cos_val; + } + } } } + break; } } #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) - for (int n = 0; n < B; ++n) { - for (int h = 0; h < H; ++h) { - for (int s = 0; s < S; ++s) { - mllm_fp32_t* act_ptr = activation.offsettedPtr({n, h, s, 0}); - mllm_fp32_t* out_ptr = out.offsettedPtr({n, h, s, 0}); - const mllm_fp32_t* sin_ptr = sin.offsettedPtr({n, s, 0}); - const mllm_fp32_t* cos_ptr = cos.offsettedPtr({n, s, 0}); - - // Vectorized processing (4 elements per iteration) - int d = 0; - constexpr int step = 4; - for (; d <= half - step; d += step) { - // Load activation blocks - float32x4_t act_front = vld1q_f32(act_ptr + d); - float32x4_t act_back = vld1q_f32(act_ptr + d + half); - - // Load sin/cos values - float32x4_t sin_vec = vld1q_f32(sin_ptr + d); - float32x4_t cos_vec = vld1q_f32(cos_ptr + d); - - // Compute rotated values - float32x4_t out_front = vsubq_f32(vmulq_f32(act_front, cos_vec), vmulq_f32(act_back, sin_vec)); - float32x4_t out_back = vaddq_f32(vmulq_f32(act_front, sin_vec), vmulq_f32(act_back, cos_vec)); - - // Store results - vst1q_f32(out_ptr + d, out_front); - vst1q_f32(out_ptr + d + half, out_back); + switch (input_layout_type) { + case aops::RoPEOpOptionsInputType::kBHSD: { + for (int n = 0; n < B; ++n) { + for (int h = 0; h < H; ++h) { + for (int s = 0; s < S; ++s) { + mllm_fp32_t* act_ptr = activation.offsettedPtr({n, h, s, 0}); + mllm_fp32_t* out_ptr = out.offsettedPtr({n, h, s, 0}); + const mllm_fp32_t* sin_ptr = sin.offsettedPtr({n, s, 0}); + const mllm_fp32_t* cos_ptr = cos.offsettedPtr({n, s, 0}); + + // Vectorized processing (4 elements per iteration) + int d = 0; + constexpr int step = 4; + for (; d <= half - step; d += step) { + // Load activation blocks + float32x4_t act_front = vld1q_f32(act_ptr + d); + float32x4_t act_back = vld1q_f32(act_ptr + d + half); + + // Load sin/cos values + float32x4_t sin_vec = vld1q_f32(sin_ptr + d); + float32x4_t cos_vec = vld1q_f32(cos_ptr + d); + + // Compute rotated values + float32x4_t out_front = vsubq_f32(vmulq_f32(act_front, cos_vec), vmulq_f32(act_back, sin_vec)); + float32x4_t out_back = vaddq_f32(vmulq_f32(act_front, sin_vec), vmulq_f32(act_back, cos_vec)); + + // Store results + vst1q_f32(out_ptr + d, out_front); + vst1q_f32(out_ptr + d + half, out_back); + } + + // Process remaining elements + for (; d < half; ++d) { + mllm_fp32_t in_val = act_ptr[d]; + mllm_fp32_t in_val2 = act_ptr[d + half]; + mllm_fp32_t sin_val = sin_ptr[d]; + mllm_fp32_t cos_val = cos_ptr[d]; + + out_ptr[d] = in_val * cos_val - in_val2 * sin_val; + out_ptr[d + half] = in_val * sin_val + in_val2 * cos_val; + } + } } + } + break; + } + case aops::RoPEOpOptionsInputType::kBSHD: { + const int half = D / 2; + constexpr int step = 4; + + for (int b = 0; b < B; ++b) { + for (int s = 0; s < S; ++s) { + const mllm_fp32_t* sin_ptr = sin.offsettedPtr({b, s, 0}); + const mllm_fp32_t* cos_ptr = cos.offsettedPtr({b, s, 0}); + + for (int h = 0; h < H; ++h) { + mllm_fp32_t* act_ptr = activation.offsettedPtr({b, s, h, 0}); + mllm_fp32_t* out_ptr = out.offsettedPtr({b, s, h, 0}); - // Process remaining elements - for (; d < half; ++d) { - mllm_fp32_t in_val = act_ptr[d]; - mllm_fp32_t in_val2 = act_ptr[d + half]; - mllm_fp32_t sin_val = sin_ptr[d]; - mllm_fp32_t cos_val = cos_ptr[d]; + int d = 0; + for (; d <= half - step; d += step) { + float32x4_t act_front = vld1q_f32(act_ptr + d); // [d, d+1, d+2, d+3] + float32x4_t act_back = vld1q_f32(act_ptr + d + half); // [d+half, ...] + float32x4_t sin_vec = vld1q_f32(sin_ptr + d); + float32x4_t cos_vec = vld1q_f32(cos_ptr + d); - out_ptr[d] = in_val * cos_val - in_val2 * sin_val; - out_ptr[d + half] = in_val * sin_val + in_val2 * cos_val; + float32x4_t out_front = vsubq_f32(vmulq_f32(act_front, cos_vec), vmulq_f32(act_back, sin_vec)); + float32x4_t out_back = vaddq_f32(vmulq_f32(act_front, sin_vec), vmulq_f32(act_back, cos_vec)); + + vst1q_f32(out_ptr + d, out_front); + vst1q_f32(out_ptr + d + half, out_back); + } + for (; d < half; ++d) { + mllm_fp32_t in0 = act_ptr[d]; + mllm_fp32_t in1 = act_ptr[d + half]; + mllm_fp32_t s = sin_ptr[d]; + mllm_fp32_t c = cos_ptr[d]; + + out_ptr[d] = in0 * c - in1 * s; + out_ptr[d + half] = in0 * s + in1 * c; + } + } } } + break; } } #endif @@ -98,67 +188,144 @@ void RoPEOpImpl::forward(const std::vector& inputs, std::vector& } case kFloat16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - for (int n = 0; n < B; ++n) { - for (int h = 0; h < H; ++h) { - for (int s = 0; s < S; ++s) { - mllm_fp16_t* act_ptr = activation.offsettedPtr({n, h, s, 0}); - mllm_fp16_t* out_ptr = out.offsettedPtr({n, h, s, 0}); - const mllm_fp16_t* sin_ptr = sin.offsettedPtr({n, s, 0}); - const mllm_fp16_t* cos_ptr = cos.offsettedPtr({n, s, 0}); - - for (int d = 0; d < half; ++d) { - mllm_fp32_t in_val = static_cast(act_ptr[d]); - mllm_fp32_t in_val2 = static_cast(act_ptr[d + half]); - mllm_fp32_t sin_val = static_cast(sin_ptr[d]); - mllm_fp32_t cos_val = static_cast(cos_ptr[d]); - - out_ptr[d] = static_cast(in_val * cos_val - in_val2 * sin_val); - out_ptr[d + half] = static_cast(in_val * sin_val + in_val2 * cos_val); + switch (input_layout_type) { + case aops::RoPEOpOptionsInputType::kBHSD: { + for (int n = 0; n < B; ++n) { + for (int h = 0; h < H; ++h) { + for (int s = 0; s < S; ++s) { + mllm_fp16_t* act_ptr = activation.offsettedPtr({n, h, s, 0}); + mllm_fp16_t* out_ptr = out.offsettedPtr({n, h, s, 0}); + const mllm_fp16_t* sin_ptr = sin.offsettedPtr({n, s, 0}); + const mllm_fp16_t* cos_ptr = cos.offsettedPtr({n, s, 0}); + + for (int d = 0; d < half; ++d) { + mllm_fp32_t in_val = static_cast(act_ptr[d]); + mllm_fp32_t in_val2 = static_cast(act_ptr[d + half]); + mllm_fp32_t sin_val = static_cast(sin_ptr[d]); + mllm_fp32_t cos_val = static_cast(cos_ptr[d]); + + out_ptr[d] = static_cast(in_val * cos_val - in_val2 * sin_val); + out_ptr[d + half] = static_cast(in_val * sin_val + in_val2 * cos_val); + } + } + } + } + break; + } + case aops::RoPEOpOptionsInputType::kBSHD: { + for (int n = 0; n < B; ++n) { + for (int s = 0; s < S; ++s) { + const mllm_fp16_t* sin_ptr = sin.offsettedPtr({n, s, 0}); + const mllm_fp16_t* cos_ptr = cos.offsettedPtr({n, s, 0}); + for (int h = 0; h < H; ++h) { + mllm_fp16_t* act_ptr = activation.offsettedPtr({n, s, h, 0}); + mllm_fp16_t* out_ptr = out.offsettedPtr({n, s, h, 0}); + for (int d = 0; d < half; ++d) { + mllm_fp32_t in_val = static_cast(act_ptr[d]); + mllm_fp32_t in_val2 = static_cast(act_ptr[d + half]); + mllm_fp32_t sin_val = static_cast(sin_ptr[d]); + mllm_fp32_t cos_val = static_cast(cos_ptr[d]); + + out_ptr[d] = static_cast(in_val * cos_val - in_val2 * sin_val); + out_ptr[d + half] = static_cast(in_val * sin_val + in_val2 * cos_val); + } + } } } + break; } } #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) - for (int n = 0; n < B; ++n) { - for (int h = 0; h < H; ++h) { - for (int s = 0; s < S; ++s) { - mllm_fp16_t* act_ptr = activation.offsettedPtr({n, h, s, 0}); - mllm_fp16_t* out_ptr = out.offsettedPtr({n, h, s, 0}); - const mllm_fp16_t* sin_ptr = sin.offsettedPtr({n, s, 0}); - const mllm_fp16_t* cos_ptr = cos.offsettedPtr({n, s, 0}); - - // Vectorized processing (8 elements per iteration) - int d = 0; - constexpr int step = 8; - for (; d <= half - step; d += step) { - // Load activation blocks - float16x8_t act_front = vld1q_f16(act_ptr + d); - float16x8_t act_back = vld1q_f16(act_ptr + d + half); - - // Load sin/cos values - float16x8_t sin_vec = vld1q_f16(sin_ptr + d); - float16x8_t cos_vec = vld1q_f16(cos_ptr + d); - - // Compute rotated values - float16x8_t out_front = vsubq_f16(vmulq_f16(act_front, cos_vec), vmulq_f16(act_back, sin_vec)); - float16x8_t out_back = vaddq_f16(vmulq_f16(act_front, sin_vec), vmulq_f16(act_back, cos_vec)); - - // Store results - vst1q_f16(out_ptr + d, out_front); - vst1q_f16(out_ptr + d + half, out_back); + switch (input_layout_type) { + case aops::RoPEOpOptionsInputType::kBHSD: { + for (int n = 0; n < B; ++n) { + for (int h = 0; h < H; ++h) { + for (int s = 0; s < S; ++s) { + mllm_fp16_t* act_ptr = activation.offsettedPtr({n, h, s, 0}); + mllm_fp16_t* out_ptr = out.offsettedPtr({n, h, s, 0}); + const mllm_fp16_t* sin_ptr = sin.offsettedPtr({n, s, 0}); + const mllm_fp16_t* cos_ptr = cos.offsettedPtr({n, s, 0}); + + // Vectorized processing (8 elements per iteration) + int d = 0; + constexpr int step = 8; + for (; d <= half - step; d += step) { + // Load activation blocks + float16x8_t act_front = vld1q_f16(act_ptr + d); + float16x8_t act_back = vld1q_f16(act_ptr + d + half); + + // Load sin/cos values + float16x8_t sin_vec = vld1q_f16(sin_ptr + d); + float16x8_t cos_vec = vld1q_f16(cos_ptr + d); + + // Compute rotated values + float16x8_t out_front = vsubq_f16(vmulq_f16(act_front, cos_vec), vmulq_f16(act_back, sin_vec)); + float16x8_t out_back = vaddq_f16(vmulq_f16(act_front, sin_vec), vmulq_f16(act_back, cos_vec)); + + // Store results + vst1q_f16(out_ptr + d, out_front); + vst1q_f16(out_ptr + d + half, out_back); + } + + // Process remaining elements + for (; d < half; ++d) { + mllm_fp32_t in_val = static_cast(act_ptr[d]); + mllm_fp32_t in_val2 = static_cast(act_ptr[d + half]); + mllm_fp32_t sin_val = static_cast(sin_ptr[d]); + mllm_fp32_t cos_val = static_cast(cos_ptr[d]); + + out_ptr[d] = static_cast(in_val * cos_val - in_val2 * sin_val); + out_ptr[d + half] = static_cast(in_val * sin_val + in_val2 * cos_val); + } + } } + } + break; + } + case aops::RoPEOpOptionsInputType::kBSHD: { + for (int n = 0; n < B; ++n) { + for (int s = 0; s < S; ++s) { + const mllm_fp16_t* sin_ptr = sin.offsettedPtr({n, s, 0}); + const mllm_fp16_t* cos_ptr = cos.offsettedPtr({n, s, 0}); + + for (int h = 0; h < H; ++h) { + mllm_fp16_t* act_ptr = activation.offsettedPtr({n, s, h, 0}); + mllm_fp16_t* out_ptr = out.offsettedPtr({n, s, h, 0}); + // Vectorized processing (8 elements per iteration) + int d = 0; + constexpr int step = 8; + for (; d <= half - step; d += step) { + // Load activation blocks + float16x8_t act_front = vld1q_f16(act_ptr + d); + float16x8_t act_back = vld1q_f16(act_ptr + d + half); + + // Load sin/cos values + float16x8_t sin_vec = vld1q_f16(sin_ptr + d); + float16x8_t cos_vec = vld1q_f16(cos_ptr + d); + + // Compute rotated values + float16x8_t out_front = vsubq_f16(vmulq_f16(act_front, cos_vec), vmulq_f16(act_back, sin_vec)); + float16x8_t out_back = vaddq_f16(vmulq_f16(act_front, sin_vec), vmulq_f16(act_back, cos_vec)); + + // Store results + vst1q_f16(out_ptr + d, out_front); + vst1q_f16(out_ptr + d + half, out_back); + } - // Process remaining elements - for (; d < half; ++d) { - mllm_fp32_t in_val = static_cast(act_ptr[d]); - mllm_fp32_t in_val2 = static_cast(act_ptr[d + half]); - mllm_fp32_t sin_val = static_cast(sin_ptr[d]); - mllm_fp32_t cos_val = static_cast(cos_ptr[d]); + // Process remaining elements + for (; d < half; ++d) { + mllm_fp32_t in_val = static_cast(act_ptr[d]); + mllm_fp32_t in_val2 = static_cast(act_ptr[d + half]); + mllm_fp32_t sin_val = static_cast(sin_ptr[d]); + mllm_fp32_t cos_val = static_cast(cos_ptr[d]); - out_ptr[d] = static_cast(in_val * cos_val - in_val2 * sin_val); - out_ptr[d + half] = static_cast(in_val * sin_val + in_val2 * cos_val); + out_ptr[d] = static_cast(in_val * cos_val - in_val2 * sin_val); + out_ptr[d + half] = static_cast(in_val * sin_val + in_val2 * cos_val); + } + } } } + break; } } #endif @@ -189,7 +356,7 @@ void CPURoPEOp::forward(const std::vector& inputs, std::vector& auto out = outputs[0]; auto impl = RoPEOpImpl(); - impl.forward(inputs, outputs, sin, cos); + impl.forward(inputs, outputs, sin, cos, options_.input_type); } } // namespace mllm::cpu diff --git a/mllm/backends/cpu/ops/RoPEOp.hpp b/mllm/backends/cpu/ops/RoPEOp.hpp index 5c7b10c78..efeb9ebda 100644 --- a/mllm/backends/cpu/ops/RoPEOp.hpp +++ b/mllm/backends/cpu/ops/RoPEOp.hpp @@ -9,7 +9,8 @@ namespace mllm::cpu { struct RoPEOpImpl { - void forward(const std::vector& inputs, std::vector& outputs, Tensor& sin, Tensor& cos); + void forward(const std::vector& inputs, std::vector& outputs, Tensor& sin, Tensor& cos, + aops::RoPEOpOptionsInputType input_layout_type); }; class CPURoPEOp final : public aops::RoPEOp { @@ -26,4 +27,4 @@ class CPURoPEOpFactory : public TypedOpFactory +#include "mllm/utils/Common.hpp" +#include "mllm/core/DataTypes.hpp" +#include "mllm/backends/cpu/ops/Scatter2ShardsOp.hpp" + +namespace mllm::cpu { + +CPUScatter2ShardsOp::CPUScatter2ShardsOp(const aops::Scatter2ShardsOpOptions& options) : aops::Scatter2ShardsOp(options) {} + +void CPUScatter2ShardsOp::forward(const std::vector& inputs, std::vector& outputs) { + const auto& src = inputs[0]; + const auto& dst_ptrs = inputs[1]; + + // Validation + MLLM_RT_ASSERT(dst_ptrs.dtype() == kInt64 && dst_ptrs.rank() == 1 && dst_ptrs.shape()[0] == src.shape()[options_.dim]); + + // [B, H, S, D] + // floating_shards_point_left = B * H + // floating_shards_point_right = D + int32_t floating_shards_point_left = 1; + int32_t floating_shards_point_right = 1; + for (int i = 0; i < options_.dim; ++i) { floating_shards_point_left *= src.shape()[i]; } + for (int i = options_.dim + 1; i < src.rank(); ++i) { floating_shards_point_right *= src.shape()[i]; } + int32_t floating_shards_point_stride = 1; + for (int i = options_.dim; i < src.rank(); ++i) { floating_shards_point_stride *= src.shape()[i]; } + if (options_.dim == src.rank() - 1) { floating_shards_point_stride = 1; } + if (options_.dim == 0) { floating_shards_point_left = src.stride()[0]; } + + int32_t loop_times = src.shape()[options_.dim]; + for (int lo = 0; lo < loop_times; ++lo) { + for (int ho = 0; ho < floating_shards_point_left; ++ho) { + auto src_ptr = src.ptr() + + (ho * floating_shards_point_stride + lo * floating_shards_point_right) + * (bytesOfType(src.dtype()) / lanesOfType(src.dtype())); + auto dst_ptr = dst_ptrs.ptr()[lo] + + ho * floating_shards_point_right * (bytesOfType(src.dtype()) / lanesOfType(src.dtype())); + memcpy(dst_ptr, src_ptr, floating_shards_point_right * (bytesOfType(src.dtype()) / lanesOfType(src.dtype()))); + } + } +} + +} // namespace mllm::cpu diff --git a/mllm/backends/cpu/ops/Scatter2ShardsOp.hpp b/mllm/backends/cpu/ops/Scatter2ShardsOp.hpp new file mode 100644 index 000000000..5073cb63b --- /dev/null +++ b/mllm/backends/cpu/ops/Scatter2ShardsOp.hpp @@ -0,0 +1,25 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/aops/Scatter2ShardsOp.hpp" + +namespace mllm::cpu { + +class CPUScatter2ShardsOp final : public aops::Scatter2ShardsOp { + public: + explicit CPUScatter2ShardsOp(const aops::Scatter2ShardsOpOptions& options); + + void forward(const std::vector& inputs, std::vector& outputs) override; +}; + +class CPUScatter2ShardsOpFactory : public TypedOpFactory { + public: + std::shared_ptr createOpImpl(const aops::Scatter2ShardsOpOptions& options) override { + return std::make_shared(options); + } +}; + +} // namespace mllm::cpu diff --git a/mllm/core/OpTypes.hpp b/mllm/core/OpTypes.hpp index 58c59dbf2..4adc215e5 100644 --- a/mllm/core/OpTypes.hpp +++ b/mllm/core/OpTypes.hpp @@ -73,6 +73,8 @@ enum class OpTypes : int32_t { // High-level Op or Fused Op kPagedAttn = 57, + kRadixAttn = 58, + kScatter2Shards = 59, // Dynamic Op Start for user to register there own ops. kDynamicOp_Start = 4096, @@ -140,6 +142,7 @@ inline std::string optype2Str(OpTypes type) { case OpTypes::kGraphBegin: return "GraphBegin"; case OpTypes::kGraphEnd: return "GraphEnd"; case OpTypes::kPagedAttn: return "PagedAttn"; + case OpTypes::kScatter2Shards: return "Scatter2Shards"; case OpTypes::kOpType_End: return "OpType_End"; default: return "Unknown"; } diff --git a/mllm/core/Tensor.cpp b/mllm/core/Tensor.cpp index cc92d2d0a..73f80bf23 100644 --- a/mllm/core/Tensor.cpp +++ b/mllm/core/Tensor.cpp @@ -109,7 +109,7 @@ Tensor Tensor::random(const std::vector& shape, float start, float end, return Context::instance().buildOpAndSubmitTask( OpTypes::kFill, aops::FillOpOptions{ - .type = aops::FillOpTypes::kRandom, .start = start, .end = end, .seed = Context::instance().getRandomSeed()}, + .type = aops::FillOpTypes::kRandom, .start = start, .end = end, .seed = Context::instance().getRandomState()}, {i})[0]; } diff --git a/mllm/core/Tensor.hpp b/mllm/core/Tensor.hpp index a7b0a25bb..0011b871f 100644 --- a/mllm/core/Tensor.hpp +++ b/mllm/core/Tensor.hpp @@ -3,6 +3,7 @@ #pragma once +#include #include #include #include @@ -93,6 +94,47 @@ class Tensor { return tensor; } + template + static inline Tensor fromVector(const std::span& vec, const shape_t& shape, DataTypes dtype = kFloat32, + DeviceTypes device = kCPU) { + Tensor tensor = Tensor::empty(shape, dtype, device).alloc(); + size_t tensor_size = tensor.numel(); + if (vec.size() != tensor_size) { + MLLM_ERROR_EXIT(ExitCode::kShapeError, "Tensor size mismatch with std::vector size"); + return Tensor::nil(); + } + std::copy(vec.begin(), vec.end(), tensor.ptr()); + return tensor; + } + + /** + * @brief Create a tensor from a std::vector, but reference the vector data, not copy it. + * + * @tparam T + * @param vec + * @param shape + * @param dtype + * @param device + * @return Tensor + */ + template + static inline Tensor refVectorData(const std::vector& vec, const shape_t& shape, DataTypes dtype = kFloat32, + DeviceTypes device = kCPU) { + size_t expected_size = 1; + for (auto dim : shape) { expected_size *= dim; } + + if (vec.size() != expected_size) { + MLLM_ERROR_EXIT(ExitCode::kShapeError, "Tensor shape mismatch with std::vector size"); + return Tensor::nil(); + } + + Tensor tensor = Tensor::empty(shape, dtype, device); + tensor.impl_->storage()->ptr_ = const_cast(vec.data()); + tensor.impl_->storage()->mem_type_ = kManual; + + return tensor; + } + template inline std::vector toVector() const { std::vector vec; @@ -522,7 +564,7 @@ class Tensor { * @return Typed base pointer. */ template - T* ptr() const { + [[nodiscard]] T* ptr() const { return impl_->ptr(); } diff --git a/mllm/core/aops/PagedAttnOp.hpp b/mllm/core/aops/PagedAttnOp.hpp index 43d97f5ff..675d6b41f 100644 --- a/mllm/core/aops/PagedAttnOp.hpp +++ b/mllm/core/aops/PagedAttnOp.hpp @@ -11,7 +11,6 @@ namespace mllm::aops { enum class PagedAttnImplType { kDefault = 0, kAllFp32 = 1, - kPrefixCache = 2, }; struct PagedAttnOpOptions : public BaseOpOptions { diff --git a/mllm/core/aops/RadixAttnOp.cpp b/mllm/core/aops/RadixAttnOp.cpp new file mode 100644 index 000000000..4d9f33378 --- /dev/null +++ b/mllm/core/aops/RadixAttnOp.cpp @@ -0,0 +1,33 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/core/aops/RadixAttnOp.hpp" +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/Tensor.hpp" +#include "mllm/utils/Common.hpp" + +namespace mllm::aops { + +RadixAttnOp::RadixAttnOp(const RadixAttnOpOptions& options) : BaseOp(OpTypes::kRadixAttn), options_(options) {} + +void RadixAttnOp::load(const ParameterFile::ptr_t& ploader) { MLLM_EMPTY_SCOPE; } + +void RadixAttnOp::trace(void* trace_context, const std::vector& inputs, std::vector& outputs) { + MLLM_WARN("RadixAttnOp::trace is not supported."); +} + +void RadixAttnOp::forward(const std::vector& inputs, std::vector& outputs) { + NYI("RadixAttnOp::forward not implemented in aops base."); +} + +void RadixAttnOp::reshape(const std::vector& inputs, std::vector& outputs) { + auto& query = inputs[0]; + auto& key = inputs[1]; + auto& value = inputs[2]; + + outputs.emplace_back(Tensor::empty(query.shape(), query.dtype(), query.device())); +} + +void RadixAttnOp::setup(const std::vector& inputs, std::vector& outputs) { BaseOp::setup(inputs, outputs); } + +} // namespace mllm::aops diff --git a/mllm/core/aops/RadixAttnOp.hpp b/mllm/core/aops/RadixAttnOp.hpp new file mode 100644 index 000000000..dbadffc59 --- /dev/null +++ b/mllm/core/aops/RadixAttnOp.hpp @@ -0,0 +1,34 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/ParameterFile.hpp" + +namespace mllm::aops { + +struct RadixAttnOpOptions : public BaseOpOptions { + int32_t H_Q; + int32_t H_KV; +}; + +class RadixAttnOp : public BaseOp { + public: + explicit RadixAttnOp(const RadixAttnOpOptions& options); + + void load(const ParameterFile::ptr_t& ploader) override; + + void trace(void* trace_context, const std::vector& inputs, std::vector& outputs) override; + + void forward(const std::vector& inputs, std::vector& outputs) override; + + void reshape(const std::vector& inputs, std::vector& outputs) override; + + void setup(const std::vector& inputs, std::vector& outputs) override; + + protected: + RadixAttnOpOptions options_; +}; + +} // namespace mllm::aops diff --git a/mllm/core/aops/RoPEOp.hpp b/mllm/core/aops/RoPEOp.hpp index 901cec799..d8401eeb0 100644 --- a/mllm/core/aops/RoPEOp.hpp +++ b/mllm/core/aops/RoPEOp.hpp @@ -8,12 +8,15 @@ namespace mllm::aops { +enum class RoPEOpOptionsInputType : uint8_t { + kBHSD = 0, + kBSHD = 1, +}; + struct RoPEOpOptions : public BaseOpOptions { float rope_theta = 10000.0F; int32_t max_position_embeddings = 16384; - - RoPEOpOptions() = default; - explicit RoPEOpOptions(float theta, int32_t max_pos_embed) : rope_theta(theta), max_position_embeddings(max_pos_embed) {} + RoPEOpOptionsInputType input_type = RoPEOpOptionsInputType::kBHSD; }; class RoPEOp : public BaseOp { @@ -36,4 +39,4 @@ class RoPEOp : public BaseOp { RoPEOpOptions options_; }; -} // namespace mllm::aops \ No newline at end of file +} // namespace mllm::aops diff --git a/mllm/core/aops/Scatter2ShardsOp.cpp b/mllm/core/aops/Scatter2ShardsOp.cpp new file mode 100644 index 000000000..bfcc8783b --- /dev/null +++ b/mllm/core/aops/Scatter2ShardsOp.cpp @@ -0,0 +1,35 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/core/aops/Scatter2ShardsOp.hpp" +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/Tensor.hpp" +#include "mllm/utils/Common.hpp" +#include "mllm/compile/ir/linalg/Op.hpp" + +namespace mllm::aops { + +Scatter2ShardsOp::Scatter2ShardsOp(const Scatter2ShardsOpOptions& options) + : BaseOp(OpTypes::kScatter2Shards), options_(options) {} + +void Scatter2ShardsOp::load(const ParameterFile::ptr_t& ploader) { MLLM_EMPTY_SCOPE; } + +void Scatter2ShardsOp::trace(void* trace_context, const std::vector& inputs, std::vector& outputs) { + MLLM_WARN("Scatter2ShardsOp::trace can't be traced in v2.0.0 right now. Pls send us a feature request issues on github if " + "you need this."); +} + +void Scatter2ShardsOp::forward(const std::vector& inputs, std::vector& outputs) { + NYI("Scatter2ShardsOp::forward not implemented in aops base."); +} + +void Scatter2ShardsOp::reshape(const std::vector& inputs, std::vector& outputs) { + // scatter op has no output. It only has inputs. + MLLM_EMPTY_SCOPE; +} + +void Scatter2ShardsOp::setup(const std::vector& inputs, std::vector& outputs) { + BaseOp::setup(inputs, outputs); +} + +} // namespace mllm::aops diff --git a/mllm/core/aops/Scatter2ShardsOp.hpp b/mllm/core/aops/Scatter2ShardsOp.hpp new file mode 100644 index 000000000..a45c0642f --- /dev/null +++ b/mllm/core/aops/Scatter2ShardsOp.hpp @@ -0,0 +1,33 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/ParameterFile.hpp" + +namespace mllm::aops { + +struct Scatter2ShardsOpOptions : public BaseOpOptions { + int dim = 0; +}; + +class Scatter2ShardsOp : public BaseOp { + public: + explicit Scatter2ShardsOp(const Scatter2ShardsOpOptions& options); + + void load(const ParameterFile::ptr_t& ploader) override; + + void trace(void* trace_context, const std::vector& inputs, std::vector& outputs) override; + + void forward(const std::vector& inputs, std::vector& outputs) override; + + void reshape(const std::vector& inputs, std::vector& outputs) override; + + void setup(const std::vector& inputs, std::vector& outputs) override; + + protected: + Scatter2ShardsOpOptions options_; +}; + +} // namespace mllm::aops diff --git a/mllm/engine/Context.cpp b/mllm/engine/Context.cpp index 17363564c..5c9f0ffb6 100644 --- a/mllm/engine/Context.cpp +++ b/mllm/engine/Context.cpp @@ -94,10 +94,20 @@ SessionTCB::ptr_t Context::thisThread() { SessionTCB::ptr_t Context::mainThread() { return main_thread_; } -void Context::setRandomSeed(uint64_t seed) { random_seed_ = seed; } +void Context::setRandomSeed(uint64_t seed) { + random_seed_ = seed; + random_state_ = seed; +} uint64_t Context::getRandomSeed() { return random_seed_; } +uint64_t Context::getRandomState() { + auto ret = random_state_; + std::mt19937 gen(random_state_); + random_state_ = gen(); + return ret; +} + uint64_t Context::curTime() { auto now = std::chrono::high_resolution_clock::now(); auto duration = now.time_since_epoch(); @@ -106,20 +116,12 @@ uint64_t Context::curTime() { std::unordered_map Context::refSessionThreads() { return session_threads_; } -void Context::setPrintPrecision(int precision) { - print_precision_ = precision; -} +void Context::setPrintPrecision(int precision) { print_precision_ = precision; } -int Context::getPrintPrecision() const { - return print_precision_; -} +int Context::getPrintPrecision() const { return print_precision_; } -void Context::setPrintMaxElementsPerDim(int max_elements) { - print_max_elements_per_dim_ = max_elements; -} +void Context::setPrintMaxElementsPerDim(int max_elements) { print_max_elements_per_dim_ = max_elements; } -int Context::getPrintMaxElementsPerDim() const { - return print_max_elements_per_dim_; -} +int Context::getPrintMaxElementsPerDim() const { return print_max_elements_per_dim_; } -} // namespace mllm \ No newline at end of file +} // namespace mllm diff --git a/mllm/engine/Context.hpp b/mllm/engine/Context.hpp index 93fb1bea1..fe0b0bc84 100644 --- a/mllm/engine/Context.hpp +++ b/mllm/engine/Context.hpp @@ -46,6 +46,8 @@ class Context { uint64_t getRandomSeed(); + uint64_t getRandomState(); + uint64_t curTime(); std::unordered_map refSessionThreads(); @@ -64,6 +66,7 @@ class Context { Context(); uint64_t random_seed_ = 42; + uint64_t random_state_ = 42; SessionTCB::ptr_t main_thread_; std::unordered_map session_threads_; diff --git a/mllm/engine/prefix_cache/RadixTree.cpp b/mllm/engine/prefix_cache/RadixTree.cpp index 770170282..d9049c521 100644 --- a/mllm/engine/prefix_cache/RadixTree.cpp +++ b/mllm/engine/prefix_cache/RadixTree.cpp @@ -157,8 +157,8 @@ RadixSearchResult RadixTree::search(const RadixTreeNodeKey& key) { result.success = false; result.path = path; result.matched_length = 0; - result.k_cache_addresses = {}; - result.v_cache_addresses = {}; + result.k_cache_addresses.resize(options_.transformer_blocks_num); + result.v_cache_addresses.resize(options_.transformer_blocks_num); } return result; diff --git a/mllm/engine/prefix_cache/ZenFS.cpp b/mllm/engine/prefix_cache/ZenFS.cpp index 8a5409129..604be3f8f 100644 --- a/mllm/engine/prefix_cache/ZenFS.cpp +++ b/mllm/engine/prefix_cache/ZenFS.cpp @@ -672,13 +672,13 @@ void ZenFileSystem::_createBlobOnDisk() { void ZenFileSystem::_createBlobOnAnonymousFile() { // Calculate blob size. size_t total_bits = (1 << (options_.page_bits + options_.lane_bits)); - size_t uint64_count = (total_bits + sizeof(uint64_t) - 1) / sizeof(uint64_t); + size_t uint64_count = (total_bits + 64 - 1) / 64; // Blob size is not K and V. // Is 1 << (options_.page_bits + options_.lane_bits) * elements * sizeof(dtype). // K and V shared one blob. size_t blob_size = (1 << (options_.page_bits + options_.lane_bits)) * options_.per_k_token_ele * bytesOfType(options_.k_dtype) - / lanesOfType(options_.v_dtype); + / lanesOfType(options_.k_dtype); std::error_code ec; auto mmap_file = ZenFSBlobMMAPFile::create(blob_size, ZenFSMMAPMode::kAnonymous, "", ec); diff --git a/mllm/engine/service/Service.cpp b/mllm/engine/service/Service.cpp index 4c9635b6f..5400bcdb2 100644 --- a/mllm/engine/service/Service.cpp +++ b/mllm/engine/service/Service.cpp @@ -2,6 +2,9 @@ // Licensed under the MIT License. #include +#include +#include + #include "mllm/engine/service/Session.hpp" #include "mllm/engine/service/Service.hpp" @@ -18,10 +21,12 @@ void RequestPool::push(RequestItem item) { cv_.notify_one(); } -RequestItem RequestPool::pop() { +std::optional RequestPool::pop() { std::unique_lock lk(mtx_); cv_.wait(lk, [this] { return !queue_.empty() || stop_; }); - if (stop_) { throw std::runtime_error("RequestPool stopped"); } + + if (stop_ && queue_.empty()) { return std::nullopt; } + auto item = std::move(queue_.front()); queue_.pop(); return item; @@ -92,35 +97,66 @@ void Service::start(size_t worker_threads) { } void Service::stop() { - running_ = false; req_pool_.shutdown(); + running_ = false; + + for (auto& t : workers_) { + if (t.joinable()) { t.join(); } + } + resp_pool_.shutdown(); - for (auto& t : workers_) t.join(); } RequestPool& Service::requestPool() { return req_pool_; } + ResponsePool& Service::responsePool() { return resp_pool_; } + SessionPool& Service::sessionPool() { return sess_pool_; } void Service::workerLoop() { while (running_) { try { - RequestItem req = req_pool_.pop(); - auto session = sess_pool_.get(req.payload.value("model", "")); - - session->streamGenerate(req.payload, [this, req_id = req.id](const std::string& token, bool finished) { - ResponseItem item; - item.id = req_id; - item.finished = finished; - item.raw = token; - resp_pool_.push(req_id, std::move(item)); - }); + if (auto req_opt = req_pool_.pop(); req_opt) { + RequestItem& req = *req_opt; + auto session = sess_pool_.get(req.payload.value("model", "")); + + session->streamGenerate( + req.payload, [this, req_payload = req.payload, req_id = req.id](const std::string& token, bool finished) { + ResponseItem item; + item.id = req_id; + item.finished = finished; + item.raw = token; + + std::time_t now = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now()); + + // Make payload + item.payload = { + {"model", req_payload["model"]}, + {"created", now}, + {"choices", nlohmann::json::array({{{"index", 0}, + {"delta", {{"content", token}}}, + {"finish_reason", finished ? "stop" : nlohmann::json(nullptr)}}})}}; + + resp_pool_.push(req_id, std::move(item)); + }); + } else { + // pop return empty, which means service is stopped and queue is empty. + break; + } } catch (...) { // TODO } } } +void startService(size_t worker_threads) { Service::instance().start(worker_threads); } + +void stopService() { Service::instance().stop(); } + +void insertSession(const std::string& session_id, const std::shared_ptr& session) { + Service::instance().sessionPool().registerSession(session_id, session); +} + int sendRequest(const std::string& json_str) { if (json_str.empty()) return -1; try { @@ -152,8 +188,9 @@ Response getResponse(const std::string& id) { j["finished"] = false; return j.dump(); } - nlohmann::json j; - j["data"] = opt->raw; + nlohmann::json j = opt->payload; + j["id"] = opt->id; + j["object"] = "chat.completion.chunk"; j["finished"] = opt->finished; std::string s = j.dump(); return s; diff --git a/mllm/engine/service/Service.hpp b/mllm/engine/service/Service.hpp index 5b87968e2..2ae6e8a12 100644 --- a/mllm/engine/service/Service.hpp +++ b/mllm/engine/service/Service.hpp @@ -4,6 +4,7 @@ #include #include +#include #include #include @@ -37,7 +38,7 @@ class RequestPool { public: void push(RequestItem item); - RequestItem pop(); + std::optional pop(); void shutdown(); @@ -106,8 +107,14 @@ class Service { std::vector workers_; }; -// Golang, Python SDK should bind those two function and provide high level Network API. +// Golang, Python SDK should bind those 5 function and provide high level Network API. // We just focus on the server side. +void startService(size_t worker_threads = 1); + +void stopService(); + +void insertSession(const std::string& session_id, const std::shared_ptr& session); + int sendRequest(const std::string& json_str); Response getResponse(const std::string& id); diff --git a/mllm/engine/service/Session.cpp b/mllm/engine/service/Session.cpp index 2089eb5a0..9fb649e1e 100644 --- a/mllm/engine/service/Session.cpp +++ b/mllm/engine/service/Session.cpp @@ -5,13 +5,14 @@ #include #include +#include "mllm/utils/Common.hpp" #include "mllm/engine/service/Session.hpp" namespace mllm::service { -Session::Session(std::shared_ptr model) : model_(std::move(model)) {} +void Session::fromPreTrain(const std::string& model_path) { MLLM_EMPTY_SCOPE; } -NoneSession::NoneSession() : Session(nullptr) {} +NoneSession::NoneSession() : Session() {} void NoneSession::streamGenerate(const nlohmann::json& request, const std::function& callback) { diff --git a/mllm/engine/service/Session.hpp b/mllm/engine/service/Session.hpp index 339d6cb60..f8e41084e 100644 --- a/mllm/engine/service/Session.hpp +++ b/mllm/engine/service/Session.hpp @@ -4,7 +4,6 @@ #include #include -#include "mllm/models/ARGeneration.hpp" namespace mllm::service { @@ -12,13 +11,12 @@ class Session { public: using ptr_t = std::shared_ptr; - explicit Session(std::shared_ptr model); + Session() = default; virtual void streamGenerate(const nlohmann::json& request, const std::function& callback) = 0; - private: - std::shared_ptr model_; + virtual void fromPreTrain(const std::string& model_path); }; class NoneSession final : public Session { diff --git a/mllm/mllm.cpp b/mllm/mllm.cpp index b93f17cbf..b6851c425 100644 --- a/mllm/mllm.cpp +++ b/mllm/mllm.cpp @@ -23,6 +23,8 @@ void setLogLevel(const LogLevel& level) { ::mllm::Logger::level() = level; } void setRandomSeed(uint64_t seed) { Context::instance().setRandomSeed(seed); } +int64_t getRandomState() { return Context::instance().getRandomState(); } + void setMaximumNumThreads(uint32_t num_threads) { // TODO } diff --git a/mllm/mllm.hpp b/mllm/mllm.hpp index c8026955a..d1268b7c4 100644 --- a/mllm/mllm.hpp +++ b/mllm/mllm.hpp @@ -156,6 +156,8 @@ void setLogLevel(const LogLevel& level); void setRandomSeed(uint64_t seed); +int64_t getRandomState(); + void setMaximumNumThreads(uint32_t num_threads); void setPrintPrecision(int precision); diff --git a/mllm/models/qwen3/configuration_qwen3.hpp b/mllm/models/qwen3/configuration_qwen3.hpp index 63c7d0cdb..7df5137ca 100644 --- a/mllm/models/qwen3/configuration_qwen3.hpp +++ b/mllm/models/qwen3/configuration_qwen3.hpp @@ -51,6 +51,8 @@ struct Qwen3Config : protected ConfigFile { bool tie_word_embeddings = true; int32_t max_cache_length = 2048; int32_t end_of_text_token_id = 151645; + int32_t thinking_start_token_id = 151667; + int32_t thinking_end_token_id = 151668; aops::LinearImplTypes linear_impl_type = aops::LinearImplTypes::kDefault; }; diff --git a/mllm/models/qwen3/modeling_qwen3_service.hpp b/mllm/models/qwen3/modeling_qwen3_service.hpp index 09b8cf8d2..372095ca4 100644 --- a/mllm/models/qwen3/modeling_qwen3_service.hpp +++ b/mllm/models/qwen3/modeling_qwen3_service.hpp @@ -1,15 +1,26 @@ // Copyright (c) MLLM Team. // Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include #include "mllm/mllm.hpp" #include "mllm/nn/Nn.hpp" #include "mllm/nn/Module.hpp" #include "mllm/nn/Functional.hpp" -#include "mllm/nn/lmcache/StaticCache.hpp" -#include "mllm/models/qwen3/configuration_qwen3.hpp" #include "mllm/utils/Enumerate.hpp" #include "mllm/models/ARGeneration.hpp" +#include "mllm/preprocessor/tokenizers/Unicode.hpp" +#include "mllm/models/qwen3/tokenization_qwen3.hpp" +#include "mllm/models/qwen3/configuration_qwen3.hpp" + +// Service related. #include "mllm/engine/service/Session.hpp" +#include "mllm/engine/prefix_cache/Cache.hpp" namespace mllm::models::qwen3 { @@ -103,8 +114,7 @@ class Qwen3Attention final : public nn::Module { nn::RMSNorm rms_norm_k_; nn::RoPE q_rope_; nn::RoPE k_rope_; - nn::CausalMask mask_; - nn::Softmax softmax_; + nn::RadixAttn attn_; int hidden_size_; int head_dim_; @@ -122,30 +132,29 @@ class Qwen3Attention final : public nn::Module { head_dim_ = cfg.head_dim; num_key_value_groups_ = num_attention_heads_ / num_key_value_heads_; - q_proj_ = - reg("q_proj", hidden_size_, head_dim_ * num_attention_heads_, cfg.attention_bias, cfg.linear_impl_type); - k_proj_ = - reg("k_proj", hidden_size_, head_dim_ * num_key_value_heads_, cfg.attention_bias, cfg.linear_impl_type); - v_proj_ = - reg("v_proj", hidden_size_, head_dim_ * num_key_value_heads_, cfg.attention_bias, cfg.linear_impl_type); - o_proj_ = - reg("o_proj", head_dim_ * num_attention_heads_, hidden_size_, cfg.attention_bias, cfg.linear_impl_type); + // clang-format off + q_proj_ = reg("q_proj", hidden_size_, head_dim_ * num_attention_heads_, cfg.attention_bias, cfg.linear_impl_type); + k_proj_ = reg("k_proj", hidden_size_, head_dim_ * num_key_value_heads_, cfg.attention_bias, cfg.linear_impl_type); + v_proj_ = reg("v_proj", hidden_size_, head_dim_ * num_key_value_heads_, cfg.attention_bias, cfg.linear_impl_type); + o_proj_ = reg("o_proj", head_dim_ * num_attention_heads_, hidden_size_, cfg.attention_bias, cfg.linear_impl_type); + // clang-format on rms_norm_q_ = reg("q_norm", cfg.rms_norm_eps); rms_norm_k_ = reg("k_norm", cfg.rms_norm_eps); - q_rope_ = reg("q_rope", cfg.rope_theta, cfg.max_position_embeddings); - k_rope_ = reg("k_rope", cfg.rope_theta, cfg.max_position_embeddings); + q_rope_ = reg("q_rope", cfg.rope_theta, cfg.max_position_embeddings, aops::RoPEOpOptionsInputType::kBSHD); + k_rope_ = reg("k_rope", cfg.rope_theta, cfg.max_position_embeddings, aops::RoPEOpOptionsInputType::kBSHD); - mask_ = reg("mask"); - softmax_ = reg("softmax", -1); + attn_ = reg("attn", num_attention_heads_, num_key_value_heads_); } std::vector forward(const std::vector& inputs, const std::vector& args) override { auto x = inputs[0]; auto llm_embedding_sin = inputs[1]; auto llm_embedding_cos = inputs[2]; - auto past_kv_cache = args[0].get(); + auto k_cache_addr = args[0].get>*>(); + auto v_cache_addr = args[1].get>*>(); + auto prefix_cache_context = args[2].get(); // [B, S, H * D] auto query_states = q_proj_(x); @@ -164,39 +173,60 @@ class Qwen3Attention final : public nn::Module { query_states = rms_norm_q_(query_states); key_states = rms_norm_k_(key_states); - // [B, H, S, D] - query_states = query_states.transpose(1, 2); - key_states = key_states.transpose(1, 2); - value_states = value_states.transpose(1, 2); - - // [B, H, S, D] + // Different from original [B, H, S, D] rope. + // [B, S, H, D] query_states = q_rope_(query_states, llm_embedding_sin, llm_embedding_cos); key_states = k_rope_(key_states, llm_embedding_sin, llm_embedding_cos); - // [B, H, S, D] - auto [key_states_new, value_states_new] = past_kv_cache->updateKVCache(layer_idx_, key_states, value_states); - key_states = key_states_new; - value_states = value_states_new; - - Tensor attn; - if (key_states.dtype() == kFloat32) { - // attention weight - // [B, H, S, S] - attn = nn::functional::matmul(query_states, key_states, false, true) * (1.f / sqrtf(head_dim_)); - attn = mask_(attn); - attn = softmax_(attn); - } else if (key_states.dtype() == kFloat16) { - attn = nn::functional::matmul(query_states.to(kFloat32), key_states.to(kFloat32), false, true) * (1.f / sqrtf(head_dim_)); - attn = mask_(attn); - attn = softmax_(attn); - attn = attn.to(kFloat16); + // FIXME: Think, if rope before cache is ok? + + // Acquire cache + std::vector k_addr_wait_for_promote; + std::vector v_addr_wait_for_promote; + for (int s_idx = 0; s_idx < S; ++s_idx) { + k_addr_wait_for_promote.push_back(prefix_cache_context->alloc(kCPU)); + v_addr_wait_for_promote.push_back(prefix_cache_context->alloc(kCPU)); } - // attn output - // [B, H, S, S] @ [B, H, S, D] -> [B, H, S, D] - auto output = nn::functional::matmul(attn, value_states); - // [B, H, S, D] -> [B, S, H, D] -> [B, S, H * D] - output = output.transpose(1, 2).view({B, S, num_attention_heads_ * head_dim_}); + // Prepare indicies cache. sizeof(char*) == 8 == sizeof(int64_t) + std::vector k_phy_addr_wait_for_promote; + std::vector v_phy_addr_wait_for_promote; + for (int s_idx = 0; s_idx < S; ++s_idx) { + k_phy_addr_wait_for_promote.push_back(prefix_cache_context->physicalAddr(k_addr_wait_for_promote[s_idx])); + v_phy_addr_wait_for_promote.push_back(prefix_cache_context->physicalAddr(v_addr_wait_for_promote[s_idx])); + } + auto k_wait_for_promote = Tensor::refVectorData(k_phy_addr_wait_for_promote, {S}, kInt64, kCPU); + auto v_wait_for_promote = Tensor::refVectorData(v_phy_addr_wait_for_promote, {S}, kInt64, kCPU); + + // Copy key_states and value_states to cache + nn::functional::scatter2Shards(key_states, k_wait_for_promote, 1); + nn::functional::scatter2Shards(value_states, v_wait_for_promote, 1); + + // Gather all cache to indicies tensor + { + auto& dst = (*k_cache_addr)[layer_idx_]; + dst.insert(dst.end(), k_addr_wait_for_promote.begin(), k_addr_wait_for_promote.end()); + } + { + auto& dst = (*v_cache_addr)[layer_idx_]; + dst.insert(dst.end(), v_addr_wait_for_promote.begin(), v_addr_wait_for_promote.end()); + } + std::vector k_phy_cache_indicies; + std::vector v_phy_cache_indicies; + int32_t kv_cache_len = (*k_cache_addr)[layer_idx_].size(); + k_phy_cache_indicies.reserve(kv_cache_len); + v_phy_cache_indicies.reserve(kv_cache_len); + for (int i = 0; i < kv_cache_len; ++i) { + k_phy_cache_indicies.push_back(prefix_cache_context->physicalAddr((*k_cache_addr)[layer_idx_][i])); + v_phy_cache_indicies.push_back(prefix_cache_context->physicalAddr((*v_cache_addr)[layer_idx_][i])); + } + auto k_cache = Tensor::refVectorData(k_phy_cache_indicies, {kv_cache_len}, kInt64, kCPU); + auto v_cache = Tensor::refVectorData(v_phy_cache_indicies, {kv_cache_len}, kInt64, kCPU); + + // Do Radix Attention + // output is [B, S, H, D] + auto output = attn_(query_states, k_cache, v_cache); + output = output.view({B, S, num_attention_heads_ * head_dim_}); output = o_proj_(output); return {output}; @@ -224,10 +254,12 @@ class Qwen3Decoder final : public nn::Module { std::vector forward(const std::vector& inputs, const std::vector& args) override { auto llm_embedding_sin = inputs[1]; auto llm_embedding_cos = inputs[2]; - auto& kv_cache = args[0]; + auto& k_cache_addr = args[0]; + auto& v_cache_addr = args[1]; + auto& prefix_cache_context = args[2]; auto x = input_layer_norm_(inputs[0]); - x = self_attn_(x, llm_embedding_sin, llm_embedding_cos, kv_cache)[0]; + x = self_attn_(x, llm_embedding_sin, llm_embedding_cos, k_cache_addr, v_cache_addr, prefix_cache_context)[0]; auto tmp = x + inputs[0]; x = post_attention_layer_norm_(tmp); x = mlp_(x)[0]; @@ -259,9 +291,13 @@ class Qwen3Text final : public nn::Module { auto llm_embedding_sin = inputs[1]; auto llm_embedding_cos = inputs[2]; - auto& kv_cache = args[0]; + auto& k_cache_addr = args[0]; + auto& v_cache_addr = args[1]; + auto& prefix_cache_context = args[2]; - for (auto& block : blocks) { x = block(x, llm_embedding_sin, llm_embedding_cos, kv_cache)[0]; } + for (auto& block : blocks) { + x = block(x, llm_embedding_sin, llm_embedding_cos, k_cache_addr, v_cache_addr, prefix_cache_context)[0]; + } x = norm_(x); @@ -272,21 +308,13 @@ class Qwen3Text final : public nn::Module { class Qwen3ForCausalLM : public ARGeneration, public nn::Module { public: explicit Qwen3ForCausalLM(const Qwen3Config& cfg) : cfg(cfg) { - kv_cache_ = nn::StaticCache(cfg.max_cache_length, cfg.num_hidden_layers, - cfg.num_attention_heads, // q_heads - cfg.num_key_value_heads, // kv_heads - cfg.head_dim, // kv_dim - kFloat32, // k_dtype - kFloat32, // v_dtype - kCPU, // device_type - false // use_fa2 - ); eos_token_id_ = cfg.end_of_text_token_id; max_length_ = cfg.max_cache_length; tie_word_embeddings_ = cfg.tie_word_embeddings; llm = reg("model", cfg); + // Qwen3 0.6B's lm_head is tied with embed_tokens. But ModelScope's official weights separate them. if (cfg.tie_word_embeddings) { // NOTE: // model.lm_head.weight is quantization weights of model.embed_tokens.weight @@ -317,6 +345,9 @@ class Qwen3ForCausalLM : public ARGeneration, public nn::Module { *position_ids.offsettedPtr({0, 0}) = last_pos + 1; } } else { + // NOTE: Service Session should not go into this branch !!! + MLLM_RT_ASSERT(false); + // Generate position_ids for prefill phase position_ids = Tensor::empty({batch_size, seq_len}, kInt64, kCPU).alloc(); auto position_ids_ptr = position_ids.ptr(); @@ -328,7 +359,8 @@ class Qwen3ForCausalLM : public ARGeneration, public nn::Module { // Generate RoPE embeddings using the inv_freq buffer auto [llm_embedding_sin, llm_embedding_cos] = makeRotaryPosEmbedding(position_ids, getBuffer("inv_freq"), 1.0f); - sequence = llm(sequence, llm_embedding_sin, llm_embedding_cos, AnyValue(&kv_cache_))[0]; + sequence = llm(sequence, llm_embedding_sin, llm_embedding_cos, args.at("k_cache_addrs"), args.at("v_cache_addrs"), + args.at("prefix_cache_context"))[0]; // clip x to one seq length { @@ -343,17 +375,298 @@ class Qwen3ForCausalLM : public ARGeneration, public nn::Module { }; } - private: const Qwen3Config& cfg; + + private: Qwen3Text llm; nn::Linear lm_head_; bool tie_word_embeddings_; - nn::StaticCache kv_cache_; }; -// class Qwen3Session final : public ::mllm::service::Session { -// public: -// private: -// }; +class Qwen3Session final : public ::mllm::service::Session { + public: + Qwen3Session() = default; + + std::size_t findThinkStartToken(const std::vector& output_ids) { + auto it = std::find(output_ids.begin(), output_ids.end(), model_->cfg.thinking_start_token_id); + return std::distance(output_ids.begin(), it); + } + + void streamGenerate(const nlohmann::json& request, + const std::function& callback) override { + const auto& messages = request["messages"]; + auto inputs = applyChatTemplate(messages, {}, true, request.value("enable_thinking", false)); + + auto full_seq_idx = tokenizer_->convert2Ids(tokenizer_->tokenize(inputs)).toVector(); + ARGenerationArgs args; + ARGenerationOutputPast input; + + // Search in the radix cache. Find the tokens we really need to compute. + auto prefix_cache_result = cache_->find(full_seq_idx); + std::span reduced_seq_idx(full_seq_idx.data() + prefix_cache_result.matched_length, + full_seq_idx.size() - prefix_cache_result.matched_length); + std::vector position_ids; + { + auto start = prefix_cache_result.matched_length; + auto end = full_seq_idx.size(); + position_ids.reserve(end - start); + std::ranges::copy(std::views::iota(static_cast(start), static_cast(end)), + std::back_inserter(position_ids)); + } + MLLM_RT_ASSERT_EQ(reduced_seq_idx.size(), position_ids.size()); + input["sequence"] = Tensor::fromVector(reduced_seq_idx, {1, (int32_t)reduced_seq_idx.size()}, kInt64, kCPU); + input["position_ids"] = Tensor::fromVector(position_ids, {1, (int32_t)position_ids.size()}, kInt64, kCPU); + + // Setup session context + k_cache_addrs_ = prefix_cache_result.k_cache_addresses; + v_cache_addrs_ = prefix_cache_result.v_cache_addresses; + args["k_cache_addrs"] = &k_cache_addrs_; + args["v_cache_addrs"] = &v_cache_addrs_; + args["prefix_cache_context"] = cache_.get(); + + // Has temperature, top_k, top_p, max_length, do_sample. + args["temperature"] = request.value("temperature", 1.0f); + args["top_k"] = request.value("top_k", 0); + args["top_p"] = request.value("top_p", 0.0f); + args["max_length"] = request.value("max_length", 1024); + args["do_sample"] = request.value("do_sample", false); + + // Iteration start + int64_t package_cnt = 0; + model_->streamGenerate(input, args, [this, &request, &full_seq_idx, &package_cnt, &callback](int64_t idx) { + bool finished = false; + std::string ret_token; + if (idx == model_->cfg.eos_token_id) { + finished = true; + ret_token = ""; + } else { + finished = false; + ret_token = preprocessor::wideString2Utf8String(tokenizer_->detokenize(idx)); + + // Update full_seq_idx to include the new token for Radix Tree to use. + full_seq_idx.push_back(idx); + } + + // Callback will send this json to the response pool for user to consume. + callback(ret_token, finished); + + package_cnt++; + }); + // Callback a finish token + callback("", true); + + // Post process full_seq_idx and k_cache_addrs_/v_cache_addrs_. Only none thinking budget should be insert in radix tree. + // + // NOTE: We will drop everything after the thinking_start_token_idx(include it). + // Suppose: Only one token in the sequence. + // + // e.g.: + // <|im_start|>user + // hello<|im_end|> + // <|im_start|>assistant + // + // + // + // hello! + // <|endoftext|> + // + // In radic tree, we will only save: + // <|im_start|>user + // hello<|im_end|> + // <|im_start|>assistant + // + // Explain: That because Qwen3 and other CoT model will remove thinking budget, which means the answer "hello"'s rope is + // changed in 2ed turn. + auto thinking_end_token_idx = findThinkStartToken(full_seq_idx); + full_seq_idx.resize(thinking_end_token_idx); + for (auto& k_vec : k_cache_addrs_) k_vec.resize(thinking_end_token_idx); + for (auto& v_vec : v_cache_addrs_) v_vec.resize(thinking_end_token_idx); + + // Insert generated tokens to the cache. + cache_->promote(full_seq_idx, k_cache_addrs_, v_cache_addrs_); + + // Cleanup session Context + k_cache_addrs_ = {}; + v_cache_addrs_ = {}; + } + + void fromPreTrain(const std::string& model_path) override { + namespace fs = std::filesystem; + fs::path root = fs::path(model_path).lexically_normal(); + fs::path config_file = root / "config.json"; + fs::path model_file = root / "model.mllm"; + fs::path tokenizer_file = root / "tokenizer.json"; + if (!fs::exists(config_file)) throw std::runtime_error(config_file.string() + " not found"); + if (!fs::exists(model_file)) throw std::runtime_error(model_file.string() + " not found"); + if (!fs::exists(tokenizer_file)) throw std::runtime_error(tokenizer_file.string() + " not found"); + + auto cfg = Qwen3Config(config_file.string()); + model_ = std::make_shared(cfg); + model_->load(load(model_file.string(), ModelFileVersion::kV2)); + tokenizer_ = std::make_shared(tokenizer_file.string()); + + cache_ = std::make_shared(prefix_cache::CacheOptions{ + .radix_tree_options = {.enable_lru_eviction = false, + .eviction_threshold = 0.9f, + .enable_path_compression = false, + .min_compression_length = 2, + .transformer_blocks_num = cfg.num_hidden_layers}, + .allocator_options = {// Normal things. + .per_k_token_ele = static_cast(cfg.head_dim * cfg.num_key_value_heads), + .per_v_token_ele = static_cast(cfg.head_dim * cfg.num_key_value_heads), + .k_dtype = mllm::kFloat32, + .v_dtype = mllm::kFloat32, + + // CUDA things. + .enable_cuda = false, + .cuda_mem_base = 0x100000, + + // CPU things. + .enable_cpu_hierarchy_memory = true, + .zen_fs_options = { + .record = false, + .working_dir = ".", + .blob_bits_size = 20, + .page_bits = 7, + .lane_bits = 5, + .per_k_token_ele = static_cast(cfg.head_dim * cfg.num_key_value_heads), + .per_v_token_ele = static_cast(cfg.head_dim * cfg.num_key_value_heads), + .k_dtype = mllm::kFloat32, + .v_dtype = mllm::kFloat32, + .mmap_type = mllm::prefix_cache::ZenFSBlobMMapType::kAnonymous, + }}}); + } + + std::string ltrim(const std::string& s) { + size_t start = s.find_first_not_of(" \n\r\t\f\v"); + return (start == std::string::npos) ? "" : s.substr(start); + } + + std::string rtrim(const std::string& s) { + size_t end = s.find_last_not_of(" \n\r\t\f\v"); + return (end == std::string::npos) ? "" : s.substr(0, end + 1); + } + + std::string trim(const std::string& s) { return rtrim(ltrim(s)); } + + std::string applyChatTemplate(const json& messages, const std::vector& tools = {}, bool add_generation_prompt = true, + bool enable_thinking = true, const std::string& bos_token = "", + const std::string& eos_token = "<|im_end|>") { + std::ostringstream oss; + + if (!tools.empty()) { + oss << "<|im_start|>system\n"; + if (!messages.empty() && messages[0].value("role", "") == "system") { oss << messages[0].value("content", "") << "\n\n"; } + oss << "# Tools\n\nYou may call one or more functions to assist with the user query.\n"; + oss << "You are provided with function signatures within XML tags:\n"; + for (const auto& tool : tools) { oss << "\n" << tool.dump(); } + oss << "\n\n\nFor each function call, return a json object with function name and arguments within " + " XML tags:\n\n{\"name\": , \"arguments\": " + "}\n<|im_end|>\n"; + } else { + if (!messages.empty() && messages[0].value("role", "") == "system") { + oss << "<|im_start|>system\n" << messages[0].value("content", "") << "<|im_end|>\n"; + } + } + + size_t last_query_index = messages.empty() ? 0 : messages.size() - 1; + bool found_last_query = false; + if (!messages.empty()) { + for (int i = messages.size() - 1; i >= 0; --i) { + const auto& msg = messages[i]; + if (msg.value("role", "") == "user" && msg.contains("content") && msg["content"].is_string()) { + std::string content_str = msg["content"].get(); + if (!(content_str.starts_with("") + && content_str.find("") == content_str.length() - std::string("").length())) { + last_query_index = i; + found_last_query = true; + break; + } + } + } + } + if (messages.empty()) { found_last_query = false; } + + for (size_t i = 0; i < messages.size(); ++i) { + const auto& message = messages[i]; + std::string role = message.value("role", ""); + std::string content; + if (message.contains("content") && message["content"].is_string()) { content = message["content"].get(); } + + if (role == "user" || (role == "system" && i > 0)) { + oss << "<|im_start|>" << role << "\n" << content << "<|im_end|>\n"; + } else if (role == "assistant") { + std::string reasoning_content; + if (message.contains("reasoning_content") && message["reasoning_content"].is_string()) { + reasoning_content = message["reasoning_content"].get(); + } else { + auto think_end_pos = content.find(""); + if (think_end_pos != std::string::npos) { + auto think_start_pos = content.rfind("", think_end_pos); + if (think_start_pos != std::string::npos) { + reasoning_content = content.substr(think_start_pos + 7, think_end_pos - (think_start_pos + 7)); + content = content.substr(think_end_pos + 8); + } + } + } + + oss << "<|im_start|>" << role << "\n"; + if (found_last_query && i > last_query_index) { + if ((i == messages.size() - 1) || !reasoning_content.empty()) { + oss << "\n" << trim(reasoning_content) << "\n\n\n" << ltrim(content); + } else { + oss << content; + } + } else { + oss << content; + } + + if (message.contains("tool_calls")) { + bool is_first_tool = true; + for (const auto& tool_call_item : message["tool_calls"]) { + if ((is_first_tool && !content.empty()) || !is_first_tool) { oss << "\n"; } + is_first_tool = false; + + const json* tool_call_ptr = &tool_call_item; + if (tool_call_item.contains("function")) { tool_call_ptr = &tool_call_item["function"]; } + const json& tool_call = *tool_call_ptr; + + oss << "\n{\"name\": \"" << tool_call.value("name", "") << R"(", "arguments": )"; + const auto& args = tool_call["arguments"]; + if (args.is_string()) { + oss << args.get(); + } else { + oss << args.dump(); + } + oss << "}\n"; + } + } + oss << "<|im_end|>\n"; + + } else if (role == "tool") { + if (i == 0 || messages[i - 1].value("role", "") != "tool") { oss << "<|im_start|>user"; } + oss << "\n\n" << content << "\n"; + if (i == messages.size() - 1 || messages[i + 1].value("role", "") != "tool") { oss << "<|im_end|>\n"; } + } + } + + if (add_generation_prompt) { + oss << "<|im_start|>assistant\n"; + if (!enable_thinking) { oss << "\n\n\n\n"; } + } + + return oss.str(); + } + + private: + // States + std::vector> k_cache_addrs_; + std::vector> v_cache_addrs_; + + // Owned data + std::shared_ptr model_; + std::shared_ptr tokenizer_; + std::shared_ptr cache_; +}; } // namespace mllm::models::qwen3 diff --git a/mllm/nn/Functional.cpp b/mllm/nn/Functional.cpp index e4697d7ee..c6803af9c 100644 --- a/mllm/nn/Functional.cpp +++ b/mllm/nn/Functional.cpp @@ -7,6 +7,7 @@ #include "mllm/core/aops/FlashAttention2Op.hpp" #include "mllm/core/aops/MatMulOp.hpp" #include "mllm/core/aops/ReduceOps.hpp" +#include "mllm/core/aops/Scatter2ShardsOp.hpp" #include "mllm/core/aops/SoftmaxOp.hpp" #include "mllm/core/aops/ElewiseOps.hpp" #include "mllm/core/aops/SplitOp.hpp" @@ -112,4 +113,9 @@ Tensor silu_(const Tensor& x) { return Context::instance().buildOpAndSubmitTask(OpTypes::kSiLU, opt, {x})[0]; } +void scatter2Shards(const Tensor& src, const Tensor& shards_pointer, int32_t dim) { + Context::instance().buildOpAndSubmitTask(OpTypes::kScatter2Shards, aops::Scatter2ShardsOpOptions{.dim = dim}, + {src, shards_pointer}); +} + } // namespace mllm::nn::functional diff --git a/mllm/nn/Functional.hpp b/mllm/nn/Functional.hpp index bd88ab836..9efad0ab2 100644 --- a/mllm/nn/Functional.hpp +++ b/mllm/nn/Functional.hpp @@ -127,4 +127,6 @@ Tensor mean(const Tensor& x, int32_t dim = std::numeric_limits::max(), Tensor silu(const Tensor& x); Tensor silu_(const Tensor& x); +void scatter2Shards(const Tensor& src, const Tensor& shards_pointer, int32_t dim); + } // namespace mllm::nn::functional diff --git a/mllm/nn/Nn.hpp b/mllm/nn/Nn.hpp index fa9074355..7bab5ebaa 100644 --- a/mllm/nn/Nn.hpp +++ b/mllm/nn/Nn.hpp @@ -26,3 +26,4 @@ #include "mllm/nn/layers/Conv1D.hpp" // IWYU pragma: export #include "mllm/nn/layers/STFT.hpp" // IWYU pragma: export #include "mllm/nn/layers/PagedAttn.hpp" // IWYU pragma: export +#include "mllm/nn/layers/RadixAttn.hpp" // IWYU pragma: export diff --git a/mllm/nn/layers/PagedAttn.cpp b/mllm/nn/layers/PagedAttn.cpp index 84484588c..5aa664135 100644 --- a/mllm/nn/layers/PagedAttn.cpp +++ b/mllm/nn/layers/PagedAttn.cpp @@ -1,7 +1,6 @@ // Copyright (c) MLLM Team. // Licensed under the MIT License. #include "mllm/core/aops/PagedAttnOp.hpp" -#include "mllm/utils/Common.hpp" #include "mllm/nn/layers/PagedAttn.hpp" namespace mllm::nn { @@ -18,14 +17,4 @@ PagedAttn::PagedAttn(int32_t head_repeat_times, bool high_precision_exp, bool fu .need_attn_weights = need_attn_weights, .impl_type = impl_type}) {} -PagedAttn::PagedAttn(void* ctx, bool high_precision_exp, bool fuse_rope, aops::PagedAttnImplType impl_type) - : Layer(OpTypes::kPagedAttn, aops::PagedAttnOpOptions{.head_repeat_times = -1, - .high_precision_exp = high_precision_exp, - .fuse_rope = fuse_rope, - .need_attn_weights = false, - .impl_type = impl_type, - .prefix_cache_ctx = ctx}) { - if (ctx == nullptr) { MLLM_ERROR_EXIT(ExitCode::kCoreError, "prefix_cache_ctx is empty."); } -} - } // namespace mllm::nn diff --git a/mllm/nn/layers/PagedAttn.hpp b/mllm/nn/layers/PagedAttn.hpp index 4aab7396d..b4b1be4a8 100644 --- a/mllm/nn/layers/PagedAttn.hpp +++ b/mllm/nn/layers/PagedAttn.hpp @@ -17,10 +17,6 @@ class PagedAttn : public Layer { explicit PagedAttn(int32_t head_repeat_times, bool high_precision_exp = false, bool fuse_rope = false, bool need_attn_weights = false, aops::PagedAttnImplType impl_type = aops::PagedAttnImplType::kAllFp32); - // Prefixed Cache Paged Attn Constructor - explicit PagedAttn(void* ctx, bool high_precision_exp = false, bool fuse_rope = false, - aops::PagedAttnImplType impl_type = aops::PagedAttnImplType::kPrefixCache); - MLLM_LAYER_ANY_INPUTS_2_OUTPUTS_FORWARD }; diff --git a/mllm/nn/layers/RadixAttn.cpp b/mllm/nn/layers/RadixAttn.cpp new file mode 100644 index 000000000..cf0c891b8 --- /dev/null +++ b/mllm/nn/layers/RadixAttn.cpp @@ -0,0 +1,16 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/core/aops/RadixAttnOp.hpp" +#include "mllm/nn/layers/RadixAttn.hpp" + +namespace mllm::nn { + +RadixAttn::RadixAttn() : Layer(OpTypes::kRadixAttn, aops::RadixAttnOpOptions{}) {} + +RadixAttn::RadixAttn(const aops::RadixAttnOpOptions& options) : Layer(OpTypes::kRadixAttn, options) {} + +RadixAttn::RadixAttn(int32_t H_Q, int32_t H_KV) + : Layer(OpTypes::kRadixAttn, aops::RadixAttnOpOptions{.H_Q = H_Q, .H_KV = H_KV}) {} + +} // namespace mllm::nn diff --git a/mllm/nn/layers/RadixAttn.hpp b/mllm/nn/layers/RadixAttn.hpp new file mode 100644 index 000000000..a641ccb7d --- /dev/null +++ b/mllm/nn/layers/RadixAttn.hpp @@ -0,0 +1,23 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/nn/Layer.hpp" +#include "mllm/core/aops/RadixAttnOp.hpp" + +namespace mllm::nn { + +class RadixAttn : public Layer { + public: + RadixAttn(); + + explicit RadixAttn(const aops::RadixAttnOpOptions& options); + + RadixAttn(int32_t H_Q, int32_t H_KV); + + // Q, K, V in and one output out + MLLM_LAYER_ANY_INPUTS_1_OUTPUTS_FORWARD +}; + +} // namespace mllm::nn diff --git a/mllm/nn/layers/RoPE.cpp b/mllm/nn/layers/RoPE.cpp index 8659aa017..3b44c529b 100644 --- a/mllm/nn/layers/RoPE.cpp +++ b/mllm/nn/layers/RoPE.cpp @@ -8,7 +8,9 @@ namespace mllm::nn { RoPE::RoPE() : Layer(OpTypes::kRoPE, aops::RoPEOpOptions{}) {} -RoPE::RoPE(float theta, int32_t max_position_embeddings) - : Layer(OpTypes::kRoPE, aops::RoPEOpOptions(theta, max_position_embeddings)) {} +RoPE::RoPE(float theta, int32_t max_position_embeddings, aops::RoPEOpOptionsInputType input_type) + : Layer(OpTypes::kRoPE, + aops::RoPEOpOptions{ + .rope_theta = theta, .max_position_embeddings = max_position_embeddings, .input_type = input_type}) {} -} // namespace mllm::nn \ No newline at end of file +} // namespace mllm::nn diff --git a/mllm/nn/layers/RoPE.hpp b/mllm/nn/layers/RoPE.hpp index abc38c680..2ca512a41 100644 --- a/mllm/nn/layers/RoPE.hpp +++ b/mllm/nn/layers/RoPE.hpp @@ -11,7 +11,9 @@ namespace mllm::nn { class RoPE : public Layer { public: RoPE(); - explicit RoPE(float theta, int32_t max_position_embeddings); + + RoPE(float theta, int32_t max_position_embeddings, + aops::RoPEOpOptionsInputType input_type = aops::RoPEOpOptionsInputType::kBHSD); MLLM_LAYER_ANY_INPUTS_1_OUTPUTS_FORWARD @@ -20,4 +22,4 @@ class RoPE : public Layer { int32_t max_position_embeddings_ = 128; }; -} // namespace mllm::nn \ No newline at end of file +} // namespace mllm::nn diff --git a/mllm/nn/lmcache/PrefixCache.hpp b/mllm/nn/lmcache/PrefixCache.hpp index b32f7234d..051654626 100644 --- a/mllm/nn/lmcache/PrefixCache.hpp +++ b/mllm/nn/lmcache/PrefixCache.hpp @@ -3,4 +3,4 @@ #pragma once -#include "mllm/engine/prefix_cache/Cache.hpp" +#include "mllm/engine/prefix_cache/Cache.hpp" // IWYU pragma: export diff --git a/pymllm/README.md b/pymllm/README.md index 5240a5fa4..e69de29bb 100644 --- a/pymllm/README.md +++ b/pymllm/README.md @@ -1,23 +0,0 @@ -This file is used to request a PyPI organization. Should not be packaged in production. - -**PyPI Staff** - -Hello ubios! Upon reviewing this organization application request, we were unable to determine your affiliation with the provided URL (https://ubiquitouslearning.github.io/mllm/). Please let us know if there is some other way for us to verify this, otherwise we will decline this request. - -**chenghuaWang** - -Hello, -Thank you for your response. - -I am the maintainer of the project hosted at https://ubiquitouslearning.github.io/mllm/ and a contributor/administrator of this GitHub repository. I am currently developing the next major version (v2) of mllm; you can see my commits at https://github.com/UbiquitousLearning/mllm/commits/v2/ under the username chenghuaWang. The “mllm” organization is tied to this project, and I would like to create it on PyPI to publish and manage the related packages. - -For verification, please find below: - -1. My GitHub profile: https://github.com/chenghuaWang -2. Repository link: https://github.com/ubiquitouslearning/mllm -3. I have added a temporary note in the repository to confirm this request: https://github.com/UbiquitousLearning/mllm/blob/v2/pymllm/README.md (our conversation is quoted there). - -Let me know if any further proof is needed. - -Best regards, -chenghua.Wang diff --git a/tests/cpu/CMakeLists.txt b/tests/cpu/CMakeLists.txt index 940544be4..478f867d9 100644 --- a/tests/cpu/CMakeLists.txt +++ b/tests/cpu/CMakeLists.txt @@ -1,4 +1,3 @@ - add_executable(Mllm-Test-CPUKernel KernelTest.cpp) target_link_libraries(Mllm-Test-CPUKernel PRIVATE gtest_main MllmRT MllmCPUBackend) target_include_directories(Mllm-Test-CPUKernel PRIVATE ${MLLM_INCLUDE_DIR}) diff --git a/tests/cpu/KernelTest.cpp b/tests/cpu/KernelTest.cpp index 06ddf432c..9a2d19803 100644 --- a/tests/cpu/KernelTest.cpp +++ b/tests/cpu/KernelTest.cpp @@ -253,42 +253,6 @@ TEST_F(ElementwiseKernelTest, DivFloat16) { true); } -TEST_F(ElementwiseKernelTest, DivInt8) { - EXPECT_EQ(DivInt8Test({ - {42}, - {5, 5}, - {16, 16}, - {16, 18}, - {32, 32}, - {128, 128, 128}, - }), - true); -} - -TEST_F(ElementwiseKernelTest, DivInt16) { - EXPECT_EQ(DivInt16Test({ - {42}, - {5, 5}, - {16, 16}, - {16, 18}, - {32, 32}, - {128, 128, 128}, - }), - true); -} - -TEST_F(ElementwiseKernelTest, DivInt32) { - EXPECT_EQ(DivInt32Test({ - {42}, - {5, 5}, - {16, 16}, - {16, 18}, - {32, 32}, - {128, 128, 128}, - }), - true); -} - //===----------------------------------------------------------------------===// // Element wise ADD Scalar. // @@ -789,9 +753,16 @@ TEST_F(ReduceKernelTest, SumFloat32) { // } #endif +//===----------------------------------------------------------------------===// +// Scatter 2 Shards Attn +//===----------------------------------------------------------------------===// +#include "Scatter2ShardsKernelTest.hpp" +TEST_F(Scatter2ShardsKernelTest, one) { EXPECT_EQ(testScatter2Shards(), true); } + //===----------------------------------------------------------------------===// // Paged Attn //===----------------------------------------------------------------------===// +#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) #include "PagedAttnTest.hpp" TEST_F(PagedAttnTest, fwd_bshd) { EXPECT_EQ(manyCases({ @@ -803,6 +774,38 @@ TEST_F(PagedAttnTest, fwd_bshd) { }), true); } +#endif + +//===----------------------------------------------------------------------===// +// Radix Attn +//===----------------------------------------------------------------------===// +#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) +#include "RadixAttnKernel.hpp" +TEST_F(RadixAttnKernelTest, fwd_bshd) { + EXPECT_EQ(testRadixAttn({{ + {"H_Q", 28}, + {"H_KV", 2}, + {"S_Q", 10}, + {"S_KV", 10}, + {"D", 128}, + }, + { + {"H_Q", 28}, + {"H_KV", 2}, + {"S_Q", 10}, + {"S_KV", 20}, + {"D", 128}, + }, + { + {"H_Q", 28}, + {"H_KV", 2}, + {"S_Q", 1}, + {"S_KV", 20}, + {"D", 128}, + }}), + true); +} +#endif int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); diff --git a/tests/cpu/RadixAttnKernel.hpp b/tests/cpu/RadixAttnKernel.hpp new file mode 100644 index 000000000..d547d43ff --- /dev/null +++ b/tests/cpu/RadixAttnKernel.hpp @@ -0,0 +1,186 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include +#include + +#include "mllm/mllm.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/nn/lmcache/PrefixCache.hpp" + +#include "KernelTestHelper.hpp" + +using namespace mllm; // NOLINT + +class RadixAttnModule : public nn::Module { + nn::RadixAttn attn_; + + public: + RadixAttnModule() = default; + + RadixAttnModule(int H_Q, int H_KV) : nn::Module() { attn_ = reg("attn", H_Q, H_KV); } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + // inputs is Q, K_indices, V_indices + return {attn_(inputs[0], inputs[1], inputs[2])}; + } +}; + +class EagerModule : public nn::Module { + public: + EagerModule() : nn::Module() {} + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + // inputs is Q, K_indices, V_indices + // Q, K, V is [B, S, H, D] + auto Q = inputs[0]; + auto K = inputs[1]; + auto V = inputs[2]; + + auto h_q = Q.shape()[2]; + auto h_kv = K.shape()[2]; + auto head_dim = Q.shape()[3]; + + // Q, K, V is [B, H, S, D] + Q = Q.transpose(1, 2); + K = K.transpose(1, 2).repeat(h_q / h_kv, 1); + V = V.transpose(1, 2).repeat(h_q / h_kv, 1); + + // Attention Weight + // [B, H, S, S] + auto attn = nn::functional::matmul(Q, K, false, true) * (1.f / sqrtf(head_dim)); + + // Make mask + auto S_Q = Q.shape()[2]; + auto S_KV = K.shape()[2]; + auto mask = Tensor::zeros({1, 1, S_Q, S_KV}); + { + auto ptr = mask.ptr(); + int __delta = S_KV - S_Q; + for (int s_q_idx = 0; s_q_idx < S_Q; s_q_idx++) { + int S_KV_BOUND = std::min(__delta + s_q_idx + 1, S_KV); + for (int s_kv_idx = S_KV_BOUND; s_kv_idx < S_KV; s_kv_idx++) { + ptr[s_q_idx * S_KV + s_kv_idx] = -std::numeric_limits::infinity(); + } + } + } + + attn = nn::functional::softmax(attn + mask, -1); + // [B, H, S, D] + auto output = nn::functional::matmul(attn, V); + // [B, S, H, D] + output = output.transpose(1, 2); + + return {output}; + } +}; + +class RadixAttnKernelTest : public KernelTest { + public: + RadixAttnKernelTest() = default; + ~RadixAttnKernelTest() override = default; + + bool testRadixAttnOnce(const std::unordered_map& cfg) { + int B = 1; + int H_Q = cfg.at("H_Q"); + int H_KV = cfg.at("H_KV"); + int S_Q = cfg.at("S_Q"); + int S_KV = cfg.at("S_KV"); + int D = cfg.at("D"); + + mllm::prefix_cache::CacheOptions opt{ + .radix_tree_options = {.enable_lru_eviction = false, + .eviction_threshold = 0.9f, + .enable_path_compression = false, + .min_compression_length = 2, + .transformer_blocks_num = 1}, + .allocator_options = {// Normal things. + .per_k_token_ele = static_cast(H_KV * D), + .per_v_token_ele = static_cast(H_KV * D), + .k_dtype = mllm::kFloat32, + .v_dtype = mllm::kFloat32, + + // CUDA things. + .enable_cuda = false, + .cuda_mem_base = 0x100000, + + // CPU things. + .enable_cpu_hierarchy_memory = true, + .zen_fs_options = { + .record = false, + .working_dir = ".", + .blob_bits_size = 20, + .page_bits = 7, + .lane_bits = 5, + .per_k_token_ele = static_cast(H_KV * D), + .per_v_token_ele = static_cast(H_KV * D), + .k_dtype = mllm::kFloat32, + .v_dtype = mllm::kFloat32, + .mmap_type = mllm::prefix_cache::ZenFSBlobMMapType::kAnonymous, + }}}; + EagerModule eager_attn; + RadixAttnModule radix_attn(H_Q, H_KV); + prefix_cache::Cache cache(opt); + + // Create Q, K, V + auto Q = Tensor::random({B, S_Q, H_Q, D}, -10.f, 10.f); + auto K = Tensor::random({B, S_KV, H_KV, D}, -10.f, 10.f); + auto V = Tensor::random({B, S_KV, H_KV, D}, -10.f, 10.f); + + // Insert K and V into Cache and Radix Tree + std::vector k_cache_addrs; + std::vector v_cache_addrs; + for (int i = 0; i < S_KV; i++) { + k_cache_addrs.push_back(cache.alloc(kCPU)); + v_cache_addrs.push_back(cache.alloc(kCPU)); + } + std::vector k_cache_ptrs; + std::vector v_cache_ptrs; + for (int i = 0; i < S_KV; i++) { + k_cache_ptrs.push_back(cache.physicalAddr(k_cache_addrs[i])); + v_cache_ptrs.push_back(cache.physicalAddr(v_cache_addrs[i])); + } + auto k_cache_indices = Tensor::refVectorData(k_cache_ptrs, {S_KV}, kInt64); + auto v_cache_indices = Tensor::refVectorData(v_cache_ptrs, {S_KV}, kInt64); + + nn::functional::scatter2Shards(K, k_cache_indices, 1); + nn::functional::scatter2Shards(V, v_cache_indices, 1); + + // loop check if shards is correct + for (int i = 0; i < S_KV; ++i) { + auto prd_ptr = (float*)k_cache_ptrs[i]; + auto gt_ptr = K.offsettedPtr({0, i, 0, 0}); + for (int j = 0; j < H_KV * D; ++j) { + if (prd_ptr[j] != gt_ptr[j]) { + print("Error at: ", i, j); + print("prd: ", prd_ptr[j], " gt: ", gt_ptr[j]); + return false; + } + } + } + + // Compute eager + Tensor gt = eager_attn(Q, K, V)[0]; + Tensor predict = radix_attn(Q, k_cache_indices, v_cache_indices)[0]; + + // Compare + // rtol and atol set to 1e-2f is because: + // 1. The eager softmax is approximate, but radix is not. + auto result = test::allClose(gt, predict, 1e-2f, 1e-2f); + if (!result) { + print(result); + print("S_Q and S_KV is", S_Q, S_KV); + print(predict); + return false; + } + return true; + } + + bool testRadixAttn(const std::vector>& cfgs) { + for (auto& cfg : cfgs) { + if (!testRadixAttnOnce(cfg)) { return false; } + } + return true; + } +}; diff --git a/tests/cpu/Scatter2ShardsKernelTest.hpp b/tests/cpu/Scatter2ShardsKernelTest.hpp new file mode 100644 index 000000000..d575dd529 --- /dev/null +++ b/tests/cpu/Scatter2ShardsKernelTest.hpp @@ -0,0 +1,28 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/mllm.hpp" + +#include "KernelTestHelper.hpp" + +class Scatter2ShardsKernelTest : public KernelTest { + public: + Scatter2ShardsKernelTest() = default; + ~Scatter2ShardsKernelTest() override = default; + + bool testScatter2Shards() { + // B, S, H, D + auto Q = mllm::Tensor::random({1, 100, 8, 64}, -10, 10); + auto P = mllm::Tensor::zeros({1, 100, 8, 64}); + + std::vector indices; + indices.reserve(100); + for (int i = 0; i < 100; ++i) { indices.push_back((char*)P.offsettedPtr({0, i, 0, 0})); } + + auto tensor_indices = mllm::Tensor::refVectorData(indices, {100}, mllm::kInt64); + + mllm::nn::functional::scatter2Shards(Q, tensor_indices, 1); + + return mllm::test::allClose(Q, P).is_close; + } +}; diff --git a/tests/engine/CMakeLists.txt b/tests/engine/CMakeLists.txt index e0d673314..9d9b48161 100644 --- a/tests/engine/CMakeLists.txt +++ b/tests/engine/CMakeLists.txt @@ -17,3 +17,7 @@ target_include_directories(Mllm-Test-Engine-PrefixCache PRIVATE ${MLLM_INCLUDE_D add_executable(Mllm-Test-Engine-Service ServiceTest.cpp) target_link_libraries(Mllm-Test-Engine-Service PRIVATE MllmRT MllmCPUBackend) target_include_directories(Mllm-Test-Engine-Service PRIVATE ${MLLM_INCLUDE_DIR}) + +add_executable(Mllm-Test-Engine-RandomStates RandomStatesTest.cpp) +target_link_libraries(Mllm-Test-Engine-RandomStates PRIVATE MllmRT MllmCPUBackend) +target_include_directories(Mllm-Test-Engine-RandomStates PRIVATE ${MLLM_INCLUDE_DIR}) diff --git a/tests/engine/RandomStatesTest.cpp b/tests/engine/RandomStatesTest.cpp new file mode 100644 index 000000000..446432eed --- /dev/null +++ b/tests/engine/RandomStatesTest.cpp @@ -0,0 +1,12 @@ +#include + +#include "mllm/mllm.hpp" + +using namespace mllm; // NOLINT + +int main() { + mllm::initializeContext(); + mllm::setRandomSeed(42); + for (int i = 0; i < 10; i++) { print(mllm::getRandomState()); } + mllm::memoryReport(); +} diff --git a/mllm/backends/cpu/kernels/common/paged_attn_x/impl-any.hpp b/tools/mllm-llm-benchmark/README.md similarity index 100% rename from mllm/backends/cpu/kernels/common/paged_attn_x/impl-any.hpp rename to tools/mllm-llm-benchmark/README.md diff --git a/tools/mllm-vlm-benchmark/README.md b/tools/mllm-vlm-benchmark/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/tox.ini b/tox.ini new file mode 100644 index 000000000..e69de29bb