diff --git a/examples/qwen_npu/main.cpp b/examples/qwen_npu/main.cpp index 97c54b4ce..e64b8d194 100644 --- a/examples/qwen_npu/main.cpp +++ b/examples/qwen_npu/main.cpp @@ -4,11 +4,16 @@ #include #include "mllm/backends/qnn/passes/QNNGraphBuildPass.hpp" -#include "mllm/backends/qnn/passes/QNNGraphBuildPipeline.hpp" +#include "mllm/backends/qnn/passes/QNNGraphIOTensorPass.hpp" +#include "mllm/backends/qnn/passes/QNNOpNamingPass.hpp" +#include "mllm/backends/qnn/QNNAllocator.hpp" #include "mllm/compile/PassManager.hpp" #include "mllm/core/DataTypes.hpp" +#include "mllm/engine/Context.hpp" #include "mllm/models/qwen_npu/tokenization_qwen.hpp" #include "mllm/models/qwen_npu/modeling_qwen_npu.hpp" +#include "mllm/utils/Common.hpp" +#include "mllm/utils/Log.hpp" using mllm::Argparse; @@ -28,7 +33,7 @@ MLLM_MAIN({ auto param = mllm::load(model_path, file_version); model.load(param); - mllm::models::ARGenerationOutputPast inputs{{"sequence", mllm::Tensor::empty({1, 32}, mllm::kInt64, mllm::kCPU).alloc()}}; + mllm::models::ARGenerationOutputPast inputs{{"sequence", mllm::Tensor::empty({1, 128}, mllm::kInt64, mllm::kCPU).alloc()}}; auto irs = model.trace(inputs, {}); @@ -49,19 +54,157 @@ MLLM_MAIN({ // cache has been updated due to trace, clear cache model.model.clearKVCache(); - auto raw_input_tokens = qwen_tokenizer.convertMessage({.prompt = "How are you?"})["sequence"]; + auto raw_input_tokens = qwen_tokenizer.convertMessage({.prompt = "提示:海洋世界里,鲸鱼是地球上体型最为庞大的哺乳动物,它们拥有流线型的身躯,主要通过头顶的喷水孔进行呼吸。与终生生活在水下并利用鱼鳃从水中提取溶解氧的鱼类有着本质区别。鲸鱼无法在水下直接呼吸氧气,因此它们需要耗费大量的体力,定时浮出水面完成一次快速而彻底的换气过程。令人惊奇的是,当它们处于睡眠状态时,为了确保不会因为忘记呼吸而发生危险,它们只会关闭大脑的一半来进行休息,另一半大脑则始终保持清醒和警觉,以便及时引导身体浮上水面。这种独特的生存机制是它们在深海中延续生命的关键。问题:鲸鱼与鱼类在呼吸方式上的根本区别是什么?它们在睡觉时会采取什么特殊的措施来保证安全和生存?"})["sequence"]; + // auto raw_input_tokens = qwen_tokenizer.convertMessage({.prompt = "提示:海洋世界里,鲸鱼是体型庞大的哺乳动物,它们通过喷水孔呼吸。与鱼类不同,鲸鱼无法在水下直接呼吸氧气。它们会定时浮出水面进行换气,每次换气需要消耗大量的体力。当它们睡觉时,只会关闭大脑的一半,另一半则保持清醒,以确保不忘记浮出水面呼吸。问题:鲸鱼与鱼类在呼吸方式上的根本区别是什么?它们在睡觉时会采取什么特殊的措施来保证安全?"})["sequence"]; print(raw_input_tokens); - - // manually set input data as fill op is not supported in QNN - auto ptr = inputs["sequence"].ptr(); + MLLM_INFO("raw_input_tokens shape: {} {}", raw_input_tokens.shape()[0], raw_input_tokens.shape()[1]); + + const int chunk_size = 128; + const int eos_token_id = 151645; + int prompt_tokens = static_cast(raw_input_tokens.shape()[1]); + if (prompt_tokens <= 0) { + MLLM_ERROR_EXIT(mllm::ExitCode::kShapeError, "Prompt sequence length must be positive"); + } + + // Prepare reusable [1, chunk_size] CPU buffer for chunked prefill/decode + mllm::models::ARGenerationOutputPast chunk_inputs{ + {"sequence", mllm::Tensor::empty({1, chunk_size}, mllm::kInt64, mllm::kCPU).alloc()}}; + auto sequence_tensor = chunk_inputs["sequence"]; + auto sequence_ptr = sequence_tensor.ptr(); auto input_data = raw_input_tokens.ptr(); - for (int i = 0; i < raw_input_tokens.shape()[1]; ++i) { ptr[i] = input_data[i]; } - for (int i = raw_input_tokens.shape()[1]; i < 32; ++i) { ptr[i] = -1; } - - auto out = model.forward(inputs, {{"seq_len", mllm::AnyValue((int)raw_input_tokens.shape()[1])}})["sequence"]; - auto sampled = model.sampleGreedy(out); - std::wcout << "token: " << sampled << " " << qwen_tokenizer.detokenize(sampled) << "\n"; + const int prompt_chunks = (prompt_tokens + chunk_size - 1) / chunk_size; + bool reached_eos = false; + int total_decode_steps = 0; + + for (int chunk_index = 0; chunk_index < prompt_chunks && !reached_eos; ++chunk_index) { + const int chunk_start = chunk_index * chunk_size; + const int chunk_prompt_len = std::min(chunk_size, prompt_tokens - chunk_start); + const bool is_last_prompt_chunk = (chunk_index == prompt_chunks - 1); + + // Copy current chunk prompt tokens and pad remaining positions with -1 + for (int i = 0; i < chunk_prompt_len; ++i) { sequence_ptr[i] = input_data[chunk_start + i]; } + for (int i = chunk_prompt_len; i < chunk_size; ++i) { sequence_ptr[i] = -1; } + + // MLLM_INFO("=== Prefill Chunk {} ===", chunk_index); + // MLLM_INFO("Chunk start: {}, Chunk prompt length: {}", chunk_start, chunk_prompt_len); + + // Calculate absolute sequence length from the start of the entire sequence + const int absolute_seq_len = chunk_start + chunk_prompt_len; + // MLLM_INFO("Absolute sequence length: {}", absolute_seq_len); + + // Align KV cache so StaticCache writes start at the chunk's absolute offset + model.setKVCacheSeqCnt(chunk_start); + // MLLM_INFO("KV cache seq_cnt set to: {}", chunk_start); + + // Generate position_ids starting from chunk_start for multi-chunk scenarios + auto position_ids_tensor = mllm::Tensor::empty({1, chunk_size}, mllm::kInt64, mllm::kCPU).alloc(); + auto position_ids_ptr = position_ids_tensor.ptr(); + for (int i = 0; i < chunk_size; ++i) { + position_ids_ptr[i] = chunk_start + i; + } + + // Prepare input with correct position_ids + mllm::models::ARGenerationOutputPast prefill_inputs{ + {"sequence", sequence_tensor}, + {"position_ids", position_ids_tensor}}; + + // real_seq should be the effective length in the current input tensor (relative position) + // hidden_states shape is [1, chunk_size, hidden_size], we need to index it with chunk_prompt_len - 1 + auto chunk_output = + model.forward(prefill_inputs, {{"seq_len", mllm::AnyValue(mllm::any_copy_tag, chunk_prompt_len)}}); + auto& chunk_logits = chunk_output["sequence"]; + + // auto tmp_next_token = model.sampleGreedy(chunk_logits); + // std::wcout << qwen_tokenizer.detokenize(tmp_next_token) << "\n"; + // std::wcout << qwen_tokenizer.detokenize(sequence_ptr[chunk_start + chunk_prompt_len]) << "\n"; + + if (!is_last_prompt_chunk) { + // MLLM_INFO("Chunk {} processed as prompt only, moving to next chunk", chunk_index); + chunk_logits.delete_(); + chunk_output.clear(); + continue; + } + + if (chunk_prompt_len >= chunk_size) { + MLLM_WARN("Last chunk is fully occupied by prompt tokens; no padding for decode"); + chunk_logits.delete_(); + chunk_output.clear(); + break; + } + + // MLLM_INFO("=== Decode Phase (Chunk {}) ===", chunk_index); + + // Use the prefill logits as the first decode step + auto next_token = model.sampleGreedy(chunk_logits); + chunk_logits.delete_(); + + // Keep full-length position_ids tensor aligned with chunk buffer + auto position_ids = position_ids_tensor; + + chunk_output.clear(); + + auto emit_token = [&](int64_t token_id) { + std::wcout << qwen_tokenizer.detokenize(token_id) << std::flush; + if (token_id == eos_token_id) { + MLLM_INFO("EOS token detected, stopping decode"); + reached_eos = true; + } + }; + + int current_chunk_len = chunk_prompt_len; + emit_token(next_token); + if (reached_eos) { break; } + + sequence_ptr[current_chunk_len] = next_token; + current_chunk_len++; + + while (!reached_eos && current_chunk_len < chunk_size) { + total_decode_steps++; + + // Calculate absolute sequence length from the start of the entire sequence + const int absolute_seq_len = chunk_start + current_chunk_len; + + // MLLM_INFO("--- Chunk {} Decode Step {} ---", chunk_index, total_decode_steps); + // MLLM_INFO("Current chunk length: {} (relative), Absolute sequence length: {} (absolute)", current_chunk_len, absolute_seq_len); + + // Keep padding clean for the remaining area + for (int i = current_chunk_len; i < chunk_size; ++i) { sequence_ptr[i] = -1; } + + // Set KV cache to absolute sequence length (where the next token will be written) + // [Maybe Wrong] + model.setKVCacheSeqCnt(chunk_start); + // MLLM_INFO("KV cache seq_cnt set to: {} (relative position)", chunk_start); + + // Prepare decode input with position_ids from previous step + mllm::models::ARGenerationOutputPast decode_inputs{ + {"sequence", sequence_tensor}, + {"position_ids", position_ids}}; + + // real_seq should be the effective length in the current input tensor (relative position) + // hidden_states shape is [1, chunk_size, hidden_size], we need to index it with current_chunk_len - 1 + auto decode_output = model.forward( + decode_inputs, {{"seq_len", mllm::AnyValue(mllm::any_copy_tag, current_chunk_len)}}); + + auto& decode_logits = decode_output["sequence"]; + next_token = model.sampleGreedy(decode_logits); + decode_logits.delete_(); + decode_output.erase("sequence"); + decode_output.clear(); + + emit_token(next_token); + if (reached_eos) { break; } + + sequence_ptr[current_chunk_len] = next_token; + current_chunk_len++; + } + + // MLLM_INFO("=== Chunk {} Decode Complete ===", chunk_index); + // MLLM_INFO("Chunk final length: {}", current_chunk_len); + // MLLM_INFO("Remaining capacity: {}", chunk_size - current_chunk_len); + } + + std::wcout << L"\n"; return 0; -}) \ No newline at end of file +}) diff --git a/mllm/backends/cpu/ops/CausalMaskOp.cpp b/mllm/backends/cpu/ops/CausalMaskOp.cpp index b91dda5d6..074ca0199 100644 --- a/mllm/backends/cpu/ops/CausalMaskOp.cpp +++ b/mllm/backends/cpu/ops/CausalMaskOp.cpp @@ -53,8 +53,8 @@ void CPUCausalMaskOp::forward(const std::vector& inputs, std::vector copy_count ? (D - copy_count) : 0; + const size_t copy_count = D - S + r + 1; + const size_t fill_count = std::max(D - copy_count, (size_t)0); memcpy(o_ptr + r * D, i_ptr + r * D, copy_count * sizeof(float)); @@ -68,8 +68,8 @@ void CPUCausalMaskOp::forward(const std::vector& inputs, std::vector copy_count ? (D - copy_count) : 0; + const size_t copy_count = D - S + r + 1; + const size_t fill_count = std::max(D - copy_count, (size_t)0); memcpy(o_ptr + r * D, i_ptr + r * D, copy_count * sizeof(float)); diff --git a/mllm/backends/cpu/ops/KVCacheOp.cpp b/mllm/backends/cpu/ops/KVCacheOp.cpp index 7847a1fb7..ed6977687 100644 --- a/mllm/backends/cpu/ops/KVCacheOp.cpp +++ b/mllm/backends/cpu/ops/KVCacheOp.cpp @@ -44,4 +44,8 @@ void CPUKVCacheOp::forward(const std::vector& inputs, std::vector +#include #include namespace mllm::qnn { +namespace { +constexpr bool kVerboseQnnAllocatorLogs = false; +} // namespace + +#define QNN_ALLOCATOR_VERBOSE(...) \ + do { \ + if constexpr (kVerboseQnnAllocatorLogs) { MLLM_INFO(__VA_ARGS__); } \ + } while (0) + // specified in QNN doc #define RPCMEM_HEAP_ID_SYSTEM 25 #define RPCMEM_DEFAULT_FLAGS 1 @@ -34,8 +46,26 @@ QNNAllocator::QNNAllocator(QNN_INTERFACE_VER_TYPE qnnInterface, void* context) rpcmem_to_fd = (RpcMemToFdFn_t)dlsym(libCdspHandle, "rpcmem_to_fd"); } +QNNAllocator::~QNNAllocator() { + for (auto iter = ptrToFdAndMemHandleMap_.begin(); iter != ptrToFdAndMemHandleMap_.end();) { + Qnn_ErrorHandle_t deregisterRet = qnnInterface_.memDeRegister(&iter->second.second, 1); + if (QNN_SUCCESS != deregisterRet) { + MLLM_WARN("~QNNAllocator: memDeRegister failed during shutdown, status=0x{:x}", deregisterRet); + } + qnnMemPtrSet_.erase(iter->first); + rpcmem_free(iter->first); + iter = ptrToFdAndMemHandleMap_.erase(iter); + } + + for (void* ptr : qnnMemPtrSet_) { + rpcmem_free(ptr); + } + qnnMemPtrSet_.clear(); +} + bool QNNAllocator::alloc(Storage* storage) { - uint8_t* ptr = (uint8_t*)rpcmem_alloc(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS, allocSize(storage)); + size_t request_bytes = allocSize(storage); + uint8_t* ptr = (uint8_t*)rpcmem_alloc(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS, request_bytes); MLLM_RT_ASSERT(ptr != nullptr); @@ -45,25 +75,290 @@ bool QNNAllocator::alloc(Storage* storage) { return true; } +/** + * @brief Free a storage buffer and manage QNN memory handle lifecycle + * + * This function handles the complex lifecycle of QNN shared buffers: + * 1. Checks if the buffer is already freed or never allocated + * 2. Detects if multiple pointers share the same mem_handle (aliases) + * 3. Only de-registers mem_handle when it's the last reference + * 4. Updates tensor ID/name mappings to point to alternative pointers if needed + * + * Key design considerations: + * - QNN doesn't support re-registering a de-registered buffer (fd may be invalidated) + * - Multiple buffer pointers can share the same mem_handle (common in decode phase) + * - Tensor mappings must be updated when pointers are redirected to aliases + * + * @param storage Pointer to the storage object containing the buffer to free + */ void QNNAllocator::free(Storage* storage) { - if (ptrToFdAndMemHandleMap_.count(storage->ptr_)) { - MLLM_RT_ASSERT_EQ(QNN_SUCCESS, - qnnInterface_.memDeRegister(&(ptrToFdAndMemHandleMap_.find(storage->ptr_)->second.second), 1)); + auto ptr = storage->ptr_; + + // Early return if ptr is nullptr or not in qnnMemPtrSet_ (already freed or never allocated) + // This is common during decode phase when buffers are reused, so we silently ignore + if (ptr == nullptr) { + // too noisy during decode; silently ignore nullptr frees + return; + } + + if (qnnMemPtrSet_.count(ptr) == 0) { + QNN_ALLOCATOR_VERBOSE("QNNAllocator::free called for ptr={} that is not in qnnMemPtrSet_, ignoring", ptr); + return; } - rpcmem_free(storage->ptr_); + // Check if any other buffer pointer shares the same mem_handle (alias detection) + // This is important because in decode phase, multiple tensor wrappers may reference + // the same underlying buffer through different pointers + void* alternative_ptr = nullptr; // Another ptr using the same mem_handle, if any + + if (ptrToFdAndMemHandleMap_.count(ptr)) { + auto iter = ptrToFdAndMemHandleMap_.find(ptr); + auto mem_handle = iter->second.second; + + // Check if any other ptr is using the same mem_handle + // This handles the case where buffer reuse creates multiple pointers to the same mem_handle + for (const auto& [other_ptr, fd_and_handle] : ptrToFdAndMemHandleMap_) { + if (other_ptr != ptr && fd_and_handle.second == mem_handle) { + alternative_ptr = other_ptr; + break; + } + } + + // Only deRegister if this is the last ptr using this mem_handle + // If there are aliases, we must keep the mem_handle registered + if (alternative_ptr == nullptr) { + // No aliases found, safe to de-register the mem_handle + auto status = qnnInterface_.memDeRegister(&mem_handle, 1); + if (status != QNN_SUCCESS) { + MLLM_WARN("QNNAllocator::free memDeRegister failed, status=0x{:x}, ptr={}, fd={}", status, ptr, iter->second.first); + } + // Remove from ptrToFdAndMemHandleMap_ and ptrToSizeMap_ + // The actual buffer will be freed later in the function + ptrToFdAndMemHandleMap_.erase(iter); + ptrToSizeMap_.erase(ptr); + } else { + // Aliases exist, skip de-registration to avoid breaking other references + QNN_ALLOCATOR_VERBOSE("QNNAllocator::free skipping deRegister for ptr={} because other ptrs use the mem_handle", ptr); + ptrToFdAndMemHandleMap_.erase(iter); + ptrToSizeMap_.erase(ptr); + } + } else { + // ptr is in qnnMemPtrSet_ but not in ptrToFdAndMemHandleMap_ + // This means it was allocated but never registered (e.g., memRegister failed) + // Just free the buffer without deRegister + QNN_ALLOCATOR_VERBOSE("QNNAllocator::free freeing unregistered buffer ptr={}", ptr); + qnnMemPtrSet_.erase(ptr); + rpcmem_free(ptr); + eraseTensorMappingsForPtr(ptr, "free(unregistered buffer)"); + clearLastRegistrationIfMatches(ptr, "free(unregistered buffer)"); + return; + } + + // Update or keep tensor ID and name mappings + // If mem_handle is still in use (alternative_ptr exists), update mappings to point to alternative_ptr + // Otherwise, free the buffer and clear mappings + if (alternative_ptr != nullptr) { + // Update mappings to point to alternative_ptr instead of deleting them + // This ensures that future tensor lookups will find the correct buffer + for (auto& entry : tensorIdToPtrMap_) { + if (entry.second == ptr) { entry.second = alternative_ptr; } + } + for (auto& entry : tensorNameToPtrMap_) { + if (entry.second == ptr) { entry.second = alternative_ptr; } + } + // Don't free the buffer here since alternative_ptr is still using it + qnnMemPtrSet_.erase(ptr); + clearLastRegistrationIfMatches(ptr, "free(ptr) -> redirected to alias"); + } else { + // Since QNN doesn't support re-registering a deRegistered buffer (fd may be invalidated), + // we should free the buffer immediately even if there are mappings. + // The decode phase will allocate a new buffer when needed. + qnnMemPtrSet_.erase(ptr); + rpcmem_free(ptr); + eraseTensorMappingsForPtr(ptr, "free(ptr) -> mem_handle released"); + clearLastRegistrationIfMatches(ptr, "free(ptr) -> mem_handle released"); } + storage->ptr_ = nullptr; +} + +/** + * @brief Register a tensor's buffer to QNN shared memory + * + * This function implements a sophisticated buffer reuse mechanism to avoid duplicate registrations + * of the same tensor across prefill and decode phases. It uses a multi-level fallback strategy: + * + * 1. Check if the buffer is already registered (by ptr) + * 2. Check if a buffer exists for the same tensor ID (primary lookup) + * 3. Check if a buffer exists for the same tensor name (fallback lookup) + * 4. Check if we can reuse the last successfully registered buffer (last resort) + * 5. If all fallbacks fail, attempt new registration + * + * This is critical for decode phase where the same tensor (e.g., KV cache) is used repeatedly, + * and QNN HTP device has limited memory resources (~2.5GB typically). + * + * @param storage Storage object containing the buffer to register + * @param qnn_tensor QNN tensor structure to update with mem_handle + * @return true if registration succeeded, false otherwise + */ +bool QNNAllocator::registerQnnTensorToSharedBuffer(Storage* storage, Qnn_Tensor_t& qnn_tensor) { + MLLM_RT_ASSERT(storage != nullptr); + void* ptr = storage->ptr_; -void QNNAllocator::registerQnnTensorToSharedBuffer(void* ptr, Qnn_Tensor_t& qnn_tensor) { // Make sure there has a memory that we can register to. + MLLM_RT_ASSERT(ptr != nullptr); MLLM_RT_ASSERT(qnnMemPtrSet_.count(ptr)); - // if already registered, just set the mem handle + // Save original tensor state in case we need to restore on failure + auto original_mem_type = QNN_TENSOR_GET_MEM_TYPE(qnn_tensor); + Qnn_MemHandle_t original_mem_handle = QNN_TENSOR_GET_MEM_HANDLE(qnn_tensor); + + // Extract tensor identification information + // Tensor ID is the primary identifier (more reliable than name) + uint32_t tensor_id = QNN_TENSOR_GET_ID(qnn_tensor); + const char* tensor_name_cstr = QNN_TENSOR_GET_NAME(qnn_tensor); + std::string tensor_name = tensor_name_cstr ? tensor_name_cstr : "unknown"; + + // Calculate buffer size from tensor dimensions and data type + uint32_t rank = QNN_TENSOR_GET_RANK(qnn_tensor); + uint32_t* dims_ptr = QNN_TENSOR_GET_DIMENSIONS(qnn_tensor); + Qnn_DataType_t data_type = QNN_TENSOR_GET_DATA_TYPE(qnn_tensor); + + size_t element_bytes = 0; + if (auto it = QNNDataTypeToSize.find(data_type); it != QNNDataTypeToSize.end()) { element_bytes = it->second; } + + size_t element_cnt = 1; + std::vector dims; + dims.reserve(rank); + for (uint32_t i = 0; i < rank; ++i) { + uint32_t dim = dims_ptr ? dims_ptr[i] : 0; + dims.push_back(dim); + element_cnt *= (dim == 0 ? 1 : dim); + } + size_t total_bytes = element_cnt * element_bytes; + + // Format shape string for error messages + std::string shape_str = "[]"; + if (!dims.empty()) { + shape_str = "["; + for (size_t i = 0; i < dims.size(); ++i) { + shape_str += std::to_string(dims[i]); + if (i + 1 < dims.size()) { shape_str += ", "; } + } + shape_str += "]"; + } + + QNN_ALLOCATOR_VERBOSE( + "registerQnnTensorToSharedBuffer: ptr={}, tensor_id={}, tensor_name={}, tensorIdToPtrMap_.size()={}", ptr, tensor_id, + tensor_name, tensorIdToPtrMap_.size()); + + /** + * @brief Update tensor ID/name mappings and size tracking + * + * This lambda updates the internal mappings that allow us to find existing buffers + * for the same tensor in future registration attempts. + */ + auto updateMappings = [&](void* mapped_ptr) { + tensorIdToPtrMap_[tensor_id] = mapped_ptr; + if (tensor_name != "unknown") { tensorNameToPtrMap_[tensor_name] = mapped_ptr; } + ptrToSizeMap_[mapped_ptr] = total_bytes; + }; + + /** + * @brief Reuse an existing registered buffer for this tensor + * + * This lambda implements the core buffer reuse logic: + * 1. Verifies the existing buffer is still registered + * 2. Copies data from new buffer to existing buffer if needed + * 3. Updates tensor to use existing mem_handle + * 4. Updates internal mappings + * 5. Frees the new buffer to avoid memory leak + * + * @param existing_ptr Pointer to the existing registered buffer + * @return true if reuse succeeded, false if buffer is no longer registered + */ + auto reuseExistingBuffer = [&](void* existing_ptr) -> bool { + auto fd_handle_iter = ptrToFdAndMemHandleMap_.find(existing_ptr); + if (fd_handle_iter == ptrToFdAndMemHandleMap_.end()) { return false; } + + Qnn_MemHandle_t existing_mem_handle = fd_handle_iter->second.second; + size_t existing_size = ptrToSizeMap_.count(existing_ptr) > 0 ? ptrToSizeMap_[existing_ptr] : 0; + + // If pointers differ, copy data from new buffer to existing buffer + // This handles the case where a new buffer was allocated but we want to reuse the old one + if (existing_ptr != ptr) { + size_t bytes_to_copy = total_bytes; + if (existing_size > 0) { bytes_to_copy = std::min(bytes_to_copy, existing_size); } + if (bytes_to_copy > 0) { std::memcpy(existing_ptr, ptr, bytes_to_copy); } + + // Free the new buffer since we're reusing the existing one + if (qnnMemPtrSet_.count(ptr) > 0) { + qnnMemPtrSet_.erase(ptr); + rpcmem_free(ptr); + } + storage->ptr_ = existing_ptr; + } + + // Update tensor to use existing mem_handle + QNN_TENSOR_SET_MEM_TYPE(qnn_tensor, QNN_TENSORMEMTYPE_MEMHANDLE); + QNN_TENSOR_SET_MEM_HANDLE(qnn_tensor, existing_mem_handle); + updateMappings(existing_ptr); + rememberLastRegistration(tensor_id, tensor_name, existing_ptr, existing_mem_handle, total_bytes); + return true; + }; + + // Level 1: Check if this exact buffer pointer is already registered + // This is the fastest path and handles the common case in decode phase if (ptrToFdAndMemHandleMap_.count(ptr) > 0) { Qnn_MemHandle_t mem_handle = ptrToFdAndMemHandleMap_[ptr].second; QNN_TENSOR_SET_MEM_TYPE(qnn_tensor, QNN_TENSORMEMTYPE_MEMHANDLE); QNN_TENSOR_SET_MEM_HANDLE(qnn_tensor, mem_handle); - return; + updateMappings(ptr); + rememberLastRegistration(tensor_id, tensor_name, ptr, mem_handle, total_bytes); + return true; + } + + // Level 2: Check if we can reuse an existing buffer for the same tensor ID + // Tensor ID is the primary identifier and is more reliable than name + // This handles decode phase where the same tensor is used repeatedly + if (tensorIdToPtrMap_.count(tensor_id) > 0) { + void* existing_ptr = tensorIdToPtrMap_[tensor_id]; + QNN_ALLOCATOR_VERBOSE("Found existing mapping for tensor_id={}: existing_ptr={}", tensor_id, existing_ptr); + + if (existing_ptr == nullptr) { + // Mapping exists but buffer was freed, clean up and register new buffer + QNN_ALLOCATOR_VERBOSE( + "Existing mapping for tensor_id={} has nullptr ptr (buffer was freed), will register new buffer", tensor_id); + tensorIdToPtrMap_.erase(tensor_id); + } else if (reuseExistingBuffer(existing_ptr)) { + return true; + } else { + // Buffer exists but is no longer registered, clean up mapping + MLLM_WARN("Existing ptr {} for tensor_id={} is no longer registered, removing from map", existing_ptr, tensor_id); + tensorIdToPtrMap_.erase(tensor_id); + } + } else { + QNN_ALLOCATOR_VERBOSE("No existing mapping found for tensor_id={}", tensor_id); + } + + // Level 3: Check by tensor name as fallback (in case ID changed or is 0) + // Some tensors may have ID=0, so name becomes the fallback identifier + if (tensor_name != "unknown" && tensorNameToPtrMap_.count(tensor_name) > 0) { + void* existing_ptr = tensorNameToPtrMap_[tensor_name]; + QNN_ALLOCATOR_VERBOSE("Found existing mapping for tensor_name={}: existing_ptr={}", tensor_name, existing_ptr); + + if (existing_ptr == nullptr) { + // Mapping exists but buffer was freed, clean up and register new buffer + QNN_ALLOCATOR_VERBOSE( + "Existing mapping for tensor_name={} has nullptr ptr (mem_handle was deRegistered), will register new buffer", + tensor_name); + tensorNameToPtrMap_.erase(tensor_name); + } else if (reuseExistingBuffer(existing_ptr)) { + return true; + } else { + // Buffer exists but is no longer registered, clean up mapping + MLLM_WARN("Existing ptr {} for tensor_name={} is no longer registered", existing_ptr, tensor_name); + tensorNameToPtrMap_.erase(tensor_name); + } } // Get the file id of this memory space. @@ -73,30 +368,230 @@ void QNNAllocator::registerQnnTensorToSharedBuffer(void* ptr, Qnn_Tensor_t& qnn_ // Make qnn memory descriptor. Set ION. Qnn_MemDescriptor_t mem_descriptor = QNN_MEM_DESCRIPTOR_INIT; mem_descriptor.memShape = { - .numDim = QNN_TENSOR_GET_RANK(qnn_tensor), - .dimSize = QNN_TENSOR_GET_DIMENSIONS(qnn_tensor), + .numDim = rank, + .dimSize = dims_ptr, .shapeConfig = nullptr, }; - mem_descriptor.dataType = QNN_TENSOR_GET_DATA_TYPE(qnn_tensor); + mem_descriptor.dataType = data_type; mem_descriptor.memType = QNN_MEM_TYPE_ION; mem_descriptor.ionInfo.fd = mem_fd; QNN_TENSOR_SET_MEM_TYPE(qnn_tensor, QNN_TENSORMEMTYPE_MEMHANDLE); // Register to QNN memory Qnn_MemHandle_t mem_handle = QNN_TENSOR_GET_MEM_HANDLE(qnn_tensor); - MLLM_RT_ASSERT_EQ(QNN_SUCCESS, qnnInterface_.memRegister(context_, &mem_descriptor, 1u, &mem_handle)); + auto status = qnnInterface_.memRegister(context_, &mem_descriptor, 1u, &mem_handle); + + // Attempt to register the buffer with QNN + // This can fail if: + // 1. QNN HTP device memory is exhausted (typically ~2.5GB limit) + // 2. FastRPC memory mapping fails + // 3. SMMU (System Memory Management Unit) mapping fails + if (status != QNN_SUCCESS) { + auto stats = getRegisteredBufferStats(); + MLLM_ERROR("QNNAllocator::registerQnnTensorToSharedBuffer memRegister failed, status=0x{:x}, ptr={}, fd={}, bytes={}, " + "shape={}, dtype={}, tensor_id={}, tensor_name={}", + status, ptr, mem_fd, total_bytes, shape_str, static_cast(mem_descriptor.dataType), tensor_id, tensor_name); + MLLM_ERROR("Current registered buffers: {} buffers, {} MB", stats.count, stats.total_bytes / (1024 * 1024)); + + // Multi-level fallback strategy when registration fails + // This is critical when QNN device memory is exhausted + bool fallback_success = false; + + // Fallback Level 1: Try to reuse buffer by tensor ID + if (tensorIdToPtrMap_.count(tensor_id) > 0) { + void* existing_ptr = tensorIdToPtrMap_[tensor_id]; + if (existing_ptr != nullptr) { + MLLM_WARN("Fallback: Reusing existing buffer by ID for tensor_id={}, tensor_name={}, old_ptr={}, new_ptr={}", + tensor_id, tensor_name, existing_ptr, ptr); + fallback_success = reuseExistingBuffer(existing_ptr); + } + } + + // Fallback Level 2: Try to reuse buffer by tensor name + if (!fallback_success && tensor_name != "unknown" && tensorNameToPtrMap_.count(tensor_name) > 0) { + void* existing_ptr = tensorNameToPtrMap_[tensor_name]; + if (existing_ptr != nullptr) { + MLLM_WARN("Fallback: Reusing existing buffer by name for tensor_id={}, tensor_name={}, old_ptr={}, new_ptr={}", + tensor_id, tensor_name, existing_ptr, ptr); + fallback_success = reuseExistingBuffer(existing_ptr); + } + } + + // Fallback Level 3: Try to reuse last successfully registered buffer + // This is a last resort when memory is exhausted and we can't find exact matches + if (!fallback_success && hasLastRegistrationInfo_) { + bool same_tensor_id = tensor_id != 0 && tensor_id == lastRegistrationInfo_.tensor_id; + bool same_tensor_name = tensor_name != "unknown" && !tensor_name.empty() + && tensor_name == lastRegistrationInfo_.tensor_name; + bool ptr_still_registered = lastRegistrationInfo_.ptr != nullptr + && ptrToFdAndMemHandleMap_.count(lastRegistrationInfo_.ptr) > 0; + if ((same_tensor_id || same_tensor_name) && ptr_still_registered) { + MLLM_WARN("Fallback: Reusing last successful buffer for tensor_id={}, tensor_name={}, old_ptr={}, new_ptr={}", + tensor_id, tensor_name, lastRegistrationInfo_.ptr, ptr); + fallback_success = reuseExistingBuffer(lastRegistrationInfo_.ptr); + } else { + MLLM_WARN("Fallback: Last registration info unusable for tensor_id={}, tensor_name={}, " + "same_tensor_id={}, same_tensor_name={}, ptr_registered={}", + tensor_id, tensor_name, same_tensor_id, same_tensor_name, ptr_still_registered); + } + } + + // If all fallbacks failed, we must free the buffer and return failure + // The caller should handle this gracefully (e.g., by retrying or using CPU fallback) + if (!fallback_success) { + MLLM_ERROR("QNNAllocator::registerQnnTensorToSharedBuffer: memRegister failed and fallback also failed. " + "Buffer ptr={} will be freed, tensor registration cannot proceed.", ptr); + + if (qnnMemPtrSet_.count(ptr) > 0) { + qnnMemPtrSet_.erase(ptr); + rpcmem_free(ptr); + storage->ptr_ = nullptr; + eraseTensorMappingsForPtr(ptr, "register failure -> freed ptr"); + clearLastRegistrationIfMatches(ptr, "register failure -> freed ptr"); + QNN_ALLOCATOR_VERBOSE("QNNAllocator::registerQnnTensorToSharedBuffer: Freed ptr={} ({} bytes) after failure", ptr, + total_bytes); + } + + // Restore original tensor state + QNN_TENSOR_SET_MEM_HANDLE(qnn_tensor, original_mem_handle); + QNN_TENSOR_SET_MEM_TYPE(qnn_tensor, original_mem_type); + return false; + } + return true; + } else { + // Registration succeeded, log verbose information + QNN_ALLOCATOR_VERBOSE("Register shared buffer ptr={}, fd={}, bytes={}, shape={}, dtype={}, tensor_id={}, tensor_name={}", + ptr, mem_fd, total_bytes, shape_str, static_cast(mem_descriptor.dataType), tensor_id, + tensor_name); + } QNN_TENSOR_SET_MEM_HANDLE(qnn_tensor, mem_handle); ptrToFdAndMemHandleMap_.insert({ptr, {mem_fd, mem_handle}}); + updateMappings(ptr); + rememberLastRegistration(tensor_id, tensor_name, ptr, mem_handle, total_bytes); + return true; } void QNNAllocator::deRegisterQnnTensorFromSharedBuffer(void* ptr) { - MLLM_RT_ASSERT_EQ(ptrToFdAndMemHandleMap_.count(ptr), 1); - MLLM_RT_ASSERT_EQ(QNN_SUCCESS, qnnInterface_.memDeRegister(&(ptrToFdAndMemHandleMap_[ptr].second), 1)); - ptrToFdAndMemHandleMap_.erase(ptr); + auto iter = ptrToFdAndMemHandleMap_.find(ptr); + if (iter == ptrToFdAndMemHandleMap_.end()) { return; } + + Qnn_ErrorHandle_t status = qnnInterface_.memDeRegister(&(iter->second.second), 1); + if (status != QNN_SUCCESS) { + MLLM_WARN("QNNAllocator::deRegisterQnnTensorFromSharedBuffer memDeRegister failed, status=0x{:x}, ptr={}, fd={}", status, + ptr, iter->second.first); + } + + ptrToFdAndMemHandleMap_.erase(iter); + ptrToSizeMap_.erase(ptr); + eraseTensorMappingsForPtr(ptr, "explicit deRegister"); + clearLastRegistrationIfMatches(ptr, "explicit deRegister"); } +QNNAllocator::BufferStats QNNAllocator::getRegisteredBufferStats() const { + BufferStats stats{}; + stats.count = ptrToFdAndMemHandleMap_.size(); + stats.total_bytes = 0; + + for (const auto& [ptr, size] : ptrToSizeMap_) { + stats.total_bytes += size; + } + + return stats; +} + +bool QNNAllocator::isRegistered(void* ptr) const { + return ptrToFdAndMemHandleMap_.count(ptr) > 0; +} + +size_t QNNAllocator::getRegisteredBufferSize(void* ptr) const { + auto it = ptrToSizeMap_.find(ptr); + if (it == ptrToSizeMap_.end()) { return 0; } + return it->second; +} + +/** + * @brief Erase all tensor ID and name mappings that point to a specific buffer pointer + * + * When a buffer is freed or de-registered, we need to clean up all mappings that reference it. + * This ensures that future lookups won't find stale pointers. + * + * @param ptr The buffer pointer to remove from mappings + * @param reason Reason for erasure (for debugging/logging purposes) + */ +void QNNAllocator::eraseTensorMappingsForPtr(void* ptr, std::string_view reason) { + if (ptr == nullptr) { return; } + + // Remove all tensor ID mappings that point to this ptr + for (auto it = tensorIdToPtrMap_.begin(); it != tensorIdToPtrMap_.end();) { + if (it->second == ptr) { + it = tensorIdToPtrMap_.erase(it); + } else { + ++it; + } + } + + // Remove all tensor name mappings that point to this ptr + for (auto it = tensorNameToPtrMap_.begin(); it != tensorNameToPtrMap_.end();) { + if (it->second == ptr) { + it = tensorNameToPtrMap_.erase(it); + } else { + ++it; + } + } +} + +/** + * @brief Remember the last successful buffer registration for fallback purposes + * + * This function stores information about the most recent successful registration. + * This information is used as a last-resort fallback when: + * 1. New registration fails (e.g., memory exhausted) + * 2. Exact tensor ID/name matches are not found + * 3. The last registered buffer is still valid and matches the tensor + * + * This is particularly useful in decode phase where memory pressure is high + * and we want to maximize buffer reuse. + * + * @param tensor_id Tensor ID of the registered tensor + * @param tensor_name Tensor name of the registered tensor + * @param ptr Buffer pointer that was successfully registered + * @param mem_handle QNN memory handle from successful registration + * @param total_bytes Size of the registered buffer in bytes + */ +void QNNAllocator::rememberLastRegistration(uint32_t tensor_id, const std::string& tensor_name, void* ptr, + Qnn_MemHandle_t mem_handle, size_t total_bytes) { + if (ptr == nullptr || mem_handle == nullptr) { return; } + lastRegistrationInfo_.tensor_id = tensor_id; + lastRegistrationInfo_.tensor_name = tensor_name; + lastRegistrationInfo_.ptr = ptr; + lastRegistrationInfo_.mem_handle = mem_handle; + lastRegistrationInfo_.bytes = total_bytes; + hasLastRegistrationInfo_ = true; + // Note: Remembered registration info is used as fallback mechanism, logging removed for performance +} + +/** + * @brief Clear the last registration info if it matches the given pointer + * + * When a buffer is freed or de-registered, we should clear the last registration + * info if it references that buffer. This prevents using stale registration info + * in future fallback attempts. + * + * @param ptr The buffer pointer to check against + * @param reason Reason for clearing (for debugging/logging purposes) + */ +void QNNAllocator::clearLastRegistrationIfMatches(void* ptr, std::string_view reason) { + if (!hasLastRegistrationInfo_ || ptr == nullptr) { return; } + if (lastRegistrationInfo_.ptr == ptr) { + lastRegistrationInfo_ = {}; + hasLastRegistrationInfo_ = false; + } +} + +#undef QNN_ALLOCATOR_VERBOSE + std::shared_ptr createQNNAllocator() { return std::make_shared(); } } // namespace mllm::qnn diff --git a/mllm/backends/qnn/QNNAllocator.hpp b/mllm/backends/qnn/QNNAllocator.hpp index 7fc7335d4..c9fe7b399 100644 --- a/mllm/backends/qnn/QNNAllocator.hpp +++ b/mllm/backends/qnn/QNNAllocator.hpp @@ -3,8 +3,9 @@ #pragma once -#include #include +#include +#include #include "QnnCommon.h" #include "QnnInterface.h" #include "mllm/backends/base/Allocator.hpp" @@ -30,14 +31,7 @@ class QNNAllocator final : public Allocator { QNNAllocator(); // need to setQNNPointer afterward QNNAllocator(QNN_INTERFACE_VER_TYPE qnnInterface, void* context); - ~QNNAllocator() { - for (auto iter = ptrToFdAndMemHandleMap_.begin(); iter != ptrToFdAndMemHandleMap_.end();) { - Qnn_ErrorHandle_t deregisterRet = qnnInterface_.memDeRegister(&iter->second.second, 1); - if (QNN_SUCCESS != deregisterRet) { MLLM_ERROR("~QNNAllocator: qnnInterface_.memDeRegister failed"); } - rpcmem_free(iter->first); - iter = ptrToFdAndMemHandleMap_.erase(iter); - } - } + ~QNNAllocator(); void setQNNPointer(QNN_INTERFACE_VER_TYPE qnnInterface, void* context) { this->qnnInterface_ = qnnInterface; @@ -75,10 +69,21 @@ class QNNAllocator final : public Allocator { // Sharing access in between processing domains in QNN HTP backend. Using shared buffers can // eliminate data copy in between client code on the host CPU and HTP accelerator. - void registerQnnTensorToSharedBuffer(void* ptr, Qnn_Tensor_t& qnn_tensor); + bool registerQnnTensorToSharedBuffer(Storage* storage, Qnn_Tensor_t& qnn_tensor); void deRegisterQnnTensorFromSharedBuffer(void* ptr); + // Debug: Get statistics about registered buffers + struct BufferStats { + size_t count; + size_t total_bytes; + }; + [[nodiscard]] BufferStats getRegisteredBufferStats() const; + + // Debug: Check if a ptr is already registered + bool isRegistered(void* ptr) const; + [[nodiscard]] size_t getRegisteredBufferSize(void* ptr) const; + private: QNN_INTERFACE_VER_TYPE qnnInterface_; Qnn_ContextHandle_t context_ = nullptr; @@ -90,6 +95,64 @@ class QNNAllocator final : public Allocator { // to check if the ptr is allocted by rpcmem_alloc std::set qnnMemPtrSet_; std::map> ptrToFdAndMemHandleMap_; + // Track buffer sizes for statistics + std::map ptrToSizeMap_; + // Map tensor name to registered buffer ptr for reuse (fallback identifier) + // Used when tensor ID is 0 or unavailable + std::map tensorNameToPtrMap_; + + // Map tensor ID to registered buffer ptr for reuse (primary identifier) + // Tensor ID is more reliable than name and is used as the primary lookup key + // This enables buffer reuse across prefill and decode phases + std::map tensorIdToPtrMap_; + + /** + * @brief Information about the last successful buffer registration + * + * This structure stores metadata about the most recent successful registration, + * which is used as a last-resort fallback when: + * - New registration fails (e.g., memory exhausted) + * - Exact tensor ID/name matches are not found + * - The last registered buffer is still valid and matches the tensor + * + * This is particularly useful in decode phase where memory pressure is high. + */ + struct LastRegistrationInfo { + uint32_t tensor_id = 0; // Tensor ID of the registered tensor + std::string tensor_name; // Tensor name of the registered tensor + void* ptr = nullptr; // Buffer pointer that was successfully registered + Qnn_MemHandle_t mem_handle = nullptr; // QNN memory handle from successful registration + size_t bytes = 0; // Size of the registered buffer in bytes + }; + + LastRegistrationInfo lastRegistrationInfo_{}; // Last successful registration info + bool hasLastRegistrationInfo_ = false; // Whether last registration info is valid + + /** + * @brief Erase all tensor ID and name mappings that point to a specific buffer pointer + * @param ptr The buffer pointer to remove from mappings + * @param reason Reason for erasure (for debugging/logging purposes) + */ + void eraseTensorMappingsForPtr(void* ptr, std::string_view reason); + + /** + * @brief Remember the last successful buffer registration for fallback purposes + * @param tensor_id Tensor ID of the registered tensor + * @param tensor_name Tensor name of the registered tensor + * @param ptr Buffer pointer that was successfully registered + * @param mem_handle QNN memory handle from successful registration + * @param total_bytes Size of the registered buffer in bytes + */ + void rememberLastRegistration(uint32_t tensor_id, const std::string& tensor_name, void* ptr, + Qnn_MemHandle_t mem_handle, size_t total_bytes); + + /** + * @brief Clear the last registration info if it matches the given pointer + * @param ptr The buffer pointer to check against + * @param reason Reason for clearing (for debugging/logging purposes) + */ + void clearLastRegistrationIfMatches(void* ptr, std::string_view reason); + }; std::shared_ptr createQNNAllocator(); diff --git a/mllm/backends/qnn/QNNBackend.cpp b/mllm/backends/qnn/QNNBackend.cpp index efcb89a36..f2025698e 100644 --- a/mllm/backends/qnn/QNNBackend.cpp +++ b/mllm/backends/qnn/QNNBackend.cpp @@ -1,4 +1,5 @@ #include "QNNBackend.hpp" +#include #include #include #include @@ -534,13 +535,74 @@ void QNNBackend::graphExecute(const std::string& graphName, std::vector& return; } + // Prepare QNN input tensors by copying data from runtime inputs to graph input wrappers + // This handles the case where input tensor sizes may differ between prefill and decode phases std::vector qnn_inputs; std::vector qnn_outputs; for (int i = 0; i < model->getGraphInputTensorWrappers().size(); i++) { - // alloc and register qnn tensor - model->getGraphInputTensorWrappers()[i]->getDataContainer() = inputs[i]; // update data container - model->getGraphInputTensorWrappers()[i]->alloc(); // QNNAllocator will handle registered memory descriptor - qnn_inputs.push_back(*(model->getGraphInputTensorWrappers()[i]->getNativeTensor())); + auto wrapper = model->getGraphInputTensorWrappers()[i]; + auto& wrapper_tensor = wrapper->getDataContainer(); + const auto& runtime_input = inputs[i]; + + // Validate input tensors + if (runtime_input.isNil()) { + MLLM_ERROR("Input tensor {} is nil for graph '{}'", i, graphName); + return; + } + + if (wrapper_tensor.isNil()) { + MLLM_ERROR("Graph input wrapper {} for graph '{}' has no backing tensor", i, graphName); + return; + } + + // Check for size mismatches (can occur in decode phase where inputs may be smaller) + size_t dst_bytes = wrapper_tensor.bytes(); + size_t src_bytes = runtime_input.bytes(); + if (dst_bytes != src_bytes) { + MLLM_WARN("Graph '{}' input tensor {} byte-size mismatch: wrapper={} bytes, runtime input={} bytes. Copying " + "min(dst, src), but this may truncate data.", + graphName, i, dst_bytes, src_bytes); + } + + if (dst_bytes > 0) { + void* dst_ptr = wrapper_tensor.ptr(); + if (!dst_ptr) { + wrapper_tensor.alloc(); + dst_ptr = wrapper_tensor.ptr(); + } + + const void* src_ptr = runtime_input.ptr(); + size_t bytes_to_copy = std::min(dst_bytes, src_bytes); + if (!src_ptr) { + MLLM_ERROR("Runtime input tensor {} for graph '{}' has null data pointer", i, graphName); + return; + } + if (dst_ptr && src_ptr && dst_ptr != src_ptr) { + // Copy source data to destination buffer + // This ensures that the graph input wrapper has the correct data for execution + if (bytes_to_copy > 0) { + std::memcpy(dst_ptr, src_ptr, bytes_to_copy); + } + + // If source is smaller than destination, zero out the remaining bytes + // This is important for decode phase where input tensors may be smaller than prefill + // For example, decode phase may use [1, 1] input while wrapper expects [1, 128] + // Note: In current implementation with full [1, 128] tensor, this should not trigger + // but it's kept as a safety measure for future optimizations + if (src_bytes < dst_bytes) { + size_t remaining_bytes = dst_bytes - src_bytes; + std::memset(static_cast(dst_ptr) + bytes_to_copy, 0, remaining_bytes); + // Only log if zero-padding actually occurs (unexpected case) + MLLM_WARN("[QNN graphExecute] Graph '{}' input tensor {}: zero-padded {} bytes (src={} bytes, dst={} bytes)", + graphName, i, remaining_bytes, src_bytes, dst_bytes); + } + } + } + + // Allocate and register the wrapper tensor with QNN allocator + // QNNAllocator will handle registered memory descriptor when needed + wrapper->alloc(); + qnn_inputs.push_back(*(wrapper->getNativeTensor())); } // Prepare QNN outputs in QNN order diff --git a/mllm/backends/qnn/QNNUtils.cpp b/mllm/backends/qnn/QNNUtils.cpp index 03d752b08..13c577931 100644 --- a/mllm/backends/qnn/QNNUtils.cpp +++ b/mllm/backends/qnn/QNNUtils.cpp @@ -11,6 +11,7 @@ #include #include #include +#include namespace mllm::qnn { @@ -360,6 +361,7 @@ std::shared_ptr QNNTensorWrapper::createStaticTensor(const std std::shared_ptr tensorWrapper = QNNTensorWrapper::create(name, QNN_TENSOR_TYPE_STATIC, tensor, quantize); tensorWrapper->isAlloc_ = true; + tensorWrapper->registeredPtr_ = tensor.ptr(); uint32_t numElement = tensor.bytes(); Qnn_ClientBuffer_t clientBuffer = {.data = tensor.ptr(), .dataSize = numElement}; @@ -369,22 +371,73 @@ std::shared_ptr QNNTensorWrapper::createStaticTensor(const std } void QNNTensorWrapper::alloc() { - if (isAlloc_) { - MLLM_WARN("Tensor {} has already been allocated.", name_); - return; - } MLLM_RT_ASSERT(dataContainer_.device() == kQNN); - // if storage is not allocated, allocate it - // or, register the existing storage to QNN(passing allocated input to QNN) - if (!dataContainer_.impl()->ptr()) { dataContainer_.alloc(); } + void* currentPtr = dataContainer_.impl()->ptr(); + if (!currentPtr) { + dataContainer_.alloc(); + currentPtr = dataContainer_.ptr(); + } + + auto allocator = std::static_pointer_cast(Context::instance().getBackend(kQNN)->allocator()); + + auto storage = dataContainer_.impl()->storage(); + MLLM_RT_ASSERT(storage != nullptr); + + size_t requiredBytes = dataContainer_.bytes(); + + // Check if we have a previously registered buffer pointer + // This handles the case where tensor dimensions change (e.g., in decode phase) + // and the existing registered buffer is too small + if (registeredPtr_) { + // Verify that the registered buffer is still valid + if (!allocator->isRegistered(registeredPtr_)) { + // Buffer was de-registered, clear the reference + registeredPtr_ = nullptr; + isAlloc_ = false; + } else { + // Check if the registered buffer is large enough for current requirements + // If not, we need to de-register it and allocate a new one + size_t registeredBytes = allocator->getRegisteredBufferSize(registeredPtr_); + if (registeredBytes > 0 && registeredBytes < requiredBytes) { + // Registered buffer is too small, de-register it + // A new buffer will be allocated and registered below + allocator->deRegisterQnnTensorFromSharedBuffer(registeredPtr_); + registeredPtr_ = nullptr; + isAlloc_ = false; + } + } + } + + if (registeredPtr_ && registeredPtr_ != storage->ptr_) { + if (!allocator->isRegistered(registeredPtr_)) { + registeredPtr_ = nullptr; + } else { + void* freshPtr = storage->ptr_; + size_t bytesToCopy = dataContainer_.bytes(); + if (freshPtr && bytesToCopy > 0) { std::memcpy(registeredPtr_, freshPtr, bytesToCopy); } + if (freshPtr) { allocator->free(storage.get()); } + storage->ptr_ = registeredPtr_; + currentPtr = registeredPtr_; + } + } - std::static_pointer_cast(Context::instance().getBackend(kQNN)->allocator()) - ->registerQnnTensorToSharedBuffer(dataContainer_.ptr(), qnnTensor_); + if (isAlloc_ && registeredPtr_ == currentPtr) { return; } + if (!allocator->registerQnnTensorToSharedBuffer(storage.get(), qnnTensor_)) { + MLLM_ERROR("QNNTensorWrapper::alloc failed to register shared buffer for tensor {}", name_); + // Fail fast: prevent executing graph with invalid mem handle + MLLM_RT_ASSERT(false); + } + + registeredPtr_ = storage->ptr_; isAlloc_ = true; } +void QNNTensorWrapper::resetAlloc() { + isAlloc_ = false; +} + void QNNTensorWrapper::initFromQnnTensor(Qnn_Tensor_t* qnnTensor) { if (qnnTensor == nullptr) { MLLM_ERROR("QNNTensorWrapper::setQnnTensor() received nullptr"); @@ -503,4 +556,4 @@ void propagateQuantScale(const Tensor& input, Tensor& output) { } } -} // namespace mllm::qnn \ No newline at end of file +} // namespace mllm::qnn diff --git a/mllm/backends/qnn/QNNUtils.hpp b/mllm/backends/qnn/QNNUtils.hpp index b5d12cb10..5c0483dfb 100644 --- a/mllm/backends/qnn/QNNUtils.hpp +++ b/mllm/backends/qnn/QNNUtils.hpp @@ -203,10 +203,12 @@ class QNNTensorWrapper { [[nodiscard]] const Qnn_Tensor_t* getNativeTensor() const { return &qnnTensor_; } // Get tensor name - const std::string& getName() const { return name_; } + [[nodiscard]] const std::string& getName() const { return name_; } // alloc graph input/output tensor memory in QNN shared buffer void alloc(); + // reset allocation flag when dataContainer is updated + void resetAlloc(); Tensor& getDataContainer() { return dataContainer_; } const std::vector* getDimension() { return &dimensions_; } @@ -216,6 +218,7 @@ class QNNTensorWrapper { Tensor dataContainer_; Qnn_Tensor_t qnnTensor_; bool isAlloc_ = false; + void* registeredPtr_ = nullptr; }; class QNNParamTensorWrapper { diff --git a/mllm/core/aops/KVCacheOp.hpp b/mllm/core/aops/KVCacheOp.hpp index 695b4142f..c4d172b1d 100644 --- a/mllm/core/aops/KVCacheOp.hpp +++ b/mllm/core/aops/KVCacheOp.hpp @@ -34,10 +34,18 @@ class KVCacheOp : public BaseOp { virtual void clearCache(); + // Set current valid sequence length for KV cache logic + // Default no-op; backends that maintain cache should override. + virtual void setCurrentSeqCnt(int32_t /*seq*/) {} + + // Get current valid sequence length for KV cache logic + // Default returns -1; backends that maintain cache should override. + virtual int32_t getCurrentSeqCnt() const { return -1; } + inline const KVCacheOpOptions& options() const { return options_; } protected: KVCacheOpOptions options_; }; -} // namespace mllm::aops \ No newline at end of file +} // namespace mllm::aops diff --git a/mllm/models/qwen_npu/modeling_qwen_npu.hpp b/mllm/models/qwen_npu/modeling_qwen_npu.hpp index 355137a34..84b7d3627 100644 --- a/mllm/models/qwen_npu/modeling_qwen_npu.hpp +++ b/mllm/models/qwen_npu/modeling_qwen_npu.hpp @@ -268,6 +268,7 @@ class QwenAttentionMatmul final : public nn::Module { } nn::KVCache& getKVCache() { return kv_cache_; } + [[nodiscard]] const nn::KVCache& getKVCache() const { return kv_cache_; } }; class QwenOutProjAndMLP final : public nn::Module { @@ -398,6 +399,7 @@ class QwenDecoder final : public nn::Module { } nn::KVCache& getKVCache() { return self_attn_matmul_.getKVCache(); } + [[nodiscard]] const nn::KVCache& getKVCache() const { return self_attn_matmul_.getKVCache(); } }; class QwenText final : public nn::Module { @@ -440,6 +442,17 @@ class QwenText final : public nn::Module { void clearKVCache() { for (auto& block : decode_blocks_.list()) { block.getKVCache().clearCache(); } } + + void setKVCacheSeqCnt(int32_t seq) { + for (auto& block : decode_blocks_.list()) { block.getKVCache().setCurrentSeqCnt(seq); } + } + + [[nodiscard]] int32_t getKVCacheSeqCnt(int32_t layer_idx = 0) const { + if (layer_idx < 0 || layer_idx >= static_cast(decode_blocks_.list().size())) { + return -1; + } + return decode_blocks_.list()[layer_idx].getKVCache().getCurrentSeqCnt(); + } }; class QwenForCausalLM : public nn::Module, public ARGeneration { @@ -453,6 +466,12 @@ class QwenForCausalLM : public nn::Module, public ARGeneration { tie_word_embeddings_ = cfg.tie_word_embeddings; } + // Set current valid sequence length for KV cache across all layers + void setKVCacheSeqCnt(int32_t seq) { model.setKVCacheSeqCnt(seq); } + + // Get current valid sequence length for KV cache from specified layer + [[nodiscard]] int32_t getKVCacheSeqCnt(int32_t layer_idx = 0) const { return model.getKVCacheSeqCnt(layer_idx); } + ARGenerationOutputPast forward(const ARGenerationOutputPast& input, const ARGenerationArgs& args) override { auto sequence = input.at("sequence"); diff --git a/mllm/nn/Module.hpp b/mllm/nn/Module.hpp index cc16c7f89..ff72e93ae 100644 --- a/mllm/nn/Module.hpp +++ b/mllm/nn/Module.hpp @@ -208,6 +208,7 @@ class ModuleList final : public Module { } std::vector& list() { return layers_; } + const std::vector& list() const { return layers_; } }; template diff --git a/mllm/nn/layers/KVCache.cpp b/mllm/nn/layers/KVCache.cpp index 6ab0c2504..4b6bc70f9 100644 --- a/mllm/nn/layers/KVCache.cpp +++ b/mllm/nn/layers/KVCache.cpp @@ -23,4 +23,12 @@ void KVCache::setLayerIndex(int32_t layer_idx) { void KVCache::clearCache() { std::static_pointer_cast(impl()->getInstancedOp())->clearCache(); } +void KVCache::setCurrentSeqCnt(int32_t seq) { + std::static_pointer_cast(impl()->getInstancedOp())->setCurrentSeqCnt(seq); +} + +int32_t KVCache::getCurrentSeqCnt() const { + return std::static_pointer_cast(impl()->getInstancedOp())->getCurrentSeqCnt(); +} + } // namespace mllm::nn diff --git a/mllm/nn/layers/KVCache.hpp b/mllm/nn/layers/KVCache.hpp index 1cea392bd..55f194544 100644 --- a/mllm/nn/layers/KVCache.hpp +++ b/mllm/nn/layers/KVCache.hpp @@ -20,6 +20,12 @@ class KVCache : public Layer { void clearCache(); + // Update current valid sequence length in underlying KV cache op + void setCurrentSeqCnt(int32_t seq); + + // Get current valid sequence length from underlying KV cache op + int32_t getCurrentSeqCnt() const; + MLLM_LAYER_ANY_INPUTS_2_OUTPUTS_FORWARD };