Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 156 additions & 13 deletions examples/qwen_npu/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,16 @@
#include <mllm/utils/AnyValue.hpp>

#include "mllm/backends/qnn/passes/QNNGraphBuildPass.hpp"
#include "mllm/backends/qnn/passes/QNNGraphBuildPipeline.hpp"
#include "mllm/backends/qnn/passes/QNNGraphIOTensorPass.hpp"
#include "mllm/backends/qnn/passes/QNNOpNamingPass.hpp"
#include "mllm/backends/qnn/QNNAllocator.hpp"
#include "mllm/compile/PassManager.hpp"
#include "mllm/core/DataTypes.hpp"
#include "mllm/engine/Context.hpp"
#include "mllm/models/qwen_npu/tokenization_qwen.hpp"
#include "mllm/models/qwen_npu/modeling_qwen_npu.hpp"
#include "mllm/utils/Common.hpp"
#include "mllm/utils/Log.hpp"

using mllm::Argparse;

Expand All @@ -28,7 +33,7 @@ MLLM_MAIN({
auto param = mllm::load(model_path, file_version);
model.load(param);

mllm::models::ARGenerationOutputPast inputs{{"sequence", mllm::Tensor::empty({1, 32}, mllm::kInt64, mllm::kCPU).alloc()}};
mllm::models::ARGenerationOutputPast inputs{{"sequence", mllm::Tensor::empty({1, 128}, mllm::kInt64, mllm::kCPU).alloc()}};

auto irs = model.trace(inputs, {});

Expand All @@ -49,19 +54,157 @@ MLLM_MAIN({
// cache has been updated due to trace, clear cache
model.model.clearKVCache();

auto raw_input_tokens = qwen_tokenizer.convertMessage({.prompt = "How are you?"})["sequence"];
auto raw_input_tokens = qwen_tokenizer.convertMessage({.prompt = "提示:海洋世界里,鲸鱼是地球上体型最为庞大的哺乳动物,它们拥有流线型的身躯,主要通过头顶的喷水孔进行呼吸。与终生生活在水下并利用鱼鳃从水中提取溶解氧的鱼类有着本质区别。鲸鱼无法在水下直接呼吸氧气,因此它们需要耗费大量的体力,定时浮出水面完成一次快速而彻底的换气过程。令人惊奇的是,当它们处于睡眠状态时,为了确保不会因为忘记呼吸而发生危险,它们只会关闭大脑的一半来进行休息,另一半大脑则始终保持清醒和警觉,以便及时引导身体浮上水面。这种独特的生存机制是它们在深海中延续生命的关键。问题:鲸鱼与鱼类在呼吸方式上的根本区别是什么?它们在睡觉时会采取什么特殊的措施来保证安全和生存?"})["sequence"];
// auto raw_input_tokens = qwen_tokenizer.convertMessage({.prompt = "提示:海洋世界里,鲸鱼是体型庞大的哺乳动物,它们通过喷水孔呼吸。与鱼类不同,鲸鱼无法在水下直接呼吸氧气。它们会定时浮出水面进行换气,每次换气需要消耗大量的体力。当它们睡觉时,只会关闭大脑的一半,另一半则保持清醒,以确保不忘记浮出水面呼吸。问题:鲸鱼与鱼类在呼吸方式上的根本区别是什么?它们在睡觉时会采取什么特殊的措施来保证安全?"})["sequence"];
print(raw_input_tokens);

// manually set input data as fill op is not supported in QNN
auto ptr = inputs["sequence"].ptr<int64_t>();
MLLM_INFO("raw_input_tokens shape: {} {}", raw_input_tokens.shape()[0], raw_input_tokens.shape()[1]);

const int chunk_size = 128;
const int eos_token_id = 151645;
int prompt_tokens = static_cast<int>(raw_input_tokens.shape()[1]);
if (prompt_tokens <= 0) {
MLLM_ERROR_EXIT(mllm::ExitCode::kShapeError, "Prompt sequence length must be positive");
}

// Prepare reusable [1, chunk_size] CPU buffer for chunked prefill/decode
mllm::models::ARGenerationOutputPast chunk_inputs{
{"sequence", mllm::Tensor::empty({1, chunk_size}, mllm::kInt64, mllm::kCPU).alloc()}};
auto sequence_tensor = chunk_inputs["sequence"];
auto sequence_ptr = sequence_tensor.ptr<int64_t>();
auto input_data = raw_input_tokens.ptr<int64_t>();
for (int i = 0; i < raw_input_tokens.shape()[1]; ++i) { ptr[i] = input_data[i]; }
for (int i = raw_input_tokens.shape()[1]; i < 32; ++i) { ptr[i] = -1; }

auto out = model.forward(inputs, {{"seq_len", mllm::AnyValue((int)raw_input_tokens.shape()[1])}})["sequence"];

auto sampled = model.sampleGreedy(out);
std::wcout << "token: " << sampled << " " << qwen_tokenizer.detokenize(sampled) << "\n";
const int prompt_chunks = (prompt_tokens + chunk_size - 1) / chunk_size;
bool reached_eos = false;
int total_decode_steps = 0;

for (int chunk_index = 0; chunk_index < prompt_chunks && !reached_eos; ++chunk_index) {
const int chunk_start = chunk_index * chunk_size;
const int chunk_prompt_len = std::min(chunk_size, prompt_tokens - chunk_start);
const bool is_last_prompt_chunk = (chunk_index == prompt_chunks - 1);

// Copy current chunk prompt tokens and pad remaining positions with -1
for (int i = 0; i < chunk_prompt_len; ++i) { sequence_ptr[i] = input_data[chunk_start + i]; }
for (int i = chunk_prompt_len; i < chunk_size; ++i) { sequence_ptr[i] = -1; }

// MLLM_INFO("=== Prefill Chunk {} ===", chunk_index);
// MLLM_INFO("Chunk start: {}, Chunk prompt length: {}", chunk_start, chunk_prompt_len);

// Calculate absolute sequence length from the start of the entire sequence
const int absolute_seq_len = chunk_start + chunk_prompt_len;
// MLLM_INFO("Absolute sequence length: {}", absolute_seq_len);

// Align KV cache so StaticCache writes start at the chunk's absolute offset
model.setKVCacheSeqCnt(chunk_start);
// MLLM_INFO("KV cache seq_cnt set to: {}", chunk_start);

// Generate position_ids starting from chunk_start for multi-chunk scenarios
auto position_ids_tensor = mllm::Tensor::empty({1, chunk_size}, mllm::kInt64, mllm::kCPU).alloc();
auto position_ids_ptr = position_ids_tensor.ptr<int64_t>();
for (int i = 0; i < chunk_size; ++i) {
position_ids_ptr[i] = chunk_start + i;
}

// Prepare input with correct position_ids
mllm::models::ARGenerationOutputPast prefill_inputs{
{"sequence", sequence_tensor},
{"position_ids", position_ids_tensor}};

// real_seq should be the effective length in the current input tensor (relative position)
// hidden_states shape is [1, chunk_size, hidden_size], we need to index it with chunk_prompt_len - 1
auto chunk_output =
model.forward(prefill_inputs, {{"seq_len", mllm::AnyValue(mllm::any_copy_tag, chunk_prompt_len)}});
auto& chunk_logits = chunk_output["sequence"];

// auto tmp_next_token = model.sampleGreedy(chunk_logits);
// std::wcout << qwen_tokenizer.detokenize(tmp_next_token) << "\n";
// std::wcout << qwen_tokenizer.detokenize(sequence_ptr[chunk_start + chunk_prompt_len]) << "\n";

if (!is_last_prompt_chunk) {
// MLLM_INFO("Chunk {} processed as prompt only, moving to next chunk", chunk_index);
chunk_logits.delete_();
chunk_output.clear();
continue;
}

if (chunk_prompt_len >= chunk_size) {
MLLM_WARN("Last chunk is fully occupied by prompt tokens; no padding for decode");
chunk_logits.delete_();
chunk_output.clear();
break;
}

// MLLM_INFO("=== Decode Phase (Chunk {}) ===", chunk_index);

// Use the prefill logits as the first decode step
auto next_token = model.sampleGreedy(chunk_logits);
chunk_logits.delete_();

// Keep full-length position_ids tensor aligned with chunk buffer
auto position_ids = position_ids_tensor;

chunk_output.clear();

auto emit_token = [&](int64_t token_id) {
std::wcout << qwen_tokenizer.detokenize(token_id) << std::flush;
if (token_id == eos_token_id) {
MLLM_INFO("EOS token detected, stopping decode");
reached_eos = true;
}
};

int current_chunk_len = chunk_prompt_len;
emit_token(next_token);
if (reached_eos) { break; }

sequence_ptr[current_chunk_len] = next_token;
current_chunk_len++;

while (!reached_eos && current_chunk_len < chunk_size) {
total_decode_steps++;

// Calculate absolute sequence length from the start of the entire sequence
const int absolute_seq_len = chunk_start + current_chunk_len;

// MLLM_INFO("--- Chunk {} Decode Step {} ---", chunk_index, total_decode_steps);
// MLLM_INFO("Current chunk length: {} (relative), Absolute sequence length: {} (absolute)", current_chunk_len, absolute_seq_len);

// Keep padding clean for the remaining area
for (int i = current_chunk_len; i < chunk_size; ++i) { sequence_ptr[i] = -1; }

// Set KV cache to absolute sequence length (where the next token will be written)
// [Maybe Wrong]
model.setKVCacheSeqCnt(chunk_start);
// MLLM_INFO("KV cache seq_cnt set to: {} (relative position)", chunk_start);

// Prepare decode input with position_ids from previous step
mllm::models::ARGenerationOutputPast decode_inputs{
{"sequence", sequence_tensor},
{"position_ids", position_ids}};

// real_seq should be the effective length in the current input tensor (relative position)
// hidden_states shape is [1, chunk_size, hidden_size], we need to index it with current_chunk_len - 1
auto decode_output = model.forward(
decode_inputs, {{"seq_len", mllm::AnyValue(mllm::any_copy_tag, current_chunk_len)}});

auto& decode_logits = decode_output["sequence"];
next_token = model.sampleGreedy(decode_logits);
decode_logits.delete_();
decode_output.erase("sequence");
decode_output.clear();

emit_token(next_token);
if (reached_eos) { break; }

sequence_ptr[current_chunk_len] = next_token;
current_chunk_len++;
}
Comment on lines +162 to +200
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

KV cache seq count is likely wrong in the decode loop (comment already says [Maybe Wrong])

In the per-chunk decode loop you recompute the absolute sequence length:

const int absolute_seq_len = chunk_start + current_chunk_len;
...
// Set KV cache to absolute sequence length (where the next token will be written)
// [Maybe Wrong]
model.setKVCacheSeqCnt(chunk_start);

Using chunk_start here ignores already-consumed tokens in the current chunk and appears inconsistent with the comment and with the new KV cache APIs. This can misalign the KV cache during multi-step decode, especially after the first generated token within the last chunk.

A more consistent approach would be to base the sequence count on the absolute number of valid tokens processed so far, e.g.:

-      // Set KV cache to absolute sequence length (where the next token will be written)
-      // [Maybe Wrong]
-      model.setKVCacheSeqCnt(chunk_start);
+      // Set KV cache to the absolute sequence length so far (prefill + decoded tokens)
+      // The next token will be written after `absolute_seq_len - 1`.
+      model.setKVCacheSeqCnt(absolute_seq_len);

or, if setKVCacheSeqCnt is defined to take “index of next write” rather than “current valid length”, adjust accordingly (e.g. absolute_seq_len - 1).

Please double-check this against nn::KVCache::setCurrentSeqCnt semantics; the current implementation is very likely incorrect for multi-step decode.



---

🏁 Script executed:

```shell
#!/bin/bash
# Find setKVCacheSeqCnt definition and usage
rg -n 'setKVCacheSeqCnt' --type=cpp -B3 -A3

Length of output: 2302


🏁 Script executed:

#!/bin/bash
# Examine the full context of main.cpp around the decode loop (lines 140-210)
sed -n '140,210p' examples/qwen_npu/main.cpp

Length of output: 2654


🏁 Script executed:

#!/bin/bash
# Search for prefill loop or chunked processing setup
rg -n 'chunk_start\|chunk_size\|prefill' examples/qwen_npu/main.cpp --type=cpp -B2 -A2

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Find KVCache class definition
find . -name '*.hpp' -o -name '*.h' | xargs rg -l 'class KVCache\|struct KVCache' 2>/dev/null | head -5

Length of output: 0



---

🏁 Script executed:

```shell
#!/bin/bash
# Get the prefill loop context before the decode loop
sed -n '70,145p' examples/qwen_npu/main.cpp

Length of output: 3484


🏁 Script executed:

#!/bin/bash
# Search for setCurrentSeqCnt definition and usage
rg -n 'setCurrentSeqCnt' --type=cpp --type=hpp -B2 -A5

Length of output: 27


🏁 Script executed:

#!/bin/bash
# Look for any StaticCache or KVCache definition in the mllm directory
find mllm -name '*.hpp' -o -name '*.h' | xargs grep -l 'setCurrentSeqCnt' 2>/dev/null | head -3

Length of output: 128


🏁 Script executed:

#!/bin/bash
# Search for how prefill is called and initialized
rg -n 'forward.*prefill\|prefill.*forward' examples/qwen_npu/main.cpp --type=cpp -B5 -A5

Length of output: 0



---

🏁 Script executed:

```shell
#!/bin/bash
# Find setCurrentSeqCnt definition in KVCacheOp.hpp
cat mllm/backends/cpu/ops/KVCacheOp.hpp | grep -A10 'setCurrentSeqCnt'

Length of output: 326


🏁 Script executed:

#!/bin/bash
# Look for StaticCache or KVCache class definition
rg -n 'class.*KVCache\|class.*StaticCache' mllm/backends/cpu/ops/KVCacheOp.hpp -A15

Length of output: 0



---

🏁 Script executed:

```shell
#!/bin/bash
# Find nn::StaticCache definition to confirm setCurrentSeqCnt semantics
find mllm -name '*.hpp' -o -name '*.h' | xargs grep -l 'class StaticCache\|struct StaticCache' 2>/dev/null

Length of output: 31


🏁 Script executed:

#!/bin/bash
# Get broader context of StaticCache to understand seq_cnt semantics
rg -n 'setCurrentSeqCnt|getCurrentSeqCnt' mllm/backends/cpu/ -B3 -A8 | head -40

Length of output: 2657



---

🏁 Script executed:

```shell
#!/bin/bash
# Check StaticCache implementation
sed -n '1,100p' mllm/nn/lmcache/StaticCache.hpp

Length of output: 3312


🏁 Script executed:

#!/bin/bash
# Search for setCurrentSeqCnt in StaticCache
rg -n 'setCurrentSeqCnt' mllm/nn/lmcache/StaticCache.hpp -B2 -A8

Length of output: 943


KV cache seq count must track absolute sequence length, not just chunk offset

The bug is confirmed. In the decode loop at line 176, you reset the KV cache seq count to chunk_start on every iteration, but current_chunk_len increments. This causes the cache to lose track of tokens decoded within the chunk:

  • Iteration 1: current_chunk_len = chunk_prompt_len + 1, but cache set to chunk_startoff by chunk_prompt_len tokens
  • Iteration 2: current_chunk_len = chunk_prompt_len + 2, but cache still chunk_startoff by chunk_prompt_len + 1 tokens

The computed absolute_seq_len variable (line 163) is never used, indicating incomplete implementation. The fix is to use it:

-      model.setKVCacheSeqCnt(chunk_start);
+      model.setKVCacheSeqCnt(absolute_seq_len);

Additionally, the prefill at line 97 should similarly use chunk_start + chunk_prompt_len (the absolute valid length after prefill) instead of just chunk_start.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
while (!reached_eos && current_chunk_len < chunk_size) {
total_decode_steps++;
// Calculate absolute sequence length from the start of the entire sequence
const int absolute_seq_len = chunk_start + current_chunk_len;
// MLLM_INFO("--- Chunk {} Decode Step {} ---", chunk_index, total_decode_steps);
// MLLM_INFO("Current chunk length: {} (relative), Absolute sequence length: {} (absolute)", current_chunk_len, absolute_seq_len);
// Keep padding clean for the remaining area
for (int i = current_chunk_len; i < chunk_size; ++i) { sequence_ptr[i] = -1; }
// Set KV cache to absolute sequence length (where the next token will be written)
// [Maybe Wrong]
model.setKVCacheSeqCnt(chunk_start);
// MLLM_INFO("KV cache seq_cnt set to: {} (relative position)", chunk_start);
// Prepare decode input with position_ids from previous step
mllm::models::ARGenerationOutputPast decode_inputs{
{"sequence", sequence_tensor},
{"position_ids", position_ids}};
// real_seq should be the effective length in the current input tensor (relative position)
// hidden_states shape is [1, chunk_size, hidden_size], we need to index it with current_chunk_len - 1
auto decode_output = model.forward(
decode_inputs, {{"seq_len", mllm::AnyValue(mllm::any_copy_tag, current_chunk_len)}});
auto& decode_logits = decode_output["sequence"];
next_token = model.sampleGreedy(decode_logits);
decode_logits.delete_();
decode_output.erase("sequence");
decode_output.clear();
emit_token(next_token);
if (reached_eos) { break; }
sequence_ptr[current_chunk_len] = next_token;
current_chunk_len++;
}
while (!reached_eos && current_chunk_len < chunk_size) {
total_decode_steps++;
// Calculate absolute sequence length from the start of the entire sequence
const int absolute_seq_len = chunk_start + current_chunk_len;
// MLLM_INFO("--- Chunk {} Decode Step {} ---", chunk_index, total_decode_steps);
// MLLM_INFO("Current chunk length: {} (relative), Absolute sequence length: {} (absolute)", current_chunk_len, absolute_seq_len);
// Keep padding clean for the remaining area
for (int i = current_chunk_len; i < chunk_size; ++i) { sequence_ptr[i] = -1; }
// Set KV cache to absolute sequence length (where the next token will be written)
// [Maybe Wrong]
model.setKVCacheSeqCnt(absolute_seq_len);
// MLLM_INFO("KV cache seq_cnt set to: {} (relative position)", chunk_start);
// Prepare decode input with position_ids from previous step
mllm::models::ARGenerationOutputPast decode_inputs{
{"sequence", sequence_tensor},
{"position_ids", position_ids}};
// real_seq should be the effective length in the current input tensor (relative position)
// hidden_states shape is [1, chunk_size, hidden_size], we need to index it with current_chunk_len - 1
auto decode_output = model.forward(
decode_inputs, {{"seq_len", mllm::AnyValue(mllm::any_copy_tag, current_chunk_len)}});
auto& decode_logits = decode_output["sequence"];
next_token = model.sampleGreedy(decode_logits);
decode_logits.delete_();
decode_output.erase("sequence");
decode_output.clear();
emit_token(next_token);
if (reached_eos) { break; }
sequence_ptr[current_chunk_len] = next_token;
current_chunk_len++;
}
🤖 Prompt for AI Agents
In examples/qwen_npu/main.cpp around lines 162 to 200, the KV cache seq count is
being set to chunk_start each decode iteration which ignores tokens already
decoded in the current chunk; replace the call to
model.setKVCacheSeqCnt(chunk_start) with
model.setKVCacheSeqCnt(absolute_seq_len) (where absolute_seq_len = chunk_start +
current_chunk_len) so the cache uses the true absolute write position, and also
update the prefill call near line 97 to set the seq count to chunk_start +
chunk_prompt_len (the absolute length after prefill) instead of just
chunk_start.


// MLLM_INFO("=== Chunk {} Decode Complete ===", chunk_index);
// MLLM_INFO("Chunk final length: {}", current_chunk_len);
// MLLM_INFO("Remaining capacity: {}", chunk_size - current_chunk_len);
}

std::wcout << L"\n";

return 0;
})
})
8 changes: 4 additions & 4 deletions mllm/backends/cpu/ops/CausalMaskOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ void CPUCausalMaskOp::forward(const std::vector<Tensor>& inputs, std::vector<Ten
#if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) && defined(__AVX2__)
const __m256 mask_val = _mm256_set1_ps(-1e10f);
for (size_t r = 0; r < S; ++r) {
const size_t copy_count = std::min(r + 1, (size_t)D);
const size_t fill_count = D > copy_count ? (D - copy_count) : 0;
const size_t copy_count = D - S + r + 1;
const size_t fill_count = std::max(D - copy_count, (size_t)0);

memcpy(o_ptr + r * D, i_ptr + r * D, copy_count * sizeof(float));

Expand All @@ -68,8 +68,8 @@ void CPUCausalMaskOp::forward(const std::vector<Tensor>& inputs, std::vector<Ten
#elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM)
const float32x4_t mask_val = vdupq_n_f32(-1e10f);
for (size_t r = 0; r < S; ++r) {
const size_t copy_count = std::min(r + 1, (size_t)D);
const size_t fill_count = D > copy_count ? (D - copy_count) : 0;
const size_t copy_count = D - S + r + 1;
const size_t fill_count = std::max(D - copy_count, (size_t)0);

memcpy(o_ptr + r * D, i_ptr + r * D, copy_count * sizeof(float));

Expand Down
4 changes: 4 additions & 0 deletions mllm/backends/cpu/ops/KVCacheOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,8 @@ void CPUKVCacheOp::forward(const std::vector<Tensor>& inputs, std::vector<Tensor

void CPUKVCacheOp::clearCache() { cache_.clearCache(); }

void CPUKVCacheOp::setCurrentSeqCnt(int32_t seq) { cache_.setCurrentSeqCnt(seq); }

int32_t CPUKVCacheOp::getCurrentSeqCnt() const { return cache_.getCurrentSeqCnt(options_.layer_idx); }

} // namespace mllm::cpu
4 changes: 4 additions & 0 deletions mllm/backends/cpu/ops/KVCacheOp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ class CPUKVCacheOp final : public aops::KVCacheOp {

void clearCache() override;

void setCurrentSeqCnt(int32_t seq) override;

int32_t getCurrentSeqCnt() const override;

private:
nn::StaticCache cache_;
};
Expand Down
Loading