Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/minicpm_o/config_chattts.json
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@
"rope_theta": 10000.0,
"tie_word_embeddings": false,
"eos_token_id": 2,
"linear_impl_type": "BLAS"
"linear_impl_type": "default"
}
2 changes: 1 addition & 1 deletion examples/minicpm_o/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ MLLM_MAIN({
fmt::print("\n{:*^60}\n", " MiniCPM-o Interactive CLI ");
fmt::print("Enter 'exit' or 'quit' to end the session\n");

std::string image_path = "path/to/your/image.jpg";
std::string image_path = "/Users/luis/Desktop/plane.png";
std::string prompt_text = "描述图片中物体";
mllm::models::minicpmo::MiniCPMOMessage message;
message.prompt = prompt_text;
Expand Down
12 changes: 5 additions & 7 deletions mllm/models/minicpm_o2_6/modeling_minicpmo.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "mllm/models/minicpm_o2_6/modeling_resampler.hpp"
#include "mllm/models/minicpm_o2_6/modeling_qwen2vl_for_minicpmo.hpp"
#include "mllm/models/ARGeneration.hpp"
#include "mllm/utils/Log.hpp"

namespace mllm::models::minicpmo {

Expand Down Expand Up @@ -124,7 +125,7 @@ class MiniCPMOForCausalLM : public models::ARGeneration {
} else if (inputs.count("sequence")) {
input_ids = inputs.at("sequence");
} else {
mllm::print("ERROR: No input_ids or sequence found!");
MLLM_ERROR("No input_ids or sequence found!");
return {};
}

Expand Down Expand Up @@ -209,9 +210,6 @@ class MiniCPMOForCausalLM : public models::ARGeneration {
// }

Tensor merge_vision_text_embeddings(Tensor& text_embeddings, Tensor& vision_embeddings, Tensor& image_bounds) {
mllm::print(text_embeddings.shape());
mllm::print(vision_embeddings.shape());
mllm::print(image_bounds);
auto batch_size = text_embeddings.shape()[0]; // text_embeddings: [1, seq_len, embed_dim]
auto seq_len = text_embeddings.shape()[1];
auto embed_dim = text_embeddings.shape()[2];
Expand All @@ -227,9 +225,9 @@ class MiniCPMOForCausalLM : public models::ARGeneration {
auto end_pos = image_bounds.at<int32_t>({bound_idx, 1}) - 1;
// exactly replace <unk> tokens between <slice> and </slice>
for (int pos = start_pos; pos <= end_pos && vision_idx < vision_seq_len; ++pos, ++vision_idx) {
for (int d = 0; d < embed_dim; ++d) {
text_embeddings.at<float>({b, pos, d}) = vision_embeddings.at<float>({bound_idx, vision_idx, d});
}
float* dst_ptr = text_embeddings.offsettedPtr<float>({b, pos, 0});
const float* src_ptr = vision_embeddings.offsettedPtr<float>({bound_idx, vision_idx, 0});
std::memcpy(dst_ptr, src_ptr, embed_dim * sizeof(float));
}
}
}
Expand Down
87 changes: 34 additions & 53 deletions mllm/models/minicpm_o2_6/modeling_resampler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#pragma once

#include "mllm/core/SlicePrimitives.hpp"
#include "mllm/mllm.hpp"
#include "mllm/models/minicpm_o2_6/modeling_vector_quantize.hpp"
#include "mllm/nn/Module.hpp"
Expand Down Expand Up @@ -83,45 +84,21 @@ class ResamplerAttention : public nn::Module {

// Perform packed in-projection: [query|key|value] = input @ in_proj_weight.T + in_proj_bias
// For cross-attention: q comes from query, k,v come from key_value
auto q_weight = in_proj_weight_.weight()[{{0, embed_dim_}, kAll}];
auto k_weight = in_proj_weight_.weight()[{{embed_dim_, 2 * embed_dim_}, kAll}];
auto v_weight = in_proj_weight_.weight()[{{2 * embed_dim_, 3 * embed_dim_}, kAll}];

auto q_weight = Tensor::empty({embed_dim_, embed_dim_}, kFloat32).alloc();
auto k_weight = Tensor::empty({embed_dim_, embed_dim_}, kFloat32).alloc();
auto v_weight = Tensor::empty({embed_dim_, embed_dim_}, kFloat32).alloc();
for (int i = 0; i < embed_dim_; i++) {
for (int j = 0; j < embed_dim_; j++) {
*q_weight.offsettedPtr<float>({i, j}) = in_proj_weight_.weight().at<float>({i, j});
*k_weight.offsettedPtr<float>({i, j}) = in_proj_weight_.weight().at<float>({embed_dim_ + i, j});
*v_weight.offsettedPtr<float>({i, j}) = in_proj_weight_.weight().at<float>({2 * embed_dim_ + i, j});
}
}

auto q_bias = Tensor::empty({embed_dim_}, kFloat32).alloc();
auto k_bias = Tensor::empty({embed_dim_}, kFloat32).alloc();
auto v_bias = Tensor::empty({embed_dim_}, kFloat32).alloc();
for (int i = 0; i < embed_dim_; i++) {
*q_bias.offsettedPtr<float>({i}) = in_proj_bias_.weight().at<float>({i});
*k_bias.offsettedPtr<float>({i}) = in_proj_bias_.weight().at<float>({embed_dim_ + i});
*v_bias.offsettedPtr<float>({i}) = in_proj_bias_.weight().at<float>({2 * embed_dim_ + i});
}
auto q_bias = in_proj_bias_.weight()[{{0, embed_dim_}}];
auto k_bias = in_proj_bias_.weight()[{{embed_dim_, 2 * embed_dim_}}];
auto v_bias = in_proj_bias_.weight()[{{2 * embed_dim_, 3 * embed_dim_}}];

auto q = nn::functional::matmul(query, q_weight, false, true);
auto k = nn::functional::matmul(key, k_weight, false, true);
auto v = nn::functional::matmul(value, v_weight, false, true);

for (int i = 0; i < num_queries; i++) {
for (int j = 0; j < embed_dim_; j++) { *q.offsettedPtr<float>({i, j}) += q_bias.at<float>({j}); }
}

for (int i = 0; i < seq_len; i++) {
for (int j = 0; j < embed_dim_; j++) {
*k.offsettedPtr<float>({i, j}) += k_bias.at<float>({j});
*v.offsettedPtr<float>({i, j}) += v_bias.at<float>({j});
}
}

for (int i = 0; i < seq_len; i++) {
for (int j = 0; j < embed_dim_; j++) { *k.offsettedPtr<float>({i, j}) += k_bias.at<float>({j}); }
}
q = q + q_bias;
k = k + k_bias;
v = v + v_bias;

auto q_reshaped = Tensor::empty({num_heads_, num_queries, head_dim_}, kFloat32).alloc();
for (int nq = 0; nq < num_queries; nq++) {
Expand Down Expand Up @@ -154,6 +131,10 @@ class ResamplerAttention : public nn::Module {
}
v = v_reshaped;

// q = q.view({num_queries, num_heads_, head_dim_}).transpose(0, 1).contiguous(); // [num_heads, num_queries, head_dim]
// k = k.view({seq_len, num_heads_, head_dim_}).transpose(0, 1).contiguous(); // [num_heads, seq_len, head_dim]
// v = v.view({seq_len, num_heads_, head_dim_}).transpose(0, 1).contiguous(); // [num_heads, seq_len, head_dim]

auto scale = 1.0f / std::sqrt(static_cast<float>(head_dim_));
auto attn_weights = nn::functional::matmul(q, k, false, true) * scale; // [num_heads, num_queries, seq_len]

Expand Down Expand Up @@ -313,28 +294,26 @@ class Resampler : public nn::Module {
std::vector<Tensor> outputs;
for (int32_t b = 0; b < batch_size; ++b) {
// x for this batch
Tensor x_b = Tensor::empty({seq_len, embed_dim_}, kFloat32).alloc();
for (int i = 0; i < seq_len; i++) {
for (int j = 0; j < embed_dim_; j++) { x_b.at<float>({i, j}) = x.at<float>({b, i, j}); }
}
Tensor x_b = x[make_slice(b), kAll, kAll].view({seq_len, embed_dim_});

// pos_embed for this batch
Tensor pos_embed_b = Tensor::empty({seq_len, embed_dim_}, kFloat32).alloc();
for (int i = 0; i < seq_len; i++) {
for (int j = 0; j < embed_dim_; j++) {
if (i < max_patch_len) {
pos_embed_b.at<float>({i, j}) = pos_embed_padded.at<float>({b, i, j});
} else {
pos_embed_b.at<float>({i, j}) = 0.0f;
}
}
}
// Tensor pos_embed_b = Tensor::empty({seq_len, embed_dim_}, kFloat32).alloc();
// for (int i = 0; i < seq_len; i++) {
// for (int j = 0; j < embed_dim_; j++) {
// if (i < max_patch_len) {
// pos_embed_b.at<float>({i, j}) = pos_embed_padded.at<float>({b, i, j});
// } else {
// pos_embed_b.at<float>({i, j}) = 0.0f;
// }
// }
// }
// TODO: handle 'set 0'
Tensor pos_embed_b = pos_embed_padded[make_slice(b), kAll, kAll].view({seq_len, embed_dim_});

auto kv_input = x_b + pos_embed_b;

// key_padding_mask for this batch
Tensor key_padding_mask_b = Tensor::empty({max_patch_len}, kUInt8).alloc();
for (int i = 0; i < max_patch_len; i++) { key_padding_mask_b.at<uint8_t>({i}) = key_padding_mask.at<uint8_t>({b, i}); }
Tensor key_padding_mask_b = key_padding_mask[make_slice(b), kAll].view({max_patch_len});

bool has_padding = false;
for (int i = 0; i < seq_len; i++) {
Expand All @@ -350,11 +329,13 @@ class Resampler : public nn::Module {
}

auto out_tensor = Tensor::empty({batch_size, num_queries_, embed_dim_}, kFloat32).alloc();
for (int i = 0; i < batch_size; i++) {
// Optimize: Use memcpy for contiguous memory copy instead of nested loops
const int32_t query_embed_size = num_queries_ * embed_dim_;
for (int32_t i = 0; i < batch_size; i++) {
auto& out_i = outputs[i];
for (int j = 0; j < num_queries_; j++) {
for (int k = 0; k < embed_dim_; k++) { *out_tensor.offsettedPtr<float>({i, j, k}) = out_i.at<float>({j, k}); }
}
float* dst_ptr = out_tensor.offsettedPtr<float>({i, 0, 0});
const float* src_ptr = out_i.ptr<float>();
std::memcpy(dst_ptr, src_ptr, query_embed_size * sizeof(float));
}

out_tensor = ln_post_(out_tensor);
Expand Down
34 changes: 23 additions & 11 deletions mllm/models/minicpm_o2_6/modeling_siglip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,12 @@ class SiglipVisionEmbeddings final : public nn::Module {

auto batch_size = pixel_values.shape()[0];

// Patch embedding: [B, C, H, W] -> [B, embed_dim, 1, H*W]
auto patch_embeds = patch_embedding_(pixel_values);
// [B, embed_dim, 1, H*W] -> [B, H*W, embed_dim]
auto embeddings = patch_embeds.squeeze(2).transpose(1, 2);
// Conv expects 4D input
pixel_values = pixel_values.view({1, pixel_values.shape()[0], pixel_values.shape()[1], pixel_values.shape()[2]});
auto patch_embeds = patch_embedding_(pixel_values); // Patch embedding: [B, C, H, W] -> [B, embed_dim, 1, H*W]
pixel_values = pixel_values.view({pixel_values.shape()[1], pixel_values.shape()[2], pixel_values.shape()[3]});

auto embeddings = patch_embeds.squeeze(2).transpose(1, 2); // [B, embed_dim, 1, H*W] -> [B, H*W, embed_dim]

// Create position embeddings
if (!tgt_sizes.isNil() && !patch_attention_mask.isNil()) {
Expand Down Expand Up @@ -279,6 +281,10 @@ class SiglipEncoderLayer final : public nn::Module {
auto hidden_states = inputs[0];
auto attention_mask = inputs.size() > 1 ? inputs[1] : Tensor::nil();

// TODO: Perf Issue
// attention > 800ms (1k tokens)
// mlp > 600ms

// Self attention with residual connection
auto residual = hidden_states;
auto normed = layer_norm1_(hidden_states);
Expand Down Expand Up @@ -314,10 +320,7 @@ class SiglipVisionEncoder final : public nn::Module {
auto attention_mask = inputs.size() > 1 ? inputs[1] : Tensor::nil();

auto hidden_states = inputs_embeds;
for (auto& layer : layers_) {
hidden_states = layer(hidden_states, attention_mask)[0];
// break; // For testing, run only one layer
}
for (auto& layer : layers_) { hidden_states = layer(hidden_states, attention_mask)[0]; }
return {hidden_states};
}
};
Expand Down Expand Up @@ -377,6 +380,7 @@ class SiglipVisionModel final : public nn::Module {
patch_attention_mask = patch_attention_mask.squeeze(1); // [B, max_patches]

// Create attention mask for encoder (4D mask for multi-head attention)
// TODO: this will take about 100ms, optimize it
Tensor attention_mask = Tensor::nil();
if (!patch_attention_mask.isNil()) {
auto batch_size = patch_attention_mask.shape()[0];
Expand All @@ -401,12 +405,20 @@ class SiglipVisionModel final : public nn::Module {

// Create 4D attention mask: [B, 1, max_patches, max_patches]
attention_mask = Tensor::empty({batch_size, 1, max_patches, max_patches}, kFloat32).alloc();

// Optimize with cache-friendly access patterns and reduced redundant accesses
for (int b = 0; b < batch_size; b++) {
// Pre-fetch mask values for this batch to improve cache locality
std::vector<float> batch_mask(max_patches);
for (int p = 0; p < max_patches; p++) { batch_mask[p] = patch_mask_float.at<float>({b, p}); }

// Compute attention mask for this batch with optimized memory access
for (int i = 0; i < max_patches; i++) {
float mask_i = batch_mask[i];
// Process row in chunks for better cache utilization
for (int j = 0; j < max_patches; j++) {
float mask_i = patch_mask_float.at<float>({b, i});
float mask_j = patch_mask_float.at<float>({b, j});
// Both positions must be valid
float mask_j = batch_mask[j];
// Both positions must be valid (branchless computation)
float final_mask = (mask_i > 0.0f && mask_j > 0.0f) ? 0.0f : -1e9f;
attention_mask.at<float>({b, 0, i, j}) = final_mask;
}
Expand Down
6 changes: 3 additions & 3 deletions mllm/models/minicpm_o2_6/tokenization_minicpmo.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ class MiniCPMOTokenizer final : public mllm::preprocessor::AutoTokenizer {
special_tokens_trie_.add(L"<|tts_eos|>");
}

std::vector<std::wstring> _tokenize(const std::string& str) {
std::vector<std::wstring> _tokenize(const std::string& str) override {
std::vector<std::wstring> ret;
std::vector<std::wstring> splitted;
::mllm::models::minicpmo::miniCPMORegex(str, splitted);
Expand All @@ -255,7 +255,7 @@ class MiniCPMOTokenizer final : public mllm::preprocessor::AutoTokenizer {
return ret;
}

std::vector<std::wstring> tokenize(const std::string& str) {
std::vector<std::wstring> tokenize(const std::string& str) override {
auto tokens = special_tokens_trie_.split(preprocessor::utf8string2WideString(str));
std::vector<std::wstring> all_tokens;
for (const auto& token : tokens) {
Expand All @@ -278,7 +278,7 @@ class MiniCPMOTokenizer final : public mllm::preprocessor::AutoTokenizer {
return {mllm::preprocessor::utf8string2WideString(utf_8_str)};
}

Tensor convert2Ids(const std::vector<std::wstring>& strs) {
Tensor convert2Ids(const std::vector<std::wstring>& strs) override {
std::vector<int64_t> ids;
ids.reserve(strs.size());
for (const auto& str : strs) { ids.emplace_back(bpe_._lookup_vocab(str)); }
Expand Down