Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ add_subdirectory(qwen2_5vl_tracer)
add_subdirectory(llama)
add_subdirectory(minicpm_o)
add_subdirectory(qwen3)
add_subdirectory(qwen3_service)
3 changes: 3 additions & 0 deletions examples/qwen3_service/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
add_executable(mllm-qwen3-service main.cpp)
target_link_libraries(mllm-qwen3-service PRIVATE MllmRT MllmCPUBackend)
target_include_directories(mllm-qwen3-service PRIVATE ${MLLM_INCLUDE_DIR})
91 changes: 91 additions & 0 deletions examples/qwen3_service/main.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#include <string>
#include <vector>

#include <fmt/core.h>
#include <nlohmann/json.hpp>

#include <mllm/mllm.hpp>
#include <mllm/engine/service/Service.hpp>
#include <mllm/models/qwen3/modeling_qwen3_service.hpp>

MLLM_MAIN({
mllm::setLogLevel(mllm::LogLevel::kError);
auto& model_path = mllm::Argparse::add<std::string>("-m|--model_path").help("Model path").required(true);
mllm::Argparse::parse(argc, argv);

auto qwen3_session = std::make_shared<mllm::models::qwen3::Qwen3Session>();
qwen3_session->fromPreTrain(model_path.get());
mllm::service::insertSession("mllmTeam/Qwen3-0.6B-w4a32kai", qwen3_session);
mllm::service::startService();

std::vector<nlohmann::json> history;
const std::string model_name = "mllmTeam/Qwen3-0.6B-w4a32kai";

std::cout << "Enter /exit or /quit to exit this program\n";

while (true) {
std::cout << "\nUser: ";
std::string user_input;
std::getline(std::cin, user_input);
if (user_input == "/exit" || user_input == "/quit") break;

nlohmann::json user_msg;
user_msg["role"] = "user";
user_msg["content"] = user_input;
history.push_back(user_msg);

nlohmann::json req;
req["model"] = model_name;
req["messages"] = history;
req["id"] = "chat-multi";
req["enable_thinking"] = true;
mllm::service::sendRequest(req.dump());
std::string assistant_content;

bool thinking_states = false;

while (true) {
std::string resp = mllm::service::getResponse("chat-multi");
auto j = nlohmann::json::parse(resp);

if (j.contains("choices") && j["choices"].size() > 0 && j["choices"][0].contains("delta")
&& j["choices"][0]["delta"].contains("content")) {
std::string delta = j["choices"][0]["delta"]["content"].get<std::string>();

if (delta == "<think>") {
thinking_states = true;
fmt::print(fmt::fg(fmt::color::gray) | fmt::emphasis::bold | fmt::emphasis::underline, "Thinking...:");
continue;
}
if (delta == "</think>") {
thinking_states = false;
fmt::print("\n");
continue;
}

if (thinking_states) {
fmt::print(fmt::fg(fmt::color::gray), "{}", delta);
std::fflush(stdout);
} else {
fmt::print("{}", delta);
std::fflush(stdout);
}

assistant_content += delta;
}

if (j.contains("choices") && j["choices"].size() > 0 && j["choices"][0].contains("finish_reason")
&& j["choices"][0]["finish_reason"].is_string() && j["choices"][0]["finish_reason"].get<std::string>() == "stop") {
break;
}
}

nlohmann::json assistant_msg;
assistant_msg["role"] = "assistant";
assistant_msg["content"] = assistant_content;
history.push_back(assistant_msg);
}

mllm::service::stopService();
return 0;
})
4 changes: 3 additions & 1 deletion mllm/backends/cpu/CPUBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mllm/backends/cpu/ops/FillOp.hpp"
#include "mllm/backends/cpu/ops/FlashAttention2Op.hpp"
#include "mllm/backends/cpu/ops/GELUOp.hpp"
#include "mllm/backends/cpu/ops/RadixAttnOp.hpp"
#include "mllm/backends/cpu/ops/ReLUOp.hpp"
#include "mllm/backends/cpu/ops/GraphOps.hpp"
#include "mllm/backends/cpu/ops/ISTFTOp.hpp"
Expand All @@ -35,6 +36,7 @@
#include "mllm/backends/cpu/ops/RepeatOp.hpp"
#include "mllm/backends/cpu/ops/RoPEOp.hpp"
#include "mllm/backends/cpu/ops/STFTOp.hpp"
#include "mllm/backends/cpu/ops/Scatter2ShardsOp.hpp"
#include "mllm/backends/cpu/ops/SiLUOp.hpp"
#include "mllm/backends/cpu/ops/SliceOp.hpp"
#include "mllm/backends/cpu/ops/SoftmaxOp.hpp"
Expand All @@ -58,7 +60,7 @@ CPUBackend::CPUBackend() : Backend(kCPU, createCPUAllocator()) {
CPUFlashAttention2OpFactory, CPUSliceOpFactory, CPUVisionRoPEOpFactory, CPUParamOpFactory,
CPUMultimodalRoPEOpFactory, CPURoPEOpFactory, CPUCausalMaskOpFactory, CPUConv1DOpFactory, CPUConv3DOpFactory,
CPUSTFTOpFactory, CPUISTFTOpFactory, CPUIndexOpFactory, CPUTopKOpFactory, CPUClipOpFactory, CPUMeanOpFactory,
CPUKVCacheOpFactory, CPUPagedAttnOpFactory>();
CPUKVCacheOpFactory, CPUPagedAttnOpFactory, CPUScatter2ShardsOpFactory, CPURadixAttnOpFactory>();
}

std::shared_ptr<CPUBackend> createCPUBackend() { return std::make_shared<CPUBackend>(); }
Expand Down
100 changes: 0 additions & 100 deletions mllm/backends/cpu/kernels/common/paged_attn_x/impl-arm.hpp

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include <cassert>
#include "mllm/utils/Common.hpp"

namespace mllm::cpu::paged_attn_x::details {
namespace mllm::cpu::radix_attn::details {

struct __AnyArchTag {};
using any_arch_tag = __AnyArchTag;
Expand All @@ -32,4 +32,9 @@ struct FMAConstArray {
static MLLM_FORCE_INLINE void run(T* __restrict__ acc_o, const U acc_s, const V* __restrict__ v_token, size_t len) {}
};

} // namespace mllm::cpu::paged_attn_x::details
template<typename ArchTag, typename T>
struct FilledWithConst {
static MLLM_FORCE_INLINE void run(T* __restrict__ a, const T v, size_t len) {}
};

} // namespace mllm::cpu::radix_attn::details
Original file line number Diff line number Diff line change
Expand Up @@ -9,63 +9,62 @@
#include <numbers>
#include "mllm/core/Parallel.hpp"
#include "mllm/utils/CPUArchHelper.hpp"
#include "mllm/engine/prefix_cache/TLB.hpp"
#include "mllm/backends/cpu/kernels/common/paged_attn_x/arch.hpp"
#include "mllm/backends/cpu/kernels/common/radix_attn/arch.hpp"

#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH)
#include "mllm/backends/cpu/kernels/common/paged_attn_x/impl-arm.hpp"
#include "mllm/backends/cpu/kernels/common/radix_attn/impl-arm.hpp"
#else
#include "mllm/backends/cpu/kernels/common/paged_attn_x/impl-any-simd.hpp"
#include "mllm/backends/cpu/kernels/common/radix_attn/impl-any-simd.hpp"
#endif
#include "mllm/backends/cpu/kernels/common/paged_attn_x/impl-any.hpp"
#include "mllm/backends/cpu/kernels/common/radix_attn/impl-any.hpp"

namespace mllm::cpu::paged_attn_x {
namespace mllm::cpu::radix_attn {

// BHSD
// K: [S_KV], address, not contiguous
// V: [S_KV], address, not contiguous
// Q: [B, H_Q, S_Q, D], contiguous
// Q: [B, S_Q, H_Q, D], contiguous
//
// After find KV Tokens, KV is [B, 1, H_KV, D]
//
// H_KV should <= H_Q
template<typename __ArchTag, typename __QDType, typename __KDType, typename __VDType, typename __ODType, typename __AccDType,
bool high_precession_exp = true>
void fwd_bhsd(int32_t B, int32_t H_Q, int32_t H_KV, int32_t S_Q, int32_t S_KV, int32_t D, const __QDType* __restrict__ __q,
const mllm::prefix_cache::vp_addr_t* __k, const mllm::prefix_cache::vp_addr_t* __v, __ODType* __restrict__ __out,
void* ctx, int32_t thread_count) {
__KDType** __k, __VDType** __v, __ODType* __restrict__ __out, int32_t thread_count) {
int32_t head_repeat_times = H_Q / H_KV;

__AccDType scale = scale = std::sqrt(1.0 / D) * (__AccDType)std::numbers::log2e;

// Loop on batch size.
for (int b_idx = 0; b_idx < B; ++b_idx) {
// Loop on head dim, should be made parallel
// FIXME: Loop on SEQUENCE may faster?
// seq_q [ head_q [ seq_kv ] ]

// Loop on HEAD dim, should be made parallel
MLLM_CONDITIONAL_PARALLEL_FOR(thread_count > 1, thread_count, h_q_idx, 0, H_Q, 1, {
int h_kv_id = h_q_idx / head_repeat_times;

// FA2's Loop
for (int s_q_idx = 0; s_q_idx < S_Q; ++s_q_idx) {
__QDType* q_token = __q + b_idx * H_Q * S_Q * D + h_q_idx * S_Q * D + s_q_idx * D;
__ODType* acc_o = __out + b_idx * H_Q * S_Q * D + h_q_idx * S_Q * D + s_q_idx * D;
const __QDType* q_token = __q + b_idx * H_Q * S_Q * D + s_q_idx * H_Q * D + h_q_idx * D;
__ODType* acc_o = __out + b_idx * H_Q * S_Q * D + s_q_idx * H_Q * D + h_q_idx * D;

// FIXME: Boost with SIMD
for (int d_idx = 0; d_idx < D; ++d_idx) { acc_o[d_idx] = 0; }
__AccDType scores_max = std::numeric_limits<__AccDType>::lowest();
__AccDType scores_max_prev = std::numeric_limits<__AccDType>::lowest();
details::FilledWithConst<__ArchTag, __ODType>::run(acc_o, 0, D);

__AccDType scores_max = -std::numeric_limits<__AccDType>::infinity();
__AccDType scores_max_prev = -std::numeric_limits<__AccDType>::infinity();
__AccDType logsum = 0;
__AccDType scores_scale = 0;
__AccDType scores_sum = 0;
__AccDType scores_scale = 0;

int __delta = S_KV - S_Q;
int S_KV_BOUND = std::min(__delta + s_q_idx + 1, S_KV);

for (int s_kv_idx = 0; s_kv_idx < S_KV_BOUND; ++s_kv_idx) {
// TODO, prefetch next

// TODO using context.
// __KDType* k_token = (__KDType*)ctx->access(__k[s_kv_idx]);
// __VDType* v_token = (__VDType*)ctx->access(__v[s_kv_idx]);
__KDType* k_token = NULL;
__VDType* v_token = NULL;
// k_token and v_token shape is [B, 1, H, D]
__KDType* k_token = __k[s_kv_idx];
__VDType* v_token = __v[s_kv_idx];

// Offset to one head.
// k_token and v_token shape is [D]
Expand All @@ -78,27 +77,24 @@ void fwd_bhsd(int32_t B, int32_t H_Q, int32_t H_KV, int32_t S_Q, int32_t S_KV, i

// 2. Do softmax stuff.
scores_max_prev = scores_max;
scores_max = std::numeric_limits<__AccDType>::lowest();
scores_max = std::max(scores_max, acc_s);
scores_max = std::max(scores_max_prev, acc_s);
scores_scale = std::exp2(scores_max_prev * scale - scores_max * scale);
acc_s = std::exp2(acc_s * scale - scores_max * scale);
scores_sum += acc_s; // TODO This line may be error.
scores_sum = acc_s;
logsum = logsum * scores_scale + scores_sum;

// 3. Scale
MulFromConst<__ArchTag, __AccDType, __AccDType>(acc_o, scores_scale, D);
details::MulFromConst<__ArchTag, __AccDType, __AccDType>::run(acc_o, scores_scale, D);

// 4. MMA1.
FMAConstArray<__ArchTag, __AccDType, __AccDType, __AccDType>(acc_o, acc_s, v_token, D);

// TODO, drop this mmap in the future.
details::FMAConstArray<__ArchTag, __AccDType, __AccDType, __AccDType>::run(acc_o, acc_s, v_token, D);
}

// 5. Final Rescale.
MulFromConst<__ArchTag, __AccDType, __AccDType>(acc_o, (1.f / logsum), D);
details::MulFromConst<__ArchTag, __ODType, __AccDType>::run(acc_o, (1.f / logsum), D);
}
});
}
}

} // namespace mllm::cpu::paged_attn_x
} // namespace mllm::cpu::radix_attn
Loading