diff --git a/.gitignore b/.gitignore index 97847b5ed..82d8c6e40 100644 --- a/.gitignore +++ b/.gitignore @@ -36,5 +36,4 @@ examples/demo_deepseek.cpp src/models/deepseek/* examples/demo.cpp -src/backends/qnn/sdk/* -*.mllm +src/backends/qnn/sdk/* \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index a85d86510..0035c3590 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -35,7 +35,12 @@ if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES endif () if (ARM) - set(EXECUTABLE_OUTPUT_PATH ${PROJECT_BINARY_DIR}/../bin-arm) +if(QNN) +set(EXECUTABLE_OUTPUT_PATH ${PROJECT_BINARY_DIR}/../bin-arm-qnn) +else() +set(EXECUTABLE_OUTPUT_PATH ${PROJECT_BINARY_DIR}/../bin-arm) +endif() + add_compile_definitions(__ARM_FEATURE_DOTPROD) # 检查是否使用的是 GCC 或 Clang 编译器 if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang") diff --git a/README.md b/README.md index 142d06958..59b5645e2 100644 --- a/README.md +++ b/README.md @@ -128,6 +128,7 @@ mllm is a lightweight, fast, and easy-to-use (multimodal) on-device LLM inferenc ```bash git clone https://github.com/UbiquitousLearning/mllm cd mllm +git submodule update --init --recursive ``` ### Check prerequisites diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index f82cef7cd..fdccfe492 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -81,9 +81,6 @@ func_llm_add_executable(demo_phonelm) func_llm_add_executable(demo_llama3) func_llm_add_executable(demo_minicpm_moe_mbm) func_llm_add_executable(demo_minicpm_moe_mbp) - - - func_vlm_add_executable(demo_llava) func_vlm_add_executable(demo_fuyu) func_vlm_add_executable(demo_vit) @@ -99,11 +96,12 @@ func_vlm_add_executable(demo_showui) if(QNN) func_llm_add_executable(demo_qwen_npu) - func_llm_add_executable(main_qwen_npu) + # func_llm_add_executable(main_qwen_npu) func_llm_add_executable(demo_phonelm_npu) - func_llm_add_executable(main_phonelm_npu) + # func_llm_add_executable(main_phonelm_npu) func_llm_add_executable(demo_qwen2.5_npu) - func_llm_add_executable(demo_qwen_pipeline) + # func_llm_add_executable(demo_qwen_pipeline) + func_vlm_add_executable(demo_qwen2_vl_npu) endif() diff --git a/examples/demo_phonelm_npu.cpp b/examples/demo_phonelm_npu.cpp index 7d269eb94..608ffab53 100644 --- a/examples/demo_phonelm_npu.cpp +++ b/examples/demo_phonelm_npu.cpp @@ -1,4 +1,5 @@ #include "Module.hpp" +#include "QNNBackend.hpp" #include "Types.hpp" #include #include "backends/cpu/CPUBackend.hpp" @@ -13,8 +14,10 @@ int main(int argc, char **argv) { cmdline::parser cmdParser; cmdParser.add("vocab", 'v', "specify mllm tokenizer model path", false, "../vocab/phonelm_vocab.mllm"); cmdParser.add("merge", 'e', "specify mllm merge file path", false, "../vocab/phonelm_merges.txt"); + cmdParser.add("model", 'm', "specify mllm model path", false, "../models/phonelm-1.5b-instruct-int8.mllm"); cmdParser.add("decoding", 'd', "specify mllm decoding model path", false, "../models/phonelm-1.5b-instruct-q4_0_4_4.mllm"); + cmdParser.add("limits", 'l', "max KV cache size", false, 400); cmdParser.add("thread", 't', "num of threads", false, 4); cmdParser.add("chunk", 'c', "chunk size", false, 64); @@ -28,6 +31,8 @@ int main(int argc, char **argv) { int chunk_size = cmdParser.get("chunk"); CPUBackend::cpu_threads = cmdParser.get("thread"); + Module::initBackend(MLLM_QNN); + auto tokenizer = SmolLMTokenizer(vocab_path, merge_path); PhoneLMConfig config(tokens_limit, "1.5B"); auto model = PhoneLMForCausalLM_NPU(config, chunk_size); @@ -57,8 +62,13 @@ int main(int argc, char **argv) { static_cast(Backend::global_backends[MLLM_CPU])->toggleSwitching(); // turn on the multi-chunk prefilling Module::isMultiChunkPrefilling = true; + // warmup END std::cout << "Warmup finished." << std::endl; + if (!std::filesystem::exists("qnn_context.bin")) { + static_cast(Backend::global_backends[MLLM_QNN])->saveQNNContext(); + } + vector in_strs = { "Give me a short introduction to large language model.", diff --git a/examples/demo_qwen.cpp b/examples/demo_qwen.cpp index 1c70d52ce..3ef6da59d 100644 --- a/examples/demo_qwen.cpp +++ b/examples/demo_qwen.cpp @@ -39,9 +39,7 @@ int main(int argc, char **argv) { model.load(model_path); vector in_strs = { - "Hello, who are you?", - "What can you do?", - "Please introduce Beijing University of Posts and Telecommunications.", + " Give me a short introduction to large language model.", }; for (int i = 0; i < in_strs.size(); ++i) { auto input_str = tokenizer.apply_chat_template(in_strs[i]); @@ -50,8 +48,8 @@ int main(int argc, char **argv) { std::cout << "[A] " << std::flush; LlmTextGeneratorOpts opt{ - .max_new_tokens = 100, - .do_sample = true, + .max_new_tokens = 1, + .do_sample = false, .temperature = 0.3F, .top_k = 50, .top_p = 0.F, diff --git a/examples/demo_qwen2.5_npu.cpp b/examples/demo_qwen2.5_npu.cpp index 761a34926..b86bef3b9 100644 --- a/examples/demo_qwen2.5_npu.cpp +++ b/examples/demo_qwen2.5_npu.cpp @@ -1,3 +1,5 @@ +#include "QNNBackend.hpp" +#include "Types.hpp" #include "backends/cpu/CPUBackend.hpp" #include "cmdline.h" #include "models/qwen/configuration_qwen.hpp" @@ -12,8 +14,8 @@ int main(int argc, char **argv) { cmdline::parser cmdParser; cmdParser.add("vocab", 'v', "specify mllm tokenizer model path", false, "../vocab/qwen2.5_vocab.mllm"); cmdParser.add("merge", 'e', "specify mllm merge file path", false, "../vocab/qwen2.5_merges.txt"); - cmdParser.add("model", 'm', "specify mllm model path", false, "../models/Qwen2.5-1.5B-Instruct.mllm"); - cmdParser.add("billion", 'b', "[0.5B | 1.8B | 1.5B]", false, "1.8B"); + cmdParser.add("model", 'm', "specify mllm model path", false, "../models/qwen-2-int8-int32bias-0kproj-test.mllm"); + cmdParser.add("billion", 'b', "[0.5B | 1.8B | 1.5B]", false, "1.5B"); cmdParser.add("limits", 'l', "max KV cache size", false, 400); cmdParser.add("thread", 't', "num of threads", false, 4); cmdParser.parse_check(argc, argv); @@ -26,11 +28,14 @@ int main(int argc, char **argv) { CPUBackend::cpu_threads = cmdParser.get("thread"); auto tokenizer = QWenTokenizer(vocab_path, merge_path); - QWenConfig config(tokens_limit, "1.5B", RoPEType::HFHUBROPE); - auto model = QWenForCausalLM_NPU(config, 64); + QWenConfig config(tokens_limit, model_billion, RoPEType::HFHUBROPE); + auto model = v2::QWenForCausalLM_NPU(config, 32); + + Module::initBackend(MLLM_QNN); + model.load(model_path); auto decoding_model = QWenForCausalLM(config); - decoding_model.load("../models/qwen-2.5-1.5b-instruct-q4_0_4_4.mllm"); + decoding_model.load("../models/qwen-2.5-1.5b-instruct-q4_k.mllm"); vector in_strs = { " Give me a short introduction to large language model.", @@ -38,7 +43,7 @@ int main(int argc, char **argv) { for (int i = 0; i < in_strs.size(); ++i) { auto input_str = tokenizer.apply_chat_template(in_strs[i]); - auto [real_seq_length, input_tensor] = tokenizer.tokenizeWithPadding(input_str, 64, config.vocab_size); + auto [real_seq_length, input_tensor] = tokenizer.tokenizeWithPadding(input_str, 32, config.vocab_size); std::cout << "[Q] " << in_strs[i] << std::endl; std::cout << "[A] " << std::flush; @@ -48,9 +53,7 @@ int main(int argc, char **argv) { LlmTextGeneratorOpts opt{ .max_new_tokens = 1, .do_sample = false, - .temperature = 0.3f, - .top_k = 50, - .top_p = 0.f, + .is_padding = true, .seq_before_padding = real_seq_length, }; @@ -67,7 +70,7 @@ int main(int argc, char **argv) { static_cast(Backend::global_backends[MLLM_CPU])->toggleSwitching(); LlmTextGeneratorOpts decoding_opt{ - .max_new_tokens = 100, + .max_new_tokens = 50, .do_sample = false, .temperature = 0.3f, .top_k = 50, @@ -79,6 +82,7 @@ int main(int argc, char **argv) { // call only once of switchDecodeTag if (!isSwitched) { static_cast(Backend::global_backends[MLLM_CPU])->toggleSwitching(); + isSwitched = true; } auto out_string = tokenizer.detokenize({out_token}); @@ -96,5 +100,9 @@ int main(int argc, char **argv) { static_cast(Backend::global_backends[MLLM_CPU])->setExecutionType(PROMPT); static_cast(Backend::global_backends[MLLM_CPU])->toggleSwitching(); std::cout << "\n"; + + if (!std::filesystem::exists("qnn_context.bin")) { + static_cast(Backend::global_backends[MLLM_QNN])->saveQNNContext(); + } } -} \ No newline at end of file +} diff --git a/examples/demo_qwen2.5_pipeline.cpp b/examples/demo_qwen2.5_pipeline.cpp new file mode 100644 index 000000000..5bc99cfcc --- /dev/null +++ b/examples/demo_qwen2.5_pipeline.cpp @@ -0,0 +1,124 @@ +#include "Backend.hpp" +#include "Trace.hpp" +#include "Types.hpp" +#include "backends/cpu/CPUBackend.hpp" +#include "cmdline.h" +#include "models/qwen/configuration_qwen.hpp" +#include "models/qwen/modeling_qwen_npu.hpp" +#include "models/qwen/modeling_qwen.hpp" +#include "models/qwen/tokenization_qwen.hpp" +#include "processor/PostProcess.hpp" +#include "Parallel.hpp" + +using namespace mllm; + +int main(int argc, char **argv) { + cmdline::parser cmdParser; + cmdParser.add("vocab", 'v', "specify mllm tokenizer model path", false, "../vocab/qwen2.5_vocab.mllm"); + cmdParser.add("merge", 'e', "specify mllm merge file path", false, "../vocab/qwen2.5_merges.txt"); + cmdParser.add("model", 'm', "specify mllm model path", false, "../models/qwen-2-int8-test.mllm"); + cmdParser.add("billion", 'b', "[0.5B | 1.8B | 1.5B]", false, "1.5B"); + cmdParser.add("limits", 'l', "max KV cache size", false, 400); + cmdParser.add("thread", 't', "num of threads", false, 4); + cmdParser.parse_check(argc, argv); + + string vocab_path = cmdParser.get("vocab"); + string merge_path = cmdParser.get("merge"); + string model_path = cmdParser.get("model"); + string model_billion = cmdParser.get("billion"); + int tokens_limit = cmdParser.get("limits"); + const int chunk_size = 128; + CPUBackend::cpu_threads = cmdParser.get("thread"); + + Module::initBackend(MLLM_QNN); + + auto tokenizer = QWenTokenizer(vocab_path, merge_path); + QWenConfig config(tokens_limit, model_billion, RoPEType::HFHUBROPE); + auto model = v2::QWenForCausalLM_NPU(config, chunk_size); + model.load(model_path); + auto decoding_model = QWenForCausalLM(config); + decoding_model.load("../models/qwen-2.5-1.5b-instruct-q4_0_4_4.mllm"); + + string trace_string = " "; + auto [_, input_tensor] = tokenizer.tokenizePaddingByChunk(trace_string, chunk_size, config.vocab_size); + Tracer::trace(&model, {input_tensor}); + std::cout << "Trace and Warmup finished" << std::endl; + + vector in_strs = { + // " Give me a short introduction to large language model.", + "\"Large Language Models (LLMs) are advanced artificial intelligence systems designed to understand and generate human-like text. These models are trained on vast amounts of data, enabling them to perform a wide range of tasks, from answering questions and summarizing text to generating creative content and engaging in conversational dialogue. LLMs like GPT-3 and GPT-4, developed by OpenAI, have set new benchmarks in natural language processing by leveraging deep learning architectures, particularly transformer models, which excel at capturing context and relationships within text. The scalability and versatility of LLMs make them invaluable tools for applications in education, customer service, content creation, and more. However, their deployment also raises ethical considerations, including issues of bias, misinformation, and the potential for misuse. As the field continues to evolve, ongoing research and responsible deployment strategies are essential to harnessing the full potential of these powerful AI systems while mitigating their risks.\"\nGenerate a title based on the above text."}; + + for (int i = 0; i < in_strs.size(); ++i) { + auto input_str = tokenizer.apply_chat_template(in_strs[i]); + auto [real_seq_length, input_tensor] = tokenizer.tokenizePaddingByChunk(input_str, chunk_size, config.vocab_size); + + // set total seq length for HeadLinear execute, which can not get the real seq length from Opts + static_cast(Backend::global_backends[MLLM_CPU])->setTotalSequenceLength(real_seq_length); + // set chunk size for the HeadLinear execute, which can not get the chunk size from Opts + static_cast(Backend::global_backends[MLLM_CPU])->setChunkSize(chunk_size); + + std::cout << "[Q] " << in_strs[i] << std::endl; + std::cout << "[A] " << std::flush; + std::cout << "real_seq_length: " << real_seq_length << std::endl; + + LlmTextGeneratorOpts opt{ + .max_new_tokens = 1, + .do_sample = false, + .is_padding = true, + .seq_before_padding = real_seq_length, + .chunk_size = chunk_size, + }; + + // tensor vectors to save the chunked tensors of the QNN prefilling input + bool isSwitched = false; + + ChunkPipeline pipeline(real_seq_length, chunk_size); + auto prefill_result = pipeline.run(input_tensor, opt, tokenizer, model, isSwitched); + + Module::isMultiChunkPrefilling = true; + Module::isFirstChunk = false; + + static_cast(Backend::global_backends[MLLM_CPU])->setCurSequenceLength(real_seq_length); + static_cast(Backend::global_backends[MLLM_CPU])->setExecutionType(AUTOREGRESSIVE); + static_cast(Backend::global_backends[MLLM_CPU])->toggleSwitching(); + + LlmTextGeneratorOpts decoding_opt{ + .max_new_tokens = 100, + .do_sample = false, + .temperature = 0.3f, + .top_k = 50, + .top_p = 0.f, + .is_padding = false, + }; + isSwitched = false; + + Tensor decoding_input; + decoding_input.setBackend(Backend::global_backends[MLLM_CPU]); + decoding_input.setTtype(INPUT_TENSOR); + decoding_input.reshape(1, 1, 1, 1); + decoding_input.setName("input0"); + decoding_input.alloc(); + decoding_input.setDataAt(0, 0, 0, 0, prefill_result->dataAt(0, 0, 0, 0)); + decoding_model.generate(decoding_input, decoding_opt, [&](unsigned int out_token) -> bool { + // call only once of switchDecodeTag + if (!isSwitched) { + static_cast(Backend::global_backends[MLLM_CPU])->toggleSwitching(); + isSwitched = true; + } + auto out_string = tokenizer.detokenize({out_token}); + auto [isOk, print_string] = tokenizer.postprocess(out_string); + if (isOk) { + std::cout << print_string << std::flush; + } else { + return false; + } + return true; + }); + + // turn on switching, set sequence length and execution type + static_cast(Backend::global_backends[MLLM_CPU])->setCurSequenceLength(0); + static_cast(Backend::global_backends[MLLM_CPU])->setExecutionType(PROMPT); + static_cast(Backend::global_backends[MLLM_CPU])->toggleSwitching(); + std::cout << "\n"; + } +} \ No newline at end of file diff --git a/examples/demo_qwen2_vl.cpp b/examples/demo_qwen2_vl.cpp index 3a23c982a..d28139e5c 100644 --- a/examples/demo_qwen2_vl.cpp +++ b/examples/demo_qwen2_vl.cpp @@ -2,6 +2,7 @@ #include "cmdline.h" #include "models/qwen2_vl/configuration_qwen2_vl.hpp" #include "models/qwen2_vl/modeling_qwen2_vl.hpp" +// #include "models/qwen2_vl/vtp/modeling_qwen2_vl.hpp" #include "models/qwen2_vl/processing_qwen2_vl.hpp" #include "processor/PostProcess.hpp" @@ -11,7 +12,7 @@ int main(int argc, char **argv) { cmdParser.add("vocab", 'v', "specify mllm tokenizer model path", false, "../vocab/qwen2vl_vocab.mllm"); cmdParser.add("merge", 'e', "specify mllm merge file path", false, "../vocab/qwen2vl_merges.txt"); cmdParser.add("model", 'm', "specify mllm model path", false, "../models/qwen-2-vl-2b-instruct-q4_k.mllm"); - cmdParser.add("limits", 'l', "max KV cache size", false, 2000); + cmdParser.add("limits", 'l', "max KV cache size", false, 800); cmdParser.add("thread", 't', "num of threads", false, 4); cmdParser.parse_check(argc, argv); @@ -25,8 +26,7 @@ int main(int argc, char **argv) { ParamLoader param_loader(model_path); auto processor = Qwen2VLProcessor(vocab_path, merge_path); Qwen2VLConfig config(tokens_limit, "1.5b"); - auto model_config = Qwen2VLConfig(config); - auto model = Qwen2VLModel(model_config); + auto model = Qwen2VLModel(config); model.load(model_path); vector in_imgs = { diff --git a/examples/demo_qwen2_vl_npu.cpp b/examples/demo_qwen2_vl_npu.cpp new file mode 100644 index 000000000..15e7723b5 --- /dev/null +++ b/examples/demo_qwen2_vl_npu.cpp @@ -0,0 +1,185 @@ +#include +#include +#include +#include "QNNBackend.hpp" +#include "Timing.hpp" +#include "Types.hpp" +#include "cmdline.h" +#include "models/qwen2_vl/configuration_qwen2_vl.hpp" +#include "models/qwen2_vl/modeling_qwen2_vl_npu.hpp" +#include "models/qwen2_vl/processing_qwen2_vl.hpp" +#include "processor/PostProcess.hpp" +#include "memory/MemInspect.hpp" + +using namespace mllm; +int main(int argc, char **argv) { + cmdline::parser cmdParser; + cmdParser.add("vocab", 'v', "specify mllm tokenizer model path", false, "../vocab/showui_vocab.mllm"); + cmdParser.add("merge", 'e', "specify mllm merge file path", false, "../vocab/showui_merges.txt"); + cmdParser.add("model", 'm', "specify mllm model path", false, "../models/showui-w8-fpbias-noshadow-xdl-test.mllm"); + cmdParser.add("limits", 'l', "max KV cache size", false, 1000); + cmdParser.add("thread", 't', "num of threads", false, 4); + cmdParser.parse_check(argc, argv); + + string vocab_path = cmdParser.get("vocab"); + string merge_path = cmdParser.get("merge"); + string model_path = cmdParser.get("model"); + const string cpu_model_path = "../models/showui-2B-rotated-q40.mllm"; + int tokens_limit = cmdParser.get("limits"); + int thread_num = cmdParser.get("thread"); + CPUBackend::cpu_threads = cmdParser.get("thread"); + + // TODO: add a function to calculate the chunk size + const int chunk_size = 256; + + Module::initBackend(MLLM_QNN); + + ParamLoader param_loader(model_path); + auto processor = Qwen2VLProcessor(vocab_path, merge_path); + Qwen2VLConfig config(tokens_limit, "1.5b-rotated"); + auto model_config = Qwen2VLConfig(config); + model_config.attn_implementation = "eager"; + + auto prefill_embedding = Qwen2VL_ImagePatchAndEmbedding(config); + auto prefill_body = Qwen2VL_PrefillBody(config, chunk_size); + prefill_embedding.load(cpu_model_path); + prefill_body.load(model_path); + + auto decoding_model = Qwen2VL_Decoding_Model(model_config); + decoding_model.load(cpu_model_path); + + vector in_imgs = { + "../assets/showui.png"}; + vector in_strs = { + "Based on the screenshot of the page, I give a text description and you give its corresponding location. The coordinate represents a clickable location [x, y] for an element, which is a relative coordinate on the screenshot, scaled from 0 to 1.<|vision_start|><|image_pad|><|vision_end|>桌面", + }; + + auto &in_str = in_strs[0]; + in_str = processor.tokenizer->apply_chat_template(in_str); + auto input_tensors = processor.process(in_str, in_imgs[0]); + + const int real_seq_length = input_tensors[0].sequence(); + std::cout << "real seq length: " << real_seq_length << std::endl; + + const int num_iter = (real_seq_length + chunk_size - 1) / chunk_size; + std::cout << "num_iter" << num_iter << std::endl; + // padding the position_ids to total chunk length(example: 256*2) for CPUMultimodalRoPEPipeline + prefill_embedding.get_position_ids(input_tensors, chunk_size * num_iter); + + // warm up (still need a warm up as the setup stage is not omitted now) + auto merged_embd_warmup_tensor = Tensor(Backend::global_backends[MLLM_QNN]); + merged_embd_warmup_tensor.reshape(1, 1, chunk_size, 1536); + merged_embd_warmup_tensor.setTtype(INPUT_TENSOR); + merged_embd_warmup_tensor.alloc(); + + merged_embd_warmup_tensor.setTtype(INPUT_TENSOR); + input_tensors.back().setTtype(INPUT_TENSOR); + vector prefill_input = {merged_embd_warmup_tensor, input_tensors.back()}; + + auto warm_start = mllm_time_ms(); + prefill_body(prefill_input); + auto warm_end = mllm_time_ms(); + std::cout << "warm up " << warm_end - warm_start << " ms" << std::endl; + + Module::isFirstChunk = false; + static_cast(Backend::global_backends[MLLM_CPU])->setCurSequenceLength(0); + static_cast(Backend::global_backends[MLLM_CPU])->setExecutionType(PROMPT); + static_cast(Backend::global_backends[MLLM_CPU])->toggleSwitching(); + + // set total seq length for HeadLinear execute, which can not get the real seq length from Opts + static_cast(Backend::global_backends[MLLM_CPU])->setTotalSequenceLength(real_seq_length); + // set chunk size for the HeadLinear execute, which can not get the chunk size from Opts + static_cast(Backend::global_backends[MLLM_CPU])->setChunkSize(chunk_size); + + for (auto &t : input_tensors) { + t.setTtype(INPUT_TENSOR); + } + + // 1. get the vit embedding using CPU + auto vit_start = mllm_time_ms(); + auto merged_embd = prefill_embedding(input_tensors); + auto vit_end = mllm_time_ms(); + std::cout << "vit embedding: " << vit_end - vit_start << " ms" << std::endl; + + // free prefill embedding tensor, approximately free 1GB for 59ms + auto begin_free = mllm_time_ms(); + auto &embedding_act = prefill_embedding.activation_tensors; + // go through the activation tensors to get the merged_embd + for (auto iter = embedding_act.begin(); iter != embedding_act.end(); ++iter) { + // std::cout << iter->first << std::endl; + if (iter->first.find("input") != std::string::npos || iter->first.find("index_put") != std::string::npos) { + continue; + } + iter->second->free(); + } + auto end_free = mllm_time_ms(); + std::cout << "free time: " << end_free - begin_free << " ms" << std::endl; + + // 2. QNN LLM Prefill + unsigned int out_token = 0; + auto start_time = mllm_time_ms(); + for (auto i = 0; i < num_iter; ++i) { + // copy the data from merged_embd[0] to merged_embd_warmup_tensor + auto source = merged_embd[0].ptrAt(0, 0, chunk_size * i, 0); + auto dest = prefill_input[0].hostPtr(); + if (i == 0) { + memcpy(dest, source, prefill_input[0].cntSize()); + } + { + memcpy(dest, source, (merged_embd[0].sequence() % chunk_size) * merged_embd[0].dimension() * sizeof(float)); + } + + auto result = prefill_body(prefill_input); + + if (i == 0) { // turn off switching to avoid RoPE h_cnt_ reset to curSequenceLength in next chunk + static_cast(Backend::global_backends[MLLM_CPU])->toggleSwitching(); + } + + if (i == 1) { + auto end_time = mllm_time_ms(); + std::cout << "Prefill:" << end_time - start_time << " ms" << std::endl; + + auto outputs = processor.detokenize(result[0], real_seq_length % chunk_size); + auto out_string = outputs.first; + out_token = outputs.second; + auto [not_end, output_string] = processor.tokenizer->postprocess(out_string); + std::cout << output_string << std::flush; + } + } + + chatPostProcessing(out_token, input_tensors[0], {&input_tensors[1], &input_tensors[2]}); + + static_cast(Backend::global_backends[MLLM_CPU])->setCurSequenceLength(real_seq_length); + static_cast(Backend::global_backends[MLLM_CPU])->setExecutionType(AUTOREGRESSIVE); + static_cast(Backend::global_backends[MLLM_CPU])->toggleSwitching(); + + // 3. CPU LLM Decoding + for (auto &t : input_tensors) { // set to INPUT_TENSOR to let decoding module update act + t.setTtype(INPUT_TENSOR); + } + + const int last_position_id = input_tensors[3].dataAt(0, 0, 0, real_seq_length - 1); + for (int step = 0; step < 100; step++) { + // use the last position id(no padding position) in decoding + prefill_embedding.get_position_ids(input_tensors, 0, last_position_id + 1 + step); + + auto result = decoding_model(input_tensors); + auto outputs = processor.detokenize(result[0]); + auto out_string = outputs.first; + auto out_token = outputs.second; + auto [not_end, output_string] = processor.tokenizer->postprocess(out_string); + if (!not_end) { break; } + std::cout << output_string << std::flush; + chatPostProcessing(out_token, input_tensors[0], {&input_tensors[1], &input_tensors[2]}); + + if (step == 0) static_cast(Backend::global_backends[MLLM_CPU])->toggleSwitching(); + } + + std::cout << std::endl; + + if (!std::filesystem::exists("qnn_context.bin")) { + static_cast(Backend::global_backends[MLLM_QNN])->saveQNNContext(); + } + + return 0; +} \ No newline at end of file diff --git a/examples/demo_qwen_npu.cpp b/examples/demo_qwen_npu.cpp index 9e230f01c..4ac0ede23 100644 --- a/examples/demo_qwen_npu.cpp +++ b/examples/demo_qwen_npu.cpp @@ -1,8 +1,10 @@ -#include "backends/cpu/CPUBackend.hpp" +// #include "QNNBackend.hpp" +// #include "backends/cpu/CPUBackend.hpp" +#include "Backend.hpp" #include "cmdline.h" #include "models/qwen/configuration_qwen.hpp" -#include "models/qwen/modeling_qwen_npu.hpp" #include "models/qwen/modeling_qwen.hpp" +#include "models/qwen/modeling_qwen_npu.hpp" #include "models/qwen/tokenization_qwen.hpp" #include "processor/PostProcess.hpp" @@ -23,12 +25,14 @@ int main(int argc, char **argv) { string model_path = cmdParser.get("model"); string model_billion = cmdParser.get("billion"); int tokens_limit = cmdParser.get("limits"); - const int chunk_size = 128; + const int chunk_size = 32; CPUBackend::cpu_threads = cmdParser.get("thread"); + Module::initBackend(MLLM_QNN); + auto tokenizer = QWenTokenizer(vocab_path, merge_path); QWenConfig config(tokens_limit, model_billion, RoPEType::HFHUBROPE); - auto model = QWenForCausalLM_NPU(config, chunk_size); + auto model = v2::QWenForCausalLM_NPU(config, chunk_size); model.load(model_path); auto decoding_model = QWenForCausalLM(config); decoding_model.load("../models/qwen-1.5-1.8b-chat-q4k.mllm"); @@ -58,9 +62,14 @@ int main(int argc, char **argv) { // warmup END std::cout << "Warmup finished." << std::endl; + // if (!std::filesystem::exists("qnn_context.bin")) { + // static_cast(Backend::global_backends[MLLM_QNN])->saveQNNContext(); + // } + vector in_strs = { - // " Give me a short introduction to large language model.", - "\"Large Language Models (LLMs) are advanced artificial intelligence systems designed to understand and generate human-like text. These models are trained on vast amounts of data, enabling them to perform a wide range of tasks, from answering questions and summarizing text to generating creative content and engaging in conversational dialogue. LLMs like GPT-3 and GPT-4, developed by OpenAI, have set new benchmarks in natural language processing by leveraging deep learning architectures, particularly transformer models, which excel at capturing context and relationships within text. The scalability and versatility of LLMs make them invaluable tools for applications in education, customer service, content creation, and more. However, their deployment also raises ethical considerations, including issues of bias, misinformation, and the potential for misuse. As the field continues to evolve, ongoing research and responsible deployment strategies are essential to harnessing the full potential of these powerful AI systems while mitigating their risks.\"\nGenerate a title based on the above text."}; + " Give me a short introduction to large language model.", + // "\"Large Language Models (LLMs) are advanced artificial intelligence systems designed to understand and generate human-like text. These models are trained on vast amounts of data, enabling them to perform a wide range of tasks, from answering questions and summarizing text to generating creative content and engaging in conversational dialogue. LLMs like GPT-3 and GPT-4, developed by OpenAI, have set new benchmarks in natural language processing by leveraging deep learning architectures, particularly transformer models, which excel at capturing context and relationships within text. The scalability and versatility of LLMs make them invaluable tools for applications in education, customer service, content creation, and more. However, their deployment also raises ethical considerations, including issues of bias, misinformation, and the potential for misuse. As the field continues to evolve, ongoing research and responsible deployment strategies are essential to harnessing the full potential of these powerful AI systems while mitigating their risks.\"\nGenerate a title based on the above text." + }; for (int i = 0; i < in_strs.size(); ++i) { auto input_str = tokenizer.apply_chat_template(in_strs[i]); @@ -114,6 +123,8 @@ int main(int argc, char **argv) { static_cast(Backend::global_backends[MLLM_CPU])->setExecutionType(AUTOREGRESSIVE); static_cast(Backend::global_backends[MLLM_CPU])->toggleSwitching(); + exit(0); + LlmTextGeneratorOpts decoding_opt{ .max_new_tokens = 100, .do_sample = false, diff --git a/examples/demo_qwen_pipeline.cpp b/examples/demo_qwen_pipeline.cpp index f2f8bb8d0..1d632d7df 100644 --- a/examples/demo_qwen_pipeline.cpp +++ b/examples/demo_qwen_pipeline.cpp @@ -30,9 +30,11 @@ int main(int argc, char **argv) { const int chunk_size = 128; CPUBackend::cpu_threads = cmdParser.get("thread"); + Module::initBackend(MLLM_QNN); + auto tokenizer = QWenTokenizer(vocab_path, merge_path); QWenConfig config(tokens_limit, model_billion, RoPEType::HFHUBROPE); - auto model = QWenForCausalLM_NPU(config, chunk_size); + auto model = v2::QWenForCausalLM_NPU(config, chunk_size); model.load(model_path); auto decoding_model = QWenForCausalLM(config); decoding_model.load("../models/qwen-1.5-1.8b-chat-q4k.mllm"); diff --git a/examples/demo_showui.cpp b/examples/demo_showui.cpp index 4349f9872..54bfd809a 100644 --- a/examples/demo_showui.cpp +++ b/examples/demo_showui.cpp @@ -29,8 +29,7 @@ int main(int argc, char **argv) { int max_pixels = 1344 * 28 * 28; auto processor = Qwen2VLProcessor(vocab_path, merge_path, min_pixels, max_pixels); Qwen2VLConfig config(tokens_limit, "1.5b"); - auto model_config = Qwen2VLConfig(config); - auto model = Qwen2VLModel(model_config); + auto model = Qwen2VLModel(config); model.load(model_path); vector in_imgs = { diff --git a/include/OpDefined.hpp b/include/OpDefined.hpp index ed8806350..9704cd0c6 100644 --- a/include/OpDefined.hpp +++ b/include/OpDefined.hpp @@ -38,6 +38,7 @@ enum OpType { CONVOLUTION2D, CONVOLUTION3D, VISIONROPE, + MULTIMODALROPEPIP, MULTIMODALROPE, AVGPOOL2D, MAXPOOL2D, @@ -110,6 +111,7 @@ static const vector OpNames = { "Convolution2D", "Convolution3D", "VisonRoPE", + "MultimodalRoPEPipeline", "MultimodalRoPE", "AvgPool2D", "MaxPool2D", @@ -181,6 +183,7 @@ enum TensorFuncType { FUNC_LIKE, FUNC_SCATTERREDUCE, FUNC_APPLY_VISIOROPE, + FUNC_FA2, // models use only FUNC_FUYU_GATHER_EMBD, FUNC_PHI3V_HD_MERGE, diff --git a/include/Types.hpp b/include/Types.hpp index 5fd850ce6..6d3bb4429 100644 --- a/include/Types.hpp +++ b/include/Types.hpp @@ -155,6 +155,7 @@ enum RoPEType { PERSIMMONROPE = 3, HFHUBROPE = 4, MLAROPE = 5, + NTKROPE = 6, }; enum RoPEThetaType { diff --git a/scripts/run_phonelm_qnn.sh b/scripts/run_phonelm_qnn.sh index 818fb2ec2..942c72d9a 100755 --- a/scripts/run_phonelm_qnn.sh +++ b/scripts/run_phonelm_qnn.sh @@ -46,5 +46,5 @@ if [ $? -ne 0 ]; then exit 1 fi -adb push ../bin-arm/demo_phonelm_npu /data/local/tmp/mllm/bin/ +adb push ../bin-arm-qnn/demo_phonelm_npu /data/local/tmp/mllm/bin/ adb shell "cd /data/local/tmp/mllm/bin && export LD_LIBRARY_PATH=/data/local/tmp/mllm/qnn-lib && export ADSP_LIBRARY_PATH=/data/local/tmp/mllm/qnn-lib && ./demo_phonelm_npu" \ No newline at end of file diff --git a/scripts/run_qwen2_vl_qnn.sh b/scripts/run_qwen2_vl_qnn.sh new file mode 100755 index 000000000..0af243172 --- /dev/null +++ b/scripts/run_qwen2_vl_qnn.sh @@ -0,0 +1,52 @@ +#!/bin/bash + +adb shell mkdir -p /data/local/tmp/mllm/vocab +adb shell mkdir -p /data/local/tmp/mllm/qnn-lib + +adb push ../vocab/qwen_vocab.mllm /data/local/tmp/mllm/vocab/ + + +if ! adb shell [ -f "/data/local/tmp/mllm/models/showui-w8-fpbias-noshadow-xdl-test.mllm" ]; then + adb push ../models/showui-w8-fpbias-noshadow-xdl-test.mllm "/data/local/tmp/mllm/models/showui-w8-fpbias-noshadow-xdl-test.mllm" +else + echo "showui-w8-fpbias-noshadow-xdl-test file already exists" +fi + + +if ! adb shell [ -f "/data/local/tmp/mllm/models/showui-2B-rotated-q40.mllm" ]; then + adb push ../models/showui-2B-rotated-q40.mllm "/data/local/tmp/mllm/models/showui-2B-rotated-q40.mllm" +else + echo "showui-2B-rotated-q40.mllm file already exists" +fi + +if [ -z "$QNN_SDK_ROOT" ]; then + export QNN_SDK_ROOT=/root/research/dev/mllm/src/backends/qnn/sdk + # export HEXAGON_SDK_ROOT=/root/research/dev/mllm/src/backends/qnn/HexagonSDK/5.4.0 + echo "QNN_SDK_ROOT is set to $QNN_SDK_ROOT" + # exit 1 +else + echo "QNN_SDK_ROOT is set to $QNN_SDK_ROOT" +fi + +ANDR_LIB=$QNN_SDK_ROOT/lib/aarch64-android +OP_PATH=../src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/build +DEST=/data/local/tmp/mllm/qnn-lib + +adb push $ANDR_LIB/libQnnHtp.so $DEST +adb push $ANDR_LIB/libQnnHtpV75Stub.so $DEST +adb push $ANDR_LIB/libQnnHtpPrepare.so $DEST +adb push $ANDR_LIB/libQnnHtpProfilingReader.so $DEST +adb push $ANDR_LIB/libQnnHtpOptraceProfilingReader.so $DEST +adb push $ANDR_LIB/libQnnHtpV75CalculatorStub.so $DEST +adb push $QNN_SDK_ROOT/lib/hexagon-v75/unsigned/libQnnHtpV75Skel.so $DEST +adb push $OP_PATH/aarch64-android/libQnnLLaMAPackage.so $DEST/libQnnLLaMAPackage_CPU.so +adb push $OP_PATH/hexagon-v75/libQnnLLaMAPackage.so $DEST/libQnnLLaMAPackage_HTP.so + + +if [ $? -ne 0 ]; then + echo "adb push failed" + exit 1 +fi +# adb shell "rm /data/local/tmp/mllm/bin/qnn_context.bin" +adb push ../bin-arm-qnn/demo_qwen2_vl_npu /data/local/tmp/mllm/bin/ +adb shell "cd /data/local/tmp/mllm/bin && export LD_LIBRARY_PATH=/data/local/tmp/mllm/qnn-lib && export ADSP_LIBRARY_PATH=/data/local/tmp/mllm/qnn-lib && ./demo_qwen2_vl_npu" \ No newline at end of file diff --git a/scripts/run_qwen_qnn.sh b/scripts/run_qwen_qnn.sh index 81a572b2f..5c18b320c 100755 --- a/scripts/run_qwen_qnn.sh +++ b/scripts/run_qwen_qnn.sh @@ -47,5 +47,6 @@ if [ $? -ne 0 ]; then exit 1 fi -adb push ../bin-arm/demo_qwen_npu /data/local/tmp/mllm/bin/ +# adb shell "rm /data/local/tmp/mllm/bin/qnn_context.bin" +adb push ../bin-arm-qnn/demo_qwen_npu /data/local/tmp/mllm/bin/ adb shell "cd /data/local/tmp/mllm/bin && export LD_LIBRARY_PATH=/data/local/tmp/mllm/qnn-lib && export ADSP_LIBRARY_PATH=/data/local/tmp/mllm/qnn-lib && ./demo_qwen_npu" \ No newline at end of file diff --git a/src/Layer.hpp b/src/Layer.hpp index 20b2b9ccf..f49eb1315 100644 --- a/src/Layer.hpp +++ b/src/Layer.hpp @@ -465,6 +465,15 @@ class KVCache final : public Layer { param_["for_xnn"] = false; init(std::move(name), OpType::KVCACHE); } + explicit KVCache(int head, int hidden, int n_rep, int cache_max, bool fa2, std::string name) { + param_["head"] = head; + param_["hidden"] = hidden; + param_["n_rep"] = n_rep; + param_["cache_max"] = cache_max; + param_["for_xnn"] = false; + param_["fa2"] = fa2; + init(std::move(name), OpType::KVCACHE); + } explicit KVCache(int cache_max, std::string name) { param_["n_rep"] = 1; @@ -494,11 +503,26 @@ class KVCache final : public Layer { init(std::move(name), OpType::KVCACHE); } } + explicit KVCache(int head, int hidden, int n_rep, int cache_max, std::string name, bool npuEnbaled) { + param_["head"] = head; + param_["hidden"] = hidden; + param_["n_rep"] = n_rep; + param_["cache_max"] = cache_max; + param_["for_xnn"] = false; + if (npuEnbaled) { + init(std::move(name), OpType::KVCACHENPU); + } else { + init(std::move(name), OpType::KVCACHE); + } + } Tensor operator()(Tensor input) { auto ts = run({input}, 1); return ts[0]; } int getCacheSeqLen() { + if (!op_) { + return -1; + } return op_->getCacheSeqLen(); } void clearCache() { @@ -616,7 +640,11 @@ class MultimodalRoPE final : public Layer { for (int i = 0; i < mrope_section.size(); i++) { param_["mrope_section_" + std::to_string(i)] = (float)mrope_section[i]; } - init(std::move(name), OpType::MULTIMODALROPE); + if (Backend::global_backends.size() == 2 && Backend::global_backends.find(MLLM_QNN) != Backend::global_backends.end()) { + init(std::move(name), OpType::MULTIMODALROPEPIP); + } else { + init(std::move(name), OpType::MULTIMODALROPE); + } } Tensor operator()(Tensor input, Tensor &position_ids) { auto ts = run({input, position_ids}, 1); @@ -658,8 +686,9 @@ class Position final : public Layer { class Quantize final : public Layer { public: - explicit Quantize(bool isNSHD, std::string name) { + explicit Quantize(bool isNSHD, std::string name, DataType type = MLLM_TYPE_I8) { param_["isNSHD"] = (float)isNSHD; + param_["dtype"] = (float)type; init(std::move(name), OpType::QUANTIZE); } Tensor operator()(Tensor input) { @@ -685,9 +714,10 @@ class Direct final : public Layer { class Dequantize final : public Layer { public: - explicit Dequantize(bool isNSHD, std::string name, bool isFP32 = true) { + explicit Dequantize(bool isNSHD, std::string name, bool isFP32 = true, DataType inType = MLLM_TYPE_I8) { param_["isNSHD"] = (float)isNSHD; param_["isFP32"] = (float)isFP32; + param_["inType"] = (float)inType; init(std::move(name), OpType::DEQUANTIZE); } Tensor operator()(Tensor input) { @@ -783,18 +813,6 @@ class View final : public Layer { } }; -class SubgraphStart final : public Layer { -public: - explicit SubgraphStart(const std::string &name) { - init(name, OpType::SUBGRAPHSTART); - } - - Tensor operator()(Tensor input) { - auto ts = run({input}, 1); - return ts[0]; - } -}; - class Transpose final : public Layer { public: explicit Transpose(std::vector perm, std::string name) { @@ -810,14 +828,30 @@ class Transpose final : public Layer { } }; +class SubgraphStart final : public Layer { +public: + SubgraphStart() = default; + explicit SubgraphStart(const std::string &name) { + init(name, OpType::SUBGRAPHSTART); + } + + Tensor operator()(vector inputs) { + Module::tmp_device = MLLM_QNN; + auto ts = run(inputs, 1); + return ts[0]; + } +}; + class SubgraphFinalize final : public Layer { public: + SubgraphFinalize() = default; explicit SubgraphFinalize(const std::string &name) { init(name, OpType::SUBGRAPHFINALIZE); } - Tensor operator()(Tensor input) { - auto ts = run({input}, 1); + Tensor operator()(vector inputs) { + auto ts = run(inputs, 1); + Module::tmp_device = MLLM_CPU; return ts[0]; } }; @@ -889,7 +923,6 @@ class NTKRoPE final : public Layer { return op_->clearCache(); } }; -// Only for QNN END } // namespace mllm diff --git a/src/Module.hpp b/src/Module.hpp index 93c1c8dab..d7d9fb9a3 100644 --- a/src/Module.hpp +++ b/src/Module.hpp @@ -130,6 +130,8 @@ class Module { } } } + + // TODO: Deprecated, the module is not backend specific, the backend should be set in the SubGraphStart and SubGraphFinalize void to(BackendType type) { initBackend(type); device_ = type; diff --git a/src/Parallel.hpp b/src/Parallel.hpp index 9603a6107..ea014b88b 100644 --- a/src/Parallel.hpp +++ b/src/Parallel.hpp @@ -22,7 +22,7 @@ class ChunkPipeline { chunk_num = seq_length_padding / chunk_size; } - shared_ptr run(Tensor &input_tensor, LlmTextGeneratorOpts &opt, Tokenizer &tokenizer, Module &model, bool &isSwitched) { + shared_ptr run(Tensor &input_tensor, LlmTextGeneratorOpts &opt, Tokenizer &tokenizer, Module &model, bool &isSwitched, const vector &clean_tensors = {}) { const int num_graph = Tracer::model_.size(); Tensor::tensor_status = TENSOR_STATIC_READY; std::cout << "num_graph: " << num_graph << std::endl; @@ -42,7 +42,7 @@ class ChunkPipeline { return; } // only the last chunk need to execute the last graph - if(i == num_graph - 1 && chunk_id != chunk_num - 1) { + if (i == num_graph - 1 && chunk_id != chunk_num - 1) { return; } // before the first graph, need to refresh the input tensor @@ -70,7 +70,7 @@ class ChunkPipeline { } } auto end_t = mllm_time_us(); - std::cout << "time: " << (end_t - start_t) / 1000.0F << "ms" << std::endl; + std::cout << "prefill time: " << (end_t - start_t) / 1000.0F << "ms" << std::endl; auto postProcessing = [&](shared_ptr result, shared_ptr &out_result, int real_seq_length) -> unsigned int { assert(result->batch() == 1); @@ -95,6 +95,12 @@ class ChunkPipeline { auto token_idx = postProcessing(result[0], chunked_tensors.back(), real_seq_length); auto out_string = tokenizer.detokenize({token_idx}); std::cout << out_string << std::flush; + + for (auto tensor : clean_tensors) { + tensor->reshape(0, 0, 0, 0); + tensor->alloc(); + } + return chunked_tensors.back(); } }; diff --git a/src/Tensor.cpp b/src/Tensor.cpp index d4104528c..40955d557 100644 --- a/src/Tensor.cpp +++ b/src/Tensor.cpp @@ -165,6 +165,9 @@ Tensor &Tensor::to(BackendType backend_type) { // realloc the tensor if (backend_type == MLLM_QNN && device() == MLLM_CPU) { this->free(); + module()->activation_tensors[name()]->setBackend(Backend::global_backends[backend_type]); + this->setBackend(Backend::global_backends[backend_type]); + return *this; } if (backend_type == MLLM_CPU && device() == MLLM_XNNPACK) { module()->activation_tensors[name()]->setBackend(Backend::global_backends[backend_type]); @@ -176,8 +179,7 @@ Tensor &Tensor::to(BackendType backend_type) { this->setBackend(Backend::global_backends[backend_type]); return *this; } - module()->activation_tensors[name()]->setBackend(Backend::global_backends[backend_type]); - this->alloc(); + return *this; }; @@ -194,236 +196,6 @@ std::vector Tensor::runFunc(std::vector out_names, return backend->runFunc(out_names, type, float_args, input_tensors, in_place); } -/* -Tensor &Tensor::getFunc(const std::string &suffix, const TensorFuncType type, - vector float_args, vector other_tensors) { - assert(module() != nullptr); - auto &module_tensors = module()->activation_tensors; - auto &activation_tensors_num = module()->activation_tensors_num; - const std::string next_name = impl_->name_ + "-" + suffix; - // if (module_tensors.find(name_) == module_tensors.end()) { - // module_tensors[name_] = std::shared_ptr(this, [](Tensor *) {}); - // } - if (module_tensors.find(next_name) == module_tensors.end()) { - module_tensors[next_name] = std::make_shared(impl_->backend_); - module_tensors[next_name]->setName(next_name); - module_tensors[next_name]->setModule(module()); - activation_tensors_num[next_name] = 0; - } - if (module()->doLoad) { return *module_tensors[next_name]; } - TensorFunction *func = impl_->backend_->funcCreate(type); - std::vector tensorPtrs = {module_tensors[impl_->name_].get()}; - for (auto &other_tensor : other_tensors) { tensorPtrs.push_back(other_tensor); } -#ifdef DEBUGOPTIME - auto start_t = mllm_time_us(); -#endif - switch (Tensor::tensor_status) { - case TENSOR_STATIC_INIT: { - func->setup({module_tensors[next_name].get()}, tensorPtrs, float_args); - break; - } - case TENSOR_STATIC_READY: { - func->execute({module_tensors[next_name].get()}, tensorPtrs, float_args); - break; - } - case TENSOR_STATIC_TRACE: { - if (impl_->backend_->type() == BackendType::MLLM_CPU) { - Tracer::addTensorFunction(func, tensorPtrs, {module_tensors[next_name].get()}, float_args); - } - break; - } - default: { - } - } - if (Backend::global_backends.size() == 1) { - for (auto input_tensor : tensorPtrs) { - if (activation_tensors_num.find(input_tensor->name()) != activation_tensors_num.end()) { - switch (Tensor::tensor_status) { - case TENSOR_STATIC_INIT: { - activation_tensors_num[input_tensor->name()] += 1; - break; - } - case TENSOR_STATIC_READY: { - activation_tensors_num[input_tensor->name()] -= 1; - break; - } - default: { - } - } - if (activation_tensors_num[input_tensor->name()] == 0 && module_tensors[input_tensor->name()]->sequence() > 1 - && module_tensors[input_tensor->name()]->ttype() != GRAPH_OUTPUT) { - module_tensors[input_tensor->name()]->free(); - // std::cout << input_tensor->name() << " |F" << std::endl; - } - } - } - } -#ifdef DEBUGOPTIME - if (Tensor::tensor_status == TENSOR_STATIC_READY) { - auto end_t = mllm_time_us(); - std::cout << next_name << " | " << Tensor::tensor_status - << " time: " << (end_t - start_t) / 1000.0F << "ms" << std::endl; - } -#endif -#ifdef DEBUGSAVETENSOR - module_tensors[next_name]->saveNData(); -#endif - return *module_tensors[next_name]; -} - -void Tensor::getFunc(const TensorFuncType type, - vector float_args, vector other_tensors) { - assert(module() != nullptr); - auto &module_tensors = module()->activation_tensors; - auto &activation_tensors_num = module()->activation_tensors_num; - if (module()->doLoad) { return; } - TensorFunction *func = impl_->backend_->funcCreate(type); - std::vector tensorPtrs = {module_tensors[impl_->name_].get()}; - for (auto &other_tensor : other_tensors) { tensorPtrs.push_back(other_tensor); } -#ifdef DEBUGOPTIME - auto start_t = mllm_time_us(); -#endif - switch (Tensor::tensor_status) { - case TENSOR_STATIC_INIT: { - func->setup({}, tensorPtrs, float_args); - break; - } - case TENSOR_STATIC_READY: { - func->execute({}, tensorPtrs, float_args); - break; - } - default: { - } - } - if (Backend::global_backends.size() == 1) { - for (auto input_tensor : tensorPtrs) { - if (activation_tensors_num.find(input_tensor->name()) != activation_tensors_num.end() - // && input_tensor->dimension() * input_tensor->sequence() > 0 - ) { - switch (Tensor::tensor_status) { - case TENSOR_STATIC_INIT: { - activation_tensors_num[input_tensor->name()] += 1; - break; - } - case TENSOR_STATIC_READY: { - activation_tensors_num[input_tensor->name()] -= 1; - break; - } - default: { - } - } - if (activation_tensors_num[input_tensor->name()] == 0 && module_tensors[input_tensor->name()]->sequence() > 1 - && module_tensors[input_tensor->name()]->ttype() != GRAPH_OUTPUT) { - module_tensors[input_tensor->name()]->free(); - // std::cout << input_tensor->name() << " |F" << std::endl; - } - } - } - } -#ifdef DEBUGOPTIME - if (Tensor::tensor_status == TENSOR_STATIC_READY) { - auto end_t = mllm_time_us(); - std::cout << " | " << Tensor::tensor_status - << " time: " << (end_t - start_t) / 1000.0F << "ms" << std::endl; - } -#endif -} - -std::vector> Tensor::getStaticFunc(vector out_names, - const TensorFuncType type, - vector float_args, - vector input_tensors) { - Module *module; - if (!input_tensors.empty()) { - module = input_tensors[0]->module(); - } else { - module = Module::llm_model_ptr; - } - assert(module != nullptr); - auto &module_tensors = module->activation_tensors; - auto &activation_tensors_num = module->activation_tensors_num; - auto *backend_h = Backend::global_backends[MLLM_CPU]; - if (!input_tensors.empty() && input_tensors[0]->impl_->backend_ != nullptr) { - backend_h = input_tensors[0]->backend(); - } - for (auto out_name : out_names) { - if (module_tensors.find(out_name) == module_tensors.end()) { - module_tensors[out_name] = std::make_shared(backend_h); - module_tensors[out_name]->setName(out_name); - module_tensors[out_name]->setModule(module); - activation_tensors_num[out_name] = 0; - } - } - if (module->doLoad) { - std::vector> results; - for (auto out_name : out_names) { results.push_back(*module_tensors[out_name]); } - return results; - } - TensorFunction *func = backend_h->funcCreate(type); - // std::vector tensorPtrs; - // for (auto input_tensor : input_tensors){ tensorPtrs.push_back(module_tensors[input_tensor->name()].get()); } - std::vector outPtrs; - for (auto out_name : out_names) { outPtrs.push_back(module_tensors[out_name].get()); } -#ifdef DEBUGOPTIME - auto start_t = mllm_time_us(); -#endif - switch (Tensor::tensor_status) { - case TENSOR_STATIC_INIT: { - func->setup(outPtrs, input_tensors, float_args); - break; - } - case TENSOR_STATIC_READY: { - func->execute(outPtrs, input_tensors, float_args); - break; - } - case TENSOR_STATIC_TRACE: { - if (backend_h->type() == BackendType::MLLM_CPU) { - Tracer::addTensorFunction(func, input_tensors, outPtrs, float_args); - } - break; - } - default: { - } - } - if (Backend::global_backends.size() == 1) { - for (auto input_tensor : input_tensors) { - if (activation_tensors_num.find(input_tensor->name()) != activation_tensors_num.end()) { - switch (Tensor::tensor_status) { - case TENSOR_STATIC_INIT: { - activation_tensors_num[input_tensor->name()] += 1; - break; - } - case TENSOR_STATIC_READY: { - activation_tensors_num[input_tensor->name()] -= 1; - break; - } - default: { - } - } - if (activation_tensors_num[input_tensor->name()] == 0 && module_tensors[input_tensor->name()]->sequence() > 1 - && module_tensors[input_tensor->name()]->ttype() != GRAPH_OUTPUT) { - module_tensors[input_tensor->name()]->free(); - // std::cout << input_tensor->name() << " |S "<< std::endl;// << out_names[0] << std::endl; - } - } - } - } -#ifdef DEBUGOPTIME - if (Tensor::tensor_status == TENSOR_STATIC_READY) { - auto end_t = mllm_time_us(); - std::cout << out_names[0] << " | " << Tensor::tensor_status - << " time: " << (end_t - start_t) / 1000.0F << "ms" << std::endl; - } -#endif -#ifdef DEBUGSAVETENSOR - for (auto out_name : out_names) { module_tensors[out_name]->saveNData(); } -#endif - std::vector> results; - for (auto out_name : out_names) { results.push_back(*module_tensors[out_name]); } - return results; -} -*/ - Tensor Tensor::operator+(float data) { return runFunc({name() + "-add"}, FUNC_ADD, {data}, {std::shared_ptr(this, [](Tensor *) {})})[0]; @@ -537,6 +309,17 @@ Tensor Tensor::clip(Chl keep_axis, vector b, vector h, vector s, {std::shared_ptr(this, [](Tensor *) {})})[0]; } +Tensor Tensor::clip(vector index, Chl dim) { + Tensor index_tensor(1, 1, 1, index.size(), impl_->backend_, false); + index_tensor.alloc(); + for (size_t i = 0; i < index.size(); ++i) { + index_tensor.setDataAt(0, 0, 0, i, static_cast(index[i])); + } + index_tensor.setName(name() + "-cliptensor-index"); + return runFunc({name() + "-cliptensor"}, FUNC_CLIPTENSOR, {(float)dim}, + {std::shared_ptr(this, [](Tensor *) {}), + std::shared_ptr(&index_tensor, [](Tensor *) {})})[0]; +} Tensor Tensor::clip(Tensor index, Chl dim) { return runFunc({name() + "-cliptensor"}, FUNC_CLIPTENSOR, {(float)dim}, {std::shared_ptr(this, [](Tensor *) {}), @@ -638,6 +421,13 @@ Tensor Tensor::zero_like(Tensor input) { return runFunc({input.name() + "-zero_like"}, FUNC_LIKE, {0.0}, {std::shared_ptr(&input, [](Tensor *) {})})[0]; } +Tensor Tensor::flash_attention2_forward(Tensor q, Tensor k, Tensor v, bool causal_mask) { + Module *module = q.module(); + return runFunc({q.name() + "-" + k.name() + "-fa2"}, FUNC_FA2, {causal_mask ? 1.0f : 0.0f}, + {std::shared_ptr(&q, [](Tensor *) {}), + std::shared_ptr(&k, [](Tensor *) {}), + std::shared_ptr(&v, [](Tensor *) {})})[0]; +}; Tensor Tensor::apply_rotary_pos_emb_vision(Tensor input, Tensor rotary_pos_emb) { Module *module = input.module(); return runFunc({input.name() + "-apply_rotary_pos_emb"}, FUNC_APPLY_VISIOROPE, diff --git a/src/Tensor.hpp b/src/Tensor.hpp index 429fedd7f..c21518d59 100644 --- a/src/Tensor.hpp +++ b/src/Tensor.hpp @@ -769,6 +769,7 @@ class Tensor { Tensor transpose(vector> axiss); Tensor clip(vector b, vector h, vector s, vector d); Tensor clip(Chl keep_axis, vector b, vector h, vector s, vector d); + Tensor clip(vector index, Chl dim); Tensor clip(Tensor index, Chl dim); Tensor expand(int b, int h, int s, int d); static Tensor cat(vector input_tensors, Chl dims); @@ -788,6 +789,7 @@ class Tensor { Tensor bincount(); Tensor repeat(Chl dim, int dim_size); static Tensor zero_like(Tensor input); + static Tensor flash_attention2_forward(Tensor q, Tensor k, Tensor v, bool is_causal = true); static Tensor apply_rotary_pos_emb_vision(Tensor input, Tensor rotary_pos_emb); // models use only @@ -1587,6 +1589,89 @@ class Tensor { outFile.close(); } + template + void saveIntData(string ex = "") { + if (Tensor::tensor_status != TENSOR_STATIC_READY) return; + if (ctype() == BTHWC || ctype() == BCTHW) { + save5Data(ex); + return; + } + // std::filesystem::create_directory("save_out"); + string directory = "save_out"; + struct stat info; +#ifdef _WIN32 + _mkdir(directory.c_str()); +#else + if (stat(directory.c_str(), &info) != 0) { + if (stat(directory.c_str(), &info) != 0) { + mkdir(directory.c_str(), 0777); // notice that 0777 is different than usual + } else if (!(info.st_mode & S_IFDIR)) { + // if the path exists but it is not a directory, also create it + mkdir(directory.c_str(), 0777); // notice that 0777 is different than usual + } + } +#endif + std::ofstream outFile(directory + "/" + name() + ex + ".log"); + outFile << "----------------------------------------" << std::endl; + if (impl_->ctype_ == BSHD) { + outFile << name() << ": [BSHD]shape:[" << batch() << " " << sequence() << " " << head() << " " << dimension() << "] " << dtype() << " " << ctype() << std::endl; + } else { + outFile << name() << ": shape:[" << batch() << " " << head() << " " << sequence() << " " << dimension() << "] " << dtype() << " " << ctype() << std::endl; + } + + int N = batch(); + int C = head(); + int H = sequence(); + int W = dimension(); + if (impl_->ctype_ == BSHD) { + for (int n = 0; n < batch(); ++n) { + for (int h = 0; h < sequence(); ++h) { + for (int c = 0; c < head(); ++c) { + for (int w = 0; w < dimension(); ++w) { + outFile << (int)dataAt(n, c, h, w) << " "; + } + outFile << std::endl; + } + outFile << std::endl; + } + outFile << std::endl; + } + outFile.close(); + return; + } + if (N == 1 && C == 1) { + for (int h = 0; h < H; ++h) { + for (int c = 0; c < W; ++c) { + outFile << std::fixed << std::setprecision(6) << dataAt(0, 0, h, c) << " "; + } + outFile << std::endl; + outFile << "---------" << std::endl; + } + } else if (N == 1 && W == 1) { + for (int h = 0; h < H; ++h) { + for (int c = 0; c < C; ++c) { + outFile << std::fixed << std::setprecision(6) << dataAt(0, c, h, 0) << " "; + } + outFile << std::endl; + } + } else { + for (int n = 0; n < N; ++n) { + for (int h = 0; h < H; ++h) { + for (int c = 0; c < C; ++c) { + for (int w = 0; w < W; ++w) { + outFile << std::fixed << std::setprecision(6) << dataAt(n, c, h, w) << " "; + } + outFile << std::endl; + } + outFile << std::endl; + } + outFile << std::endl; + } + } + + outFile.close(); + } + template void saveNData(string new_name = "", string ex = "") { if (Tensor::tensor_status == TENSOR_STATIC_READY && !shape().empty()) { diff --git a/src/Trace.cpp b/src/Trace.cpp index 5e1f51ac3..18b57f833 100644 --- a/src/Trace.cpp +++ b/src/Trace.cpp @@ -41,10 +41,12 @@ void Tracer::addTensorFunction(TensorFunction *func, } void Tracer::trace(Module *model, vector inputs) { - inputs[0].setTtype(TensorType::NORMAL_TENSOR); - model->activation_tensors[inputs[0].name()] = std::shared_ptr(&inputs[0], [](Tensor *) {}); - model->activation_tensors[inputs[0].name()]->setName(inputs[0].name()); - model->activation_tensors[inputs[0].name()]->setModule(model); + for(auto& input : inputs) { + input.setTtype(TensorType::NORMAL_TENSOR); + model->activation_tensors[input.name()] = std::shared_ptr(&input, [](Tensor *) {}); + model->activation_tensors[input.name()]->setName(input.name()); + model->activation_tensors[input.name()]->setModule(model); + } Module::llm_model_ptr = model; diff --git a/src/backends/cpu/CPUBackend.cpp b/src/backends/cpu/CPUBackend.cpp index a1b9bdafc..511fbef6e 100644 --- a/src/backends/cpu/CPUBackend.cpp +++ b/src/backends/cpu/CPUBackend.cpp @@ -11,6 +11,7 @@ #include "op/CPUHeadLinear.hpp" #include "op/CPULinearInt8.hpp" +#include "op/CPUMultimodalRoPEPipeline.hpp" #include "op/CPUNTKRoPE.hpp" #include "op/CPUPoEmbedding.hpp" #include "op/CPUSplitInput.hpp" @@ -86,7 +87,8 @@ #include "function/CPURepeatFunc.hpp" #include "function/CPULikeFunc.hpp" #include "function/CPUScatterReduceFunc.hpp" -#include "function/CPUApplyVisionRoPE.hpp" +#include "function/CPUVisionRoPEFunc.hpp" +#include "function/CPUFlashAttention2Func.hpp" #include "function/CPUFuyuGatherEmbdFunc.hpp" #include "function/CPUPhi3VhdmergeFunc.hpp" @@ -95,16 +97,17 @@ namespace mllm { class CPUBackendCreator : public BackendCreator { Backend *create(BackendConfig config) { shared_ptr mm = nullptr; - switch (config.memory) { - case BackendConfig::Memory_High: - // mm = std::make_shared(); - mm = std::make_shared(); // todomm - break; - default: - // mm = std::make_shared(); - mm = std::make_shared(); // todomm - break; - } + mm = std::make_shared(); // todomm + // switch (config.memory) { + // case BackendConfig::Memory_High: + // mm = std::make_shared(); + // // mm = std::make_shared(); // todomm + // break; + // default: + // mm = std::make_shared(); + // // mm = std::make_shared(); // todomm + // break; + // } return new CPUBackend(mm); }; }; @@ -160,8 +163,8 @@ void CPUBackend::registerOps() { addCreator(MAXPOOL2D, (CPUBackend::Creator *)(new CPUMaxPoolCreator())); addCreator(CONVOLUTION3D, (CPUBackend::Creator *)(new CPUConvolution3DCreator())); addCreator(VISIONROPE, (CPUBackend::Creator *)(new CPUVisionRoPECreator())); + addCreator(MULTIMODALROPEPIP, (CPUBackend::Creator *)(new CPUMultimodalRoPEPipelineCreator())); addCreator(MULTIMODALROPE, (CPUBackend::Creator *)(new CPUMultimodalRoPECreator())); - // addCreator(CAT, (CPUBackend::Creator *)(new CPUCatCreator())); addCreator(TRANSPOSE, (CPUBackend::Creator *)(new CPUTransposeCreator())); addCreator(SUBDIM, (CPUBackend::Creator *)(new CPUSubDimCreator())); addCreator(DIVISION, (CPUBackend::Creator *)(new CPUDivisionCreator())); @@ -226,7 +229,8 @@ void CPUBackend::registerFuncs() { map_function_[TensorFuncType::FUNC_REPEAT] = new CPUrepeatFunction(); map_function_[TensorFuncType::FUNC_LIKE] = new CPUlikeFunction(); map_function_[TensorFuncType::FUNC_SCATTERREDUCE] = new CPUScatterReduceFunction(); - map_function_[TensorFuncType::FUNC_APPLY_VISIOROPE] = new CPUApplyVisionRoPEFunction(); + map_function_[TensorFuncType::FUNC_APPLY_VISIOROPE] = new CPUVisionRoPEFuncFunction(); + map_function_[TensorFuncType::FUNC_FA2] = new CPUFlashAttention2Func(); // models use only map_function_[TensorFuncType::FUNC_FUYU_GATHER_EMBD] = new CPUFuyuGatherEmbdFunc(); map_function_[TensorFuncType::FUNC_PHI3V_HD_MERGE] = new CPUPhi3VhdmergeFunction(); diff --git a/src/backends/cpu/compute/FlashAttention2.hpp b/src/backends/cpu/compute/FlashAttention2.hpp new file mode 100644 index 000000000..ee10fc3ff --- /dev/null +++ b/src/backends/cpu/compute/FlashAttention2.hpp @@ -0,0 +1,3077 @@ + +#ifndef MLLM_FA2_CAL_HPP +#define MLLM_FA2_CAL_HPP + +// #include +// #include +// #include +#ifdef __AVX2__ +#include +#include +#include +#include +#include +#include +#include +#include "Types.hpp" +#include "VecDot.hpp" + +namespace mobi_attn { + +// ======================================== +// 数学函数和工具 +// ======================================== +#define NEG_INF std::numeric_limits::lowest() +// Horizontal max of a __m256 vector +inline float _mm256_hmax_ps(__m256 x) { + __m128 lo = _mm256_castps256_ps128(x); + __m128 hi = _mm256_extractf128_ps(x, 1); + __m128 max_val = _mm_max_ps(lo, hi); + max_val = _mm_max_ps(max_val, _mm_shuffle_ps(max_val, max_val, _MM_SHUFFLE(0, 0, 2, 2))); + max_val = _mm_max_ps(max_val, _mm_shuffle_ps(max_val, max_val, _MM_SHUFFLE(0, 0, 0, 1))); + return _mm_cvtss_f32(max_val); +} + +// Horizontal sum of a __m256 vector +inline float _mm256_hadd_ps(__m256 x) { + __m128 lo = _mm256_castps256_ps128(x); + __m128 hi = _mm256_extractf128_ps(x, 1); + __m128 sum = _mm_add_ps(lo, hi); + sum = _mm_hadd_ps(sum, sum); + sum = _mm_hadd_ps(sum, sum); + return _mm_cvtss_f32(sum); +} + +// ======================================== +// 内存对齐分配函数 +// ======================================== +// 使用 posix_memalign 进行分配 +void x86_align_alloc(void **ptr, size_t required_bytes, size_t align) { + // posix_memalign 要求 alignment 必须是 void* 大小的整数倍,并且是 2 的幂 + if (align % sizeof(void *) != 0 || (align & (align - 1)) != 0) { + *ptr = nullptr; + return; + } + + // posix_memalign 返回 0 表示成功,否则返回错误码 + if (posix_memalign(ptr, align, required_bytes) != 0) { + *ptr = nullptr; + } +} + +// 直接使用标准 free 进行释放 +void x86_align_free(void *ptr) { + free(ptr); +} + +// ======================================== +// FlashAttention2 核心实现 (FP32版本) +// ======================================== +struct AVX_FA_2_GQA_QKV_FP32_BSHD_O_FP32_BSHD_ACC_FP32_IMPL { + using dtype_q_in_t = float; + using dtype_kv_in_t = dtype_q_in_t; + using dtype_out_t = dtype_q_in_t; + using dtype_t = dtype_out_t; + using acc_dtype_t = float; + // 添加配置参数作为成员变量 + int32_t Br; + int32_t Bc; + int32_t Q_Head; + int32_t KV_Head; + int32_t threads; + bool high_precision; + // 配置参数初始化 + void configure(int32_t Br_, int32_t Bc_, int32_t Q_Head_, int32_t KV_Head_, int32_t threads_, bool high_precision_) { + Br = Br_; + Bc = Bc_; + Q_Head = Q_Head_; + KV_Head = KV_Head_; + threads = threads_; + high_precision = high_precision_; + } + + void init_workspace(acc_dtype_t *acc_o, acc_dtype_t *acc_s, + acc_dtype_t *logsum, acc_dtype_t *scoremax, acc_dtype_t *scoremax_prev, + acc_dtype_t *score_scale, acc_dtype_t *score_sum) { + acc_o_ = acc_o; + acc_s_ = acc_s; + logsum_ = logsum; + scoremax_ = scoremax; + scoremax_prev_ = scoremax_prev; + score_scale_ = score_scale; + score_sum_ = score_sum; + } + + // 核心计算函数 + inline void __fa2_prefill_append(const dtype_t *__restrict__ Q, const dtype_t *__restrict__ K, + const dtype_t *__restrict__ V, dtype_t *__restrict__ O, + const int32_t batch_size, const int32_t head_size, // head_size 就是 Q_Head + const int32_t seq_size_q, const int32_t seq_size_k, + const int32_t dim_size, bool causal_mask = true) { + const int32_t Tr = seq_size_q / Br; + const int32_t Tr_left = seq_size_q % Br; + const int32_t Tc = seq_size_k / Bc; + const int32_t Tc_left = seq_size_k % Bc; + + const float local_scale = 1.0f / sqrtf(static_cast(dim_size)); + + // 【关键修改 1】计算Q头与KV头的分组对应关系 + const int32_t kv_group_size = Q_Head / KV_Head; + + for (int32_t b_idx = 0; b_idx < batch_size; ++b_idx) { +#pragma omp parallel for num_threads(threads) schedule(dynamic, 1) if (threads > 1) + for (int32_t h_idx = 0; h_idx < head_size; ++h_idx) { // h_idx 是当前Q头的索引 + const int32_t thread_id = omp_get_thread_num(); + const int32_t this_thread_head = h_idx; + + // 【关键修改 1】计算当前Q头 (h_idx) 对应的KV头索引 + const int32_t this_thread_kv_head = this_thread_head / kv_group_size; + + // --- 主循环 (Tr) --- + for (int t_r_idx = 0; t_r_idx < Tr; ++t_r_idx) { + init_temp(logsum_ + thread_id * Br, scoremax_ + thread_id * Br, + acc_o_ + thread_id * Br * dim_size, dim_size); + for (int t_c_idx = 0; t_c_idx < Tc; ++t_c_idx) { + // Q 的指针计算保持不变,因为它有 Q_Head 个头 + const dtype_t *tile_q = Q + b_idx * seq_size_q * head_size * dim_size + t_r_idx * Br * head_size * dim_size + this_thread_head * dim_size; + + // 【关键修改 2】K 和 V 的指针计算,必须使用 KV_Head 和映射后的 this_thread_kv_head + const dtype_t *tile_k = K + b_idx * seq_size_k * KV_Head * dim_size + t_c_idx * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + const dtype_t *tile_v = V + b_idx * seq_size_k * KV_Head * dim_size + t_c_idx * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + + acc_dtype_t *tile_acc_s = acc_s_ + thread_id * Br * Bc; + acc_dtype_t *acc_o = acc_o_ + thread_id * Br * dim_size; + + // 【关键修改 3】为 mma0 传入Q和K各自的正确步长 + mma0(tile_q, tile_k, tile_acc_s, dim_size, head_size * dim_size, KV_Head * dim_size, t_r_idx, t_c_idx, seq_size_q, seq_size_k, causal_mask); + softmax(tile_acc_s, scoremax_ + thread_id * Br, scoremax_prev_ + thread_id * Br, score_scale_ + thread_id * Br, score_sum_ + thread_id * Br, logsum_ + thread_id * Br, local_scale, t_r_idx, t_c_idx, seq_size_q, seq_size_k, causal_mask); + rescale(acc_o, score_scale_ + thread_id * Br, dim_size, t_r_idx, t_c_idx, seq_size_q, seq_size_k, causal_mask); + // 【关键修改 3】为 mma1 传入 KV_Head 作为V的头数量,用于计算其内部步长 + mma1(tile_acc_s, tile_v, acc_o, KV_Head, dim_size, t_r_idx, t_c_idx, seq_size_q, seq_size_k, causal_mask); + } + if (Tc_left) { + const dtype_t *tile_q = Q + b_idx * seq_size_q * head_size * dim_size + t_r_idx * Br * head_size * dim_size + this_thread_head * dim_size; + const dtype_t *tile_k = K + b_idx * seq_size_k * KV_Head * dim_size + Tc * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + const dtype_t *tile_v = V + b_idx * seq_size_k * KV_Head * dim_size + Tc * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + acc_dtype_t *tile_acc_s = acc_s_ + thread_id * Br * Bc; + acc_dtype_t *acc_o = acc_o_ + thread_id * Br * dim_size; + mma0_pa_n_fixed(Br, Tc_left, tile_q, tile_k, tile_acc_s, dim_size, head_size * dim_size, KV_Head * dim_size, t_r_idx, Tc, seq_size_q, seq_size_k, causal_mask); + softmax_pa_n_fixed(Br, Tc_left, tile_acc_s, scoremax_ + thread_id * Br, scoremax_prev_ + thread_id * Br, score_scale_ + thread_id * Br, score_sum_ + thread_id * Br, logsum_ + thread_id * Br, local_scale, t_r_idx, Tc, seq_size_q, seq_size_k, causal_mask); + rescale_pa_n_fixed(Br, Tc_left, acc_o, score_scale_ + thread_id * Br, dim_size, t_r_idx, Tc, seq_size_q, seq_size_k, causal_mask); + mma1_pa_n_fixed(Br, Tc_left, tile_acc_s, tile_v, acc_o, KV_Head, dim_size, t_r_idx, Tc, seq_size_q, seq_size_k, causal_mask); + } + scale_and_store(acc_o_ + thread_id * Br * dim_size, logsum_ + thread_id * Br, O + b_idx * seq_size_q * head_size * dim_size + t_r_idx * Br * head_size * dim_size + this_thread_head * dim_size, t_r_idx, head_size, dim_size); + } + if (Tr_left) { + init_temp(logsum_ + thread_id * Br, scoremax_ + thread_id * Br, acc_o_ + thread_id * Br * dim_size, dim_size); + for (int t_c_idx = 0; t_c_idx < Tc; ++t_c_idx) { + const dtype_t *tile_q = Q + b_idx * seq_size_q * head_size * dim_size + Tr * Br * head_size * dim_size + this_thread_head * dim_size; + const dtype_t *tile_k = K + b_idx * seq_size_k * KV_Head * dim_size + t_c_idx * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + const dtype_t *tile_v = V + b_idx * seq_size_k * KV_Head * dim_size + t_c_idx * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + acc_dtype_t *tile_acc_s = acc_s_ + thread_id * Br * Bc; + acc_dtype_t *acc_o = acc_o_ + thread_id * Br * dim_size; + mma0_pa_n_fixed(Tr_left, Bc, tile_q, tile_k, tile_acc_s, dim_size, head_size * dim_size, KV_Head * dim_size, Tr, t_c_idx, seq_size_q, seq_size_k, causal_mask); + softmax_pa_n_fixed(Tr_left, Bc, tile_acc_s, scoremax_ + thread_id * Br, scoremax_prev_ + thread_id * Br, score_scale_ + thread_id * Br, score_sum_ + thread_id * Br, logsum_ + thread_id * Br, local_scale, Tr, t_c_idx, seq_size_q, seq_size_k, causal_mask); + rescale_pa_n_fixed(Tr_left, Bc, acc_o, score_scale_ + thread_id * Br, dim_size, Tr, t_c_idx, seq_size_q, seq_size_k, causal_mask); + mma1_pa_n_fixed(Tr_left, Bc, tile_acc_s, tile_v, acc_o, KV_Head, dim_size, Tr, t_c_idx, seq_size_q, seq_size_k, causal_mask); + } + if (Tc_left) { + const dtype_t *tile_q = Q + b_idx * seq_size_q * head_size * dim_size + Tr * Br * head_size * dim_size + this_thread_head * dim_size; + const dtype_t *tile_k = K + b_idx * seq_size_k * KV_Head * dim_size + Tc * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + const dtype_t *tile_v = V + b_idx * seq_size_k * KV_Head * dim_size + Tc * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + acc_dtype_t *tile_acc_s = acc_s_ + thread_id * Br * Bc; + acc_dtype_t *acc_o = acc_o_ + thread_id * Br * dim_size; + mma0_pa_n_fixed(Tr_left, Tc_left, tile_q, tile_k, tile_acc_s, dim_size, head_size * dim_size, KV_Head * dim_size, Tr, Tc, seq_size_q, seq_size_k, causal_mask); + softmax_pa_n_fixed(Tr_left, Tc_left, tile_acc_s, scoremax_ + thread_id * Br, scoremax_prev_ + thread_id * Br, score_scale_ + thread_id * Br, score_sum_ + thread_id * Br, logsum_ + thread_id * Br, local_scale, Tr, Tc, seq_size_q, seq_size_k, causal_mask); + rescale_pa_n_fixed(Tr_left, Tc_left, acc_o, score_scale_ + thread_id * Br, dim_size, Tr, Tc, seq_size_q, seq_size_k, causal_mask); + mma1_pa_n_fixed(Tr_left, Tc_left, tile_acc_s, tile_v, acc_o, KV_Head, dim_size, Tr, Tc, seq_size_q, seq_size_k, causal_mask); + } + scale_and_store_pa_n_fixed(Tr_left, acc_o_ + thread_id * Br * dim_size, logsum_ + thread_id * Br, O + b_idx * seq_size_q * head_size * dim_size + Tr * Br * head_size * dim_size + this_thread_head * dim_size, Tr, head_size, dim_size); + } + } + } + } + + inline void __fa2_decode(const dtype_t *__restrict__ Q, const dtype_t *__restrict__ K, + const dtype_t *__restrict__ V, dtype_t *__restrict__ O, + const int32_t batch_size, const int32_t head_size, + const int32_t seq_size_q, const int32_t seq_size_k, + const int32_t dim_size, bool causal_mask = true) { + const int32_t Tr = 1; + const int32_t Tc = seq_size_k / Bc; + const int32_t Tc_left = seq_size_k % Bc; + + const float local_scale = 1.0f / sqrtf(static_cast(dim_size)); + // FIX: Calculate the ratio of Q_Head to KV_Head to handle GQA/MHA correctly. + const int32_t kv_group_size = (Q_Head > 0 && KV_Head > 0) ? Q_Head / KV_Head : 1; + + for (int32_t b_idx = 0; b_idx < batch_size; ++b_idx) { +#pragma omp parallel for num_threads(threads) schedule(dynamic, 1) if (threads > 1) + for (int32_t h_idx = 0; h_idx < head_size; ++h_idx) { + const int32_t thread_id = omp_get_thread_num(); + const int32_t this_thread_head = h_idx; + // FIX: Map the current query head 'h_idx' to its corresponding KV head. + const int32_t this_thread_kv_head = this_thread_head / kv_group_size; + + for (int t_r_idx = 0; t_r_idx < Tr; ++t_r_idx) { + init_temp_d(logsum_ + thread_id * Br, scoremax_ + thread_id * Br, acc_o_ + thread_id * Br * dim_size, dim_size); + for (int t_c_idx = 0; t_c_idx < Tc; ++t_c_idx) { + const dtype_t *tile_q = Q + b_idx * seq_size_q * head_size * dim_size + t_r_idx * 1 * head_size * dim_size + this_thread_head * dim_size; + // FIX: Corrected pointer arithmetic for K and V. + const dtype_t *tile_k = K + b_idx * seq_size_k * KV_Head * dim_size + t_c_idx * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + const dtype_t *tile_v = V + b_idx * seq_size_k * KV_Head * dim_size + t_c_idx * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + + acc_dtype_t *tile_acc_s = acc_s_ + thread_id * Br * Bc; + acc_dtype_t *acc_o = acc_o_ + thread_id * Br * dim_size; + + mma0_d(tile_q, tile_k, tile_acc_s, dim_size, KV_Head * dim_size, t_r_idx, t_c_idx, seq_size_q, seq_size_k, causal_mask); + softmax_d(tile_acc_s, scoremax_ + thread_id * Br, scoremax_prev_ + thread_id * Br, score_scale_ + thread_id * Br, score_sum_ + thread_id * Br, logsum_ + thread_id * Br, local_scale, t_r_idx, t_c_idx, seq_size_q, seq_size_k, causal_mask); + rescale_d(acc_o, score_scale_ + thread_id * Br, dim_size, t_r_idx, t_c_idx, seq_size_q, seq_size_k, causal_mask); + mma1_d(tile_acc_s, tile_v, acc_o, KV_Head, dim_size, t_r_idx, t_c_idx, seq_size_q, seq_size_k, causal_mask); + } + if (Tc_left) { + const dtype_t *tile_q = Q + b_idx * seq_size_q * head_size * dim_size + t_r_idx * 1 * head_size * dim_size + this_thread_head * dim_size; + const dtype_t *tile_k = K + b_idx * seq_size_k * KV_Head * dim_size + Tc * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + const dtype_t *tile_v = V + b_idx * seq_size_k * KV_Head * dim_size + Tc * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + acc_dtype_t *tile_acc_s = acc_s_ + thread_id * Br * Bc; + acc_dtype_t *acc_o = acc_o_ + thread_id * Br * dim_size; + mma0_d_n_fixed(Tc_left, tile_q, tile_k, tile_acc_s, dim_size, KV_Head * dim_size, t_r_idx, Tc, seq_size_q, seq_size_k, causal_mask); + softmax_d_n_fixed(Tc_left, tile_acc_s, scoremax_ + thread_id * Br, scoremax_prev_ + thread_id * Br, score_scale_ + thread_id * Br, score_sum_ + thread_id * Br, logsum_ + thread_id * Br, local_scale, t_r_idx, Tc, seq_size_q, seq_size_k, causal_mask); + rescale_d_n_fixed(Tc_left, acc_o, score_scale_ + thread_id * Br, dim_size, t_r_idx, Tc, seq_size_q, seq_size_k, causal_mask); + mma1_d_n_fixed(Tc_left, tile_acc_s, tile_v, acc_o, KV_Head, dim_size, t_r_idx, Tc, seq_size_q, seq_size_k, causal_mask); + } + scale_and_store_d(acc_o_ + thread_id * Br * dim_size, logsum_ + thread_id * Br, O + b_idx * seq_size_q * head_size * dim_size + t_r_idx * 1 * head_size * dim_size + this_thread_head * dim_size, t_r_idx, head_size, dim_size); + } + } + } + } + + void fa2(const dtype_t *__restrict__ Q, const dtype_t *__restrict__ K, + const dtype_t *__restrict__ V, dtype_t *__restrict__ O, const int32_t batch_size, + const int32_t head_size, const int32_t seq_size_q, const int32_t seq_size_k, + const int32_t dim_size, bool causal_mask = true) { + assert(Br == Bc); + // assert(Br % 4 == 0); + // FIX: Assert that Q_Head is a multiple of KV_Head for valid GQA/MHA. + assert(Q_Head % KV_Head == 0); + assert(head_size % threads == 0); + assert(dim_size % 8 == 0); // AVX processes 8 floats at a time + + if (seq_size_q != 1) { + __fa2_prefill_append(Q, K, V, O, batch_size, head_size, seq_size_q, seq_size_k, dim_size, + causal_mask); + } else { + __fa2_decode(Q, K, V, O, batch_size, head_size, seq_size_q, seq_size_k, dim_size, + causal_mask); + } + } + +private: + // inline void init_temp(acc_dtype_t *logsum, acc_dtype_t *scoremax, acc_dtype_t *acc_o, + // const int32_t dim_size) { + // __m256 zero_vec = _mm256_set1_ps(0.0f); + // __m256 neg_inf_vec = _mm256_set1_ps(NEG_INF); + + // for (int i = 0; i < Br; i += 8) { _mm256_storeu_ps(logsum + i, zero_vec); } + // for (int i = 0; i < Br; i += 8) { _mm256_storeu_ps(scoremax + i, neg_inf_vec); } + // for (int i = 0; i < Br * dim_size; i += 8) { _mm256_storeu_ps(acc_o + i, zero_vec); } + // } + + inline void init_temp(acc_dtype_t *logsum, acc_dtype_t *scoremax, acc_dtype_t *acc_o, const int32_t dim_size) { + __m256 zero_vec = _mm256_set1_ps(0.0f); + __m256 neg_inf_vec = _mm256_set1_ps(NEG_INF); + + // 【最终修正】使用安全的循环来确保完全初始化,不再依赖Br是8的倍数 + int i = 0; + for (; i <= Br - 8; i += 8) { + _mm256_storeu_ps(logsum + i, zero_vec); + _mm256_storeu_ps(scoremax + i, neg_inf_vec); + } + // 处理剩余的元素(如果Br不是8的倍数) + for (; i < Br; ++i) { + logsum[i] = 0.0f; + scoremax[i] = NEG_INF; + } + + // acc_o 的初始化是安全的,因为调用者保证了 dim_size % 8 == 0 + for (int j = 0; j < Br * dim_size; j += 8) { + _mm256_storeu_ps(acc_o + j, zero_vec); + } + } + + // 【关键修改】函数签名增加 kv_stride_size 参数,用于区分Q和K的步长 + inline void mma0(const dtype_t *__restrict__ q_block, const dtype_t *__restrict__ k_block, + acc_dtype_t *__restrict__ acc_s, const int32_t dim_size, + const int32_t q_stride_size, const int32_t kv_stride_size, + const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, bool causal_mask) { + const int32_t global_r_start = t_r_idx * Br; + const int32_t global_r_end = global_r_start + Br; + const int32_t global_c_start = t_c_idx * Bc; + int delta_pos = seq_size_k - seq_size_q; + + if (causal_mask && (global_c_start - delta_pos > (global_r_end - 1))) { return; } + +#pragma unroll + for (int32_t b_r_idx = 0; b_r_idx < Br; ++b_r_idx) { + // 【关键修改】使用传入的 q_stride_size + const dtype_t *q_block_line = q_block + b_r_idx * q_stride_size; +#pragma unroll + for (int32_t b_c_idx = 0; b_c_idx < Bc; ++b_c_idx) { + // 【关键修改】使用传入的 kv_stride_size + const dtype_t *k_block_line = k_block + b_c_idx * kv_stride_size; + + __m256 sum_vec = _mm256_setzero_ps(); + int i = 0; + for (; i <= dim_size - 8; i += 8) { + __builtin_prefetch(q_block_line + i + 64); + __builtin_prefetch(k_block_line + i + 64); + __m256 q_vec = _mm256_loadu_ps(q_block_line + i); + __m256 k_vec = _mm256_loadu_ps(k_block_line + i); + sum_vec = _mm256_fmadd_ps(q_vec, k_vec, sum_vec); + } + acc_dtype_t total = _mm256_hadd_ps(sum_vec); + for (; i < dim_size; ++i) { total += q_block_line[i] * k_block_line[i]; } + + acc_s[b_r_idx * Bc + b_c_idx] = total; + } + } + + if (causal_mask && (global_r_end == (t_c_idx * Bc + Bc) - delta_pos)) { + for (int i = 0; i < Br; ++i) { + for (int j = 0; j < Bc; ++j) { + if (j > i) { acc_s[i * Bc + j] = NEG_INF; } + } + } + } + } + + inline void softmax(acc_dtype_t *__restrict__ acc_s, acc_dtype_t *scoremax, acc_dtype_t *scoremax_prev, + acc_dtype_t *score_scale, acc_dtype_t *score_sum, acc_dtype_t *logsum, + const float scale, + const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, bool causal_mask) { + const int32_t global_r_start = t_r_idx * Br; + const int32_t global_c_start = t_c_idx * Bc; + int delta_pos = seq_size_k - seq_size_q; + if (causal_mask && (global_c_start - delta_pos > (global_r_start + Br - 1))) return; + memcpy(scoremax_prev, scoremax, Br * sizeof(acc_dtype_t)); + for (int br = 0; br < Br; ++br) { + __m256 max_vec = _mm256_set1_ps(scoremax[br]); + acc_dtype_t *row = acc_s + br * Bc; + int bc = 0; + for (; bc <= Bc - 8; bc += 8) { max_vec = _mm256_max_ps(max_vec, _mm256_loadu_ps(row + bc)); } + float max_val = _mm256_hmax_ps(max_vec); + for (; bc < Bc; ++bc) { max_val = fmaxf(max_val, row[bc]); } + scoremax[br] = max_val; + } + for (int br = 0; br < Br; ++br) { score_scale[br] = expf((scoremax_prev[br] - scoremax[br]) * scale); } + for (int br = 0; br < Br; ++br) { + const float sm = scoremax[br]; + acc_dtype_t *row = acc_s + br * Bc; + float sum = 0.0f; + for (int bc = 0; bc < Bc; ++bc) { + float val = expf((row[bc] - sm) * scale); + row[bc] = val; + sum += val; + } + score_sum[br] = sum; + } + for (int br = 0; br < Br; ++br) { logsum[br] = logsum[br] * score_scale[br] + score_sum[br]; } + } + + inline void rescale(acc_dtype_t *__restrict__ acc_o, acc_dtype_t *__restrict__ score_scale, + const int32_t dim_size, const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, bool causal_mask) { + const int32_t global_r_start = t_r_idx * Br; + const int32_t global_c_start = t_c_idx * Bc; + int delta_pos = seq_size_k - seq_size_q; + if (causal_mask && (global_c_start - delta_pos > (global_r_start + Br - 1))) return; + +#pragma unroll + for (int i = 0; i < Br; ++i) { + __m256 scale_v = _mm256_set1_ps(score_scale[i]); + float *row_ptr = acc_o + i * dim_size; + for (int j = 0; j < dim_size; j += 8) { + __m256 acc = _mm256_loadu_ps(row_ptr + j); + acc = _mm256_mul_ps(acc, scale_v); + _mm256_storeu_ps(row_ptr + j, acc); + } + } + } + + // 【关键修改】函数签名增加 kv_head_size 参数,用于计算V的内部步长 + inline void mma1(const acc_dtype_t *__restrict__ w_block, const dtype_t *__restrict__ v_block, + acc_dtype_t *__restrict__ acc_o, const int32_t kv_head_size, const int32_t dim_size, + const int32_t t_r_idx, const int32_t t_c_idx, const int32_t seq_size_q, + const int32_t seq_size_k, bool causal_mask) { + const int32_t global_r_start = t_r_idx * Br; + const int32_t global_c_start = t_c_idx * Bc; + int delta_pos = seq_size_k - seq_size_q; + if (causal_mask && (global_c_start - delta_pos > (global_r_start + Br - 1))) return; + + // 【关键修改】使用传入的 kv_head_size 来计算 V 的步长 + const int32_t v_stride_size = kv_head_size * dim_size; + +#pragma unroll + for (int b_r_idx = 0; b_r_idx < Br; ++b_r_idx) { + for (int d_base = 0; d_base < dim_size; d_base += 8) { + __m256 acc = _mm256_loadu_ps(acc_o + b_r_idx * dim_size + d_base); +#pragma unroll + for (int b_c_idx = 0; b_c_idx < Bc; ++b_c_idx) { + __m256 w_vec = _mm256_set1_ps(w_block[b_r_idx * Bc + b_c_idx]); + // 【关键修改】使用计算出的 v_stride_size + const float *v_ptr = v_block + b_c_idx * v_stride_size + d_base; + __m256 v_vec = _mm256_loadu_ps(v_ptr); + acc = _mm256_fmadd_ps(w_vec, v_vec, acc); + } + _mm256_storeu_ps(acc_o + b_r_idx * dim_size + d_base, acc); + } + } + } + + inline void scale_and_store(const acc_dtype_t *__restrict__ acc_o, + const acc_dtype_t *__restrict__ logsum, + dtype_t *__restrict__ o_block, const int32_t t_r_idx, + const int32_t head_size, const int32_t dim_size) { +#pragma unroll + for (int i = 0; i < Br; ++i) { + dtype_t *o_block_line = o_block + i * head_size * dim_size; + __m256 reciprocal_logsum_vec = _mm256_set1_ps(1.0f / logsum[i]); + int j = 0; + for (; j <= dim_size - 8; j += 8) { + __m256 vec_acc_o = _mm256_loadu_ps(acc_o + i * dim_size + j); + __m256 result_vec = _mm256_mul_ps(vec_acc_o, reciprocal_logsum_vec); + _mm256_storeu_ps(o_block_line + j, result_vec); + } + float reciprocal_logsum = 1.0f / logsum[i]; + for (; j < dim_size; ++j) { + o_block_line[j] = acc_o[i * dim_size + j] * reciprocal_logsum; + } + } + } + + // N-fixed functions for handling leftovers + // FIX: Modified mma0_pa_n_fixed to accept separate strides. + inline void mma0_pa_n_fixed(const int32_t Br_n_fixed, const int32_t Bc_n_fixed, + const dtype_t *__restrict__ q_block, + const dtype_t *__restrict__ k_block, acc_dtype_t *__restrict__ acc_s, + const int32_t dim_size, const int32_t q_stride_size, const int32_t kv_stride_size, + const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, + bool causal_mask) { + const int32_t global_r_start = t_r_idx * Br; + const int32_t global_r_end = global_r_start + Br_n_fixed; + const int32_t global_c_start = t_c_idx * Bc; + int delta_pos = seq_size_k - seq_size_q; + if (causal_mask && (global_c_start - delta_pos > (global_r_end - 1))) { return; } + + for (int32_t b_r_idx = 0; b_r_idx < Br_n_fixed; ++b_r_idx) { + const dtype_t *q_block_line = q_block + b_r_idx * q_stride_size; + for (int32_t b_c_idx = 0; b_c_idx < Bc_n_fixed; ++b_c_idx) { + const dtype_t *k_block_line = k_block + b_c_idx * kv_stride_size; + __m256 sum_vec = _mm256_setzero_ps(); + int i = 0; + for (; i <= dim_size - 8; i += 8) { + sum_vec = _mm256_fmadd_ps(_mm256_loadu_ps(q_block_line + i), _mm256_loadu_ps(k_block_line + i), sum_vec); + } + acc_dtype_t total = _mm256_hadd_ps(sum_vec); + for (; i < dim_size; ++i) { total += q_block_line[i] * k_block_line[i]; } + acc_s[b_r_idx * Bc + b_c_idx] = total; + } + } + + if (causal_mask && (global_r_end == (global_c_start + Bc_n_fixed) - delta_pos)) { + for (int i = 0; i < Br_n_fixed; ++i) { + for (int j = 0; j < Bc_n_fixed; ++j) { + if (j > i) { acc_s[i * Bc + j] = NEG_INF; } + } + } + } + } + + inline void softmax_pa_n_fixed(const int32_t Br_n_fixed, const int32_t Bc_n_fixed, + acc_dtype_t *__restrict__ acc_s, acc_dtype_t *scoremax, + acc_dtype_t *scoremax_prev, acc_dtype_t *score_scale, + acc_dtype_t *score_sum, acc_dtype_t *logsum, + const float scale, + const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, bool causal_mask) { + const int32_t global_r_start = t_r_idx * Br; + const int32_t global_c_start = t_c_idx * Bc; + int delta_pos = seq_size_k - seq_size_q; + if (causal_mask && (global_c_start - delta_pos > (global_r_start + Br_n_fixed - 1))) return; + memcpy(scoremax_prev, scoremax, Br_n_fixed * sizeof(acc_dtype_t)); + for (int br = 0; br < Br_n_fixed; ++br) { + acc_dtype_t *row = acc_s + br * Bc; + float max_val = NEG_INF; + for (int bc = 0; bc < Bc_n_fixed; ++bc) max_val = fmaxf(max_val, row[bc]); + scoremax[br] = fmaxf(max_val, scoremax[br]); + } + for (int br = 0; br < Br_n_fixed; ++br) { + score_scale[br] = expf((scoremax_prev[br] - scoremax[br]) * scale); + } + for (int br = 0; br < Br_n_fixed; ++br) { + acc_dtype_t *row = acc_s + br * Bc; + float current_sum = 0.0f; + for (int bc = 0; bc < Bc_n_fixed; ++bc) { + float val = expf((row[bc] - scoremax[br]) * scale); + row[bc] = val; + current_sum += val; + } + score_sum[br] = current_sum; + } + for (int br = 0; br < Br_n_fixed; ++br) { logsum[br] = logsum[br] * score_scale[br] + score_sum[br]; } + } + + inline void rescale_pa_n_fixed(const int32_t Br_n_fixed, const int32_t Bc_n_fixed, + acc_dtype_t *__restrict__ acc_o, + acc_dtype_t *__restrict__ score_scale, const int32_t dim_size, + const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, + bool causal_mask) { + const int32_t global_r_start = t_r_idx * Br; + const int32_t global_c_start = t_c_idx * Bc; + int delta_pos = seq_size_k - seq_size_q; + if (causal_mask && (global_c_start - delta_pos > (global_r_start + Br_n_fixed - 1))) return; + + for (int i = 0; i < Br_n_fixed; ++i) { + float *row_ptr = acc_o + i * dim_size; + __m256 scale_v = _mm256_set1_ps(score_scale[i]); + for (int j = 0; j < dim_size; j += 8) { + _mm256_storeu_ps(row_ptr + j, _mm256_mul_ps(_mm256_loadu_ps(row_ptr + j), scale_v)); + } + } + } + + // FIX: Modified mma1_pa_n_fixed to accept kv_head_size. + inline void mma1_pa_n_fixed(const int32_t Br_n_fixed, const int32_t Bc_n_fixed, + const acc_dtype_t *__restrict__ w_block, + const dtype_t *__restrict__ v_block, acc_dtype_t *__restrict__ acc_o, + const int32_t kv_head_size, const int32_t dim_size, + const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, + bool causal_mask) { + const int32_t global_r_start = t_r_idx * Br; + const int32_t global_c_start = t_c_idx * Bc; + int delta_pos = seq_size_k - seq_size_q; + if (causal_mask && (global_c_start - delta_pos > (global_r_start + Br_n_fixed - 1))) return; + + const int32_t v_stride_size = kv_head_size * dim_size; + + for (int b_r_idx = 0; b_r_idx < Br_n_fixed; ++b_r_idx) { + for (int d_base = 0; d_base < dim_size; d_base += 8) { + __m256 acc = _mm256_loadu_ps(acc_o + b_r_idx * dim_size + d_base); + for (int b_c_idx = 0; b_c_idx < Bc_n_fixed; ++b_c_idx) { + __m256 w_vec = _mm256_set1_ps(w_block[b_r_idx * Bc + b_c_idx]); + const float *v_ptr = v_block + b_c_idx * v_stride_size + d_base; + acc = _mm256_fmadd_ps(w_vec, _mm256_loadu_ps(v_ptr), acc); + } + _mm256_storeu_ps(acc_o + b_r_idx * dim_size + d_base, acc); + } + } + } + + inline void scale_and_store_pa_n_fixed(const int32_t Br_n_fixed, + const acc_dtype_t *__restrict__ acc_o, + const acc_dtype_t *__restrict__ logsum, + dtype_t *__restrict__ o_block, const int32_t t_r_idx, + const int32_t head_size, const int32_t dim_size) { + for (int i = 0; i < Br_n_fixed; ++i) { + dtype_t *o_block_line = o_block + i * head_size * dim_size; + float reciprocal_logsum = 1.0f / logsum[i]; + __m256 reciprocal_logsum_vec = _mm256_set1_ps(reciprocal_logsum); + int j = 0; + for (; j <= dim_size - 8; j += 8) { + __m256 vec_acc_o = _mm256_loadu_ps(acc_o + i * dim_size + j); + _mm256_storeu_ps(o_block_line + j, _mm256_mul_ps(vec_acc_o, reciprocal_logsum_vec)); + } + for (; j < dim_size; ++j) { + o_block_line[j] = acc_o[i * dim_size + j] * reciprocal_logsum; + } + } + } + + // Decode mode functions + inline void init_temp_d(acc_dtype_t *logsum, acc_dtype_t *scoremax, acc_dtype_t *acc_o, + const int32_t dim_size) { + logsum[0] = 0.0f; + scoremax[0] = NEG_INF; + __m256 zero_vec = _mm256_setzero_ps(); + for (int i = 0; i < 1 * dim_size; i += 8) { _mm256_storeu_ps(acc_o + i, zero_vec); } + } + + // FIX: Modified mma0_d to accept kv_stride_size. + inline void mma0_d(const dtype_t *__restrict__ q_block, const dtype_t *__restrict__ k_block, + acc_dtype_t *__restrict__ acc_s, const int32_t dim_size, + const int32_t kv_stride_size, const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, bool causal_mask) { + const dtype_t *q_block_line = q_block; +#pragma unroll + for (int32_t b_c_idx = 0; b_c_idx < Bc; ++b_c_idx) { + const dtype_t *k_block_line = k_block + b_c_idx * kv_stride_size; + __m256 sum_vec = _mm256_setzero_ps(); + int i = 0; + for (; i <= dim_size - 8; i += 8) { + sum_vec = _mm256_fmadd_ps(_mm256_loadu_ps(q_block_line + i), _mm256_loadu_ps(k_block_line + i), sum_vec); + } + acc_dtype_t total = _mm256_hadd_ps(sum_vec); + for (; i < dim_size; ++i) { total += q_block_line[i] * k_block_line[i]; } + acc_s[b_c_idx] = total; + } + } + + inline void softmax_d(acc_dtype_t *__restrict__ acc_s, acc_dtype_t *scoremax, + acc_dtype_t *scoremax_prev, acc_dtype_t *score_scale, + acc_dtype_t *score_sum, acc_dtype_t *logsum, + const float scale, + const int32_t t_r_idx, + const int32_t t_c_idx, const int32_t seq_size_q, + const int32_t seq_size_k, bool causal_mask) { + scoremax_prev[0] = scoremax[0]; + float max_val = NEG_INF; + for (int bc = 0; bc < Bc; ++bc) max_val = fmaxf(max_val, acc_s[bc]); + scoremax[0] = fmaxf(max_val, scoremax[0]); + score_scale[0] = expf((scoremax_prev[0] - scoremax[0]) * scale); + float current_sum = 0.0f; + for (int bc = 0; bc < Bc; ++bc) { + float val = expf((acc_s[bc] - scoremax[0]) * scale); + acc_s[bc] = val; + current_sum += val; + } + score_sum[0] = current_sum; + logsum[0] = logsum[0] * score_scale[0] + score_sum[0]; + } + + inline void rescale_d(acc_dtype_t *__restrict__ acc_o, acc_dtype_t *__restrict__ score_scale, + const int32_t dim_size, const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, bool causal_mask) { + __m256 scale_v = _mm256_set1_ps(score_scale[0]); + for (int j = 0; j < dim_size; j += 8) { + __m256 acc = _mm256_loadu_ps(acc_o + j); + acc = _mm256_mul_ps(acc, scale_v); + _mm256_storeu_ps(acc_o + j, acc); + } + } + + // FIX: Modified mma1_d to accept kv_head_size. + inline void mma1_d(const acc_dtype_t *__restrict__ w_block, const dtype_t *__restrict__ v_block, + acc_dtype_t *__restrict__ acc_o, const int32_t kv_head_size, + const int32_t dim_size, const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, bool causal_mask) { + const int32_t v_stride_size = kv_head_size * dim_size; + for (int d_base = 0; d_base < dim_size; d_base += 8) { + __m256 acc = _mm256_loadu_ps(acc_o + d_base); + for (int b_c_idx = 0; b_c_idx < Bc; ++b_c_idx) { + __m256 w_vec = _mm256_set1_ps(w_block[b_c_idx]); + const float *v_ptr = v_block + b_c_idx * v_stride_size + d_base; + acc = _mm256_fmadd_ps(w_vec, _mm256_loadu_ps(v_ptr), acc); + } + _mm256_storeu_ps(acc_o + d_base, acc); + } + } + + inline void scale_and_store_d(const acc_dtype_t *__restrict__ acc_o, + const acc_dtype_t *__restrict__ logsum, + dtype_t *__restrict__ o_block, const int32_t t_r_idx, + const int32_t head_size, const int32_t dim_size) { + float reciprocal_logsum = 1.0f / logsum[0]; + __m256 reciprocal_logsum_vec = _mm256_set1_ps(reciprocal_logsum); + int j = 0; + for (; j <= dim_size - 8; j += 8) { + _mm256_storeu_ps(o_block + j, _mm256_mul_ps(_mm256_loadu_ps(acc_o + j), reciprocal_logsum_vec)); + } + for (; j < dim_size; ++j) { + o_block[j] = acc_o[j] * reciprocal_logsum; + } + } + + // Decode n-fixed functions + // FIX: Modified mma0_d_n_fixed to accept kv_stride_size. + inline void mma0_d_n_fixed(const int32_t Bc_n_fixed, const dtype_t *__restrict__ q_block, + const dtype_t *__restrict__ k_block, acc_dtype_t *__restrict__ acc_s, + const int32_t dim_size, const int32_t kv_stride_size, + const int32_t t_r_idx, const int32_t t_c_idx, const int32_t seq_size_q, + const int32_t seq_size_k, bool causal_mask) { + const dtype_t *q_block_line = q_block; + for (int32_t b_c_idx = 0; b_c_idx < Bc_n_fixed; ++b_c_idx) { + const dtype_t *k_block_line = k_block + b_c_idx * kv_stride_size; + float total = 0.0f; + for (int i = 0; i < dim_size; ++i) { total += q_block_line[i] * k_block_line[i]; } + acc_s[b_c_idx] = total; + } + } + + inline void softmax_d_n_fixed(const int32_t Bc_n_fixed, acc_dtype_t *__restrict__ acc_s, + acc_dtype_t *scoremax, acc_dtype_t *scoremax_prev, + acc_dtype_t *score_scale, acc_dtype_t *score_sum, + acc_dtype_t *logsum, + const float scale, + const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, bool causal_mask) { + scoremax_prev[0] = scoremax[0]; + float max_val = NEG_INF; + for (int bc = 0; bc < Bc_n_fixed; ++bc) max_val = fmaxf(max_val, acc_s[bc]); + scoremax[0] = fmaxf(max_val, scoremax[0]); + score_scale[0] = expf((scoremax_prev[0] - scoremax[0]) * scale); + float current_sum = 0.0f; + for (int bc = 0; bc < Bc_n_fixed; ++bc) { + float val = expf((acc_s[bc] - scoremax[0]) * scale); + acc_s[bc] = val; + current_sum += val; + } + score_sum[0] = current_sum; + logsum[0] = logsum[0] * score_scale[0] + score_sum[0]; + } + + inline void rescale_d_n_fixed(const int32_t Bc_n_fixed, acc_dtype_t *__restrict__ acc_o, + acc_dtype_t *__restrict__ score_scale, const int32_t dim_size, + const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, + bool causal_mask) { + float scale = score_scale[0]; + for (int j = 0; j < dim_size; ++j) { + acc_o[j] *= scale; + } + } + + // FIX: Modified mma1_d_n_fixed to accept kv_head_size. + inline void mma1_d_n_fixed(const int32_t Bc_n_fixed, const acc_dtype_t *__restrict__ w_block, + const dtype_t *__restrict__ v_block, acc_dtype_t *__restrict__ acc_o, + const int32_t kv_head_size, const int32_t dim_size, const int32_t t_r_idx, + const int32_t t_c_idx, const int32_t seq_size_q, + const int32_t seq_size_k, bool causal_mask) { + const int32_t v_stride_size = kv_head_size * dim_size; + for (int d_base = 0; d_base < dim_size; ++d_base) { + float acc = acc_o[d_base]; + for (int b_c_idx = 0; b_c_idx < Bc_n_fixed; ++b_c_idx) { + acc += w_block[b_c_idx] * v_block[b_c_idx * v_stride_size + d_base]; + } + acc_o[d_base] = acc; + } + } + +private: + acc_dtype_t *acc_o_; + acc_dtype_t *acc_s_; + acc_dtype_t *logsum_; + acc_dtype_t *scoremax_; + acc_dtype_t *scoremax_prev_; + acc_dtype_t *score_scale_; + acc_dtype_t *score_sum_; +}; + +// ======================================== +// FlashAttention2 核心实现 ( Q FP32/KV FP16 输入,FP32 输出版本) +// ======================================== +struct AVX_FA_2_GQA_Q_FP32_KV_FP16_BSHD_O_FP32_BSHD_ACC_FP32_IMPL { + // 【修改】定义多种输入类型 + using dtype_q_in_t = float; + using dtype_kv_in_t = mllm_fp16_t; + using dtype_out_t = float; + using acc_dtype_t = float; + + int32_t Br, Bc, Q_Head, KV_Head, threads; + bool high_precision; + + void configure(int32_t Br_, int32_t Bc_, int32_t Q_Head_, int32_t KV_Head_, int32_t threads_, bool high_precision_) { + Br = Br_; + Bc = Bc_; + Q_Head = Q_Head_; + KV_Head = KV_Head_; + threads = threads_; + high_precision = high_precision_; + } + + void init_workspace(acc_dtype_t *acc_o, acc_dtype_t *acc_s, + acc_dtype_t *logsum, acc_dtype_t *scoremax, acc_dtype_t *scoremax_prev, + acc_dtype_t *score_scale, acc_dtype_t *score_sum) { + acc_o_ = acc_o; + acc_s_ = acc_s; + logsum_ = logsum; + scoremax_ = scoremax; + scoremax_prev_ = scoremax_prev; + score_scale_ = score_scale; + score_sum_ = score_sum; + } + + void fa2(const dtype_q_in_t *__restrict__ Q, const dtype_kv_in_t *__restrict__ K, + const dtype_kv_in_t *__restrict__ V, dtype_out_t *__restrict__ O, const int32_t batch_size, + const int32_t head_size, const int32_t seq_size_q, const int32_t seq_size_k, + const int32_t dim_size, bool causal_mask = true) { + assert(Br == Bc); + // assert(Br % 4 == 0); + assert(head_size % threads == 0); + assert(dim_size % 8 == 0); + + if (seq_size_q != 1) { + __fa2_prefill_append(Q, K, V, O, batch_size, head_size, seq_size_q, seq_size_k, dim_size, causal_mask); + } else { + __fa2_decode(Q, K, V, O, batch_size, head_size, seq_size_q, seq_size_k, dim_size, causal_mask); + } + } + +private: + inline void __fa2_prefill_append(const dtype_q_in_t *__restrict__ Q, const dtype_kv_in_t *__restrict__ K, + const dtype_kv_in_t *__restrict__ V, dtype_out_t *__restrict__ O, + const int32_t batch_size, const int32_t head_size, + const int32_t seq_size_q, const int32_t seq_size_k, + const int32_t dim_size, bool causal_mask = true) { + const int32_t Tr = seq_size_q / Br; + const int32_t Tr_left = seq_size_q % Br; + const int32_t Tc = seq_size_k / Bc; + const int32_t Tc_left = seq_size_k % Tc; + + const float local_scale = 1.0f / sqrtf(static_cast(dim_size)); + const int32_t kv_group_size = (Q_Head > 0 && KV_Head > 0) ? Q_Head / KV_Head : 1; + + for (int32_t b_idx = 0; b_idx < batch_size; ++b_idx) { +// Note: OpenMP is not applied to the head_size loop in the prefill reference +#pragma omp parallel for num_threads(threads) schedule(dynamic, 1) if (threads > 1) + for (int32_t h_idx = 0; h_idx < head_size; ++h_idx) { + const int32_t thread_id = omp_get_thread_num(); + const int32_t this_thread_head = h_idx; + const int32_t this_thread_kv_head = this_thread_head / kv_group_size; + + for (int t_r_idx = 0; t_r_idx < Tr; ++t_r_idx) { + init_temp(logsum_ + thread_id * Br, scoremax_ + thread_id * Br, + acc_o_ + thread_id * Br * dim_size, dim_size); + for (int t_c_idx = 0; t_c_idx < Tc; ++t_c_idx) { + const dtype_q_in_t *tile_q = Q + b_idx * seq_size_q * head_size * dim_size + t_r_idx * Br * head_size * dim_size + this_thread_head * dim_size; + const dtype_kv_in_t *tile_k = K + b_idx * seq_size_k * KV_Head * dim_size + t_c_idx * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + const dtype_kv_in_t *tile_v = V + b_idx * seq_size_k * KV_Head * dim_size + t_c_idx * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + + acc_dtype_t *tile_acc_s = acc_s_ + thread_id * Br * Bc; + acc_dtype_t *acc_o = acc_o_ + thread_id * Br * dim_size; + + mma0(tile_q, tile_k, tile_acc_s, dim_size, head_size * dim_size, KV_Head * dim_size, t_r_idx, t_c_idx, seq_size_q, seq_size_k, causal_mask); + softmax(tile_acc_s, scoremax_ + thread_id * Br, scoremax_prev_ + thread_id * Br, score_scale_ + thread_id * Br, score_sum_ + thread_id * Br, logsum_ + thread_id * Br, local_scale, t_r_idx, t_c_idx, seq_size_q, seq_size_k, causal_mask); + rescale(acc_o, score_scale_ + thread_id * Br, dim_size, t_r_idx, t_c_idx, seq_size_q, seq_size_k, causal_mask); + mma1(tile_acc_s, tile_v, acc_o, KV_Head, dim_size, t_r_idx, t_c_idx, seq_size_q, seq_size_k, causal_mask); + } + if (Tc_left) { + const dtype_q_in_t *tile_q = Q + b_idx * seq_size_q * head_size * dim_size + t_r_idx * Br * head_size * dim_size + this_thread_head * dim_size; + const dtype_kv_in_t *tile_k = K + b_idx * seq_size_k * KV_Head * dim_size + Tc * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + const dtype_kv_in_t *tile_v = V + b_idx * seq_size_k * KV_Head * dim_size + Tc * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + acc_dtype_t *tile_acc_s = acc_s_ + thread_id * Br * Bc; + acc_dtype_t *acc_o = acc_o_ + thread_id * Br * dim_size; + mma0_pa_n_fixed(Br, Tc_left, tile_q, tile_k, tile_acc_s, dim_size, head_size * dim_size, KV_Head * dim_size, t_r_idx, Tc, seq_size_q, seq_size_k, causal_mask); + softmax_pa_n_fixed(Br, Tc_left, tile_acc_s, scoremax_ + thread_id * Br, scoremax_prev_ + thread_id * Br, score_scale_ + thread_id * Br, score_sum_ + thread_id * Br, logsum_ + thread_id * Br, local_scale, t_r_idx, Tc, seq_size_q, seq_size_k, causal_mask); + rescale_pa_n_fixed(Br, Tc_left, acc_o, score_scale_ + thread_id * Br, dim_size, t_r_idx, Tc, seq_size_q, seq_size_k, causal_mask); + mma1_pa_n_fixed(Br, Tc_left, tile_acc_s, tile_v, acc_o, KV_Head, dim_size, t_r_idx, Tc, seq_size_q, seq_size_k, causal_mask); + } + scale_and_store(acc_o_ + thread_id * Br * dim_size, logsum_ + thread_id * Br, O + b_idx * seq_size_q * head_size * dim_size + t_r_idx * Br * head_size * dim_size + this_thread_head * dim_size, t_r_idx, head_size, dim_size); + } + if (Tr_left) { + init_temp(logsum_ + thread_id * Br, scoremax_ + thread_id * Br, acc_o_ + thread_id * Br * dim_size, dim_size); + for (int t_c_idx = 0; t_c_idx < Tc; ++t_c_idx) { + const dtype_q_in_t *tile_q = Q + b_idx * seq_size_q * head_size * dim_size + Tr * Br * head_size * dim_size + this_thread_head * dim_size; + const dtype_kv_in_t *tile_k = K + b_idx * seq_size_k * KV_Head * dim_size + t_c_idx * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + const dtype_kv_in_t *tile_v = V + b_idx * seq_size_k * KV_Head * dim_size + t_c_idx * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + acc_dtype_t *tile_acc_s = acc_s_ + thread_id * Br * Bc; + acc_dtype_t *acc_o = acc_o_ + thread_id * Br * dim_size; + mma0_pa_n_fixed(Tr_left, Bc, tile_q, tile_k, tile_acc_s, dim_size, head_size * dim_size, KV_Head * dim_size, Tr, t_c_idx, seq_size_q, seq_size_k, causal_mask); + softmax_pa_n_fixed(Tr_left, Bc, tile_acc_s, scoremax_ + thread_id * Br, scoremax_prev_ + thread_id * Br, score_scale_ + thread_id * Br, score_sum_ + thread_id * Br, logsum_ + thread_id * Br, local_scale, Tr, t_c_idx, seq_size_q, seq_size_k, causal_mask); + rescale_pa_n_fixed(Tr_left, Bc, acc_o, score_scale_ + thread_id * Br, dim_size, Tr, t_c_idx, seq_size_q, seq_size_k, causal_mask); + mma1_pa_n_fixed(Tr_left, Bc, tile_acc_s, tile_v, acc_o, KV_Head, dim_size, Tr, t_c_idx, seq_size_q, seq_size_k, causal_mask); + } + if (Tc_left) { + const dtype_q_in_t *tile_q = Q + b_idx * seq_size_q * head_size * dim_size + Tr * Br * head_size * dim_size + this_thread_head * dim_size; + const dtype_kv_in_t *tile_k = K + b_idx * seq_size_k * KV_Head * dim_size + Tc * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + const dtype_kv_in_t *tile_v = V + b_idx * seq_size_k * KV_Head * dim_size + Tc * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + acc_dtype_t *tile_acc_s = acc_s_ + thread_id * Br * Bc; + acc_dtype_t *acc_o = acc_o_ + thread_id * Br * dim_size; + mma0_pa_n_fixed(Tr_left, Tc_left, tile_q, tile_k, tile_acc_s, dim_size, head_size * dim_size, KV_Head * dim_size, Tr, Tc, seq_size_q, seq_size_k, causal_mask); + softmax_pa_n_fixed(Tr_left, Tc_left, tile_acc_s, scoremax_ + thread_id * Br, scoremax_prev_ + thread_id * Br, score_scale_ + thread_id * Br, score_sum_ + thread_id * Br, logsum_ + thread_id * Br, local_scale, Tr, Tc, seq_size_q, seq_size_k, causal_mask); + rescale_pa_n_fixed(Tr_left, Tc_left, acc_o, score_scale_ + thread_id * Br, dim_size, Tr, Tc, seq_size_q, seq_size_k, causal_mask); + mma1_pa_n_fixed(Tr_left, Tc_left, tile_acc_s, tile_v, acc_o, KV_Head, dim_size, Tr, Tc, seq_size_q, seq_size_k, causal_mask); + } + scale_and_store_pa_n_fixed(Tr_left, acc_o_ + thread_id * Br * dim_size, logsum_ + thread_id * Br, O + b_idx * seq_size_q * head_size * dim_size + Tr * Br * head_size * dim_size + this_thread_head * dim_size, Tr, head_size, dim_size); + } + } + } + } + + inline void __fa2_decode(const dtype_q_in_t *__restrict__ Q, const dtype_kv_in_t *__restrict__ K, + const dtype_kv_in_t *__restrict__ V, dtype_out_t *__restrict__ O, + const int32_t batch_size, const int32_t head_size, + const int32_t seq_size_q, const int32_t seq_size_k, + const int32_t dim_size, bool causal_mask = true) { + const int32_t Tr = 1; // In decode, seq_size_q is always 1 + const int32_t Tc = seq_size_k / Bc; + const int32_t Tc_left = seq_size_k % Bc; + + const float local_scale = 1.0f / sqrtf(static_cast(dim_size)); + const int32_t kv_group_size = (Q_Head > 0 && KV_Head > 0) ? Q_Head / KV_Head : 1; + + for (int32_t b_idx = 0; b_idx < batch_size; ++b_idx) { +#pragma omp parallel for num_threads(threads) schedule(dynamic, 1) if (threads > 1) + for (int32_t h_idx = 0; h_idx < head_size; ++h_idx) { + const int32_t thread_id = omp_get_thread_num(); + const int32_t this_thread_head = h_idx; + const int32_t this_thread_kv_head = this_thread_head / kv_group_size; + + // In decode mode, t_r_idx is always 0 as we process one token + const int t_r_idx = 0; + init_temp_d(logsum_ + thread_id * Br, scoremax_ + thread_id * Br, acc_o_ + thread_id * Br * dim_size, dim_size); + for (int t_c_idx = 0; t_c_idx < Tc; ++t_c_idx) { + const dtype_q_in_t *tile_q = Q + b_idx * seq_size_q * head_size * dim_size + t_r_idx * 1 * head_size * dim_size + this_thread_head * dim_size; + const dtype_kv_in_t *tile_k = K + b_idx * seq_size_k * KV_Head * dim_size + t_c_idx * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + const dtype_kv_in_t *tile_v = V + b_idx * seq_size_k * KV_Head * dim_size + t_c_idx * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + + acc_dtype_t *tile_acc_s = acc_s_ + thread_id * Br * Bc; + acc_dtype_t *acc_o = acc_o_ + thread_id * Br * dim_size; + + mma0_d(tile_q, tile_k, tile_acc_s, dim_size, KV_Head * dim_size, t_r_idx, t_c_idx, seq_size_q, seq_size_k, causal_mask); + softmax_d(tile_acc_s, scoremax_ + thread_id * Br, scoremax_prev_ + thread_id * Br, score_scale_ + thread_id * Br, score_sum_ + thread_id * Br, logsum_ + thread_id * Br, local_scale, t_r_idx, t_c_idx, seq_size_q, seq_size_k, causal_mask); + rescale_d(acc_o, score_scale_ + thread_id * Br, dim_size, t_r_idx, t_c_idx, seq_size_q, seq_size_k, causal_mask); + mma1_d(tile_acc_s, tile_v, acc_o, KV_Head, dim_size, t_r_idx, t_c_idx, seq_size_q, seq_size_k, causal_mask); + } + if (Tc_left) { + const dtype_q_in_t *tile_q = Q + b_idx * seq_size_q * head_size * dim_size + t_r_idx * 1 * head_size * dim_size + this_thread_head * dim_size; + const dtype_kv_in_t *tile_k = K + b_idx * seq_size_k * KV_Head * dim_size + Tc * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + const dtype_kv_in_t *tile_v = V + b_idx * seq_size_k * KV_Head * dim_size + Tc * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + acc_dtype_t *tile_acc_s = acc_s_ + thread_id * Br * Bc; + acc_dtype_t *acc_o = acc_o_ + thread_id * Br * dim_size; + mma0_d_n_fixed(Tc_left, tile_q, tile_k, tile_acc_s, dim_size, KV_Head * dim_size, t_r_idx, Tc, seq_size_q, seq_size_k, causal_mask); + softmax_d_n_fixed(Tc_left, tile_acc_s, scoremax_ + thread_id * Br, scoremax_prev_ + thread_id * Br, score_scale_ + thread_id * Br, score_sum_ + thread_id * Br, logsum_ + thread_id * Br, local_scale, t_r_idx, Tc, seq_size_q, seq_size_k, causal_mask); + rescale_d_n_fixed(Tc_left, acc_o, score_scale_ + thread_id * Br, dim_size, t_r_idx, Tc, seq_size_q, seq_size_k, causal_mask); + mma1_d_n_fixed(Tc_left, tile_acc_s, tile_v, acc_o, KV_Head, dim_size, t_r_idx, Tc, seq_size_q, seq_size_k, causal_mask); + } + scale_and_store_d(acc_o_ + thread_id * Br * dim_size, logsum_ + thread_id * Br, O + b_idx * seq_size_q * head_size * dim_size + t_r_idx * 1 * head_size * dim_size + this_thread_head * dim_size, t_r_idx, head_size, dim_size); + } + } + } + + inline void init_temp(acc_dtype_t *logsum, acc_dtype_t *scoremax, acc_dtype_t *acc_o, const int32_t dim_size) { + __m256 zero_vec = _mm256_set1_ps(0.0f); + __m256 neg_inf_vec = _mm256_set1_ps(NEG_INF); + + // 【最终修正】使用安全的循环来确保完全初始化,不再依赖Br是8的倍数 + int i = 0; + for (; i <= Br - 8; i += 8) { + _mm256_storeu_ps(logsum + i, zero_vec); + _mm256_storeu_ps(scoremax + i, neg_inf_vec); + } + // 处理剩余的元素(如果Br不是8的倍数) + for (; i < Br; ++i) { + logsum[i] = 0.0f; + scoremax[i] = NEG_INF; + } + + // acc_o 的初始化是安全的,因为调用者保证了 dim_size % 8 == 0 + for (int j = 0; j < Br * dim_size; j += 8) { + _mm256_storeu_ps(acc_o + j, zero_vec); + } + } + + inline void mma0(const dtype_q_in_t *__restrict__ q_block, const dtype_kv_in_t *__restrict__ k_block, + acc_dtype_t *__restrict__ acc_s, const int32_t dim_size, + const int32_t q_stride_size, const int32_t kv_stride_size, + const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, bool causal_mask) { + const int32_t global_r_start = t_r_idx * Br, global_r_end = global_r_start + Br; + const int32_t global_c_start = t_c_idx * Bc; + int delta_pos = seq_size_k - seq_size_q; + if (causal_mask && (global_c_start - delta_pos > (global_r_end - 1))) { return; } + for (int32_t b_r_idx = 0; b_r_idx < Br; ++b_r_idx) { + const dtype_q_in_t *q_block_line = q_block + b_r_idx * q_stride_size; + for (int32_t b_c_idx = 0; b_c_idx < Bc; ++b_c_idx) { + const dtype_kv_in_t *k_block_line = k_block + b_c_idx * kv_stride_size; + __m256 sum_vec = _mm256_setzero_ps(); + int i = 0; + for (; i <= dim_size - 8; i += 8) { + __m256 q_vec = _mm256_loadu_ps(q_block_line + i); + __m256 k_vec = MLLM_F32Cx8_LOAD(k_block_line + i); + sum_vec = _mm256_fmadd_ps(q_vec, k_vec, sum_vec); + } + acc_dtype_t total = _mm256_hadd_ps(sum_vec); + for (; i < dim_size; ++i) { total += q_block_line[i] * MLLM_FP16_TO_FP32(k_block_line[i]); } + acc_s[b_r_idx * Bc + b_c_idx] = total; + } + } + if (causal_mask && (global_r_end == (t_c_idx * Bc + Bc) - delta_pos)) { + for (int i = 0; i < Br; ++i) { + for (int j = 0; j < Bc; ++j) { + if (j > i) { acc_s[i * Bc + j] = NEG_INF; } + } + } + } + } + + inline void softmax(acc_dtype_t *__restrict__ acc_s, acc_dtype_t *scoremax, acc_dtype_t *scoremax_prev, + acc_dtype_t *score_scale, acc_dtype_t *score_sum, acc_dtype_t *logsum, + const float scale, const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, bool causal_mask) { + const int32_t global_r_start = t_r_idx * Br; + const int32_t global_c_start = t_c_idx * Bc; + int delta_pos = seq_size_k - seq_size_q; + if (causal_mask && (global_c_start - delta_pos > (global_r_start + Br - 1))) return; + memcpy(scoremax_prev, scoremax, Br * sizeof(acc_dtype_t)); + for (int br = 0; br < Br; ++br) { + __m256 max_vec = _mm256_set1_ps(scoremax[br]); + acc_dtype_t *row = acc_s + br * Bc; + int bc = 0; + for (; bc <= Bc - 8; bc += 8) { max_vec = _mm256_max_ps(max_vec, _mm256_loadu_ps(row + bc)); } + float max_val = _mm256_hmax_ps(max_vec); + for (; bc < Bc; ++bc) { max_val = fmaxf(max_val, row[bc]); } + scoremax[br] = max_val; + } + for (int br = 0; br < Br; ++br) { score_scale[br] = expf((scoremax_prev[br] - scoremax[br]) * scale); } + for (int br = 0; br < Br; ++br) { + const float sm = scoremax[br]; + acc_dtype_t *row = acc_s + br * Bc; + float sum = 0.0f; + for (int bc = 0; bc < Bc; ++bc) { + float val = expf((row[bc] - sm) * scale); + row[bc] = val; + sum += val; + } + score_sum[br] = sum; + } + for (int br = 0; br < Br; ++br) { logsum[br] = logsum[br] * score_scale[br] + score_sum[br]; } + } + + inline void rescale(acc_dtype_t *__restrict__ acc_o, acc_dtype_t *__restrict__ score_scale, + const int32_t dim_size, const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, bool causal_mask) { + // (无变化) + const int32_t global_r_start = t_r_idx * Br; + const int32_t global_c_start = t_c_idx * Bc; + int delta_pos = seq_size_k - seq_size_q; + if (causal_mask && (global_c_start - delta_pos > (global_r_start + Br - 1))) return; + for (int i = 0; i < Br; ++i) { + __m256 scale_v = _mm256_set1_ps(score_scale[i]); + float *row_ptr = acc_o + i * dim_size; + for (int j = 0; j < dim_size; j += 8) { + __m256 acc = _mm256_loadu_ps(row_ptr + j); + acc = _mm256_mul_ps(acc, scale_v); + _mm256_storeu_ps(row_ptr + j, acc); + } + } + } + + inline void mma1(const acc_dtype_t *__restrict__ w_block, const dtype_kv_in_t *__restrict__ v_block, + acc_dtype_t *__restrict__ acc_o, const int32_t kv_head_size, const int32_t dim_size, + const int32_t t_r_idx, const int32_t t_c_idx, const int32_t seq_size_q, + const int32_t seq_size_k, bool causal_mask) { + const int32_t global_r_start = t_r_idx * Br; + const int32_t global_c_start = t_c_idx * Bc; + int delta_pos = seq_size_k - seq_size_q; + if (causal_mask && (global_c_start - delta_pos > (global_r_start + Br - 1))) return; + const int32_t v_stride_size = kv_head_size * dim_size; + for (int b_r_idx = 0; b_r_idx < Br; ++b_r_idx) { + for (int d_base = 0; d_base < dim_size; d_base += 8) { + __m256 acc = _mm256_loadu_ps(acc_o + b_r_idx * dim_size + d_base); + for (int b_c_idx = 0; b_c_idx < Bc; ++b_c_idx) { + __m256 w_vec = _mm256_set1_ps(w_block[b_r_idx * Bc + b_c_idx]); + const dtype_kv_in_t *v_ptr = v_block + b_c_idx * v_stride_size + d_base; + __m256 v_vec = MLLM_F32Cx8_LOAD(v_ptr); + acc = _mm256_fmadd_ps(w_vec, v_vec, acc); + } + _mm256_storeu_ps(acc_o + b_r_idx * dim_size + d_base, acc); + } + } + } + + inline void scale_and_store(const acc_dtype_t *__restrict__ acc_o, const acc_dtype_t *__restrict__ logsum, + dtype_out_t *__restrict__ o_block, const int32_t t_r_idx, + const int32_t head_size, const int32_t dim_size) { + // (无变化,输出O是fp32) + for (int i = 0; i < Br; ++i) { + // 【修正】这里的 o_block_line 计算是错误的,它没有正确处理 BSHD 布局下的行步进 + // 正确的行步进已经由外层循环的 o_block 指针计算好了 + // 我们只需要在此基础上按行写入即可 + dtype_out_t *o_block_line = o_block + i * head_size * dim_size; // << 保持 BSHD 的行步长 + + __m256 reciprocal_logsum_vec = _mm256_set1_ps(1.0f / logsum[i]); + int j = 0; + for (; j <= dim_size - 8; j += 8) { + __m256 vec_acc_o = _mm256_loadu_ps(acc_o + i * dim_size + j); + __m256 result_vec = _mm256_mul_ps(vec_acc_o, reciprocal_logsum_vec); + _mm256_storeu_ps(o_block_line + j, result_vec); + } + float reciprocal_logsum = 1.0f / logsum[i]; + for (; j < dim_size; ++j) { o_block_line[j] = acc_o[i * dim_size + j] * reciprocal_logsum; } + } + } + + inline void mma0_pa_n_fixed(const int32_t Br_n_fixed, const int32_t Bc_n_fixed, + const dtype_q_in_t *__restrict__ q_block, const dtype_kv_in_t *__restrict__ k_block, + acc_dtype_t *__restrict__ acc_s, const int32_t dim_size, + const int32_t q_stride_size, const int32_t kv_stride_size, + const int32_t t_r_idx, const int32_t t_c_idx, const int32_t seq_size_q, + const int32_t seq_size_k, bool causal_mask) { + const int32_t global_r_start = t_r_idx * Br; + const int32_t global_r_end = global_r_start + Br_n_fixed; + const int32_t global_c_start = t_c_idx * Bc; + int delta_pos = seq_size_k - seq_size_q; + if (causal_mask && (global_c_start - delta_pos > (global_r_end - 1))) { return; } + for (int32_t b_r_idx = 0; b_r_idx < Br_n_fixed; ++b_r_idx) { + const dtype_q_in_t *q_block_line = q_block + b_r_idx * q_stride_size; + for (int32_t b_c_idx = 0; b_c_idx < Bc_n_fixed; ++b_c_idx) { + const dtype_kv_in_t *k_block_line = k_block + b_c_idx * kv_stride_size; + __m256 sum_vec = _mm256_setzero_ps(); + int i = 0; + for (; i <= dim_size - 8; i += 8) { + sum_vec = _mm256_fmadd_ps(_mm256_loadu_ps(q_block_line + i), MLLM_F32Cx8_LOAD(k_block_line + i), sum_vec); + } + acc_dtype_t total = _mm256_hadd_ps(sum_vec); + for (; i < dim_size; ++i) { total += q_block_line[i] * MLLM_FP16_TO_FP32(k_block_line[i]); } + acc_s[b_r_idx * Bc + b_c_idx] = total; + } + } + if (causal_mask && (global_r_end == (global_c_start + Bc_n_fixed) - delta_pos)) { + for (int i = 0; i < Br_n_fixed; ++i) { + for (int j = 0; j < Bc_n_fixed; ++j) { + if (j > i) { acc_s[i * Bc + j] = NEG_INF; } + } + } + } + } + + inline void softmax_pa_n_fixed(const int32_t Br_n_fixed, const int32_t Bc_n_fixed, + acc_dtype_t *__restrict__ acc_s, acc_dtype_t *scoremax, + acc_dtype_t *scoremax_prev, acc_dtype_t *score_scale, + acc_dtype_t *score_sum, acc_dtype_t *logsum, + const float scale, const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, bool causal_mask) { + const int32_t global_r_start = t_r_idx * Br; + const int32_t global_c_start = t_c_idx * Bc; + int delta_pos = seq_size_k - seq_size_q; + if (causal_mask && (global_c_start - delta_pos > (global_r_start + Br_n_fixed - 1))) return; + memcpy(scoremax_prev, scoremax, Br_n_fixed * sizeof(acc_dtype_t)); + for (int br = 0; br < Br_n_fixed; ++br) { + acc_dtype_t *row = acc_s + br * Bc; + float max_val = NEG_INF; + for (int bc = 0; bc < Bc_n_fixed; ++bc) max_val = fmaxf(max_val, row[bc]); + scoremax[br] = fmaxf(max_val, scoremax[br]); + } + for (int br = 0; br < Br_n_fixed; ++br) { score_scale[br] = expf((scoremax_prev[br] - scoremax[br]) * scale); } + for (int br = 0; br < Br_n_fixed; ++br) { + acc_dtype_t *row = acc_s + br * Bc; + float current_sum = 0.0f; + for (int bc = 0; bc < Bc_n_fixed; ++bc) { + float val = expf((row[bc] - scoremax[br]) * scale); + row[bc] = val; + current_sum += val; + } + score_sum[br] = current_sum; + } + for (int br = 0; br < Br_n_fixed; ++br) { logsum[br] = logsum[br] * score_scale[br] + score_sum[br]; } + } + + // (无变化) + inline void rescale_pa_n_fixed(const int32_t Br_n_fixed, const int32_t Bc_n_fixed, + acc_dtype_t *__restrict__ acc_o, acc_dtype_t *__restrict__ score_scale, + const int32_t dim_size, const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, bool causal_mask) { + const int32_t global_r_start = t_r_idx * Br; + const int32_t global_c_start = t_c_idx * Bc; + int delta_pos = seq_size_k - seq_size_q; + if (causal_mask && (global_c_start - delta_pos > (global_r_start + Br_n_fixed - 1))) return; + for (int i = 0; i < Br_n_fixed; ++i) { + float *row_ptr = acc_o + i * dim_size; + __m256 scale_v = _mm256_set1_ps(score_scale[i]); + for (int j = 0; j < dim_size; j += 8) { + _mm256_storeu_ps(row_ptr + j, _mm256_mul_ps(_mm256_loadu_ps(row_ptr + j), scale_v)); + } + } + } + + inline void mma1_pa_n_fixed(const int32_t Br_n_fixed, const int32_t Bc_n_fixed, + const acc_dtype_t *__restrict__ w_block, const dtype_kv_in_t *__restrict__ v_block, + acc_dtype_t *__restrict__ acc_o, const int32_t kv_head_size, const int32_t dim_size, + const int32_t t_r_idx, const int32_t t_c_idx, const int32_t seq_size_q, + const int32_t seq_size_k, bool causal_mask) { + const int32_t global_r_start = t_r_idx * Br; + const int32_t global_c_start = t_c_idx * Bc; + int delta_pos = seq_size_k - seq_size_q; + if (causal_mask && (global_c_start - delta_pos > (global_r_start + Br_n_fixed - 1))) return; + const int32_t v_stride_size = kv_head_size * dim_size; + for (int b_r_idx = 0; b_r_idx < Br_n_fixed; ++b_r_idx) { + for (int d_base = 0; d_base < dim_size; d_base += 8) { + __m256 acc = _mm256_loadu_ps(acc_o + b_r_idx * dim_size + d_base); + for (int b_c_idx = 0; b_c_idx < Bc_n_fixed; ++b_c_idx) { + __m256 w_vec = _mm256_set1_ps(w_block[b_r_idx * Bc + b_c_idx]); + const dtype_kv_in_t *v_ptr = v_block + b_c_idx * v_stride_size + d_base; + acc = _mm256_fmadd_ps(w_vec, MLLM_F32Cx8_LOAD(v_ptr), acc); + } + _mm256_storeu_ps(acc_o + b_r_idx * dim_size + d_base, acc); + } + } + } + + inline void scale_and_store_pa_n_fixed(const int32_t Br_n_fixed, const acc_dtype_t *__restrict__ acc_o, + const acc_dtype_t *__restrict__ logsum, dtype_out_t *__restrict__ o_block, + const int32_t t_r_idx, const int32_t head_size, const int32_t dim_size) { + for (int i = 0; i < Br_n_fixed; ++i) { + // 【修正】同上,这里的行步进计算也是错误的 + dtype_out_t *o_block_line = o_block + i * head_size * dim_size; // << 保持 BSHD 的行步长 + + float reciprocal_logsum = 1.0f / logsum[i]; + __m256 reciprocal_logsum_vec = _mm256_set1_ps(reciprocal_logsum); + int j = 0; + for (; j <= dim_size - 8; j += 8) { + __m256 vec_acc_o = _mm256_loadu_ps(acc_o + i * dim_size + j); + _mm256_storeu_ps(o_block_line + j, _mm256_mul_ps(vec_acc_o, reciprocal_logsum_vec)); + } + for (; j < dim_size; ++j) { o_block_line[j] = acc_o[i * dim_size + j] * reciprocal_logsum; } + } + } + + // (此函数无变化,但为完整性一并提供) + inline void init_temp_d(acc_dtype_t *logsum, acc_dtype_t *scoremax, acc_dtype_t *acc_o, const int32_t dim_size) { + logsum[0] = 0.0f; + scoremax[0] = NEG_INF; + __m256 zero_vec = _mm256_setzero_ps(); + for (int i = 0; i < 1 * dim_size; i += 8) { _mm256_storeu_ps(acc_o + i, zero_vec); } + } + + inline void mma0_d(const dtype_q_in_t *__restrict__ q_block, const dtype_kv_in_t *__restrict__ k_block, + acc_dtype_t *__restrict__ acc_s, const int32_t dim_size, + const int32_t kv_stride_size, const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, bool causal_mask) { + const dtype_q_in_t *q_block_line = q_block; + for (int32_t b_c_idx = 0; b_c_idx < Bc; ++b_c_idx) { + const dtype_kv_in_t *k_block_line = k_block + b_c_idx * kv_stride_size; + __m256 sum_vec = _mm256_setzero_ps(); + int i = 0; + for (; i <= dim_size - 8; i += 8) { + sum_vec = _mm256_fmadd_ps(_mm256_loadu_ps(q_block_line + i), MLLM_F32Cx8_LOAD(k_block_line + i), sum_vec); + } + acc_dtype_t total = _mm256_hadd_ps(sum_vec); + for (; i < dim_size; ++i) { total += q_block_line[i] * MLLM_FP16_TO_FP32(k_block_line[i]); } + acc_s[b_c_idx] = total; + } + } + + inline void softmax_d(acc_dtype_t *__restrict__ acc_s, + acc_dtype_t *scoremax, acc_dtype_t *scoremax_prev, acc_dtype_t *score_scale, + acc_dtype_t *score_sum, acc_dtype_t *logsum, const float scale, + const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, bool causal_mask) { + scoremax_prev[0] = scoremax[0]; + float max_val = NEG_INF; + for (int bc = 0; bc < Bc; ++bc) max_val = fmaxf(max_val, acc_s[bc]); + scoremax[0] = fmaxf(max_val, scoremax[0]); + score_scale[0] = expf((scoremax_prev[0] - scoremax[0]) * scale); + float current_sum = 0.0f; + for (int bc = 0; bc < Bc; ++bc) { + float val = expf((acc_s[bc] - scoremax[0]) * scale); + acc_s[bc] = val; + current_sum += val; + } + score_sum[0] = current_sum; + logsum[0] = logsum[0] * score_scale[0] + score_sum[0]; + } + + // (此函数无变化,但为完整性一并提供) + inline void rescale_d(acc_dtype_t *__restrict__ acc_o, acc_dtype_t *__restrict__ score_scale, + const int32_t dim_size, const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, bool causal_mask) { + __m256 scale_v = _mm256_set1_ps(score_scale[0]); + for (int j = 0; j < dim_size; j += 8) { + __m256 acc = _mm256_loadu_ps(acc_o + j); + acc = _mm256_mul_ps(acc, scale_v); + _mm256_storeu_ps(acc_o + j, acc); + } + } + + inline void mma1_d(const acc_dtype_t *__restrict__ w_block, const dtype_kv_in_t *__restrict__ v_block, + acc_dtype_t *__restrict__ acc_o, const int32_t kv_head_size, const int32_t dim_size, + const int32_t t_r_idx, const int32_t t_c_idx, const int32_t seq_size_q, + const int32_t seq_size_k, bool causal_mask) { + const int32_t v_stride_size = kv_head_size * dim_size; + for (int d_base = 0; d_base < dim_size; d_base += 8) { + __m256 acc = _mm256_loadu_ps(acc_o + d_base); + for (int b_c_idx = 0; b_c_idx < Bc; ++b_c_idx) { + __m256 w_vec = _mm256_set1_ps(w_block[b_c_idx]); + const dtype_kv_in_t *v_ptr = v_block + b_c_idx * v_stride_size + d_base; + __m256 v_vec = MLLM_F32Cx8_LOAD(v_ptr); + acc = _mm256_fmadd_ps(w_vec, v_vec, acc); + } + _mm256_storeu_ps(acc_o + d_base, acc); + } + } + // (此函数无变化,但为完整性一并提供) + inline void scale_and_store_d(const acc_dtype_t *__restrict__ acc_o, + const acc_dtype_t *__restrict__ logsum, + dtype_out_t *__restrict__ o_block, const int32_t t_r_idx, + const int32_t head_size, const int32_t dim_size) { + float reciprocal_logsum = 1.0f / logsum[0]; + __m256 reciprocal_logsum_vec = _mm256_set1_ps(reciprocal_logsum); + int j = 0; + for (; j <= dim_size - 8; j += 8) { + _mm256_storeu_ps(o_block + j, _mm256_mul_ps(_mm256_loadu_ps(acc_o + j), reciprocal_logsum_vec)); + } + for (; j < dim_size; ++j) { + o_block[j] = acc_o[j] * reciprocal_logsum; + } + } + + inline void mma0_d_n_fixed(const int32_t Bc_n_fixed, const dtype_q_in_t *__restrict__ q_block, + const dtype_kv_in_t *__restrict__ k_block, acc_dtype_t *__restrict__ acc_s, + const int32_t dim_size, const int32_t kv_stride_size, + const int32_t t_r_idx, const int32_t t_c_idx, const int32_t seq_size_q, + const int32_t seq_size_k, bool causal_mask) { + const dtype_q_in_t *q_block_line = q_block; + for (int32_t b_c_idx = 0; b_c_idx < Bc_n_fixed; ++b_c_idx) { + const dtype_kv_in_t *k_block_line = k_block + b_c_idx * kv_stride_size; + float total = 0.0f; + for (int i = 0; i < dim_size; ++i) { total += q_block_line[i] * MLLM_FP16_TO_FP32(k_block_line[i]); } + acc_s[b_c_idx] = total; + } + } + + inline void softmax_d_n_fixed(const int32_t Bc_n_fixed, acc_dtype_t *__restrict__ acc_s, + acc_dtype_t *scoremax, acc_dtype_t *scoremax_prev, + acc_dtype_t *score_scale, acc_dtype_t *score_sum, + acc_dtype_t *logsum, const float scale, + const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, bool causal_mask) { + scoremax_prev[0] = scoremax[0]; + float max_val = NEG_INF; + for (int bc = 0; bc < Bc_n_fixed; ++bc) max_val = fmaxf(max_val, acc_s[bc]); + scoremax[0] = fmaxf(max_val, scoremax[0]); + score_scale[0] = expf((scoremax_prev[0] - scoremax[0]) * scale); + float current_sum = 0.0f; + for (int bc = 0; bc < Bc_n_fixed; ++bc) { + float val = expf((acc_s[bc] - scoremax[0]) * scale); + acc_s[bc] = val; + current_sum += val; + } + score_sum[0] = current_sum; + logsum[0] = logsum[0] * score_scale[0] + score_sum[0]; + } + + // (此函数无变化,但为完整性一并提供) + inline void rescale_d_n_fixed(const int32_t Bc_n_fixed, acc_dtype_t *__restrict__ acc_o, + acc_dtype_t *__restrict__ score_scale, const int32_t dim_size, + const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, + bool causal_mask) { + float scale = score_scale[0]; + for (int j = 0; j < dim_size; ++j) { acc_o[j] *= scale; } + } + + inline void mma1_d_n_fixed(const int32_t Bc_n_fixed, const acc_dtype_t *__restrict__ w_block, + const dtype_kv_in_t *__restrict__ v_block, acc_dtype_t *__restrict__ acc_o, + const int32_t kv_head_size, const int32_t dim_size, const int32_t t_r_idx, + const int32_t t_c_idx, const int32_t seq_size_q, + const int32_t seq_size_k, bool causal_mask) { + const int32_t v_stride_size = kv_head_size * dim_size; + for (int d_base = 0; d_base < dim_size; ++d_base) { + float acc = acc_o[d_base]; + for (int b_c_idx = 0; b_c_idx < Bc_n_fixed; ++b_c_idx) { + // Scalar fallback for leftover dimensions + acc += w_block[b_c_idx] * MLLM_FP16_TO_FP32(v_block[b_c_idx * v_stride_size + d_base]); + } + acc_o[d_base] = acc; + } + } + +private: + // float scale_; + acc_dtype_t *acc_o_; + acc_dtype_t *acc_s_; + acc_dtype_t *logsum_; + acc_dtype_t *scoremax_; + acc_dtype_t *scoremax_prev_; + acc_dtype_t *score_scale_; + acc_dtype_t *score_sum_; +}; + +// ======================================== +// 【修改】统一的FlashAttention2接口,改为模板以支持不同实现 +template +struct FlashAttn2T { +public: + using dtype_q_in_t = typename Impl::dtype_q_in_t; + using dtype_kv_in_t = typename Impl::dtype_kv_in_t; + using dtype_out_t = typename Impl::dtype_out_t; + using acc_dtype_t = typename Impl::acc_dtype_t; + + void configure(int32_t Br, int32_t Bc, int32_t Q_Head, int32_t KV_Head, int32_t threads, bool high_precision) { + impl_.configure(Br, Bc, Q_Head, KV_Head, threads, high_precision); + } + + void init_workspace(acc_dtype_t *acc_o, acc_dtype_t *acc_s, + acc_dtype_t *logsum, acc_dtype_t *scoremax, acc_dtype_t *scoremax_prev, + acc_dtype_t *score_scale, acc_dtype_t *score_sum) { + // Note: workspace pointers are always float, acc_s_cast is removed + impl_.init_workspace(acc_o, acc_s, logsum, scoremax, scoremax_prev, score_scale, score_sum); + } + + void operator()(const dtype_q_in_t *__restrict__ Q, const dtype_kv_in_t *__restrict__ K, + const dtype_kv_in_t *__restrict__ V, dtype_out_t *__restrict__ O, + const int32_t batch_size, const int32_t head_size, + const int32_t seq_size_q, const int32_t seq_size_k, + const int32_t dim_size, bool causal_mask = true) { + impl_.fa2(Q, K, V, O, batch_size, head_size, seq_size_q, seq_size_k, dim_size, causal_mask); + } + +private: + Impl impl_; +}; + +} // namespace mobi_attn + +// ======================================== +// 用户接口函数 +// ======================================== + +// 【修改】用户接口函数,增加FP16输入分支 +void flash_attention_2_forward( + const void *Q, const void *K, const void *V, void *O, + int32_t batch_size, int32_t head_size, int32_t seq_size_q, int32_t seq_size_k, int32_t dim_size, + bool causal_mask, bool use_fp32, int32_t threads, int32_t br, int32_t bc, + int32_t q_head, int32_t kv_head, bool high_precision_exp) { + // 工作空间大小与输入数据类型无关,因为内部累加器总是float32 + // acc_s_cast is no longer needed as the intermediate softmax result is kept in float32 + const size_t acc_o_size = threads * br * dim_size * sizeof(float); + const size_t acc_s_size = threads * br * bc * sizeof(float); + const size_t logsum_size = threads * br * sizeof(float); + const size_t scoremax_size = threads * br * sizeof(float); + const size_t scoremax_prev_size = threads * br * sizeof(float); + const size_t score_scale_size = threads * br * sizeof(float); + const size_t score_sum_size = threads * br * sizeof(float); + + // 分配对齐的工作空间 + void *workspace[7]; + mobi_attn::x86_align_alloc(&workspace[0], acc_o_size, 32); + mobi_attn::x86_align_alloc(&workspace[1], acc_s_size, 32); + mobi_attn::x86_align_alloc(&workspace[2], logsum_size, 32); + mobi_attn::x86_align_alloc(&workspace[3], scoremax_size, 32); + mobi_attn::x86_align_alloc(&workspace[4], scoremax_prev_size, 32); + mobi_attn::x86_align_alloc(&workspace[5], score_scale_size, 32); + mobi_attn::x86_align_alloc(&workspace[6], score_sum_size, 32); + + if (use_fp32) { + // 使用纯FP32实现 + mobi_attn::FlashAttn2T op; + op.configure(br, bc, q_head, kv_head, threads, high_precision_exp); + + op.init_workspace( + static_cast(workspace[0]), static_cast(workspace[1]), + static_cast(workspace[2]), static_cast(workspace[3]), + static_cast(workspace[4]), static_cast(workspace[5]), + static_cast(workspace[6])); + + op(static_cast(Q), static_cast(K), static_cast(V), + static_cast(O), + batch_size, head_size, seq_size_q, seq_size_k, dim_size, causal_mask); + } else { + // 使用FP16输入,FP32输出的实现 + mobi_attn::FlashAttn2T op; + op.configure(br, bc, q_head, kv_head, threads, high_precision_exp); + + op.init_workspace( + static_cast(workspace[0]), static_cast(workspace[1]), + static_cast(workspace[2]), static_cast(workspace[3]), + static_cast(workspace[4]), static_cast(workspace[5]), + static_cast(workspace[6])); + + op(static_cast(Q), static_cast(K), static_cast(V), + static_cast(O), + batch_size, head_size, seq_size_q, seq_size_k, dim_size, causal_mask); + } + + // 释放工作空间 + for (void *ptr : workspace) { + if (ptr) mobi_attn::x86_align_free(ptr); + } +} +#elif __ARM_NEON +#include +// 核心修改:将 x86 的 immintrin.h 替换为 ARM 的 arm_neon.h +#include +#include +#include +#include +#include +#include +#include "Types.hpp" +#include "VecDot.hpp" + +namespace mobi_attn { + +// ======================================== +// 数学函数和工具 (NEON版本) +// ======================================== +#define NEG_INF std::numeric_limits::lowest() + +// NEON版本:水平最大值 (Horizontal max of a float32x4_t vector) +// float32x4_t 中包含4个float, vmaxvq_f32可以直接找到这4个float中的最大值 +inline float _vmaxvq_f32_hmax(float32x4_t x) { + return vmaxvq_f32(x); +} + +// NEON版本:水平求和 (Horizontal sum of a float32x4_t vector) +// float32x4_t 中包含4个float, vaddvq_f32可以直接将这4个float相加 +inline float _vaddvq_f32_hadd(float32x4_t x) { + return vaddvq_f32(x); +} + +// ======================================== +// 高性能NEON数学函数 (新增) +// ======================================== + +// 基于多项式逼近的快速exp实现 (NEON版本) +inline float32x4_t vexpq_fast_f32(float32x4_t x) { + // 定义常量 + const float32x4_t c0 = vdupq_n_f32(1.0f); + const float32x4_t c1 = vdupq_n_f32(0.0416598990559578f); + const float32x4_t c2 = vdupq_n_f32(0.166664719581604f); + const float32x4_t c3 = vdupq_n_f32(0.5000005960464478f); + const float32x4_t log2e = vdupq_n_f32(1.4426950408889634f); + const float32x4_t ln2_hi = vdupq_n_f32(0.693145751953125f); + const float32x4_t ln2_lo = vdupq_n_f32(1.428606765330187e-06f); + const int32x4_t M_126 = vdupq_n_s32(126); + + // 计算 y = x * log2(e) + float32x4_t y = vmulq_f32(x, log2e); + + // 对y取整, n = round(y) + int32x4_t n = vcvtaq_s32_f32(y); + + // 计算 z = x - n * ln2 + float32x4_t n_f = vcvtq_f32_s32(n); + float32x4_t z = vmlsq_f32(x, n_f, ln2_hi); + z = vmlsq_f32(z, n_f, ln2_lo); + + // 多项式逼近 exp(z) ~= 1 + z + z^2/2! + ... + float32x4_t poly = c1; + poly = vmlaq_f32(c2, poly, z); + poly = vmlaq_f32(c3, poly, z); + poly = vmlaq_f32(c0, poly, z); + poly = vmlaq_f32(poly, vmulq_f32(z, z), poly); + + // 组合结果: poly * 2^n + int32x4_t m = vaddq_s32(n, M_126); + m = vshlq_n_s32(m, 23); + + return vmulq_f32(poly, vreinterpretq_f32_s32(m)); +} + +// ======================================== +// 内存对齐分配函数 (重命名以去除x86特定性) +// ======================================== +// 使用 posix_memalign 进行分配,此函数在支持POSIX的系统(如Linux)上是通用的 +void aligned_alloc(void **ptr, size_t required_bytes, size_t align) { + // posix_memalign 要求 alignment 必须是 void* 大小的整数倍,并且是 2 的幂 + if (align % sizeof(void *) != 0 || (align & (align - 1)) != 0) { + *ptr = nullptr; + return; + } + + // posix_memalign 返回 0 表示成功,否则返回错误码 + if (posix_memalign(ptr, align, required_bytes) != 0) { + *ptr = nullptr; + } +} + +// 直接使用标准 free 进行释放 +void aligned_free(void *ptr) { + free(ptr); +} + +// ======================================== +// FlashAttention2 核心实现 (FP32版本, NEON) +// ======================================== +struct NEON_FA_2_GQA_QKV_FP32_BSHD_O_FP32_BSHD_ACC_FP32_IMPL { + using dtype_q_in_t = float; + using dtype_kv_in_t = dtype_q_in_t; + using dtype_out_t = dtype_q_in_t; + using dtype_t = dtype_out_t; + using acc_dtype_t = float; + + // 添加配置参数作为成员变量 + int32_t Br; + int32_t Bc; + int32_t Q_Head; + int32_t KV_Head; + int32_t threads; + bool high_precision; + + // 配置参数初始化 + void configure(int32_t Br_, int32_t Bc_, int32_t Q_Head_, int32_t KV_Head_, int32_t threads_, bool high_precision_) { + Br = Br_; + Bc = Bc_; + Q_Head = Q_Head_; + KV_Head = KV_Head_; + threads = threads_; + high_precision = high_precision_; + } + + // 初始化工作空间指针 + void init_workspace(acc_dtype_t *acc_o, acc_dtype_t *acc_s, + acc_dtype_t *logsum, acc_dtype_t *scoremax, acc_dtype_t *scoremax_prev, + acc_dtype_t *score_scale, acc_dtype_t *score_sum) { + acc_o_ = acc_o; + acc_s_ = acc_s; + logsum_ = logsum; + scoremax_ = scoremax; + scoremax_prev_ = scoremax_prev; + score_scale_ = score_scale; + score_sum_ = score_sum; + } + + // fa2 主函数,根据Q的序列长度分发到 prefill/append 或 decode 模式 + void fa2(const dtype_t *__restrict__ Q, const dtype_t *__restrict__ K, + const dtype_t *__restrict__ V, dtype_t *__restrict__ O, const int32_t batch_size, + const int32_t head_size, const int32_t seq_size_q, const int32_t seq_size_k, + const int32_t dim_size, bool causal_mask = true) { + assert(Br == Bc); + // NEON 一次处理 4 个 float + assert(dim_size % 4 == 0); + // 确保Q头是KV头的整数倍,这是GQA/MHA的有效性要求 + assert(Q_Head % KV_Head == 0); + assert(head_size % threads == 0); + + if (seq_size_q != 1) { + __fa2_prefill_append(Q, K, V, O, batch_size, head_size, seq_size_q, seq_size_k, dim_size, + causal_mask); + } else { + __fa2_decode(Q, K, V, O, batch_size, head_size, seq_size_q, seq_size_k, dim_size, + causal_mask); + } + } + + // ========================================================================================= + // 以下是 NEON_FA_2_GQA_QKV_FP32_BSHD_O_FP32_BSHD_ACC_FP32_IMPL 结构体内部的私有函数实现 + // 承接第一部分的代码 + // ========================================================================================= + +private: + // 核心计算函数 (prefill/append 模式, 适用于 seq_size_q > 1) + inline void __fa2_prefill_append(const dtype_t *__restrict__ Q, const dtype_t *__restrict__ K, + const dtype_t *__restrict__ V, dtype_t *__restrict__ O, + const int32_t batch_size, const int32_t head_size, // head_size 就是 Q_Head + const int32_t seq_size_q, const int32_t seq_size_k, + const int32_t dim_size, bool causal_mask = true) { + // Tr, Tc 分别是 Q 和 K/V 在序列长度维度上被切分成的块数 + const int32_t Tr = seq_size_q / Br; + const int32_t Tr_left = seq_size_q % Br; + const int32_t Tc = seq_size_k / Bc; + const int32_t Tc_left = seq_size_k % Bc; + + const float local_scale = 1.0f / sqrtf(static_cast(dim_size)); + + // 计算Q头与KV头的分组对应关系 (GQA) + const int32_t kv_group_size = Q_Head / KV_Head; + + for (int32_t b_idx = 0; b_idx < batch_size; ++b_idx) { +#pragma omp parallel for num_threads(threads) schedule(dynamic, 1) if (threads > 1) + for (int32_t h_idx = 0; h_idx < head_size; ++h_idx) { // h_idx 是当前Q头的索引 + const int32_t thread_id = omp_get_thread_num(); + const int32_t this_thread_head = h_idx; + + // 计算当前Q头 (h_idx) 对应的KV头索引 + const int32_t this_thread_kv_head = this_thread_head / kv_group_size; + + // --- 主循环 (处理完整的块) --- + for (int t_r_idx = 0; t_r_idx < Tr; ++t_r_idx) { + // 初始化该线程的临时工作空间 + init_temp(logsum_ + thread_id * Br, scoremax_ + thread_id * Br, + acc_o_ + thread_id * Br * dim_size, dim_size); + + for (int t_c_idx = 0; t_c_idx < Tc; ++t_c_idx) { + // Q 的指针计算,它有 Q_Head (==head_size) 个头 + const dtype_t *tile_q = Q + b_idx * seq_size_q * head_size * dim_size + t_r_idx * Br * head_size * dim_size + this_thread_head * dim_size; + + // K 和 V 的指针计算,必须使用 KV_Head 和映射后的 this_thread_kv_head + const dtype_t *tile_k = K + b_idx * seq_size_k * KV_Head * dim_size + t_c_idx * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + const dtype_t *tile_v = V + b_idx * seq_size_k * KV_Head * dim_size + t_c_idx * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + + acc_dtype_t *tile_acc_s = acc_s_ + thread_id * Br * Bc; + acc_dtype_t *acc_o = acc_o_ + thread_id * Br * dim_size; + + // Step 1: Q * K^T + mma0(tile_q, tile_k, tile_acc_s, dim_size, head_size * dim_size, KV_Head * dim_size, t_r_idx, t_c_idx, seq_size_q, seq_size_k, causal_mask); + // Step 2: Softmax + softmax(tile_acc_s, scoremax_ + thread_id * Br, scoremax_prev_ + thread_id * Br, score_scale_ + thread_id * Br, score_sum_ + thread_id * Br, logsum_ + thread_id * Br, local_scale, t_r_idx, t_c_idx, seq_size_q, seq_size_k, causal_mask); + // Step 3: Rescale O + rescale(acc_o, score_scale_ + thread_id * Br, dim_size, t_r_idx, t_c_idx, seq_size_q, seq_size_k, causal_mask); + // Step 4: P * V + mma1(tile_acc_s, tile_v, acc_o, KV_Head, dim_size, t_r_idx, t_c_idx, seq_size_q, seq_size_k, causal_mask); + } + // --- 处理 K/V 序列的剩余部分 (Tc_left) --- + if (Tc_left) { + const dtype_t *tile_q = Q + b_idx * seq_size_q * head_size * dim_size + t_r_idx * Br * head_size * dim_size + this_thread_head * dim_size; + const dtype_t *tile_k = K + b_idx * seq_size_k * KV_Head * dim_size + Tc * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + const dtype_t *tile_v = V + b_idx * seq_size_k * KV_Head * dim_size + Tc * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + acc_dtype_t *tile_acc_s = acc_s_ + thread_id * Br * Bc; + acc_dtype_t *acc_o = acc_o_ + thread_id * Br * dim_size; + mma0_pa_n_fixed(Br, Tc_left, tile_q, tile_k, tile_acc_s, dim_size, head_size * dim_size, KV_Head * dim_size, t_r_idx, Tc, seq_size_q, seq_size_k, causal_mask); + softmax_pa_n_fixed(Br, Tc_left, tile_acc_s, scoremax_ + thread_id * Br, scoremax_prev_ + thread_id * Br, score_scale_ + thread_id * Br, score_sum_ + thread_id * Br, logsum_ + thread_id * Br, local_scale, t_r_idx, Tc, seq_size_q, seq_size_k, causal_mask); + rescale_pa_n_fixed(Br, Tc_left, acc_o, score_scale_ + thread_id * Br, dim_size, t_r_idx, Tc, seq_size_q, seq_size_k, causal_mask); + mma1_pa_n_fixed(Br, Tc_left, tile_acc_s, tile_v, acc_o, KV_Head, dim_size, t_r_idx, Tc, seq_size_q, seq_size_k, causal_mask); + } + // Step 5: 将最终结果缩放并存回输出 O + scale_and_store(acc_o_ + thread_id * Br * dim_size, logsum_ + thread_id * Br, O + b_idx * seq_size_q * head_size * dim_size + t_r_idx * Br * head_size * dim_size + this_thread_head * dim_size, t_r_idx, head_size, dim_size); + } + // --- 处理 Q 序列的剩余部分 (Tr_left) --- + if (Tr_left) { + init_temp(logsum_ + thread_id * Br, scoremax_ + thread_id * Br, acc_o_ + thread_id * Br * dim_size, dim_size); + for (int t_c_idx = 0; t_c_idx < Tc; ++t_c_idx) { + const dtype_t *tile_q = Q + b_idx * seq_size_q * head_size * dim_size + Tr * Br * head_size * dim_size + this_thread_head * dim_size; + const dtype_t *tile_k = K + b_idx * seq_size_k * KV_Head * dim_size + t_c_idx * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + const dtype_t *tile_v = V + b_idx * seq_size_k * KV_Head * dim_size + t_c_idx * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + acc_dtype_t *tile_acc_s = acc_s_ + thread_id * Br * Bc; + acc_dtype_t *acc_o = acc_o_ + thread_id * Br * dim_size; + mma0_pa_n_fixed(Tr_left, Bc, tile_q, tile_k, tile_acc_s, dim_size, head_size * dim_size, KV_Head * dim_size, Tr, t_c_idx, seq_size_q, seq_size_k, causal_mask); + softmax_pa_n_fixed(Tr_left, Bc, tile_acc_s, scoremax_ + thread_id * Br, scoremax_prev_ + thread_id * Br, score_scale_ + thread_id * Br, score_sum_ + thread_id * Br, logsum_ + thread_id * Br, local_scale, Tr, t_c_idx, seq_size_q, seq_size_k, causal_mask); + rescale_pa_n_fixed(Tr_left, Bc, acc_o, score_scale_ + thread_id * Br, dim_size, Tr, t_c_idx, seq_size_q, seq_size_k, causal_mask); + mma1_pa_n_fixed(Tr_left, Bc, tile_acc_s, tile_v, acc_o, KV_Head, dim_size, Tr, t_c_idx, seq_size_q, seq_size_k, causal_mask); + } + if (Tc_left) { + const dtype_t *tile_q = Q + b_idx * seq_size_q * head_size * dim_size + Tr * Br * head_size * dim_size + this_thread_head * dim_size; + const dtype_t *tile_k = K + b_idx * seq_size_k * KV_Head * dim_size + Tc * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + const dtype_t *tile_v = V + b_idx * seq_size_k * KV_Head * dim_size + Tc * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + acc_dtype_t *tile_acc_s = acc_s_ + thread_id * Br * Bc; + acc_dtype_t *acc_o = acc_o_ + thread_id * Br * dim_size; + mma0_pa_n_fixed(Tr_left, Tc_left, tile_q, tile_k, tile_acc_s, dim_size, head_size * dim_size, KV_Head * dim_size, Tr, Tc, seq_size_q, seq_size_k, causal_mask); + softmax_pa_n_fixed(Tr_left, Tc_left, tile_acc_s, scoremax_ + thread_id * Br, scoremax_prev_ + thread_id * Br, score_scale_ + thread_id * Br, score_sum_ + thread_id * Br, logsum_ + thread_id * Br, local_scale, Tr, Tc, seq_size_q, seq_size_k, causal_mask); + rescale_pa_n_fixed(Tr_left, Tc_left, acc_o, score_scale_ + thread_id * Br, dim_size, Tr, Tc, seq_size_q, seq_size_k, causal_mask); + mma1_pa_n_fixed(Tr_left, Tc_left, tile_acc_s, tile_v, acc_o, KV_Head, dim_size, Tr, Tc, seq_size_q, seq_size_k, causal_mask); + } + scale_and_store_pa_n_fixed(Tr_left, acc_o_ + thread_id * Br * dim_size, logsum_ + thread_id * Br, O + b_idx * seq_size_q * head_size * dim_size + Tr * Br * head_size * dim_size + this_thread_head * dim_size, Tr, head_size, dim_size); + } + } + } + } + inline void __fa2_decode(const dtype_t *__restrict__ Q, const dtype_t *__restrict__ K, + const dtype_t *__restrict__ V, dtype_t *__restrict__ O, + const int32_t batch_size, const int32_t head_size, + const int32_t seq_size_q, const int32_t seq_size_k, + const int32_t dim_size, bool causal_mask = true) { + // 在 decode 模式下, Q 的序列长度固定为 1, 因此 Tr = 1, Br = 1 + const int32_t Tr = 1; + const int32_t Tc = seq_size_k / Bc; + const int32_t Tc_left = seq_size_k % Bc; + + const float local_scale = 1.0f / sqrtf(static_cast(dim_size)); + const int32_t kv_group_size = (Q_Head > 0 && KV_Head > 0) ? Q_Head / KV_Head : 1; + + for (int32_t b_idx = 0; b_idx < batch_size; ++b_idx) { +#pragma omp parallel for num_threads(threads) schedule(dynamic, 1) if (threads > 1) + for (int32_t h_idx = 0; h_idx < head_size; ++h_idx) { + const int32_t thread_id = omp_get_thread_num(); + const int32_t this_thread_head = h_idx; + const int32_t this_thread_kv_head = this_thread_head / kv_group_size; + + const int t_r_idx = 0; + + init_temp_d(logsum_ + thread_id * Br, scoremax_ + thread_id * Br, acc_o_ + thread_id * Br * dim_size, dim_size); + + for (int t_c_idx = 0; t_c_idx < Tc; ++t_c_idx) { + const dtype_t *tile_q = Q + b_idx * seq_size_q * head_size * dim_size + t_r_idx * 1 * head_size * dim_size + this_thread_head * dim_size; + const dtype_t *tile_k = K + b_idx * seq_size_k * KV_Head * dim_size + t_c_idx * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + const dtype_t *tile_v = V + b_idx * seq_size_k * KV_Head * dim_size + t_c_idx * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + + acc_dtype_t *tile_acc_s = acc_s_ + thread_id * Br * Bc; + acc_dtype_t *acc_o = acc_o_ + thread_id * Br * dim_size; + + mma0_d(tile_q, tile_k, tile_acc_s, dim_size, KV_Head * dim_size, t_r_idx, t_c_idx, seq_size_q, seq_size_k, causal_mask); + softmax_d(tile_acc_s, scoremax_ + thread_id * Br, scoremax_prev_ + thread_id * Br, score_scale_ + thread_id * Br, score_sum_ + thread_id * Br, logsum_ + thread_id * Br, local_scale, t_r_idx, t_c_idx, seq_size_q, seq_size_k, causal_mask); + rescale_d(acc_o, score_scale_ + thread_id * Br, dim_size, t_r_idx, t_c_idx, seq_size_q, seq_size_k, causal_mask); + mma1_d(tile_acc_s, tile_v, acc_o, KV_Head, dim_size, t_r_idx, t_c_idx, seq_size_q, seq_size_k, causal_mask); + } + + if (Tc_left) { + const dtype_t *tile_q = Q + b_idx * seq_size_q * head_size * dim_size + t_r_idx * 1 * head_size * dim_size + this_thread_head * dim_size; + const dtype_t *tile_k = K + b_idx * seq_size_k * KV_Head * dim_size + Tc * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + const dtype_t *tile_v = V + b_idx * seq_size_k * KV_Head * dim_size + Tc * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + acc_dtype_t *tile_acc_s = acc_s_ + thread_id * Br * Bc; + acc_dtype_t *acc_o = acc_o_ + thread_id * Br * dim_size; + mma0_d_n_fixed(Tc_left, tile_q, tile_k, tile_acc_s, dim_size, KV_Head * dim_size, t_r_idx, Tc, seq_size_q, seq_size_k, causal_mask); + softmax_d_n_fixed(Tc_left, tile_acc_s, scoremax_ + thread_id * Br, scoremax_prev_ + thread_id * Br, score_scale_ + thread_id * Br, score_sum_ + thread_id * Br, logsum_ + thread_id * Br, local_scale, t_r_idx, Tc, seq_size_q, seq_size_k, causal_mask); + rescale_d_n_fixed(Tc_left, acc_o, score_scale_ + thread_id * Br, dim_size, t_r_idx, Tc, seq_size_q, seq_size_k, causal_mask); + mma1_d_n_fixed(Tc_left, tile_acc_s, tile_v, acc_o, KV_Head, dim_size, t_r_idx, Tc, seq_size_q, seq_size_k, causal_mask); + } + + scale_and_store_d(acc_o_ + thread_id * Br * dim_size, logsum_ + thread_id * Br, O + b_idx * seq_size_q * head_size * dim_size + t_r_idx * 1 * head_size * dim_size + this_thread_head * dim_size, t_r_idx, head_size, dim_size); + } + } + } + + // 初始化临时工作区 (NEON 版本) + inline void init_temp(acc_dtype_t *logsum, acc_dtype_t *scoremax, acc_dtype_t *acc_o, const int32_t dim_size) { + float32x4_t zero_vec = vdupq_n_f32(0.0f); + float32x4_t neg_inf_vec = vdupq_n_f32(NEG_INF); + + int i = 0; + // NEON 一次处理4个 + for (; i <= Br - 4; i += 4) { + vst1q_f32(logsum + i, zero_vec); + vst1q_f32(scoremax + i, neg_inf_vec); + } + // 处理剩余的元素(如果Br不是4的倍数) + for (; i < Br; ++i) { + logsum[i] = 0.0f; + scoremax[i] = NEG_INF; + } + + // acc_o 的初始化, 调用者保证了 dim_size % 4 == 0 + for (int j = 0; j < Br * dim_size; j += 4) { + vst1q_f32(acc_o + j, zero_vec); + } + } + + // Q * K^T 计算 (NEON 版本) + inline void mma0(const dtype_t *__restrict__ q_block, const dtype_t *__restrict__ k_block, + acc_dtype_t *__restrict__ acc_s, const int32_t dim_size, + const int32_t q_stride_size, const int32_t kv_stride_size, + const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, bool causal_mask) { + const int32_t global_r_start = t_r_idx * Br; + const int32_t global_r_end = global_r_start + Br; + const int32_t global_c_start = t_c_idx * Bc; + int delta_pos = seq_size_k - seq_size_q; + + if (causal_mask && (global_c_start - delta_pos > (global_r_end - 1))) { return; } + + for (int32_t b_r_idx = 0; b_r_idx < Br; ++b_r_idx) { + const dtype_t *q_block_line = q_block + b_r_idx * q_stride_size; + for (int32_t b_c_idx = 0; b_c_idx < Bc; ++b_c_idx) { + const dtype_t *k_block_line = k_block + b_c_idx * kv_stride_size; + + float32x4_t sum_vec = vdupq_n_f32(0.0f); + int i = 0; + // 为了性能,可以一次处理8个或更多元素,这里保持与AVX版本类似的逻辑,但向量宽度减半 + for (; i <= dim_size - 8; i += 8) { + // 预取数据到缓存 + __builtin_prefetch(q_block_line + i + 64); + __builtin_prefetch(k_block_line + i + 64); + // 加载两组4个float数据 + float32x4_t q_vec0 = vld1q_f32(q_block_line + i); + float32x4_t k_vec0 = vld1q_f32(k_block_line + i); + float32x4_t q_vec1 = vld1q_f32(q_block_line + i + 4); + float32x4_t k_vec1 = vld1q_f32(k_block_line + i + 4); + // 融合乘加 + sum_vec = vfmaq_f32(sum_vec, q_vec0, k_vec0); + sum_vec = vfmaq_f32(sum_vec, q_vec1, k_vec1); + } + // 处理 dim_size % 8 剩下的部分 + for (; i <= dim_size - 4; i += 4) { + float32x4_t q_vec = vld1q_f32(q_block_line + i); + float32x4_t k_vec = vld1q_f32(k_block_line + i); + sum_vec = vfmaq_f32(sum_vec, q_vec, k_vec); + } + + acc_dtype_t total = _vaddvq_f32_hadd(sum_vec); + // 处理最后不足4个的元素 + for (; i < dim_size; ++i) { total += q_block_line[i] * k_block_line[i]; } + + acc_s[b_r_idx * Bc + b_c_idx] = total; + } + } + // 应用因果掩码 + if (causal_mask && (global_r_end == (t_c_idx * Bc + Bc) - delta_pos)) { + for (int i = 0; i < Br; ++i) { + for (int j = 0; j < Bc; ++j) { + if (j > i) { acc_s[i * Bc + j] = NEG_INF; } + } + } + } + } + + // Softmax (NEON 版本) + inline void softmax(acc_dtype_t *__restrict__ acc_s, acc_dtype_t *scoremax, acc_dtype_t *scoremax_prev, + acc_dtype_t *score_scale, acc_dtype_t *score_sum, acc_dtype_t *logsum, + const float scale, + const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, bool causal_mask) { + const int32_t global_r_start = t_r_idx * Br; + const int32_t global_c_start = t_c_idx * Bc; + int delta_pos = seq_size_k - seq_size_q; + if (causal_mask && (global_c_start - delta_pos > (global_r_start + Br - 1))) return; + + memcpy(scoremax_prev, scoremax, Br * sizeof(acc_dtype_t)); + + // 1. 找到每行的最大值 m_i + for (int br = 0; br < Br; ++br) { + float32x4_t max_vec = vdupq_n_f32(scoremax[br]); + acc_dtype_t *row = acc_s + br * Bc; + int bc = 0; + for (; bc <= Bc - 4; bc += 4) { + max_vec = vmaxq_f32(max_vec, vld1q_f32(row + bc)); + } + float max_val = _vmaxvq_f32_hmax(max_vec); + for (; bc < Bc; ++bc) { max_val = fmaxf(max_val, row[bc]); } + scoremax[br] = max_val; + } + + // 2. 计算缩放因子 s_i = exp((m_i_prev - m_i) * scale) + for (int br = 0; br < Br; ++br) { + score_scale[br] = expf((scoremax_prev[br] - scoremax[br]) * scale); + } + + // 3. 计算 P_ij = exp((S_ij - m_i) * scale) 和 l_i = sum(P_ij) + for (int br = 0; br < Br; ++br) { + const float sm = scoremax[br]; + acc_dtype_t *row = acc_s + br * Bc; + float sum = 0.0f; + // 这里可以进一步用NEON优化expf, 但expf的SIMD实现复杂,暂用标量 + for (int bc = 0; bc < Bc; ++bc) { + float val = expf((row[bc] - sm) * scale); + row[bc] = val; // 更新 acc_s 为 P_ij + sum += val; + } + score_sum[br] = sum; + } + + // 4. 更新 logsum: l_i_new = l_i_prev * s_i + l_i + for (int br = 0; br < Br; ++br) { + logsum[br] = logsum[br] * score_scale[br] + score_sum[br]; + } + } + + // 重缩放累加的输出 O (NEON 版本) + inline void rescale(acc_dtype_t *__restrict__ acc_o, acc_dtype_t *__restrict__ score_scale, + const int32_t dim_size, const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, bool causal_mask) { + const int32_t global_r_start = t_r_idx * Br; + const int32_t global_c_start = t_c_idx * Bc; + int delta_pos = seq_size_k - seq_size_q; + if (causal_mask && (global_c_start - delta_pos > (global_r_start + Br - 1))) return; + + for (int i = 0; i < Br; ++i) { + float32x4_t scale_v = vdupq_n_f32(score_scale[i]); + float *row_ptr = acc_o + i * dim_size; + for (int j = 0; j < dim_size; j += 4) { + float32x4_t acc = vld1q_f32(row_ptr + j); + acc = vmulq_f32(acc, scale_v); + vst1q_f32(row_ptr + j, acc); + } + } + } + + // P * V 计算 (NEON 版本) + inline void mma1(const acc_dtype_t *__restrict__ w_block, const dtype_t *__restrict__ v_block, + acc_dtype_t *__restrict__ acc_o, const int32_t kv_head_size, const int32_t dim_size, + const int32_t t_r_idx, const int32_t t_c_idx, const int32_t seq_size_q, + const int32_t seq_size_k, bool causal_mask) { + const int32_t global_r_start = t_r_idx * Br; + const int32_t global_c_start = t_c_idx * Bc; + int delta_pos = seq_size_k - seq_size_q; + if (causal_mask && (global_c_start - delta_pos > (global_r_start + Br - 1))) return; + + const int32_t v_stride_size = kv_head_size * dim_size; + + for (int b_r_idx = 0; b_r_idx < Br; ++b_r_idx) { + for (int d_base = 0; d_base < dim_size; d_base += 4) { + float32x4_t acc = vld1q_f32(acc_o + b_r_idx * dim_size + d_base); + for (int b_c_idx = 0; b_c_idx < Bc; ++b_c_idx) { + float32x4_t w_vec = vdupq_n_f32(w_block[b_r_idx * Bc + b_c_idx]); + const float *v_ptr = v_block + b_c_idx * v_stride_size + d_base; + float32x4_t v_vec = vld1q_f32(v_ptr); + acc = vfmaq_f32(acc, w_vec, v_vec); + } + vst1q_f32(acc_o + b_r_idx * dim_size + d_base, acc); + } + } + } + + // 缩放并存储最终结果 (NEON 版本) + inline void scale_and_store(const acc_dtype_t *__restrict__ acc_o, + const acc_dtype_t *__restrict__ logsum, + dtype_t *__restrict__ o_block, const int32_t t_r_idx, + const int32_t head_size, const int32_t dim_size) { + for (int i = 0; i < Br; ++i) { + dtype_t *o_block_line = o_block + i * head_size * dim_size; + float reciprocal_logsum = 1.0f / logsum[i]; + float32x4_t reciprocal_logsum_vec = vdupq_n_f32(reciprocal_logsum); + int j = 0; + for (; j <= dim_size - 4; j += 4) { + float32x4_t vec_acc_o = vld1q_f32(acc_o + i * dim_size + j); + float32x4_t result_vec = vmulq_f32(vec_acc_o, reciprocal_logsum_vec); + vst1q_f32(o_block_line + j, result_vec); + } + for (; j < dim_size; ++j) { + o_block_line[j] = acc_o[i * dim_size + j] * reciprocal_logsum; + } + } + } + + // --- 处理剩余块的 N-fixed 函数 (NEON 版本) --- + inline void mma0_pa_n_fixed(const int32_t Br_n_fixed, const int32_t Bc_n_fixed, + const dtype_t *__restrict__ q_block, + const dtype_t *__restrict__ k_block, acc_dtype_t *__restrict__ acc_s, + const int32_t dim_size, const int32_t q_stride_size, const int32_t kv_stride_size, + const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, + bool causal_mask) { + const int32_t global_r_start = t_r_idx * Br; + const int32_t global_r_end = global_r_start + Br_n_fixed; + const int32_t global_c_start = t_c_idx * Bc; + int delta_pos = seq_size_k - seq_size_q; + if (causal_mask && (global_c_start - delta_pos > (global_r_end - 1))) { return; } + + for (int32_t b_r_idx = 0; b_r_idx < Br_n_fixed; ++b_r_idx) { + const dtype_t *q_block_line = q_block + b_r_idx * q_stride_size; + for (int32_t b_c_idx = 0; b_c_idx < Bc_n_fixed; ++b_c_idx) { + const dtype_t *k_block_line = k_block + b_c_idx * kv_stride_size; + float32x4_t sum_vec = vdupq_n_f32(0.0f); + int i = 0; + for (; i <= dim_size - 4; i += 4) { + sum_vec = vfmaq_f32(sum_vec, vld1q_f32(q_block_line + i), vld1q_f32(k_block_line + i)); + } + acc_dtype_t total = _vaddvq_f32_hadd(sum_vec); + for (; i < dim_size; ++i) { total += q_block_line[i] * k_block_line[i]; } + acc_s[b_r_idx * Bc + b_c_idx] = total; + } + } + + if (causal_mask && (global_r_end == (global_c_start + Bc_n_fixed) - delta_pos)) { + for (int i = 0; i < Br_n_fixed; ++i) { + for (int j = 0; j < Bc_n_fixed; ++j) { + if (j > i) { acc_s[i * Bc + j] = NEG_INF; } + } + } + } + } + + // (这里的逻辑主要是标量,和AVX版本基本一致) + inline void softmax_pa_n_fixed(const int32_t Br_n_fixed, const int32_t Bc_n_fixed, + acc_dtype_t *__restrict__ acc_s, acc_dtype_t *scoremax, + acc_dtype_t *scoremax_prev, acc_dtype_t *score_scale, + acc_dtype_t *score_sum, acc_dtype_t *logsum, + const float scale, + const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, bool causal_mask) { + const int32_t global_r_start = t_r_idx * Br; + const int32_t global_c_start = t_c_idx * Bc; + int delta_pos = seq_size_k - seq_size_q; + if (causal_mask && (global_c_start - delta_pos > (global_r_start + Br_n_fixed - 1))) return; + + memcpy(scoremax_prev, scoremax, Br_n_fixed * sizeof(acc_dtype_t)); + + for (int br = 0; br < Br_n_fixed; ++br) { + acc_dtype_t *row = acc_s + br * Bc; + float max_val = NEG_INF; + for (int bc = 0; bc < Bc_n_fixed; ++bc) max_val = fmaxf(max_val, row[bc]); + scoremax[br] = fmaxf(max_val, scoremax[br]); + } + for (int br = 0; br < Br_n_fixed; ++br) { + score_scale[br] = expf((scoremax_prev[br] - scoremax[br]) * scale); + } + for (int br = 0; br < Br_n_fixed; ++br) { + acc_dtype_t *row = acc_s + br * Bc; + float current_sum = 0.0f; + for (int bc = 0; bc < Bc_n_fixed; ++bc) { + float val = expf((row[bc] - scoremax[br]) * scale); + row[bc] = val; + current_sum += val; + } + score_sum[br] = current_sum; + } + for (int br = 0; br < Br_n_fixed; ++br) { logsum[br] = logsum[br] * score_scale[br] + score_sum[br]; } + } + + inline void rescale_pa_n_fixed(const int32_t Br_n_fixed, const int32_t Bc_n_fixed, + acc_dtype_t *__restrict__ acc_o, + acc_dtype_t *__restrict__ score_scale, const int32_t dim_size, + const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, + bool causal_mask) { + const int32_t global_r_start = t_r_idx * Br; + const int32_t global_c_start = t_c_idx * Bc; + int delta_pos = seq_size_k - seq_size_q; + if (causal_mask && (global_c_start - delta_pos > (global_r_start + Br_n_fixed - 1))) return; + + for (int i = 0; i < Br_n_fixed; ++i) { + float *row_ptr = acc_o + i * dim_size; + float32x4_t scale_v = vdupq_n_f32(score_scale[i]); + for (int j = 0; j < dim_size; j += 4) { + vst1q_f32(row_ptr + j, vmulq_f32(vld1q_f32(row_ptr + j), scale_v)); + } + } + } + + inline void mma1_pa_n_fixed(const int32_t Br_n_fixed, const int32_t Bc_n_fixed, + const acc_dtype_t *__restrict__ w_block, + const dtype_t *__restrict__ v_block, acc_dtype_t *__restrict__ acc_o, + const int32_t kv_head_size, const int32_t dim_size, + const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, + bool causal_mask) { + const int32_t global_r_start = t_r_idx * Br; + const int32_t global_c_start = t_c_idx * Bc; + int delta_pos = seq_size_k - seq_size_q; + if (causal_mask && (global_c_start - delta_pos > (global_r_start + Br_n_fixed - 1))) return; + + const int32_t v_stride_size = kv_head_size * dim_size; + + for (int b_r_idx = 0; b_r_idx < Br_n_fixed; ++b_r_idx) { + for (int d_base = 0; d_base < dim_size; d_base += 4) { + float32x4_t acc = vld1q_f32(acc_o + b_r_idx * dim_size + d_base); + for (int b_c_idx = 0; b_c_idx < Bc_n_fixed; ++b_c_idx) { + float32x4_t w_vec = vdupq_n_f32(w_block[b_r_idx * Bc + b_c_idx]); + const float *v_ptr = v_block + b_c_idx * v_stride_size + d_base; + acc = vfmaq_f32(acc, w_vec, vld1q_f32(v_ptr)); + } + vst1q_f32(acc_o + b_r_idx * dim_size + d_base, acc); + } + } + } + + inline void scale_and_store_pa_n_fixed(const int32_t Br_n_fixed, + const acc_dtype_t *__restrict__ acc_o, + const acc_dtype_t *__restrict__ logsum, + dtype_t *__restrict__ o_block, const int32_t t_r_idx, + const int32_t head_size, const int32_t dim_size) { + for (int i = 0; i < Br_n_fixed; ++i) { + dtype_t *o_block_line = o_block + i * head_size * dim_size; + float reciprocal_logsum = 1.0f / logsum[i]; + float32x4_t reciprocal_logsum_vec = vdupq_n_f32(reciprocal_logsum); + int j = 0; + for (; j <= dim_size - 4; j += 4) { + float32x4_t vec_acc_o = vld1q_f32(acc_o + i * dim_size + j); + vst1q_f32(o_block_line + j, vmulq_f32(vec_acc_o, reciprocal_logsum_vec)); + } + for (; j < dim_size; ++j) { + o_block_line[j] = acc_o[i * dim_size + j] * reciprocal_logsum; + } + } + } + + // --- Decode 模式的辅助函数 (NEON 版本) --- + + inline void init_temp_d(acc_dtype_t *logsum, acc_dtype_t *scoremax, acc_dtype_t *acc_o, + const int32_t dim_size) { + logsum[0] = 0.0f; + scoremax[0] = NEG_INF; + float32x4_t zero_vec = vdupq_n_f32(0.0f); + // Br 在 decode 模式下为 1 + for (int i = 0; i < 1 * dim_size; i += 4) { + vst1q_f32(acc_o + i, zero_vec); + } + } + + inline void mma0_d(const dtype_t *__restrict__ q_block, const dtype_t *__restrict__ k_block, + acc_dtype_t *__restrict__ acc_s, const int32_t dim_size, + const int32_t kv_stride_size, const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, bool causal_mask) { + const dtype_t *q_block_line = q_block; // q 只有一个向量 + for (int32_t b_c_idx = 0; b_c_idx < Bc; ++b_c_idx) { + const dtype_t *k_block_line = k_block + b_c_idx * kv_stride_size; + float32x4_t sum_vec = vdupq_n_f32(0.0f); + int i = 0; + for (; i <= dim_size - 4; i += 4) { + sum_vec = vfmaq_f32(sum_vec, vld1q_f32(q_block_line + i), vld1q_f32(k_block_line + i)); + } + acc_dtype_t total = _vaddvq_f32_hadd(sum_vec); + for (; i < dim_size; ++i) { total += q_block_line[i] * k_block_line[i]; } + acc_s[b_c_idx] = total; + } + } + + inline void softmax_d(acc_dtype_t *__restrict__ acc_s, acc_dtype_t *scoremax, + acc_dtype_t *scoremax_prev, acc_dtype_t *score_scale, + acc_dtype_t *score_sum, acc_dtype_t *logsum, + const float scale, + const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, bool causal_mask) { + scoremax_prev[0] = scoremax[0]; + + float max_val = scoremax[0]; + for (int bc = 0; bc < Bc; ++bc) { + max_val = fmaxf(max_val, acc_s[bc]); + } + scoremax[0] = max_val; + + score_scale[0] = expf((scoremax_prev[0] - scoremax[0]) * scale); + + float current_sum = 0.0f; + for (int bc = 0; bc < Bc; ++bc) { + float val = expf((acc_s[bc] - scoremax[0]) * scale); + acc_s[bc] = val; + current_sum += val; + } + score_sum[0] = current_sum; + + logsum[0] = logsum[0] * score_scale[0] + score_sum[0]; + } + + inline void rescale_d(acc_dtype_t *__restrict__ acc_o, acc_dtype_t *__restrict__ score_scale, + const int32_t dim_size, const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, bool causal_mask) { + float32x4_t scale_v = vdupq_n_f32(score_scale[0]); + for (int j = 0; j < dim_size; j += 4) { + float32x4_t acc = vld1q_f32(acc_o + j); + acc = vmulq_f32(acc, scale_v); + vst1q_f32(acc_o + j, acc); + } + } + + inline void mma1_d(const acc_dtype_t *__restrict__ w_block, const dtype_t *__restrict__ v_block, + acc_dtype_t *__restrict__ acc_o, const int32_t kv_head_size, + const int32_t dim_size, const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, bool causal_mask) { + const int32_t v_stride_size = kv_head_size * dim_size; + for (int d_base = 0; d_base < dim_size; d_base += 4) { + float32x4_t acc = vld1q_f32(acc_o + d_base); + for (int b_c_idx = 0; b_c_idx < Bc; ++b_c_idx) { + float32x4_t w_vec = vdupq_n_f32(w_block[b_c_idx]); + const float *v_ptr = v_block + b_c_idx * v_stride_size + d_base; + acc = vfmaq_f32(acc, w_vec, vld1q_f32(v_ptr)); + } + vst1q_f32(acc_o + d_base, acc); + } + } + + inline void scale_and_store_d(const acc_dtype_t *__restrict__ acc_o, + const acc_dtype_t *__restrict__ logsum, + dtype_t *__restrict__ o_block, const int32_t t_r_idx, + const int32_t head_size, const int32_t dim_size) { + float reciprocal_logsum = 1.0f / logsum[0]; + float32x4_t reciprocal_logsum_vec = vdupq_n_f32(reciprocal_logsum); + int j = 0; + for (; j <= dim_size - 4; j += 4) { + vst1q_f32(o_block + j, vmulq_f32(vld1q_f32(acc_o + j), reciprocal_logsum_vec)); + } + for (; j < dim_size; ++j) { + o_block[j] = acc_o[j] * reciprocal_logsum; + } + } + + // --- Decode N-fixed 函数 (NEON 版本, 逻辑多为标量) --- + inline void mma0_d_n_fixed(const int32_t Bc_n_fixed, const dtype_t *__restrict__ q_block, + const dtype_t *__restrict__ k_block, acc_dtype_t *__restrict__ acc_s, + const int32_t dim_size, const int32_t kv_stride_size, + const int32_t t_r_idx, const int32_t t_c_idx, const int32_t seq_size_q, + const int32_t seq_size_k, bool causal_mask) { + const dtype_t *q_block_line = q_block; + for (int32_t b_c_idx = 0; b_c_idx < Bc_n_fixed; ++b_c_idx) { + const dtype_t *k_block_line = k_block + b_c_idx * kv_stride_size; + float total = 0.0f; + for (int i = 0; i < dim_size; ++i) { total += q_block_line[i] * k_block_line[i]; } + acc_s[b_c_idx] = total; + } + } + + inline void softmax_d_n_fixed(const int32_t Bc_n_fixed, acc_dtype_t *__restrict__ acc_s, + acc_dtype_t *scoremax, acc_dtype_t *scoremax_prev, + acc_dtype_t *score_scale, acc_dtype_t *score_sum, + acc_dtype_t *logsum, + const float scale, + const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, bool causal_mask) { + scoremax_prev[0] = scoremax[0]; + float max_val = scoremax[0]; + for (int bc = 0; bc < Bc_n_fixed; ++bc) max_val = fmaxf(max_val, acc_s[bc]); + scoremax[0] = max_val; + + score_scale[0] = expf((scoremax_prev[0] - scoremax[0]) * scale); + + float current_sum = 0.0f; + for (int bc = 0; bc < Bc_n_fixed; ++bc) { + float val = expf((acc_s[bc] - scoremax[0]) * scale); + acc_s[bc] = val; + current_sum += val; + } + score_sum[0] = current_sum; + + logsum[0] = logsum[0] * score_scale[0] + score_sum[0]; + } + + inline void rescale_d_n_fixed(const int32_t Bc_n_fixed, acc_dtype_t *__restrict__ acc_o, + acc_dtype_t *__restrict__ score_scale, const int32_t dim_size, + const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, + bool causal_mask) { + float scale = score_scale[0]; + for (int j = 0; j < dim_size; ++j) { + acc_o[j] *= scale; + } + } + + inline void mma1_d_n_fixed(const int32_t Bc_n_fixed, const acc_dtype_t *__restrict__ w_block, + const dtype_t *__restrict__ v_block, acc_dtype_t *__restrict__ acc_o, + const int32_t kv_head_size, const int32_t dim_size, const int32_t t_r_idx, + const int32_t t_c_idx, const int32_t seq_size_q, + const int32_t seq_size_k, bool causal_mask) { + const int32_t v_stride_size = kv_head_size * dim_size; + for (int d_base = 0; d_base < dim_size; ++d_base) { + float acc = acc_o[d_base]; + for (int b_c_idx = 0; b_c_idx < Bc_n_fixed; ++b_c_idx) { + acc += w_block[b_c_idx] * v_block[b_c_idx * v_stride_size + d_base]; + } + acc_o[d_base] = acc; + } + } + +private: + // 私有成员,用于存储工作空间的指针 + acc_dtype_t *acc_o_; + acc_dtype_t *acc_s_; + acc_dtype_t *logsum_; + acc_dtype_t *scoremax_; + acc_dtype_t *scoremax_prev_; + acc_dtype_t *score_scale_; + acc_dtype_t *score_sum_; +}; +// ======================================== +// FlashAttention2 核心实现 (Q FP32/KV FP16 输入, FP32 输出, NEON 版本) +// 【注意:本版本为完整、未省略代码的版本】 +// ======================================== +struct NEON_FA_2_GQA_Q_FP32_KV_FP16_BSHD_O_FP32_BSHD_ACC_FP32_IMPL { + // 定义不同的输入数据类型 + using dtype_q_in_t = float; + using dtype_kv_in_t = mllm_fp16_t; // K和V是FP16 + using dtype_out_t = float; + using acc_dtype_t = float; + + // 配置参数 + int32_t Br, Bc, Q_Head, KV_Head, threads; + bool high_precision; + + void configure(int32_t Br_, int32_t Bc_, int32_t Q_Head_, int32_t KV_Head_, int32_t threads_, bool high_precision_) { + Br = Br_; + Bc = Bc_; + Q_Head = Q_Head_; + KV_Head = KV_Head_; + threads = threads_; + high_precision = high_precision_; + } + + void init_workspace(acc_dtype_t *acc_o, acc_dtype_t *acc_s, + acc_dtype_t *logsum, acc_dtype_t *scoremax, acc_dtype_t *scoremax_prev, + acc_dtype_t *score_scale, acc_dtype_t *score_sum) { + acc_o_ = acc_o; + acc_s_ = acc_s; + logsum_ = logsum; + scoremax_ = scoremax; + scoremax_prev_ = scoremax_prev; + score_scale_ = score_scale; + score_sum_ = score_sum; + } + + // 主函数 fa2, 注意 K 和 V 的类型是 dtype_kv_in_t + void fa2(const dtype_q_in_t *__restrict__ Q, const dtype_kv_in_t *__restrict__ K, + const dtype_kv_in_t *__restrict__ V, dtype_out_t *__restrict__ O, const int32_t batch_size, + const int32_t head_size, const int32_t seq_size_q, const int32_t seq_size_k, + const int32_t dim_size, bool causal_mask = true) { + assert(Br == Bc); + assert(dim_size % 4 == 0); + assert(Q_Head % KV_Head == 0); + assert(head_size % threads == 0); + + if (seq_size_q != 1) { + __fa2_prefill_append(Q, K, V, O, batch_size, head_size, seq_size_q, seq_size_k, dim_size, causal_mask); + } else { + __fa2_decode(Q, K, V, O, batch_size, head_size, seq_size_q, seq_size_k, dim_size, causal_mask); + } + } + +private: + // 定义一个宏,用于从内存加载4个fp16, 并转换为一个fp32向量 + // 这需要 ARMv8.2-A FP16 指令支持 +#define MLLM_NEON_F32x4_FROM_FP16(addr) vcvt_f32_f16(vld1_f16((const __fp16 *)(addr))) + + // Prefill/Append 主循环 (混合精度) - 完整版 + inline void __fa2_prefill_append(const dtype_q_in_t *__restrict__ Q, const dtype_kv_in_t *__restrict__ K, + const dtype_kv_in_t *__restrict__ V, dtype_out_t *__restrict__ O, + const int32_t batch_size, const int32_t head_size, + const int32_t seq_size_q, const int32_t seq_size_k, + const int32_t dim_size, bool causal_mask = true) { + const int32_t Tr = seq_size_q / Br; + const int32_t Tr_left = seq_size_q % Br; + const int32_t Tc = seq_size_k / Bc; + const int32_t Tc_left = seq_size_k % Bc; + + const float local_scale = 1.0f / sqrtf(static_cast(dim_size)); + const int32_t kv_group_size = (Q_Head > 0 && KV_Head > 0) ? Q_Head / KV_Head : 1; + + for (int32_t b_idx = 0; b_idx < batch_size; ++b_idx) { +#pragma omp parallel for num_threads(threads) schedule(dynamic, 1) if (threads > 1) + for (int32_t h_idx = 0; h_idx < head_size; ++h_idx) { + const int32_t thread_id = omp_get_thread_num(); + const int32_t this_thread_head = h_idx; + const int32_t this_thread_kv_head = this_thread_head / kv_group_size; + + // --- 主循环 (Tr) --- + for (int t_r_idx = 0; t_r_idx < Tr; ++t_r_idx) { + init_temp(logsum_ + thread_id * Br, scoremax_ + thread_id * Br, acc_o_ + thread_id * Br * dim_size, dim_size); + for (int t_c_idx = 0; t_c_idx < Tc; ++t_c_idx) { + const dtype_q_in_t *tile_q = Q + b_idx * seq_size_q * head_size * dim_size + t_r_idx * Br * head_size * dim_size + this_thread_head * dim_size; + const dtype_kv_in_t *tile_k = K + b_idx * seq_size_k * KV_Head * dim_size + t_c_idx * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + const dtype_kv_in_t *tile_v = V + b_idx * seq_size_k * KV_Head * dim_size + t_c_idx * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + acc_dtype_t *tile_acc_s = acc_s_ + thread_id * Br * Bc; + acc_dtype_t *acc_o = acc_o_ + thread_id * Br * dim_size; + mma0(tile_q, tile_k, tile_acc_s, dim_size, head_size * dim_size, KV_Head * dim_size, t_r_idx, t_c_idx, seq_size_q, seq_size_k, causal_mask); + softmax(tile_acc_s, scoremax_ + thread_id * Br, scoremax_prev_ + thread_id * Br, score_scale_ + thread_id * Br, score_sum_ + thread_id * Br, logsum_ + thread_id * Br, local_scale, t_r_idx, t_c_idx, seq_size_q, seq_size_k, causal_mask); + rescale(acc_o, score_scale_ + thread_id * Br, dim_size, t_r_idx, t_c_idx, seq_size_q, seq_size_k, causal_mask); + mma1(tile_acc_s, tile_v, acc_o, KV_Head, dim_size, t_r_idx, t_c_idx, seq_size_q, seq_size_k, causal_mask); + } + if (Tc_left) { + const dtype_q_in_t *tile_q = Q + b_idx * seq_size_q * head_size * dim_size + t_r_idx * Br * head_size * dim_size + this_thread_head * dim_size; + const dtype_kv_in_t *tile_k = K + b_idx * seq_size_k * KV_Head * dim_size + Tc * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + const dtype_kv_in_t *tile_v = V + b_idx * seq_size_k * KV_Head * dim_size + Tc * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + acc_dtype_t *tile_acc_s = acc_s_ + thread_id * Br * Bc; + acc_dtype_t *acc_o = acc_o_ + thread_id * Br * dim_size; + mma0_pa_n_fixed(Br, Tc_left, tile_q, tile_k, tile_acc_s, dim_size, head_size * dim_size, KV_Head * dim_size, t_r_idx, Tc, seq_size_q, seq_size_k, causal_mask); + softmax_pa_n_fixed(Br, Tc_left, tile_acc_s, scoremax_ + thread_id * Br, scoremax_prev_ + thread_id * Br, score_scale_ + thread_id * Br, score_sum_ + thread_id * Br, logsum_ + thread_id * Br, local_scale, t_r_idx, Tc, seq_size_q, seq_size_k, causal_mask); + rescale_pa_n_fixed(Br, Tc_left, acc_o, score_scale_ + thread_id * Br, dim_size, t_r_idx, Tc, seq_size_q, seq_size_k, causal_mask); + mma1_pa_n_fixed(Br, Tc_left, tile_acc_s, tile_v, acc_o, KV_Head, dim_size, t_r_idx, Tc, seq_size_q, seq_size_k, causal_mask); + } + scale_and_store(acc_o_ + thread_id * Br * dim_size, logsum_ + thread_id * Br, O + b_idx * seq_size_q * head_size * dim_size + t_r_idx * Br * head_size * dim_size + this_thread_head * dim_size, t_r_idx, head_size, dim_size); + } + // --- 处理 Q 序列的剩余部分 (Tr_left) - 完整版 --- + if (Tr_left) { + init_temp(logsum_ + thread_id * Br, scoremax_ + thread_id * Br, acc_o_ + thread_id * Br * dim_size, dim_size); + for (int t_c_idx = 0; t_c_idx < Tc; ++t_c_idx) { + const dtype_q_in_t *tile_q = Q + b_idx * seq_size_q * head_size * dim_size + Tr * Br * head_size * dim_size + this_thread_head * dim_size; + const dtype_kv_in_t *tile_k = K + b_idx * seq_size_k * KV_Head * dim_size + t_c_idx * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + const dtype_kv_in_t *tile_v = V + b_idx * seq_size_k * KV_Head * dim_size + t_c_idx * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + acc_dtype_t *tile_acc_s = acc_s_ + thread_id * Br * Bc; + acc_dtype_t *acc_o = acc_o_ + thread_id * Br * dim_size; + mma0_pa_n_fixed(Tr_left, Bc, tile_q, tile_k, tile_acc_s, dim_size, head_size * dim_size, KV_Head * dim_size, Tr, t_c_idx, seq_size_q, seq_size_k, causal_mask); + softmax_pa_n_fixed(Tr_left, Bc, tile_acc_s, scoremax_ + thread_id * Br, scoremax_prev_ + thread_id * Br, score_scale_ + thread_id * Br, score_sum_ + thread_id * Br, logsum_ + thread_id * Br, local_scale, Tr, t_c_idx, seq_size_q, seq_size_k, causal_mask); + rescale_pa_n_fixed(Tr_left, Bc, acc_o, score_scale_ + thread_id * Br, dim_size, Tr, t_c_idx, seq_size_q, seq_size_k, causal_mask); + mma1_pa_n_fixed(Tr_left, Bc, tile_acc_s, tile_v, acc_o, KV_Head, dim_size, Tr, t_c_idx, seq_size_q, seq_size_k, causal_mask); + } + if (Tc_left) { + const dtype_q_in_t *tile_q = Q + b_idx * seq_size_q * head_size * dim_size + Tr * Br * head_size * dim_size + this_thread_head * dim_size; + const dtype_kv_in_t *tile_k = K + b_idx * seq_size_k * KV_Head * dim_size + Tc * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + const dtype_kv_in_t *tile_v = V + b_idx * seq_size_k * KV_Head * dim_size + Tc * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + acc_dtype_t *tile_acc_s = acc_s_ + thread_id * Br * Bc; + acc_dtype_t *acc_o = acc_o_ + thread_id * Br * dim_size; + mma0_pa_n_fixed(Tr_left, Tc_left, tile_q, tile_k, tile_acc_s, dim_size, head_size * dim_size, KV_Head * dim_size, Tr, Tc, seq_size_q, seq_size_k, causal_mask); + softmax_pa_n_fixed(Tr_left, Tc_left, tile_acc_s, scoremax_ + thread_id * Br, scoremax_prev_ + thread_id * Br, score_scale_ + thread_id * Br, score_sum_ + thread_id * Br, logsum_ + thread_id * Br, local_scale, Tr, Tc, seq_size_q, seq_size_k, causal_mask); + rescale_pa_n_fixed(Tr_left, Tc_left, acc_o, score_scale_ + thread_id * Br, dim_size, Tr, Tc, seq_size_q, seq_size_k, causal_mask); + mma1_pa_n_fixed(Tr_left, Tc_left, tile_acc_s, tile_v, acc_o, KV_Head, dim_size, Tr, Tc, seq_size_q, seq_size_k, causal_mask); + } + scale_and_store_pa_n_fixed(Tr_left, acc_o_ + thread_id * Br * dim_size, logsum_ + thread_id * Br, O + b_idx * seq_size_q * head_size * dim_size + Tr * Br * head_size * dim_size + this_thread_head * dim_size, Tr, head_size, dim_size); + } + } + } + } + + // Decode 主循环 (混合精度) - 完整版 + inline void __fa2_decode(const dtype_q_in_t *__restrict__ Q, const dtype_kv_in_t *__restrict__ K, + const dtype_kv_in_t *__restrict__ V, dtype_out_t *__restrict__ O, + const int32_t batch_size, const int32_t head_size, + const int32_t seq_size_q, const int32_t seq_size_k, + const int32_t dim_size, bool causal_mask = true) { + const int32_t Tr = 1; + const int32_t Tc = seq_size_k / Bc; + const int32_t Tc_left = seq_size_k % Bc; + + const float local_scale = 1.0f / sqrtf(static_cast(dim_size)); + const int32_t kv_group_size = (Q_Head > 0 && KV_Head > 0) ? Q_Head / KV_Head : 1; + + for (int32_t b_idx = 0; b_idx < batch_size; ++b_idx) { +#pragma omp parallel for num_threads(threads) schedule(dynamic, 1) if (threads > 1) + for (int32_t h_idx = 0; h_idx < head_size; ++h_idx) { + const int32_t thread_id = omp_get_thread_num(); + const int32_t this_thread_head = h_idx; + const int32_t this_thread_kv_head = this_thread_head / kv_group_size; + + const int t_r_idx = 0; + init_temp_d(logsum_ + thread_id * Br, scoremax_ + thread_id * Br, acc_o_ + thread_id * Br * dim_size, dim_size); + for (int t_c_idx = 0; t_c_idx < Tc; ++t_c_idx) { + const dtype_q_in_t *tile_q = Q + b_idx * seq_size_q * head_size * dim_size + t_r_idx * 1 * head_size * dim_size + this_thread_head * dim_size; + const dtype_kv_in_t *tile_k = K + b_idx * seq_size_k * KV_Head * dim_size + t_c_idx * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + const dtype_kv_in_t *tile_v = V + b_idx * seq_size_k * KV_Head * dim_size + t_c_idx * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + acc_dtype_t *tile_acc_s = acc_s_ + thread_id * Br * Bc; + acc_dtype_t *acc_o = acc_o_ + thread_id * Br * dim_size; + mma0_d(tile_q, tile_k, tile_acc_s, dim_size, KV_Head * dim_size, t_r_idx, t_c_idx, seq_size_q, seq_size_k, causal_mask); + softmax_d(tile_acc_s, scoremax_ + thread_id * Br, scoremax_prev_ + thread_id * Br, score_scale_ + thread_id * Br, score_sum_ + thread_id * Br, logsum_ + thread_id * Br, local_scale, t_r_idx, t_c_idx, seq_size_q, seq_size_k, causal_mask); + rescale_d(acc_o, score_scale_ + thread_id * Br, dim_size, t_r_idx, t_c_idx, seq_size_q, seq_size_k, causal_mask); + mma1_d(tile_acc_s, tile_v, acc_o, KV_Head, dim_size, t_r_idx, t_c_idx, seq_size_q, seq_size_k, causal_mask); + } + if (Tc_left) { + const dtype_q_in_t *tile_q = Q + b_idx * seq_size_q * head_size * dim_size + t_r_idx * 1 * head_size * dim_size + this_thread_head * dim_size; + const dtype_kv_in_t *tile_k = K + b_idx * seq_size_k * KV_Head * dim_size + Tc * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + const dtype_kv_in_t *tile_v = V + b_idx * seq_size_k * KV_Head * dim_size + Tc * Bc * KV_Head * dim_size + this_thread_kv_head * dim_size; + acc_dtype_t *tile_acc_s = acc_s_ + thread_id * Br * Bc; + acc_dtype_t *acc_o = acc_o_ + thread_id * Br * dim_size; + mma0_d_n_fixed(Tc_left, tile_q, tile_k, tile_acc_s, dim_size, KV_Head * dim_size, t_r_idx, Tc, seq_size_q, seq_size_k, causal_mask); + softmax_d_n_fixed(Tc_left, tile_acc_s, scoremax_ + thread_id * Br, scoremax_prev_ + thread_id * Br, score_scale_ + thread_id * Br, score_sum_ + thread_id * Br, logsum_ + thread_id * Br, local_scale, t_r_idx, Tc, seq_size_q, seq_size_k, causal_mask); + rescale_d_n_fixed(Tc_left, acc_o, score_scale_ + thread_id * Br, dim_size, t_r_idx, Tc, seq_size_q, seq_size_k, causal_mask); + mma1_d_n_fixed(Tc_left, tile_acc_s, tile_v, acc_o, KV_Head, dim_size, t_r_idx, Tc, seq_size_q, seq_size_k, causal_mask); + } + scale_and_store_d(acc_o_ + thread_id * Br * dim_size, logsum_ + thread_id * Br, O + b_idx * seq_size_q * head_size * dim_size + t_r_idx * 1 * head_size * dim_size + this_thread_head * dim_size, t_r_idx, head_size, dim_size); + } + } + } + + // --- 完整版辅助函数 --- + + // (与FP32版本相同) + inline void init_temp(acc_dtype_t *logsum, acc_dtype_t *scoremax, acc_dtype_t *acc_o, const int32_t dim_size) { + float32x4_t zero_vec = vdupq_n_f32(0.0f); + float32x4_t neg_inf_vec = vdupq_n_f32(NEG_INF); + int i = 0; + for (; i <= Br - 4; i += 4) { + vst1q_f32(logsum + i, zero_vec); + vst1q_f32(scoremax + i, neg_inf_vec); + } + for (; i < Br; ++i) { + logsum[i] = 0.0f; + scoremax[i] = NEG_INF; + } + for (int j = 0; j < Br * dim_size; j += 4) { + vst1q_f32(acc_o + j, zero_vec); + } + } + + // 输入Q为FP32,但在函数内动态转为FP16,以使用最高效的 vfmlalq_f16 指令进行计算 + inline void mma0(const dtype_q_in_t *__restrict__ q_block, const dtype_kv_in_t *__restrict__ k_block, + acc_dtype_t *__restrict__ acc_s, const int32_t dim_size, + const int32_t q_stride_size, const int32_t kv_stride_size, + const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, bool causal_mask) { + // 因果掩码的前置检查 + const int32_t global_r_start = t_r_idx * Br; + const int32_t global_r_end = global_r_start + Br; + const int32_t global_c_start = t_c_idx * Bc; + int delta_pos = seq_size_k - seq_size_q; + + if (causal_mask && (global_c_start - delta_pos > (global_r_end - 1))) { return; } + + // 遍历Br x Bc的块 + for (int32_t b_r_idx = 0; b_r_idx < Br; ++b_r_idx) { + const dtype_q_in_t *q_block_line = q_block + b_r_idx * q_stride_size; + for (int32_t b_c_idx = 0; b_c_idx < Bc; ++b_c_idx) { + const dtype_kv_in_t *k_block_line = k_block + b_c_idx * kv_stride_size; + + // 使用多个FP32累加器,借鉴自您提供的参考文件 + float32x4_t sum0 = vdupq_n_f32(0.0f); + float32x4_t sum1 = vdupq_n_f32(0.0f); + + int i = 0; + // 主循环,一次处理16个元素 (两个float16x8_t向量) + for (; i <= dim_size - 16; i += 16) { + __builtin_prefetch(q_block_line + i + 64); + __builtin_prefetch(k_block_line + i + 64); + + // 1. 加载16个 FP32 的 Q + float32x4_t q_f32_0 = vld1q_f32(q_block_line + i); + float32x4_t q_f32_1 = vld1q_f32(q_block_line + i + 4); + float32x4_t q_f32_2 = vld1q_f32(q_block_line + i + 8); + float32x4_t q_f32_3 = vld1q_f32(q_block_line + i + 12); + + // 2. 【核心修改】将加载的 FP32 Q 动态转换为 FP16 + float16x8_t q_f16_0 = vcombine_f16(vcvt_f16_f32(q_f32_0), vcvt_f16_f32(q_f32_1)); + float16x8_t q_f16_1 = vcombine_f16(vcvt_f16_f32(q_f32_2), vcvt_f16_f32(q_f32_3)); + + // 3. 加载16个 FP16 的 K + float16x8_t k_f16_0 = vld1q_f16((const __fp16 *)k_block_line + i); + float16x8_t k_f16_1 = vld1q_f16((const __fp16 *)k_block_line + i + 8); + + // 4. 【核心修改】直接对两个FP16向量进行乘加,结果累加到FP32寄存器 + sum0 = vfmlalq_low_f16(sum0, q_f16_0, k_f16_0); + sum0 = vfmlalq_high_f16(sum0, q_f16_0, k_f16_0); + sum1 = vfmlalq_low_f16(sum1, q_f16_1, k_f16_1); + sum1 = vfmlalq_high_f16(sum1, q_f16_1, k_f16_1); + } + + sum0 = vaddq_f32(sum0, sum1); // 合并累加器 + + // 处理剩余的8个元素 + if (i <= dim_size - 8) { + float32x4_t q_f32_0 = vld1q_f32(q_block_line + i); + float32x4_t q_f32_1 = vld1q_f32(q_block_line + i + 4); + float16x8_t q_f16 = vcombine_f16(vcvt_f16_f32(q_f32_0), vcvt_f16_f32(q_f32_1)); + float16x8_t k_f16 = vld1q_f16((const __fp16 *)k_block_line + i); + + sum0 = vfmlalq_low_f16(sum0, q_f16, k_f16); + sum0 = vfmlalq_high_f16(sum0, q_f16, k_f16); + i += 8; + } + + // 水平求和,得到最终的点积结果 + acc_dtype_t total = vaddvq_f32(sum0); + + // 用标量方式处理最后不足8个的元素 + for (; i < dim_size; ++i) { + total += q_block_line[i] * MLLM_FP16_TO_FP32(k_block_line[i]); + } + acc_s[b_r_idx * Bc + b_c_idx] = total; + } + } + + // 应用因果掩码的后处理 + if (causal_mask && (global_r_end == (t_c_idx * Bc + Bc) - delta_pos)) { + for (int i = 0; i < Br; ++i) { + for (int j = 0; j < Bc; ++j) { + if (j > i) { acc_s[i * Bc + j] = NEG_INF; } + } + } + } + } + // (与FP32版本相同) + inline void softmax(acc_dtype_t *__restrict__ acc_s, acc_dtype_t *scoremax, acc_dtype_t *scoremax_prev, + acc_dtype_t *score_scale, acc_dtype_t *score_sum, acc_dtype_t *logsum, + const float scale, const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, bool causal_mask) { + const int32_t global_r_start = t_r_idx * Br; + const int32_t global_c_start = t_c_idx * Bc; + int delta_pos = seq_size_k - seq_size_q; + if (causal_mask && (global_c_start - delta_pos > (global_r_start + Br - 1))) return; + memcpy(scoremax_prev, scoremax, Br * sizeof(acc_dtype_t)); + for (int br = 0; br < Br; ++br) { + float32x4_t max_vec = vdupq_n_f32(scoremax[br]); + acc_dtype_t *row = acc_s + br * Bc; + int bc = 0; + for (; bc <= Bc - 4; bc += 4) { max_vec = vmaxq_f32(max_vec, vld1q_f32(row + bc)); } + float max_val = _vmaxvq_f32_hmax(max_vec); + for (; bc < Bc; ++bc) { max_val = fmaxf(max_val, row[bc]); } + scoremax[br] = max_val; + } + for (int br = 0; br < Br; ++br) { score_scale[br] = expf((scoremax_prev[br] - scoremax[br]) * scale); } + for (int br = 0; br < Br; ++br) { + const float sm = scoremax[br]; + acc_dtype_t *row = acc_s + br * Bc; + float sum = 0.0f; + for (int bc = 0; bc < Bc; ++bc) { + float val = expf((row[bc] - sm) * scale); + row[bc] = val; + sum += val; + } + score_sum[br] = sum; + } + for (int br = 0; br < Br; ++br) { logsum[br] = logsum[br] * score_scale[br] + score_sum[br]; } + } + + // (与FP32版本相同) + inline void rescale(acc_dtype_t *__restrict__ acc_o, acc_dtype_t *__restrict__ score_scale, + const int32_t dim_size, const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, bool causal_mask) { + const int32_t global_r_start = t_r_idx * Br; + const int32_t global_c_start = t_c_idx * Bc; + int delta_pos = seq_size_k - seq_size_q; + if (causal_mask && (global_c_start - delta_pos > (global_r_start + Br - 1))) return; + for (int i = 0; i < Br; ++i) { + float32x4_t scale_v = vdupq_n_f32(score_scale[i]); + float *row_ptr = acc_o + i * dim_size; + for (int j = 0; j < dim_size; j += 4) { + float32x4_t acc = vld1q_f32(row_ptr + j); + acc = vmulq_f32(acc, scale_v); + vst1q_f32(row_ptr + j, acc); + } + } + } + + // (混合精度修改版) + inline void mma1(const acc_dtype_t *__restrict__ w_block, const dtype_kv_in_t *__restrict__ v_block, + acc_dtype_t *__restrict__ acc_o, const int32_t kv_head_size, const int32_t dim_size, + const int32_t t_r_idx, const int32_t t_c_idx, const int32_t seq_size_q, + const int32_t seq_size_k, bool causal_mask) { + const int32_t global_r_start = t_r_idx * Br; + const int32_t global_c_start = t_c_idx * Bc; + int delta_pos = seq_size_k - seq_size_q; + if (causal_mask && (global_c_start - delta_pos > (global_r_start + Br - 1))) return; + const int32_t v_stride_size = kv_head_size * dim_size; + for (int b_r_idx = 0; b_r_idx < Br; ++b_r_idx) { + for (int d_base = 0; d_base < dim_size; d_base += 4) { + float32x4_t acc = vld1q_f32(acc_o + b_r_idx * dim_size + d_base); + for (int b_c_idx = 0; b_c_idx < Bc; ++b_c_idx) { + float32x4_t w_vec = vdupq_n_f32(w_block[b_r_idx * Bc + b_c_idx]); + const dtype_kv_in_t *v_ptr = v_block + b_c_idx * v_stride_size + d_base; + acc = vfmaq_f32(acc, w_vec, MLLM_NEON_F32x4_FROM_FP16(v_ptr)); + } + vst1q_f32(acc_o + b_r_idx * dim_size + d_base, acc); + } + } + } + + // (与FP32版本相同) + inline void scale_and_store(const acc_dtype_t *__restrict__ acc_o, const acc_dtype_t *__restrict__ logsum, + dtype_out_t *__restrict__ o_block, const int32_t t_r_idx, + const int32_t head_size, const int32_t dim_size) { + for (int i = 0; i < Br; ++i) { + dtype_out_t *o_block_line = o_block + i * head_size * dim_size; + float reciprocal_logsum = 1.0f / logsum[i]; + float32x4_t reciprocal_logsum_vec = vdupq_n_f32(reciprocal_logsum); + int j = 0; + for (; j <= dim_size - 4; j += 4) { + vst1q_f32(o_block_line + j, vmulq_f32(vld1q_f32(acc_o + i * dim_size + j), reciprocal_logsum_vec)); + } + for (; j < dim_size; ++j) { o_block_line[j] = acc_o[i * dim_size + j] * reciprocal_logsum; } + } + } + + // (混合精度修改版) + inline void mma0_pa_n_fixed(const int32_t Br_n_fixed, const int32_t Bc_n_fixed, + const dtype_q_in_t *__restrict__ q_block, const dtype_kv_in_t *__restrict__ k_block, + acc_dtype_t *__restrict__ acc_s, const int32_t dim_size, + const int32_t q_stride_size, const int32_t kv_stride_size, + const int32_t t_r_idx, const int32_t t_c_idx, const int32_t seq_size_q, + const int32_t seq_size_k, bool causal_mask) { + const int32_t global_r_start = t_r_idx * Br, global_r_end = global_r_start + Br_n_fixed; + const int32_t global_c_start = t_c_idx * Bc; + int delta_pos = seq_size_k - seq_size_q; + if (causal_mask && (global_c_start - delta_pos > (global_r_end - 1))) { return; } + for (int32_t b_r_idx = 0; b_r_idx < Br_n_fixed; ++b_r_idx) { + const dtype_q_in_t *q_block_line = q_block + b_r_idx * q_stride_size; + for (int32_t b_c_idx = 0; b_c_idx < Bc_n_fixed; ++b_c_idx) { + const dtype_kv_in_t *k_block_line = k_block + b_c_idx * kv_stride_size; + float32x4_t sum_vec = vdupq_n_f32(0.0f); + int i = 0; + for (; i <= dim_size - 4; i += 4) { + sum_vec = vfmaq_f32(sum_vec, vld1q_f32(q_block_line + i), MLLM_NEON_F32x4_FROM_FP16(k_block_line + i)); + } + acc_dtype_t total = _vaddvq_f32_hadd(sum_vec); + for (; i < dim_size; ++i) { total += q_block_line[i] * MLLM_FP16_TO_FP32(k_block_line[i]); } + acc_s[b_r_idx * Bc + b_c_idx] = total; + } + } + if (causal_mask && (global_r_end == (global_c_start + Bc_n_fixed) - delta_pos)) { + for (int i = 0; i < Br_n_fixed; ++i) { + for (int j = 0; j < Bc_n_fixed; ++j) { + if (j > i) { acc_s[i * Bc + j] = NEG_INF; } + } + } + } + } + + // (与FP32版本相同) + inline void softmax_pa_n_fixed(const int32_t Br_n_fixed, const int32_t Bc_n_fixed, + acc_dtype_t *__restrict__ acc_s, acc_dtype_t *scoremax, + acc_dtype_t *scoremax_prev, acc_dtype_t *score_scale, + acc_dtype_t *score_sum, acc_dtype_t *logsum, + const float scale, const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, bool causal_mask) { + const int32_t global_r_start = t_r_idx * Br; + const int32_t global_c_start = t_c_idx * Bc; + int delta_pos = seq_size_k - seq_size_q; + if (causal_mask && (global_c_start - delta_pos > (global_r_start + Br_n_fixed - 1))) return; + memcpy(scoremax_prev, scoremax, Br_n_fixed * sizeof(acc_dtype_t)); + for (int br = 0; br < Br_n_fixed; ++br) { + acc_dtype_t *row = acc_s + br * Bc; + float max_val = NEG_INF; + for (int bc = 0; bc < Bc_n_fixed; ++bc) max_val = fmaxf(max_val, row[bc]); + scoremax[br] = fmaxf(max_val, scoremax[br]); + } + for (int br = 0; br < Br_n_fixed; ++br) { score_scale[br] = expf((scoremax_prev[br] - scoremax[br]) * scale); } + for (int br = 0; br < Br_n_fixed; ++br) { + acc_dtype_t *row = acc_s + br * Bc; + float current_sum = 0.0f; + for (int bc = 0; bc < Bc_n_fixed; ++bc) { + float val = expf((row[bc] - scoremax[br]) * scale); + row[bc] = val; + current_sum += val; + } + score_sum[br] = current_sum; + } + for (int br = 0; br < Br_n_fixed; ++br) { logsum[br] = logsum[br] * score_scale[br] + score_sum[br]; } + } + + // (与FP32版本相同) + inline void rescale_pa_n_fixed(const int32_t Br_n_fixed, const int32_t Bc_n_fixed, + acc_dtype_t *__restrict__ acc_o, acc_dtype_t *__restrict__ score_scale, + const int32_t dim_size, const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, bool causal_mask) { + const int32_t global_r_start = t_r_idx * Br; + const int32_t global_c_start = t_c_idx * Bc; + int delta_pos = seq_size_k - seq_size_q; + if (causal_mask && (global_c_start - delta_pos > (global_r_start + Br_n_fixed - 1))) return; + for (int i = 0; i < Br_n_fixed; ++i) { + float *row_ptr = acc_o + i * dim_size; + float32x4_t scale_v = vdupq_n_f32(score_scale[i]); + for (int j = 0; j < dim_size; j += 4) { + vst1q_f32(row_ptr + j, vmulq_f32(vld1q_f32(row_ptr + j), scale_v)); + } + } + } + + // (混合精度修改版) + inline void mma1_pa_n_fixed(const int32_t Br_n_fixed, const int32_t Bc_n_fixed, + const acc_dtype_t *__restrict__ w_block, const dtype_kv_in_t *__restrict__ v_block, + acc_dtype_t *__restrict__ acc_o, const int32_t kv_head_size, const int32_t dim_size, + const int32_t t_r_idx, const int32_t t_c_idx, const int32_t seq_size_q, + const int32_t seq_size_k, bool causal_mask) { + const int32_t global_r_start = t_r_idx * Br; + const int32_t global_c_start = t_c_idx * Bc; + int delta_pos = seq_size_k - seq_size_q; + if (causal_mask && (global_c_start - delta_pos > (global_r_start + Br_n_fixed - 1))) return; + const int32_t v_stride_size = kv_head_size * dim_size; + for (int b_r_idx = 0; b_r_idx < Br_n_fixed; ++b_r_idx) { + for (int d_base = 0; d_base < dim_size; d_base += 4) { + float32x4_t acc = vld1q_f32(acc_o + b_r_idx * dim_size + d_base); + for (int b_c_idx = 0; b_c_idx < Bc_n_fixed; ++b_c_idx) { + float32x4_t w_vec = vdupq_n_f32(w_block[b_r_idx * Bc + b_c_idx]); + const dtype_kv_in_t *v_ptr = v_block + b_c_idx * v_stride_size + d_base; + acc = vfmaq_f32(acc, w_vec, MLLM_NEON_F32x4_FROM_FP16(v_ptr)); + } + vst1q_f32(acc_o + b_r_idx * dim_size + d_base, acc); + } + } + } + + // (与FP32版本相同) + inline void scale_and_store_pa_n_fixed(const int32_t Br_n_fixed, const acc_dtype_t *__restrict__ acc_o, + const acc_dtype_t *__restrict__ logsum, dtype_out_t *__restrict__ o_block, + const int32_t t_r_idx, const int32_t head_size, const int32_t dim_size) { + for (int i = 0; i < Br_n_fixed; ++i) { + dtype_out_t *o_block_line = o_block + i * head_size * dim_size; + float reciprocal_logsum = 1.0f / logsum[i]; + float32x4_t reciprocal_logsum_vec = vdupq_n_f32(reciprocal_logsum); + int j = 0; + for (; j <= dim_size - 4; j += 4) { + vst1q_f32(o_block_line + j, vmulq_f32(vld1q_f32(acc_o + i * dim_size + j), reciprocal_logsum_vec)); + } + for (; j < dim_size; ++j) { o_block_line[j] = acc_o[i * dim_size + j] * reciprocal_logsum; } + } + } + + // --- Decode 模式函数 (完整版) --- + + // (与FP32版本相同) + inline void init_temp_d(acc_dtype_t *logsum, acc_dtype_t *scoremax, acc_dtype_t *acc_o, const int32_t dim_size) { + logsum[0] = 0.0f; + scoremax[0] = NEG_INF; + float32x4_t zero_vec = vdupq_n_f32(0.0f); + for (int i = 0; i < 1 * dim_size; i += 4) { vst1q_f32(acc_o + i, zero_vec); } + } + + // (混合精度修改版) + inline void mma0_d(const dtype_q_in_t *__restrict__ q_block, const dtype_kv_in_t *__restrict__ k_block, + acc_dtype_t *__restrict__ acc_s, const int32_t dim_size, + const int32_t kv_stride_size, const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, bool causal_mask) { + const dtype_q_in_t *q_block_line = q_block; + for (int32_t b_c_idx = 0; b_c_idx < Bc; ++b_c_idx) { + const dtype_kv_in_t *k_block_line = k_block + b_c_idx * kv_stride_size; + float32x4_t sum_vec = vdupq_n_f32(0.0f); + int i = 0; + for (; i <= dim_size - 4; i += 4) { + sum_vec = vfmaq_f32(sum_vec, vld1q_f32(q_block_line + i), MLLM_NEON_F32x4_FROM_FP16(k_block_line + i)); + } + acc_dtype_t total = _vaddvq_f32_hadd(sum_vec); + for (; i < dim_size; ++i) { total += q_block_line[i] * MLLM_FP16_TO_FP32(k_block_line[i]); } + acc_s[b_c_idx] = total; + } + } + + // (与FP32版本相同) + inline void softmax_d(acc_dtype_t *__restrict__ acc_s, acc_dtype_t *scoremax, acc_dtype_t *scoremax_prev, + acc_dtype_t *score_scale, acc_dtype_t *score_sum, acc_dtype_t *logsum, + const float scale, const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, bool causal_mask) { + scoremax_prev[0] = scoremax[0]; + float max_val = scoremax[0]; + for (int bc = 0; bc < Bc; ++bc) max_val = fmaxf(max_val, acc_s[bc]); + scoremax[0] = max_val; + score_scale[0] = expf((scoremax_prev[0] - scoremax[0]) * scale); + float current_sum = 0.0f; + for (int bc = 0; bc < Bc; ++bc) { + float val = expf((acc_s[bc] - scoremax[0]) * scale); + acc_s[bc] = val; + current_sum += val; + } + score_sum[0] = current_sum; + logsum[0] = logsum[0] * score_scale[0] + score_sum[0]; + } + + // (与FP32版本相同) + inline void rescale_d(acc_dtype_t *__restrict__ acc_o, acc_dtype_t *__restrict__ score_scale, + const int32_t dim_size, const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, bool causal_mask) { + float32x4_t scale_v = vdupq_n_f32(score_scale[0]); + for (int j = 0; j < dim_size; j += 4) { + vst1q_f32(acc_o + j, vmulq_f32(vld1q_f32(acc_o + j), scale_v)); + } + } + + // (混合精度修改版) + inline void mma1_d(const acc_dtype_t *__restrict__ w_block, const dtype_kv_in_t *__restrict__ v_block, + acc_dtype_t *__restrict__ acc_o, const int32_t kv_head_size, const int32_t dim_size, + const int32_t t_r_idx, const int32_t t_c_idx, const int32_t seq_size_q, + const int32_t seq_size_k, bool causal_mask) { + const int32_t v_stride_size = kv_head_size * dim_size; + for (int d_base = 0; d_base < dim_size; d_base += 4) { + float32x4_t acc = vld1q_f32(acc_o + d_base); + for (int b_c_idx = 0; b_c_idx < Bc; ++b_c_idx) { + float32x4_t w_vec = vdupq_n_f32(w_block[b_c_idx]); + const dtype_kv_in_t *v_ptr = v_block + b_c_idx * v_stride_size + d_base; + acc = vfmaq_f32(acc, w_vec, MLLM_NEON_F32x4_FROM_FP16(v_ptr)); + } + vst1q_f32(acc_o + d_base, acc); + } + } + + // (与FP32版本相同) + inline void scale_and_store_d(const acc_dtype_t *__restrict__ acc_o, const acc_dtype_t *__restrict__ logsum, + dtype_out_t *__restrict__ o_block, const int32_t t_r_idx, + const int32_t head_size, const int32_t dim_size) { + float reciprocal_logsum = 1.0f / logsum[0]; + float32x4_t reciprocal_logsum_vec = vdupq_n_f32(reciprocal_logsum); + int j = 0; + for (; j <= dim_size - 4; j += 4) { + vst1q_f32(o_block + j, vmulq_f32(vld1q_f32(acc_o + j), reciprocal_logsum_vec)); + } + for (; j < dim_size; ++j) { o_block[j] = acc_o[j] * reciprocal_logsum; } + } + + // (混合精度修改版) + inline void mma0_d_n_fixed(const int32_t Bc_n_fixed, const dtype_q_in_t *__restrict__ q_block, + const dtype_kv_in_t *__restrict__ k_block, acc_dtype_t *__restrict__ acc_s, + const int32_t dim_size, const int32_t kv_stride_size, + const int32_t t_r_idx, const int32_t t_c_idx, const int32_t seq_size_q, + const int32_t seq_size_k, bool causal_mask) { + const dtype_q_in_t *q_block_line = q_block; + for (int32_t b_c_idx = 0; b_c_idx < Bc_n_fixed; ++b_c_idx) { + const dtype_kv_in_t *k_block_line = k_block + b_c_idx * kv_stride_size; + float total = 0.0f; + for (int i = 0; i < dim_size; ++i) { total += q_block_line[i] * MLLM_FP16_TO_FP32(k_block_line[i]); } + acc_s[b_c_idx] = total; + } + } + + // (与FP32版本相同) + inline void softmax_d_n_fixed(const int32_t Bc_n_fixed, acc_dtype_t *__restrict__ acc_s, + acc_dtype_t *scoremax, acc_dtype_t *scoremax_prev, + acc_dtype_t *score_scale, acc_dtype_t *score_sum, + acc_dtype_t *logsum, const float scale, + const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, bool causal_mask) { + scoremax_prev[0] = scoremax[0]; + float max_val = scoremax[0]; + for (int bc = 0; bc < Bc_n_fixed; ++bc) max_val = fmaxf(max_val, acc_s[bc]); + scoremax[0] = max_val; + score_scale[0] = expf((scoremax_prev[0] - scoremax[0]) * scale); + float current_sum = 0.0f; + for (int bc = 0; bc < Bc_n_fixed; ++bc) { + float val = expf((acc_s[bc] - scoremax[0]) * scale); + acc_s[bc] = val; + current_sum += val; + } + score_sum[0] = current_sum; + logsum[0] = logsum[0] * score_scale[0] + score_sum[0]; + } + + // (与FP32版本相同) + inline void rescale_d_n_fixed(const int32_t Bc_n_fixed, acc_dtype_t *__restrict__ acc_o, + acc_dtype_t *__restrict__ score_scale, const int32_t dim_size, + const int32_t t_r_idx, const int32_t t_c_idx, + const int32_t seq_size_q, const int32_t seq_size_k, + bool causal_mask) { + float scale = score_scale[0]; + for (int j = 0; j < dim_size; ++j) { acc_o[j] *= scale; } + } + + // (混合精度修改版) + inline void mma1_d_n_fixed(const int32_t Bc_n_fixed, const acc_dtype_t *__restrict__ w_block, + const dtype_kv_in_t *__restrict__ v_block, acc_dtype_t *__restrict__ acc_o, + const int32_t kv_head_size, const int32_t dim_size, const int32_t t_r_idx, + const int32_t t_c_idx, const int32_t seq_size_q, + const int32_t seq_size_k, bool causal_mask) { + const int32_t v_stride_size = kv_head_size * dim_size; + for (int d_base = 0; d_base < dim_size; ++d_base) { + float acc = acc_o[d_base]; + for (int b_c_idx = 0; b_c_idx < Bc_n_fixed; ++b_c_idx) { + acc += w_block[b_c_idx] * MLLM_FP16_TO_FP32(v_block[b_c_idx * v_stride_size + d_base]); + } + acc_o[d_base] = acc; + } + } + +private: + acc_dtype_t *acc_o_; + acc_dtype_t *acc_s_; + acc_dtype_t *logsum_; + acc_dtype_t *scoremax_; + acc_dtype_t *scoremax_prev_; + acc_dtype_t *score_scale_; + acc_dtype_t *score_sum_; +}; + +template +struct FlashAttn2T { +public: + using dtype_q_in_t = typename Impl::dtype_q_in_t; + using dtype_kv_in_t = typename Impl::dtype_kv_in_t; + using dtype_out_t = typename Impl::dtype_out_t; + using acc_dtype_t = typename Impl::acc_dtype_t; + + void configure(int32_t Br, int32_t Bc, int32_t Q_Head, int32_t KV_Head, int32_t threads, bool high_precision) { + impl_.configure(Br, Bc, Q_Head, KV_Head, threads, high_precision); + } + + void init_workspace(acc_dtype_t *acc_o, acc_dtype_t *acc_s, + acc_dtype_t *logsum, acc_dtype_t *scoremax, acc_dtype_t *scoremax_prev, + acc_dtype_t *score_scale, acc_dtype_t *score_sum) { + impl_.init_workspace(acc_o, acc_s, logsum, scoremax, scoremax_prev, score_scale, score_sum); + } + + void operator()(const dtype_q_in_t *__restrict__ Q, const dtype_kv_in_t *__restrict__ K, + const dtype_kv_in_t *__restrict__ V, dtype_out_t *__restrict__ O, + const int32_t batch_size, const int32_t head_size, + const int32_t seq_size_q, const int32_t seq_size_k, + const int32_t dim_size, bool causal_mask = true) { + impl_.fa2(Q, K, V, O, batch_size, head_size, seq_size_q, seq_size_k, dim_size, causal_mask); + } + +private: + Impl impl_; +}; + +} // namespace mobi_attn + +void flash_attention_2_forward( + const void *Q, const void *K, const void *V, void *O, + int32_t batch_size, int32_t head_size, int32_t seq_size_q, int32_t seq_size_k, int32_t dim_size, + bool causal_mask, bool use_fp32, int32_t threads, int32_t br, int32_t bc, + int32_t q_head, int32_t kv_head, bool high_precision_exp) { + const size_t acc_o_size = threads * br * dim_size * sizeof(float); + const size_t acc_s_size = threads * br * bc * sizeof(float); + const size_t logsum_size = threads * br * sizeof(float); + const size_t scoremax_size = threads * br * sizeof(float); + const size_t scoremax_prev_size = threads * br * sizeof(float); + const size_t score_scale_size = threads * br * sizeof(float); + const size_t score_sum_size = threads * br * sizeof(float); + + // TODO 改为只分配一次 + void *workspace[7]; + mobi_attn::aligned_alloc(&workspace[0], acc_o_size, 32); + mobi_attn::aligned_alloc(&workspace[1], acc_s_size, 32); + mobi_attn::aligned_alloc(&workspace[2], logsum_size, 32); + mobi_attn::aligned_alloc(&workspace[3], scoremax_size, 32); + mobi_attn::aligned_alloc(&workspace[4], scoremax_prev_size, 32); + mobi_attn::aligned_alloc(&workspace[5], score_scale_size, 32); + mobi_attn::aligned_alloc(&workspace[6], score_sum_size, 32); + + if (use_fp32) { + // 使用纯FP32 NEON实现 + mobi_attn::FlashAttn2T op; + op.configure(br, bc, q_head, kv_head, threads, high_precision_exp); + + op.init_workspace( + static_cast(workspace[0]), static_cast(workspace[1]), + static_cast(workspace[2]), static_cast(workspace[3]), + static_cast(workspace[4]), static_cast(workspace[5]), + static_cast(workspace[6])); + + op(static_cast(Q), static_cast(K), static_cast(V), + static_cast(O), + batch_size, head_size, seq_size_q, seq_size_k, dim_size, causal_mask); + } else { + // 使用FP16输入,FP32输出的NEON实现 + mobi_attn::FlashAttn2T op; + op.configure(br, bc, q_head, kv_head, threads, high_precision_exp); + + op.init_workspace( + static_cast(workspace[0]), static_cast(workspace[1]), + static_cast(workspace[2]), static_cast(workspace[3]), + static_cast(workspace[4]), static_cast(workspace[5]), + static_cast(workspace[6])); + + op(static_cast(Q), static_cast(K), static_cast(V), + static_cast(O), + batch_size, head_size, seq_size_q, seq_size_k, dim_size, causal_mask); + } + + for (void *ptr : workspace) { + if (ptr) mobi_attn::aligned_free(ptr); + } +} + +#endif // AVX2 + +#endif // MLLM_FA2_CAL_HPP \ No newline at end of file diff --git a/src/backends/cpu/compute/VecDot.hpp b/src/backends/cpu/compute/VecDot.hpp index 4862a3633..f96e7a2d8 100644 --- a/src/backends/cpu/compute/VecDot.hpp +++ b/src/backends/cpu/compute/VecDot.hpp @@ -225,7 +225,7 @@ #define MLLM_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x))) #define MLLM_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0)) #else -static inline __m256 __avx_f32cx8_load(MLLM_fp16_t *x) { +static inline __m256 __avx_f32cx8_load(mllm_fp16_t *x) { float tmp[8]; for (int i = 0; i < 8; i++) { @@ -234,7 +234,7 @@ static inline __m256 __avx_f32cx8_load(MLLM_fp16_t *x) { return _mm256_loadu_ps(tmp); } -static inline void __avx_f32cx8_store(MLLM_fp16_t *x, __m256 y) { +static inline void __avx_f32cx8_store(mllm_fp16_t *x, __m256 y) { float arr[8]; _mm256_storeu_ps(arr, y); diff --git a/src/backends/cpu/function/CPUClipFunc.hpp b/src/backends/cpu/function/CPUClipFunc.hpp index 08e6d9e94..d43cb7c72 100644 --- a/src/backends/cpu/function/CPUClipFunc.hpp +++ b/src/backends/cpu/function/CPUClipFunc.hpp @@ -132,6 +132,21 @@ class CPUclipFunction : public TensorFunction { } } } + } else if (d.size() == 1) { + int seq_idx = d[0]; + if (seq_idx < 0) { + seq_idx = inputs[0]->dimension() + seq_idx; + } +#pragma omp parallel for collapse(1) num_threads(CPUBackend::cpu_threads) + for (int b = 0; b < inputs[0]->batch(); ++b) { + for (int s = 0; s < inputs[0]->sequence(); ++s) { + for (int h = 0; h < inputs[0]->head(); ++h) { + memcpy(outputs[0]->hostPtr() + outputs[0]->offset(b, h, s, 0), + inputs[0]->hostPtr() + inputs[0]->offset(b, h, s, seq_idx), + sizeof(float)); + } + } + } } else { std::cout << "[TODO]Tensor.CLip not support!!!!" << std::endl; } @@ -266,6 +281,28 @@ class CPUcliptensorFunction : public TensorFunction { void execute(vector> outputs, vector> inputs, vector args) override { Chl dim = (Chl)args[0]; if (dim == SEQUENCE) { + if (inputs[0]->ctype() == BHDS) { + outputs[0]->chls() = inputs[0]->chls(); + outputs[0]->setCtype(BHDS); + int new_seq = inputs[1]->dimension(); + if (outputs[0]->sequence() == 0 || outputs[0]->shape().empty() + || new_seq != outputs[0]->sequence()) { + outputs[0]->reshape(inputs[0]->batch(), inputs[0]->head(), new_seq, inputs[0]->dimension()); + outputs[0]->alloc(); + } + +#pragma omp parallel for collapse(3) num_threads(CPUBackend::cpu_threads) + for (int b = 0; b < inputs[0]->batch(); ++b) { + for (int d = 0; d < inputs[0]->dimension(); ++d) { + for (int s = 0; s < new_seq; ++s) { + auto selected_idx = (int)inputs[1]->dataAt(0, 0, 0, s); + outputs[0]->setDataAt(b, 0, s, d, + inputs[0]->dataAt(b, 0, selected_idx, d)); + } + } + } + return; + } int new_seq = inputs[1]->dimension(); if (outputs[0]->sequence() == 0 || outputs[0]->shape().empty() || new_seq != outputs[0]->sequence()) { diff --git a/src/backends/cpu/function/CPUFlashAttention2Func.hpp b/src/backends/cpu/function/CPUFlashAttention2Func.hpp new file mode 100644 index 000000000..6ae558982 --- /dev/null +++ b/src/backends/cpu/function/CPUFlashAttention2Func.hpp @@ -0,0 +1,74 @@ +// +// Created by Rongjie Yi on 25-2-16. +// + +#ifndef CPUFA2FUNC_HPP +#define CPUFA2FUNC_HPP +#include "CPUBackend.hpp" +#include "Tensor.hpp" +#include "Types.hpp" +#include "../compute/FlashAttention2.hpp" + +namespace mllm { +class Tensor; + +class CPUFlashAttention2Func : public TensorFunction { +public: + void reshape(vector> outputs, vector> inputs, vector args) override { + auto q_tensor = inputs[0]; + auto k_tensor = inputs[1]; + auto v_tensor = inputs[2]; + auto o_tensor = outputs[0]; + int batch_size = q_tensor->batch(); + int q_head = q_tensor->head(); + int q_sequence = q_tensor->sequence(); + int dimension = q_tensor->dimension(); + o_tensor->reshape(batch_size, q_head, q_sequence, dimension); + o_tensor->setDtype(inputs[0]->dtype()); + o_tensor->alloc(); + } + void execute(vector> outputs, vector> inputs, vector args) override { + auto q_tensor = inputs[0]; + auto k_tensor = inputs[1]; + auto v_tensor = inputs[2]; + auto o_tensor = outputs[0]; + bool causal_mask = (bool)args[0]; + int batch_size = q_tensor->batch(); + int q_head = q_tensor->head(); + int q_sequence = q_tensor->sequence(); + int dimension = q_tensor->dimension(); + int k_head = k_tensor->head(); + int k_sequence = k_tensor->sequence(); + int v_head = v_tensor->head(); + int v_sequence = v_tensor->sequence(); + assert(v_head == k_head && v_sequence == k_sequence); + bool kv_use_fp32 = k_tensor->dtype() == MLLM_TYPE_F32 ? true : false; // x86只支持FP32 + int threads = CPUBackend::cpu_threads; + if (threads > v_head) { + threads = v_head; // 线程数不能超过头数 + } + int32_t br = q_sequence >= 4 ? 4 : 1; + int32_t bc = q_sequence >= 4 ? 4 : 1; + constexpr bool high_precision_exp = false; + // q_tensor->saveData(); + // k_tensor->saveData(); + // v_tensor->saveData(); + // GQA is not ready + flash_attention_2_forward( + q_tensor->hostPtr(), k_tensor->hostPtr(), v_tensor->hostPtr(), + o_tensor->hostPtr(), // 输入输出张量 + batch_size, q_head, q_sequence, k_sequence, dimension, // 基本维度 + causal_mask, // 使用因果掩码 + kv_use_fp32, // 使用FP32(x86必须) + threads, // 使用4线程 + br, // 查询分块大小64 + bc, // 键值分块大小128 + q_head, // 查询头数12 + k_head, // 键值头数4 + high_precision_exp // 使用快速指数近似 + ); + // o_tensor->saveData(); + } +}; +} // namespace mllm +#endif // CPUFA2FUNC_HPP \ No newline at end of file diff --git a/src/backends/cpu/function/CPUApplyVisionRoPE.hpp b/src/backends/cpu/function/CPUVisionRoPEFunc.hpp similarity index 99% rename from src/backends/cpu/function/CPUApplyVisionRoPE.hpp rename to src/backends/cpu/function/CPUVisionRoPEFunc.hpp index 121cf71a8..9b1337a10 100644 --- a/src/backends/cpu/function/CPUApplyVisionRoPE.hpp +++ b/src/backends/cpu/function/CPUVisionRoPEFunc.hpp @@ -11,7 +11,7 @@ namespace mllm { class Tensor; -class CPUApplyVisionRoPEFunction : public TensorFunction { +class CPUVisionRoPEFuncFunction : public TensorFunction { void rope_hf(shared_ptr input, shared_ptr rotary_pos_emb, shared_ptr output, int thread_count = 4) { auto out_dtype = output->dtype(); diff --git a/src/backends/cpu/op/CPUElasticLinear.cpp b/src/backends/cpu/op/CPUElasticLinear.cpp index b2f9462cf..8314a0362 100644 --- a/src/backends/cpu/op/CPUElasticLinear.cpp +++ b/src/backends/cpu/op/CPUElasticLinear.cpp @@ -69,32 +69,6 @@ ErrorCode CPUElasticLinear::execute(vector> inputs, vector> inputs, vector> outputs) { diff --git a/src/backends/cpu/op/CPUKVCache.cpp b/src/backends/cpu/op/CPUKVCache.cpp index c6985b017..16f610722 100644 --- a/src/backends/cpu/op/CPUKVCache.cpp +++ b/src/backends/cpu/op/CPUKVCache.cpp @@ -6,8 +6,9 @@ int n_pack = 16; namespace mllm { -CPUKVCache::CPUKVCache(Backend *bn, string opName, int hidden, int head, int n_rep, int cache_max, int threadCount) : +CPUKVCache::CPUKVCache(Backend *bn, string opName, int hidden, int head, int n_rep, bool fa2, int cache_max, int threadCount) : thread_count(threadCount), Op(bn, opName) { + fa2_ = fa2; cache_.setBackend(bn); switch (KVCache_TYPE) { case 16: { @@ -32,10 +33,13 @@ CPUKVCache::CPUKVCache(Backend *bn, string opName, int hidden, int head, int n_r break; } } -// #endif + if (!fa2) { // not fa2 #ifdef LLAMAFILE_SGEMM - cache_max = ((cache_max + (n_pack - 1)) / n_pack) * n_pack; + cache_max = ((cache_max + (n_pack - 1)) / n_pack) * n_pack; #endif + } else { // fa2 + n_rep = 1; + } cache_limit_ = cache_max; n_rep_ = n_rep; if (head > 0) { @@ -93,7 +97,9 @@ ErrorCode CPUKVCache::reshape(vector> inputs, } int sequence = inputs[0]->sequence() + cache_seq_len_; #ifdef LLAMAFILE_SGEMM - if (!for_xnn_ && sequence % n_pack != 0) sequence = ((sequence + (n_pack - 1)) / n_pack) * n_pack; + if (!fa2_) { + if (!for_xnn_ && sequence % n_pack != 0) sequence = ((sequence + (n_pack - 1)) / n_pack) * n_pack; + } #endif outputs[0]->reshape(inputs[0]->batch(), inputs[0]->head() * n_rep_, sequence, inputs[0]->dimension()); diff --git a/src/backends/cpu/op/CPUKVCache.hpp b/src/backends/cpu/op/CPUKVCache.hpp index 4e06c3989..7ff67c305 100644 --- a/src/backends/cpu/op/CPUKVCache.hpp +++ b/src/backends/cpu/op/CPUKVCache.hpp @@ -10,7 +10,7 @@ namespace mllm { class CPUKVCache final : public Op { public: - CPUKVCache(Backend *bn, string opName, int hidden, int head, int n_rep, int cache_max = 100, int threadCount = 4); + CPUKVCache(Backend *bn, string opName, int hidden, int head, int n_rep, bool fa2, int cache_max = 100, int threadCount = 4); virtual ~CPUKVCache() = default; virtual ErrorCode reshape(vector> inputs, vector> outputs) override; virtual ErrorCode load(AbstructLoader &loader) override; @@ -40,6 +40,8 @@ class CPUKVCache final : public Op { bool for_xnn_ = false; int cache_limit_; + + bool fa2_ = false; // not_fa2 }; class CPUKVCacheCreator : public CPUBackend::Creator { @@ -50,7 +52,8 @@ class CPUKVCacheCreator : public CPUBackend::Creator { bool for_xnn = (bool)op_param["for_xnn"]; int hidden = (int)op_param["hidden"]; int head = (int)op_param["head"]; - auto ret = new CPUKVCache(bn, name, hidden, head, n_rep, cache_max, threadCount); + bool fa2 = (bool)op_param["fa2"]; + auto ret = new CPUKVCache(bn, name, hidden, head, n_rep, fa2, cache_max, threadCount); ret->setForXnn(for_xnn); return ret; } diff --git a/src/backends/cpu/op/CPULinearINT8Shadow.cpp b/src/backends/cpu/op/CPULinearINT8Shadow.cpp index f510609c4..461612002 100755 --- a/src/backends/cpu/op/CPULinearINT8Shadow.cpp +++ b/src/backends/cpu/op/CPULinearINT8Shadow.cpp @@ -155,9 +155,9 @@ ErrorCode CPULinearINT8Shadow::execute(vector> inputs, vector int8_t output_clip = outputClip_.dataAt(0, 0, 0, 0); input_scale = input_scale / 127.0; - input_scale = roundf(input_scale * 100000) / 100000; + // input_scale = roundf(input_scale * 100000) / 100000; - output_scale = roundf(output_scale * 100000) / 100000; + // output_scale = roundf(output_scale * 100000) / 100000; memcpy(outputs[0]->hostPtr(), inputs[2]->hostPtr(), inputs[2]->cntSize()); @@ -173,7 +173,7 @@ ErrorCode CPULinearINT8Shadow::execute(vector> inputs, vector for (int j = 0; j < input0_buffer_.sequence(); j++) { for (int k = 0; k < input0_buffer_.dimension(); k++) { float round_value = roundf(input0_buffer_.dataAt(i, h, j, k) / input_scale); - if (round_value > (127.0 * 8) || round_value < (-128.0 * 8)) { + if (round_value > (127.0) || round_value < (-128.0)) { #if defined(__ARM_NEON) float origin_value = round_value * input_scale * weight_scale; float clip_value = std::fmax(std::fmin(round_value, 127), -128) * input_scale * weight_scale; @@ -213,7 +213,7 @@ ErrorCode CPULinearINT8Shadow::execute(vector> inputs, vector } #else - mllm_fp16_t origin_value = round_value * input_scale * weight_scale; + float origin_value = round_value * input_scale * weight_scale; float clip_value = std::fmax(std::fmin(round_value, 127), -128) * input_scale * weight_scale; #pragma omp parallel for collapse(1) num_threads(4) diff --git a/src/backends/cpu/op/CPULinearInt8.cpp b/src/backends/cpu/op/CPULinearInt8.cpp index fdbc09f5a..fa0431007 100644 --- a/src/backends/cpu/op/CPULinearInt8.cpp +++ b/src/backends/cpu/op/CPULinearInt8.cpp @@ -138,7 +138,7 @@ ErrorCode CPULinearInt8::free(vector> inputs, vector()[0] / 127.0; - scale1 = roundf(scale1 * 100000) / 100000; + // scale1 = roundf(scale1 * 100000) / 100000; float scale2 = weightScale_.hostPtr()[0]; @@ -147,7 +147,7 @@ ErrorCode CPULinearInt8::mat_mul_fp32_i8(Tensor *src0_, Tensor *src1, Tensor *ds scale3 = biasScale_.hostPtr()[0]; float scale4 = outputActivatationScale_.hostPtr()[0] / 127.0; - scale4 = roundf(scale4 * 100000) / 100000; + // scale4 = roundf(scale4 * 100000) / 100000; assert(src1->dtype() == MLLM_TYPE_I8); assert(src0_->dtype() == MLLM_TYPE_F32); diff --git a/src/backends/cpu/op/CPUMultimodalRoPE.cpp b/src/backends/cpu/op/CPUMultimodalRoPE.cpp index 4985a6014..c00f4ceb7 100644 --- a/src/backends/cpu/op/CPUMultimodalRoPE.cpp +++ b/src/backends/cpu/op/CPUMultimodalRoPE.cpp @@ -132,12 +132,7 @@ ErrorCode CPUMultimodalRoPE::reshape(vector> inputs, vector(backend_); - if (cpuBackend->isStageSwitching()) { - h_cnt_ = cpuBackend->getCurSequenceLength(); - } -#endif + return Op::reshape(inputs, outputs); } @@ -317,10 +312,7 @@ ErrorCode CPUMultimodalRoPE::doExecute(vector> inputs, vector } } } - h_cnt_ += input->sequence(); - if (h_cnt_ >= pos_max_) { - h_cnt_ = 0; - } + return Op::execute(inputs, outputs); } diff --git a/src/backends/cpu/op/CPUMultimodalRoPEPipeline.cpp b/src/backends/cpu/op/CPUMultimodalRoPEPipeline.cpp new file mode 100644 index 000000000..c23b55f27 --- /dev/null +++ b/src/backends/cpu/op/CPUMultimodalRoPEPipeline.cpp @@ -0,0 +1,330 @@ + +#include "CPUMultimodalRoPEPipeline.hpp" +// #include "Timing.hpp" +#include "Types.hpp" +#include +#include +#include +// #include +#include "backends/cpu/quantize/QuantizeQ8.hpp" + +namespace mllm { + +vector CPUMultimodalRoPEPipeline::theta_; // inv_freq + +vector> CPUMultimodalRoPEPipeline::sin_; +vector> CPUMultimodalRoPEPipeline::cos_; +int CPUMultimodalRoPEPipeline::ishape_old; +int CPUMultimodalRoPEPipeline::last_pos; + +// to avoid conflict with CPUMultimodalRoPE +namespace pipeline_rope { +typedef float (*mllm_rope_init_func)(const OpParam &, std::vector &); + +float multimodal_default_init_rope(const OpParam &config, vector &theta) { + auto base = config.at("base"); // theta_i = base^-(2i/dim) = 1 / base^(2i/dim) i from 0 to (dim/2 - 1) + auto dim = config.at("dim"); + + theta.resize((int)(dim / 2)); +#pragma omp parallel for num_threads(4) + for (int i = 0; i < theta.size(); i++) + theta[i] = 1.0 / pow(base, 2.0 * i / dim); + + return 1.0; +} + +void apply_multimodal_rotary_pos_emb( + const std::vector>> &in_cos, + const std::vector>> &in_sin, + std::vector> &out_cos, + std::vector> &out_sin, + const std::vector &mrope_section) { + int num_rows = in_cos[0].size(); + int num_cols = in_cos[0][0].size(); + // 初始化输出向量大小 + out_cos.resize(num_rows, std::vector(num_cols)); + out_sin.resize(num_rows, std::vector(num_cols)); + // 计算每个块的起始列索引 + std::vector start_cols; + int current_start = 0; + start_cols.push_back(current_start); + for (int s : mrope_section) { + current_start += s; + start_cols.push_back(current_start); + } + // 遍历每个块 + for (int j = 0; j < mrope_section.size(); ++j) { + int layer = j % 3; + int s_j = mrope_section[j]; + int start_col_in = start_cols[j]; + int start_col_out = start_cols[j]; // 输出和输入的起始列相同 + for (int row = 0; row < num_rows; ++row) { + // 处理cos + const auto &in_cos_row = in_cos[layer][row]; + auto &out_cos_row = out_cos[row]; + for (int c = 0; c < s_j; ++c) { + out_cos_row[start_col_out + c] = in_cos_row[start_col_in + c]; + } + // 处理sin + const auto &in_sin_row = in_sin[layer][row]; + auto &out_sin_row = out_sin[row]; + for (int c = 0; c < s_j; ++c) { + out_sin_row[start_col_out + c] = in_sin_row[start_col_in + c]; + } + } + } +} + +void multimodal_sinusoidal_position_embedding(shared_ptr position_ids, int seq_len, int output_dim, const vector &theta, + vector> &sin, vector> &cos, float attention_scaling = 1.0, + const std::vector &mrope_section = {}) { + vector>> tmp_sin; + vector>> tmp_cos; + for (int b = 0; b < position_ids->batch(); ++b) { + vector> cos_freqs(position_ids->dimension(), std::vector(theta.size() * 2, 0)); + vector> sin_freqs(position_ids->dimension(), std::vector(theta.size() * 2, 0)); + for (int i = 0; i < theta.size(); ++i) { + for (int j = 0; j < position_ids->dimension(); ++j) { + auto value = theta[i] * position_ids->dataAt(b, 0, 0, j); + cos_freqs[j][i] = cosf(value) * attention_scaling; + cos_freqs[j][i + theta.size()] = cosf(value) * attention_scaling; + sin_freqs[j][i] = sinf(value) * attention_scaling; + sin_freqs[j][i + theta.size()] = sinf(value) * attention_scaling; + } + } + tmp_cos.push_back(cos_freqs); + tmp_sin.push_back(sin_freqs); + } + if (!mrope_section.empty()) { + apply_multimodal_rotary_pos_emb(tmp_cos, tmp_sin, cos, sin, mrope_section); + } +} +} // namespace pipeline_rope + +CPUMultimodalRoPEPipeline::CPUMultimodalRoPEPipeline(Backend *bn, string opName, float rope_theta, int max_position_embeddings, vector mrope_section, int threadCount) : + thread_count(threadCount), + Op(bn, opName) { + rope_theta_ = rope_theta; + pos_max_ = max_position_embeddings; + mrope_section_ = mrope_section; + for (int i = 0; i < mrope_section.size(); i++) { + mrope_section_.push_back(mrope_section[i]); + } +} + +ErrorCode CPUMultimodalRoPEPipeline::reshape(vector> inputs, vector> outputs) { + assert(inputs.size() == 2); + assert(outputs.size() == 1); + outputs[0]->reshape(inputs[0]->batch(), inputs[0]->head(), inputs[0]->sequence(), inputs[0]->dimension()); + ishape = inputs[0]->dimension() * partial_rotary_factor_; + // pos_max_ = 16384; + auto position_ids = inputs[1]; + + if (sin_.empty() || ishape_old < ishape || position_ids->dataAt(0, 0, 0, position_ids->dimension() - 1) != last_pos) { + auto config = config_; + config["base"] = (float)rope_theta_; + config["dim"] = ishape; + float attention_scaling = pipeline_rope::multimodal_default_init_rope(config, theta_); + ishape_old = ishape; + last_pos = position_ids->dataAt(0, 0, 0, position_ids->dimension() - 1); + pipeline_rope::multimodal_sinusoidal_position_embedding(position_ids, pos_max_, ishape, theta_, sin_, cos_, attention_scaling, mrope_section_); + } + + // if in switching, reset the h_cnt_ + auto cpuBackend = static_cast(backend_); + if (cpuBackend->isStageSwitching()) { + if(cpuBackend->getExecutionType() == PROMPT) { + // set to 0/chunk_size*iter when in prefill stage + h_cnt_ = cpuBackend->getCurSequenceLength(); + } else { + // when switch to decoding, reset the h_cnt_ to 0 + h_cnt_ = 0; + } + + } + return Op::reshape(inputs, outputs); +} + +void CPUMultimodalRoPEPipeline::multimodal_rope_hf(shared_ptr input, shared_ptr output) { + auto out_dtype = output->dtype(); + int partial_dimension = (input->dimension()) * partial_rotary_factor_; + int half = (int)(partial_dimension / 2); + assert(partial_dimension % 2 == 0); + + const int seq_offset = h_cnt_; + if (static_cast(backend_)->getExecutionType() == PROMPT) { + // increment the h_cnt_ when in prefill stage + h_cnt_ += input->sequence(); + } + + if (output->ctype() == BSHD) { + if (input->dtype() == MLLM_TYPE_F16) { +#pragma omp parallel for collapse(4) num_threads(thread_count) + for (int n = 0; n < input->batch(); ++n) { + for (int h = 0; h < input->head(); ++h) { + for (int s = 0; s < input->sequence(); ++s) { // sequance + for (int d = 0; d < partial_dimension / 2; ++d) { + auto v = input->ptrAt(n, h, s, d); + auto o = output->ptrAt(n, h, s, d); + float in_value = static_cast(v[0]); + float in_value_2 = static_cast(v[half]); + float sin_value = sin_[s + seq_offset][d]; + float cos_value = cos_[s + seq_offset][d]; + auto value = in_value * cos_value - in_value_2 * sin_value; + auto value2 = in_value * sin_value + in_value_2 * cos_value; + o[0] = MLLM_FP32_TO_FP16(value); + o[half] = MLLM_FP32_TO_FP16(value2); + } + } + } + } + + } else { + if (out_dtype == MLLM_TYPE_F32) { +#pragma omp parallel for collapse(4) num_threads(thread_count) + for (int n = 0; n < input->batch(); ++n) { + for (int h = 0; h < input->head(); ++h) { + for (int s = 0; s < input->sequence(); ++s) { // sequance + for (int d = 0; d < partial_dimension / 2; ++d) { + auto v = input->ptrAt(n, h, s, d); + auto o = output->ptrAt(n, h, s, d); + float in_value = v[0]; + float in_value_2 = v[half]; + float sin_value = sin_[s + seq_offset][d]; + float cos_value = cos_[s + seq_offset][d]; + auto value = in_value * cos_value - in_value_2 * sin_value; + auto value2 = in_value * sin_value + in_value_2 * cos_value; + o[0] = value; + o[half] = value2; + } + } + } + } + } else if (out_dtype == MLLM_TYPE_F16) { +#pragma omp parallel for collapse(4) num_threads(thread_count) + for (int n = 0; n < input->batch(); ++n) { + for (int h = 0; h < input->head(); ++h) { + for (int s = 0; s < input->sequence(); ++s) { // sequance + for (int d = 0; d < partial_dimension / 2; ++d) { + auto v = input->ptrAt(n, h, s, d); + auto o = output->ptrAt(n, h, s, d); + float in_value = v[0]; + float in_value_2 = v[half]; + float sin_value = sin_[s + seq_offset][d]; + float cos_value = cos_[s + seq_offset][d]; + auto value = in_value * cos_value - in_value_2 * sin_value; + auto value2 = in_value * sin_value + in_value_2 * cos_value; + o[0] = MLLM_FP32_TO_FP16(value); + o[half] = MLLM_FP32_TO_FP16(value2); + } + } + } + } + } + } + return; + } +#pragma omp parallel for collapse(4) num_threads(thread_count) + for (int n = 0; n < input->batch(); ++n) { + for (int h = 0; h < input->head(); ++h) { + for (int s = 0; s < input->sequence(); ++s) { // sequance + for (int d = 0; d < partial_dimension / 2; ++d) { + if (input->dtype() == MLLM_TYPE_F16) { + float in_value = static_cast(input->dataAt(n, h, s, d)); + float in_value_2 = static_cast(input->dataAt(n, h, s, d + partial_dimension / 2)); + float sin_value = sin_[s + seq_offset][d]; + float cos_value = cos_[s + seq_offset][d]; + auto value = in_value * cos_value - in_value_2 * sin_value; + auto value2 = in_value * sin_value + in_value_2 * cos_value; + if (out_dtype == MLLM_TYPE_F32) { + output->setDataAt(n, h, s, d, value); + output->setDataAt(n, h, s, d + partial_dimension / 2, value2); + } else if (out_dtype == MLLM_TYPE_F16) { + output->setDataAt(n, h, s, d, MLLM_FP32_TO_FP16(value)); + output->setDataAt(n, h, s, d + partial_dimension / 2, MLLM_FP32_TO_FP16(value2)); + } + + } else { + float in_value = input->dataAt(n, h, s, d); + float in_value_2 = input->dataAt(n, h, s, d + partial_dimension / 2); + float sin_value = sin_[s + seq_offset][d]; + float cos_value = cos_[s + seq_offset][d]; + auto value = in_value * cos_value - in_value_2 * sin_value; + auto value2 = in_value * sin_value + in_value_2 * cos_value; + if (out_dtype == MLLM_TYPE_F32) { + output->setDataAt(n, h, s, d, value); + output->setDataAt(n, h, s, d + partial_dimension / 2, value2); + } else if (out_dtype == MLLM_TYPE_F16) { + output->setDataAt(n, h, s, d, MLLM_FP32_TO_FP16(value)); + output->setDataAt(n, h, s, d + partial_dimension / 2, MLLM_FP32_TO_FP16(value2)); + } + } + } + } + } + } +} + +// TODO: Q8_0 KVCache can not use!! +ErrorCode CPUMultimodalRoPEPipeline::execute(vector> inputs, vector> outputs) { + if (outputs[0]->dtype() == MLLM_TYPE_Q8_0) { + auto tmp_out = std::make_shared(outputs[0]->backend()); + // tmp_out->setBackend(outputs[0]->backend()); + auto b = outputs[0]->batch(); + auto h = outputs[0]->head(); + auto d = outputs[0]->dimension(); + auto s = outputs[0]->sequence(); + tmp_out->chls() = outputs[0]->chls(); + tmp_out->setCtype(outputs[0]->ctype()); + tmp_out->reshape(b, h, s, d); + tmp_out->setDtype(MLLM_TYPE_F32); + tmp_out->alloc(); + doExecute(inputs, {tmp_out}); +#pragma omp parallel for collapse(3) num_threads(thread_count) + for (int b = 0; b < tmp_out->batch(); b++) { + for (int h = 0; h < tmp_out->head(); h++) { + for (int s = 0; s < tmp_out->sequence(); s++) { + quantize_row_q8_0(tmp_out->hostPtr() + tmp_out->offset(b, h, s, 0), + (char *)outputs[0]->rawHostPtr() + + outputs[0]->offset(b, h, s, 0) * sizeof(block_q8_0) / QK8_0, + tmp_out->dimension()); + } + } + } + return MLLM_NO_ERROR; + } else { + return doExecute(inputs, outputs); + } +} +ErrorCode CPUMultimodalRoPEPipeline::doExecute(vector> inputs, vector> outputs) { + auto &input = inputs[0]; + auto &output = outputs[0]; + auto out_dtype = output->dtype(); + int partial_dimension = (input->dimension()) * partial_rotary_factor_; + // auto start_t = mllm_time_us(); + multimodal_rope_hf(input, output); +#pragma omp parallel for collapse(4) num_threads(thread_count) + for (int n = 0; n < input->batch(); ++n) { + for (int h = 0; h < input->head(); ++h) { + for (int s = 0; s < input->sequence(); ++s) { + for (int d = partial_dimension; d < input->dimension(); ++d) { + if (out_dtype == MLLM_TYPE_F32) { + output->setDataAt(n, h, s, d, input->dataAt(n, h, s, d)); + } else if (out_dtype == MLLM_TYPE_F16) { + output->setDataAt(n, h, s, d, MLLM_FP32_TO_FP16(input->dataAt(n, h, s, d))); + } + } + } + } + } + + return Op::execute(inputs, outputs); +} + +ErrorCode CPUMultimodalRoPEPipeline::load(AbstructLoader &loader) { + return Op::load(loader); +} +ErrorCode CPUMultimodalRoPEPipeline::free(vector> inputs, vector> outputs) { + return Op::free(inputs, outputs); +} +} // namespace mllm diff --git a/src/backends/cpu/op/CPUMultimodalRoPEPipeline.hpp b/src/backends/cpu/op/CPUMultimodalRoPEPipeline.hpp new file mode 100644 index 000000000..9b37e0948 --- /dev/null +++ b/src/backends/cpu/op/CPUMultimodalRoPEPipeline.hpp @@ -0,0 +1,72 @@ +#ifndef MLLM_CPUMULTIMODALROPE_PIPELINE_H +#define MLLM_CPUMULTIMODALROPE_PIPELINE_H + +#include "Op.hpp" +#include "../CPUBackend.hpp" + +namespace mllm { + +class CPUMultimodalRoPEPipeline final : public Op { +public: + CPUMultimodalRoPEPipeline(Backend *bn, string opName, float rope_theta, int max_position_embeddings, vector mrope_section, int threadCount); + + virtual ~CPUMultimodalRoPEPipeline() = default; + virtual ErrorCode reshape(vector> inputs, vector> outputs) override; + virtual ErrorCode load(AbstructLoader &loader) override; + virtual ErrorCode execute(vector> inputs, vector> outputs) override; + virtual ErrorCode free(vector> inputs, vector> outputs) override; + ErrorCode doExecute(vector> inputs, vector> outputs); + +private: + static vector theta_; // inv_freq + static vector> sin_; + static vector> cos_; + static int ishape_old; + static int last_pos; + vector mrope_section_; + int rope_theta_ = 10000; + int h_cnt_ = 0; + int pos_max_ = 16384; + int ishape; + int thread_count = 4; + float partial_rotary_factor_ = 1; + + OpParam config_; + + RoPEThetaType rope_type = DEFAULT; + + void multimodal_rope_hf(shared_ptr input, shared_ptr output); + void clearCache() override { + h_cnt_ = 0; + } +}; + +class CPUMultimodalRoPEPipelineCreator : public CPUBackend::Creator { +public: + virtual Op *create(OpParam op_param, Backend *bn, string name, int threadCount) const { + // int pose_type = op_param["pose_type"]; + // if (op_param.find("rope_theta") == op_param.end()) { + // return new CPUMultimodalRoPEPipeline(bn, name, pose_type, threadCount); + // } + // float rope_theta = op_param["rope_theta"]; + // int max_position_embeddings = op_param["max_position_embeddings"]; + // if (op_param.find("partial_rotary_factor") == op_param.end()) { + // return new CPUMultimodalRoPEPipeline(bn, name, pose_type, rope_theta, max_position_embeddings, threadCount); + // } + // float partial_rotary_factor = op_param["partial_rotary_factor"]; + // return new CPUMultimodalRoPEPipeline(bn, name, pose_type, rope_theta, partial_rotary_factor, max_position_embeddings, threadCount); + + // int pose_type = op_param["pose_type"]; + float rope_theta = op_param["rope_theta"]; + int max_position_embeddings = op_param["max_position_embeddings"]; + int length = op_param.size() - 3; + vector mrope_section; + for (int i = 0; i < length; i++) { + mrope_section.push_back((int)op_param["mrope_section_" + std::to_string(i)]); + } + return new CPUMultimodalRoPEPipeline(bn, name, rope_theta, max_position_embeddings, mrope_section, threadCount); + } +}; +} // namespace mllm + +#endif // MLLM_CPUMULTIMODALROPE_PIPELINE_H \ No newline at end of file diff --git a/src/backends/cpu/op/CPUQuantize.cpp b/src/backends/cpu/op/CPUQuantize.cpp index fa4060f8f..11207ea9d 100644 --- a/src/backends/cpu/op/CPUQuantize.cpp +++ b/src/backends/cpu/op/CPUQuantize.cpp @@ -6,11 +6,16 @@ #include "Types.hpp" #include "backends/cpu/quantize/QuantizeQ8.hpp" +#include +#include #include namespace mllm { -CPUQuantize::CPUQuantize(Backend *bn, string opName, int threadCount):thread_count(threadCount), Op(bn, std::move(opName)) { - activation_dtype_ = MLLM_TYPE_I8; +CPUQuantize::CPUQuantize(Backend *bn, string opName, DataType type, int threadCount) : + thread_count(threadCount), + Op(bn, std::move(opName)) { + assert(type == MLLM_TYPE_I8 || type == MLLM_TYPE_I16); + activation_dtype_ = type; scale_.setBackend(bn); } @@ -30,46 +35,53 @@ ErrorCode CPUQuantize::execute(vector> inputs, vectordimension(); float quantScale = 0; - quantScale = scale_.hostPtr()[0] / 127.0; - quantScale = roundf(quantScale * 100000) / 100000; + // quantScale = scale_.hostPtr()[0] / 127.0; + // quantScale = roundf(quantScale * 100000) / 100000; + switch (activation_dtype_) { + case MLLM_TYPE_I8: + quantScale = scale_.hostPtr()[0] / (pow(2, 7) - 1); + break; + case MLLM_TYPE_I16: + quantScale = scale_.hostPtr()[0] / (pow(2, 15) - 1); + break; + default: + return NOT_SUPPORT; + } + // quantScale = roundf(quantScale * 100000) / 100000; auto src0 = inputs[0]; - auto src0_i8 = outputs[0]; - -// #pragma omp parallel for collapse(4) - // for (int b = 0; b dataAt(b, h, s, d); - // int32_t v = static_cast(Round(value / quantScale)); - // v = std::max (std::min(v, 127), -128); - // output->setDataAt(b, h, s, d, static_cast(v)); - // } - // } - // std::cout << std::endl; - // } - // } + auto out0 = outputs[0]; + if (activation_dtype_ == MLLM_TYPE_I8) { #pragma omp parallel for collapse(3) num_threads(thread_count) - for (int b = 0; b < batch; b++) { - for (int h = 0; h hostPtr() + src0->offset(b, h, s, 0), - src0_i8->hostPtr() + src0_i8->offset(b, h, s, 0), + for (int b = 0; b < batch; b++) { + for (int h = 0; h < head; h++) { + for (int s = 0; s < seq; s++) { + quantize_row_i8(src0->hostPtr() + src0->offset(b, h, s, 0), + out0->hostPtr() + out0->offset(b, h, s, 0), dim, quantScale); + } + } + } + } else if (activation_dtype_ == MLLM_TYPE_I16) { +#pragma omp parallel for collapse(3) num_threads(thread_count) + for (int b = 0; b < batch; b++) { + for (int h = 0; h < head; h++) { + for (int s = 0; s < seq; s++) { + quantize_row_i16(src0->hostPtr() + src0->offset(b, h, s, 0), + out0->hostPtr() + out0->offset(b, h, s, 0), + dim, quantScale); + } } } + } else { + return NOT_SUPPORT; } - // outputs[0]->printData(); - - return Op::execute(inputs, outputs); } ErrorCode CPUQuantize::setUp(vector> inputs, vector> outputs) { - activation_dtype_ = MLLM_TYPE_I8; return Op::setUp(inputs, outputs); } diff --git a/src/backends/cpu/op/CPUQuantize.hpp b/src/backends/cpu/op/CPUQuantize.hpp index df3751ee1..13df2b2e0 100644 --- a/src/backends/cpu/op/CPUQuantize.hpp +++ b/src/backends/cpu/op/CPUQuantize.hpp @@ -7,10 +7,11 @@ #include "Op.hpp" #include "../CPUBackend.hpp" +#include "Types.hpp" namespace mllm { class CPUQuantize final : public Op { public: - CPUQuantize(Backend *bn, string opName, int threadCount); + CPUQuantize(Backend *bn, string opName, DataType type, int threadCount); virtual ~CPUQuantize() = default; virtual ErrorCode reshape(vector> inputs, vector> outputs) override; virtual ErrorCode execute(vector> inputs, vector> outputs) override; @@ -36,7 +37,7 @@ class CPUQuantize final : public Op { class CPUQuantizeCreator : public CPUBackend::Creator { public: virtual Op *create(OpParam op_param, Backend *bn, string name, int threadCount) const { - return new CPUQuantize(bn, name, threadCount); + return new CPUQuantize(bn, name, (DataType)op_param["dtype"], threadCount); } }; } // namespace mllm diff --git a/src/backends/cpu/quantize/QuantizeQ8.cpp b/src/backends/cpu/quantize/QuantizeQ8.cpp index 44e16292d..35f51f7e8 100644 --- a/src/backends/cpu/quantize/QuantizeQ8.cpp +++ b/src/backends/cpu/quantize/QuantizeQ8.cpp @@ -541,4 +541,84 @@ void quantize_round_dequantize_row_i8(const float *__restrict vx, float *__restr } } +// per-tensor int16 +void quantize_row_i16(const float *__restrict x, void *__restrict vy, int k, float scale) { + const int BLOCK_SIZE = 32; + assert(k % BLOCK_SIZE == 0); + const int nb = k / BLOCK_SIZE; + + int16_t *__restrict y = (int16_t *)vy; + + const float d = scale; + const float id = d ? 1.0f / d : 0.0f; + +#if defined(__ARM_NEON) + const int32x4_t min_32768 = vdupq_n_s32(-32768); + const int32x4_t max32767 = vdupq_n_s32(32767); + + for (int i = 0; i < nb; i++) { + float32x4_t srcv[8]; + for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i * 32 + 4 * j); + + for (int j = 0; j < 8; j++) { + const float32x4_t v = vmulq_n_f32(srcv[j], id); + int32x4_t vi = vcvtnq_s32_f32(v); + + vi = vminq_s32(vi, max32767); + vi = vmaxq_s32(vi, min_32768); + + y[i * 32 + 4 * j + 0] = (int16_t)vgetq_lane_s32(vi, 0); + y[i * 32 + 4 * j + 1] = (int16_t)vgetq_lane_s32(vi, 1); + y[i * 32 + 4 * j + 2] = (int16_t)vgetq_lane_s32(vi, 2); + y[i * 32 + 4 * j + 3] = (int16_t)vgetq_lane_s32(vi, 3); + } + } +#else + // fallback scalar version + for (int i = 0; i < k; i++) { + int v = (int)roundf(x[i] * id); + if (v < -32768) v = -32768; + if (v > 32767) v = 32767; + y[i] = (int16_t)v; + } +#endif +} + +void dequantize_row_i16(const void *__restrict vx, float *__restrict y, int k, float scale) { +#if defined(__ARM_NEON) + const int16_t *__restrict x = (int16_t *)vx; + + float32x4_t scale_vec = vdupq_n_f32(scale); + + int i; + for (i = 0; i <= k - 8; i += 8) { + // Load 8 int16_t values + int16x8_t x_vec = vld1q_s16(&x[i]); + + // Split into lower and upper 4 elements + int32x4_t x_lo = vmovl_s16(vget_low_s16(x_vec)); // 前4个 int16 -> int32 + int32x4_t x_hi = vmovl_s16(vget_high_s16(x_vec)); // 后4个 int16 -> int32 + + // Convert to float32 + float32x4_t x_f32_lo = vcvtq_f32_s32(x_lo); + float32x4_t x_f32_hi = vcvtq_f32_s32(x_hi); + + // Multiply by scale + x_f32_lo = vmulq_f32(x_f32_lo, scale_vec); + x_f32_hi = vmulq_f32(x_f32_hi, scale_vec); + + // Store result + vst1q_f32(&y[i], x_f32_lo); + vst1q_f32(&y[i + 4], x_f32_hi); + } + + // Handle remaining elements + for (; i < k; i++) { + y[i] = x[i] * scale; + } +#else +// TODO: avx +#endif +} + // #endif \ No newline at end of file diff --git a/src/backends/cpu/quantize/QuantizeQ8.hpp b/src/backends/cpu/quantize/QuantizeQ8.hpp index defb69943..77ead9267 100644 --- a/src/backends/cpu/quantize/QuantizeQ8.hpp +++ b/src/backends/cpu/quantize/QuantizeQ8.hpp @@ -40,5 +40,8 @@ void quantize_row_i8(const float *__restrict x, void *__restrict y, int k, float void dequantize_row_i8(const void *__restrict vx, float *__restrict y, int k, float scale = 1.f); void dequantize_row_i8_to_fp16(const void *__restrict vx, void *__restrict vy, int k, float scale = 1.f); void quantize_round_dequantize_row_i8(const float *__restrict vx, float *__restrict y, int k, float scale = 1.f); +// per-tensor int16 quantize +void quantize_row_i16(const float *__restrict x, void *__restrict y, int k, float scale = 1.f); +void dequantize_row_i16(const void *__restrict vx, float *__restrict y, int k, float scale = 1.f); #endif // MLLM_QUANTIZEQ8_HPP diff --git a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/config/LLaMAOpPackageHtp.xml b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/config/LLaMAOpPackageHtp.xml index 259f786ef..b302f2513 100755 --- a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/config/LLaMAOpPackageHtp.xml +++ b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/config/LLaMAOpPackageHtp.xml @@ -729,6 +729,70 @@ Confidential and Proprietary - Qualcomm Technologies, Inc. HTP + + + LLaMADequantizeAdd + + + LLaMA Dequantize and Add + + + + + in[0] + + input activation + + true + BACKEND_SPECIFIC + + 4D + NHWC + [N, C, H , W] + + + + + in[1] + + input bias + + true + BACKEND_SPECIFIC + + 4D + NHWC + [N, C, H , W] + + + + + out[0] + + output activation + + true + BACKEND_SPECIFIC + + 4D + [N, C, H , W] + + + + + scale + true + QNN_DATATYPE_FLOAT_32 + + SCALAR + + N-1 + + + + HTP + + LLaMAQuantize @@ -1554,16 +1618,39 @@ Confidential and Proprietary - Qualcomm Technologies, Inc. out[0] QNN_DATATYPE_SFIXED_POINT_8 + QNN_DATATYPE_SFIXED_POINT_16 - + LLaMADequantize in[0] QNN_DATATYPE_SFIXED_POINT_8 + QNN_DATATYPE_SFIXED_POINT_16 + + + out[0] + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + + + + + + LLaMADequantizeAdd + + + in[0] + QNN_DATATYPE_SFIXED_POINT_8 + QNN_DATATYPE_SFIXED_POINT_16 + + + in[1] + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 out[0] diff --git a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/LLaMAPackageInterface.cpp b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/LLaMAPackageInterface.cpp index de7261b56..85448fb63 100755 --- a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/LLaMAPackageInterface.cpp +++ b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/LLaMAPackageInterface.cpp @@ -19,32 +19,33 @@ BEGIN_PKG_OPS_OPTS_LIST() * registered to the HTP Core. * Append the latest OpName at the bottom */ -DECLARE_PKG_OPS_OPTS_LIST(PKG_IRoPE) -DECLARE_PKG_OPS_OPTS_LIST(PKG_LLaMALinear) -DECLARE_PKG_OPS_OPTS_LIST(PKG_SplitInput) -DECLARE_PKG_OPS_OPTS_LIST(PKG_LLaMAReLU) -DECLARE_PKG_OPS_OPTS_LIST(PKG_LLaMASuperSiLU) -DECLARE_PKG_OPS_OPTS_LIST(PKG_LLaMAQuantize) -DECLARE_PKG_OPS_OPTS_LIST(PKG_LLaMAMul) +DECLARE_PKG_OPS_OPTS_LIST(PKG_RMSNorm) DECLARE_PKG_OPS_OPTS_LIST(PKG_KVCache) +DECLARE_PKG_OPS_OPTS_LIST(PKG_LLaMADequantizeAdd) +DECLARE_PKG_OPS_OPTS_LIST(PKG_LLaMAMul) +DECLARE_PKG_OPS_OPTS_LIST(PKG_MergeOutput) +DECLARE_PKG_OPS_OPTS_LIST(PKG_LLaMAReLU) +DECLARE_PKG_OPS_OPTS_LIST(PKG_CausalMask) +DECLARE_PKG_OPS_OPTS_LIST(PKG_SiLU) DECLARE_PKG_OPS_OPTS_LIST(PKG_Attention) DECLARE_PKG_OPS_OPTS_LIST(PKG_QLayerNorm) +DECLARE_PKG_OPS_OPTS_LIST(PKG_RoPE) +DECLARE_PKG_OPS_OPTS_LIST(PKG_WNop) DECLARE_PKG_OPS_OPTS_LIST(PKG_LLaMAAdd) -DECLARE_PKG_OPS_OPTS_LIST(PKG_CausalMask) +DECLARE_PKG_OPS_OPTS_LIST(PKG_IRoPE) +DECLARE_PKG_OPS_OPTS_LIST(PKG_LLaMALinear) +DECLARE_PKG_OPS_OPTS_LIST(PKG_SplitInput) DECLARE_PKG_OPS_OPTS_LIST(PKG_HeadMatmul) -DECLARE_PKG_OPS_OPTS_LIST(PKG_RoPE) DECLARE_PKG_OPS_OPTS_LIST(PKG_LLaMADequantize) -DECLARE_PKG_OPS_OPTS_LIST(PKG_WNop) -DECLARE_PKG_OPS_OPTS_LIST(PKG_MergeOutput) -DECLARE_PKG_OPS_OPTS_LIST(PKG_RMSNorm) -DECLARE_PKG_OPS_OPTS_LIST(PKG_SiLU) +DECLARE_PKG_OPS_OPTS_LIST(PKG_LLaMASuperSiLU) +DECLARE_PKG_OPS_OPTS_LIST(PKG_LLaMAQuantize) END_PKG_OPS_OPTS_LIST() // op package info static constexpr auto sg_packageName = THIS_PKG_NAME_STR; // package name passed in as compile flag -static std::array sg_opNames{{"IRoPE", "LLaMALinear", "SplitInput", "LLaMAReLU", "LLaMASuperSiLU", "LLaMAQuantize", "LLaMAMul", "KVCache", "Attention", "QLayerNorm", "LLaMAAdd", "CausalMask", "HeadMatmul", "RoPE", "LLaMADequantize", "WNop", "MergeOutput", "RMSNorm", "SiLU"}}; +static std::array sg_opNames{{"RMSNorm", "KVCache", "LLaMADequantizeAdd", "LLaMAMul", "MergeOutput", "LLaMAReLU", "CausalMask", "SiLU", "Attention", "QLayerNorm", "RoPE", "WNop", "LLaMAAdd", "IRoPE", "LLaMALinear", "SplitInput", "HeadMatmul", "LLaMADequantize", "LLaMASuperSiLU", "LLaMAQuantize"}}; static Qnn_ApiVersion_t sg_sdkApiVersion = QNN_HTP_API_VERSION_INIT; static QnnOpPackage_Info_t sg_packageInfo = QNN_OP_PACKAGE_INFO_INIT; @@ -228,43 +229,43 @@ Qnn_ErrorHandle_t LLaMAPackageValidateOpConfig (Qnn_OpConfig_t opConfig){ * Check if op config type matches any registered ops * If a match is found, check number of inputs, outputs and params */ - if (std::string(opConfig.v1.typeName) == "IRoPE"){ - if (opConfig.v1.numOfParams != 1 || opConfig.v1.numOfInputs != 4 || opConfig.v1.numOfOutputs != 1){ + if (std::string(opConfig.v1.typeName) == "RMSNorm"){ + if (opConfig.v1.numOfParams != 0 || opConfig.v1.numOfInputs != 2 || opConfig.v1.numOfOutputs != 1){ return QNN_OP_PACKAGE_ERROR_VALIDATION_FAILURE; } } - else if (std::string(opConfig.v1.typeName) == "LLaMALinear"){ - if (opConfig.v1.numOfParams != 4 || opConfig.v1.numOfInputs != 3 || opConfig.v1.numOfOutputs != 1){ + else if (std::string(opConfig.v1.typeName) == "KVCache"){ + if (opConfig.v1.numOfParams != 1 || opConfig.v1.numOfInputs != 2 || opConfig.v1.numOfOutputs != 1){ return QNN_OP_PACKAGE_ERROR_VALIDATION_FAILURE; } } - else if (std::string(opConfig.v1.typeName) == "SplitInput"){ - if (opConfig.v1.numOfParams != 1 || opConfig.v1.numOfInputs != 2 || opConfig.v1.numOfOutputs != 2){ + else if (std::string(opConfig.v1.typeName) == "LLaMADequantizeAdd"){ + if (opConfig.v1.numOfParams != 1 || opConfig.v1.numOfInputs != 2 || opConfig.v1.numOfOutputs != 1){ return QNN_OP_PACKAGE_ERROR_VALIDATION_FAILURE; } } - else if (std::string(opConfig.v1.typeName) == "LLaMAReLU"){ - if (opConfig.v1.numOfParams != 0 || opConfig.v1.numOfInputs != 1 || opConfig.v1.numOfOutputs != 1){ + else if (std::string(opConfig.v1.typeName) == "LLaMAMul"){ + if (opConfig.v1.numOfParams != 0 || opConfig.v1.numOfInputs != 2 || opConfig.v1.numOfOutputs != 1){ return QNN_OP_PACKAGE_ERROR_VALIDATION_FAILURE; } } - else if (std::string(opConfig.v1.typeName) == "LLaMASuperSiLU"){ - if (opConfig.v1.numOfParams != 3 || opConfig.v1.numOfInputs != 2 || opConfig.v1.numOfOutputs != 1){ + else if (std::string(opConfig.v1.typeName) == "MergeOutput"){ + if (opConfig.v1.numOfParams != 1 || opConfig.v1.numOfInputs != 4 || opConfig.v1.numOfOutputs != 1){ return QNN_OP_PACKAGE_ERROR_VALIDATION_FAILURE; } } - else if (std::string(opConfig.v1.typeName) == "LLaMAQuantize"){ - if (opConfig.v1.numOfParams != 1 || opConfig.v1.numOfInputs != 1 || opConfig.v1.numOfOutputs != 1){ + else if (std::string(opConfig.v1.typeName) == "LLaMAReLU"){ + if (opConfig.v1.numOfParams != 0 || opConfig.v1.numOfInputs != 1 || opConfig.v1.numOfOutputs != 1){ return QNN_OP_PACKAGE_ERROR_VALIDATION_FAILURE; } } - else if (std::string(opConfig.v1.typeName) == "LLaMAMul"){ - if (opConfig.v1.numOfParams != 0 || opConfig.v1.numOfInputs != 2 || opConfig.v1.numOfOutputs != 1){ + else if (std::string(opConfig.v1.typeName) == "CausalMask"){ + if (opConfig.v1.numOfParams != 0 || opConfig.v1.numOfInputs != 1 || opConfig.v1.numOfOutputs != 1){ return QNN_OP_PACKAGE_ERROR_VALIDATION_FAILURE; } } - else if (std::string(opConfig.v1.typeName) == "KVCache"){ - if (opConfig.v1.numOfParams != 1 || opConfig.v1.numOfInputs != 2 || opConfig.v1.numOfOutputs != 1){ + else if (std::string(opConfig.v1.typeName) == "SiLU"){ + if (opConfig.v1.numOfParams != 0 || opConfig.v1.numOfInputs != 1 || opConfig.v1.numOfOutputs != 1){ return QNN_OP_PACKAGE_ERROR_VALIDATION_FAILURE; } } @@ -278,48 +279,53 @@ Qnn_ErrorHandle_t LLaMAPackageValidateOpConfig (Qnn_OpConfig_t opConfig){ return QNN_OP_PACKAGE_ERROR_VALIDATION_FAILURE; } } - else if (std::string(opConfig.v1.typeName) == "LLaMAAdd"){ - if (opConfig.v1.numOfParams != 0 || opConfig.v1.numOfInputs != 2 || opConfig.v1.numOfOutputs != 1){ + else if (std::string(opConfig.v1.typeName) == "RoPE"){ + if (opConfig.v1.numOfParams != 1 || opConfig.v1.numOfInputs != 4 || opConfig.v1.numOfOutputs != 1){ return QNN_OP_PACKAGE_ERROR_VALIDATION_FAILURE; } } - else if (std::string(opConfig.v1.typeName) == "CausalMask"){ - if (opConfig.v1.numOfParams != 0 || opConfig.v1.numOfInputs != 1 || opConfig.v1.numOfOutputs != 1){ + else if (std::string(opConfig.v1.typeName) == "WNop"){ + if (opConfig.v1.numOfParams != 1 || opConfig.v1.numOfInputs != 2 || opConfig.v1.numOfOutputs != 2){ return QNN_OP_PACKAGE_ERROR_VALIDATION_FAILURE; } } - else if (std::string(opConfig.v1.typeName) == "HeadMatmul"){ - if (opConfig.v1.numOfParams != 2 || opConfig.v1.numOfInputs != 2 || opConfig.v1.numOfOutputs != 1){ + else if (std::string(opConfig.v1.typeName) == "LLaMAAdd"){ + if (opConfig.v1.numOfParams != 0 || opConfig.v1.numOfInputs != 2 || opConfig.v1.numOfOutputs != 1){ return QNN_OP_PACKAGE_ERROR_VALIDATION_FAILURE; } } - else if (std::string(opConfig.v1.typeName) == "RoPE"){ + else if (std::string(opConfig.v1.typeName) == "IRoPE"){ if (opConfig.v1.numOfParams != 1 || opConfig.v1.numOfInputs != 4 || opConfig.v1.numOfOutputs != 1){ return QNN_OP_PACKAGE_ERROR_VALIDATION_FAILURE; } } - else if (std::string(opConfig.v1.typeName) == "LLaMADequantize"){ - if (opConfig.v1.numOfParams != 1 || opConfig.v1.numOfInputs != 1 || opConfig.v1.numOfOutputs != 1){ + else if (std::string(opConfig.v1.typeName) == "LLaMALinear"){ + if (opConfig.v1.numOfParams != 4 || opConfig.v1.numOfInputs != 3 || opConfig.v1.numOfOutputs != 1){ return QNN_OP_PACKAGE_ERROR_VALIDATION_FAILURE; } } - else if (std::string(opConfig.v1.typeName) == "WNop"){ + else if (std::string(opConfig.v1.typeName) == "SplitInput"){ if (opConfig.v1.numOfParams != 1 || opConfig.v1.numOfInputs != 2 || opConfig.v1.numOfOutputs != 2){ return QNN_OP_PACKAGE_ERROR_VALIDATION_FAILURE; } } - else if (std::string(opConfig.v1.typeName) == "MergeOutput"){ - if (opConfig.v1.numOfParams != 1 || opConfig.v1.numOfInputs != 4 || opConfig.v1.numOfOutputs != 1){ + else if (std::string(opConfig.v1.typeName) == "HeadMatmul"){ + if (opConfig.v1.numOfParams != 2 || opConfig.v1.numOfInputs != 2 || opConfig.v1.numOfOutputs != 1){ return QNN_OP_PACKAGE_ERROR_VALIDATION_FAILURE; } } - else if (std::string(opConfig.v1.typeName) == "RMSNorm"){ - if (opConfig.v1.numOfParams != 0 || opConfig.v1.numOfInputs != 2 || opConfig.v1.numOfOutputs != 1){ + else if (std::string(opConfig.v1.typeName) == "LLaMADequantize"){ + if (opConfig.v1.numOfParams != 1 || opConfig.v1.numOfInputs != 1 || opConfig.v1.numOfOutputs != 1){ return QNN_OP_PACKAGE_ERROR_VALIDATION_FAILURE; } } - else if (std::string(opConfig.v1.typeName) == "SiLU"){ - if (opConfig.v1.numOfParams != 0 || opConfig.v1.numOfInputs != 1 || opConfig.v1.numOfOutputs != 1){ + else if (std::string(opConfig.v1.typeName) == "LLaMASuperSiLU"){ + if (opConfig.v1.numOfParams != 3 || opConfig.v1.numOfInputs != 2 || opConfig.v1.numOfOutputs != 1){ + return QNN_OP_PACKAGE_ERROR_VALIDATION_FAILURE; + } + } + else if (std::string(opConfig.v1.typeName) == "LLaMAQuantize"){ + if (opConfig.v1.numOfParams != 1 || opConfig.v1.numOfInputs != 1 || opConfig.v1.numOfOutputs != 1){ return QNN_OP_PACKAGE_ERROR_VALIDATION_FAILURE; } } diff --git a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/Attention.cpp b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/Attention.cpp index e3db65468..559985e48 100755 --- a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/Attention.cpp +++ b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/Attention.cpp @@ -9,18 +9,16 @@ #include "QnnOpPackage.h" #include "HTP/core/simple_reg.h" - BEGIN_PKG_OP_DEFINITION(PKG_Attention); - // op execute function declarations -template +template GraphStatus attentionImpl(TensorType1 &out_0, const TensorType1 &in_0, const TensorType1 &in_1, - const TensorType& in_2, - const TensorType& in_3, - const TensorType& in_4); + const TensorType &in_2, + const TensorType &in_3, + const TensorType &in_4); // forward declaration of sample cost function static float attentionCostFunc(const Op *op); @@ -65,11 +63,11 @@ DEF_PACKAGE_OP((attentionImpl), "Attention") * one definition per op, and this is optional * syntax: DEF_PACKAGE_PARAM_ORDER(OP,PARAM1,MANDATORY1,DEFAULT1,PARAM2,MANDATORY2,DEFAULT2...) * one or more parameters can be specified for each op - * order of parameters listed determines the order of parameters passed into op execution functions + * order of parameters listed determines the order of parameters passed into op execution functions * if an op does not have a parameter order definition, parameter order passed into Qnn_addNode * will be passed into op execution functions * if an op has a parameter order definition, any parameter passed into Qnn_addNode with unlisted - * name will be abandoned + * name will be abandoned * if two or more op packages with the same package name will be registered, they cannot list * conflicting parameter orders * PARAM refers to parameter name as a string literal @@ -83,47 +81,41 @@ DEF_PACKAGE_OP((attentionImpl), "Attention") * Qnn_addNode */ - /* execute functions for ops */ -template +template GraphStatus attentionImpl(TensorType1 &out_0, const TensorType1 &in_0, const TensorType1 &in_1, - const TensorType& in_2, - const TensorType& in_3, - const TensorType& in_4) + const TensorType &in_2, + const TensorType &in_3, + const TensorType &in_4) { - /* - * add code here - * */ - /* - * To have good performance and stability, it is required to avoid heap memory - * allocation in this function. The heap memory allocation includes but not - * limited to calling malloc, operator new, constructing STL container objects - * like std::vector with default allocator, and adding items like calling - * std::vector::push_back to STL container objects with default allocator. - * - * Please check in SDK documentation for more information. - */ - return GraphStatus::Success; + /* + * add code here + * */ + /* + * To have good performance and stability, it is required to avoid heap memory + * allocation in this function. The heap memory allocation includes but not + * limited to calling malloc, operator new, constructing STL container objects + * like std::vector with default allocator, and adding items like calling + * std::vector::push_back to STL container objects with default allocator. + * + * Please check in SDK documentation for more information. + */ + return GraphStatus::Success; } -__attribute__((unused)) static float attentionCostFunc(const Op *op) -{ - /* - * add code here - * */ +__attribute__((unused)) static float attentionCostFunc(const Op *op) { + /* + * add code here + * */ - float cost = 0.0; // add cost computation here - return cost; + float cost = 0.0; // add cost computation here + return cost; } - - - - /* At the bottom of the op file, call END_PKG_OP_DEFINITION(), where is as BEGIN_PKG_OP_DEFINITION */ diff --git a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/CausalMask.cpp b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/CausalMask.cpp index c3100cea1..a9ab7cd61 100755 --- a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/CausalMask.cpp +++ b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/CausalMask.cpp @@ -13,11 +13,10 @@ BEGIN_PKG_OP_DEFINITION(PKG_CausalMask); - // op execute function declarations -template -GraphStatus causalmaskImpl(TensorType& out_0, - const TensorType& in_0); +template +GraphStatus causalmaskImpl(TensorType &out_0, + const TensorType &in_0); // forward declaration of sample cost function static float causalmaskCostFunc(const Op *op); @@ -62,11 +61,11 @@ DEF_PACKAGE_OP((causalmaskImpl), "CausalMask") * one definition per op, and this is optional * syntax: DEF_PACKAGE_PARAM_ORDER(OP,PARAM1,MANDATORY1,DEFAULT1,PARAM2,MANDATORY2,DEFAULT2...) * one or more parameters can be specified for each op - * order of parameters listed determines the order of parameters passed into op execution functions + * order of parameters listed determines the order of parameters passed into op execution functions * if an op does not have a parameter order definition, parameter order passed into Qnn_addNode * will be passed into op execution functions * if an op has a parameter order definition, any parameter passed into Qnn_addNode with unlisted - * name will be abandoned + * name will be abandoned * if two or more op packages with the same package name will be registered, they cannot list * conflicting parameter orders * PARAM refers to parameter name as a string literal @@ -80,78 +79,67 @@ DEF_PACKAGE_OP((causalmaskImpl), "CausalMask") * Qnn_addNode */ - /* execute functions for ops */ -template -GraphStatus causalmaskImpl(TensorType& out_0, - const TensorType& in_0) +template +GraphStatus causalmaskImpl(TensorType &out_0, + const TensorType &in_0) { - /* - * add code here - * */ - /* - * To have good performance and stability, it is required to avoid heap memory - * allocation in this function. The heap memory allocation includes but not - * limited to calling malloc, operator new, constructing STL container objects - * like std::vector with default allocator, and adding items like calling - * std::vector::push_back to STL container objects with default allocator. - * - * Please check in SDK documentation for more information. - */ - out_0.set_dims(in_0); - - int old_dim = 0; - - // NHSD - auto [b_in, h_in, w_in, d_in] = in_0.dims(); - - // S > 1 => mask - if (w_in > 1) { - for (Idx b = 0; b < b_in; b++) { - for (Idx h = 0; h < h_in; h++) { - for (Idx w = 0; w < w_in; w++) { - // CausalMask - for (Idx d = 0; d < d_in; d++) { - - float in_value = in_0(b, h, w, d); - - if (d > w + old_dim) - out_0(b, h, w, d) = in_value - MASK_INFINITY; - else - out_0(b, h, w, d) = in_value; - - } + /* + * add code here + * */ + /* + * To have good performance and stability, it is required to avoid heap memory + * allocation in this function. The heap memory allocation includes but not + * limited to calling malloc, operator new, constructing STL container objects + * like std::vector with default allocator, and adding items like calling + * std::vector::push_back to STL container objects with default allocator. + * + * Please check in SDK documentation for more information. + */ + out_0.set_dims(in_0); + + int old_dim = 0; + + // NHSD + auto [b_in, h_in, w_in, d_in] = in_0.dims(); + + // S > 1 => mask + if (w_in > 1) { + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + for (Idx w = 0; w < w_in; w++) { + // CausalMask + for (Idx d = 0; d < d_in; d++) { + float in_value = in_0(b, h, w, d); + + if (d > w + old_dim) + out_0(b, h, w, d) = in_value - MASK_INFINITY; + else + out_0(b, h, w, d) = in_value; + } + } + } } - } + } else { + auto in_ptr = in_0.raw_data_const(); + auto out_ptr = out_0.raw_data(); + memcpy(out_ptr, in_ptr, b_in * h_in * w_in * d_in * 4); } - } else { - auto in_ptr = in_0.raw_data_const(); - auto out_ptr = out_0.raw_data(); - memcpy(out_ptr, in_ptr, b_in*h_in*w_in*d_in*4); - } - - - - return GraphStatus::Success; + return GraphStatus::Success; } -__attribute__((unused)) static float causalmaskCostFunc(const Op *op) -{ - /* - * add code here - * */ +__attribute__((unused)) static float causalmaskCostFunc(const Op *op) { + /* + * add code here + * */ - float cost = 0.0; // add cost computation here - return cost; + float cost = 0.0; // add cost computation here + return cost; } - - - - /* At the bottom of the op file, call END_PKG_OP_DEFINITION(), where is as BEGIN_PKG_OP_DEFINITION */ diff --git a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/HeadMatmul.cpp b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/HeadMatmul.cpp index 18440880c..8dbd62f29 100755 --- a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/HeadMatmul.cpp +++ b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/HeadMatmul.cpp @@ -9,25 +9,24 @@ #include "QnnOpPackage.h" #include "HTP/core/simple_reg.h" - BEGIN_PKG_OP_DEFINITION(PKG_HeadMatmul); static Qnn_Scalar_t sg_opDefaultTranspose_In0Scalar = {.dataType = Qnn_DataType_t::QNN_DATATYPE_BOOL_8, - .bool8Value = false}; + .bool8Value = false}; static Qnn_Param_t sg_opDefaultTranspose_In0 = {.paramType = QNN_PARAMTYPE_SCALAR, - .scalarParam = sg_opDefaultTranspose_In0Scalar}; + .scalarParam = sg_opDefaultTranspose_In0Scalar}; static Qnn_Scalar_t sg_opDefaultTranspose_In1Scalar = {.dataType = Qnn_DataType_t::QNN_DATATYPE_BOOL_8, - .bool8Value = false}; + .bool8Value = false}; static Qnn_Param_t sg_opDefaultTranspose_In1 = {.paramType = QNN_PARAMTYPE_SCALAR, - .scalarParam = sg_opDefaultTranspose_In1Scalar}; + .scalarParam = sg_opDefaultTranspose_In1Scalar}; // op execute function declarations -template -GraphStatus headmatmulImpl(TensorType& out_0, - const TensorType& in_0, - const TensorType& in_1, - const QuantUint16Tensor& transpose_in0, - const QuantUint16Tensor& transpose_in1); +template +GraphStatus headmatmulImpl(TensorType &out_0, + const TensorType &in_0, + const TensorType &in_1, + const QuantUint16Tensor &transpose_in0, + const QuantUint16Tensor &transpose_in1); // forward declaration of sample cost function static float headmatmulCostFunc(const Op *op); @@ -72,11 +71,11 @@ DEF_PACKAGE_OP((headmatmulImpl), "HeadMatmul") * one definition per op, and this is optional * syntax: DEF_PACKAGE_PARAM_ORDER(OP,PARAM1,MANDATORY1,DEFAULT1,PARAM2,MANDATORY2,DEFAULT2...) * one or more parameters can be specified for each op - * order of parameters listed determines the order of parameters passed into op execution functions + * order of parameters listed determines the order of parameters passed into op execution functions * if an op does not have a parameter order definition, parameter order passed into Qnn_addNode * will be passed into op execution functions * if an op has a parameter order definition, any parameter passed into Qnn_addNode with unlisted - * name will be abandoned + * name will be abandoned * if two or more op packages with the same package name will be registered, they cannot list * conflicting parameter orders * PARAM refers to parameter name as a string literal @@ -89,7 +88,7 @@ DEF_PACKAGE_OP((headmatmulImpl), "HeadMatmul") * graph construction will skip this parameter when this parameter is not provided at * Qnn_addNode */ -DEF_PACKAGE_PARAM_ORDER("HeadMatmul", +DEF_PACKAGE_PARAM_ORDER("HeadMatmul", "transpose_in0", false, &sg_opDefaultTranspose_In0, @@ -97,77 +96,65 @@ DEF_PACKAGE_PARAM_ORDER("HeadMatmul", false, &sg_opDefaultTranspose_In1) - /* execute functions for ops */ -template -GraphStatus headmatmulImpl(TensorType& out_0, - const TensorType& in_0, - const TensorType& in_1, - const QuantUint16Tensor& transpose_in0, - const QuantUint16Tensor& transpose_in1) +template +GraphStatus headmatmulImpl(TensorType &out_0, + const TensorType &in_0, + const TensorType &in_1, + const QuantUint16Tensor &transpose_in0, + const QuantUint16Tensor &transpose_in1) { - /* - * add code here - * */ - /* - * To have good performance and stability, it is required to avoid heap memory - * allocation in this function. The heap memory allocation includes but not - * limited to calling malloc, operator new, constructing STL container objects - * like std::vector with default allocator, and adding items like calling - * std::vector::push_back to STL container objects with default allocator. - * - * Please check in SDK documentation for more information. - */ - - auto transpose_in0_ = transpose_in0(0,0,0,0); - auto transpose_in1_ = transpose_in1(0,0,0,0); + /* + * add code here + * */ + /* + * To have good performance and stability, it is required to avoid heap memory + * allocation in this function. The heap memory allocation includes but not + * limited to calling malloc, operator new, constructing STL container objects + * like std::vector with default allocator, and adding items like calling + * std::vector::push_back to STL container objects with default allocator. + * + * Please check in SDK documentation for more information. + */ + + auto transpose_in0_ = transpose_in0(0, 0, 0, 0); + auto transpose_in1_ = transpose_in1(0, 0, 0, 0); auto [b_in, h_in, w_in, d_in] = in_0.dims(); auto [b_in2, h_in2, w_in2, d_in2] = in_1.dims(); if (transpose_in0_ && transpose_in1_) { - - // Q KT head matmul - const size_t dims[] = {b_in, w_in, h_in, h_in}; - out_0.set_dims(dims); - debuglog("HeadMatmul execute... dims=(%zdx%zdx%zdx%zd)", out_0.dim(0), out_0.dim(1), out_0.dim(2), out_0.dim(3)); - + // Q KT head matmul + const size_t dims[] = {b_in, w_in, h_in, h_in}; + out_0.set_dims(dims); + debuglog("HeadMatmul execute... dims=(%zdx%zdx%zdx%zd)", out_0.dim(0), out_0.dim(1), out_0.dim(2), out_0.dim(3)); } else if (transpose_in0_) { - } else if (transpose_in1_) { + // QKT V head matmul + const size_t dims[] = {b_in, w_in, h_in, d_in2}; + out_0.set_dims(dims); + debuglog("HeadMatmul execute... dims=(%zdx%zdx%zdx%zd)", out_0.dim(0), out_0.dim(1), out_0.dim(2), out_0.dim(3)); - // QKT V head matmul - const size_t dims[] = {b_in, w_in, h_in, d_in2}; - out_0.set_dims(dims); - debuglog("HeadMatmul execute... dims=(%zdx%zdx%zdx%zd)", out_0.dim(0), out_0.dim(1), out_0.dim(2), out_0.dim(3)); - - // Todo out matrix needs transpose, we directly calculate the final dimensions. + // Todo out matrix needs transpose, we directly calculate the final dimensions. } else { - } - - return GraphStatus::Success; + return GraphStatus::Success; } -__attribute__((unused)) static float headmatmulCostFunc(const Op *op) -{ - /* - * add code here - * */ +__attribute__((unused)) static float headmatmulCostFunc(const Op *op) { + /* + * add code here + * */ - float cost = 0.0; // add cost computation here - return cost; + float cost = 0.0; // add cost computation here + return cost; } - - - - /* At the bottom of the op file, call END_PKG_OP_DEFINITION(), where is as BEGIN_PKG_OP_DEFINITION */ diff --git a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/IRoPE.cpp b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/IRoPE.cpp index b237b70af..a4fc6d757 100755 --- a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/IRoPE.cpp +++ b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/IRoPE.cpp @@ -9,18 +9,16 @@ #include "QnnOpPackage.h" #include "HTP/core/simple_reg.h" - BEGIN_PKG_OP_DEFINITION(PKG_IRoPE); - // op execute function declarations -template -GraphStatus iropeImpl(TensorType& out_0, - const TensorType& in_0, - const TensorType& in_1, - const TensorType& cos, +template +GraphStatus iropeImpl(TensorType &out_0, + const TensorType &in_0, + const TensorType &in_1, + const TensorType &cos, const TensorType1 &h_cnt, - const Tensor& pose_type); + const Tensor &pose_type); // forward declaration of sample cost function static float iropeCostFunc(const Op *op); @@ -65,11 +63,11 @@ DEF_PACKAGE_OP((iropeImpl), "IRoPE") * one definition per op, and this is optional * syntax: DEF_PACKAGE_PARAM_ORDER(OP,PARAM1,MANDATORY1,DEFAULT1,PARAM2,MANDATORY2,DEFAULT2...) * one or more parameters can be specified for each op - * order of parameters listed determines the order of parameters passed into op execution functions + * order of parameters listed determines the order of parameters passed into op execution functions * if an op does not have a parameter order definition, parameter order passed into Qnn_addNode * will be passed into op execution functions * if an op has a parameter order definition, any parameter passed into Qnn_addNode with unlisted - * name will be abandoned + * name will be abandoned * if two or more op packages with the same package name will be registered, they cannot list * conflicting parameter orders * PARAM refers to parameter name as a string literal @@ -82,142 +80,130 @@ DEF_PACKAGE_OP((iropeImpl), "IRoPE") * graph construction will skip this parameter when this parameter is not provided at * Qnn_addNode */ -DEF_PACKAGE_PARAM_ORDER("IRoPE", +DEF_PACKAGE_PARAM_ORDER("IRoPE", "pose_type", true, nullptr) - /* execute functions for ops */ -template -GraphStatus iropeImpl(TensorType& out_0, - const TensorType& in_0, - const TensorType& sin, - const TensorType& cos, +template +GraphStatus iropeImpl(TensorType &out_0, + const TensorType &in_0, + const TensorType &sin, + const TensorType &cos, const TensorType1 &h_cnt, - const Tensor& pose_type) + const Tensor &pose_type) { - /* - * add code here - * */ - /* - * To have good performance and stability, it is required to avoid heap memory - * allocation in this function. The heap memory allocation includes but not - * limited to calling malloc, operator new, constructing STL container objects - * like std::vector with default allocator, and adding items like calling - * std::vector::push_back to STL container objects with default allocator. - * - * Please check in SDK documentation for more information. - */ - auto pose_type_ = pose_type(0,0,0,0); - auto h_cnt_ = static_cast(h_cnt(0,0,0,0)); - - out_0.set_dims(in_0); - auto [b_in, h_in, w_in, d_in] = in_0.dims(); - - uint32_t half_dimension = d_in / 2; - - auto sin_ptr = (uint8_t*)sin.raw_data_const(); - auto cos_ptr = (uint8_t*)cos.raw_data_const(); - - auto in_ptr = (uint8_t*)in_0.raw_data_const(); - - sin_ptr += half_dimension * h_cnt_; - cos_ptr += half_dimension * h_cnt_; - - // float scale_ = in_0.get_interface_scale() * sin.get_interface_scale() * cos.get_interface_scale(); - - if (pose_type_ == 4) { - DType dtype = out_0.get_dtype(); - - if (dtype == DType::Float32) { - - auto out_ptr = (float*)out_0.raw_data(); - for (Idx b = 0; b < b_in; b++) { - for (Idx h = 0; h < h_in; h++) { - for (Idx w = 0; w < w_in; w++) { - - int partial_dimension = d_in; - for (Idx d = 0; d < partial_dimension / 2; ++d) { - int in_value = *in_ptr; - int in_value_2 = *(in_ptr + half_dimension); - - int sin_value = *(sin_ptr+d); - int cos_value = *(cos_ptr+d); - float value = (in_value-128) * (cos_value-128) * cos.get_interface_scale() - (in_value_2-128) * (sin_value-128) * sin.get_interface_scale(); - float value2 = (in_value-128) * (sin_value-128) * sin.get_interface_scale() + (in_value_2-128) * (cos_value-128) * cos.get_interface_scale(); - - *out_ptr = value; - *(out_ptr + half_dimension) = value2; - - out_ptr++; - in_ptr++; + /* + * add code here + * */ + /* + * To have good performance and stability, it is required to avoid heap memory + * allocation in this function. The heap memory allocation includes but not + * limited to calling malloc, operator new, constructing STL container objects + * like std::vector with default allocator, and adding items like calling + * std::vector::push_back to STL container objects with default allocator. + * + * Please check in SDK documentation for more information. + */ + auto pose_type_ = pose_type(0, 0, 0, 0); + auto h_cnt_ = static_cast(h_cnt(0, 0, 0, 0)); + + out_0.set_dims(in_0); + auto [b_in, h_in, w_in, d_in] = in_0.dims(); + + uint32_t half_dimension = d_in / 2; + + auto sin_ptr = (uint8_t *)sin.raw_data_const(); + auto cos_ptr = (uint8_t *)cos.raw_data_const(); + + auto in_ptr = (uint8_t *)in_0.raw_data_const(); + + sin_ptr += half_dimension * h_cnt_; + cos_ptr += half_dimension * h_cnt_; + + // float scale_ = in_0.get_interface_scale() * sin.get_interface_scale() * cos.get_interface_scale(); + + if (pose_type_ == 4) { + DType dtype = out_0.get_dtype(); + + if (dtype == DType::Float32) { + auto out_ptr = (float *)out_0.raw_data(); + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + for (Idx w = 0; w < w_in; w++) { + int partial_dimension = d_in; + for (Idx d = 0; d < partial_dimension / 2; ++d) { + int in_value = *in_ptr; + int in_value_2 = *(in_ptr + half_dimension); + + int sin_value = *(sin_ptr + d); + int cos_value = *(cos_ptr + d); + float value = (in_value - 128) * (cos_value - 128) * cos.get_interface_scale() - (in_value_2 - 128) * (sin_value - 128) * sin.get_interface_scale(); + float value2 = (in_value - 128) * (sin_value - 128) * sin.get_interface_scale() + (in_value_2 - 128) * (cos_value - 128) * cos.get_interface_scale(); + + *out_ptr = value; + *(out_ptr + half_dimension) = value2; + + out_ptr++; + in_ptr++; + } + + in_ptr += half_dimension; + out_ptr += half_dimension; + } + + sin_ptr += half_dimension; + cos_ptr += half_dimension; + } } - - in_ptr += half_dimension; - out_ptr += half_dimension; - } - - sin_ptr += half_dimension; - cos_ptr += half_dimension; - - } - } - } else if (dtype == DType::Float16) { - - auto out_ptr = (__fp16*)out_0.raw_data(); - for (Idx b = 0; b < b_in; b++) { - for (Idx h = 0; h < h_in; h++) { - for (Idx w = 0; w < w_in; w++) { - - int partial_dimension = d_in; - for (Idx d = 0; d < partial_dimension / 2; ++d) { - int in_value = *in_ptr; - int in_value_2 = *(in_ptr + half_dimension); - - int sin_value = *(sin_ptr+d); - int cos_value = *(cos_ptr+d); - float value = (in_value-128) * (cos_value-128) * cos.get_interface_scale() - (in_value_2-128) * (sin_value-128) * sin.get_interface_scale(); - float value2 = (in_value-128) * (sin_value-128) * sin.get_interface_scale() + (in_value_2-128) * (cos_value-128) * cos.get_interface_scale(); - - *out_ptr = static_cast<__fp16>(value); - *(out_ptr + half_dimension) = static_cast<__fp16>(value2); - - out_ptr++; - in_ptr++; + } else if (dtype == DType::Float16) { + auto out_ptr = (__fp16 *)out_0.raw_data(); + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + for (Idx w = 0; w < w_in; w++) { + int partial_dimension = d_in; + for (Idx d = 0; d < partial_dimension / 2; ++d) { + int in_value = *in_ptr; + int in_value_2 = *(in_ptr + half_dimension); + + int sin_value = *(sin_ptr + d); + int cos_value = *(cos_ptr + d); + float value = (in_value - 128) * (cos_value - 128) * cos.get_interface_scale() - (in_value_2 - 128) * (sin_value - 128) * sin.get_interface_scale(); + float value2 = (in_value - 128) * (sin_value - 128) * sin.get_interface_scale() + (in_value_2 - 128) * (cos_value - 128) * cos.get_interface_scale(); + + *out_ptr = static_cast<__fp16>(value); + *(out_ptr + half_dimension) = static_cast<__fp16>(value2); + + out_ptr++; + in_ptr++; + } + + in_ptr += half_dimension; + out_ptr += half_dimension; + } + + sin_ptr += half_dimension; + cos_ptr += half_dimension; + } } - - in_ptr += half_dimension; - out_ptr += half_dimension; - } - - sin_ptr += half_dimension; - cos_ptr += half_dimension; - } - } - } - } + } - return GraphStatus::Success; + return GraphStatus::Success; } -__attribute__((unused)) static float iropeCostFunc(const Op *op) -{ - /* - * add code here - * */ +__attribute__((unused)) static float iropeCostFunc(const Op *op) { + /* + * add code here + * */ - float cost = 0.0; // add cost computation here - return cost; + float cost = 0.0; // add cost computation here + return cost; } - - - - /* At the bottom of the op file, call END_PKG_OP_DEFINITION(), where is as BEGIN_PKG_OP_DEFINITION */ diff --git a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/KVCache.cpp b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/KVCache.cpp index 6d3ecb2d3..bf11ce8c8 100755 --- a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/KVCache.cpp +++ b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/KVCache.cpp @@ -9,16 +9,14 @@ #include "QnnOpPackage.h" #include "HTP/core/simple_reg.h" - BEGIN_PKG_OP_DEFINITION(PKG_KVCache); - // op execute function declarations -template -GraphStatus kvcacheImpl(TensorType& out_0, - const TensorType& in_0, +template +GraphStatus kvcacheImpl(TensorType &out_0, + const TensorType &in_0, const TensorType1 &seq_pos, - const Tensor& hidden_dim); + const Tensor &hidden_dim); // forward declaration of sample cost function static float kvcacheCostFunc(const Op *op); @@ -63,11 +61,11 @@ DEF_PACKAGE_OP((kvcacheImpl), "KVCache") * one definition per op, and this is optional * syntax: DEF_PACKAGE_PARAM_ORDER(OP,PARAM1,MANDATORY1,DEFAULT1,PARAM2,MANDATORY2,DEFAULT2...) * one or more parameters can be specified for each op - * order of parameters listed determines the order of parameters passed into op execution functions + * order of parameters listed determines the order of parameters passed into op execution functions * if an op does not have a parameter order definition, parameter order passed into Qnn_addNode * will be passed into op execution functions * if an op has a parameter order definition, any parameter passed into Qnn_addNode with unlisted - * name will be abandoned + * name will be abandoned * if two or more op packages with the same package name will be registered, they cannot list * conflicting parameter orders * PARAM refers to parameter name as a string literal @@ -80,12 +78,11 @@ DEF_PACKAGE_OP((kvcacheImpl), "KVCache") * graph construction will skip this parameter when this parameter is not provided at * Qnn_addNode */ -DEF_PACKAGE_PARAM_ORDER("KVCache", +DEF_PACKAGE_PARAM_ORDER("KVCache", "hidden_dim", true, nullptr) - /* execute functions for ops */ // #ifndef REFERENCE_OP @@ -100,7 +97,6 @@ DEF_PACKAGE_PARAM_ORDER("KVCache", // #define ONE 0x3F800000 // #define M_ONE 0xAF800000 - // int32_t hvx_memcpy_af(float *restrict input, float *restrict output, uint32_t size) // { // HVX_Vector *input_v_ptr; @@ -113,7 +109,6 @@ DEF_PACKAGE_PARAM_ORDER("KVCache", // int32_t vectors_in_rounddown = size / 32; // int32_t leftover_size = leftover * sizeof(float); - // /* Check input arguments. Return error status if some argument has invalid value */ // if ((input == 0) || (output == 0) || (size == 0)) // { @@ -187,7 +182,6 @@ DEF_PACKAGE_PARAM_ORDER("KVCache", // return 0; // } - // template // GraphStatus kvcacheImpl(TensorType& out_0, // const TensorType& in_0, @@ -207,10 +201,10 @@ DEF_PACKAGE_PARAM_ORDER("KVCache", // * // * Please check in SDK documentation for more information. // */ - + // out_0.set_dims(in_0); // auto [b_in, h_in, w_in, d_in] = in_0.dims(); - + // uint32_t seq_pos_ = seq_pos(0,0,0,0); // // uint32_t hidden_dim_ = hidden_dim(0,0,0,0); @@ -226,91 +220,76 @@ DEF_PACKAGE_PARAM_ORDER("KVCache", // hvx_memcpy_af(out_ptr, in_ptr, h_in * w_in * d_in); - // return GraphStatus::Success; // } - // #else -template -GraphStatus kvcacheImpl(TensorType& out_0, - const TensorType& in_0, +template +GraphStatus kvcacheImpl(TensorType &out_0, + const TensorType &in_0, const TensorType1 &seq_pos, - const Tensor& hidden_dim) + const Tensor &hidden_dim) { - /* - * add code here - * */ - /* - * To have good performance and stability, it is required to avoid heap memory - * allocation in this function. The heap memory allocation includes but not - * limited to calling malloc, operator new, constructing STL container objects - * like std::vector with default allocator, and adding items like calling - * std::vector::push_back to STL container objects with default allocator. - * - * Please check in SDK documentation for more information. - */ - - auto [b_in, h_in, w_in, d_in] = in_0.dims(); - - uint32_t seq_pos_ = seq_pos(0,0,0,0); - const size_t dims[] = {b_in, h_in + seq_pos_, w_in, d_in}; - - out_0.set_dims(dims); - - // uint32_t hidden_dim_ = hidden_dim(0,0,0,0); - - // // const size_t dims[] = {b_in, h_in, seq_pos_+1, hidden_dim_}; - // // out_0.set_dims(dims); - - // NSHD + /* + * add code here + * */ + /* + * To have good performance and stability, it is required to avoid heap memory + * allocation in this function. The heap memory allocation includes but not + * limited to calling malloc, operator new, constructing STL container objects + * like std::vector with default allocator, and adding items like calling + * std::vector::push_back to STL container objects with default allocator. + * + * Please check in SDK documentation for more information. + */ + + auto [b_in, h_in, w_in, d_in] = in_0.dims(); + + uint32_t seq_pos_ = seq_pos(0, 0, 0, 0); + const size_t dims[] = {b_in, h_in + seq_pos_, w_in, d_in}; + + out_0.set_dims(dims); + + // uint32_t hidden_dim_ = hidden_dim(0,0,0,0); + + // // const size_t dims[] = {b_in, h_in, seq_pos_+1, hidden_dim_}; + // // out_0.set_dims(dims); + + // NSHD DType dtype = in_0.get_dtype(); - const uint8_t *in_ptr = (uint8_t*)in_0.raw_data_const(); - uint8_t *out_ptr = (uint8_t*)out_0.raw_data(); + const uint8_t *in_ptr = (uint8_t *)in_0.raw_data_const(); + uint8_t *out_ptr = (uint8_t *)out_0.raw_data(); if (dtype == DType::QUInt8) { - out_ptr += seq_pos_ * w_in * d_in; memcpy(out_ptr, in_ptr, h_in * w_in * d_in * sizeof(uint8_t)); } else if (dtype == DType::Float16) { - out_ptr += seq_pos_ * w_in * d_in * sizeof(float) / 2; memcpy(out_ptr, in_ptr, h_in * w_in * d_in * sizeof(float) / 2); } else if (dtype == DType::Float32) { - out_ptr += seq_pos_ * w_in * d_in * sizeof(float); memcpy(out_ptr, in_ptr, h_in * w_in * d_in * sizeof(float)); } - - - - return GraphStatus::Success; + return GraphStatus::Success; } - // #endif +__attribute__((unused)) static float kvcacheCostFunc(const Op *op) { + /* + * add code here + * */ -__attribute__((unused)) static float kvcacheCostFunc(const Op *op) -{ - /* - * add code here - * */ - - float cost = 0.0; // add cost computation here - return cost; + float cost = 0.0; // add cost computation here + return cost; } - - - - /* At the bottom of the op file, call END_PKG_OP_DEFINITION(), where is as BEGIN_PKG_OP_DEFINITION */ diff --git a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAAdd.cpp b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAAdd.cpp index f883d3cb3..99576ab8b 100755 --- a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAAdd.cpp +++ b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAAdd.cpp @@ -9,15 +9,13 @@ #include "QnnOpPackage.h" #include "HTP/core/simple_reg.h" - BEGIN_PKG_OP_DEFINITION(PKG_LLaMAAdd); - // op execute function declarations -template -GraphStatus llamaaddImpl(TensorType& out_0, - const TensorType& in_0, - const TensorType& in_1); +template +GraphStatus llamaaddImpl(TensorType &out_0, + const TensorType &in_0, + const TensorType &in_1); // forward declaration of sample cost function static float llamaaddCostFunc(const Op *op); @@ -62,11 +60,11 @@ DEF_PACKAGE_OP((llamaaddImpl), "LLaMAAdd") * one definition per op, and this is optional * syntax: DEF_PACKAGE_PARAM_ORDER(OP,PARAM1,MANDATORY1,DEFAULT1,PARAM2,MANDATORY2,DEFAULT2...) * one or more parameters can be specified for each op - * order of parameters listed determines the order of parameters passed into op execution functions + * order of parameters listed determines the order of parameters passed into op execution functions * if an op does not have a parameter order definition, parameter order passed into Qnn_addNode * will be passed into op execution functions * if an op has a parameter order definition, any parameter passed into Qnn_addNode with unlisted - * name will be abandoned + * name will be abandoned * if two or more op packages with the same package name will be registered, they cannot list * conflicting parameter orders * PARAM refers to parameter name as a string literal @@ -80,7 +78,6 @@ DEF_PACKAGE_OP((llamaaddImpl), "LLaMAAdd") * Qnn_addNode */ - /* execute functions for ops */ #ifndef REFERENCE_OP @@ -90,17 +87,15 @@ DEF_PACKAGE_OP((llamaaddImpl), "LLaMAAdd") #include #include -#define BLOCK_SIZE (8*1024/VLEN) /* vector chunks */ -#define L2FETCH_AHEAD (BLOCK_SIZE) +#define BLOCK_SIZE (8 * 1024 / VLEN) /* vector chunks */ +#define L2FETCH_AHEAD (BLOCK_SIZE) int32_t hvx_add_af( float *restrict input, float *restrict input2, float *restrict output, - uint32_t size) -{ - if ((input == NULL) || (output == NULL) || (size == 0)) - { + uint32_t size) { + if ((input == NULL) || (output == NULL) || (size == 0)) { return -1; } @@ -120,25 +115,21 @@ int32_t hvx_add_af( sline1p = *iptr++; sline2p = *iptr2++; - for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) - { + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { block = Q6_R_min_RR(i, BLOCK_SIZE); l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); - if (l2fetch_block > 0) - { + if (l2fetch_block > 0) { l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); l2fetch(iptr2 + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); } - for (int32_t j = 0; j < block; ++j) - { + for (int32_t j = 0; j < block; ++j) { sline1c = *iptr++; sline2c = *iptr2++; sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t)input2); - // Our add consider uint8->int8 bugs from QNN. // sline2 = Q6_Vb_vsub_VbVb(sline2, v128); *optr++ = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(sline1, sline2)); @@ -149,134 +140,116 @@ int32_t hvx_add_af( } if (vectors_in_rounddown > 0) { + sline1c = is_aligned(iptr, VLEN) && leftover == 0 ? sline1p : *iptr++; + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); - sline1c = is_aligned(iptr, VLEN) && leftover == 0 ? sline1p : *iptr++; - sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t) input); - - sline2c = is_aligned(iptr2, VLEN) && leftover == 0 ? sline2p : *iptr2++; - sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t) input2); - - // sline2 = Q6_Vb_vsub_VbVb(sline2, v128); - *optr++ = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(sline1, sline2)); + sline2c = is_aligned(iptr2, VLEN) && leftover == 0 ? sline2p : *iptr2++; + sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t)input2); + // sline2 = Q6_Vb_vsub_VbVb(sline2, v128); + *optr++ = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(sline1, sline2)); } // Handle leftover elements. if (leftover_size > 0) { - sline1c = (is_in_one_chunk(iptr, leftover_size, VLEN) - ? sline1p - : *iptr++); - sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); - + sline1c = (is_in_one_chunk(iptr, leftover_size, VLEN) ? sline1p : *iptr++); + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); - sline2c = (is_in_one_chunk(iptr2, leftover_size, VLEN) - ? sline2p - : *iptr2++); - sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t)input2); + sline2c = (is_in_one_chunk(iptr2, leftover_size, VLEN) ? sline2p : *iptr2++); + sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t)input2); - // sline2 = Q6_Vb_vsub_VbVb(sline2, v128); - vstu_variable(optr, leftover_size, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(sline1, sline2))); + // sline2 = Q6_Vb_vsub_VbVb(sline2, v128); + vstu_variable(optr, leftover_size, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(sline1, sline2))); } return 0; } -template -GraphStatus llamaaddImpl(TensorType& out_0, - const TensorType& in_0, - const TensorType& in_1) +template +GraphStatus llamaaddImpl(TensorType &out_0, + const TensorType &in_0, + const TensorType &in_1) { - /* - * add code here - * */ - /* - * To have good performance and stability, it is required to avoid heap memory - * allocation in this function. The heap memory allocation includes but not - * limited to calling malloc, operator new, constructing STL container objects - * like std::vector with default allocator, and adding items like calling - * std::vector::push_back to STL container objects with default allocator. - * - * Please check in SDK documentation for more information. - */ + /* + * add code here + * */ + /* + * To have good performance and stability, it is required to avoid heap memory + * allocation in this function. The heap memory allocation includes but not + * limited to calling malloc, operator new, constructing STL container objects + * like std::vector with default allocator, and adding items like calling + * std::vector::push_back to STL container objects with default allocator. + * + * Please check in SDK documentation for more information. + */ out_0.set_dims(in_0); - - auto in_ptr = (float*)in_0.raw_data_const(); - auto in2_ptr = (float*)in_1.raw_data_const(); - auto out_ptr = (float*)out_0.raw_data(); + auto in_ptr = (float *)in_0.raw_data_const(); + auto in2_ptr = (float *)in_1.raw_data_const(); + auto out_ptr = (float *)out_0.raw_data(); auto [b_in, h_in, w_in, d_in] = in_0.dims(); - size_t size = b_in*h_in*w_in*d_in; + size_t size = b_in * h_in * w_in * d_in; hvx_add_af(in_ptr, in2_ptr, out_ptr, size); - return GraphStatus::Success; + return GraphStatus::Success; } #else - -template -GraphStatus llamaaddImpl(TensorType& out_0, - const TensorType& in_0, - const TensorType& in_1) +template +GraphStatus llamaaddImpl(TensorType &out_0, + const TensorType &in_0, + const TensorType &in_1) { - /* - * add code here - * */ - /* - * To have good performance and stability, it is required to avoid heap memory - * allocation in this function. The heap memory allocation includes but not - * limited to calling malloc, operator new, constructing STL container objects - * like std::vector with default allocator, and adding items like calling - * std::vector::push_back to STL container objects with default allocator. - * - * Please check in SDK documentation for more information. - */ - out_0.set_dims(in_0); - - auto [b_in, h_in, w_in, d_in] = in_0.dims(); - for (Idx b = 0; b < b_in; b++) { - for (Idx h = 0; h < h_in; h++) { - for (Idx w = 0; w < w_in; w++) { - // mul - for (Idx d = 0; d < d_in; d++) { - float inval = in_0(b, h, w, d); - float inval2 = in_1(b, h, w, d); - float outval = inval + inval2; - - out_0(b, h, w, d) = outval; - - } + /* + * add code here + * */ + /* + * To have good performance and stability, it is required to avoid heap memory + * allocation in this function. The heap memory allocation includes but not + * limited to calling malloc, operator new, constructing STL container objects + * like std::vector with default allocator, and adding items like calling + * std::vector::push_back to STL container objects with default allocator. + * + * Please check in SDK documentation for more information. + */ + out_0.set_dims(in_0); + + auto [b_in, h_in, w_in, d_in] = in_0.dims(); + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + for (Idx w = 0; w < w_in; w++) { + // mul + for (Idx d = 0; d < d_in; d++) { + float inval = in_0(b, h, w, d); + float inval2 = in_1(b, h, w, d); + float outval = inval + inval2; + + out_0(b, h, w, d) = outval; + } + } } - } } - - return GraphStatus::Success; + return GraphStatus::Success; } - - #endif -__attribute__((unused)) static float llamaaddCostFunc(const Op *op) -{ - /* - * add code here - * */ +__attribute__((unused)) static float llamaaddCostFunc(const Op *op) { + /* + * add code here + * */ - float cost = 0.0; // add cost computation here - return cost; + float cost = 0.0; // add cost computation here + return cost; } - - - - /* At the bottom of the op file, call END_PKG_OP_DEFINITION(), where is as BEGIN_PKG_OP_DEFINITION */ diff --git a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMADequantize.cpp b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMADequantize.cpp index 6afb884f2..c8b03b53b 100755 --- a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMADequantize.cpp +++ b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMADequantize.cpp @@ -9,15 +9,13 @@ #include "QnnOpPackage.h" #include "HTP/core/simple_reg.h" - BEGIN_PKG_OP_DEFINITION(PKG_LLaMADequantize); - // op execute function declarations -template +template GraphStatus llamadequantizeImpl(TensorType1 &out_0, const TensorType1 &in_0, - const PlainFloatTensor& scale); + const PlainFloatTensor &scale); // forward declaration of sample cost function static float llamadequantizeCostFunc(const Op *op); @@ -62,11 +60,11 @@ DEF_PACKAGE_OP((llamadequantizeImpl), "LLaMADequantize") * one definition per op, and this is optional * syntax: DEF_PACKAGE_PARAM_ORDER(OP,PARAM1,MANDATORY1,DEFAULT1,PARAM2,MANDATORY2,DEFAULT2...) * one or more parameters can be specified for each op - * order of parameters listed determines the order of parameters passed into op execution functions + * order of parameters listed determines the order of parameters passed into op execution functions * if an op does not have a parameter order definition, parameter order passed into Qnn_addNode * will be passed into op execution functions * if an op has a parameter order definition, any parameter passed into Qnn_addNode with unlisted - * name will be abandoned + * name will be abandoned * if two or more op packages with the same package name will be registered, they cannot list * conflicting parameter orders * PARAM refers to parameter name as a string literal @@ -79,7 +77,7 @@ DEF_PACKAGE_OP((llamadequantizeImpl), "LLaMADequantize") * graph construction will skip this parameter when this parameter is not provided at * Qnn_addNode */ -DEF_PACKAGE_PARAM_ORDER("LLaMADequantize", +DEF_PACKAGE_PARAM_ORDER("LLaMADequantize", "scale", true, nullptr) @@ -91,11 +89,10 @@ DEF_PACKAGE_PARAM_ORDER("LLaMADequantize", #include #include -#define BLOCK_SIZE (8*1024/VLEN) /* vector chunks */ -#define L2FETCH_AHEAD (BLOCK_SIZE) +#define BLOCK_SIZE (8 * 1024 / VLEN) /* vector chunks */ +#define L2FETCH_AHEAD (BLOCK_SIZE) -static inline int32_t float_to_fp16s(float input) -{ +static inline int32_t float_to_fp16s(float input) { union { int32_t i; __fp16 f[2]; @@ -103,35 +100,33 @@ static inline int32_t float_to_fp16s(float input) return fp32.i; } -static HVX_INLINE_ALWAYS uint32_t float_to_bits(float x) -{ - union { float f; uint32_t i; } fp32 = { .f = x }; +static HVX_INLINE_ALWAYS uint32_t float_to_bits(float x) { + union { + float f; + uint32_t i; + } fp32 = {.f = x}; return fp32.i; } - - /* execute functions for ops */ int32_t qhmath_hvx_dequantize_ahf( int8_t *restrict input, int8_t *restrict output, uint32_t size, - float scale) -{ - if ((input == NULL) || (output == NULL) || (size == 0)) - { + float scale) { + if ((input == NULL) || (output == NULL) || (size == 0)) { return -1; } - HVX_Vector *iptr = (HVX_Vector *) input; - HVX_UVector *optr = (HVX_UVector *) output; + HVX_Vector *iptr = (HVX_Vector *)input; + HVX_UVector *optr = (HVX_UVector *)output; HVX_Vector sline1p, sline1c, sline1; HVX_Vector scale_vec; int32_t block, l2fetch_block; - int32_t leftover = size & 63; - int32_t vectors_in_rounddown = size / 128; // element number! + int32_t leftover = size & 127; + int32_t vectors_in_rounddown = size / 128; // element number! // int32_t leftover_size = leftover * sizeof(float); sline1p = *iptr++; @@ -139,42 +134,36 @@ int32_t qhmath_hvx_dequantize_ahf( uint32_t convert = 0x00800080; HVX_Vector convert_vector = Q6_V_vsplat_R(convert); - scale_vec = Q6_V_vsplat_R(float_to_fp16s(scale)); HVX_Vector zero_v_sf = Q6_V_vzero(); - for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) - { + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { block = Q6_R_min_RR(i, BLOCK_SIZE); l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); - if (l2fetch_block > 0) - { + if (l2fetch_block > 0) { l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); } - for (int32_t j = 0; j < block; ++j) - { + for (int32_t j = 0; j < block; ++j) { sline1c = *iptr++; - sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t) input); + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); HVX_VectorPair temp = Q6_Wh_vadd_VubVub(sline1, zero_v_sf); temp = Q6_W_vshuff_VVR(Q6_V_hi_W(temp), Q6_V_lo_W(temp), -2); HVX_Vector sout1 = Q6_Vh_vsub_VhVh(Q6_V_lo_W(temp), convert_vector); HVX_Vector sout2 = Q6_Vh_vsub_VhVh(Q6_V_hi_W(temp), convert_vector); - - *optr++ = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout1), scale_vec)); - *optr++ = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout2), scale_vec)); + *optr++ = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout1), scale_vec)); + *optr++ = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout2), scale_vec)); sline1p = sline1c; } } if (vectors_in_rounddown > 0) { - sline1c = is_aligned(iptr, VLEN) && leftover == 0 ? sline1p : *iptr++; - sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t) input); + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); HVX_VectorPair temp = Q6_Wh_vadd_VubVub(sline1, zero_v_sf); @@ -182,30 +171,85 @@ int32_t qhmath_hvx_dequantize_ahf( HVX_Vector sout1 = Q6_Vh_vsub_VhVh(Q6_V_lo_W(temp), convert_vector); HVX_Vector sout2 = Q6_Vh_vsub_VhVh(Q6_V_hi_W(temp), convert_vector); + *optr++ = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout1), scale_vec)); + *optr++ = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout2), scale_vec)); + } - *optr++ = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout1), scale_vec)); - *optr++ = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout2), scale_vec)); + return 0; +} +int32_t qhmath_hvx_dequantize_ui16_ahf( + int8_t *restrict input, + int8_t *restrict output, + uint32_t size, + float scale) { + if ((input == NULL) || (output == NULL) || (size == 0)) { + return -1; } + HVX_Vector *iptr = (HVX_Vector *)input; + HVX_UVector *optr = (HVX_UVector *)output; + + HVX_Vector sline1p, sline1c, sline1; + HVX_Vector scale_vec; + + int32_t block, l2fetch_block; + int32_t leftover = size & 63; + int32_t vectors_in_rounddown = size / 64; // element number! + // int32_t leftover_size = leftover * sizeof(float); + + sline1p = *iptr++; + + uint32_t convert = 0x80008000; + HVX_Vector convert_vector = Q6_V_vsplat_R(convert); + + scale_vec = Q6_V_vsplat_R(float_to_fp16s(scale)); + + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { + block = Q6_R_min_RR(i, BLOCK_SIZE); + l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); + + if (l2fetch_block > 0) { + l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); + } + + for (int32_t j = 0; j < block; ++j) { + sline1c = *iptr++; + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); + HVX_Vector temp = sline1; + + HVX_Vector sout1 = Q6_Vh_vsub_VhVh(temp, convert_vector); + *optr++ = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout1), scale_vec)); + + sline1p = sline1c; + } + } + + if (vectors_in_rounddown > 0) { + sline1c = is_aligned(iptr, VLEN) && leftover == 0 ? sline1p : *iptr++; + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); + + HVX_Vector temp = sline1; + + HVX_Vector sout1 = Q6_Vh_vsub_VhVh(temp, convert_vector); + *optr++ = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout1), scale_vec)); + } return 0; } -// Only support 128x dimension +// Only support 128x dimension int32_t qhmath_hvx_dequantize_af( int8_t *restrict input, int8_t *restrict output, uint32_t size, - float scale) -{ - if ((input == NULL) || (output == NULL) || (size == 0)) - { + float scale) { + if ((input == NULL) || (output == NULL) || (size == 0)) { return -1; } - HVX_Vector *iptr = (HVX_Vector *) input; - HVX_UVector *optr = (HVX_UVector *) output; + HVX_Vector *iptr = (HVX_Vector *)input; + HVX_UVector *optr = (HVX_UVector *)output; HVX_Vector sline1p, sline1c, sline1; HVX_Vector scale_vec; @@ -221,26 +265,22 @@ int32_t qhmath_hvx_dequantize_af( uint32_t convert = 0x00800080; HVX_Vector convert_vector = Q6_V_vsplat_R(convert); - scale_vec = Q6_V_vsplat_R(float_to_bits(scale)); one_vec = Q6_V_vsplat_R(float_to_fp16s(1.0)); HVX_Vector zero_v_sf = Q6_V_vzero(); scale_vec = Q6_Vqf32_vadd_VsfVsf(scale_vec, Q6_V_vzero()); - for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) - { + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { block = Q6_R_min_RR(i, BLOCK_SIZE); l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); - if (l2fetch_block > 0) - { + if (l2fetch_block > 0) { l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); } - for (int32_t j = 0; j < block; ++j) - { + for (int32_t j = 0; j < block; ++j) { sline1c = *iptr++; - sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t) input); + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); HVX_VectorPair temp = Q6_Wh_vadd_VubVub(sline1, zero_v_sf); temp = Q6_W_vshuff_VVR(Q6_V_hi_W(temp), Q6_V_lo_W(temp), -2); @@ -253,19 +293,18 @@ int32_t qhmath_hvx_dequantize_af( HVX_VectorPair result2 = Q6_Wqf32_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout2), one_vec); result2 = Q6_W_vshuff_VVR(Q6_V_hi_W(result2), Q6_V_lo_W(result2), -4); - *optr++ = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result1), scale_vec)); - *optr++ = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result1), scale_vec)); - *optr++ = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result2), scale_vec)); - *optr++ = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result2), scale_vec)); + *optr++ = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result1), scale_vec)); + *optr++ = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result1), scale_vec)); + *optr++ = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result2), scale_vec)); + *optr++ = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result2), scale_vec)); sline1p = sline1c; } } if (vectors_in_rounddown > 0) { - sline1c = is_aligned(iptr, VLEN) && leftover == 0 ? sline1p : *iptr++; - sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t) input); + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); HVX_VectorPair temp = Q6_Wh_vadd_VubVub(sline1, zero_v_sf); @@ -279,141 +318,222 @@ int32_t qhmath_hvx_dequantize_af( HVX_VectorPair result2 = Q6_Wqf32_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout2), one_vec); result2 = Q6_W_vshuff_VVR(Q6_V_hi_W(result2), Q6_V_lo_W(result2), -4); - *optr++ = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result1), scale_vec)); - *optr++ = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result1), scale_vec)); - *optr++ = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result2), scale_vec)); - *optr++ = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result2), scale_vec)); + *optr++ = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result1), scale_vec)); + *optr++ = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result1), scale_vec)); + *optr++ = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result2), scale_vec)); + *optr++ = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result2), scale_vec)); + } + + return 0; +} + +int32_t qhmath_hvx_dequantize_ui16_af( + int8_t *restrict input, + int8_t *restrict output, + uint32_t size, + float scale) { + if ((input == NULL) || (output == NULL) || (size == 0)) { + return -1; + } + + HVX_Vector *iptr = (HVX_Vector *)input; + HVX_UVector *optr = (HVX_UVector *)output; + + HVX_Vector sline1p, sline1c, sline1; + HVX_Vector scale_vec; + HVX_Vector one_vec; + + int32_t block, l2fetch_block; + int32_t leftover = size & 63; + int32_t vectors_in_rounddown = size / 64; + // int32_t leftover_size = leftover * sizeof(float); + + sline1p = *iptr++; + + uint32_t convert = 0x80008000; + HVX_Vector convert_vector = Q6_V_vsplat_R(convert); + scale_vec = Q6_V_vsplat_R(float_to_bits(scale)); + one_vec = Q6_V_vsplat_R(float_to_fp16s(1.0)); + scale_vec = Q6_Vqf32_vadd_VsfVsf(scale_vec, Q6_V_vzero()); + + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { + block = Q6_R_min_RR(i, BLOCK_SIZE); + l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); + + if (l2fetch_block > 0) { + l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); + } + + for (int32_t j = 0; j < block; ++j) { + sline1c = *iptr++; + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); + + HVX_Vector temp = Q6_Vh_vsub_VhVh(sline1, convert_vector); + HVX_VectorPair result = Q6_Wqf32_vmpy_VhfVhf(Q6_Vhf_equals_Vh(temp), one_vec); + result = Q6_W_vshuff_VVR(Q6_V_hi_W(result), Q6_V_lo_W(result), -4); + *optr++ = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result), scale_vec)); + *optr++ = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result), scale_vec)); + + sline1p = sline1c; + } + } + + if (vectors_in_rounddown > 0) { + sline1c = is_aligned(iptr, VLEN) && leftover == 0 ? sline1p : *iptr++; + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); + + HVX_Vector temp = Q6_Vh_vsub_VhVh(sline1, convert_vector); + HVX_VectorPair result = Q6_Wqf32_vmpy_VhfVhf(Q6_Vhf_equals_Vh(temp), one_vec); + result = Q6_W_vshuff_VVR(Q6_V_hi_W(result), Q6_V_lo_W(result), -4); + *optr++ = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result), scale_vec)); + *optr++ = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result), scale_vec)); } return 0; } -template +template GraphStatus llamadequantizeImpl(TensorType1 &out_0, const TensorType1 &in_0, - const PlainFloatTensor& scale) + const PlainFloatTensor &scale) { - /* - * add code here - * */ - /* - * To have good performance and stability, it is required to avoid heap memory - * allocation in this function. The heap memory allocation includes but not - * limited to calling malloc, operator new, constructing STL container objects - * like std::vector with default allocator, and adding items like calling - * std::vector::push_back to STL container objects with default allocator. - * - * Please check in SDK documentation for more information. - */ - - // HVX Method -- FP32 Version + /* + * add code here + * */ + /* + * To have good performance and stability, it is required to avoid heap memory + * allocation in this function. The heap memory allocation includes but not + * limited to calling malloc, operator new, constructing STL container objects + * like std::vector with default allocator, and adding items like calling + * std::vector::push_back to STL container objects with default allocator. + * + * Please check in SDK documentation for more information. + */ + + // HVX Method -- FP32 Version out_0.set_dims(in_0); - + // NHWC - auto in_ptr = (int8_t*)in_0.raw_data_const(); - auto out_ptr = (int8_t*)out_0.raw_data(); + auto in_ptr = (int8_t *)in_0.raw_data_const(); + auto out_ptr = (int8_t *)out_0.raw_data(); auto [b_in, h_in, w_in, d_in] = in_0.dims(); - float scale_ = scale(0,0,0,0); - + float scale_ = scale(0, 0, 0, 0); - size_t size = b_in*h_in*w_in*d_in; + size_t size = b_in * h_in * w_in * d_in; if (in_0.get_dtype() == DType::QUInt8 && out_0.get_dtype() == DType::Float16) { qhmath_hvx_dequantize_ahf(in_ptr, out_ptr, size, scale_); - } - else { + } else if (in_0.get_dtype() == DType::QUInt16 && out_0.get_dtype() == DType::Float16) { + qhmath_hvx_dequantize_ui16_ahf(in_ptr, out_ptr, size, scale_); + } else if (in_0.get_dtype() == DType::QUInt16 && out_0.get_dtype() == DType::Float32) { + qhmath_hvx_dequantize_ui16_af(in_ptr, out_ptr, size, scale_); + } else { qhmath_hvx_dequantize_af(in_ptr, out_ptr, size, scale_); } - - - return GraphStatus::Success; + return GraphStatus::Success; } #else -template +template GraphStatus llamadequantizeImpl(TensorType1 &out_0, const TensorType1 &in_0, - const PlainFloatTensor& scale) + const PlainFloatTensor &scale) { - /* - * add code here - * */ - /* - * To have good performance and stability, it is required to avoid heap memory - * allocation in this function. The heap memory allocation includes but not - * limited to calling malloc, operator new, constructing STL container objects - * like std::vector with default allocator, and adding items like calling - * std::vector::push_back to STL container objects with default allocator. - * - * Please check in SDK documentation for more information. - */ - - // HVX Method -- FP32 Version + /* + * add code here + * */ + /* + * To have good performance and stability, it is required to avoid heap memory + * allocation in this function. The heap memory allocation includes but not + * limited to calling malloc, operator new, constructing STL container objects + * like std::vector with default allocator, and adding items like calling + * std::vector::push_back to STL container objects with default allocator. + * + * Please check in SDK documentation for more information. + */ + + // HVX Method -- FP32 Version out_0.set_dims(in_0); - float scale_ = scale(0,0,0,0); - - auto in_ptr = (uint8_t*)in_0.raw_data_const(); + float scale_ = scale(0, 0, 0, 0); auto [b_in, h_in, w_in, d_in] = in_0.dims(); - - if (out_0.get_dtype() == DType::Float32) { - auto out_ptr = (float*)out_0.raw_data(); + if (in_0.get_dtype() == DType::QUInt8 && out_0.get_dtype() == DType::Float32) { + auto out_ptr = (float *)out_0.raw_data(); + auto in_ptr = (uint8_t *)in_0.raw_data_const(); for (Idx b = 0; b < b_in; b++) { for (Idx h = 0; h < h_in; h++) { for (Idx w = 0; w < w_in; w++) { - for (Idx d = 0; d < d_in; d++) { - - int32_t inval = static_cast(*in_ptr++); - *out_ptr++ = (inval-128) * scale_; - + for (Idx d = 0; d < d_in; d++) { + int32_t inval = static_cast(*in_ptr++); + *out_ptr++ = (inval - 128) * scale_; + } } + } + } + } else if (in_0.get_dtype() == DType::QUInt16 && out_0.get_dtype() == DType::Float32) { + auto out_ptr = (float *)out_0.raw_data(); + auto in_ptr = (uint16_t *)in_0.raw_data_const(); + + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + for (Idx w = 0; w < w_in; w++) { + for (Idx d = 0; d < d_in; d++) { + int32_t inval = static_cast(*in_ptr++); + *out_ptr++ = (inval - 32768) * scale_; + } } } } - } else if (out_0.get_dtype() == DType::Float16) { + } else if (in_0.get_dtype() == DType::QUInt16 && out_0.get_dtype() == DType::Float16) { + auto out_ptr = (__fp16 *)out_0.raw_data(); + auto in_ptr = (uint16_t *)in_0.raw_data_const(); - auto out_ptr = (__fp16*)out_0.raw_data(); + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + for (Idx w = 0; w < w_in; w++) { + for (Idx d = 0; d < d_in; d++) { + int32_t inval = static_cast(*in_ptr++); + *out_ptr++ = (__fp16)((inval - 32768) * scale_); + } + } + } + } + } else if (in_0.get_dtype() == DType::QUInt8 && out_0.get_dtype() == DType::Float16) { + auto out_ptr = (__fp16 *)out_0.raw_data(); + auto in_ptr = (uint8_t *)in_0.raw_data_const(); for (Idx b = 0; b < b_in; b++) { for (Idx h = 0; h < h_in; h++) { for (Idx w = 0; w < w_in; w++) { - for (Idx d = 0; d < d_in; d++) { - - int32_t inval = static_cast(*in_ptr++); - *out_ptr++ = (__fp16)((inval-128) * scale_); - + for (Idx d = 0; d < d_in; d++) { + int32_t inval = static_cast(*in_ptr++); + *out_ptr++ = (__fp16)((inval - 128) * scale_); } } } } } - return GraphStatus::Success; + return GraphStatus::Success; } #endif +__attribute__((unused)) static float llamadequantizeCostFunc(const Op *op) { + /* + * add code here + * */ -__attribute__((unused)) static float llamadequantizeCostFunc(const Op *op) -{ - /* - * add code here - * */ - - float cost = 0.0; // add cost computation here - return cost; + float cost = 0.0; // add cost computation here + return cost; } - - - - /* At the bottom of the op file, call END_PKG_OP_DEFINITION(), where is as BEGIN_PKG_OP_DEFINITION */ diff --git a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMADequantizeAdd.cpp b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMADequantizeAdd.cpp new file mode 100755 index 000000000..86158e228 --- /dev/null +++ b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMADequantizeAdd.cpp @@ -0,0 +1,694 @@ +//============================================================================== +// Auto Generated Code for LLaMAPackage +//============================================================================== + +#include "HTP/core/constraints.h" +#include "HTP/core/op_package_feature_support.h" +#include "HTP/core/op_register_ext.h" +#include "HTP/core/optimize.h" +#include "QnnOpPackage.h" +#include "HTP/core/simple_reg.h" + +BEGIN_PKG_OP_DEFINITION(PKG_LLaMADequantizeAdd); + +// op execute function declarations +template +GraphStatus llamadequantizeaddImpl(TensorType1 &out_0, + const TensorType1 &in_0, + const TensorType &in_1, + const PlainFloatTensor &scale); + +// forward declaration of sample cost function +static float llamadequantizeaddCostFunc(const Op *op); + +/* + * method 1 for defining op, using default cost value (i.e. GLACIAL) and default flag (Flags::RESOURCE_HVX) + * syntax: DEF_PACKAGE_OP(F,OP) + * e.g. DEF_PACKAGE_OP((llamadequantizeaddImpl), "LLaMADequantizeAdd") + */ +DEF_PACKAGE_OP((llamadequantizeaddImpl), "LLaMADequantizeAdd") + +/* + * method 2 for defining op with specified cost value (one of GLACIAL, SNAIL, FAST, FREE) + * and provided flags + * syntax: DEF_PACKAGE_OP_AND_COST_AND_FLAGS(F,OP,COST,...) + * can use zero or more flags, FLAG options are IS_CONST, INHIBIT_CONST_PROP, + * RESOURCE_HVX, RESOURCE_HMX(not supported in external op packages) + * e.g. DEF_PACKAGE_OP_AND_COST_AND_FLAGS((llamadequantizeaddImpl), "LLaMADequantizeAdd", SNAIL) + */ + +/* + * method 3 for defining op with cost function pointer and provided flags + * cost function pointer type: typedef float (*cost_function) (const Op * op); + * syntax: DEF_PACKAGE_OP_AND_COST_F_AND_FLAGS(F,OP,COST_F,...) + * e.g. DEF_PACKAGE_OP_AND_COST_F_AND_FLAGS((llamadequantizeaddImpl), + * "LLaMADequantizeAdd", llamadequantizeaddCostFunc, Flags::RESOURCE_HVX) + */ + +/* + * optimization definitions + * need to be global in the package + * one definition per optimization + * syntax: DEF_PACKAGE_OPTIMIZATION(PRIORITY,MATCHCODE,CONSTRAINTCODE,REPLACECODE) + * PRIORITY predefined values include EARLY(2000), MIDDLE(3000), LATE(4000) + * HTP core provides some replacement functions for op package to use + * for more information about optimization rules, please refer to HTP core documentations + */ + +/* + * op parameter order definitions + * need to be global in the package + * one definition per op, and this is optional + * syntax: DEF_PACKAGE_PARAM_ORDER(OP,PARAM1,MANDATORY1,DEFAULT1,PARAM2,MANDATORY2,DEFAULT2...) + * one or more parameters can be specified for each op + * order of parameters listed determines the order of parameters passed into op execution functions + * if an op does not have a parameter order definition, parameter order passed into Qnn_addNode + * will be passed into op execution functions + * if an op has a parameter order definition, any parameter passed into Qnn_addNode with unlisted + * name will be abandoned + * if two or more op packages with the same package name will be registered, they cannot list + * conflicting parameter orders + * PARAM refers to parameter name as a string literal + * MANDATORY refers to whether this parameter is required to be provided at Qnn_addNode + * DEFAULT is used when MANDATORY is false + * if provided as Qnn_Param_t*, + * DEFAULT will be used for graph construction when this parameter is not provided at + * Qnn_addNode + * if provided as nullptr, + * graph construction will skip this parameter when this parameter is not provided at + * Qnn_addNode + */ +DEF_PACKAGE_PARAM_ORDER("LLaMADequantizeAdd", + "scale", + true, + nullptr) + +/* execute functions for ops */ +#ifndef REFERENCE_OP +#include "qhmath_hvx.h" +#include "hvx_internal.h" +#include +#include + +#define BLOCK_SIZE (8 * 1024 / VLEN) /* vector chunks */ +#define L2FETCH_AHEAD (BLOCK_SIZE) + +static inline int32_t float_to_fp16s(float input) { + union { + int32_t i; + __fp16 f[2]; + } fp32 = {.f = {(__fp16)input, (__fp16)input}}; + return fp32.i; +} + +static HVX_INLINE_ALWAYS uint32_t float_to_bits(float x) { + union { + float f; + uint32_t i; + } fp32 = {.f = x}; + return fp32.i; +} + +/* execute functions for ops */ +int32_t qhmath_hvx_dequantize_add_ahf( + int8_t *restrict input, + float_t *restrict bias, + int8_t *restrict output, + uint32_t size, + float scale) { + if ((input == NULL) || (bias == NULL) || (output == NULL) || (size == 0)) { + return -1; + } + + HVX_Vector *iptr = (HVX_Vector *)input; + HVX_Vector *bptr = (HVX_Vector *)bias; + HVX_UVector *optr = (HVX_UVector *)output; + + HVX_Vector sline1p, sline1c, sline1; + HVX_Vector scale_vec; + + int32_t block, l2fetch_block; + int32_t leftover = size & 127; + int32_t vectors_in_rounddown = size / 128; // element number! + // int32_t leftover_size = leftover * sizeof(float); + + sline1p = *iptr++; + + uint32_t convert = 0x00800080; + HVX_Vector convert_vector = Q6_V_vsplat_R(convert); + + scale_vec = Q6_V_vsplat_R(float_to_fp16s(scale)); + HVX_Vector zero_v_sf = Q6_V_vzero(); + + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { + block = Q6_R_min_RR(i, BLOCK_SIZE); + l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); + + if (l2fetch_block > 0) { + l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); + } + + for (int32_t j = 0; j < block; ++j) { + sline1c = *iptr++; + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); + HVX_VectorPair temp = Q6_Wh_vadd_VubVub(sline1, zero_v_sf); + + temp = Q6_W_vshuff_VVR(Q6_V_hi_W(temp), Q6_V_lo_W(temp), -2); + HVX_Vector sout1 = Q6_Vh_vsub_VhVh(Q6_V_lo_W(temp), convert_vector); + HVX_Vector sout2 = Q6_Vh_vsub_VhVh(Q6_V_hi_W(temp), convert_vector); + + HVX_Vector bvec1 = *bptr++; // load 32 float elements + HVX_Vector bvec2 = *bptr++; + // see HVX documention for Vector shuffle and deal cross-lane + // Q6_Vhf_equals_Wqf32 will use elements in the lower vector as the odd elements, so need to transpose here + HVX_VectorPair bias_pair = Q6_W_vdeal_VVR(Q6_Vqf32_equals_Vsf(bvec2), Q6_Vqf32_equals_Vsf(bvec1), -4); // make a fp32 pair + HVX_Vector hf16_bias = Q6_Vhf_equals_Wqf32(bias_pair); + + *optr++ = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf( + Q6_Vqf16_vmpy_VhfVhf( + Q6_Vhf_equals_Vh(sout1), scale_vec), + hf16_bias)); + + bvec1 = *bptr++; // load 32 float elements + bvec2 = *bptr++; + // see HVX documention for Vector shuffle and deal cross-lane + // Q6_Vhf_equals_Wqf32 will use elements in the lower vector as the odd elements, so need to transpose here + bias_pair = Q6_W_vdeal_VVR(Q6_Vqf32_equals_Vsf(bvec2), Q6_Vqf32_equals_Vsf(bvec1), -4); // make a fp32 pair + hf16_bias = Q6_Vhf_equals_Wqf32(bias_pair); + + *optr++ = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf( + Q6_Vqf16_vmpy_VhfVhf( + Q6_Vhf_equals_Vh(sout2), scale_vec), + hf16_bias)); + + sline1p = sline1c; + } + } + + if (vectors_in_rounddown > 0) { + sline1c = is_aligned(iptr, VLEN) && leftover == 0 ? sline1p : *iptr++; + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); + + HVX_VectorPair temp = Q6_Wh_vadd_VubVub(sline1, zero_v_sf); + + temp = Q6_W_vshuff_VVR(Q6_V_hi_W(temp), Q6_V_lo_W(temp), -2); + HVX_Vector sout1 = Q6_Vh_vsub_VhVh(Q6_V_lo_W(temp), convert_vector); + HVX_Vector sout2 = Q6_Vh_vsub_VhVh(Q6_V_hi_W(temp), convert_vector); + + HVX_Vector bvec1 = *bptr++; // load 32 float elements + HVX_Vector bvec2 = *bptr++; + HVX_VectorPair bias_pair = Q6_W_vshuff_VVR(Q6_Vqf32_equals_Vsf(bvec1), Q6_Vqf32_equals_Vsf(bvec2), -4); // make a fp32 pair + HVX_Vector hf16_bias = Q6_Vhf_equals_Wqf32(bias_pair); + + *optr++ = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf( + Q6_Vqf16_vmpy_VhfVhf( + Q6_Vhf_equals_Vh(sout1), scale_vec), + hf16_bias)); + + bvec1 = *bptr++; // load 32 float elements + bvec2 = *bptr++; + bias_pair = Q6_W_vshuff_VVR(Q6_Vqf32_equals_Vsf(bvec1), Q6_Vqf32_equals_Vsf(bvec2), -4); // make a fp32 pair + hf16_bias = Q6_Vhf_equals_Wqf32(bias_pair); + + *optr++ = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf( + Q6_Vqf16_vmpy_VhfVhf( + Q6_Vhf_equals_Vh(sout2), scale_vec), + hf16_bias)); + } + + return 0; +} + +int32_t qhmath_hvx_dequantize_add_ui16_ahf( + int8_t *restrict input, + float_t *restrict bias, + int8_t *restrict output, + uint32_t size, + float scale) { + if ((input == NULL) || (bias == NULL) || (output == NULL) || (size == 0)) { + return -1; + } + + HVX_Vector *iptr = (HVX_Vector *)input; + HVX_Vector *bptr = (HVX_Vector *)bias; + HVX_UVector *optr = (HVX_UVector *)output; + + HVX_Vector sline1p, sline1c, sline1; + HVX_Vector scale_vec; + + int32_t block, l2fetch_block; + int32_t leftover = size & 63; + int32_t vectors_in_rounddown = size / 64; // element number! + // int32_t leftover_size = leftover * sizeof(float); + + sline1p = *iptr++; + + uint32_t convert = 0x80008000; + HVX_Vector convert_vector = Q6_V_vsplat_R(convert); + + scale_vec = Q6_V_vsplat_R(float_to_fp16s(scale)); + + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { + block = Q6_R_min_RR(i, BLOCK_SIZE); + l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); + + if (l2fetch_block > 0) { + l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); + } + + for (int32_t j = 0; j < block; ++j) { + sline1c = *iptr++; + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); + + HVX_Vector bvec1 = *bptr++; // load 32 float elements + HVX_Vector bvec2 = *bptr++; + // see HVX documention for Vector shuffle and deal cross-lane + // Q6_Vhf_equals_Wqf32 will use elements in the lower vector as the odd elements, so need to transpose here + HVX_VectorPair bias_pair = Q6_W_vdeal_VVR(Q6_Vqf32_equals_Vsf(bvec2), Q6_Vqf32_equals_Vsf(bvec1), -4); // make a fp32 pair + HVX_Vector hf16_bias = Q6_Vhf_equals_Wqf32(bias_pair); + + HVX_Vector temp = sline1; + + HVX_Vector sout1 = Q6_Vh_vsub_VhVh(temp, convert_vector); + HVX_Vector qf16_val = Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout1), scale_vec); + + *optr++ = Q6_Vhf_equals_Vqf16( + Q6_Vqf16_vadd_Vqf16Vhf(qf16_val, hf16_bias)); + + sline1p = sline1c; + } + } + + if (vectors_in_rounddown > 0) { + sline1c = is_aligned(iptr, VLEN) && leftover == 0 ? sline1p : *iptr++; + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); + + HVX_Vector bvec1 = *bptr++; // load 32 float elements + HVX_Vector bvec2 = *bptr++; + + // see HVX documention for Vector shuffle and deal cross-lane + // Q6_Vhf_equals_Wqf32 will use elements in the lower vector as the odd elements, so need to transpose here + HVX_VectorPair bias_pair = Q6_W_vdeal_VVR(Q6_Vqf32_equals_Vsf(bvec2), Q6_Vqf32_equals_Vsf(bvec1), -4); // make a fp32 pair + HVX_Vector hf16_bias = Q6_Vhf_equals_Wqf32(bias_pair); + + HVX_Vector temp = sline1; + + HVX_Vector sout1 = Q6_Vh_vsub_VhVh(temp, convert_vector); + HVX_Vector qf16_val = Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout1), scale_vec); + + *optr++ = Q6_Vhf_equals_Vqf16( + Q6_Vqf16_vadd_Vqf16Vhf(qf16_val, hf16_bias)); + } + + return 0; +} + +// Only support 128x dimension +int32_t qhmath_hvx_dequantize_add_af( + int8_t *restrict input, + float_t *restrict bias, + int8_t *restrict output, + uint32_t size, + float scale) { + if ((input == NULL) || (bias == NULL) || (output == NULL) || (size == 0)) { + return -1; + } + + HVX_Vector *iptr = (HVX_Vector *)input; + HVX_Vector *bptr = (HVX_Vector *)bias; + HVX_UVector *optr = (HVX_UVector *)output; + + HVX_Vector sline1p, sline1c, sline1; + HVX_Vector scale_vec; + HVX_Vector one_vec; + + int32_t block, l2fetch_block; + int32_t leftover = size & 127; + int32_t vectors_in_rounddown = size / 128; + // int32_t leftover_size = leftover * sizeof(float); + + sline1p = *iptr++; + + uint32_t convert = 0x00800080; + HVX_Vector convert_vector = Q6_V_vsplat_R(convert); + + scale_vec = Q6_V_vsplat_R(float_to_bits(scale)); + one_vec = Q6_V_vsplat_R(float_to_fp16s(1.0)); + HVX_Vector zero_v_sf = Q6_V_vzero(); + scale_vec = Q6_Vqf32_vadd_VsfVsf(scale_vec, Q6_V_vzero()); + + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { + block = Q6_R_min_RR(i, BLOCK_SIZE); + l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); + + if (l2fetch_block > 0) { + l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); + } + + for (int32_t j = 0; j < block; ++j) { + sline1c = *iptr++; + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); + HVX_Vector bvec1 = *bptr++; // load 32 float elements + HVX_Vector bvec2 = *bptr++; + HVX_Vector bvec3 = *bptr++; + HVX_Vector bvec4 = *bptr++; + + HVX_VectorPair temp = Q6_Wh_vadd_VubVub(sline1, zero_v_sf); + + temp = Q6_W_vshuff_VVR(Q6_V_hi_W(temp), Q6_V_lo_W(temp), -2); + HVX_Vector sout1 = Q6_Vh_vsub_VhVh(Q6_V_lo_W(temp), convert_vector); + HVX_Vector sout2 = Q6_Vh_vsub_VhVh(Q6_V_hi_W(temp), convert_vector); + + HVX_VectorPair result1 = Q6_Wqf32_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout1), one_vec); + result1 = Q6_W_vshuff_VVR(Q6_V_hi_W(result1), Q6_V_lo_W(result1), -4); + + HVX_VectorPair result2 = Q6_Wqf32_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout2), one_vec); + result2 = Q6_W_vshuff_VVR(Q6_V_hi_W(result2), Q6_V_lo_W(result2), -4); + + *optr++ = Q6_Vsf_equals_Vqf32( + Q6_Vqf32_vadd_Vqf32Vqf32( + Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result1), scale_vec), + Q6_Vqf32_equals_Vsf(bvec1))); + *optr++ = Q6_Vsf_equals_Vqf32( + Q6_Vqf32_vadd_Vqf32Vqf32( + Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result1), scale_vec), + Q6_Vqf32_equals_Vsf(bvec2))); + + *optr++ = Q6_Vsf_equals_Vqf32( + Q6_Vqf32_vadd_Vqf32Vqf32( + Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result2), scale_vec), + Q6_Vqf32_equals_Vsf(bvec3))); + + *optr++ = Q6_Vsf_equals_Vqf32( + Q6_Vqf32_vadd_Vqf32Vqf32( + Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result2), scale_vec), + Q6_Vqf32_equals_Vsf(bvec4))); + + sline1p = sline1c; + } + } + + if (vectors_in_rounddown > 0) { + sline1c = is_aligned(iptr, VLEN) && leftover == 0 ? sline1p : *iptr++; + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); + + // NOTE: assume bias size is multiple of 128 + HVX_Vector bvec1 = *bptr++; // load 32 float elements + HVX_Vector bvec2 = *bptr++; + HVX_Vector bvec3 = *bptr++; + HVX_Vector bvec4 = *bptr++; + + HVX_VectorPair temp = Q6_Wh_vadd_VubVub(sline1, zero_v_sf); + + temp = Q6_W_vshuff_VVR(Q6_V_hi_W(temp), Q6_V_lo_W(temp), -2); + HVX_Vector sout1 = Q6_Vh_vsub_VhVh(Q6_V_lo_W(temp), convert_vector); + HVX_Vector sout2 = Q6_Vh_vsub_VhVh(Q6_V_hi_W(temp), convert_vector); + + HVX_VectorPair result1 = Q6_Wqf32_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout1), one_vec); + result1 = Q6_W_vshuff_VVR(Q6_V_hi_W(result1), Q6_V_lo_W(result1), -4); + + HVX_VectorPair result2 = Q6_Wqf32_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout2), one_vec); + result2 = Q6_W_vshuff_VVR(Q6_V_hi_W(result2), Q6_V_lo_W(result2), -4); + + *optr++ = Q6_Vsf_equals_Vqf32( + Q6_Vqf32_vadd_Vqf32Vqf32( + Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result1), scale_vec), + Q6_Vqf32_equals_Vsf(bvec1))); + *optr++ = Q6_Vsf_equals_Vqf32( + Q6_Vqf32_vadd_Vqf32Vqf32( + Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result1), scale_vec), + Q6_Vqf32_equals_Vsf(bvec2))); + + *optr++ = Q6_Vsf_equals_Vqf32( + Q6_Vqf32_vadd_Vqf32Vqf32( + Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result2), scale_vec), + Q6_Vqf32_equals_Vsf(bvec3))); + + *optr++ = Q6_Vsf_equals_Vqf32( + Q6_Vqf32_vadd_Vqf32Vqf32( + Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result2), scale_vec), + Q6_Vqf32_equals_Vsf(bvec4))); + } + + return 0; +} + +int32_t qhmath_hvx_dequantize_add_ui16_af( + int8_t *restrict input, + float_t *restrict bias, + int8_t *restrict output, + uint32_t size, + float scale) { + if ((input == NULL) || (bias == NULL) || (output == NULL) || (size == 0)) { + return -1; + } + + HVX_Vector *iptr = (HVX_Vector *)input; + HVX_Vector *bptr = (HVX_Vector *)bias; + HVX_UVector *optr = (HVX_UVector *)output; + + HVX_Vector sline1p, sline1c, sline1; + HVX_Vector scale_vec; + HVX_Vector one_vec; + + int32_t block, l2fetch_block; + int32_t leftover = size & 63; + int32_t vectors_in_rounddown = size / 64; + // int32_t leftover_size = leftover * sizeof(float); + + sline1p = *iptr++; + + uint32_t convert = 0x80008000; + HVX_Vector convert_vector = Q6_V_vsplat_R(convert); + + scale_vec = Q6_V_vsplat_R(float_to_bits(scale)); + one_vec = Q6_V_vsplat_R(float_to_fp16s(1.0)); + scale_vec = Q6_Vqf32_vadd_VsfVsf(scale_vec, Q6_V_vzero()); + + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { + block = Q6_R_min_RR(i, BLOCK_SIZE); + l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); + + if (l2fetch_block > 0) { + l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); + } + + for (int32_t j = 0; j < block; ++j) { + sline1c = *iptr++; + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); + + HVX_Vector bvec1 = *bptr++; // load 32 float elements + HVX_Vector bvec2 = *bptr++; + + HVX_Vector temp = Q6_Vh_vsub_VhVh(sline1, convert_vector); + HVX_VectorPair result = Q6_Wqf32_vmpy_VhfVhf(Q6_Vhf_equals_Vh(temp), one_vec); + result = Q6_W_vshuff_VVR(Q6_V_hi_W(result), Q6_V_lo_W(result), -4); + *optr++ = Q6_Vsf_equals_Vqf32( + Q6_Vqf32_vadd_Vqf32Vqf32( + Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result), scale_vec), Q6_Vqf32_equals_Vsf(bvec1))); + *optr++ = Q6_Vsf_equals_Vqf32( + Q6_Vqf32_vadd_Vqf32Vqf32( + Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result), scale_vec), Q6_Vqf32_equals_Vsf(bvec2))); + + sline1p = sline1c; + } + } + + if (vectors_in_rounddown > 0) { + sline1c = is_aligned(iptr, VLEN) && leftover == 0 ? sline1p : *iptr++; + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); + + HVX_Vector bvec1 = *bptr++; // load 32 float elements + HVX_Vector bvec2 = *bptr++; + + HVX_Vector temp = Q6_Vh_vsub_VhVh(sline1, convert_vector); + HVX_VectorPair result = Q6_Wqf32_vmpy_VhfVhf(Q6_Vhf_equals_Vh(temp), one_vec); + result = Q6_W_vshuff_VVR(Q6_V_hi_W(result), Q6_V_lo_W(result), -4); + *optr++ = Q6_Vsf_equals_Vqf32( + Q6_Vqf32_vadd_Vqf32Vqf32( + Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result), scale_vec), Q6_Vqf32_equals_Vsf(bvec1))); + *optr++ = Q6_Vsf_equals_Vqf32( + Q6_Vqf32_vadd_Vqf32Vqf32( + Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result), scale_vec), Q6_Vqf32_equals_Vsf(bvec2))); + } + + return 0; +} + +template +GraphStatus llamadequantizeaddImpl(TensorType1 &out_0, + const TensorType1 &in_0, + const TensorType &in_1, + const PlainFloatTensor &scale) + +{ + /* + * add code here + * */ + /* + * To have good performance and stability, it is required to avoid heap memory + * allocation in this function. The heap memory allocation includes but not + * limited to calling malloc, operator new, constructing STL container objects + * like std::vector with default allocator, and adding items like calling + * std::vector::push_back to STL container objects with default allocator. + * + * Please check in SDK documentation for more information. + */ + + // HVX Method -- FP32 Version + out_0.set_dims(in_0); + + // NHWC + auto bias_ptr = (float *)in_1.raw_data_const(); + auto [b_in, h_in, w_in, d_in] = in_0.dims(); + + float scale_ = scale(0, 0, 0, 0); + + if (d_in % 128 != 0) { + return GraphStatus::ErrorDimensions; + } + + // call the kernel function for every dim() (assume total_size == bias_length) + // NOTE: in modeling, the dequantize add can appear after linear multihead attention, so w_in * d_in == bias_length + // in other positions, the w_in will be 1 + if (in_0.get_dtype() == DType::QUInt8 && out_0.get_dtype() == DType::Float16) { + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + auto in_ptr = (int8_t *)in_0.raw_data_const() + (((b * h_in) + h) * w_in * d_in); + auto out_ptr = (int8_t *)((int16_t *)out_0.raw_data() + (((b * h_in) + h) * w_in * d_in)); + qhmath_hvx_dequantize_add_ahf(in_ptr, bias_ptr, out_ptr, w_in * d_in, scale_); + } + } + } else if (in_0.get_dtype() == DType::QUInt16 && out_0.get_dtype() == DType::Float16) { + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + auto in_ptr = (int8_t *)((int16_t *)in_0.raw_data_const() + (((b * h_in) + h) * w_in * d_in)); + auto out_ptr = (int8_t *)((int16_t *)out_0.raw_data() + (((b * h_in) + h) * w_in * d_in)); + qhmath_hvx_dequantize_add_ui16_ahf(in_ptr, bias_ptr, out_ptr, w_in * d_in, scale_); + } + } + } else if (in_0.get_dtype() == DType::QUInt16 && out_0.get_dtype() == DType::Float32) { + // NOTE: correct + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + auto in_ptr = (int8_t *)((int16_t *)in_0.raw_data_const() + (((b * h_in) + h) * w_in * d_in)); + auto out_ptr = (int8_t *)((float_t *)out_0.raw_data() + (((b * h_in) + h) * w_in * d_in)); + qhmath_hvx_dequantize_add_ui16_af(in_ptr, bias_ptr, out_ptr, w_in * d_in, scale_); + } + } + } else if (in_0.get_dtype() == DType::QUInt8 && out_0.get_dtype() == DType::Float32) { + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + auto in_ptr = (int8_t *)in_0.raw_data_const() + (((b * h_in) + h) * w_in * d_in); + auto out_ptr = (int8_t *)((float_t *)out_0.raw_data() + ((((b * h_in) + h) * w_in * d_in))); + qhmath_hvx_dequantize_add_af(in_ptr, bias_ptr, out_ptr, w_in * d_in, scale_); + } + } + } else { + return GraphStatus::GraphErrorCode::ErrorUnsupported; + } + + return GraphStatus::Success; +} +#else + +template +GraphStatus llamadequantizeaddImpl(TensorType1 &out_0, + const TensorType1 &in_0, + const TensorType &in_1, + const PlainFloatTensor &scale) + +{ + /* + * add code here + * */ + /* + * To have good performance and stability, it is required to avoid heap memory + * allocation in this function. The heap memory allocation includes but not + * limited to calling malloc, operator new, constructing STL container objects + * like std::vector with default allocator, and adding items like calling + * std::vector::push_back to STL container objects with default allocator. + * + * Please check in SDK documentation for more information. + */ + + out_0.set_dims(in_0); + + float scale_ = scale(0, 0, 0, 0); + auto [b_in, h_in, w_in, d_in] = in_0.dims(); + + if (in_0.get_dtype() == DType::QUInt8 && out_0.get_dtype() == DType::Float32) { + auto out_ptr = (float *)out_0.raw_data(); + auto in_ptr = (uint8_t *)in_0.raw_data_const(); + + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + for (Idx w = 0; w < w_in; w++) { + for (Idx d = 0; d < d_in; d++) { + int32_t inval = static_cast(*in_ptr++); + *out_ptr++ = (inval - 128) * scale_ + in_1(0, 0, 0, w * d_in + d); + } + } + } + } + } else if (in_0.get_dtype() == DType::QUInt16 && out_0.get_dtype() == DType::Float32) { + auto out_ptr = (float *)out_0.raw_data(); + auto in_ptr = (uint16_t *)in_0.raw_data_const(); + + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + for (Idx w = 0; w < w_in; w++) { + for (Idx d = 0; d < d_in; d++) { + int32_t inval = static_cast(*in_ptr++); + *out_ptr++ = (inval - 32768) * scale_ + in_1(0, 0, 0, w * d_in + d); + } + } + } + } + } else if (in_0.get_dtype() == DType::QUInt16 && out_0.get_dtype() == DType::Float16) { + auto out_ptr = (__fp16 *)out_0.raw_data(); + auto in_ptr = (uint16_t *)in_0.raw_data_const(); + + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + for (Idx w = 0; w < w_in; w++) { + for (Idx d = 0; d < d_in; d++) { + int32_t inval = static_cast(*in_ptr++); + *out_ptr++ = (__fp16)((inval - 32768) * scale_ + in_1(0, 0, 0, w * d_in + d)); + } + } + } + } + } else if (in_0.get_dtype() == DType::QUInt8 && out_0.get_dtype() == DType::Float16) { + auto out_ptr = (__fp16 *)out_0.raw_data(); + auto in_ptr = (uint8_t *)in_0.raw_data_const(); + + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + for (Idx w = 0; w < w_in; w++) { + for (Idx d = 0; d < d_in; d++) { + int32_t inval = static_cast(*in_ptr++); + *out_ptr++ = (__fp16)((inval - 128) * scale_ + in_1(0, 0, 0, w * d_in + d)); + } + } + } + } + } + return GraphStatus::Success; +} + +#endif + +__attribute__((unused)) static float llamadequantizeaddCostFunc(const Op *op) { + /* + * add code here + * */ + + float cost = 0.0; // add cost computation here + return cost; +} + +/* At the bottom of the op file, call END_PKG_OP_DEFINITION(), + where is as BEGIN_PKG_OP_DEFINITION +*/ +END_PKG_OP_DEFINITION(PKG_LLaMADequantizeAdd); \ No newline at end of file diff --git a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMALinear.cpp b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMALinear.cpp index 5c3358dac..44487d245 100755 --- a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMALinear.cpp +++ b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMALinear.cpp @@ -9,20 +9,18 @@ #include "QnnOpPackage.h" #include "HTP/core/simple_reg.h" - BEGIN_PKG_OP_DEFINITION(PKG_LLaMALinear); - // op execute function declarations -template -GraphStatus llamalinearImpl(TensorType& out_0, - const TensorType& in_0, - const TensorType& in_1, - const TensorType& in_2, - const PlainFloatTensor& in_scale, - const PlainFloatTensor& weight_scale, - const PlainFloatTensor& bias_scale, - const PlainFloatTensor& output_scale); +template +GraphStatus llamalinearImpl(TensorType &out_0, + const TensorType &in_0, + const TensorType &in_1, + const TensorType &in_2, + const PlainFloatTensor &in_scale, + const PlainFloatTensor &weight_scale, + const PlainFloatTensor &bias_scale, + const PlainFloatTensor &output_scale); // forward declaration of sample cost function static float llamalinearCostFunc(const Op *op); @@ -67,11 +65,11 @@ DEF_PACKAGE_OP((llamalinearImpl), "LLaMALinear") * one definition per op, and this is optional * syntax: DEF_PACKAGE_PARAM_ORDER(OP,PARAM1,MANDATORY1,DEFAULT1,PARAM2,MANDATORY2,DEFAULT2...) * one or more parameters can be specified for each op - * order of parameters listed determines the order of parameters passed into op execution functions + * order of parameters listed determines the order of parameters passed into op execution functions * if an op does not have a parameter order definition, parameter order passed into Qnn_addNode * will be passed into op execution functions * if an op has a parameter order definition, any parameter passed into Qnn_addNode with unlisted - * name will be abandoned + * name will be abandoned * if two or more op packages with the same package name will be registered, they cannot list * conflicting parameter orders * PARAM refers to parameter name as a string literal @@ -84,7 +82,7 @@ DEF_PACKAGE_OP((llamalinearImpl), "LLaMALinear") * graph construction will skip this parameter when this parameter is not provided at * Qnn_addNode */ -DEF_PACKAGE_PARAM_ORDER("LLaMALinear", +DEF_PACKAGE_PARAM_ORDER("LLaMALinear", "in_scale", true, nullptr, @@ -98,13 +96,12 @@ DEF_PACKAGE_PARAM_ORDER("LLaMALinear", true, nullptr) - /* execute functions for ops */ float Round(float num) { float floor_num = floor(num); float ceil_num = ceil(num); - + if (num - floor_num < ceil_num - num) { return floor_num; } else { @@ -112,34 +109,34 @@ float Round(float num) { } } -template -GraphStatus llamalinearImpl(TensorType& out_0, - const TensorType& in_0, - const TensorType& in_1, - const TensorType& in_2, - const PlainFloatTensor& in_scale, - const PlainFloatTensor& weight_scale, - const PlainFloatTensor& bias_scale, - const PlainFloatTensor& output_scale) +template +GraphStatus llamalinearImpl(TensorType &out_0, + const TensorType &in_0, + const TensorType &in_1, + const TensorType &in_2, + const PlainFloatTensor &in_scale, + const PlainFloatTensor &weight_scale, + const PlainFloatTensor &bias_scale, + const PlainFloatTensor &output_scale) { - /* - * add code here - * */ - /* - * To have good performance and stability, it is required to avoid heap memory - * allocation in this function. The heap memory allocation includes but not - * limited to calling malloc, operator new, constructing STL container objects - * like std::vector with default allocator, and adding items like calling - * std::vector::push_back to STL container objects with default allocator. - * - * Please check in SDK documentation for more information. - */ - // 假设输入张量是4维的,NHWC格式 + /* + * add code here + * */ + /* + * To have good performance and stability, it is required to avoid heap memory + * allocation in this function. The heap memory allocation includes but not + * limited to calling malloc, operator new, constructing STL container objects + * like std::vector with default allocator, and adding items like calling + * std::vector::push_back to STL container objects with default allocator. + * + * Please check in SDK documentation for more information. + */ + // 假设输入张量是4维的,NHWC格式 int batch_size = in_0.dims()[0]; int height = in_0.dims()[1]; int width = in_0.dims()[2]; - int in_features = in_0.dims()[3]; // 输入的通道数 + int in_features = in_0.dims()[3]; // 输入的通道数 int out_features = in_1.dims()[3]; // 输出的特征数(即输出通道数) // 检查输入张量的形状是否匹配 @@ -147,24 +144,23 @@ GraphStatus llamalinearImpl(TensorType& out_0, return GraphStatus::ErrorFatal; } - // 获取量化比例 - float w_scale = weight_scale(0,0,0,0); - float i_scale = in_scale(0,0,0,0); - float b_scale = bias_scale(0,0,0,0); - float o_scale = output_scale(0,0,0,0); - + float w_scale = weight_scale(0, 0, 0, 0); + float i_scale = in_scale(0, 0, 0, 0); + float b_scale = bias_scale(0, 0, 0, 0); + float o_scale = output_scale(0, 0, 0, 0); + // 初始化输出张量 size_t dims[] = {static_cast(batch_size), static_cast(height), static_cast(width), static_cast(out_features)}; out_0.set_dims(dims); // only support float bias now. - auto in0_ptr = (uint8_t*)in_0.raw_data_const(); - auto in1_ptr = (uint8_t*)in_1.raw_data_const(); - auto in2_ptr = (uint8_t*)in_2.raw_data_const(); - auto out_ptr = (int8_t*)out_0.raw_data(); - + auto in0_ptr = (uint8_t *)in_0.raw_data_const(); + auto in1_ptr = (uint8_t *)in_1.raw_data_const(); + auto in2_ptr = (uint8_t *)in_2.raw_data_const(); + auto out_ptr = (int8_t *)out_0.raw_data(); + // 进行量化Linear乘法 for (int b = 0; b < batch_size; ++b) { for (int h = 0; h < height; ++h) { @@ -174,24 +170,24 @@ GraphStatus llamalinearImpl(TensorType& out_0, for (int k = 0; k < in_features; ++k) { int in_index = b * height * width * in_features + h * width * in_features + w * in_features + k; int weight_index = k * out_features + n; - acc += ((static_cast(in0_ptr[in_index])-128) * i_scale) * ((static_cast(in1_ptr[weight_index])-128) * w_scale); + acc += ((static_cast(in0_ptr[in_index]) - 128) * i_scale) * ((static_cast(in1_ptr[weight_index]) - 128) * w_scale); } // 加上偏置并进行反量化 float result = acc; - result += (static_cast(in2_ptr[n])-128) * b_scale; + result += (static_cast(in2_ptr[n]) - 128) * b_scale; // 将结果限制在uint8范围内 int out_index = b * height * width * out_features + h * width * out_features + w * out_features + n; result = Round(result / o_scale); long v = lroundf(result); - + if (v > 127) v = 127; - + if (v < -128) v = -128; - + if (out_0.get_dtype() == DType::QUInt8) v += 128; @@ -201,23 +197,18 @@ GraphStatus llamalinearImpl(TensorType& out_0, } } - return GraphStatus::Success; + return GraphStatus::Success; } -__attribute__((unused)) static float llamalinearCostFunc(const Op *op) -{ - /* - * add code here - * */ +__attribute__((unused)) static float llamalinearCostFunc(const Op *op) { + /* + * add code here + * */ - float cost = 0.0; // add cost computation here - return cost; + float cost = 0.0; // add cost computation here + return cost; } - - - - /* At the bottom of the op file, call END_PKG_OP_DEFINITION(), where is as BEGIN_PKG_OP_DEFINITION */ diff --git a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAMul.cpp b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAMul.cpp index 36c614ea8..802acbacf 100755 --- a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAMul.cpp +++ b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAMul.cpp @@ -9,15 +9,13 @@ #include "QnnOpPackage.h" #include "HTP/core/simple_reg.h" - BEGIN_PKG_OP_DEFINITION(PKG_LLaMAMul); - // op execute function declarations -template -GraphStatus llamamulImpl(TensorType& out_0, - const TensorType& in_0, - const TensorType& in_1); +template +GraphStatus llamamulImpl(TensorType &out_0, + const TensorType &in_0, + const TensorType &in_1); // forward declaration of sample cost function static float llamamulCostFunc(const Op *op); @@ -62,11 +60,11 @@ DEF_PACKAGE_OP((llamamulImpl), "LLaMAMul") * one definition per op, and this is optional * syntax: DEF_PACKAGE_PARAM_ORDER(OP,PARAM1,MANDATORY1,DEFAULT1,PARAM2,MANDATORY2,DEFAULT2...) * one or more parameters can be specified for each op - * order of parameters listed determines the order of parameters passed into op execution functions + * order of parameters listed determines the order of parameters passed into op execution functions * if an op does not have a parameter order definition, parameter order passed into Qnn_addNode * will be passed into op execution functions * if an op has a parameter order definition, any parameter passed into Qnn_addNode with unlisted - * name will be abandoned + * name will be abandoned * if two or more op packages with the same package name will be registered, they cannot list * conflicting parameter orders * PARAM refers to parameter name as a string literal @@ -80,7 +78,6 @@ DEF_PACKAGE_OP((llamamulImpl), "LLaMAMul") * Qnn_addNode */ - /* execute functions for ops */ #ifndef REFERENCE_OP @@ -89,17 +86,15 @@ DEF_PACKAGE_OP((llamamulImpl), "LLaMAMul") #include #include -#define BLOCK_SIZE (8*1024/VLEN) /* vector chunks */ -#define L2FETCH_AHEAD (BLOCK_SIZE) +#define BLOCK_SIZE (8 * 1024 / VLEN) /* vector chunks */ +#define L2FETCH_AHEAD (BLOCK_SIZE) int32_t hvx_mul_af( float *restrict input, float *restrict input2, float *restrict output, - uint32_t size) -{ - if ((input == NULL) || (output == NULL) || (size == 0)) - { + uint32_t size) { + if ((input == NULL) || (output == NULL) || (size == 0)) { return -1; } @@ -117,19 +112,16 @@ int32_t hvx_mul_af( sline1p = *iptr++; sline2p = *iptr2++; - for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) - { + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { block = Q6_R_min_RR(i, BLOCK_SIZE); l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); - if (l2fetch_block > 0) - { + if (l2fetch_block > 0) { l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); l2fetch(iptr2 + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); } - for (int32_t j = 0; j < block; ++j) - { + for (int32_t j = 0; j < block; ++j) { sline1c = *iptr++; sline2c = *iptr2++; sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); @@ -143,31 +135,24 @@ int32_t hvx_mul_af( } if (vectors_in_rounddown > 0) { + sline1c = is_aligned(iptr, VLEN) && leftover == 0 ? sline1p : *iptr++; + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); - sline1c = is_aligned(iptr, VLEN) && leftover == 0 ? sline1p : *iptr++; - sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t) input); - - sline2c = is_aligned(iptr2, VLEN) && leftover == 0 ? sline2p : *iptr2++; - sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t) input2); - - *optr++ = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(sline1, sline2)); + sline2c = is_aligned(iptr2, VLEN) && leftover == 0 ? sline2p : *iptr2++; + sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t)input2); + *optr++ = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(sline1, sline2)); } // Handle leftover elements. if (leftover_size > 0) { - sline1c = (is_in_one_chunk(iptr, leftover_size, VLEN) - ? sline1p - : *iptr++); - sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); + sline1c = (is_in_one_chunk(iptr, leftover_size, VLEN) ? sline1p : *iptr++); + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); + sline2c = (is_in_one_chunk(iptr2, leftover_size, VLEN) ? sline2p : *iptr2++); + sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t)input2); - sline2c = (is_in_one_chunk(iptr2, leftover_size, VLEN) - ? sline2p - : *iptr2++); - sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t)input2); - - vstu_variable(optr, leftover_size, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(sline1, sline2))); + vstu_variable(optr, leftover_size, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(sline1, sline2))); } return 0; @@ -177,10 +162,8 @@ int32_t hvx_mul_ahf( __fp16 *restrict input, __fp16 *restrict input2, __fp16 *restrict output, - uint32_t size) -{ - if ((input == NULL) || (output == NULL) || (size == 0)) - { + uint32_t size) { + if ((input == NULL) || (output == NULL) || (size == 0)) { return -1; } @@ -198,19 +181,16 @@ int32_t hvx_mul_ahf( sline1p = *iptr++; sline2p = *iptr2++; - for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) - { + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { block = Q6_R_min_RR(i, BLOCK_SIZE); l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); - if (l2fetch_block > 0) - { + if (l2fetch_block > 0) { l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); l2fetch(iptr2 + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); } - for (int32_t j = 0; j < block; ++j) - { + for (int32_t j = 0; j < block; ++j) { sline1c = *iptr++; sline2c = *iptr2++; sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); @@ -224,160 +204,140 @@ int32_t hvx_mul_ahf( } if (vectors_in_rounddown > 0) { + sline1c = is_aligned(iptr, VLEN) && leftover == 0 ? sline1p : *iptr++; + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); - sline1c = is_aligned(iptr, VLEN) && leftover == 0 ? sline1p : *iptr++; - sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t) input); - - sline2c = is_aligned(iptr2, VLEN) && leftover == 0 ? sline2p : *iptr2++; - sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t) input2); - - *optr++ = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(sline1, sline2)); + sline2c = is_aligned(iptr2, VLEN) && leftover == 0 ? sline2p : *iptr2++; + sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t)input2); + *optr++ = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(sline1, sline2)); } // Handle leftover elements. if (leftover_size > 0) { - sline1c = (is_in_one_chunk(iptr, leftover_size, VLEN) - ? sline1p - : *iptr++); - sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); - + sline1c = (is_in_one_chunk(iptr, leftover_size, VLEN) ? sline1p : *iptr++); + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); - sline2c = (is_in_one_chunk(iptr2, leftover_size, VLEN) - ? sline2p - : *iptr2++); - sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t)input2); + sline2c = (is_in_one_chunk(iptr2, leftover_size, VLEN) ? sline2p : *iptr2++); + sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t)input2); - vstu_variable(optr, leftover_size, Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(sline1, sline2))); + vstu_variable(optr, leftover_size, Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(sline1, sline2))); } return 0; } -template -GraphStatus llamamulImpl(TensorType& out_0, - const TensorType& in_0, - const TensorType& in_1) +template +GraphStatus llamamulImpl(TensorType &out_0, + const TensorType &in_0, + const TensorType &in_1) { - /* - * add code here - * */ - /* - * To have good performance and stability, it is required to avoid heap memory - * allocation in this function. The heap memory allocation includes but not - * limited to calling malloc, operator new, constructing STL container objects - * like std::vector with default allocator, and adding items like calling - * std::vector::push_back to STL container objects with default allocator. - * - * Please check in SDK documentation for more information. - */ - out_0.set_dims(in_0); - - auto [b_in, h_in, w_in, d_in] = in_0.dims(); - size_t size = b_in*h_in*w_in*d_in; - - DType dtype = in_0.get_dtype(); - - if (dtype == DType::Float16) { - auto in_ptr = (__fp16*)in_0.raw_data_const(); - auto in2_ptr = (__fp16*)in_1.raw_data_const(); - auto out_ptr = (__fp16*)out_0.raw_data(); - - hvx_mul_ahf(in_ptr, in2_ptr, out_ptr, size); - - } else { - auto in_ptr = (float*)in_0.raw_data_const(); - auto in2_ptr = (float*)in_1.raw_data_const(); - auto out_ptr = (float*)out_0.raw_data(); - - hvx_mul_af(in_ptr, in2_ptr, out_ptr, size); - } - - return GraphStatus::Success; + /* + * add code here + * */ + /* + * To have good performance and stability, it is required to avoid heap memory + * allocation in this function. The heap memory allocation includes but not + * limited to calling malloc, operator new, constructing STL container objects + * like std::vector with default allocator, and adding items like calling + * std::vector::push_back to STL container objects with default allocator. + * + * Please check in SDK documentation for more information. + */ + out_0.set_dims(in_0); + + auto [b_in, h_in, w_in, d_in] = in_0.dims(); + size_t size = b_in * h_in * w_in * d_in; + + DType dtype = in_0.get_dtype(); + + if (dtype == DType::Float16) { + auto in_ptr = (__fp16 *)in_0.raw_data_const(); + auto in2_ptr = (__fp16 *)in_1.raw_data_const(); + auto out_ptr = (__fp16 *)out_0.raw_data(); + + hvx_mul_ahf(in_ptr, in2_ptr, out_ptr, size); + + } else { + auto in_ptr = (float *)in_0.raw_data_const(); + auto in2_ptr = (float *)in_1.raw_data_const(); + auto out_ptr = (float *)out_0.raw_data(); + + hvx_mul_af(in_ptr, in2_ptr, out_ptr, size); + } + + return GraphStatus::Success; } #else - -template -GraphStatus llamamulImpl(TensorType& out_0, - const TensorType& in_0, - const TensorType& in_1) +template +GraphStatus llamamulImpl(TensorType &out_0, + const TensorType &in_0, + const TensorType &in_1) { - /* - * add code here - * */ - /* - * To have good performance and stability, it is required to avoid heap memory - * allocation in this function. The heap memory allocation includes but not - * limited to calling malloc, operator new, constructing STL container objects - * like std::vector with default allocator, and adding items like calling - * std::vector::push_back to STL container objects with default allocator. - * - * Please check in SDK documentation for more information. - */ - out_0.set_dims(in_0); - - DType dtype = in_0.get_dtype(); - - - auto out_ptr = (__fp16*)out_0.raw_data(); - auto in_ptr = (__fp16*)in_0.raw_data_const(); - auto in_ptr2 = (__fp16*)in_1.raw_data_const(); - - auto [b_in, h_in, w_in, d_in] = in_0.dims(); - for (Idx b = 0; b < b_in; b++) { - for (Idx h = 0; h < h_in; h++) { - for (Idx w = 0; w < w_in; w++) { - // mul - for (Idx d = 0; d < d_in; d++) { - - if (dtype == DType::Float16) { - - __fp16 inval = *in_ptr++; - __fp16 inval2 = *in_ptr2++; - __fp16 outval = inval * inval2; - - *out_ptr++ = outval; - } - - if (dtype == DType::Float32) { - float inval = in_0(b, h, w, d); - float inval2 = in_1(b, h, w, d); - float outval = inval * inval2; - - out_0(b, h, w, d) = outval; + /* + * add code here + * */ + /* + * To have good performance and stability, it is required to avoid heap memory + * allocation in this function. The heap memory allocation includes but not + * limited to calling malloc, operator new, constructing STL container objects + * like std::vector with default allocator, and adding items like calling + * std::vector::push_back to STL container objects with default allocator. + * + * Please check in SDK documentation for more information. + */ + out_0.set_dims(in_0); + + DType dtype = in_0.get_dtype(); + + auto out_ptr = (__fp16 *)out_0.raw_data(); + auto in_ptr = (__fp16 *)in_0.raw_data_const(); + auto in_ptr2 = (__fp16 *)in_1.raw_data_const(); + + auto [b_in, h_in, w_in, d_in] = in_0.dims(); + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + for (Idx w = 0; w < w_in; w++) { + // mul + for (Idx d = 0; d < d_in; d++) { + if (dtype == DType::Float16) { + __fp16 inval = *in_ptr++; + __fp16 inval2 = *in_ptr2++; + __fp16 outval = inval * inval2; + + *out_ptr++ = outval; + } + + if (dtype == DType::Float32) { + float inval = in_0(b, h, w, d); + float inval2 = in_1(b, h, w, d); + float outval = inval * inval2; + + out_0(b, h, w, d) = outval; + } + } } - - } } - } } - - return GraphStatus::Success; + return GraphStatus::Success; } - - #endif -__attribute__((unused)) static float llamamulCostFunc(const Op *op) -{ - /* - * add code here - * */ +__attribute__((unused)) static float llamamulCostFunc(const Op *op) { + /* + * add code here + * */ - float cost = 0.0; // add cost computation here - return cost; + float cost = 0.0; // add cost computation here + return cost; } - - - - /* At the bottom of the op file, call END_PKG_OP_DEFINITION(), where is as BEGIN_PKG_OP_DEFINITION */ diff --git a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAQuantize.cpp b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAQuantize.cpp index 23b357b51..8c90ef05a 100755 --- a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAQuantize.cpp +++ b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAQuantize.cpp @@ -9,15 +9,13 @@ #include "QnnOpPackage.h" #include "HTP/core/simple_reg.h" - BEGIN_PKG_OP_DEFINITION(PKG_LLaMAQuantize); - // op execute function declarations -template +template GraphStatus llamaquantizeImpl(TensorType1 &out_0, const TensorType1 &in_0, - const PlainFloatTensor& scale); + const PlainFloatTensor &scale); // forward declaration of sample cost function static float llamaquantizeCostFunc(const Op *op); @@ -62,11 +60,11 @@ DEF_PACKAGE_OP((llamaquantizeImpl), "LLaMAQuantize") * one definition per op, and this is optional * syntax: DEF_PACKAGE_PARAM_ORDER(OP,PARAM1,MANDATORY1,DEFAULT1,PARAM2,MANDATORY2,DEFAULT2...) * one or more parameters can be specified for each op - * order of parameters listed determines the order of parameters passed into op execution functions + * order of parameters listed determines the order of parameters passed into op execution functions * if an op does not have a parameter order definition, parameter order passed into Qnn_addNode * will be passed into op execution functions * if an op has a parameter order definition, any parameter passed into Qnn_addNode with unlisted - * name will be abandoned + * name will be abandoned * if two or more op packages with the same package name will be registered, they cannot list * conflicting parameter orders * PARAM refers to parameter name as a string literal @@ -79,29 +77,29 @@ DEF_PACKAGE_OP((llamaquantizeImpl), "LLaMAQuantize") * graph construction will skip this parameter when this parameter is not provided at * Qnn_addNode */ -DEF_PACKAGE_PARAM_ORDER("LLaMAQuantize", +DEF_PACKAGE_PARAM_ORDER("LLaMAQuantize", "scale", true, nullptr) #ifndef REFERENCE_OP - #include "qhmath_hvx.h" #include "hvx_internal.h" #include #include -#define BLOCK_SIZE (8*1024/VLEN) /* vector chunks */ -#define L2FETCH_AHEAD (BLOCK_SIZE) +#define BLOCK_SIZE (8 * 1024 / VLEN) /* vector chunks */ +#define L2FETCH_AHEAD (BLOCK_SIZE) -static HVX_INLINE_ALWAYS uint32_t float_to_bits(float x) -{ - union { float f; uint32_t i; } fp32 = { .f = x }; +static HVX_INLINE_ALWAYS uint32_t float_to_bits(float x) { + union { + float f; + uint32_t i; + } fp32 = {.f = x}; return fp32.i; } -static inline int32_t float_to_fp16s(float input) -{ +static inline int32_t float_to_fp16s(float input) { union { int32_t i; __fp16 f[2]; @@ -116,7 +114,6 @@ static inline int32_t float_to_fp16s(float input) #define FP16_SIGN 15 #define FP16_NEG_1 0xbc00 - /* execute functions for ops */ int32_t qhmath_hvx_quantize_ahf( __fp16 *restrict input, @@ -124,15 +121,13 @@ int32_t qhmath_hvx_quantize_ahf( uint32_t size, float low_level, float high_level, - float scale) -{ - if ((input == NULL) || (output == NULL) || (size == 0)) - { + float scale) { + if ((input == NULL) || (output == NULL) || (size == 0)) { return -1; } - HVX_Vector *iptr = (HVX_Vector *) input; - HVX_UVector *optr = (HVX_UVector *) output; + HVX_Vector *iptr = (HVX_Vector *)input; + HVX_UVector *optr = (HVX_UVector *)output; HVX_Vector sline1p, sline1c, sline1; HVX_Vector sline2p, sline2c, sline2; @@ -153,7 +148,7 @@ int32_t qhmath_hvx_quantize_ahf( HVX_Vector uintconvert = Q6_V_vsplat_R(0x80808080); - float es = 0.5; + float es = 0.5; low_level_vec = Q6_V_vsplat_R(float_to_fp16s(low_level)); high_level_vec = Q6_V_vsplat_R(float_to_fp16s(high_level)); scale_vec = Q6_V_vsplat_R(float_to_fp16s(scale)); @@ -170,22 +165,19 @@ int32_t qhmath_hvx_quantize_ahf( HVX_Vector negone = Q6_Vh_vsplat_R(FP16_NEG_1); HVX_Vector zero = Q6_V_vzero(); - for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) - { + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { block = Q6_R_min_RR(i, BLOCK_SIZE); l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); - if (l2fetch_block > 0) - { + if (l2fetch_block > 0) { l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); } - for (int32_t j = 0; j < block; j+=4) - { + for (int32_t j = 0; j < block; j += 4) { sline1c = *iptr++; - sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t) input); + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); - sout1 = Q6_Vqf16_vmpy_VhfVhf(sline1,scale_vec); + sout1 = Q6_Vqf16_vmpy_VhfVhf(sline1, scale_vec); sout1 = Q6_Vqf16_vadd_Vqf16Vqf16(sout1, es_vec); sout1 = Q6_Vhf_equals_Vqf16(sout1); sout1 = Q6_Vhf_vmin_VhfVhf(sout1, high_level_vec); @@ -228,9 +220,9 @@ int32_t qhmath_hvx_quantize_ahf( sout1 = Q6_Vh_equals_Vhf(sout1); sline2c = *iptr++; - sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t) input); + sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t)input); - sout2 = Q6_Vqf16_vmpy_VhfVhf(sline2,scale_vec); + sout2 = Q6_Vqf16_vmpy_VhfVhf(sline2, scale_vec); sout2 = Q6_Vqf16_vadd_Vqf16Vqf16(sout2, es_vec); sout2 = Q6_Vhf_equals_Vqf16(sout2); sout2 = Q6_Vhf_vmin_VhfVhf(sout2, high_level_vec); @@ -273,9 +265,9 @@ int32_t qhmath_hvx_quantize_ahf( sout2 = Q6_Vh_equals_Vhf(sout2); sline3c = *iptr++; - sline3 = Q6_V_valign_VVR(sline3c, sline3p, (size_t) input); + sline3 = Q6_V_valign_VVR(sline3c, sline3p, (size_t)input); - sout3 = Q6_Vqf16_vmpy_VhfVhf(sline3,scale_vec); + sout3 = Q6_Vqf16_vmpy_VhfVhf(sline3, scale_vec); sout3 = Q6_Vqf16_vadd_Vqf16Vqf16(sout3, es_vec); sout3 = Q6_Vhf_equals_Vqf16(sout3); sout3 = Q6_Vhf_vmin_VhfVhf(sout3, high_level_vec); @@ -318,9 +310,9 @@ int32_t qhmath_hvx_quantize_ahf( sout3 = Q6_Vh_equals_Vhf(sout3); sline4c = *iptr++; - sline4 = Q6_V_valign_VVR(sline4c, sline4p, (size_t) input); + sline4 = Q6_V_valign_VVR(sline4c, sline4p, (size_t)input); - sout4 = Q6_Vqf16_vmpy_VhfVhf(sline4,scale_vec); + sout4 = Q6_Vqf16_vmpy_VhfVhf(sline4, scale_vec); sout4 = Q6_Vqf16_vadd_Vqf16Vqf16(sout4, es_vec); sout4 = Q6_Vhf_equals_Vqf16(sout4); sout4 = Q6_Vhf_vmin_VhfVhf(sout4, high_level_vec); @@ -362,15 +354,266 @@ int32_t qhmath_hvx_quantize_ahf( sout4 = Q6_Vh_equals_Vhf(sout4); - HVX_Vector reql_h = Q6_Vb_vpack_VhVh_sat(sout2, sout1); *optr++ = Q6_Vb_vadd_VbVb(reql_h, uintconvert); HVX_Vector reqh_h = Q6_Vb_vpack_VhVh_sat(sout4, sout3); *optr++ = Q6_Vb_vadd_VbVb(reqh_h, uintconvert); + sline1p = sline1c; + sline2p = sline2c; + sline3p = sline3c; + sline4p = sline4c; + } + } + + return 0; +} + +int32_t qhmath_hvx_quantize_ui16_ahf( + __fp16 *restrict input, + __fp16 *restrict output, + uint32_t size, + float low_level, + float high_level, + float scale) { + if ((input == NULL) || (output == NULL) || (size == 0)) { + return -1; + } + + HVX_Vector *iptr = (HVX_Vector *)input; + HVX_UVector *optr = (HVX_UVector *)output; + + HVX_Vector sline1p, sline1c, sline1; + HVX_Vector sline2p, sline2c, sline2; + HVX_Vector sline3p, sline3c, sline3; + HVX_Vector sline4p, sline4c, sline4; + + HVX_Vector sout1, sout2, sout3, sout4; + HVX_Vector low_level_vec, high_level_vec, scale_vec, es_vec; + int32_t block, l2fetch_block; + // int32_t leftover = size & 31; + int32_t vectors_in_rounddown = size / 64; + // int32_t leftover_size = leftover * sizeof(float); + + sline1p = *iptr++; + sline2p = *iptr++; + sline3p = *iptr++; + sline4p = *iptr++; + + HVX_Vector uintconvert = Q6_V_vsplat_R(0x80008000); + + float es = 0.5; + low_level_vec = Q6_V_vsplat_R(float_to_fp16s(low_level)); + high_level_vec = Q6_V_vsplat_R(float_to_fp16s(high_level)); + scale_vec = Q6_V_vsplat_R(float_to_fp16s(scale)); + es_vec = Q6_V_vsplat_R(float_to_fp16s(es)); + + HVX_Vector zero_v_sf = Q6_V_vzero(); + es_vec = Q6_Vqf16_vadd_VhfVhf(es_vec, zero_v_sf); + + HVX_Vector expmask = Q6_Vh_vsplat_R(FP16_EXPONENT_MASK); + HVX_Vector expbias = Q6_Vh_vsplat_R(FP16_EXPONENT_BIAS); + HVX_Vector manmask = Q6_Vh_vsplat_R(FP16_MANTISA_MASK); + HVX_Vector exp23 = Q6_Vh_vsplat_R(23 - 1); + HVX_Vector exp0 = Q6_Vh_vsplat_R(0 - 1); + HVX_Vector negone = Q6_Vh_vsplat_R(FP16_NEG_1); + HVX_Vector zero = Q6_V_vzero(); + + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { + block = Q6_R_min_RR(i, BLOCK_SIZE); + l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); + + if (l2fetch_block > 0) { + l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); + } + + for (int32_t j = 0; j < block; j += 4) { + sline1c = *iptr++; + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); + + sout1 = Q6_Vqf16_vmpy_VhfVhf(sline1, scale_vec); + sout1 = Q6_Vqf16_vadd_Vqf16Vqf16(sout1, es_vec); + sout1 = Q6_Vhf_equals_Vqf16(sout1); + sout1 = Q6_Vhf_vmin_VhfVhf(sout1, high_level_vec); + sout1 = Q6_Vhf_vmax_VhfVhf(sout1, low_level_vec); + + { + HVX_Vector exp = Q6_Vh_vasr_VhR(sout1, FP16_MANTISA); + exp = Q6_V_vand_VV(exp, expmask); + exp = Q6_Vh_vsub_VhVh(exp, expbias); + + HVX_Vector man = Q6_Vh_vasr_VhVh(manmask, exp); + HVX_Vector manzero = Q6_V_vand_VV(sout1, man); + + HVX_Vector sign = Q6_Vh_vasr_VhR(sout1, FP16_SIGN); + HVX_Vector issignpos = Q6_Q_vcmp_eq_VhVh(sign, zero); + + HVX_Vector expgte23 = Q6_Q_vcmp_gt_VhVh(exp, exp23); + HVX_Vector expgte0 = Q6_Q_vcmp_gt_VhVh(exp, exp0); + HVX_Vector maneqzero = Q6_Q_vcmp_eq_VhVh(manzero, zero); + + HVX_Vector exppos_signneg = Q6_Vh_vadd_VhVh(sout1, man); + man = Q6_V_vnot_V(man); + HVX_Vector exppos_signpos = Q6_V_vand_VV(sout1, man); + exppos_signneg = Q6_V_vand_VV(exppos_signneg, man); + HVX_Vector shift1 = Q6_Vh_vasl_VhR(sout1, 1); + HVX_Vector iszero = Q6_Q_vcmp_eq_VhVh(shift1, zero); + + // exp >= 0 + HVX_Vector tsout1 = Q6_V_vmux_QVV(issignpos, exppos_signpos, exppos_signneg); + tsout1 = Q6_V_vmux_QVV(maneqzero, sout1, tsout1); + + // exp < 0 (-1, 1) + HVX_Vector tsout2 = Q6_V_vmux_QVV(iszero, sout1, negone); + tsout2 = Q6_V_vmux_QVV(issignpos, zero, tsout2); + + tsout1 = Q6_V_vmux_QVV(expgte0, tsout1, tsout2); + sout1 = Q6_V_vmux_QVV(expgte23, sout1, tsout1); + } + + sout1 = Q6_Vh_equals_Vhf(sout1); + + sline2c = *iptr++; + sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t)input); + + sout2 = Q6_Vqf16_vmpy_VhfVhf(sline2, scale_vec); + sout2 = Q6_Vqf16_vadd_Vqf16Vqf16(sout2, es_vec); + sout2 = Q6_Vhf_equals_Vqf16(sout2); + sout2 = Q6_Vhf_vmin_VhfVhf(sout2, high_level_vec); + sout2 = Q6_Vhf_vmax_VhfVhf(sout2, low_level_vec); + + { + HVX_Vector exp = Q6_Vh_vasr_VhR(sout2, FP16_MANTISA); + exp = Q6_V_vand_VV(exp, expmask); + exp = Q6_Vh_vsub_VhVh(exp, expbias); + + HVX_Vector man = Q6_Vh_vasr_VhVh(manmask, exp); + HVX_Vector manzero = Q6_V_vand_VV(sout2, man); + + HVX_Vector sign = Q6_Vh_vasr_VhR(sout2, FP16_SIGN); + HVX_Vector issignpos = Q6_Q_vcmp_eq_VhVh(sign, zero); + + HVX_Vector expgte23 = Q6_Q_vcmp_gt_VhVh(exp, exp23); + HVX_Vector expgte0 = Q6_Q_vcmp_gt_VhVh(exp, exp0); + HVX_Vector maneqzero = Q6_Q_vcmp_eq_VhVh(manzero, zero); + + HVX_Vector exppos_signneg = Q6_Vh_vadd_VhVh(sout2, man); + man = Q6_V_vnot_V(man); + HVX_Vector exppos_signpos = Q6_V_vand_VV(sout2, man); + exppos_signneg = Q6_V_vand_VV(exppos_signneg, man); + HVX_Vector shift1 = Q6_Vh_vasl_VhR(sout2, 1); + HVX_Vector iszero = Q6_Q_vcmp_eq_VhVh(shift1, zero); + + // exp >= 0 + HVX_Vector tsout1 = Q6_V_vmux_QVV(issignpos, exppos_signpos, exppos_signneg); + tsout1 = Q6_V_vmux_QVV(maneqzero, sout2, tsout1); + + // exp < 0 (-1, 1) + HVX_Vector tsout2 = Q6_V_vmux_QVV(iszero, sout2, negone); + tsout2 = Q6_V_vmux_QVV(issignpos, zero, tsout2); + + tsout1 = Q6_V_vmux_QVV(expgte0, tsout1, tsout2); + sout2 = Q6_V_vmux_QVV(expgte23, sout2, tsout1); + } + + sout2 = Q6_Vh_equals_Vhf(sout2); + + sline3c = *iptr++; + sline3 = Q6_V_valign_VVR(sline3c, sline3p, (size_t)input); + + sout3 = Q6_Vqf16_vmpy_VhfVhf(sline3, scale_vec); + sout3 = Q6_Vqf16_vadd_Vqf16Vqf16(sout3, es_vec); + sout3 = Q6_Vhf_equals_Vqf16(sout3); + sout3 = Q6_Vhf_vmin_VhfVhf(sout3, high_level_vec); + sout3 = Q6_Vhf_vmax_VhfVhf(sout3, low_level_vec); + + { + HVX_Vector exp = Q6_Vh_vasr_VhR(sout3, FP16_MANTISA); + exp = Q6_V_vand_VV(exp, expmask); + exp = Q6_Vh_vsub_VhVh(exp, expbias); + + HVX_Vector man = Q6_Vh_vasr_VhVh(manmask, exp); + HVX_Vector manzero = Q6_V_vand_VV(sout3, man); + + HVX_Vector sign = Q6_Vh_vasr_VhR(sout3, FP16_SIGN); + HVX_Vector issignpos = Q6_Q_vcmp_eq_VhVh(sign, zero); + + HVX_Vector expgte23 = Q6_Q_vcmp_gt_VhVh(exp, exp23); + HVX_Vector expgte0 = Q6_Q_vcmp_gt_VhVh(exp, exp0); + HVX_Vector maneqzero = Q6_Q_vcmp_eq_VhVh(manzero, zero); + + HVX_Vector exppos_signneg = Q6_Vh_vadd_VhVh(sout3, man); + man = Q6_V_vnot_V(man); + HVX_Vector exppos_signpos = Q6_V_vand_VV(sout3, man); + exppos_signneg = Q6_V_vand_VV(exppos_signneg, man); + HVX_Vector shift1 = Q6_Vh_vasl_VhR(sout3, 1); + HVX_Vector iszero = Q6_Q_vcmp_eq_VhVh(shift1, zero); + + // exp >= 0 + HVX_Vector tsout1 = Q6_V_vmux_QVV(issignpos, exppos_signpos, exppos_signneg); + tsout1 = Q6_V_vmux_QVV(maneqzero, sout3, tsout1); + + // exp < 0 (-1, 1) + HVX_Vector tsout2 = Q6_V_vmux_QVV(iszero, sout3, negone); + tsout2 = Q6_V_vmux_QVV(issignpos, zero, tsout2); + + tsout1 = Q6_V_vmux_QVV(expgte0, tsout1, tsout2); + sout3 = Q6_V_vmux_QVV(expgte23, sout3, tsout1); + } + + sout3 = Q6_Vh_equals_Vhf(sout3); + + sline4c = *iptr++; + sline4 = Q6_V_valign_VVR(sline4c, sline4p, (size_t)input); + + sout4 = Q6_Vqf16_vmpy_VhfVhf(sline4, scale_vec); + sout4 = Q6_Vqf16_vadd_Vqf16Vqf16(sout4, es_vec); + sout4 = Q6_Vhf_equals_Vqf16(sout4); + sout4 = Q6_Vhf_vmin_VhfVhf(sout4, high_level_vec); + sout4 = Q6_Vhf_vmax_VhfVhf(sout4, low_level_vec); + + { + HVX_Vector exp = Q6_Vh_vasr_VhR(sout4, FP16_MANTISA); + exp = Q6_V_vand_VV(exp, expmask); + exp = Q6_Vh_vsub_VhVh(exp, expbias); + + HVX_Vector man = Q6_Vh_vasr_VhVh(manmask, exp); + HVX_Vector manzero = Q6_V_vand_VV(sout4, man); + + HVX_Vector sign = Q6_Vh_vasr_VhR(sout4, FP16_SIGN); + HVX_Vector issignpos = Q6_Q_vcmp_eq_VhVh(sign, zero); + + HVX_Vector expgte23 = Q6_Q_vcmp_gt_VhVh(exp, exp23); + HVX_Vector expgte0 = Q6_Q_vcmp_gt_VhVh(exp, exp0); + HVX_Vector maneqzero = Q6_Q_vcmp_eq_VhVh(manzero, zero); + + HVX_Vector exppos_signneg = Q6_Vh_vadd_VhVh(sout4, man); + man = Q6_V_vnot_V(man); + HVX_Vector exppos_signpos = Q6_V_vand_VV(sout4, man); + exppos_signneg = Q6_V_vand_VV(exppos_signneg, man); + HVX_Vector shift1 = Q6_Vh_vasl_VhR(sout4, 1); + HVX_Vector iszero = Q6_Q_vcmp_eq_VhVh(shift1, zero); + + // exp >= 0 + HVX_Vector tsout1 = Q6_V_vmux_QVV(issignpos, exppos_signpos, exppos_signneg); + tsout1 = Q6_V_vmux_QVV(maneqzero, sout4, tsout1); + + // exp < 0 (-1, 1) + HVX_Vector tsout2 = Q6_V_vmux_QVV(iszero, sout4, negone); + tsout2 = Q6_V_vmux_QVV(issignpos, zero, tsout2); + + tsout1 = Q6_V_vmux_QVV(expgte0, tsout1, tsout2); + sout4 = Q6_V_vmux_QVV(expgte23, sout4, tsout1); + } + + sout4 = Q6_Vh_equals_Vhf(sout4); + + *optr++ = Q6_Vh_vadd_VhVh(sout1, uintconvert); + *optr++ = Q6_Vh_vadd_VhVh(sout2, uintconvert); + *optr++ = Q6_Vh_vadd_VhVh(sout3, uintconvert); + *optr++ = Q6_Vh_vadd_VhVh(sout4, uintconvert); - sline1p = sline1c; sline2p = sline2c; sline3p = sline3c; @@ -387,15 +630,13 @@ int32_t qhmath_hvx_quantize_ahf_int8( uint32_t size, float low_level, float high_level, - float scale) -{ - if ((input == NULL) || (output == NULL) || (size == 0)) - { + float scale) { + if ((input == NULL) || (output == NULL) || (size == 0)) { return -1; } - HVX_Vector *iptr = (HVX_Vector *) input; - HVX_UVector *optr = (HVX_UVector *) output; + HVX_Vector *iptr = (HVX_Vector *)input; + HVX_UVector *optr = (HVX_UVector *)output; HVX_Vector sline1p, sline1c, sline1; HVX_Vector sline2p, sline2c, sline2; @@ -414,7 +655,7 @@ int32_t qhmath_hvx_quantize_ahf_int8( sline3p = *iptr++; sline4p = *iptr++; - float es = 0.5; + float es = 0.5; low_level_vec = Q6_V_vsplat_R(float_to_fp16s(low_level)); high_level_vec = Q6_V_vsplat_R(float_to_fp16s(high_level)); scale_vec = Q6_V_vsplat_R(float_to_fp16s(scale)); @@ -431,22 +672,19 @@ int32_t qhmath_hvx_quantize_ahf_int8( HVX_Vector negone = Q6_Vh_vsplat_R(FP16_NEG_1); HVX_Vector zero = Q6_V_vzero(); - for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) - { + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { block = Q6_R_min_RR(i, BLOCK_SIZE); l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); - if (l2fetch_block > 0) - { + if (l2fetch_block > 0) { l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); } - for (int32_t j = 0; j < block; j+=4) - { + for (int32_t j = 0; j < block; j += 4) { sline1c = *iptr++; - sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t) input); + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); - sout1 = Q6_Vqf16_vmpy_VhfVhf(sline1,scale_vec); + sout1 = Q6_Vqf16_vmpy_VhfVhf(sline1, scale_vec); sout1 = Q6_Vqf16_vadd_Vqf16Vqf16(sout1, es_vec); sout1 = Q6_Vhf_equals_Vqf16(sout1); sout1 = Q6_Vhf_vmin_VhfVhf(sout1, high_level_vec); @@ -489,9 +727,9 @@ int32_t qhmath_hvx_quantize_ahf_int8( sout1 = Q6_Vh_equals_Vhf(sout1); sline2c = *iptr++; - sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t) input); + sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t)input); - sout2 = Q6_Vqf16_vmpy_VhfVhf(sline2,scale_vec); + sout2 = Q6_Vqf16_vmpy_VhfVhf(sline2, scale_vec); sout2 = Q6_Vqf16_vadd_Vqf16Vqf16(sout2, es_vec); sout2 = Q6_Vhf_equals_Vqf16(sout2); sout2 = Q6_Vhf_vmin_VhfVhf(sout2, high_level_vec); @@ -534,9 +772,9 @@ int32_t qhmath_hvx_quantize_ahf_int8( sout2 = Q6_Vh_equals_Vhf(sout2); sline3c = *iptr++; - sline3 = Q6_V_valign_VVR(sline3c, sline3p, (size_t) input); + sline3 = Q6_V_valign_VVR(sline3c, sline3p, (size_t)input); - sout3 = Q6_Vqf16_vmpy_VhfVhf(sline3,scale_vec); + sout3 = Q6_Vqf16_vmpy_VhfVhf(sline3, scale_vec); sout3 = Q6_Vqf16_vadd_Vqf16Vqf16(sout3, es_vec); sout3 = Q6_Vhf_equals_Vqf16(sout3); sout3 = Q6_Vhf_vmin_VhfVhf(sout3, high_level_vec); @@ -579,9 +817,9 @@ int32_t qhmath_hvx_quantize_ahf_int8( sout3 = Q6_Vh_equals_Vhf(sout3); sline4c = *iptr++; - sline4 = Q6_V_valign_VVR(sline4c, sline4p, (size_t) input); + sline4 = Q6_V_valign_VVR(sline4c, sline4p, (size_t)input); - sout4 = Q6_Vqf32_vmpy_VsfVsf(sline4,scale_vec); + sout4 = Q6_Vqf32_vmpy_VsfVsf(sline4, scale_vec); sout4 = Q6_Vqf16_vadd_Vqf16Vqf16(sout4, es_vec); sout4 = Q6_Vhf_equals_Vqf16(sout4); sout4 = Q6_Vhf_vmin_VhfVhf(sout4, high_level_vec); @@ -623,15 +861,12 @@ int32_t qhmath_hvx_quantize_ahf_int8( sout4 = Q6_Vh_equals_Vhf(sout4); - HVX_Vector reql_h = Q6_Vb_vpack_VhVh_sat(sout2, sout1); *optr++ = reql_h; HVX_Vector reqh_h = Q6_Vb_vpack_VhVh_sat(sout4, sout3); *optr++ = reqh_h; - - sline1p = sline1c; sline2p = sline2c; sline3p = sline3c; @@ -657,15 +892,13 @@ int32_t qhmath_hvx_quantize_af( uint32_t size, float low_level, float high_level, - float scale) -{ - if ((input == NULL) || (output == NULL) || (size == 0)) - { + float scale) { + if ((input == NULL) || (output == NULL) || (size == 0)) { return -1; } - HVX_Vector *iptr = (HVX_Vector *) input; - HVX_UVector *optr = (HVX_UVector *) output; + HVX_Vector *iptr = (HVX_Vector *)input; + HVX_UVector *optr = (HVX_UVector *)output; HVX_Vector sline1p, sline1c, sline1; HVX_Vector sline2p, sline2c, sline2; @@ -684,7 +917,7 @@ int32_t qhmath_hvx_quantize_af( sline3p = *iptr++; sline4p = *iptr++; - float es = 0.5f; + float es = 0.5f; low_level_vec = Q6_V_vsplat_R(float_to_bits(low_level)); high_level_vec = Q6_V_vsplat_R(float_to_bits(high_level)); scale_vec = Q6_V_vsplat_R(float_to_bits(scale)); @@ -696,7 +929,6 @@ int32_t qhmath_hvx_quantize_af( HVX_Vector uintconvert = Q6_V_vsplat_R(0x80808080); - // HVX_Vector expmask = Q6_V_vsplat_R(FLOAT_EXPONENT_MASK); // HVX_Vector expbias = Q6_V_vsplat_R(FLOAT_EXPONENT_BIAS); // HVX_Vector manmask = Q6_V_vsplat_R(FLOAT_MANTISA_MASK); @@ -705,22 +937,19 @@ int32_t qhmath_hvx_quantize_af( // HVX_Vector negone = Q6_V_vsplat_R(FLOAT_NEG_1); // HVX_Vector zero = Q6_V_vzero(); - for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) - { + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { block = Q6_R_min_RR(i, BLOCK_SIZE); l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); - if (l2fetch_block > 0) - { + if (l2fetch_block > 0) { l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); } - for (int32_t j = 0; j < block; j+=4) - { + for (int32_t j = 0; j < block; j += 4) { sline1c = *iptr++; - sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t) input); + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); - sout1 = Q6_Vqf32_vmpy_VsfVsf(sline1,scale_vec); + sout1 = Q6_Vqf32_vmpy_VsfVsf(sline1, scale_vec); sout1 = Q6_Vqf32_vadd_Vqf32Vqf32(sout1, es_vec); sout1 = Q6_Vsf_equals_Vqf32(sout1); sout1 = Q6_Vsf_vmin_VsfVsf(sout1, high_level_vec); @@ -767,9 +996,9 @@ int32_t qhmath_hvx_quantize_af( // sout1 = qhmath_hvx_vw_convert_vqf32_rmode(Q6_Vqf32_vadd_VsfVsf(sout1, Q6_V_vzero()), 0); sline2c = *iptr++; - sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t) input); + sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t)input); - sout2 = Q6_Vqf32_vmpy_VsfVsf(sline2,scale_vec); + sout2 = Q6_Vqf32_vmpy_VsfVsf(sline2, scale_vec); sout2 = Q6_Vqf32_vadd_Vqf32Vqf32(sout2, es_vec); sout2 = Q6_Vsf_equals_Vqf32(sout2); sout2 = Q6_Vsf_vmin_VsfVsf(sout2, high_level_vec); @@ -816,9 +1045,9 @@ int32_t qhmath_hvx_quantize_af( // sout2 = qhmath_hvx_vw_convert_vqf32_rmode(Q6_Vqf32_vadd_VsfVsf(sout2, Q6_V_vzero()), 0); sline3c = *iptr++; - sline3 = Q6_V_valign_VVR(sline3c, sline3p, (size_t) input); + sline3 = Q6_V_valign_VVR(sline3c, sline3p, (size_t)input); - sout3 = Q6_Vqf32_vmpy_VsfVsf(sline3,scale_vec); + sout3 = Q6_Vqf32_vmpy_VsfVsf(sline3, scale_vec); sout3 = Q6_Vqf32_vadd_Vqf32Vqf32(sout3, es_vec); sout3 = Q6_Vsf_equals_Vqf32(sout3); sout3 = Q6_Vsf_vmin_VsfVsf(sout3, high_level_vec); @@ -860,22 +1089,20 @@ int32_t qhmath_hvx_quantize_af( // sout3 = Q6_V_vmux_QVV(expgte23, sout3, tsout1); // } - sout3 = Q6_Vw_equals_Vsf(sout3); sout3 = Q6_Vw_vasr_VwR(sout3, ROUND_2_SCALE); // sout3 = qhmath_hvx_vw_convert_vqf32_rmode(Q6_Vqf32_vadd_VsfVsf(sout3, Q6_V_vzero()), 0); sline4c = *iptr++; - sline4 = Q6_V_valign_VVR(sline4c, sline4p, (size_t) input); + sline4 = Q6_V_valign_VVR(sline4c, sline4p, (size_t)input); - sout4 = Q6_Vqf32_vmpy_VsfVsf(sline4,scale_vec); + sout4 = Q6_Vqf32_vmpy_VsfVsf(sline4, scale_vec); sout4 = Q6_Vqf32_vadd_Vqf32Vqf32(sout4, es_vec); sout4 = Q6_Vsf_equals_Vqf32(sout4); sout4 = Q6_Vsf_vmin_VsfVsf(sout4, high_level_vec); sout4 = Q6_Vsf_vmax_VsfVsf(sout4, low_level_vec); sout4 = Q6_Vqf32_vmpy_VsfVsf(sout4, round_scale_vec); sout4 = Q6_Vsf_equals_Vqf32(sout4); - // { // HVX_Vector exp = Q6_Vw_vasr_VwR(sout4, FLOAT_MANTISA); @@ -915,7 +1142,6 @@ int32_t qhmath_hvx_quantize_af( sout4 = Q6_Vw_vasr_VwR(sout4, ROUND_2_SCALE); // sout4 = qhmath_hvx_vw_convert_vqf32_rmode(Q6_Vqf32_vadd_VsfVsf(sout4, Q6_V_vzero()), 0); - HVX_Vector reql_h = Q6_Vh_vpack_VwVw_sat(sout2, sout1); HVX_Vector reqh_h = Q6_Vh_vpack_VwVw_sat(sout4, sout3); HVX_Vector req_b = Q6_Vb_vpack_VhVh_sat(reqh_h, reql_h); @@ -932,21 +1158,150 @@ int32_t qhmath_hvx_quantize_af( return 0; } +#define INT16_ROUND_2_SCALE 15 +#define INT16_ROUND_SCALSE ((1 << INT16_ROUND_2_SCALE) * 1.0f) + +int32_t qhmath_hvx_quantize_ui16_af( + float *restrict input, + float *restrict output, + uint32_t size, + float low_level, + float high_level, + float scale) { + if ((input == NULL) || (output == NULL) || (size == 0)) { + return -1; + } + + HVX_Vector *iptr = (HVX_Vector *)input; + HVX_UVector *optr = (HVX_UVector *)output; + + HVX_Vector sline1p, sline1c, sline1; + HVX_Vector sline2p, sline2c, sline2; + HVX_Vector sline3p, sline3c, sline3; + HVX_Vector sline4p, sline4c, sline4; + + HVX_Vector sout1, sout2, sout3, sout4; + HVX_Vector low_level_vec, high_level_vec, scale_vec, es_vec, round_scale_vec; + int32_t block, l2fetch_block; + // int32_t leftover = size & 31; + int32_t vectors_in_rounddown = size / 32; + // int32_t leftover_size = leftover * sizeof(float); + + sline1p = *iptr++; + sline2p = *iptr++; + sline3p = *iptr++; + sline4p = *iptr++; + + float es = 0.5f; + low_level_vec = Q6_V_vsplat_R(float_to_bits(low_level)); + high_level_vec = Q6_V_vsplat_R(float_to_bits(high_level)); + scale_vec = Q6_V_vsplat_R(float_to_bits(scale)); + es_vec = Q6_V_vsplat_R(float_to_bits(es)); + round_scale_vec = Q6_V_vsplat_R(float_to_bits(INT16_ROUND_SCALSE)); + + HVX_Vector zero_v_sf = Q6_V_vzero(); + es_vec = Q6_Vqf32_vadd_VsfVsf(es_vec, zero_v_sf); + + HVX_Vector uintconvert = Q6_V_vsplat_R(0x80008000); + + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { + block = Q6_R_min_RR(i, BLOCK_SIZE); + l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); + + if (l2fetch_block > 0) { + l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); + } + + for (int32_t j = 0; j < block; j += 4) { + sline1c = *iptr++; + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); + + sout1 = Q6_Vqf32_vmpy_VsfVsf(sline1, scale_vec); + sout1 = Q6_Vqf32_vadd_Vqf32Vqf32(sout1, es_vec); + sout1 = Q6_Vsf_equals_Vqf32(sout1); + sout1 = Q6_Vsf_vmin_VsfVsf(sout1, high_level_vec); + sout1 = Q6_Vsf_vmax_VsfVsf(sout1, low_level_vec); + sout1 = Q6_Vqf32_vmpy_VsfVsf(sout1, round_scale_vec); + sout1 = Q6_Vsf_equals_Vqf32(sout1); + + sout1 = Q6_Vw_equals_Vsf(sout1); + sout1 = Q6_Vw_vasr_VwR(sout1, INT16_ROUND_2_SCALE); + // sout1 = qhmath_hvx_vw_convert_vqf32_rmode(Q6_Vqf32_vadd_VsfVsf(sout1, Q6_V_vzero()), 0); + + sline2c = *iptr++; + sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t)input); + + sout2 = Q6_Vqf32_vmpy_VsfVsf(sline2, scale_vec); + sout2 = Q6_Vqf32_vadd_Vqf32Vqf32(sout2, es_vec); + sout2 = Q6_Vsf_equals_Vqf32(sout2); + sout2 = Q6_Vsf_vmin_VsfVsf(sout2, high_level_vec); + sout2 = Q6_Vsf_vmax_VsfVsf(sout2, low_level_vec); + sout2 = Q6_Vqf32_vmpy_VsfVsf(sout2, round_scale_vec); + sout2 = Q6_Vsf_equals_Vqf32(sout2); + + sout2 = Q6_Vw_equals_Vsf(sout2); + sout2 = Q6_Vw_vasr_VwR(sout2, INT16_ROUND_2_SCALE); + // sout2 = qhmath_hvx_vw_convert_vqf32_rmode(Q6_Vqf32_vadd_VsfVsf(sout2, Q6_V_vzero()), 0); + + sline3c = *iptr++; + sline3 = Q6_V_valign_VVR(sline3c, sline3p, (size_t)input); + + sout3 = Q6_Vqf32_vmpy_VsfVsf(sline3, scale_vec); + sout3 = Q6_Vqf32_vadd_Vqf32Vqf32(sout3, es_vec); + sout3 = Q6_Vsf_equals_Vqf32(sout3); + sout3 = Q6_Vsf_vmin_VsfVsf(sout3, high_level_vec); + sout3 = Q6_Vsf_vmax_VsfVsf(sout3, low_level_vec); + sout3 = Q6_Vqf32_vmpy_VsfVsf(sout3, round_scale_vec); + sout3 = Q6_Vsf_equals_Vqf32(sout3); + + sout3 = Q6_Vw_equals_Vsf(sout3); + sout3 = Q6_Vw_vasr_VwR(sout3, INT16_ROUND_2_SCALE); + // sout3 = qhmath_hvx_vw_convert_vqf32_rmode(Q6_Vqf32_vadd_VsfVsf(sout3, Q6_V_vzero()), 0); + + sline4c = *iptr++; + sline4 = Q6_V_valign_VVR(sline4c, sline4p, (size_t)input); + + sout4 = Q6_Vqf32_vmpy_VsfVsf(sline4, scale_vec); + sout4 = Q6_Vqf32_vadd_Vqf32Vqf32(sout4, es_vec); + sout4 = Q6_Vsf_equals_Vqf32(sout4); + sout4 = Q6_Vsf_vmin_VsfVsf(sout4, high_level_vec); + sout4 = Q6_Vsf_vmax_VsfVsf(sout4, low_level_vec); + sout4 = Q6_Vqf32_vmpy_VsfVsf(sout4, round_scale_vec); + sout4 = Q6_Vsf_equals_Vqf32(sout4); + + sout4 = Q6_Vw_equals_Vsf(sout4); + sout4 = Q6_Vw_vasr_VwR(sout4, INT16_ROUND_2_SCALE); + // sout4 = qhmath_hvx_vw_convert_vqf32_rmode(Q6_Vqf32_vadd_VsfVsf(sout4, Q6_V_vzero()), 0); + + HVX_Vector reql_h = Q6_Vh_vpack_VwVw_sat(sout2, sout1); + HVX_Vector reqh_h = Q6_Vh_vpack_VwVw_sat(sout4, sout3); + + *optr++ = Q6_Vh_vadd_VhVh(reql_h, uintconvert); + *optr++ = Q6_Vh_vadd_VhVh(reqh_h, uintconvert); + + sline1p = sline1c; + sline2p = sline2c; + sline3p = sline3c; + sline4p = sline4c; + } + } + + return 0; +} + int32_t qhmath_hvx_quantize_af_out_int8( float *restrict input, float *restrict output, uint32_t size, float low_level, float high_level, - float scale) -{ - if ((input == NULL) || (output == NULL) || (size == 0)) - { + float scale) { + if ((input == NULL) || (output == NULL) || (size == 0)) { return -1; } - HVX_Vector *iptr = (HVX_Vector *) input; - HVX_UVector *optr = (HVX_UVector *) output; + HVX_Vector *iptr = (HVX_Vector *)input; + HVX_UVector *optr = (HVX_UVector *)output; HVX_Vector sline1p, sline1c, sline1; HVX_Vector sline2p, sline2c, sline2; @@ -965,7 +1320,7 @@ int32_t qhmath_hvx_quantize_af_out_int8( sline3p = *iptr++; sline4p = *iptr++; - float es = 0.5f; + float es = 0.5f; low_level_vec = Q6_V_vsplat_R(float_to_bits(low_level)); high_level_vec = Q6_V_vsplat_R(float_to_bits(high_level)); scale_vec = Q6_V_vsplat_R(float_to_bits(scale)); @@ -974,7 +1329,6 @@ int32_t qhmath_hvx_quantize_af_out_int8( HVX_Vector zero_v_sf = Q6_V_vzero(); es_vec = Q6_Vqf32_vadd_VsfVsf(es_vec, zero_v_sf); - HVX_Vector expmask = Q6_V_vsplat_R(FLOAT_EXPONENT_MASK); HVX_Vector expbias = Q6_V_vsplat_R(FLOAT_EXPONENT_BIAS); HVX_Vector manmask = Q6_V_vsplat_R(FLOAT_MANTISA_MASK); @@ -983,22 +1337,19 @@ int32_t qhmath_hvx_quantize_af_out_int8( HVX_Vector negone = Q6_V_vsplat_R(FLOAT_NEG_1); HVX_Vector zero = Q6_V_vzero(); - for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) - { + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { block = Q6_R_min_RR(i, BLOCK_SIZE); l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); - if (l2fetch_block > 0) - { + if (l2fetch_block > 0) { l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); } - for (int32_t j = 0; j < block; j+=4) - { + for (int32_t j = 0; j < block; j += 4) { sline1c = *iptr++; - sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t) input); + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); - sout1 = Q6_Vqf32_vmpy_VsfVsf(sline1,scale_vec); + sout1 = Q6_Vqf32_vmpy_VsfVsf(sline1, scale_vec); sout1 = Q6_Vqf32_vadd_Vqf32Vqf32(sout1, es_vec); sout1 = Q6_Vsf_equals_Vqf32(sout1); sout1 = Q6_Vsf_vmin_VsfVsf(sout1, high_level_vec); @@ -1042,9 +1393,9 @@ int32_t qhmath_hvx_quantize_af_out_int8( // sout1 = qhmath_hvx_vw_convert_vqf32_rmode(Q6_Vqf32_vadd_VsfVsf(sout1, Q6_V_vzero()), 0); sline2c = *iptr++; - sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t) input); + sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t)input); - sout2 = Q6_Vqf32_vmpy_VsfVsf(sline2,scale_vec); + sout2 = Q6_Vqf32_vmpy_VsfVsf(sline2, scale_vec); sout2 = Q6_Vqf32_vadd_Vqf32Vqf32(sout2, es_vec); sout2 = Q6_Vsf_equals_Vqf32(sout2); sout2 = Q6_Vsf_vmin_VsfVsf(sout2, high_level_vec); @@ -1088,9 +1439,9 @@ int32_t qhmath_hvx_quantize_af_out_int8( // sout2 = qhmath_hvx_vw_convert_vqf32_rmode(Q6_Vqf32_vadd_VsfVsf(sout2, Q6_V_vzero()), 0); sline3c = *iptr++; - sline3 = Q6_V_valign_VVR(sline3c, sline3p, (size_t) input); + sline3 = Q6_V_valign_VVR(sline3c, sline3p, (size_t)input); - sout3 = Q6_Vqf32_vmpy_VsfVsf(sline3,scale_vec); + sout3 = Q6_Vqf32_vmpy_VsfVsf(sline3, scale_vec); sout3 = Q6_Vqf32_vadd_Vqf32Vqf32(sout3, es_vec); sout3 = Q6_Vsf_equals_Vqf32(sout3); sout3 = Q6_Vsf_vmin_VsfVsf(sout3, high_level_vec); @@ -1130,14 +1481,13 @@ int32_t qhmath_hvx_quantize_af_out_int8( sout3 = Q6_V_vmux_QVV(expgte23, sout3, tsout1); } - sout3 = Q6_Vw_equals_Vsf(sout3); // sout3 = qhmath_hvx_vw_convert_vqf32_rmode(Q6_Vqf32_vadd_VsfVsf(sout3, Q6_V_vzero()), 0); sline4c = *iptr++; - sline4 = Q6_V_valign_VVR(sline4c, sline4p, (size_t) input); + sline4 = Q6_V_valign_VVR(sline4c, sline4p, (size_t)input); - sout4 = Q6_Vqf32_vmpy_VsfVsf(sline4,scale_vec); + sout4 = Q6_Vqf32_vmpy_VsfVsf(sline4, scale_vec); sout4 = Q6_Vqf32_vadd_Vqf32Vqf32(sout4, es_vec); sout4 = Q6_Vsf_equals_Vqf32(sout4); sout4 = Q6_Vsf_vmin_VsfVsf(sout4, high_level_vec); @@ -1180,7 +1530,6 @@ int32_t qhmath_hvx_quantize_af_out_int8( sout4 = Q6_Vw_equals_Vsf(sout4); // sout4 = qhmath_hvx_vw_convert_vqf32_rmode(Q6_Vqf32_vadd_VsfVsf(sout4, Q6_V_vzero()), 0); - HVX_Vector reql_h = Q6_Vh_vpack_VwVw_sat(sout2, sout1); HVX_Vector reqh_h = Q6_Vh_vpack_VwVw_sat(sout4, sout3); HVX_Vector req_b = Q6_Vb_vpack_VhVh_sat(reqh_h, reql_h); @@ -1197,145 +1546,172 @@ int32_t qhmath_hvx_quantize_af_out_int8( return 0; } - -template +template GraphStatus llamaquantizeImpl(TensorType1 &out_0, const TensorType1 &in_0, - const PlainFloatTensor& scale) + const PlainFloatTensor &scale) { - /* - * add code here - * */ - /* - * To have good performance and stability, it is required to avoid heap memory - * allocation in this function. The heap memory allocation includes but not - * limited to calling malloc, operator new, constructing STL container objects - * like std::vector with default allocator, and adding items like calling - * std::vector::push_back to STL container objects with default allocator. - * - * Please check in SDK documentation for more information. - */ - - // HVX Method -- FP32 Version - out_0.set_dims(in_0); - auto [b_in, h_in, w_in, d_in] = in_0.dims(); - - float scale_ = scale(0,0,0,0); - - scale_ = 1.0f/scale_; - - size_t size = b_in*h_in*w_in*d_in; - DType dtype = in_0.get_dtype(); - - if (dtype == DType::Float16 && out_0.get_dtype() == DType::QUInt8) { - // NHWC - auto in_ptr = (__fp16*)in_0.raw_data_const(); - auto out_ptr = (__fp16*)out_0.raw_data(); - - qhmath_hvx_quantize_ahf(in_ptr, out_ptr, size, -128.0f, 127.0f, scale_); - - } - if (dtype == DType::Float32 && out_0.get_dtype() == DType::QUInt8) { - - // NHWC - auto in_ptr = (float*)in_0.raw_data_const(); - auto out_ptr = (float*)out_0.raw_data(); - qhmath_hvx_quantize_af(in_ptr, out_ptr, size, -128.0f, 127.0f, scale_); - - } - - if (dtype == DType::Float16 && out_0.get_dtype() == DType::QInt8) { - // NHWC - auto in_ptr = (__fp16*)in_0.raw_data_const(); - auto out_ptr = (__fp16*)out_0.raw_data(); - - qhmath_hvx_quantize_ahf_int8(in_ptr, out_ptr, size, -128.0f, 127.0f, scale_); - - } - - if (dtype == DType::Float32 && out_0.get_dtype() == DType::QInt8) { - - // NHWC - auto in_ptr = (float*)in_0.raw_data_const(); - auto out_ptr = (float*)out_0.raw_data(); - qhmath_hvx_quantize_af_out_int8(in_ptr, out_ptr, size, -128.0f, 127.0f, scale_); - - - } - -// auto out_ptr = (int8_t*)out_0.raw_data(); - -// out_ptr[0] = (int)dtype; -// out_ptr[1] = (int)out_0.get_dtype(); - - return GraphStatus::Success; + /* + * add code here + * */ + /* + * To have good performance and stability, it is required to avoid heap memory + * allocation in this function. The heap memory allocation includes but not + * limited to calling malloc, operator new, constructing STL container objects + * like std::vector with default allocator, and adding items like calling + * std::vector::push_back to STL container objects with default allocator. + * + * Please check in SDK documentation for more information. + */ + + // HVX Method -- FP32 Version + out_0.set_dims(in_0); + auto [b_in, h_in, w_in, d_in] = in_0.dims(); + + float scale_ = scale(0, 0, 0, 0); + + scale_ = 1.0f / scale_; + + size_t size = b_in * h_in * w_in * d_in; + DType dtype = in_0.get_dtype(); + + if (dtype == DType::Float16 && out_0.get_dtype() == DType::QUInt8) { + // NHWC + auto in_ptr = (__fp16 *)in_0.raw_data_const(); + auto out_ptr = (__fp16 *)out_0.raw_data(); + + qhmath_hvx_quantize_ahf(in_ptr, out_ptr, size, -128.0f, 127.0f, scale_); + } + if (dtype == DType::Float32 && out_0.get_dtype() == DType::QUInt8) { + // NHWC + auto in_ptr = (float *)in_0.raw_data_const(); + auto out_ptr = (float *)out_0.raw_data(); + qhmath_hvx_quantize_af(in_ptr, out_ptr, size, -128.0f, 127.0f, scale_); + } + + if (dtype == DType::Float16 && out_0.get_dtype() == DType::QUInt16) { + // NHWC + auto in_ptr = (__fp16 *)in_0.raw_data_const(); + auto out_ptr = (__fp16 *)out_0.raw_data(); + + qhmath_hvx_quantize_ui16_ahf(in_ptr, out_ptr, size, -32768.0f, 32767.0f, scale_); + } + if (dtype == DType::Float32 && out_0.get_dtype() == DType::QUInt16) { + // NHWC + auto in_ptr = (float *)in_0.raw_data_const(); + auto out_ptr = (float *)out_0.raw_data(); + qhmath_hvx_quantize_ui16_af(in_ptr, out_ptr, size, -32768.0f, 32767.0f, scale_); + } + + if (dtype == DType::Float16 && out_0.get_dtype() == DType::QInt8) { + // NHWC + auto in_ptr = (__fp16 *)in_0.raw_data_const(); + auto out_ptr = (__fp16 *)out_0.raw_data(); + + qhmath_hvx_quantize_ahf_int8(in_ptr, out_ptr, size, -128.0f, 127.0f, scale_); + } + + if (dtype == DType::Float32 && out_0.get_dtype() == DType::QInt8) { + // NHWC + auto in_ptr = (float *)in_0.raw_data_const(); + auto out_ptr = (float *)out_0.raw_data(); + qhmath_hvx_quantize_af_out_int8(in_ptr, out_ptr, size, -128.0f, 127.0f, scale_); + } + + // auto out_ptr = (int8_t*)out_0.raw_data(); + + // out_ptr[0] = (int)dtype; + // out_ptr[1] = (int)out_0.get_dtype(); + + return GraphStatus::Success; } #else extern float Round(float num); -template +template GraphStatus llamaquantizeImpl(TensorType1 &out_0, const TensorType1 &in_0, - const PlainFloatTensor& scale) + const PlainFloatTensor &scale) { out_0.set_dims(in_0); - float scale_ = scale(0,0,0,0); - - auto out_ptr = (int8_t*)out_0.raw_data(); + float scale_ = scale(0, 0, 0, 0); auto [b_in, h_in, w_in, d_in] = in_0.dims(); - for (Idx b = 0; b < b_in; b++) { - for (Idx h = 0; h < h_in; h++) { - - for (Idx w = 0; w < w_in; w++) { - - for (Idx d = 0; d < d_in; d++) { - - float inval = in_0(b, h, w, d); - - // float result = Round(inval / scale_); - - - long v = lroundf(inval / scale_); - - if (v > 127) - v = 127; - - if (v < -128) - v = -128; - - if (out_0.get_dtype() == DType::QUInt8) - v += 128; - - *out_ptr++ = static_cast(v); - } + DType dtype = in_0.get_dtype(); + if (dtype == DType::Float32 && out_0.get_dtype() == DType::QUInt8) { + auto out_ptr = (int8_t *)out_0.raw_data(); + + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + for (Idx w = 0; w < w_in; w++) { + for (Idx d = 0; d < d_in; d++) { + float inval = in_0(b, h, w, d); + + // float result = Round(inval / scale_); + + long v = lroundf(inval / scale_); + + if (v > 127) + v = 127; + + if (v < -128) + v = -128; + + v += 128; + + *out_ptr++ = static_cast(v); + } + } } } } - return GraphStatus::Success; -} -#endif + if (dtype == DType::Float32 && out_0.get_dtype() == DType::QUInt16) { + auto out_ptr = (int16_t *)out_0.raw_data(); -__attribute__((unused)) static float llamaquantizeCostFunc(const Op *op) -{ - /* - * add code here - * */ + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + for (Idx w = 0; w < w_in; w++) { + for (Idx d = 0; d < d_in; d++) { + float inval = in_0(b, h, w, d); - float cost = 0.0; // add cost computation here - return cost; -} + // float result = Round(inval / scale_); + + long v = lroundf(inval / scale_); + + if (v > 32767) + v = 32767; + + if (v < -32768) + v = -32768; + + v += 32768; + *out_ptr++ = static_cast(v); + } + } + } + } + } + return GraphStatus::Success; +} +#endif +__attribute__((unused)) static float llamaquantizeCostFunc(const Op *op) { + /* + * add code here + * */ + float cost = 0.0; // add cost computation here + return cost; +} /* At the bottom of the op file, call END_PKG_OP_DEFINITION(), where is as BEGIN_PKG_OP_DEFINITION diff --git a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAReLU.cpp b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAReLU.cpp index 1ef2c0c93..c56bb8719 100755 --- a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAReLU.cpp +++ b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAReLU.cpp @@ -9,14 +9,12 @@ #include "QnnOpPackage.h" #include "HTP/core/simple_reg.h" - BEGIN_PKG_OP_DEFINITION(PKG_LLaMAReLU); - // op execute function declarations -template -GraphStatus llamareluImpl(TensorType& out_0, - const TensorType& in_0); +template +GraphStatus llamareluImpl(TensorType &out_0, + const TensorType &in_0); // forward declaration of sample cost function static float llamareluCostFunc(const Op *op); @@ -61,11 +59,11 @@ DEF_PACKAGE_OP((llamareluImpl), "LLaMAReLU") * one definition per op, and this is optional * syntax: DEF_PACKAGE_PARAM_ORDER(OP,PARAM1,MANDATORY1,DEFAULT1,PARAM2,MANDATORY2,DEFAULT2...) * one or more parameters can be specified for each op - * order of parameters listed determines the order of parameters passed into op execution functions + * order of parameters listed determines the order of parameters passed into op execution functions * if an op does not have a parameter order definition, parameter order passed into Qnn_addNode * will be passed into op execution functions * if an op has a parameter order definition, any parameter passed into Qnn_addNode with unlisted - * name will be abandoned + * name will be abandoned * if two or more op packages with the same package name will be registered, they cannot list * conflicting parameter orders * PARAM refers to parameter name as a string literal @@ -79,7 +77,6 @@ DEF_PACKAGE_OP((llamareluImpl), "LLaMAReLU") * Qnn_addNode */ - /* execute functions for ops */ // #ifndef REFERENCE_OP @@ -94,7 +91,6 @@ DEF_PACKAGE_OP((llamareluImpl), "LLaMAReLU") // #define ONE 0x3F800000 // #define M_ONE 0xAF800000 - // int32_t hvx_relu_au8(uint8_t *restrict input, uint8_t *restrict output, uint32_t size) // { // HVX_Vector *input_v_ptr; @@ -107,7 +103,6 @@ DEF_PACKAGE_OP((llamareluImpl), "LLaMAReLU") // int32_t vectors_in_rounddown = size / 128; // int32_t leftover_size = leftover * sizeof(uint8_t); - // /* Check input arguments. Return error status if some argument has invalid value */ // if ((input == 0) || (output == 0) || (size == 0)) // { @@ -200,7 +195,6 @@ DEF_PACKAGE_OP((llamareluImpl), "LLaMAReLU") // * // * Please check in SDK documentation for more information. // */ - // out_0.set_dims(in_0); @@ -214,91 +208,80 @@ DEF_PACKAGE_OP((llamareluImpl), "LLaMAReLU") // return GraphStatus::Success; // } // #else -template -GraphStatus llamareluImpl(TensorType& out_0, - const TensorType& in_0) +template +GraphStatus llamareluImpl(TensorType &out_0, + const TensorType &in_0) { - out_0.set_dims(in_0); + out_0.set_dims(in_0); // NHWC - if (in_0.get_dtype() == DType::QUInt8) { - auto [b_in, h_in, w_in, d_in] = in_0.dims(); - for (Idx b = 0; b < b_in; b++) { - for (Idx h = 0; h < h_in; h++) { - for (Idx w = 0; w < w_in; w++) { - // SiLU - for (Idx d = 0; d < d_in; d++) { - uint8_t inval = in_0(b, h, w, d); - if (inval < 0) - inval = 0; - - out_0(b, h, w, d) = inval; - - } + if (in_0.get_dtype() == DType::QUInt8) { + auto [b_in, h_in, w_in, d_in] = in_0.dims(); + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + for (Idx w = 0; w < w_in; w++) { + // SiLU + for (Idx d = 0; d < d_in; d++) { + uint8_t inval = in_0(b, h, w, d); + if (inval < 0) + inval = 0; + + out_0(b, h, w, d) = inval; + } + } + } } - } - } - } else if (in_0.get_dtype() == DType::Float16) { - auto [b_in, h_in, w_in, d_in] = in_0.dims(); - - auto out_ptr = (__fp16*)out_0.raw_data(); - auto in_ptr = (__fp16*)in_0.raw_data_const(); - - for (Idx b = 0; b < b_in; b++) { - for (Idx h = 0; h < h_in; h++) { - for (Idx w = 0; w < w_in; w++) { - - for (Idx d = 0; d < d_in; d++) { - __fp16 inval = *in_ptr++; - if (inval < 0) - inval = 0; - - *out_ptr++ = inval; - - } + } else if (in_0.get_dtype() == DType::Float16) { + auto [b_in, h_in, w_in, d_in] = in_0.dims(); + + auto out_ptr = (__fp16 *)out_0.raw_data(); + auto in_ptr = (__fp16 *)in_0.raw_data_const(); + + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + for (Idx w = 0; w < w_in; w++) { + for (Idx d = 0; d < d_in; d++) { + __fp16 inval = *in_ptr++; + if (inval < 0) + inval = 0; + + *out_ptr++ = inval; + } + } + } } - } - } - } else if(in_0.get_dtype() == DType::Float32) { - auto [b_in, h_in, w_in, d_in] = in_0.dims(); - for (Idx b = 0; b < b_in; b++) { - for (Idx h = 0; h < h_in; h++) { - for (Idx w = 0; w < w_in; w++) { - for (Idx d = 0; d < d_in; d++) { - float inval = in_0(b, h, w, d); - if (inval < 0) - inval = 0; - - out_0(b, h, w, d) = inval; - - } + } else if (in_0.get_dtype() == DType::Float32) { + auto [b_in, h_in, w_in, d_in] = in_0.dims(); + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + for (Idx w = 0; w < w_in; w++) { + for (Idx d = 0; d < d_in; d++) { + float inval = in_0(b, h, w, d); + if (inval < 0) + inval = 0; + + out_0(b, h, w, d) = inval; + } + } + } } - } } - } - - return GraphStatus::Success; + return GraphStatus::Success; } // #endif +__attribute__((unused)) static float llamareluCostFunc(const Op *op) { + /* + * add code here + * */ -__attribute__((unused)) static float llamareluCostFunc(const Op *op) -{ - /* - * add code here - * */ - - float cost = 0.0; // add cost computation here - return cost; + float cost = 0.0; // add cost computation here + return cost; } - - - - /* At the bottom of the op file, call END_PKG_OP_DEFINITION(), where is as BEGIN_PKG_OP_DEFINITION */ diff --git a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMASuperSiLU.cpp b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMASuperSiLU.cpp index 0a849ca11..3976f60ba 100755 --- a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMASuperSiLU.cpp +++ b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMASuperSiLU.cpp @@ -9,18 +9,16 @@ #include "QnnOpPackage.h" #include "HTP/core/simple_reg.h" - BEGIN_PKG_OP_DEFINITION(PKG_LLaMASuperSiLU); - // op execute function declarations -template -GraphStatus llamasupersiluImpl(TensorType& out_0, - const TensorType& in_0, - const TensorType& in_1, - const PlainFloatTensor& a_scale, - const PlainFloatTensor& b_scale, - const PlainFloatTensor& o_scale); +template +GraphStatus llamasupersiluImpl(TensorType &out_0, + const TensorType &in_0, + const TensorType &in_1, + const PlainFloatTensor &a_scale, + const PlainFloatTensor &b_scale, + const PlainFloatTensor &o_scale); // forward declaration of sample cost function static float llamasupersiluCostFunc(const Op *op); @@ -65,11 +63,11 @@ DEF_PACKAGE_OP((llamasupersiluImpl), "LLaMASuperSiLU") * one definition per op, and this is optional * syntax: DEF_PACKAGE_PARAM_ORDER(OP,PARAM1,MANDATORY1,DEFAULT1,PARAM2,MANDATORY2,DEFAULT2...) * one or more parameters can be specified for each op - * order of parameters listed determines the order of parameters passed into op execution functions + * order of parameters listed determines the order of parameters passed into op execution functions * if an op does not have a parameter order definition, parameter order passed into Qnn_addNode * will be passed into op execution functions * if an op has a parameter order definition, any parameter passed into Qnn_addNode with unlisted - * name will be abandoned + * name will be abandoned * if two or more op packages with the same package name will be registered, they cannot list * conflicting parameter orders * PARAM refers to parameter name as a string literal @@ -82,7 +80,7 @@ DEF_PACKAGE_OP((llamasupersiluImpl), "LLaMASuperSiLU") * graph construction will skip this parameter when this parameter is not provided at * Qnn_addNode */ -DEF_PACKAGE_PARAM_ORDER("LLaMASuperSiLU", +DEF_PACKAGE_PARAM_ORDER("LLaMASuperSiLU", "a_scale", true, nullptr, @@ -93,7 +91,6 @@ DEF_PACKAGE_PARAM_ORDER("LLaMASuperSiLU", true, nullptr) - /* execute functions for ops */ #ifndef REFERENCE_OP @@ -103,8 +100,8 @@ DEF_PACKAGE_PARAM_ORDER("LLaMASuperSiLU", #include #include -#define BLOCK_SIZE (8*1024/VLEN) /* vector chunks */ -#define L2FETCH_AHEAD (BLOCK_SIZE) +#define BLOCK_SIZE (8 * 1024 / VLEN) /* vector chunks */ +#define L2FETCH_AHEAD (BLOCK_SIZE) #define FP16_MANTISA 10 #define FP16_EXPONENT_MASK 0x1f @@ -115,8 +112,7 @@ DEF_PACKAGE_PARAM_ORDER("LLaMASuperSiLU", #define ROUND_2_SCALE 22 #define ROUND_SCALSE ((1 << ROUND_2_SCALE) * 1.0f) -static inline int32_t float_to_fp16s(float input) -{ +static inline int32_t float_to_fp16s(float input) { union { int32_t i; __fp16 f[2]; @@ -124,47 +120,188 @@ static inline int32_t float_to_fp16s(float input) return fp32.i; } -static HVX_INLINE_ALWAYS uint32_t float_to_bits(float x) -{ - union { float f; uint32_t i; } fp32 = { .f = x }; +static HVX_INLINE_ALWAYS uint32_t float_to_bits(float x) { + union { + float f; + uint32_t i; + } fp32 = {.f = x}; return fp32.i; } - static const float fp16_c0_coeffs[32] __attribute__((aligned(VLEN))) = -{ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.13239719960243818,0.2216255210749415,0.3447664743728659,0.48137452032585476,0.5716299228719798,0.5547323231605259,0.5046287748870234,0.4999985574626892, -0.5000036514755082,0.49475652448004626,0.4441393352532763,0.428500379952032,0.5173297285470642,0.6541461039833616,0.7783931007462818,0.8678015179911097, + { + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.13239719960243818, + 0.2216255210749415, + 0.3447664743728659, + 0.48137452032585476, + 0.5716299228719798, + 0.5547323231605259, + 0.5046287748870234, + 0.4999985574626892, + 0.5000036514755082, + 0.49475652448004626, + 0.4441393352532763, + 0.428500379952032, + 0.5173297285470642, + 0.6541461039833616, + 0.7783931007462818, + 0.8678015179911097, }; static const float fp16_c1_coeffs[32] __attribute__((aligned(VLEN))) = -{ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.05928005756790343,0.11063222460270064,0.1932879057003057,0.30302440212086995,0.3922924462181049,0.36546332659415875,0.2644148210990377,0.24989020912329707, -0.2498532691910313,0.2661055781198988,0.36728015359480604,0.39215270010450015,0.3041825601732039,0.1940762094668647,0.11061794856987572,0.059174800917353595, + { + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.05928005756790343, + 0.11063222460270064, + 0.1932879057003057, + 0.30302440212086995, + 0.3922924462181049, + 0.36546332659415875, + 0.2644148210990377, + 0.24989020912329707, + 0.2498532691910313, + 0.2661055781198988, + 0.36728015359480604, + 0.39215270010450015, + 0.3041825601732039, + 0.1940762094668647, + 0.11061794856987572, + 0.059174800917353595, }; static const float fp16_c2_coeffs[32] __attribute__((aligned(VLEN))) = -{ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.010145494303219278,0.02123968384425681,0.04207468332514667,0.07519946712591977,0.10840620196267145,0.09270738184406795,0.015322371881818012,-0.0009948273994921822, -0.0011544907060402412,-0.017040517565094934,-0.09379878876657094,-0.10835043868732394,-0.07558705272699548,-0.04228875316413285,-0.021235740718738055,-0.010124599879590107, + { + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.010145494303219278, + 0.02123968384425681, + 0.04207468332514667, + 0.07519946712591977, + 0.10840620196267145, + 0.09270738184406795, + 0.015322371881818012, + -0.0009948273994921822, + 0.0011544907060402412, + -0.017040517565094934, + -0.09379878876657094, + -0.10835043868732394, + -0.07558705272699548, + -0.04228875316413285, + -0.021235740718738055, + -0.010124599879590107, }; static const float fp16_c3_coeffs[32] __attribute__((aligned(VLEN))) = -{ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.0007841223015974933,0.001850453397354219,0.004187899308371771,0.008640952434084206,0.01414741414964877,0.010117749275618,-0.01654848996354919,-0.02395108399453624, --0.024199111971064446,-0.015783556879607072,0.010407672131558174,0.014137608186323335,0.008698510795258909,0.004213708431213342,0.0018499827774393985,0.0007822799742289481, + { + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0007841223015974933, + 0.001850453397354219, + 0.004187899308371771, + 0.008640952434084206, + 0.01414741414964877, + 0.010117749275618, + -0.01654848996354919, + -0.02395108399453624, + -0.024199111971064446, + -0.015783556879607072, + 0.010407672131558174, + 0.014137608186323335, + 0.008698510795258909, + 0.004213708431213342, + 0.0018499827774393985, + 0.0007822799742289481, }; static const float fp16_c4_coeffs[32] __attribute__((aligned(VLEN))) = -{ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -2.3031641204975905e-05,6.150442488966733e-05,0.00015997783736818624,0.00038491646239693526,0.0007283649599237781,0.00034439150914392054,-0.003142246198646662,-0.004120389580321761, -0.004246050162553198,0.0030162727520777893,-0.00037312974308425725,-0.0007277242855014247,-0.00038811687679772674,-0.0001611434776868886,-6.14837984586862e-05,-2.297076123375133e-05, + { + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 2.3031641204975905e-05, + 6.150442488966733e-05, + 0.00015997783736818624, + 0.00038491646239693526, + 0.0007283649599237781, + 0.00034439150914392054, + -0.003142246198646662, + -0.004120389580321761, + 0.004246050162553198, + 0.0030162727520777893, + -0.00037312974308425725, + -0.0007277242855014247, + -0.00038811687679772674, + -0.0001611434776868886, + -6.14837984586862e-05, + -2.297076123375133e-05, }; int32_t hvx_supersilu_ahf( @@ -174,10 +311,8 @@ int32_t hvx_supersilu_ahf( float a_scale, float b_scale, float o_scale, - uint32_t size) -{ - if ((input == NULL) || (output == NULL) || (size == 0)) - { + uint32_t size) { + if ((input == NULL) || (output == NULL) || (size == 0)) { return -1; } @@ -195,18 +330,15 @@ int32_t hvx_supersilu_ahf( sline1p = *iptr++; sline2p = *iptr2++; - - // dequantize + // dequantize uint32_t convert = 0x00800080; HVX_Vector convert_vector = Q6_V_vsplat_R(convert); - HVX_Vector a_scale_vec = Q6_V_vsplat_R(float_to_fp16s(a_scale)); HVX_Vector b_scale_vec = Q6_V_vsplat_R(float_to_fp16s(b_scale)); HVX_Vector zero_v_sf = Q6_V_vzero(); - - //silu + // silu HVX_Vector input_min_v_hf; HVX_Vector input_shifted_v_hf; HVX_Vector input_scaled_v; @@ -286,27 +418,25 @@ int32_t hvx_supersilu_ahf( c3_coeff_dv.VV = Q6_Wuw_vzxt_Vuh(c3_coeff_v); c4_coeff_dv.VV = Q6_Wuw_vzxt_Vuh(c4_coeff_v); - // quantize HVX_Vector low_level_vec, high_level_vec, o_scale_vec, es_vec, round_scale_vec; HVX_Vector uintconvert = Q6_V_vsplat_R(0x80808080); HVX_Vector vmb = Q6_V_vsplat_R(0x40004000); - float post_scale_flt = a_scale * b_scale * o_scale; - int scexp = flt_getexp( post_scale_flt); - int rsh = min_i32( -scexp,7); // e.g. 0.11 -> 0.88, rsh = 3 + int scexp = flt_getexp(post_scale_flt); + int rsh = min_i32(-scexp, 7); // e.g. 0.11 -> 0.88, rsh = 3 float rsh_fac = flt_power2(rsh); int adj_bias = roundf_i32(128 * rsh_fac); - adj_bias = Q6_R_combine_RlRl( adj_bias, adj_bias); + adj_bias = Q6_R_combine_RlRl(adj_bias, adj_bias); HVX_Vector vadj = Q6_V_vsplat_R(adj_bias); - float es = 0.5; + float es = 0.5; low_level_vec = Q6_V_vsplat_R(float_to_fp16s(-128.0f)); high_level_vec = Q6_V_vsplat_R(float_to_fp16s(127.0f)); - o_scale_vec = Q6_V_vsplat_R(float_to_fp16s(post_scale_flt * rsh_fac * (1<<15))); + o_scale_vec = Q6_V_vsplat_R(float_to_fp16s(post_scale_flt * rsh_fac * (1 << 15))); // one_vec = Q6_V_vsplat_R(float_to_fp16s(1.0f)); // o_scale_vec = Q6_Vqf16_vadd_VhfVhf(o_scale_vec, zero_v_hf); es_vec = Q6_V_vsplat_R(float_to_fp16s(es)); @@ -323,45 +453,40 @@ int32_t hvx_supersilu_ahf( HVX_Vector negone = Q6_Vh_vsplat_R(FP16_NEG_1); HVX_Vector zero = Q6_V_vzero(); - for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) - { + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { block = Q6_R_min_RR(i, BLOCK_SIZE); l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); - if (l2fetch_block > 0) - { + if (l2fetch_block > 0) { l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); l2fetch(iptr2 + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); } - for (int32_t j = 0; j < block; ++j) - { + for (int32_t j = 0; j < block; ++j) { sline1c = *iptr++; sline2c = *iptr2++; sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t)input2); - HVX_Vector sline1_high; HVX_Vector sline1_low; // HVX_Vector sline2_high; // HVX_Vector sline2_low; { - // dequantize sline1 qf16 - HVX_VectorPair temp = Q6_Wh_vadd_VubVub(sline1, zero_v_sf); + // dequantize sline1 qf16 + HVX_VectorPair temp = Q6_Wh_vadd_VubVub(sline1, zero_v_sf); - temp = Q6_W_vshuff_VVR(Q6_V_hi_W(temp), Q6_V_lo_W(temp), -2); - HVX_Vector sout1 = Q6_Vh_vsub_VhVh(Q6_V_lo_W(temp), convert_vector); - HVX_Vector sout2 = Q6_Vh_vsub_VhVh(Q6_V_hi_W(temp), convert_vector); + temp = Q6_W_vshuff_VVR(Q6_V_hi_W(temp), Q6_V_lo_W(temp), -2); + HVX_Vector sout1 = Q6_Vh_vsub_VhVh(Q6_V_lo_W(temp), convert_vector); + HVX_Vector sout2 = Q6_Vh_vsub_VhVh(Q6_V_hi_W(temp), convert_vector); - sline1_low = Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout1), a_scale_vec); - sline1_low = Q6_Vhf_equals_Vqf16(sline1_low); - sline1_high = Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout2), a_scale_vec); - sline1_high = Q6_Vhf_equals_Vqf16(sline1_high); + sline1_low = Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout1), a_scale_vec); + sline1_low = Q6_Vhf_equals_Vqf16(sline1_low); + sline1_high = Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout2), a_scale_vec); + sline1_high = Q6_Vhf_equals_Vqf16(sline1_high); } - // { // // dequantize sline2 qf16 // HVX_VectorPair temp = Q6_Wh_vadd_VubVub(sline2, zero_v_sf); @@ -377,164 +502,162 @@ int32_t hvx_supersilu_ahf( // } { - // silu sline1_low - tmp_v = Q6_Vh_vdeal_Vh(sline1_low); - - /* Shift input range from [input_min, input_max] to [0, input_max - input_min] */ - input_shifted_v_hf = Q6_Vqf16_vsub_VhfVhf(tmp_v, input_min_v_hf); - - /* - * Scale shifted input range from [0, input_max - input_min] to [0,16.0) - * in order to get corresponding coefficient indexes - */ - input_scaled_v = Q6_Vqf16_vmpy_Vqf16Vqf16(input_shifted_v_hf, scale_v); - - /* - * VLUT 16 requires integer indexes. Shift scaled input range from [0,16.0) - * to [16.0,32.0) in order to convert float indexes to integer values. - * Float values, represented in IEEE 754, in range [16.0,32.0] have the - * same exponent, which means 4 MSB of mantissa carry information about - * integer index. - * Use the same input_scaled_v vector for hf and qf16 representation - */ - input_scaled_v = Q6_Vqf16_vadd_Vqf16Vhf(input_scaled_v, const16_0_v_hf); - - /* Convert back from qf16 to hf in order to extract integer index */ - tmp_v = Q6_Vhf_equals_Vqf16(input_scaled_v); - - /* Only 4 MSB bits of mantissa represent segment index */ - idx1_v = Q6_Vuh_vlsr_VuhR(tmp_v, 6); - - /* Ensure only 4 MSB bits of mantissa are used as indexes */ - idx1_v = Q6_V_vand_VV(idx1_v, mask_idx1_v); - - idx1_v = Q6_Vb_vshuff_Vb(idx1_v); - idx1_v = Q6_V_vor_VV(idx1_v, mask_idx2_v); - idx2_v = Q6_Vw_vasl_VwR(idx1_v, 16); - - /* Obtain the polynomial coefficients from lookup table */ - c0_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(c0_coeff_dv.VV), 1); - c0_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c0_coeff_vp, idx2_v, Q6_V_hi_W(c0_coeff_dv.VV), 1); - c1_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(c1_coeff_dv.VV), 1); - c1_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c1_coeff_vp, idx2_v, Q6_V_hi_W(c1_coeff_dv.VV), 1); - c2_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(c2_coeff_dv.VV), 1); - c2_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c2_coeff_vp, idx2_v, Q6_V_hi_W(c2_coeff_dv.VV), 1); - c3_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(c3_coeff_dv.VV), 1); - c3_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c3_coeff_vp, idx2_v, Q6_V_hi_W(c3_coeff_dv.VV), 1); - c4_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(c4_coeff_dv.VV), 1); - c4_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c4_coeff_vp, idx2_v, Q6_V_hi_W(c4_coeff_dv.VV), 1); - - /* Convert input from hf vector to qf32 vector pair for Horner's method*/ - input_vp_qf32 = Q6_Wqf32_vmpy_VhfVhf(sline1_low, one_v_hf); - - /* Perform evaluation of polynomial using Horner's method */ - output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(c4_coeff_vp), Q6_V_lo_W(input_vp_qf32)); - output_dv.V.lo = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(c3_coeff_vp)); - output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(input_vp_qf32)); - output_dv.V.lo = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(c2_coeff_vp)); - output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(input_vp_qf32)); - output_dv.V.lo = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(c1_coeff_vp)); - output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(input_vp_qf32)); - output_dv.V.lo = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(c0_coeff_vp)); - - output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(c4_coeff_vp), Q6_V_hi_W(input_vp_qf32)); - output_dv.V.hi = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(c3_coeff_vp)); - output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(input_vp_qf32)); - output_dv.V.hi = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(c2_coeff_vp)); - output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(input_vp_qf32)); - output_dv.V.hi = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(c1_coeff_vp)); - output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(input_vp_qf32)); - output_dv.V.hi = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(c0_coeff_vp)); - - // x * sigmod - // output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(input_vp_qf32), output_dv.V.lo); - // output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(input_vp_qf32), output_dv.V.hi); - - sline1_low = Q6_Vhf_equals_Wqf32(output_dv.VV); + // silu sline1_low + tmp_v = Q6_Vh_vdeal_Vh(sline1_low); + + /* Shift input range from [input_min, input_max] to [0, input_max - input_min] */ + input_shifted_v_hf = Q6_Vqf16_vsub_VhfVhf(tmp_v, input_min_v_hf); + + /* + * Scale shifted input range from [0, input_max - input_min] to [0,16.0) + * in order to get corresponding coefficient indexes + */ + input_scaled_v = Q6_Vqf16_vmpy_Vqf16Vqf16(input_shifted_v_hf, scale_v); + + /* + * VLUT 16 requires integer indexes. Shift scaled input range from [0,16.0) + * to [16.0,32.0) in order to convert float indexes to integer values. + * Float values, represented in IEEE 754, in range [16.0,32.0] have the + * same exponent, which means 4 MSB of mantissa carry information about + * integer index. + * Use the same input_scaled_v vector for hf and qf16 representation + */ + input_scaled_v = Q6_Vqf16_vadd_Vqf16Vhf(input_scaled_v, const16_0_v_hf); + + /* Convert back from qf16 to hf in order to extract integer index */ + tmp_v = Q6_Vhf_equals_Vqf16(input_scaled_v); + + /* Only 4 MSB bits of mantissa represent segment index */ + idx1_v = Q6_Vuh_vlsr_VuhR(tmp_v, 6); + + /* Ensure only 4 MSB bits of mantissa are used as indexes */ + idx1_v = Q6_V_vand_VV(idx1_v, mask_idx1_v); + + idx1_v = Q6_Vb_vshuff_Vb(idx1_v); + idx1_v = Q6_V_vor_VV(idx1_v, mask_idx2_v); + idx2_v = Q6_Vw_vasl_VwR(idx1_v, 16); + + /* Obtain the polynomial coefficients from lookup table */ + c0_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(c0_coeff_dv.VV), 1); + c0_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c0_coeff_vp, idx2_v, Q6_V_hi_W(c0_coeff_dv.VV), 1); + c1_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(c1_coeff_dv.VV), 1); + c1_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c1_coeff_vp, idx2_v, Q6_V_hi_W(c1_coeff_dv.VV), 1); + c2_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(c2_coeff_dv.VV), 1); + c2_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c2_coeff_vp, idx2_v, Q6_V_hi_W(c2_coeff_dv.VV), 1); + c3_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(c3_coeff_dv.VV), 1); + c3_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c3_coeff_vp, idx2_v, Q6_V_hi_W(c3_coeff_dv.VV), 1); + c4_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(c4_coeff_dv.VV), 1); + c4_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c4_coeff_vp, idx2_v, Q6_V_hi_W(c4_coeff_dv.VV), 1); + + /* Convert input from hf vector to qf32 vector pair for Horner's method*/ + input_vp_qf32 = Q6_Wqf32_vmpy_VhfVhf(sline1_low, one_v_hf); + + /* Perform evaluation of polynomial using Horner's method */ + output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(c4_coeff_vp), Q6_V_lo_W(input_vp_qf32)); + output_dv.V.lo = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(c3_coeff_vp)); + output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(input_vp_qf32)); + output_dv.V.lo = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(c2_coeff_vp)); + output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(input_vp_qf32)); + output_dv.V.lo = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(c1_coeff_vp)); + output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(input_vp_qf32)); + output_dv.V.lo = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(c0_coeff_vp)); + + output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(c4_coeff_vp), Q6_V_hi_W(input_vp_qf32)); + output_dv.V.hi = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(c3_coeff_vp)); + output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(input_vp_qf32)); + output_dv.V.hi = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(c2_coeff_vp)); + output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(input_vp_qf32)); + output_dv.V.hi = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(c1_coeff_vp)); + output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(input_vp_qf32)); + output_dv.V.hi = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(c0_coeff_vp)); + + // x * sigmod + // output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(input_vp_qf32), output_dv.V.lo); + // output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(input_vp_qf32), output_dv.V.hi); + + sline1_low = Q6_Vhf_equals_Wqf32(output_dv.VV); } - { - // silu sline1_high - tmp_v = Q6_Vh_vdeal_Vh(sline1_high); - - /* Shift input range from [input_min, input_max] to [0, input_max - input_min] */ - input_shifted_v_hf = Q6_Vqf16_vsub_VhfVhf(tmp_v, input_min_v_hf); - - /* - * Scale shifted input range from [0, input_max - input_min] to [0,16.0) - * in order to get corresponding coefficient indexes - */ - input_scaled_v = Q6_Vqf16_vmpy_Vqf16Vqf16(input_shifted_v_hf, scale_v); - - /* - * VLUT 16 requires integer indexes. Shift scaled input range from [0,16.0) - * to [16.0,32.0) in order to convert float indexes to integer values. - * Float values, represented in IEEE 754, in range [16.0,32.0] have the - * same exponent, which means 4 MSB of mantissa carry information about - * integer index. - * Use the same input_scaled_v vector for hf and qf16 representation - */ - input_scaled_v = Q6_Vqf16_vadd_Vqf16Vhf(input_scaled_v, const16_0_v_hf); - - /* Convert back from qf16 to hf in order to extract integer index */ - tmp_v = Q6_Vhf_equals_Vqf16(input_scaled_v); - - /* Only 4 MSB bits of mantissa represent segment index */ - idx1_v = Q6_Vuh_vlsr_VuhR(tmp_v, 6); - - /* Ensure only 4 MSB bits of mantissa are used as indexes */ - idx1_v = Q6_V_vand_VV(idx1_v, mask_idx1_v); - - idx1_v = Q6_Vb_vshuff_Vb(idx1_v); - idx1_v = Q6_V_vor_VV(idx1_v, mask_idx2_v); - idx2_v = Q6_Vw_vasl_VwR(idx1_v, 16); - - /* Obtain the polynomial coefficients from lookup table */ - c0_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(c0_coeff_dv.VV), 1); - c0_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c0_coeff_vp, idx2_v, Q6_V_hi_W(c0_coeff_dv.VV), 1); - c1_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(c1_coeff_dv.VV), 1); - c1_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c1_coeff_vp, idx2_v, Q6_V_hi_W(c1_coeff_dv.VV), 1); - c2_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(c2_coeff_dv.VV), 1); - c2_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c2_coeff_vp, idx2_v, Q6_V_hi_W(c2_coeff_dv.VV), 1); - c3_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(c3_coeff_dv.VV), 1); - c3_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c3_coeff_vp, idx2_v, Q6_V_hi_W(c3_coeff_dv.VV), 1); - c4_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(c4_coeff_dv.VV), 1); - c4_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c4_coeff_vp, idx2_v, Q6_V_hi_W(c4_coeff_dv.VV), 1); - - /* Convert input from hf vector to qf32 vector pair for Horner's method*/ - input_vp_qf32 = Q6_Wqf32_vmpy_VhfVhf(sline1_high, one_v_hf); - - /* Perform evaluation of polynomial using Horner's method */ - output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(c4_coeff_vp), Q6_V_lo_W(input_vp_qf32)); - output_dv.V.lo = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(c3_coeff_vp)); - output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(input_vp_qf32)); - output_dv.V.lo = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(c2_coeff_vp)); - output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(input_vp_qf32)); - output_dv.V.lo = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(c1_coeff_vp)); - output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(input_vp_qf32)); - output_dv.V.lo = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(c0_coeff_vp)); - - output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(c4_coeff_vp), Q6_V_hi_W(input_vp_qf32)); - output_dv.V.hi = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(c3_coeff_vp)); - output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(input_vp_qf32)); - output_dv.V.hi = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(c2_coeff_vp)); - output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(input_vp_qf32)); - output_dv.V.hi = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(c1_coeff_vp)); - output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(input_vp_qf32)); - output_dv.V.hi = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(c0_coeff_vp)); - - // x * sigmod - // output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(input_vp_qf32), output_dv.V.lo); - // output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(input_vp_qf32), output_dv.V.hi); - - sline1_high = Q6_Vhf_equals_Wqf32(output_dv.VV); + // silu sline1_high + tmp_v = Q6_Vh_vdeal_Vh(sline1_high); + + /* Shift input range from [input_min, input_max] to [0, input_max - input_min] */ + input_shifted_v_hf = Q6_Vqf16_vsub_VhfVhf(tmp_v, input_min_v_hf); + + /* + * Scale shifted input range from [0, input_max - input_min] to [0,16.0) + * in order to get corresponding coefficient indexes + */ + input_scaled_v = Q6_Vqf16_vmpy_Vqf16Vqf16(input_shifted_v_hf, scale_v); + + /* + * VLUT 16 requires integer indexes. Shift scaled input range from [0,16.0) + * to [16.0,32.0) in order to convert float indexes to integer values. + * Float values, represented in IEEE 754, in range [16.0,32.0] have the + * same exponent, which means 4 MSB of mantissa carry information about + * integer index. + * Use the same input_scaled_v vector for hf and qf16 representation + */ + input_scaled_v = Q6_Vqf16_vadd_Vqf16Vhf(input_scaled_v, const16_0_v_hf); + + /* Convert back from qf16 to hf in order to extract integer index */ + tmp_v = Q6_Vhf_equals_Vqf16(input_scaled_v); + + /* Only 4 MSB bits of mantissa represent segment index */ + idx1_v = Q6_Vuh_vlsr_VuhR(tmp_v, 6); + + /* Ensure only 4 MSB bits of mantissa are used as indexes */ + idx1_v = Q6_V_vand_VV(idx1_v, mask_idx1_v); + + idx1_v = Q6_Vb_vshuff_Vb(idx1_v); + idx1_v = Q6_V_vor_VV(idx1_v, mask_idx2_v); + idx2_v = Q6_Vw_vasl_VwR(idx1_v, 16); + + /* Obtain the polynomial coefficients from lookup table */ + c0_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(c0_coeff_dv.VV), 1); + c0_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c0_coeff_vp, idx2_v, Q6_V_hi_W(c0_coeff_dv.VV), 1); + c1_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(c1_coeff_dv.VV), 1); + c1_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c1_coeff_vp, idx2_v, Q6_V_hi_W(c1_coeff_dv.VV), 1); + c2_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(c2_coeff_dv.VV), 1); + c2_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c2_coeff_vp, idx2_v, Q6_V_hi_W(c2_coeff_dv.VV), 1); + c3_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(c3_coeff_dv.VV), 1); + c3_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c3_coeff_vp, idx2_v, Q6_V_hi_W(c3_coeff_dv.VV), 1); + c4_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(c4_coeff_dv.VV), 1); + c4_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c4_coeff_vp, idx2_v, Q6_V_hi_W(c4_coeff_dv.VV), 1); + + /* Convert input from hf vector to qf32 vector pair for Horner's method*/ + input_vp_qf32 = Q6_Wqf32_vmpy_VhfVhf(sline1_high, one_v_hf); + + /* Perform evaluation of polynomial using Horner's method */ + output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(c4_coeff_vp), Q6_V_lo_W(input_vp_qf32)); + output_dv.V.lo = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(c3_coeff_vp)); + output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(input_vp_qf32)); + output_dv.V.lo = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(c2_coeff_vp)); + output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(input_vp_qf32)); + output_dv.V.lo = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(c1_coeff_vp)); + output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(input_vp_qf32)); + output_dv.V.lo = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(c0_coeff_vp)); + + output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(c4_coeff_vp), Q6_V_hi_W(input_vp_qf32)); + output_dv.V.hi = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(c3_coeff_vp)); + output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(input_vp_qf32)); + output_dv.V.hi = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(c2_coeff_vp)); + output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(input_vp_qf32)); + output_dv.V.hi = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(c1_coeff_vp)); + output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(input_vp_qf32)); + output_dv.V.hi = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(c0_coeff_vp)); + + // x * sigmod + // output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(input_vp_qf32), output_dv.V.lo); + // output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(input_vp_qf32), output_dv.V.hi); + + sline1_high = Q6_Vhf_equals_Wqf32(output_dv.VV); } - HVX_Vector sline_high; HVX_Vector sline_low; - + // { // // mul // sline_high = Q6_Vqf16_vmpy_VhfVhf(sline1_high, sline2_high); @@ -543,35 +666,34 @@ int32_t hvx_supersilu_ahf( // sline_high = Q6_Vhf_equals_Vqf16(sline_high); // sline_low = Q6_Vhf_equals_Vqf16(sline_low); // } - + HVX_VectorPair mul_output; { - // uint8 mul - // (a-128)*(b-128) = a*b - 128 (a+b) + 128*128 - HVX_VectorPair prod1 = Q6_Wuh_vmpyacc_WuhVubVub(Q6_W_vcombine_VV(vmb,vmb), sline1, sline2); - HVX_VectorPair prod2 = Q6_Wh_vmpa_WubRub( Q6_W_vcombine_VV(sline2, sline1), 0x80808080); - mul_output = Q6_Wh_vsub_WhWh(prod1, prod2); + // uint8 mul + // (a-128)*(b-128) = a*b - 128 (a+b) + 128*128 + HVX_VectorPair prod1 = Q6_Wuh_vmpyacc_WuhVubVub(Q6_W_vcombine_VV(vmb, vmb), sline1, sline2); + HVX_VectorPair prod2 = Q6_Wh_vmpa_WubRub(Q6_W_vcombine_VV(sline2, sline1), 0x80808080); + mul_output = Q6_Wh_vsub_WhWh(prod1, prod2); - mul_output = Q6_W_vshuff_VVR(Q6_V_hi_W(mul_output), Q6_V_lo_W(mul_output), -2); - - // sline_low = Q6_Vqf16_vmpy_VhfVhf(sline1_low, Q6_Vhf_equals_Vh(Q6_V_lo_W(mul_output))); - // sline_high = Q6_Vqf16_vmpy_VhfVhf(sline1_high, Q6_Vhf_equals_Vh(Q6_V_hi_W(mul_output))); + mul_output = Q6_W_vshuff_VVR(Q6_V_hi_W(mul_output), Q6_V_lo_W(mul_output), -2); + // sline_low = Q6_Vqf16_vmpy_VhfVhf(sline1_low, Q6_Vhf_equals_Vh(Q6_V_lo_W(mul_output))); + // sline_high = Q6_Vqf16_vmpy_VhfVhf(sline1_high, Q6_Vhf_equals_Vh(Q6_V_hi_W(mul_output))); } { - // scaling quantize - sline_low = Q6_Vqf16_vmpy_VhfVhf(sline1_low, o_scale_vec); - sline_low = Q6_Vh_equals_Vhf(Q6_Vhf_equals_Vqf16(sline_low)); - sline_low = Q6_Vh_vadd_VhVh_sat(Q6_Vh_vmpy_VhVh_s1_rnd_sat(Q6_V_lo_W(mul_output), sline_low), vadj); - - sline_high = Q6_Vqf16_vmpy_VhfVhf(sline1_high, o_scale_vec); - sline_high = Q6_Vh_equals_Vhf(Q6_Vhf_equals_Vqf16(sline_high)); - sline_high = Q6_Vh_vadd_VhVh_sat(Q6_Vh_vmpy_VhVh_s1_rnd_sat(sline_high, Q6_V_hi_W(mul_output)), vadj); - - HVX_Vector sout = Q6_Vub_vasr_VhVhR_rnd_sat( sline_high, sline_low, rsh); - sout = Q6_Vb_vdeal_Vb(sout); - *optr++ = sout; + // scaling quantize + sline_low = Q6_Vqf16_vmpy_VhfVhf(sline1_low, o_scale_vec); + sline_low = Q6_Vh_equals_Vhf(Q6_Vhf_equals_Vqf16(sline_low)); + sline_low = Q6_Vh_vadd_VhVh_sat(Q6_Vh_vmpy_VhVh_s1_rnd_sat(Q6_V_lo_W(mul_output), sline_low), vadj); + + sline_high = Q6_Vqf16_vmpy_VhfVhf(sline1_high, o_scale_vec); + sline_high = Q6_Vh_equals_Vhf(Q6_Vhf_equals_Vqf16(sline_high)); + sline_high = Q6_Vh_vadd_VhVh_sat(Q6_Vh_vmpy_VhVh_s1_rnd_sat(sline_high, Q6_V_hi_W(mul_output)), vadj); + + HVX_Vector sout = Q6_Vub_vasr_VhVhR_rnd_sat(sline_high, sline_low, rsh); + sout = Q6_Vb_vdeal_Vb(sout); + *optr++ = sout; } // { @@ -589,7 +711,6 @@ int32_t hvx_supersilu_ahf( // sout1_low = Q6_V_lo_W(sout1_pair); // sout1_high = Q6_V_hi_W(sout1_pair); - // // { // // HVX_Vector exp = Q6_Vh_vasr_VhR(sout1, FP16_MANTISA); // // exp = Q6_V_vand_VV(exp, expmask); @@ -629,7 +750,6 @@ int32_t hvx_supersilu_ahf( // sout1_high = Q6_Vw_equals_Vsf(sout1_high); // sout1_high = Q6_Vw_vasr_VwR(sout1_high, ROUND_2_SCALE); - // HVX_Vector sout2 = Q6_Vqf16_vmpy_Vqf16Vhf(sline_high, o_scale_vec); // sout2 = Q6_Vqf16_vadd_Vqf16Vqf16(sout2, es_vec); // sout2 = Q6_Vhf_equals_Vqf16(sout2); @@ -689,321 +809,308 @@ int32_t hvx_supersilu_ahf( // *optr++ = Q6_Vb_vadd_VbVb(req_b, uintconvert); // } - - - - sline1p = sline1c; sline2p = sline2c; } } if (vectors_in_rounddown > 0) { + o_scale_vec = Q6_V_vsplat_R(float_to_fp16s(o_scale)); - o_scale_vec = Q6_V_vsplat_R(float_to_fp16s(o_scale)); - - sline1c = is_aligned(iptr, VLEN) && leftover == 0 ? sline1p : *iptr++; - sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t) input); - - sline2c = is_aligned(iptr2, VLEN) && leftover == 0 ? sline2p : *iptr2++; - sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t) input2); - - - HVX_Vector sline1_high; - HVX_Vector sline1_low; - HVX_Vector sline2_high; - HVX_Vector sline2_low; - - { - // dequantize sline1 qf16 - HVX_VectorPair temp = Q6_Wh_vadd_VubVub(sline1, zero_v_sf); - - temp = Q6_W_vshuff_VVR(Q6_V_hi_W(temp), Q6_V_lo_W(temp), -2); - HVX_Vector sout1 = Q6_Vh_vsub_VhVh(Q6_V_lo_W(temp), convert_vector); - HVX_Vector sout2 = Q6_Vh_vsub_VhVh(Q6_V_hi_W(temp), convert_vector); - - sline1_low = Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout1), a_scale_vec); - sline1_low = Q6_Vhf_equals_Vqf16(sline1_low); - sline1_high = Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout2), a_scale_vec); - sline1_high = Q6_Vhf_equals_Vqf16(sline1_high); - } - - - { - // dequantize sline2 qf16 - HVX_VectorPair temp = Q6_Wh_vadd_VubVub(sline2, zero_v_sf); - - temp = Q6_W_vshuff_VVR(Q6_V_hi_W(temp), Q6_V_lo_W(temp), -2); - HVX_Vector sout1 = Q6_Vh_vsub_VhVh(Q6_V_lo_W(temp), convert_vector); - HVX_Vector sout2 = Q6_Vh_vsub_VhVh(Q6_V_hi_W(temp), convert_vector); - - sline2_low = Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout1), b_scale_vec); - sline2_low = Q6_Vhf_equals_Vqf16(sline2_low); - sline2_high = Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout2), b_scale_vec); - sline2_high = Q6_Vhf_equals_Vqf16(sline2_high); - } - - { - // silu sline1_low - tmp_v = Q6_Vh_vdeal_Vh(sline1_low); - - /* Shift input range from [input_min, input_max] to [0, input_max - input_min] */ - input_shifted_v_hf = Q6_Vqf16_vsub_VhfVhf(tmp_v, input_min_v_hf); - - /* - * Scale shifted input range from [0, input_max - input_min] to [0,16.0) - * in order to get corresponding coefficient indexes - */ - input_scaled_v = Q6_Vqf16_vmpy_Vqf16Vqf16(input_shifted_v_hf, scale_v); - - /* - * VLUT 16 requires integer indexes. Shift scaled input range from [0,16.0) - * to [16.0,32.0) in order to convert float indexes to integer values. - * Float values, represented in IEEE 754, in range [16.0,32.0] have the - * same exponent, which means 4 MSB of mantissa carry information about - * integer index. - * Use the same input_scaled_v vector for hf and qf16 representation - */ - input_scaled_v = Q6_Vqf16_vadd_Vqf16Vhf(input_scaled_v, const16_0_v_hf); - - /* Convert back from qf16 to hf in order to extract integer index */ - tmp_v = Q6_Vhf_equals_Vqf16(input_scaled_v); - - /* Only 4 MSB bits of mantissa represent segment index */ - idx1_v = Q6_Vuh_vlsr_VuhR(tmp_v, 6); - - /* Ensure only 4 MSB bits of mantissa are used as indexes */ - idx1_v = Q6_V_vand_VV(idx1_v, mask_idx1_v); - - idx1_v = Q6_Vb_vshuff_Vb(idx1_v); - idx1_v = Q6_V_vor_VV(idx1_v, mask_idx2_v); - idx2_v = Q6_Vw_vasl_VwR(idx1_v, 16); - - /* Obtain the polynomial coefficients from lookup table */ - c0_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(c0_coeff_dv.VV), 1); - c0_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c0_coeff_vp, idx2_v, Q6_V_hi_W(c0_coeff_dv.VV), 1); - c1_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(c1_coeff_dv.VV), 1); - c1_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c1_coeff_vp, idx2_v, Q6_V_hi_W(c1_coeff_dv.VV), 1); - c2_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(c2_coeff_dv.VV), 1); - c2_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c2_coeff_vp, idx2_v, Q6_V_hi_W(c2_coeff_dv.VV), 1); - c3_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(c3_coeff_dv.VV), 1); - c3_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c3_coeff_vp, idx2_v, Q6_V_hi_W(c3_coeff_dv.VV), 1); - c4_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(c4_coeff_dv.VV), 1); - c4_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c4_coeff_vp, idx2_v, Q6_V_hi_W(c4_coeff_dv.VV), 1); - - /* Convert input from hf vector to qf32 vector pair for Horner's method*/ - input_vp_qf32 = Q6_Wqf32_vmpy_VhfVhf(sline1_low, one_v_hf); - - /* Perform evaluation of polynomial using Horner's method */ - output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(c4_coeff_vp), Q6_V_lo_W(input_vp_qf32)); - output_dv.V.lo = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(c3_coeff_vp)); - output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(input_vp_qf32)); - output_dv.V.lo = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(c2_coeff_vp)); - output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(input_vp_qf32)); - output_dv.V.lo = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(c1_coeff_vp)); - output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(input_vp_qf32)); - output_dv.V.lo = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(c0_coeff_vp)); - - output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(c4_coeff_vp), Q6_V_hi_W(input_vp_qf32)); - output_dv.V.hi = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(c3_coeff_vp)); - output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(input_vp_qf32)); - output_dv.V.hi = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(c2_coeff_vp)); - output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(input_vp_qf32)); - output_dv.V.hi = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(c1_coeff_vp)); - output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(input_vp_qf32)); - output_dv.V.hi = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(c0_coeff_vp)); - - // x * sigmod - output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(input_vp_qf32), output_dv.V.lo); - output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(input_vp_qf32), output_dv.V.hi); - - sline1_low = Q6_Vhf_equals_Wqf32(output_dv.VV); - } - - - { - // silu sline1_high - tmp_v = Q6_Vh_vdeal_Vh(sline1_high); - - /* Shift input range from [input_min, input_max] to [0, input_max - input_min] */ - input_shifted_v_hf = Q6_Vqf16_vsub_VhfVhf(tmp_v, input_min_v_hf); - - /* - * Scale shifted input range from [0, input_max - input_min] to [0,16.0) - * in order to get corresponding coefficient indexes - */ - input_scaled_v = Q6_Vqf16_vmpy_Vqf16Vqf16(input_shifted_v_hf, scale_v); - - /* - * VLUT 16 requires integer indexes. Shift scaled input range from [0,16.0) - * to [16.0,32.0) in order to convert float indexes to integer values. - * Float values, represented in IEEE 754, in range [16.0,32.0] have the - * same exponent, which means 4 MSB of mantissa carry information about - * integer index. - * Use the same input_scaled_v vector for hf and qf16 representation - */ - input_scaled_v = Q6_Vqf16_vadd_Vqf16Vhf(input_scaled_v, const16_0_v_hf); - - /* Convert back from qf16 to hf in order to extract integer index */ - tmp_v = Q6_Vhf_equals_Vqf16(input_scaled_v); - - /* Only 4 MSB bits of mantissa represent segment index */ - idx1_v = Q6_Vuh_vlsr_VuhR(tmp_v, 6); - - /* Ensure only 4 MSB bits of mantissa are used as indexes */ - idx1_v = Q6_V_vand_VV(idx1_v, mask_idx1_v); - - idx1_v = Q6_Vb_vshuff_Vb(idx1_v); - idx1_v = Q6_V_vor_VV(idx1_v, mask_idx2_v); - idx2_v = Q6_Vw_vasl_VwR(idx1_v, 16); - - /* Obtain the polynomial coefficients from lookup table */ - c0_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(c0_coeff_dv.VV), 1); - c0_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c0_coeff_vp, idx2_v, Q6_V_hi_W(c0_coeff_dv.VV), 1); - c1_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(c1_coeff_dv.VV), 1); - c1_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c1_coeff_vp, idx2_v, Q6_V_hi_W(c1_coeff_dv.VV), 1); - c2_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(c2_coeff_dv.VV), 1); - c2_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c2_coeff_vp, idx2_v, Q6_V_hi_W(c2_coeff_dv.VV), 1); - c3_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(c3_coeff_dv.VV), 1); - c3_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c3_coeff_vp, idx2_v, Q6_V_hi_W(c3_coeff_dv.VV), 1); - c4_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(c4_coeff_dv.VV), 1); - c4_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c4_coeff_vp, idx2_v, Q6_V_hi_W(c4_coeff_dv.VV), 1); - - /* Convert input from hf vector to qf32 vector pair for Horner's method*/ - input_vp_qf32 = Q6_Wqf32_vmpy_VhfVhf(sline1_high, one_v_hf); - - /* Perform evaluation of polynomial using Horner's method */ - output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(c4_coeff_vp), Q6_V_lo_W(input_vp_qf32)); - output_dv.V.lo = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(c3_coeff_vp)); - output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(input_vp_qf32)); - output_dv.V.lo = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(c2_coeff_vp)); - output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(input_vp_qf32)); - output_dv.V.lo = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(c1_coeff_vp)); - output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(input_vp_qf32)); - output_dv.V.lo = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(c0_coeff_vp)); - - output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(c4_coeff_vp), Q6_V_hi_W(input_vp_qf32)); - output_dv.V.hi = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(c3_coeff_vp)); - output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(input_vp_qf32)); - output_dv.V.hi = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(c2_coeff_vp)); - output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(input_vp_qf32)); - output_dv.V.hi = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(c1_coeff_vp)); - output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(input_vp_qf32)); - output_dv.V.hi = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(c0_coeff_vp)); - - // x * sigmod - output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(input_vp_qf32), output_dv.V.lo); - output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(input_vp_qf32), output_dv.V.hi); - - sline1_high = Q6_Vhf_equals_Wqf32(output_dv.VV); - } - - - HVX_Vector sline_high; - HVX_Vector sline_low; - - { - // mul - sline_high = Q6_Vqf16_vmpy_VhfVhf(sline1_high, sline2_high); - sline_low = Q6_Vqf16_vmpy_VhfVhf(sline1_low, sline2_low); - - sline_high = Q6_Vhf_equals_Vqf16(sline_high); - sline_low = Q6_Vhf_equals_Vqf16(sline_low); - } - - - { - // quantize - HVX_Vector sout1 = Q6_Vqf16_vmpy_VhfVhf(sline_low, o_scale_vec); - sout1 = Q6_Vqf16_vadd_Vqf16Vqf16(sout1, es_vec); - sout1 = Q6_Vhf_equals_Vqf16(sout1); - sout1 = Q6_Vhf_vmin_VhfVhf(sout1, high_level_vec); - sout1 = Q6_Vhf_vmax_VhfVhf(sout1, low_level_vec); + sline1c = is_aligned(iptr, VLEN) && leftover == 0 ? sline1p : *iptr++; + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); + + sline2c = is_aligned(iptr2, VLEN) && leftover == 0 ? sline2p : *iptr2++; + sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t)input2); + + HVX_Vector sline1_high; + HVX_Vector sline1_low; + HVX_Vector sline2_high; + HVX_Vector sline2_low; { - HVX_Vector exp = Q6_Vh_vasr_VhR(sout1, FP16_MANTISA); - exp = Q6_V_vand_VV(exp, expmask); - exp = Q6_Vh_vsub_VhVh(exp, expbias); + // dequantize sline1 qf16 + HVX_VectorPair temp = Q6_Wh_vadd_VubVub(sline1, zero_v_sf); - HVX_Vector man = Q6_Vh_vasr_VhVh(manmask, exp); - HVX_Vector manzero = Q6_V_vand_VV(sout1, man); + temp = Q6_W_vshuff_VVR(Q6_V_hi_W(temp), Q6_V_lo_W(temp), -2); + HVX_Vector sout1 = Q6_Vh_vsub_VhVh(Q6_V_lo_W(temp), convert_vector); + HVX_Vector sout2 = Q6_Vh_vsub_VhVh(Q6_V_hi_W(temp), convert_vector); - HVX_Vector sign = Q6_Vh_vasr_VhR(sout1, FP16_SIGN); - HVX_Vector issignpos = Q6_Q_vcmp_eq_VhVh(sign, zero); + sline1_low = Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout1), a_scale_vec); + sline1_low = Q6_Vhf_equals_Vqf16(sline1_low); + sline1_high = Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout2), a_scale_vec); + sline1_high = Q6_Vhf_equals_Vqf16(sline1_high); + } - HVX_Vector expgte23 = Q6_Q_vcmp_gt_VhVh(exp, exp23); - HVX_Vector expgte0 = Q6_Q_vcmp_gt_VhVh(exp, exp0); - HVX_Vector maneqzero = Q6_Q_vcmp_eq_VhVh(manzero, zero); + { + // dequantize sline2 qf16 + HVX_VectorPair temp = Q6_Wh_vadd_VubVub(sline2, zero_v_sf); - HVX_Vector exppos_signneg = Q6_Vh_vadd_VhVh(sout1, man); - man = Q6_V_vnot_V(man); - HVX_Vector exppos_signpos = Q6_V_vand_VV(sout1, man); - exppos_signneg = Q6_V_vand_VV(exppos_signneg, man); - HVX_Vector shift1 = Q6_Vh_vasl_VhR(sout1, 1); - HVX_Vector iszero = Q6_Q_vcmp_eq_VhVh(shift1, zero); + temp = Q6_W_vshuff_VVR(Q6_V_hi_W(temp), Q6_V_lo_W(temp), -2); + HVX_Vector sout1 = Q6_Vh_vsub_VhVh(Q6_V_lo_W(temp), convert_vector); + HVX_Vector sout2 = Q6_Vh_vsub_VhVh(Q6_V_hi_W(temp), convert_vector); - // exp >= 0 - HVX_Vector tsout1 = Q6_V_vmux_QVV(issignpos, exppos_signpos, exppos_signneg); - tsout1 = Q6_V_vmux_QVV(maneqzero, sout1, tsout1); + sline2_low = Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout1), b_scale_vec); + sline2_low = Q6_Vhf_equals_Vqf16(sline2_low); + sline2_high = Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout2), b_scale_vec); + sline2_high = Q6_Vhf_equals_Vqf16(sline2_high); + } - // exp < 0 (-1, 1) - HVX_Vector tsout2 = Q6_V_vmux_QVV(iszero, sout1, negone); - tsout2 = Q6_V_vmux_QVV(issignpos, zero, tsout2); + { + // silu sline1_low + tmp_v = Q6_Vh_vdeal_Vh(sline1_low); + + /* Shift input range from [input_min, input_max] to [0, input_max - input_min] */ + input_shifted_v_hf = Q6_Vqf16_vsub_VhfVhf(tmp_v, input_min_v_hf); + + /* + * Scale shifted input range from [0, input_max - input_min] to [0,16.0) + * in order to get corresponding coefficient indexes + */ + input_scaled_v = Q6_Vqf16_vmpy_Vqf16Vqf16(input_shifted_v_hf, scale_v); + + /* + * VLUT 16 requires integer indexes. Shift scaled input range from [0,16.0) + * to [16.0,32.0) in order to convert float indexes to integer values. + * Float values, represented in IEEE 754, in range [16.0,32.0] have the + * same exponent, which means 4 MSB of mantissa carry information about + * integer index. + * Use the same input_scaled_v vector for hf and qf16 representation + */ + input_scaled_v = Q6_Vqf16_vadd_Vqf16Vhf(input_scaled_v, const16_0_v_hf); + + /* Convert back from qf16 to hf in order to extract integer index */ + tmp_v = Q6_Vhf_equals_Vqf16(input_scaled_v); + + /* Only 4 MSB bits of mantissa represent segment index */ + idx1_v = Q6_Vuh_vlsr_VuhR(tmp_v, 6); + + /* Ensure only 4 MSB bits of mantissa are used as indexes */ + idx1_v = Q6_V_vand_VV(idx1_v, mask_idx1_v); + + idx1_v = Q6_Vb_vshuff_Vb(idx1_v); + idx1_v = Q6_V_vor_VV(idx1_v, mask_idx2_v); + idx2_v = Q6_Vw_vasl_VwR(idx1_v, 16); + + /* Obtain the polynomial coefficients from lookup table */ + c0_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(c0_coeff_dv.VV), 1); + c0_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c0_coeff_vp, idx2_v, Q6_V_hi_W(c0_coeff_dv.VV), 1); + c1_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(c1_coeff_dv.VV), 1); + c1_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c1_coeff_vp, idx2_v, Q6_V_hi_W(c1_coeff_dv.VV), 1); + c2_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(c2_coeff_dv.VV), 1); + c2_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c2_coeff_vp, idx2_v, Q6_V_hi_W(c2_coeff_dv.VV), 1); + c3_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(c3_coeff_dv.VV), 1); + c3_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c3_coeff_vp, idx2_v, Q6_V_hi_W(c3_coeff_dv.VV), 1); + c4_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(c4_coeff_dv.VV), 1); + c4_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c4_coeff_vp, idx2_v, Q6_V_hi_W(c4_coeff_dv.VV), 1); + + /* Convert input from hf vector to qf32 vector pair for Horner's method*/ + input_vp_qf32 = Q6_Wqf32_vmpy_VhfVhf(sline1_low, one_v_hf); + + /* Perform evaluation of polynomial using Horner's method */ + output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(c4_coeff_vp), Q6_V_lo_W(input_vp_qf32)); + output_dv.V.lo = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(c3_coeff_vp)); + output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(input_vp_qf32)); + output_dv.V.lo = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(c2_coeff_vp)); + output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(input_vp_qf32)); + output_dv.V.lo = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(c1_coeff_vp)); + output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(input_vp_qf32)); + output_dv.V.lo = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(c0_coeff_vp)); + + output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(c4_coeff_vp), Q6_V_hi_W(input_vp_qf32)); + output_dv.V.hi = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(c3_coeff_vp)); + output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(input_vp_qf32)); + output_dv.V.hi = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(c2_coeff_vp)); + output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(input_vp_qf32)); + output_dv.V.hi = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(c1_coeff_vp)); + output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(input_vp_qf32)); + output_dv.V.hi = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(c0_coeff_vp)); + + // x * sigmod + output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(input_vp_qf32), output_dv.V.lo); + output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(input_vp_qf32), output_dv.V.hi); + + sline1_low = Q6_Vhf_equals_Wqf32(output_dv.VV); + } - tsout1 = Q6_V_vmux_QVV(expgte0, tsout1, tsout2); - sout1 = Q6_V_vmux_QVV(expgte23, sout1, tsout1); + { + // silu sline1_high + tmp_v = Q6_Vh_vdeal_Vh(sline1_high); + + /* Shift input range from [input_min, input_max] to [0, input_max - input_min] */ + input_shifted_v_hf = Q6_Vqf16_vsub_VhfVhf(tmp_v, input_min_v_hf); + + /* + * Scale shifted input range from [0, input_max - input_min] to [0,16.0) + * in order to get corresponding coefficient indexes + */ + input_scaled_v = Q6_Vqf16_vmpy_Vqf16Vqf16(input_shifted_v_hf, scale_v); + + /* + * VLUT 16 requires integer indexes. Shift scaled input range from [0,16.0) + * to [16.0,32.0) in order to convert float indexes to integer values. + * Float values, represented in IEEE 754, in range [16.0,32.0] have the + * same exponent, which means 4 MSB of mantissa carry information about + * integer index. + * Use the same input_scaled_v vector for hf and qf16 representation + */ + input_scaled_v = Q6_Vqf16_vadd_Vqf16Vhf(input_scaled_v, const16_0_v_hf); + + /* Convert back from qf16 to hf in order to extract integer index */ + tmp_v = Q6_Vhf_equals_Vqf16(input_scaled_v); + + /* Only 4 MSB bits of mantissa represent segment index */ + idx1_v = Q6_Vuh_vlsr_VuhR(tmp_v, 6); + + /* Ensure only 4 MSB bits of mantissa are used as indexes */ + idx1_v = Q6_V_vand_VV(idx1_v, mask_idx1_v); + + idx1_v = Q6_Vb_vshuff_Vb(idx1_v); + idx1_v = Q6_V_vor_VV(idx1_v, mask_idx2_v); + idx2_v = Q6_Vw_vasl_VwR(idx1_v, 16); + + /* Obtain the polynomial coefficients from lookup table */ + c0_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(c0_coeff_dv.VV), 1); + c0_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c0_coeff_vp, idx2_v, Q6_V_hi_W(c0_coeff_dv.VV), 1); + c1_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(c1_coeff_dv.VV), 1); + c1_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c1_coeff_vp, idx2_v, Q6_V_hi_W(c1_coeff_dv.VV), 1); + c2_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(c2_coeff_dv.VV), 1); + c2_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c2_coeff_vp, idx2_v, Q6_V_hi_W(c2_coeff_dv.VV), 1); + c3_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(c3_coeff_dv.VV), 1); + c3_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c3_coeff_vp, idx2_v, Q6_V_hi_W(c3_coeff_dv.VV), 1); + c4_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(c4_coeff_dv.VV), 1); + c4_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c4_coeff_vp, idx2_v, Q6_V_hi_W(c4_coeff_dv.VV), 1); + + /* Convert input from hf vector to qf32 vector pair for Horner's method*/ + input_vp_qf32 = Q6_Wqf32_vmpy_VhfVhf(sline1_high, one_v_hf); + + /* Perform evaluation of polynomial using Horner's method */ + output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(c4_coeff_vp), Q6_V_lo_W(input_vp_qf32)); + output_dv.V.lo = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(c3_coeff_vp)); + output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(input_vp_qf32)); + output_dv.V.lo = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(c2_coeff_vp)); + output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(input_vp_qf32)); + output_dv.V.lo = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(c1_coeff_vp)); + output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(input_vp_qf32)); + output_dv.V.lo = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(c0_coeff_vp)); + + output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(c4_coeff_vp), Q6_V_hi_W(input_vp_qf32)); + output_dv.V.hi = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(c3_coeff_vp)); + output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(input_vp_qf32)); + output_dv.V.hi = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(c2_coeff_vp)); + output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(input_vp_qf32)); + output_dv.V.hi = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(c1_coeff_vp)); + output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(input_vp_qf32)); + output_dv.V.hi = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(c0_coeff_vp)); + + // x * sigmod + output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(input_vp_qf32), output_dv.V.lo); + output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(input_vp_qf32), output_dv.V.hi); + + sline1_high = Q6_Vhf_equals_Wqf32(output_dv.VV); } - sout1 = Q6_Vh_equals_Vhf(sout1); + HVX_Vector sline_high; + HVX_Vector sline_low; + { + // mul + sline_high = Q6_Vqf16_vmpy_VhfVhf(sline1_high, sline2_high); + sline_low = Q6_Vqf16_vmpy_VhfVhf(sline1_low, sline2_low); - HVX_Vector sout2 = Q6_Vqf16_vmpy_VhfVhf(sline_high, o_scale_vec); - sout2 = Q6_Vqf16_vadd_Vqf16Vqf16(sout2, es_vec); - sout2 = Q6_Vhf_equals_Vqf16(sout2); - sout2 = Q6_Vhf_vmin_VhfVhf(sout2, high_level_vec); - sout2 = Q6_Vhf_vmax_VhfVhf(sout2, low_level_vec); + sline_high = Q6_Vhf_equals_Vqf16(sline_high); + sline_low = Q6_Vhf_equals_Vqf16(sline_low); + } { - HVX_Vector exp = Q6_Vh_vasr_VhR(sout2, FP16_MANTISA); - exp = Q6_V_vand_VV(exp, expmask); - exp = Q6_Vh_vsub_VhVh(exp, expbias); + // quantize + HVX_Vector sout1 = Q6_Vqf16_vmpy_VhfVhf(sline_low, o_scale_vec); + sout1 = Q6_Vqf16_vadd_Vqf16Vqf16(sout1, es_vec); + sout1 = Q6_Vhf_equals_Vqf16(sout1); + sout1 = Q6_Vhf_vmin_VhfVhf(sout1, high_level_vec); + sout1 = Q6_Vhf_vmax_VhfVhf(sout1, low_level_vec); - HVX_Vector man = Q6_Vh_vasr_VhVh(manmask, exp); - HVX_Vector manzero = Q6_V_vand_VV(sout2, man); + { + HVX_Vector exp = Q6_Vh_vasr_VhR(sout1, FP16_MANTISA); + exp = Q6_V_vand_VV(exp, expmask); + exp = Q6_Vh_vsub_VhVh(exp, expbias); - HVX_Vector sign = Q6_Vh_vasr_VhR(sout2, FP16_SIGN); - HVX_Vector issignpos = Q6_Q_vcmp_eq_VhVh(sign, zero); + HVX_Vector man = Q6_Vh_vasr_VhVh(manmask, exp); + HVX_Vector manzero = Q6_V_vand_VV(sout1, man); - HVX_Vector expgte23 = Q6_Q_vcmp_gt_VhVh(exp, exp23); - HVX_Vector expgte0 = Q6_Q_vcmp_gt_VhVh(exp, exp0); - HVX_Vector maneqzero = Q6_Q_vcmp_eq_VhVh(manzero, zero); + HVX_Vector sign = Q6_Vh_vasr_VhR(sout1, FP16_SIGN); + HVX_Vector issignpos = Q6_Q_vcmp_eq_VhVh(sign, zero); - HVX_Vector exppos_signneg = Q6_Vh_vadd_VhVh(sout2, man); - man = Q6_V_vnot_V(man); - HVX_Vector exppos_signpos = Q6_V_vand_VV(sout2, man); - exppos_signneg = Q6_V_vand_VV(exppos_signneg, man); - HVX_Vector shift1 = Q6_Vh_vasl_VhR(sout2, 1); - HVX_Vector iszero = Q6_Q_vcmp_eq_VhVh(shift1, zero); + HVX_Vector expgte23 = Q6_Q_vcmp_gt_VhVh(exp, exp23); + HVX_Vector expgte0 = Q6_Q_vcmp_gt_VhVh(exp, exp0); + HVX_Vector maneqzero = Q6_Q_vcmp_eq_VhVh(manzero, zero); - // exp >= 0 - HVX_Vector tsout1 = Q6_V_vmux_QVV(issignpos, exppos_signpos, exppos_signneg); - tsout1 = Q6_V_vmux_QVV(maneqzero, sout2, tsout1); + HVX_Vector exppos_signneg = Q6_Vh_vadd_VhVh(sout1, man); + man = Q6_V_vnot_V(man); + HVX_Vector exppos_signpos = Q6_V_vand_VV(sout1, man); + exppos_signneg = Q6_V_vand_VV(exppos_signneg, man); + HVX_Vector shift1 = Q6_Vh_vasl_VhR(sout1, 1); + HVX_Vector iszero = Q6_Q_vcmp_eq_VhVh(shift1, zero); - // exp < 0 (-1, 1) - HVX_Vector tsout2 = Q6_V_vmux_QVV(iszero, sout2, negone); - tsout2 = Q6_V_vmux_QVV(issignpos, zero, tsout2); + // exp >= 0 + HVX_Vector tsout1 = Q6_V_vmux_QVV(issignpos, exppos_signpos, exppos_signneg); + tsout1 = Q6_V_vmux_QVV(maneqzero, sout1, tsout1); - tsout1 = Q6_V_vmux_QVV(expgte0, tsout1, tsout2); - sout2 = Q6_V_vmux_QVV(expgte23, sout2, tsout1); - } + // exp < 0 (-1, 1) + HVX_Vector tsout2 = Q6_V_vmux_QVV(iszero, sout1, negone); + tsout2 = Q6_V_vmux_QVV(issignpos, zero, tsout2); + + tsout1 = Q6_V_vmux_QVV(expgte0, tsout1, tsout2); + sout1 = Q6_V_vmux_QVV(expgte23, sout1, tsout1); + } - sout2 = Q6_Vh_equals_Vhf(sout2); + sout1 = Q6_Vh_equals_Vhf(sout1); - HVX_Vector reql_h = Q6_Vb_vpack_VhVh_sat(sout2, sout1); - *optr++ = Q6_Vb_vadd_VbVb(reql_h, uintconvert); + HVX_Vector sout2 = Q6_Vqf16_vmpy_VhfVhf(sline_high, o_scale_vec); + sout2 = Q6_Vqf16_vadd_Vqf16Vqf16(sout2, es_vec); + sout2 = Q6_Vhf_equals_Vqf16(sout2); + sout2 = Q6_Vhf_vmin_VhfVhf(sout2, high_level_vec); + sout2 = Q6_Vhf_vmax_VhfVhf(sout2, low_level_vec); + + { + HVX_Vector exp = Q6_Vh_vasr_VhR(sout2, FP16_MANTISA); + exp = Q6_V_vand_VV(exp, expmask); + exp = Q6_Vh_vsub_VhVh(exp, expbias); - } + HVX_Vector man = Q6_Vh_vasr_VhVh(manmask, exp); + HVX_Vector manzero = Q6_V_vand_VV(sout2, man); + HVX_Vector sign = Q6_Vh_vasr_VhR(sout2, FP16_SIGN); + HVX_Vector issignpos = Q6_Q_vcmp_eq_VhVh(sign, zero); + + HVX_Vector expgte23 = Q6_Q_vcmp_gt_VhVh(exp, exp23); + HVX_Vector expgte0 = Q6_Q_vcmp_gt_VhVh(exp, exp0); + HVX_Vector maneqzero = Q6_Q_vcmp_eq_VhVh(manzero, zero); + + HVX_Vector exppos_signneg = Q6_Vh_vadd_VhVh(sout2, man); + man = Q6_V_vnot_V(man); + HVX_Vector exppos_signpos = Q6_V_vand_VV(sout2, man); + exppos_signneg = Q6_V_vand_VV(exppos_signneg, man); + HVX_Vector shift1 = Q6_Vh_vasl_VhR(sout2, 1); + HVX_Vector iszero = Q6_Q_vcmp_eq_VhVh(shift1, zero); + + // exp >= 0 + HVX_Vector tsout1 = Q6_V_vmux_QVV(issignpos, exppos_signpos, exppos_signneg); + tsout1 = Q6_V_vmux_QVV(maneqzero, sout2, tsout1); + + // exp < 0 (-1, 1) + HVX_Vector tsout2 = Q6_V_vmux_QVV(iszero, sout2, negone); + tsout2 = Q6_V_vmux_QVV(issignpos, zero, tsout2); + + tsout1 = Q6_V_vmux_QVV(expgte0, tsout1, tsout2); + sout2 = Q6_V_vmux_QVV(expgte23, sout2, tsout1); + } + + sout2 = Q6_Vh_equals_Vhf(sout2); + + HVX_Vector reql_h = Q6_Vb_vpack_VhVh_sat(sout2, sout1); + *optr++ = Q6_Vb_vadd_VbVb(reql_h, uintconvert); + } } // // Handle leftover elements. @@ -1013,7 +1120,6 @@ int32_t hvx_supersilu_ahf( // : *iptr++); // sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); - // sline2c = (is_in_one_chunk(iptr2, leftover_size, VLEN) // ? sline2p // : *iptr2++); @@ -1025,146 +1131,131 @@ int32_t hvx_supersilu_ahf( return 0; } - -template -GraphStatus llamasupersiluImpl(TensorType& out_0, - const TensorType& in_0, - const TensorType& in_1, - const PlainFloatTensor& a_scale, - const PlainFloatTensor& b_scale, - const PlainFloatTensor& o_scale) +template +GraphStatus llamasupersiluImpl(TensorType &out_0, + const TensorType &in_0, + const TensorType &in_1, + const PlainFloatTensor &a_scale, + const PlainFloatTensor &b_scale, + const PlainFloatTensor &o_scale) { - /* - * add code here - * */ - /* - * To have good performance and stability, it is required to avoid heap memory - * allocation in this function. The heap memory allocation includes but not - * limited to calling malloc, operator new, constructing STL container objects - * like std::vector with default allocator, and adding items like calling - * std::vector::push_back to STL container objects with default allocator. - * - * Please check in SDK documentation for more information. - */ - out_0.set_dims(in_0); - - auto [b_in, h_in, w_in, d_in] = in_0.dims(); - size_t size = b_in*h_in*w_in*d_in; - - - float a_scale_ = a_scale(0,0,0,0); - float b_scale_ = b_scale(0,0,0,0); - float o_scale_ = o_scale(0,0,0,0); - - auto in_ptr = (uint8_t*)in_0.raw_data_const(); - auto in_ptr2 = (uint8_t*)in_1.raw_data_const(); - - auto out_ptr = (uint8_t*)out_0.raw_data(); + /* + * add code here + * */ + /* + * To have good performance and stability, it is required to avoid heap memory + * allocation in this function. The heap memory allocation includes but not + * limited to calling malloc, operator new, constructing STL container objects + * like std::vector with default allocator, and adding items like calling + * std::vector::push_back to STL container objects with default allocator. + * + * Please check in SDK documentation for more information. + */ + out_0.set_dims(in_0); + + auto [b_in, h_in, w_in, d_in] = in_0.dims(); + size_t size = b_in * h_in * w_in * d_in; + float a_scale_ = a_scale(0, 0, 0, 0); + float b_scale_ = b_scale(0, 0, 0, 0); + float o_scale_ = o_scale(0, 0, 0, 0); + + auto in_ptr = (uint8_t *)in_0.raw_data_const(); + auto in_ptr2 = (uint8_t *)in_1.raw_data_const(); + + auto out_ptr = (uint8_t *)out_0.raw_data(); DType dtype = in_0.get_dtype(); - if (dtype == DType::QUInt8 && out_0.get_dtype() == DType::QUInt8) { - hvx_supersilu_ahf(in_ptr, in_ptr2, out_ptr, a_scale_, b_scale_, 1.0f/o_scale_, size); - } + if (dtype == DType::QUInt8 && out_0.get_dtype() == DType::QUInt8) { + hvx_supersilu_ahf(in_ptr, in_ptr2, out_ptr, a_scale_, b_scale_, 1.0f / o_scale_, size); + } - return GraphStatus::Success; + return GraphStatus::Success; } #else -template -GraphStatus llamasupersiluImpl(TensorType& out_0, - const TensorType& in_0, - const TensorType& in_1, - const PlainFloatTensor& a_scale, - const PlainFloatTensor& b_scale, - const PlainFloatTensor& o_scale) +template +GraphStatus llamasupersiluImpl(TensorType &out_0, + const TensorType &in_0, + const TensorType &in_1, + const PlainFloatTensor &a_scale, + const PlainFloatTensor &b_scale, + const PlainFloatTensor &o_scale) { - /* - * add code here - * */ - /* - * To have good performance and stability, it is required to avoid heap memory - * allocation in this function. The heap memory allocation includes but not - * limited to calling malloc, operator new, constructing STL container objects - * like std::vector with default allocator, and adding items like calling - * std::vector::push_back to STL container objects with default allocator. - * - * Please check in SDK documentation for more information. - */ - - out_0.set_dims(in_0); + /* + * add code here + * */ + /* + * To have good performance and stability, it is required to avoid heap memory + * allocation in this function. The heap memory allocation includes but not + * limited to calling malloc, operator new, constructing STL container objects + * like std::vector with default allocator, and adding items like calling + * std::vector::push_back to STL container objects with default allocator. + * + * Please check in SDK documentation for more information. + */ - float a_scale_ = a_scale(0,0,0,0); - float b_scale_ = b_scale(0,0,0,0); - float o_scale_ = o_scale(0,0,0,0); + out_0.set_dims(in_0); - auto in_ptr = (uint8_t*)in_0.raw_data_const(); - auto in_ptr2 = (uint8_t*)in_1.raw_data_const(); + float a_scale_ = a_scale(0, 0, 0, 0); + float b_scale_ = b_scale(0, 0, 0, 0); + float o_scale_ = o_scale(0, 0, 0, 0); - auto out_ptr = (uint8_t*)out_0.raw_data(); + auto in_ptr = (uint8_t *)in_0.raw_data_const(); + auto in_ptr2 = (uint8_t *)in_1.raw_data_const(); + auto out_ptr = (uint8_t *)out_0.raw_data(); - auto [b_in, h_in, w_in, d_in] = in_0.dims(); - for (Idx b = 0; b < b_in; b++) { - for (Idx h = 0; h < h_in; h++) { - for (Idx w = 0; w < w_in; w++) { - // mul - for (Idx d = 0; d < d_in; d++) { + auto [b_in, h_in, w_in, d_in] = in_0.dims(); + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + for (Idx w = 0; w < w_in; w++) { + // mul + for (Idx d = 0; d < d_in; d++) { + int32_t a_inval = static_cast(*in_ptr++); + float a_inval_fp16 = (a_inval - 128) * a_scale_; + int32_t b_inval = static_cast(*in_ptr2++); + float b_inval_fp16 = (b_inval - 128) * b_scale_; - int32_t a_inval = static_cast(*in_ptr++); - float a_inval_fp16 = (a_inval-128) * a_scale_; + a_inval_fp16 = a_inval_fp16 * (1 / (1 + expf(-a_inval_fp16))); + float inval = a_inval_fp16 * b_inval_fp16; - int32_t b_inval = static_cast(*in_ptr2++); - float b_inval_fp16 = (b_inval-128) * b_scale_; + long v = lroundf(inval / o_scale_); - - a_inval_fp16 = a_inval_fp16 * (1 / (1 + expf(-a_inval_fp16))); + if (v > 127) + v = 127; - float inval = a_inval_fp16 * b_inval_fp16; - - long v = lroundf(inval / o_scale_); - - if (v > 127) - v = 127; - - if (v < -128) - v = -128; - - v += 128; + if (v < -128) + v = -128; - *out_ptr++ = static_cast(v); + v += 128; + *out_ptr++ = static_cast(v); + } + } } - } } - } - - return GraphStatus::Success; + return GraphStatus::Success; } #endif -__attribute__((unused)) static float llamasupersiluCostFunc(const Op *op) -{ - /* - * add code here - * */ +__attribute__((unused)) static float llamasupersiluCostFunc(const Op *op) { + /* + * add code here + * */ - float cost = 0.0; // add cost computation here - return cost; + float cost = 0.0; // add cost computation here + return cost; } - - - - /* At the bottom of the op file, call END_PKG_OP_DEFINITION(), where is as BEGIN_PKG_OP_DEFINITION */ diff --git a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/MergeOutput.cpp b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/MergeOutput.cpp index 6c573ab2d..e001c4e3e 100755 --- a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/MergeOutput.cpp +++ b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/MergeOutput.cpp @@ -9,18 +9,16 @@ #include "QnnOpPackage.h" #include "HTP/core/simple_reg.h" - BEGIN_PKG_OP_DEFINITION(PKG_MergeOutput); - // op execute function declarations -template -GraphStatus mergeoutputImpl(TensorType& out_0, - const TensorType& in_0, - const TensorType& in_1, - const TensorType& in_2, - const TensorType& in_3, - const Tensor& num); +template +GraphStatus mergeoutputImpl(TensorType &out_0, + const TensorType &in_0, + const TensorType &in_1, + const TensorType &in_2, + const TensorType &in_3, + const Tensor &num); // forward declaration of sample cost function static float mergeoutputCostFunc(const Op *op); @@ -65,11 +63,11 @@ DEF_PACKAGE_OP((mergeoutputImpl), "MergeOutput") * one definition per op, and this is optional * syntax: DEF_PACKAGE_PARAM_ORDER(OP,PARAM1,MANDATORY1,DEFAULT1,PARAM2,MANDATORY2,DEFAULT2...) * one or more parameters can be specified for each op - * order of parameters listed determines the order of parameters passed into op execution functions + * order of parameters listed determines the order of parameters passed into op execution functions * if an op does not have a parameter order definition, parameter order passed into Qnn_addNode * will be passed into op execution functions * if an op has a parameter order definition, any parameter passed into Qnn_addNode with unlisted - * name will be abandoned + * name will be abandoned * if two or more op packages with the same package name will be registered, they cannot list * conflicting parameter orders * PARAM refers to parameter name as a string literal @@ -83,90 +81,81 @@ DEF_PACKAGE_OP((mergeoutputImpl), "MergeOutput") * Qnn_addNode */ - /* execute functions for ops */ -template -GraphStatus mergeoutputImpl(TensorType& out_0, - const TensorType& in_0, - const TensorType& in_1, - const TensorType& in_2, - const TensorType& in_3, - const Tensor& num) +template +GraphStatus mergeoutputImpl(TensorType &out_0, + const TensorType &in_0, + const TensorType &in_1, + const TensorType &in_2, + const TensorType &in_3, + const Tensor &num) { - /* - * add code here - * */ - /* - * To have good performance and stability, it is required to avoid heap memory - * allocation in this function. The heap memory allocation includes but not - * limited to calling malloc, operator new, constructing STL container objects - * like std::vector with default allocator, and adding items like calling - * std::vector::push_back to STL container objects with default allocator. - * - * Please check in SDK documentation for more information. - */ - - auto [b_in_0, h_in_0, w_in_0, d_in_0] = in_0.dims(); - auto [b_in_1, h_in_1, w_in_1, d_in_1] = in_1.dims(); - auto [b_in_2, h_in_2, w_in_2, d_in_2] = in_2.dims(); - auto [b_in_3, h_in_3, w_in_3, d_in_3] = in_3.dims(); - - const size_t dims[] = {b_in_0, h_in_0 + h_in_1 + h_in_2 + h_in_3 * 4, w_in_0, d_in_0}; + /* + * add code here + * */ + /* + * To have good performance and stability, it is required to avoid heap memory + * allocation in this function. The heap memory allocation includes but not + * limited to calling malloc, operator new, constructing STL container objects + * like std::vector with default allocator, and adding items like calling + * std::vector::push_back to STL container objects with default allocator. + * + * Please check in SDK documentation for more information. + */ - out_0.set_dims(dims); + auto [b_in_0, h_in_0, w_in_0, d_in_0] = in_0.dims(); + auto [b_in_1, h_in_1, w_in_1, d_in_1] = in_1.dims(); + auto [b_in_2, h_in_2, w_in_2, d_in_2] = in_2.dims(); + auto [b_in_3, h_in_3, w_in_3, d_in_3] = in_3.dims(); - DType dtype = in_0.get_dtype(); - uint32_t bitwidth = 4; + const size_t dims[] = {b_in_0, h_in_0 + h_in_1 + h_in_2 + h_in_3 * 4, w_in_0, d_in_0}; - if (dtype == DType::QUInt8 || dtype == DType::QInt8) { + out_0.set_dims(dims); - bitwidth = 1; + DType dtype = in_0.get_dtype(); + uint32_t bitwidth = 4; - } else if (dtype == DType::Float16) { + if (dtype == DType::QUInt8 || dtype == DType::QInt8) { + bitwidth = 1; - bitwidth = 2; - } else if (dtype == DType::Float32) { + } else if (dtype == DType::Float16) { + bitwidth = 2; + } else if (dtype == DType::Float32) { + bitwidth = 4; + } - bitwidth = 4; - } + const uint8_t *in_ptr_0 = (uint8_t *)in_0.raw_data_const(); + const uint8_t *in_ptr_1 = (uint8_t *)in_1.raw_data_const(); + const uint8_t *in_ptr_2 = (uint8_t *)in_2.raw_data_const(); + // const uint8_t *in_ptr_3 = (uint8_t*)in_3.raw_data_const(); - const uint8_t *in_ptr_0 = (uint8_t*)in_0.raw_data_const(); - const uint8_t *in_ptr_1 = (uint8_t*)in_1.raw_data_const(); - const uint8_t *in_ptr_2 = (uint8_t*)in_2.raw_data_const(); -// const uint8_t *in_ptr_3 = (uint8_t*)in_3.raw_data_const(); - - uint8_t *out_ptr = (uint8_t*)out_0.raw_data(); + uint8_t *out_ptr = (uint8_t *)out_0.raw_data(); - memcpy(out_ptr, in_ptr_0, b_in_0 * h_in_0 * w_in_0 * d_in_0 * bitwidth); - out_ptr += b_in_0 * h_in_0 * w_in_0 * d_in_0 * bitwidth; + memcpy(out_ptr, in_ptr_0, b_in_0 * h_in_0 * w_in_0 * d_in_0 * bitwidth); + out_ptr += b_in_0 * h_in_0 * w_in_0 * d_in_0 * bitwidth; - memcpy(out_ptr, in_ptr_1, b_in_1 * h_in_1 * w_in_1 * d_in_1 * bitwidth); - out_ptr += b_in_1 * h_in_1 * w_in_1 * d_in_1 * bitwidth; + memcpy(out_ptr, in_ptr_1, b_in_1 * h_in_1 * w_in_1 * d_in_1 * bitwidth); + out_ptr += b_in_1 * h_in_1 * w_in_1 * d_in_1 * bitwidth; - memcpy(out_ptr, in_ptr_2, b_in_2 * h_in_2 * w_in_2 * d_in_2 * bitwidth); - out_ptr += b_in_2 * h_in_2 * w_in_2 * d_in_2 * bitwidth; + memcpy(out_ptr, in_ptr_2, b_in_2 * h_in_2 * w_in_2 * d_in_2 * bitwidth); + out_ptr += b_in_2 * h_in_2 * w_in_2 * d_in_2 * bitwidth; -// memcpy(out_ptr, in_ptr_3, b_in_3 * h_in_3 * w_in_3 * d_in_3 * bitwidth * 4); + // memcpy(out_ptr, in_ptr_3, b_in_3 * h_in_3 * w_in_3 * d_in_3 * bitwidth * 4); - return GraphStatus::Success; + return GraphStatus::Success; } -__attribute__((unused)) static float mergeoutputCostFunc(const Op *op) -{ - /* - * add code here - * */ +__attribute__((unused)) static float mergeoutputCostFunc(const Op *op) { + /* + * add code here + * */ - float cost = 0.0; // add cost computation here - return cost; + float cost = 0.0; // add cost computation here + return cost; } - - - - /* At the bottom of the op file, call END_PKG_OP_DEFINITION(), where is as BEGIN_PKG_OP_DEFINITION */ diff --git a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/QLayerNorm.cpp b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/QLayerNorm.cpp index 0bb733e4f..be61c7286 100755 --- a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/QLayerNorm.cpp +++ b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/QLayerNorm.cpp @@ -9,16 +9,14 @@ #include "QnnOpPackage.h" #include "HTP/core/simple_reg.h" - BEGIN_PKG_OP_DEFINITION(PKG_QLayerNorm); - // op execute function declarations -template -GraphStatus qlayernormImpl(TensorType& out_0, - const TensorType& in_0, - const TensorType& weights, - const TensorType& bias); +template +GraphStatus qlayernormImpl(TensorType &out_0, + const TensorType &in_0, + const TensorType &weights, + const TensorType &bias); // forward declaration of sample cost function static float qlayernormCostFunc(const Op *op); @@ -63,11 +61,11 @@ DEF_PACKAGE_OP((qlayernormImpl), "QLayerNorm") * one definition per op, and this is optional * syntax: DEF_PACKAGE_PARAM_ORDER(OP,PARAM1,MANDATORY1,DEFAULT1,PARAM2,MANDATORY2,DEFAULT2...) * one or more parameters can be specified for each op - * order of parameters listed determines the order of parameters passed into op execution functions + * order of parameters listed determines the order of parameters passed into op execution functions * if an op does not have a parameter order definition, parameter order passed into Qnn_addNode * will be passed into op execution functions * if an op has a parameter order definition, any parameter passed into Qnn_addNode with unlisted - * name will be abandoned + * name will be abandoned * if two or more op packages with the same package name will be registered, they cannot list * conflicting parameter orders * PARAM refers to parameter name as a string literal @@ -81,7 +79,6 @@ DEF_PACKAGE_OP((qlayernormImpl), "QLayerNorm") * Qnn_addNode */ - /* execute functions for ops */ #ifndef REFERENCE_OP @@ -90,18 +87,16 @@ DEF_PACKAGE_OP((qlayernormImpl), "QLayerNorm") #include #include -#define BLOCK_SIZE (8*1024/VLEN) /* vector chunks */ -#define L2FETCH_AHEAD (BLOCK_SIZE) +#define BLOCK_SIZE (8 * 1024 / VLEN) /* vector chunks */ +#define L2FETCH_AHEAD (BLOCK_SIZE) int32_t hvx_qlayernorm_af( float *restrict input, float *restrict weights, float *restrict bias, float *restrict output, - uint32_t size) -{ - if ((input == NULL) || (output == NULL) || (size == 0)) - { + uint32_t size) { + if ((input == NULL) || (output == NULL) || (size == 0)) { return -1; } @@ -125,58 +120,47 @@ int32_t hvx_qlayernorm_af( // sline1p = *iptr++; - // x sum HVX_Vector xsum = Q6_Vqf32_vadd_VsfVsf(Q6_V_vzero(), Q6_V_vzero()); - for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) - { + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { block = Q6_R_min_RR(i, BLOCK_SIZE); l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); - if (l2fetch_block > 0) - { + if (l2fetch_block > 0) { l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); } - for (int32_t j = 0; j < block; ++j) - { + for (int32_t j = 0; j < block; ++j) { sline1c = *iptr++; sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); - xsum = Q6_Vqf32_vadd_Vqf32Vqf32(xsum, sline1); - + xsum = Q6_Vqf32_vadd_Vqf32Vqf32(xsum, sline1); sline1p = sline1c; } } if (vectors_in_rounddown > 0) { - - sline1c = is_aligned(iptr, VLEN) && leftover == 0 ? sline1p : *iptr++; - sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t) input); - xsum = Q6_Vqf32_vadd_Vqf32Vqf32(xsum, sline1); - + sline1c = is_aligned(iptr, VLEN) && leftover == 0 ? sline1p : *iptr++; + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); + xsum = Q6_Vqf32_vadd_Vqf32Vqf32(xsum, sline1); } union { - float f; - uint32_t ui; + float f; + uint32_t ui; } mean_value; mean_value.f = 0.0f; - - - for (int32_t i = 64; i >= 4; i >>= 1) - { + for (int32_t i = 64; i >= 4; i >>= 1) { xsum = Q6_Vqf32_vadd_Vqf32Vqf32(xsum, Q6_V_vlalign_VVR(xsum, zero, i)); } xsum = Q6_Vsf_equals_Vqf32(xsum); - *(HVX_Vector *) tmp_buf = xsum; + *(HVX_Vector *)tmp_buf = xsum; mean_value.f = xsum[31] / size; - // x-e^2 sum iptr = (HVX_Vector *)input; sline1p = *iptr++; @@ -185,57 +169,49 @@ int32_t hvx_qlayernorm_af( HVX_Vector mean_vsf = Q6_V_vsplat_R(mean_value.ui); - for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) - { + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { block = Q6_R_min_RR(i, BLOCK_SIZE); l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); - if (l2fetch_block > 0) - { + if (l2fetch_block > 0) { l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); } - for (int32_t j = 0; j < block; ++j) - { + for (int32_t j = 0; j < block; ++j) { sline1c = *iptr++; sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); sline1 = Q6_Vqf32_vsub_Vqf32Vqf32(sline1, mean_vsf); - x2sum = Q6_Vqf32_vadd_Vqf32Vqf32(x2sum, Q6_Vqf32_vmpy_Vqf32Vqf32(sline1, sline1)); - + x2sum = Q6_Vqf32_vadd_Vqf32Vqf32(x2sum, Q6_Vqf32_vmpy_Vqf32Vqf32(sline1, sline1)); + sline1p = sline1c; } } if (vectors_in_rounddown > 0) { + sline1c = is_aligned(iptr, VLEN) && leftover == 0 ? sline1p : *iptr++; + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); - sline1c = is_aligned(iptr, VLEN) && leftover == 0 ? sline1p : *iptr++; - sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t) input); - - sline1 = Q6_Vqf32_vsub_Vqf32Vqf32(sline1, mean_vsf); - x2sum = Q6_Vqf32_vadd_Vqf32Vqf32(x2sum, Q6_Vqf32_vmpy_Vqf32Vqf32(sline1, sline1)); + sline1 = Q6_Vqf32_vsub_Vqf32Vqf32(sline1, mean_vsf); + x2sum = Q6_Vqf32_vadd_Vqf32Vqf32(x2sum, Q6_Vqf32_vmpy_Vqf32Vqf32(sline1, sline1)); } float epsilon_ = 1e-5; union { - float f; - uint32_t ui; + float f; + uint32_t ui; } sum_value; sum_value.f = 0.0f; - - - for (int32_t i = 64; i >= 4; i >>= 1) - { + for (int32_t i = 64; i >= 4; i >>= 1) { x2sum = Q6_Vqf32_vadd_Vqf32Vqf32(x2sum, Q6_V_vlalign_VVR(x2sum, zero, i)); } x2sum = Q6_Vsf_equals_Vqf32(x2sum); - *(HVX_Vector *) tmp_buf = x2sum; + *(HVX_Vector *)tmp_buf = x2sum; sum_value.f = 1.0f / sqrtf(x2sum[31] / size + epsilon_); - // x * 1/rsqrt(sum) iptr = (HVX_Vector *)input; sline1p = *iptr++; @@ -245,20 +221,17 @@ int32_t hvx_qlayernorm_af( HVX_Vector irsqrt_vsf = Q6_V_vsplat_R(sum_value.ui); HVX_Vector irsqrt_vqf32 = Q6_Vqf32_vadd_VsfVsf(irsqrt_vsf, Q6_V_vzero()); - for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) - { + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { block = Q6_R_min_RR(i, BLOCK_SIZE); l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); - if (l2fetch_block > 0) - { + if (l2fetch_block > 0) { l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); l2fetch(iptr2 + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); l2fetch(iptr3 + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); } - for (int32_t j = 0; j < block; ++j) - { + for (int32_t j = 0; j < block; ++j) { sline1c = *iptr++; sline2c = *iptr2++; sline3c = *iptr3++; @@ -281,107 +254,97 @@ int32_t hvx_qlayernorm_af( } if (vectors_in_rounddown > 0) { + sline1c = is_aligned(iptr, VLEN) && leftover == 0 ? sline1p : *iptr++; + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); - sline1c = is_aligned(iptr, VLEN) && leftover == 0 ? sline1p : *iptr++; - sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t) input); - - sline2c = is_aligned(iptr2, VLEN) && leftover == 0 ? sline2p : *iptr2++; - sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t) weights); + sline2c = is_aligned(iptr2, VLEN) && leftover == 0 ? sline2p : *iptr2++; + sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t)weights); - sline3c = is_aligned(iptr3, VLEN) && leftover == 0 ? sline3p : *iptr3++; - sline3 = Q6_V_valign_VVR(sline3c, sline3p, (size_t) weights); + sline3c = is_aligned(iptr3, VLEN) && leftover == 0 ? sline3p : *iptr3++; + sline3 = Q6_V_valign_VVR(sline3c, sline3p, (size_t)weights); - sline1 = Q6_Vqf32_vsub_VsfVsf(sline1, mean_vsf); + sline1 = Q6_Vqf32_vsub_VsfVsf(sline1, mean_vsf); - HVX_Vector middle_value_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(sline1, sline2); - middle_value_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(middle_value_qf32, irsqrt_vqf32); - middle_value_qf32 = Q6_Vqf32_vadd_Vqf32Vqf32(middle_value_qf32, sline3); - - *optr++ = Q6_Vsf_equals_Vqf32(middle_value_qf32); + HVX_Vector middle_value_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(sline1, sline2); + middle_value_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(middle_value_qf32, irsqrt_vqf32); + middle_value_qf32 = Q6_Vqf32_vadd_Vqf32Vqf32(middle_value_qf32, sline3); + *optr++ = Q6_Vsf_equals_Vqf32(middle_value_qf32); } - if (leftover_size > 0) - return -1; + return -1; return 0; } -template -GraphStatus qlayernormImpl(TensorType& out_0, - const TensorType& in_0, - const TensorType& weights, - const TensorType& bias) +template +GraphStatus qlayernormImpl(TensorType &out_0, + const TensorType &in_0, + const TensorType &weights, + const TensorType &bias) { - out_0.set_dims(in_0); - - // NHWC - - auto in_ptr = (float*)in_0.raw_data_const(); - auto out_ptr = (float*)out_0.raw_data(); - auto weights_ptr = (float*)weights.raw_data_const(); - auto bias_ptr = (float*)bias.raw_data_const(); - - - auto [b_in, h_in, w_in, d_in] = in_0.dims(); - for (Idx b = 0; b < b_in; b++) { - for (Idx h = 0; h < h_in; h++) { - for (Idx w = 0; w < w_in; w++) { - // RMS - hvx_qlayernorm_af(in_ptr, weights_ptr, bias_ptr, out_ptr, d_in); - - in_ptr += d_in; - out_ptr += d_in; - } + out_0.set_dims(in_0); + + // NHWC + + auto in_ptr = (float *)in_0.raw_data_const(); + auto out_ptr = (float *)out_0.raw_data(); + auto weights_ptr = (float *)weights.raw_data_const(); + auto bias_ptr = (float *)bias.raw_data_const(); + + auto [b_in, h_in, w_in, d_in] = in_0.dims(); + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + for (Idx w = 0; w < w_in; w++) { + // RMS + hvx_qlayernorm_af(in_ptr, weights_ptr, bias_ptr, out_ptr, d_in); + + in_ptr += d_in; + out_ptr += d_in; + } + } } - } - return GraphStatus::Success; + return GraphStatus::Success; } #else -template -GraphStatus qlayernormImpl(TensorType& out_0, - const TensorType& in_0, - const TensorType& weights, - const TensorType& bias) +template +GraphStatus qlayernormImpl(TensorType &out_0, + const TensorType &in_0, + const TensorType &weights, + const TensorType &bias) { - /* - * add code here - * */ - /* - * To have good performance and stability, it is required to avoid heap memory - * allocation in this function. The heap memory allocation includes but not - * limited to calling malloc, operator new, constructing STL container objects - * like std::vector with default allocator, and adding items like calling - * std::vector::push_back to STL container objects with default allocator. - * - * Please check in SDK documentation for more information. - */ - return GraphStatus::Success; + /* + * add code here + * */ + /* + * To have good performance and stability, it is required to avoid heap memory + * allocation in this function. The heap memory allocation includes but not + * limited to calling malloc, operator new, constructing STL container objects + * like std::vector with default allocator, and adding items like calling + * std::vector::push_back to STL container objects with default allocator. + * + * Please check in SDK documentation for more information. + */ + return GraphStatus::Success; } #endif -__attribute__((unused)) static float qlayernormCostFunc(const Op *op) -{ - /* - * add code here - * */ +__attribute__((unused)) static float qlayernormCostFunc(const Op *op) { + /* + * add code here + * */ - float cost = 0.0; // add cost computation here - return cost; + float cost = 0.0; // add cost computation here + return cost; } - - - - - /* At the bottom of the op file, call END_PKG_OP_DEFINITION(), where is as BEGIN_PKG_OP_DEFINITION */ diff --git a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/RMSNorm.cpp b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/RMSNorm.cpp index bd079a2c9..1197eab19 100755 --- a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/RMSNorm.cpp +++ b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/RMSNorm.cpp @@ -9,15 +9,13 @@ #include "QnnOpPackage.h" #include "HTP/core/simple_reg.h" - BEGIN_PKG_OP_DEFINITION(PKG_RMSNorm); - // op execute function declarations -template -GraphStatus rmsnormImpl(TensorType& out_0, - const TensorType& in_0, - const TensorType& weights); +template +GraphStatus rmsnormImpl(TensorType &out_0, + const TensorType &in_0, + const TensorType &weights); // forward declaration of sample cost function static float rmsnormCostFunc(const Op *op); @@ -62,11 +60,11 @@ DEF_PACKAGE_OP((rmsnormImpl), "RMSNorm") * one definition per op, and this is optional * syntax: DEF_PACKAGE_PARAM_ORDER(OP,PARAM1,MANDATORY1,DEFAULT1,PARAM2,MANDATORY2,DEFAULT2...) * one or more parameters can be specified for each op - * order of parameters listed determines the order of parameters passed into op execution functions + * order of parameters listed determines the order of parameters passed into op execution functions * if an op does not have a parameter order definition, parameter order passed into Qnn_addNode * will be passed into op execution functions * if an op has a parameter order definition, any parameter passed into Qnn_addNode with unlisted - * name will be abandoned + * name will be abandoned * if two or more op packages with the same package name will be registered, they cannot list * conflicting parameter orders * PARAM refers to parameter name as a string literal @@ -80,7 +78,6 @@ DEF_PACKAGE_OP((rmsnormImpl), "RMSNorm") * Qnn_addNode */ - /* execute functions for ops */ #ifndef REFERENCE_OP @@ -90,17 +87,15 @@ DEF_PACKAGE_OP((rmsnormImpl), "RMSNorm") #include #include -#define BLOCK_SIZE (8*1024/VLEN) /* vector chunks */ -#define L2FETCH_AHEAD (BLOCK_SIZE) +#define BLOCK_SIZE (8 * 1024 / VLEN) /* vector chunks */ +#define L2FETCH_AHEAD (BLOCK_SIZE) int32_t hvx_rmsnorm_af( float *restrict input, float *restrict weights, float *restrict output, - uint32_t size) -{ - if ((input == NULL) || (output == NULL) || (size == 0)) - { + uint32_t size) { + if ((input == NULL) || (output == NULL) || (size == 0)) { return -1; } @@ -117,56 +112,47 @@ int32_t hvx_rmsnorm_af( sline1p = *iptr++; - // ^2 sum HVX_Vector sum = Q6_Vqf32_vadd_VsfVsf(Q6_V_vzero(), Q6_V_vzero()); - for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) - { + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { block = Q6_R_min_RR(i, BLOCK_SIZE); l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); - if (l2fetch_block > 0) - { + if (l2fetch_block > 0) { l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); } - for (int32_t j = 0; j < block; ++j) - { + for (int32_t j = 0; j < block; ++j) { sline1c = *iptr++; sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); - sum = Q6_Vqf32_vadd_Vqf32Vqf32(sum, Q6_Vqf32_vmpy_VsfVsf(sline1, sline1)); - + sum = Q6_Vqf32_vadd_Vqf32Vqf32(sum, Q6_Vqf32_vmpy_VsfVsf(sline1, sline1)); sline1p = sline1c; } } if (vectors_in_rounddown > 0) { - - sline1c = is_aligned(iptr, VLEN) && leftover == 0 ? sline1p : *iptr++; - sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t) input); - sum = Q6_Vqf32_vadd_Vqf32Vqf32(sum, Q6_Vqf32_vmpy_VsfVsf(sline1, sline1)); - + sline1c = is_aligned(iptr, VLEN) && leftover == 0 ? sline1p : *iptr++; + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); + sum = Q6_Vqf32_vadd_Vqf32Vqf32(sum, Q6_Vqf32_vmpy_VsfVsf(sline1, sline1)); } float epsilon_ = 1e-6; union { - float f; - uint32_t ui; + float f; + uint32_t ui; } sum_value; sum_value.f = 0.0f; - HVX_Vector zero = Q6_V_vzero(); - for (int32_t i = 64; i >= 4; i >>= 1) - { + for (int32_t i = 64; i >= 4; i >>= 1) { sum = Q6_Vqf32_vadd_Vqf32Vqf32(sum, Q6_V_vlalign_VVR(sum, zero, i)); } sum = Q6_Vsf_equals_Vqf32(sum); - sum_value.f = 1.0f / sqrtf(*((float*)&sum + 31) / size + epsilon_); + sum_value.f = 1.0f / sqrtf(*((float *)&sum + 31) / size + epsilon_); // x * 1/rsqrt(sum) iptr = (HVX_Vector *)input; @@ -176,19 +162,16 @@ int32_t hvx_rmsnorm_af( HVX_Vector irsqrt_vsf = Q6_V_vsplat_R(sum_value.ui); HVX_Vector irsqrt_vqf32 = Q6_Vqf32_vadd_VsfVsf(irsqrt_vsf, Q6_V_vzero()); - for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) - { + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { block = Q6_R_min_RR(i, BLOCK_SIZE); l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); - if (l2fetch_block > 0) - { + if (l2fetch_block > 0) { l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); l2fetch(iptr2 + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); } - for (int32_t j = 0; j < block; ++j) - { + for (int32_t j = 0; j < block; ++j) { sline1c = *iptr++; sline2c = *iptr2++; sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); @@ -203,33 +186,31 @@ int32_t hvx_rmsnorm_af( } if (vectors_in_rounddown > 0) { + sline1c = is_aligned(iptr, VLEN) && leftover == 0 ? sline1p : *iptr++; + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); - sline1c = is_aligned(iptr, VLEN) && leftover == 0 ? sline1p : *iptr++; - sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t) input); - - sline2c = is_aligned(iptr2, VLEN) && leftover == 0 ? sline2p : *iptr2++; - sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t) weights); - - HVX_Vector middle_value_qf32 = Q6_Vqf32_vmpy_VsfVsf(sline1, sline2); - *optr++ = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_Vqf32Vqf32(middle_value_qf32, irsqrt_vqf32)); + sline2c = is_aligned(iptr2, VLEN) && leftover == 0 ? sline2p : *iptr2++; + sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t)weights); + HVX_Vector middle_value_qf32 = Q6_Vqf32_vmpy_VsfVsf(sline1, sline2); + *optr++ = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_Vqf32Vqf32(middle_value_qf32, irsqrt_vqf32)); } - if (leftover_size > 0) - return -1; + return -1; return 0; } -static HVX_INLINE_ALWAYS uint32_t float_to_bits(float x) -{ - union { float f; uint32_t i; } fp32 = { .f = x }; +static HVX_INLINE_ALWAYS uint32_t float_to_bits(float x) { + union { + float f; + uint32_t i; + } fp32 = {.f = x}; return fp32.i; } -static inline int32_t float_to_fp16s(float input) -{ +static inline int32_t float_to_fp16s(float input) { union { int32_t i; __fp16 f[2]; @@ -237,7 +218,6 @@ static inline int32_t float_to_fp16s(float input) return fp32.i; } - #define FLOAT_MANTISA 23 #define FLOAT_EXPONENT_MASK 0xff #define FLOAT_EXPONENT_BIAS 0x7f @@ -252,10 +232,8 @@ int32_t hvx_rmsnorm_auint8( float *restrict weights, uint8_t *restrict output, uint32_t size, - float scale) -{ - if ((input == NULL) || (output == NULL) || (size == 0)) - { + float scale) { + if ((input == NULL) || (output == NULL) || (size == 0)) { return -1; } @@ -274,7 +252,7 @@ int32_t hvx_rmsnorm_auint8( float low_level = -128.0f; float high_level = 127.0f; - float es = 0.5f; + float es = 0.5f; low_level_vec = Q6_V_vsplat_R(float_to_bits(low_level)); high_level_vec = Q6_V_vsplat_R(float_to_bits(high_level)); scale_vec = Q6_V_vsplat_R(float_to_bits(scale)); @@ -287,7 +265,6 @@ int32_t hvx_rmsnorm_auint8( HVX_Vector uintconvert = Q6_V_vsplat_R(0x80808080); - // HVX_Vector expmask = Q6_V_vsplat_R(FLOAT_EXPONENT_MASK); // HVX_Vector expbias = Q6_V_vsplat_R(FLOAT_EXPONENT_BIAS); // HVX_Vector manmask = Q6_V_vsplat_R(FLOAT_MANTISA_MASK); @@ -303,53 +280,45 @@ int32_t hvx_rmsnorm_auint8( sline1p = *iptr++; - // ^2 sum HVX_Vector sum = Q6_Vqf32_vadd_VsfVsf(Q6_V_vzero(), Q6_V_vzero()); - for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) - { + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { block = Q6_R_min_RR(i, BLOCK_SIZE); l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); - if (l2fetch_block > 0) - { + if (l2fetch_block > 0) { l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); } - for (int32_t j = 0; j < block; ++j) - { + for (int32_t j = 0; j < block; ++j) { sline1c = *iptr++; sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); - sum = Q6_Vqf32_vadd_Vqf32Vqf32(sum, Q6_Vqf32_vmpy_VsfVsf(sline1, sline1)); - + sum = Q6_Vqf32_vadd_Vqf32Vqf32(sum, Q6_Vqf32_vmpy_VsfVsf(sline1, sline1)); sline1p = sline1c; } } if (vectors_in_rounddown > 0) { - - sline1c = is_aligned(iptr, VLEN) && leftover == 0 ? sline1p : *iptr++; - sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t) input); - sum = Q6_Vqf32_vadd_Vqf32Vqf32(sum, Q6_Vqf32_vmpy_VsfVsf(sline1, sline1)); - + sline1c = is_aligned(iptr, VLEN) && leftover == 0 ? sline1p : *iptr++; + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); + sum = Q6_Vqf32_vadd_Vqf32Vqf32(sum, Q6_Vqf32_vmpy_VsfVsf(sline1, sline1)); } float epsilon_ = 1e-6; union { - float f; - uint32_t ui; + float f; + uint32_t ui; } sum_value; sum_value.f = 0.0f; - for (int32_t i = 64; i >= 4; i >>= 1) - { + for (int32_t i = 64; i >= 4; i >>= 1) { sum = Q6_Vqf32_vadd_Vqf32Vqf32(sum, Q6_V_vlalign_VVR(sum, zero, i)); } sum = Q6_Vsf_equals_Vqf32(sum); - sum_value.f = 1.0f / sqrtf(*((float*)&sum + 31) / size + epsilon_); + sum_value.f = 1.0f / sqrtf(*((float *)&sum + 31) / size + epsilon_); // x * 1/rsqrt(sum) iptr = (HVX_Vector *)input; @@ -361,37 +330,32 @@ int32_t hvx_rmsnorm_auint8( slinewp = *iptr2++; - HVX_Vector irsqrt_vsf = Q6_V_vsplat_R(sum_value.ui); HVX_Vector irsqrt_vqf32 = Q6_Vqf32_vadd_VsfVsf(irsqrt_vsf, Q6_V_vzero()); - for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) - { + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { block = Q6_R_min_RR(i, BLOCK_SIZE); l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); - if (l2fetch_block > 0) - { + if (l2fetch_block > 0) { l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); l2fetch(iptr2 + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); } - for (int32_t j = 0; j < block; j+=4) - { - + for (int32_t j = 0; j < block; j += 4) { { - sline1c = *iptr++; - slinewc = *iptr2++; - sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); - slinew = Q6_V_valign_VVR(slinewc, slinewp, (size_t)weights); + sline1c = *iptr++; + slinewc = *iptr2++; + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); + slinew = Q6_V_valign_VVR(slinewc, slinewp, (size_t)weights); - HVX_Vector middle_value_qf32 = Q6_Vqf32_vmpy_VsfVsf(sline1, slinew); - sline1 = Q6_Vqf32_vmpy_Vqf32Vqf32(middle_value_qf32, irsqrt_vqf32); + HVX_Vector middle_value_qf32 = Q6_Vqf32_vmpy_VsfVsf(sline1, slinew); + sline1 = Q6_Vqf32_vmpy_Vqf32Vqf32(middle_value_qf32, irsqrt_vqf32); - slinewp = slinewc; + slinewp = slinewc; } - - sout1 = Q6_Vqf32_vmpy_Vqf32Vqf32(sline1,scale_vec); + + sout1 = Q6_Vqf32_vmpy_Vqf32Vqf32(sline1, scale_vec); sout1 = Q6_Vqf32_vadd_Vqf32Vqf32(sout1, es_vec); sout1 = Q6_Vsf_equals_Vqf32(sout1); sout1 = Q6_Vsf_vmin_VsfVsf(sout1, high_level_vec); @@ -438,19 +402,18 @@ int32_t hvx_rmsnorm_auint8( // sout1 = qhmath_hvx_vw_convert_vqf32_rmode(Q6_Vqf32_vadd_VsfVsf(sout1, Q6_V_vzero()), 0); { - sline2c = *iptr++; - slinewc = *iptr2++; - sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t)input); - slinew = Q6_V_valign_VVR(slinewc, slinewp, (size_t)weights); + sline2c = *iptr++; + slinewc = *iptr2++; + sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t)input); + slinew = Q6_V_valign_VVR(slinewc, slinewp, (size_t)weights); - HVX_Vector middle_value_qf32 = Q6_Vqf32_vmpy_VsfVsf(sline2, slinew); - sline2 = Q6_Vqf32_vmpy_Vqf32Vqf32(middle_value_qf32, irsqrt_vqf32); + HVX_Vector middle_value_qf32 = Q6_Vqf32_vmpy_VsfVsf(sline2, slinew); + sline2 = Q6_Vqf32_vmpy_Vqf32Vqf32(middle_value_qf32, irsqrt_vqf32); - slinewp = slinewc; + slinewp = slinewc; } - - sout2 = Q6_Vqf32_vmpy_Vqf32Vqf32(sline2,scale_vec); + sout2 = Q6_Vqf32_vmpy_Vqf32Vqf32(sline2, scale_vec); sout2 = Q6_Vqf32_vadd_Vqf32Vqf32(sout2, es_vec); sout2 = Q6_Vsf_equals_Vqf32(sout2); sout2 = Q6_Vsf_vmin_VsfVsf(sout2, high_level_vec); @@ -497,19 +460,18 @@ int32_t hvx_rmsnorm_auint8( // sout2 = qhmath_hvx_vw_convert_vqf32_rmode(Q6_Vqf32_vadd_VsfVsf(sout2, Q6_V_vzero()), 0); { - sline3c = *iptr++; - slinewc = *iptr2++; - sline3 = Q6_V_valign_VVR(sline3c, sline3p, (size_t) input); - slinew = Q6_V_valign_VVR(slinewc, slinewp, (size_t)weights); + sline3c = *iptr++; + slinewc = *iptr2++; + sline3 = Q6_V_valign_VVR(sline3c, sline3p, (size_t)input); + slinew = Q6_V_valign_VVR(slinewc, slinewp, (size_t)weights); - HVX_Vector middle_value_qf32 = Q6_Vqf32_vmpy_VsfVsf(sline3, slinew); - sline3 = Q6_Vqf32_vmpy_Vqf32Vqf32(middle_value_qf32, irsqrt_vqf32); + HVX_Vector middle_value_qf32 = Q6_Vqf32_vmpy_VsfVsf(sline3, slinew); + sline3 = Q6_Vqf32_vmpy_Vqf32Vqf32(middle_value_qf32, irsqrt_vqf32); - slinewp = slinewc; + slinewp = slinewc; } - - sout3 = Q6_Vqf32_vmpy_Vqf32Vqf32(sline3,scale_vec); + sout3 = Q6_Vqf32_vmpy_Vqf32Vqf32(sline3, scale_vec); sout3 = Q6_Vqf32_vadd_Vqf32Vqf32(sout3, es_vec); sout3 = Q6_Vsf_equals_Vqf32(sout3); sout3 = Q6_Vsf_vmin_VsfVsf(sout3, high_level_vec); @@ -551,25 +513,23 @@ int32_t hvx_rmsnorm_auint8( // sout3 = Q6_V_vmux_QVV(expgte23, sout3, tsout1); // } - sout3 = Q6_Vw_equals_Vsf(sout3); sout3 = Q6_Vw_vasr_VwR(sout3, ROUND_2_SCALE); // sout3 = qhmath_hvx_vw_convert_vqf32_rmode(Q6_Vqf32_vadd_VsfVsf(sout3, Q6_V_vzero()), 0); { - sline4c = *iptr++; - slinewc = *iptr2++; - sline4 = Q6_V_valign_VVR(sline4c, sline4p, (size_t) input); - slinew = Q6_V_valign_VVR(slinewc, slinewp, (size_t)weights); + sline4c = *iptr++; + slinewc = *iptr2++; + sline4 = Q6_V_valign_VVR(sline4c, sline4p, (size_t)input); + slinew = Q6_V_valign_VVR(slinewc, slinewp, (size_t)weights); - HVX_Vector middle_value_qf32 = Q6_Vqf32_vmpy_VsfVsf(sline4, slinew); - sline4 = Q6_Vqf32_vmpy_Vqf32Vqf32(middle_value_qf32, irsqrt_vqf32); + HVX_Vector middle_value_qf32 = Q6_Vqf32_vmpy_VsfVsf(sline4, slinew); + sline4 = Q6_Vqf32_vmpy_Vqf32Vqf32(middle_value_qf32, irsqrt_vqf32); - slinewp = slinewc; + slinewp = slinewc; } - - sout4 = Q6_Vqf32_vmpy_Vqf32Vqf32(sline4,scale_vec); + sout4 = Q6_Vqf32_vmpy_Vqf32Vqf32(sline4, scale_vec); sout4 = Q6_Vqf32_vadd_Vqf32Vqf32(sout4, es_vec); sout4 = Q6_Vsf_equals_Vqf32(sout4); sout4 = Q6_Vsf_vmin_VsfVsf(sout4, high_level_vec); @@ -615,7 +575,6 @@ int32_t hvx_rmsnorm_auint8( sout4 = Q6_Vw_vasr_VwR(sout4, ROUND_2_SCALE); // sout4 = qhmath_hvx_vw_convert_vqf32_rmode(Q6_Vqf32_vadd_VsfVsf(sout4, Q6_V_vzero()), 0); - HVX_Vector reql_h = Q6_Vh_vpack_VwVw_sat(sout2, sout1); HVX_Vector reqh_h = Q6_Vh_vpack_VwVw_sat(sout4, sout3); HVX_Vector req_b = Q6_Vb_vpack_VhVh_sat(reqh_h, reql_h); @@ -627,9 +586,7 @@ int32_t hvx_rmsnorm_auint8( sline3p = sline3c; sline4p = sline4c; - slinewp = slinewc; - } } @@ -641,10 +598,8 @@ int32_t hvx_rmsnorm_auint8_opt( float *restrict weights, uint8_t *restrict output, uint32_t size, - float scale) -{ - if ((input == NULL) || (output == NULL) || (size == 0)) - { + float scale) { + if ((input == NULL) || (output == NULL) || (size == 0)) { return -1; } @@ -663,7 +618,7 @@ int32_t hvx_rmsnorm_auint8_opt( // float low_level = -128.0f; // float high_level = 127.0f; - // float es = 0.5f; + // float es = 0.5f; // low_level_vec = Q6_V_vsplat_R(float_to_bits(low_level)); // high_level_vec = Q6_V_vsplat_R(float_to_bits(high_level)); // scale_vec = Q6_V_vsplat_R(float_to_bits(scale)); @@ -676,7 +631,6 @@ int32_t hvx_rmsnorm_auint8_opt( // HVX_Vector uintconvert = Q6_V_vsplat_R(0x80808080); - // HVX_Vector expmask = Q6_V_vsplat_R(FLOAT_EXPONENT_MASK); // HVX_Vector expbias = Q6_V_vsplat_R(FLOAT_EXPONENT_BIAS); // HVX_Vector manmask = Q6_V_vsplat_R(FLOAT_MANTISA_MASK); @@ -692,53 +646,45 @@ int32_t hvx_rmsnorm_auint8_opt( sline1p = *iptr++; - // ^2 sum HVX_Vector sum = Q6_Vqf32_vadd_VsfVsf(Q6_V_vzero(), Q6_V_vzero()); - for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) - { + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { block = Q6_R_min_RR(i, BLOCK_SIZE); l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); - if (l2fetch_block > 0) - { + if (l2fetch_block > 0) { l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); } - for (int32_t j = 0; j < block; ++j) - { + for (int32_t j = 0; j < block; ++j) { sline1c = *iptr++; sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); - sum = Q6_Vqf32_vadd_Vqf32Vqf32(sum, Q6_Vqf32_vmpy_VsfVsf(sline1, sline1)); - + sum = Q6_Vqf32_vadd_Vqf32Vqf32(sum, Q6_Vqf32_vmpy_VsfVsf(sline1, sline1)); sline1p = sline1c; } } if (vectors_in_rounddown > 0) { - - sline1c = is_aligned(iptr, VLEN) && leftover == 0 ? sline1p : *iptr++; - sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t) input); - sum = Q6_Vqf32_vadd_Vqf32Vqf32(sum, Q6_Vqf32_vmpy_VsfVsf(sline1, sline1)); - + sline1c = is_aligned(iptr, VLEN) && leftover == 0 ? sline1p : *iptr++; + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); + sum = Q6_Vqf32_vadd_Vqf32Vqf32(sum, Q6_Vqf32_vmpy_VsfVsf(sline1, sline1)); } float epsilon_ = 1e-6; union { - float f; - uint32_t ui; + float f; + uint32_t ui; } sum_value; sum_value.f = 0.0f; - for (int32_t i = 64; i >= 4; i >>= 1) - { + for (int32_t i = 64; i >= 4; i >>= 1) { sum = Q6_Vqf32_vadd_Vqf32Vqf32(sum, Q6_V_vlalign_VVR(sum, zero, i)); } sum = Q6_Vsf_equals_Vqf32(sum); - sum_value.f = 1.0f / sqrtf(*((float*)&sum + 31) / size + epsilon_); + sum_value.f = 1.0f / sqrtf(*((float *)&sum + 31) / size + epsilon_); // x * 1/rsqrt(sum) iptr = (HVX_Vector *)input; @@ -750,66 +696,58 @@ int32_t hvx_rmsnorm_auint8_opt( slinewp = *iptr2++; - HVX_Vector irsqrt_vsf = Q6_V_vsplat_R(sum_value.ui); HVX_Vector irsqrt_vqf32 = Q6_Vqf32_vadd_VsfVsf(irsqrt_vsf, Q6_V_vzero()); - float post_scale_flt = scale / 64.0f; - int scexp = flt_getexp( post_scale_flt); - int rsh = min_i32( -scexp,7); // e.g. 0.11 -> 0.88, rsh = 3 + int scexp = flt_getexp(post_scale_flt); + int rsh = min_i32(-scexp, 7); // e.g. 0.11 -> 0.88, rsh = 3 float rsh_fac = flt_power2(rsh); int adj_bias = roundf_i32(128 * rsh_fac); - adj_bias = Q6_R_combine_RlRl( adj_bias, adj_bias); - + adj_bias = Q6_R_combine_RlRl(adj_bias, adj_bias); HVX_Vector zero_v_sf = Q6_V_vzero(); - float es = 0.5f; + float es = 0.5f; HVX_Vector es_vec = Q6_V_vsplat_R(float_to_fp16s(es)); es_vec = Q6_Vqf16_vadd_VhfVhf(es_vec, zero_v_sf); HVX_Vector vadj = Q6_V_vsplat_R(adj_bias); - HVX_Vector o_scale_vec = Q6_V_vsplat_R(float_to_fp16s(post_scale_flt * rsh_fac * (1<<15))); + HVX_Vector o_scale_vec = Q6_V_vsplat_R(float_to_fp16s(post_scale_flt * rsh_fac * (1 << 15))); - for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) - { + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { block = Q6_R_min_RR(i, BLOCK_SIZE); l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); - if (l2fetch_block > 0) - { + if (l2fetch_block > 0) { l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); l2fetch(iptr2 + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); } - for (int32_t j = 0; j < block; j+=4) - { - + for (int32_t j = 0; j < block; j += 4) { { - sline1c = *iptr++; - slinewc = *iptr2++; - sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); - slinew = Q6_V_valign_VVR(slinewc, slinewp, (size_t)weights); + sline1c = *iptr++; + slinewc = *iptr2++; + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); + slinew = Q6_V_valign_VVR(slinewc, slinewp, (size_t)weights); - HVX_Vector middle_value_qf32 = Q6_Vqf32_vmpy_VsfVsf(sline1, slinew); - sline1 = Q6_Vqf32_vmpy_Vqf32Vqf32(middle_value_qf32, irsqrt_vqf32); + HVX_Vector middle_value_qf32 = Q6_Vqf32_vmpy_VsfVsf(sline1, slinew); + sline1 = Q6_Vqf32_vmpy_Vqf32Vqf32(middle_value_qf32, irsqrt_vqf32); - slinewp = slinewc; + slinewp = slinewc; } { - sline2c = *iptr++; - slinewc = *iptr2++; - sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t)input); - slinew = Q6_V_valign_VVR(slinewc, slinewp, (size_t)weights); + sline2c = *iptr++; + slinewc = *iptr2++; + sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t)input); + slinew = Q6_V_valign_VVR(slinewc, slinewp, (size_t)weights); - HVX_Vector middle_value_qf32 = Q6_Vqf32_vmpy_VsfVsf(sline2, slinew); - sline2 = Q6_Vqf32_vmpy_Vqf32Vqf32(middle_value_qf32, irsqrt_vqf32); + HVX_Vector middle_value_qf32 = Q6_Vqf32_vmpy_VsfVsf(sline2, slinew); + sline2 = Q6_Vqf32_vmpy_Vqf32Vqf32(middle_value_qf32, irsqrt_vqf32); - slinewp = slinewc; + slinewp = slinewc; } - HVX_Vector sline_low = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(sline2, sline1)); sline_low = Q6_Vqf16_vadd_Vqf16Vqf16(sline_low, es_vec); @@ -820,29 +758,28 @@ int32_t hvx_rmsnorm_auint8_opt( sline_low = Q6_Vh_vdeal_Vh(sline_low); { - sline3c = *iptr++; - slinewc = *iptr2++; - sline3 = Q6_V_valign_VVR(sline3c, sline3p, (size_t) input); - slinew = Q6_V_valign_VVR(slinewc, slinewp, (size_t)weights); + sline3c = *iptr++; + slinewc = *iptr2++; + sline3 = Q6_V_valign_VVR(sline3c, sline3p, (size_t)input); + slinew = Q6_V_valign_VVR(slinewc, slinewp, (size_t)weights); - HVX_Vector middle_value_qf32 = Q6_Vqf32_vmpy_VsfVsf(sline3, slinew); - sline3 = Q6_Vqf32_vmpy_Vqf32Vqf32(middle_value_qf32, irsqrt_vqf32); + HVX_Vector middle_value_qf32 = Q6_Vqf32_vmpy_VsfVsf(sline3, slinew); + sline3 = Q6_Vqf32_vmpy_Vqf32Vqf32(middle_value_qf32, irsqrt_vqf32); - slinewp = slinewc; + slinewp = slinewc; } { - sline4c = *iptr++; - slinewc = *iptr2++; - sline4 = Q6_V_valign_VVR(sline4c, sline4p, (size_t) input); - slinew = Q6_V_valign_VVR(slinewc, slinewp, (size_t)weights); + sline4c = *iptr++; + slinewc = *iptr2++; + sline4 = Q6_V_valign_VVR(sline4c, sline4p, (size_t)input); + slinew = Q6_V_valign_VVR(slinewc, slinewp, (size_t)weights); - HVX_Vector middle_value_qf32 = Q6_Vqf32_vmpy_VsfVsf(sline4, slinew); - sline4 = Q6_Vqf32_vmpy_Vqf32Vqf32(middle_value_qf32, irsqrt_vqf32); + HVX_Vector middle_value_qf32 = Q6_Vqf32_vmpy_VsfVsf(sline4, slinew); + sline4 = Q6_Vqf32_vmpy_Vqf32Vqf32(middle_value_qf32, irsqrt_vqf32); - slinewp = slinewc; + slinewp = slinewc; } - HVX_Vector sline_high = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(sline4, sline3)); sline_high = Q6_Vqf16_vadd_Vqf16Vqf16(sline_high, es_vec); @@ -852,7 +789,7 @@ int32_t hvx_rmsnorm_auint8_opt( sline_high = Q6_Vh_vdeal_Vh(sline_high); - HVX_Vector sout = Q6_Vub_vasr_VhVhR_rnd_sat( sline_high, sline_low, rsh); + HVX_Vector sout = Q6_Vub_vasr_VhVhR_rnd_sat(sline_high, sline_low, rsh); sout = Q6_Vb_vdeal_Vb(sout); *optr++ = sout; @@ -861,149 +798,132 @@ int32_t hvx_rmsnorm_auint8_opt( sline3p = sline3c; sline4p = sline4c; - slinewp = slinewc; - } } return 0; } -template -GraphStatus rmsnormImpl(TensorType& out_0, - const TensorType& in_0, - const TensorType& weights) +template +GraphStatus rmsnormImpl(TensorType &out_0, + const TensorType &in_0, + const TensorType &weights) { - out_0.set_dims(in_0); - - // NHWC + out_0.set_dims(in_0); - auto in_ptr = (float*)in_0.raw_data_const(); - auto weights_ptr = (float*)weights.raw_data_const(); + // NHWC + auto in_ptr = (float *)in_0.raw_data_const(); + auto weights_ptr = (float *)weights.raw_data_const(); - auto [b_in, h_in, w_in, d_in] = in_0.dims(); - + auto [b_in, h_in, w_in, d_in] = in_0.dims(); - DType dtype = out_0.get_dtype(); + DType dtype = out_0.get_dtype(); - if (dtype == DType::Float32) { + if (dtype == DType::Float32) { + auto out_ptr = (float *)out_0.raw_data(); - auto out_ptr = (float*)out_0.raw_data(); + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + for (Idx w = 0; w < w_in; w++) { + // RMS + hvx_rmsnorm_af(in_ptr, weights_ptr, out_ptr, d_in); - for (Idx b = 0; b < b_in; b++) { - for (Idx h = 0; h < h_in; h++) { - for (Idx w = 0; w < w_in; w++) { - // RMS - hvx_rmsnorm_af(in_ptr, weights_ptr, out_ptr, d_in); - - in_ptr += d_in; - out_ptr += d_in; + in_ptr += d_in; + out_ptr += d_in; + } + } } - } - } - } else if (dtype == DType::QUInt8) { + } else if (dtype == DType::QUInt8) { + auto out_ptr = (uint8_t *)out_0.raw_data(); + float scale_ = out_0.get_interface_scale(); - auto out_ptr = (uint8_t*)out_0.raw_data(); - float scale_ = out_0.get_interface_scale(); + scale_ = 1.0f / scale_; - scale_ = 1.0f/scale_; + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + for (Idx w = 0; w < w_in; w++) { + // RMS + hvx_rmsnorm_auint8(in_ptr, weights_ptr, out_ptr, d_in, scale_); - for (Idx b = 0; b < b_in; b++) { - for (Idx h = 0; h < h_in; h++) { - for (Idx w = 0; w < w_in; w++) { - // RMS - hvx_rmsnorm_auint8(in_ptr, weights_ptr, out_ptr, d_in, scale_); - - in_ptr += d_in; - out_ptr += d_in; + in_ptr += d_in; + out_ptr += d_in; + } + } } - } } - } - - return GraphStatus::Success; + return GraphStatus::Success; } #else -template -GraphStatus rmsnormImpl(TensorType& out_0, - const TensorType& in_0, - const TensorType& weights) +template +GraphStatus rmsnormImpl(TensorType &out_0, + const TensorType &in_0, + const TensorType &weights) { - /* - * add code here - * */ - /* - * To have good performance and stability, it is required to avoid heap memory - * allocation in this function. The heap memory allocation includes but not - * limited to calling malloc, operator new, constructing STL container objects - * like std::vector with default allocator, and adding items like calling - * std::vector::push_back to STL container objects with default allocator. - * - * Please check in SDK documentation for more information. - */ - out_0.set_dims(in_0); + /* + * add code here + * */ + /* + * To have good performance and stability, it is required to avoid heap memory + * allocation in this function. The heap memory allocation includes but not + * limited to calling malloc, operator new, constructing STL container objects + * like std::vector with default allocator, and adding items like calling + * std::vector::push_back to STL container objects with default allocator. + * + * Please check in SDK documentation for more information. + */ + out_0.set_dims(in_0); // NHWC float epsilon_ = 1e-6; auto [b_in, h_in, w_in, d_in] = in_0.dims(); for (Idx b = 0; b < b_in; b++) { - for (Idx h = 0; h < h_in; h++) { - for (Idx w = 0; w < w_in; w++) { - // RMS - float sum_squares = 0.0f; - for (Idx d = 0; d < d_in; d++) { - float inval = in_0(b, h, w, d); - sum_squares += inval*inval; - } - - // debuglog("silu execute... sum_squares=(%f)", sum_squares); - - float rms = sqrtf(sum_squares / d_in + epsilon_); - debuglog("rms execute... sum_squares=(%f)", 1.0f / rms); - debuglog("rms execute... sum_squares=(%f)", sum_squares); - - for (Idx d = 0; d < d_in; d++) { - float inval = in_0(b, h, w, d); - float weight = weights(0, 0, 0, d); - - out_0(b, h, w, d) = inval * weight / rms; - - } - + for (Idx h = 0; h < h_in; h++) { + for (Idx w = 0; w < w_in; w++) { + // RMS + float sum_squares = 0.0f; + for (Idx d = 0; d < d_in; d++) { + float inval = in_0(b, h, w, d); + sum_squares += inval * inval; + } + + // debuglog("silu execute... sum_squares=(%f)", sum_squares); + + float rms = sqrtf(sum_squares / d_in + epsilon_); + debuglog("rms execute... sum_squares=(%f)", 1.0f / rms); + debuglog("rms execute... sum_squares=(%f)", sum_squares); + + for (Idx d = 0; d < d_in; d++) { + float inval = in_0(b, h, w, d); + float weight = weights(0, 0, 0, d); + + out_0(b, h, w, d) = inval * weight / rms; + } + } } - } } - - - return GraphStatus::Success; + return GraphStatus::Success; } #endif +__attribute__((unused)) static float rmsnormCostFunc(const Op *op) { + /* + * add code here + * */ -__attribute__((unused)) static float rmsnormCostFunc(const Op *op) -{ - /* - * add code here - * */ - - float cost = 0.0; // add cost computation here - return cost; + float cost = 0.0; // add cost computation here + return cost; } - - - - /* At the bottom of the op file, call END_PKG_OP_DEFINITION(), where is as BEGIN_PKG_OP_DEFINITION */ diff --git a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/RoPE.cpp b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/RoPE.cpp index 3aaeccf00..ac36798cc 100755 --- a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/RoPE.cpp +++ b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/RoPE.cpp @@ -9,19 +9,16 @@ #include "QnnOpPackage.h" #include "HTP/core/simple_reg.h" - BEGIN_PKG_OP_DEFINITION(PKG_RoPE); - // op execute function declarations -template -GraphStatus ropeImpl(TensorType& out_0, - const TensorType& in_0, - const TensorType& sin, - const TensorType& cos, +template +GraphStatus ropeImpl(TensorType &out_0, + const TensorType &in_0, + const TensorType &sin, + const TensorType &cos, const TensorType1 &h_cnt, - const Tensor& pose_type); - + const Tensor &pose_type); // forward declaration of sample cost function static float ropeCostFunc(const Op *op); @@ -66,11 +63,11 @@ DEF_PACKAGE_OP((ropeImpl), "RoPE") * one definition per op, and this is optional * syntax: DEF_PACKAGE_PARAM_ORDER(OP,PARAM1,MANDATORY1,DEFAULT1,PARAM2,MANDATORY2,DEFAULT2...) * one or more parameters can be specified for each op - * order of parameters listed determines the order of parameters passed into op execution functions + * order of parameters listed determines the order of parameters passed into op execution functions * if an op does not have a parameter order definition, parameter order passed into Qnn_addNode * will be passed into op execution functions * if an op has a parameter order definition, any parameter passed into Qnn_addNode with unlisted - * name will be abandoned + * name will be abandoned * if two or more op packages with the same package name will be registered, they cannot list * conflicting parameter orders * PARAM refers to parameter name as a string literal @@ -83,12 +80,11 @@ DEF_PACKAGE_OP((ropeImpl), "RoPE") * graph construction will skip this parameter when this parameter is not provided at * Qnn_addNode */ -DEF_PACKAGE_PARAM_ORDER("RoPE", +DEF_PACKAGE_PARAM_ORDER("RoPE", "pose_type", true, nullptr) - /* execute functions for ops */ #ifndef REFERENCE_OP @@ -98,10 +94,10 @@ DEF_PACKAGE_PARAM_ORDER("RoPE", #include #include -#define BLOCK_SIZE (8*1024/VLEN) /* vector chunks */ -#define L2FETCH_AHEAD (BLOCK_SIZE) -#define ONE 0x3F800000 -#define M_ONE 0xAF800000 +#define BLOCK_SIZE (8 * 1024 / VLEN) /* vector chunks */ +#define L2FETCH_AHEAD (BLOCK_SIZE) +#define ONE 0x3F800000 +#define M_ONE 0xAF800000 int32_t hvx_rope_af( float *restrict input, @@ -109,19 +105,18 @@ int32_t hvx_rope_af( float *restrict cos, float *restrict output, uint32_t size, - uint32_t partial_dimension) -{ - if ((input == NULL) || (output == NULL) || (size == 0)) - { + uint32_t partial_dimension) { + if ((input == NULL) || (output == NULL) || (size == 0)) { return -1; } HVX_Vector *iptr = (HVX_Vector *)input; - HVX_Vector *iptr_half = (HVX_Vector *)(input + partial_dimension/2); + HVX_Vector *iptr_half = (HVX_Vector *)(input + partial_dimension / 2); HVX_Vector *iptr2 = (HVX_Vector *)sin; HVX_Vector *iptr3 = (HVX_Vector *)cos; HVX_UVector *optr = (HVX_UVector *)output; - HVX_UVector *optr_half = (HVX_UVector *)(output + partial_dimension/2);; + HVX_UVector *optr_half = (HVX_UVector *)(output + partial_dimension / 2); + ; HVX_Vector sline1; HVX_Vector sline1_half; HVX_Vector sinline1p, sinline1c, sinline1; @@ -135,63 +130,53 @@ int32_t hvx_rope_af( sinline1p = *iptr2++; cosline1p = *iptr3++; - for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) - { + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); - if (l2fetch_block > 0) - { + if (l2fetch_block > 0) { l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); l2fetch(iptr2 + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); l2fetch(iptr3 + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); } - for (int32_t d = 0; d < partial_dimension/2; d+=32) { - cosline1c = *iptr3++; - cosline1 = Q6_V_valign_VVR(cosline1c, cosline1p, (size_t)cos); - cosline1p = cosline1c; - - sinline1c = *iptr2++; - sinline1 = Q6_V_valign_VVR(sinline1c, sinline1p, (size_t)sin); - sinline1p = sinline1c; - - - HVX_Vector *jiptr = iptr + d/32; - HVX_Vector *jiptr_half = iptr_half + d/32; - HVX_Vector *joptr = optr + d/32; - HVX_Vector *joptr_half = optr_half + d/32; - - for (int32_t j = 0; j < size/partial_dimension; j++) - { - sline1 = *jiptr; - sline1_half = *jiptr_half; - - // auto value = in_value * cos_value - in_value_2 * sin_value; - { - HVX_Vector cos_middle_value_qf32 = Q6_Vqf32_vmpy_VsfVsf(sline1, cosline1); - HVX_Vector sin_middle_value_qf32 = Q6_Vqf32_vmpy_VsfVsf(sline1_half, sinline1); - *joptr = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_Vqf32Vqf32(cos_middle_value_qf32, sin_middle_value_qf32)); - } - - - - // auto value2 = in_value * sin_value + in_value_2 * cos_value; - { - HVX_Vector cos_middle_value_qf32 = Q6_Vqf32_vmpy_VsfVsf(sline1_half, cosline1); - HVX_Vector sin_middle_value_qf32 = Q6_Vqf32_vmpy_VsfVsf(sline1, sinline1); - *joptr_half = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vqf32(cos_middle_value_qf32, sin_middle_value_qf32)); - } - - jiptr += partial_dimension/32; - jiptr_half += partial_dimension/32; - joptr += partial_dimension/32; - joptr_half += partial_dimension/32; - - } - - + for (int32_t d = 0; d < partial_dimension / 2; d += 32) { + cosline1c = *iptr3++; + cosline1 = Q6_V_valign_VVR(cosline1c, cosline1p, (size_t)cos); + cosline1p = cosline1c; + + sinline1c = *iptr2++; + sinline1 = Q6_V_valign_VVR(sinline1c, sinline1p, (size_t)sin); + sinline1p = sinline1c; + + HVX_Vector *jiptr = iptr + d / 32; + HVX_Vector *jiptr_half = iptr_half + d / 32; + HVX_Vector *joptr = optr + d / 32; + HVX_Vector *joptr_half = optr_half + d / 32; + + for (int32_t j = 0; j < size / partial_dimension; j++) { + sline1 = *jiptr; + sline1_half = *jiptr_half; + + // auto value = in_value * cos_value - in_value_2 * sin_value; + { + HVX_Vector cos_middle_value_qf32 = Q6_Vqf32_vmpy_VsfVsf(sline1, cosline1); + HVX_Vector sin_middle_value_qf32 = Q6_Vqf32_vmpy_VsfVsf(sline1_half, sinline1); + *joptr = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_Vqf32Vqf32(cos_middle_value_qf32, sin_middle_value_qf32)); + } + + // auto value2 = in_value * sin_value + in_value_2 * cos_value; + { + HVX_Vector cos_middle_value_qf32 = Q6_Vqf32_vmpy_VsfVsf(sline1_half, cosline1); + HVX_Vector sin_middle_value_qf32 = Q6_Vqf32_vmpy_VsfVsf(sline1, sinline1); + *joptr_half = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vqf32(cos_middle_value_qf32, sin_middle_value_qf32)); + } + + jiptr += partial_dimension / 32; + jiptr_half += partial_dimension / 32; + joptr += partial_dimension / 32; + joptr_half += partial_dimension / 32; + } } - } // if (vectors_in_rounddown > 0) { @@ -202,15 +187,13 @@ int32_t hvx_rope_af( // } - if (leftover_size > 0) - return -1; + return -1; return 0; } -static inline int32_t float_to_fp16s(float input) -{ +static inline int32_t float_to_fp16s(float input) { union { int32_t i; __fp16 f[2]; @@ -224,10 +207,8 @@ int32_t hvx_rope_uint8_af( float *restrict cos, float *restrict output, uint32_t size, - uint32_t partial_dimension) -{ - if ((input == NULL) || (output == NULL) || (size == 0)) - { + uint32_t partial_dimension) { + if ((input == NULL) || (output == NULL) || (size == 0)) { return -1; } @@ -246,83 +227,70 @@ int32_t hvx_rope_uint8_af( HVX_Vector convert_vector = Q6_V_vsplat_R(convert); HVX_Vector one_vec = Q6_V_vsplat_R(float_to_fp16s(1.0)); - // - for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) - { + // + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); - if (l2fetch_block > 0) - { + if (l2fetch_block > 0) { l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); } - // + // HVX_Vector sinline1_low = *iptr2; HVX_Vector cosline1_low = *iptr3; sinline1_low = Q6_Vqf32_vadd_VsfVsf(sinline1_low, Q6_V_vzero()); cosline1_low = Q6_Vqf32_vadd_VsfVsf(cosline1_low, Q6_V_vzero()); - - HVX_Vector sinline1_high = *(iptr2+1); - HVX_Vector cosline1_high = *(iptr3+1); + HVX_Vector sinline1_high = *(iptr2 + 1); + HVX_Vector cosline1_high = *(iptr3 + 1); sinline1_high = Q6_Vqf32_vadd_VsfVsf(sinline1_high, Q6_V_vzero()); cosline1_high = Q6_Vqf32_vadd_VsfVsf(cosline1_high, Q6_V_vzero()); - - for (int32_t j = 0; j < size/partial_dimension; j++) { - - HVX_Vector sline1 = *iptr++; - - HVX_VectorPair temp = Q6_Wh_vadd_VubVub(sline1, zero_v_sf); - temp = Q6_W_vshuff_VVR(Q6_V_hi_W(temp), Q6_V_lo_W(temp), -2); - HVX_Vector sout1 = Q6_Vh_vsub_VhVh(Q6_V_lo_W(temp), convert_vector); - HVX_Vector sout2 = Q6_Vh_vsub_VhVh(Q6_V_hi_W(temp), convert_vector); + for (int32_t j = 0; j < size / partial_dimension; j++) { + HVX_Vector sline1 = *iptr++; - HVX_VectorPair result1 = Q6_Wqf32_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout1), one_vec); - result1 = Q6_W_vshuff_VVR(Q6_V_hi_W(result1), Q6_V_lo_W(result1), -4); + HVX_VectorPair temp = Q6_Wh_vadd_VubVub(sline1, zero_v_sf); - HVX_VectorPair result2 = Q6_Wqf32_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout2), one_vec); - result2 = Q6_W_vshuff_VVR(Q6_V_hi_W(result2), Q6_V_lo_W(result2), -4); + temp = Q6_W_vshuff_VVR(Q6_V_hi_W(temp), Q6_V_lo_W(temp), -2); + HVX_Vector sout1 = Q6_Vh_vsub_VhVh(Q6_V_lo_W(temp), convert_vector); + HVX_Vector sout2 = Q6_Vh_vsub_VhVh(Q6_V_hi_W(temp), convert_vector); + HVX_VectorPair result1 = Q6_Wqf32_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout1), one_vec); + result1 = Q6_W_vshuff_VVR(Q6_V_hi_W(result1), Q6_V_lo_W(result1), -4); + HVX_VectorPair result2 = Q6_Wqf32_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout2), one_vec); + result2 = Q6_W_vshuff_VVR(Q6_V_hi_W(result2), Q6_V_lo_W(result2), -4); - // auto value = in_value * cos_value - in_value_2 * sin_value; - { - HVX_Vector cos_middle_value_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result1), cosline1_low); - HVX_Vector sin_middle_value_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result2), sinline1_low); - *optr = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_Vqf32Vqf32(cos_middle_value_qf32, sin_middle_value_qf32)); - } - - - - // auto value2 = in_value * sin_value + in_value_2 * cos_value; - { - HVX_Vector cos_middle_value_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result2), cosline1_low); - HVX_Vector sin_middle_value_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result1), sinline1_low); - *(optr+2) = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vqf32(cos_middle_value_qf32, sin_middle_value_qf32)); - } - - - // auto value = in_value * cos_value - in_value_2 * sin_value; - { - HVX_Vector cos_middle_value_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result1), cosline1_high); - HVX_Vector sin_middle_value_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result2), sinline1_high); - *(optr+1) = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_Vqf32Vqf32(cos_middle_value_qf32, sin_middle_value_qf32)); - } - + // auto value = in_value * cos_value - in_value_2 * sin_value; + { + HVX_Vector cos_middle_value_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result1), cosline1_low); + HVX_Vector sin_middle_value_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result2), sinline1_low); + *optr = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_Vqf32Vqf32(cos_middle_value_qf32, sin_middle_value_qf32)); + } + // auto value2 = in_value * sin_value + in_value_2 * cos_value; + { + HVX_Vector cos_middle_value_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result2), cosline1_low); + HVX_Vector sin_middle_value_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result1), sinline1_low); + *(optr + 2) = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vqf32(cos_middle_value_qf32, sin_middle_value_qf32)); + } - // auto value2 = in_value * sin_value + in_value_2 * cos_value; - { - HVX_Vector cos_middle_value_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result2), cosline1_high); - HVX_Vector sin_middle_value_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result1), sinline1_high); - *(optr+3) = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vqf32(cos_middle_value_qf32, sin_middle_value_qf32)); - } + // auto value = in_value * cos_value - in_value_2 * sin_value; + { + HVX_Vector cos_middle_value_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result1), cosline1_high); + HVX_Vector sin_middle_value_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result2), sinline1_high); + *(optr + 1) = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_Vqf32Vqf32(cos_middle_value_qf32, sin_middle_value_qf32)); + } - optr+=4; + // auto value2 = in_value * sin_value + in_value_2 * cos_value; + { + HVX_Vector cos_middle_value_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result2), cosline1_high); + HVX_Vector sin_middle_value_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result1), sinline1_high); + *(optr + 3) = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vqf32(cos_middle_value_qf32, sin_middle_value_qf32)); + } + optr += 4; } - } // if (vectors_in_rounddown > 0) { @@ -333,9 +301,8 @@ int32_t hvx_rope_uint8_af( // } - if (leftover_size > 0) - return -1; + return -1; return 0; } @@ -347,10 +314,8 @@ int32_t hvx_rope_uint8_ahf( __fp16 *restrict output, uint32_t size, uint32_t partial_dimension, - float scale) -{ - if ((input == NULL) || (output == NULL) || (size == 0)) - { + float scale) { + if ((input == NULL) || (output == NULL) || (size == 0)) { return -1; } @@ -370,96 +335,85 @@ int32_t hvx_rope_uint8_ahf( HVX_Vector scale_vec = Q6_V_vsplat_R(float_to_fp16s(scale)); - // - for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) - { + // + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); - if (l2fetch_block > 0) - { + if (l2fetch_block > 0) { l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); } - // + // HVX_Vector sinline1_low = *iptr2; HVX_Vector cosline1_low = *iptr3; sinline1_low = Q6_Vqf32_vadd_VsfVsf(sinline1_low, Q6_V_vzero()); cosline1_low = Q6_Vqf32_vadd_VsfVsf(cosline1_low, Q6_V_vzero()); - - HVX_Vector sinline1_high = *(iptr2+1); - HVX_Vector cosline1_high = *(iptr3+1); + HVX_Vector sinline1_high = *(iptr2 + 1); + HVX_Vector cosline1_high = *(iptr3 + 1); sinline1_high = Q6_Vqf32_vadd_VsfVsf(sinline1_high, Q6_V_vzero()); cosline1_high = Q6_Vqf32_vadd_VsfVsf(cosline1_high, Q6_V_vzero()); - - for (int32_t j = 0; j < size/partial_dimension; j++) { - - HVX_Vector sline1 = *iptr++; - - HVX_VectorPair temp = Q6_Wh_vadd_VubVub(sline1, zero_v_sf); - temp = Q6_W_vshuff_VVR(Q6_V_hi_W(temp), Q6_V_lo_W(temp), -2); - HVX_Vector sout1 = Q6_Vh_vsub_VhVh(Q6_V_lo_W(temp), convert_vector); - HVX_Vector sout2 = Q6_Vh_vsub_VhVh(Q6_V_hi_W(temp), convert_vector); + for (int32_t j = 0; j < size / partial_dimension; j++) { + HVX_Vector sline1 = *iptr++; - HVX_VectorPair result1 = Q6_Wqf32_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout1), scale_vec); - result1 = Q6_W_vshuff_VVR(Q6_V_hi_W(result1), Q6_V_lo_W(result1), -4); + HVX_VectorPair temp = Q6_Wh_vadd_VubVub(sline1, zero_v_sf); - HVX_VectorPair result2 = Q6_Wqf32_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout2), scale_vec); - result2 = Q6_W_vshuff_VVR(Q6_V_hi_W(result2), Q6_V_lo_W(result2), -4); + temp = Q6_W_vshuff_VVR(Q6_V_hi_W(temp), Q6_V_lo_W(temp), -2); + HVX_Vector sout1 = Q6_Vh_vsub_VhVh(Q6_V_lo_W(temp), convert_vector); + HVX_Vector sout2 = Q6_Vh_vsub_VhVh(Q6_V_hi_W(temp), convert_vector); + HVX_VectorPair result1 = Q6_Wqf32_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout1), scale_vec); + result1 = Q6_W_vshuff_VVR(Q6_V_hi_W(result1), Q6_V_lo_W(result1), -4); + HVX_VectorPair result2 = Q6_Wqf32_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout2), scale_vec); + result2 = Q6_W_vshuff_VVR(Q6_V_hi_W(result2), Q6_V_lo_W(result2), -4); - - { - HVX_Vector first; - HVX_Vector second; - // auto value = in_value * cos_value - in_value_2 * sin_value; { - HVX_Vector cos_middle_value_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result1), cosline1_low); - HVX_Vector sin_middle_value_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result2), sinline1_low); - first = Q6_Vqf32_vsub_Vqf32Vqf32(cos_middle_value_qf32, sin_middle_value_qf32); + HVX_Vector first; + HVX_Vector second; + // auto value = in_value * cos_value - in_value_2 * sin_value; + { + HVX_Vector cos_middle_value_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result1), cosline1_low); + HVX_Vector sin_middle_value_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result2), sinline1_low); + first = Q6_Vqf32_vsub_Vqf32Vqf32(cos_middle_value_qf32, sin_middle_value_qf32); + } + + // auto value = in_value * cos_value - in_value_2 * sin_value; + { + HVX_Vector cos_middle_value_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result1), cosline1_high); + HVX_Vector sin_middle_value_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result2), sinline1_high); + second = Q6_Vqf32_vsub_Vqf32Vqf32(cos_middle_value_qf32, sin_middle_value_qf32); + } + + HVX_Vector r = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(second, first)); + r = Q6_Vh_vdeal_Vh(r); + *optr = r; } - // auto value = in_value * cos_value - in_value_2 * sin_value; - { - HVX_Vector cos_middle_value_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result1), cosline1_high); - HVX_Vector sin_middle_value_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result2), sinline1_high); - second = Q6_Vqf32_vsub_Vqf32Vqf32(cos_middle_value_qf32, sin_middle_value_qf32); - } - - HVX_Vector r = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(second, first)); - r = Q6_Vh_vdeal_Vh(r); - *optr = r; - } - - { - HVX_Vector first; - HVX_Vector second; - // auto value2 = in_value * sin_value + in_value_2 * cos_value; - { - HVX_Vector cos_middle_value_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result2), cosline1_low); - HVX_Vector sin_middle_value_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result1), sinline1_low); - first = Q6_Vqf32_vadd_Vqf32Vqf32(cos_middle_value_qf32, sin_middle_value_qf32); - } - - - // auto value2 = in_value * sin_value + in_value_2 * cos_value; { - HVX_Vector cos_middle_value_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result2), cosline1_high); - HVX_Vector sin_middle_value_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result1), sinline1_high); - second = Q6_Vqf32_vadd_Vqf32Vqf32(cos_middle_value_qf32, sin_middle_value_qf32); + HVX_Vector first; + HVX_Vector second; + // auto value2 = in_value * sin_value + in_value_2 * cos_value; + { + HVX_Vector cos_middle_value_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result2), cosline1_low); + HVX_Vector sin_middle_value_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result1), sinline1_low); + first = Q6_Vqf32_vadd_Vqf32Vqf32(cos_middle_value_qf32, sin_middle_value_qf32); + } + + // auto value2 = in_value * sin_value + in_value_2 * cos_value; + { + HVX_Vector cos_middle_value_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result2), cosline1_high); + HVX_Vector sin_middle_value_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result1), sinline1_high); + second = Q6_Vqf32_vadd_Vqf32Vqf32(cos_middle_value_qf32, sin_middle_value_qf32); + } + HVX_Vector r = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(second, first)); + r = Q6_Vh_vdeal_Vh(r); + *(optr + 1) = r; } - HVX_Vector r = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(second, first)); - r = Q6_Vh_vdeal_Vh(r); - *(optr+1) = r; - } - - - optr+=2; + optr += 2; } - } // if (vectors_in_rounddown > 0) { @@ -470,33 +424,30 @@ int32_t hvx_rope_uint8_ahf( // } - if (leftover_size > 0) - return -1; + return -1; return 0; } - int32_t hvx_rope_ahf( __fp16 *restrict input, float *restrict sin, float *restrict cos, __fp16 *restrict output, uint32_t size, - uint32_t partial_dimension) -{ - if ((input == NULL) || (output == NULL) || (size == 0)) - { + uint32_t partial_dimension) { + if ((input == NULL) || (output == NULL) || (size == 0)) { return -1; } HVX_Vector *iptr = (HVX_Vector *)input; - HVX_Vector *iptr_half = (HVX_Vector *)(input + partial_dimension/2); + HVX_Vector *iptr_half = (HVX_Vector *)(input + partial_dimension / 2); HVX_Vector *iptr2 = (HVX_Vector *)sin; HVX_Vector *iptr3 = (HVX_Vector *)cos; HVX_UVector *optr = (HVX_UVector *)output; - HVX_UVector *optr_half = (HVX_UVector *)(output + partial_dimension/2);; + HVX_UVector *optr_half = (HVX_UVector *)(output + partial_dimension / 2); + ; HVX_Vector sline1; HVX_Vector sline1_half; @@ -511,115 +462,104 @@ int32_t hvx_rope_ahf( HVX_Vector one_vhf = Q6_V_vsplat_R(float_to_fp16s(1.0)); // HVX_Vector m_one_vqf16 = Q6_Vqf32_vsub_VsfVsf(Q6_V_vzero(), one_vhf); - for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) - { + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); - if (l2fetch_block > 0) - { + if (l2fetch_block > 0) { l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); l2fetch(iptr2 + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); l2fetch(iptr3 + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); } - for (int32_t d = 0; d < partial_dimension/2; d+=64) { + for (int32_t d = 0; d < partial_dimension / 2; d += 64) { + HVX_Vector sinline1_low = *iptr2++; + HVX_Vector cosline1_low = *iptr3++; - HVX_Vector sinline1_low = *iptr2++; - HVX_Vector cosline1_low = *iptr3++; + HVX_Vector sinline1_high = *iptr2++; + HVX_Vector cosline1_high = *iptr3++; - HVX_Vector sinline1_high = *iptr2++; - HVX_Vector cosline1_high = *iptr3++; + HVX_Vector *jiptr = iptr + d / 64; + HVX_Vector *jiptr_half = iptr_half + d / 64; + HVX_Vector *joptr = optr + d / 64; + HVX_Vector *joptr_half = optr_half + d / 64; + for (int32_t j = 0; j < size / partial_dimension; j++) { + sline1 = *jiptr; + sline1_half = *jiptr_half; - HVX_Vector *jiptr = iptr + d/64; - HVX_Vector *jiptr_half = iptr_half + d/64; - HVX_Vector *joptr = optr + d/64; - HVX_Vector *joptr_half = optr_half + d/64; + HVX_VectorPair sline1_half_pair = Q6_Wqf32_vmpy_VhfVhf(sline1_half, one_vhf); + HVX_VectorPair sline1_pair = Q6_Wqf32_vmpy_VhfVhf(sline1, one_vhf); - - for (int32_t j = 0; j < size/partial_dimension; j++) - { - sline1 = *jiptr; - sline1_half = *jiptr_half; + sline1_half_pair = Q6_W_vshuff_VVR(Q6_V_hi_W(sline1_half_pair), Q6_V_lo_W(sline1_half_pair), -4); + sline1_pair = Q6_W_vshuff_VVR(Q6_V_hi_W(sline1_pair), Q6_V_lo_W(sline1_pair), -4); - HVX_VectorPair sline1_half_pair = Q6_Wqf32_vmpy_VhfVhf(sline1_half, one_vhf); - HVX_VectorPair sline1_pair = Q6_Wqf32_vmpy_VhfVhf(sline1, one_vhf); + HVX_Vector m_sline1_half_low = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(sline1_half_pair), m_one_vqf32); + HVX_Vector m_sline1_half_hi = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(sline1_half_pair), m_one_vqf32); - sline1_half_pair = Q6_W_vshuff_VVR(Q6_V_hi_W(sline1_half_pair), Q6_V_lo_W(sline1_half_pair), -4); - sline1_pair = Q6_W_vshuff_VVR(Q6_V_hi_W(sline1_pair), Q6_V_lo_W(sline1_pair), -4); + // auto value = in_value * cos_value - in_value_2 * sin_value; + HVX_Vector middle_value_low; + { + HVX_Vector cosline1_vqf32_low = Q6_Vqf32_vadd_VsfVsf(cosline1_low, Q6_V_vzero()); + HVX_Vector cos_middle_value_qf32_low = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(sline1_pair), cosline1_vqf32_low); - HVX_Vector m_sline1_half_low = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(sline1_half_pair), m_one_vqf32); - HVX_Vector m_sline1_half_hi = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(sline1_half_pair), m_one_vqf32); + HVX_Vector sinline1_vqf32_low = Q6_Vqf32_vadd_VsfVsf(sinline1_low, Q6_V_vzero()); + HVX_Vector sin_middle_value_qf32_low = Q6_Vqf32_vmpy_Vqf32Vqf32(m_sline1_half_low, sinline1_vqf32_low); + middle_value_low = Q6_Vqf32_vadd_Vqf32Vqf32(cos_middle_value_qf32_low, sin_middle_value_qf32_low); + } - // auto value = in_value * cos_value - in_value_2 * sin_value; - HVX_Vector middle_value_low; - { - HVX_Vector cosline1_vqf32_low = Q6_Vqf32_vadd_VsfVsf(cosline1_low, Q6_V_vzero()); - HVX_Vector cos_middle_value_qf32_low = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(sline1_pair), cosline1_vqf32_low); + // auto value2 = in_value * sin_value + in_value_2 * cos_value; - HVX_Vector sinline1_vqf32_low = Q6_Vqf32_vadd_VsfVsf(sinline1_low, Q6_V_vzero()); - - HVX_Vector sin_middle_value_qf32_low = Q6_Vqf32_vmpy_Vqf32Vqf32(m_sline1_half_low, sinline1_vqf32_low); - middle_value_low = Q6_Vqf32_vadd_Vqf32Vqf32(cos_middle_value_qf32_low, sin_middle_value_qf32_low); - } - + HVX_Vector middle_value_half_low; + { + HVX_Vector cosline1_vqf32_low = Q6_Vqf32_vadd_VsfVsf(cosline1_low, Q6_V_vzero()); + HVX_Vector cos_middle_value_qf32_low = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(sline1_half_pair), cosline1_vqf32_low); + HVX_Vector sinline1_vqf32_low = Q6_Vqf32_vadd_VsfVsf(sinline1_low, Q6_V_vzero()); + HVX_Vector sin_middle_value_qf32_low = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(sline1_pair), sinline1_vqf32_low); - // auto value2 = in_value * sin_value + in_value_2 * cos_value; + middle_value_half_low = Q6_Vqf32_vadd_Vqf32Vqf32(cos_middle_value_qf32_low, sin_middle_value_qf32_low); + } - HVX_Vector middle_value_half_low; - { - HVX_Vector cosline1_vqf32_low = Q6_Vqf32_vadd_VsfVsf(cosline1_low, Q6_V_vzero()); - HVX_Vector cos_middle_value_qf32_low = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(sline1_half_pair), cosline1_vqf32_low); + // second qf16 vector + HVX_Vector middle_value_high; + { + HVX_Vector cosline1_vqf32_high = Q6_Vqf32_vadd_VsfVsf(cosline1_high, Q6_V_vzero()); + HVX_Vector cos_middle_value_qf32_high = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(sline1_pair), cosline1_vqf32_high); - HVX_Vector sinline1_vqf32_low = Q6_Vqf32_vadd_VsfVsf(sinline1_low, Q6_V_vzero()); - HVX_Vector sin_middle_value_qf32_low = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(sline1_pair), sinline1_vqf32_low); + HVX_Vector sinline1_vqf32_high = Q6_Vqf32_vadd_VsfVsf(sinline1_high, Q6_V_vzero()); - middle_value_half_low = Q6_Vqf32_vadd_Vqf32Vqf32(cos_middle_value_qf32_low, sin_middle_value_qf32_low); - } + HVX_Vector sin_middle_value_qf32_high = Q6_Vqf32_vmpy_Vqf32Vqf32(m_sline1_half_hi, sinline1_vqf32_high); + middle_value_high = Q6_Vqf32_vadd_Vqf32Vqf32(cos_middle_value_qf32_high, sin_middle_value_qf32_high); + } - // second qf16 vector - HVX_Vector middle_value_high; - { - HVX_Vector cosline1_vqf32_high= Q6_Vqf32_vadd_VsfVsf(cosline1_high, Q6_V_vzero()); - HVX_Vector cos_middle_value_qf32_high = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(sline1_pair), cosline1_vqf32_high); + // auto value2 = in_value * sin_value + in_value_2 * cos_value; - HVX_Vector sinline1_vqf32_high = Q6_Vqf32_vadd_VsfVsf(sinline1_high, Q6_V_vzero()); + HVX_Vector middle_value_half_high; + { + HVX_Vector cosline1_vqf32_high = Q6_Vqf32_vadd_VsfVsf(cosline1_high, Q6_V_vzero()); + HVX_Vector cos_middle_value_qf32_high = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(sline1_half_pair), cosline1_vqf32_high); - HVX_Vector sin_middle_value_qf32_high = Q6_Vqf32_vmpy_Vqf32Vqf32(m_sline1_half_hi, sinline1_vqf32_high); - middle_value_high = Q6_Vqf32_vadd_Vqf32Vqf32(cos_middle_value_qf32_high, sin_middle_value_qf32_high); - } - + HVX_Vector sinline1_vqf32_high = Q6_Vqf32_vadd_VsfVsf(sinline1_high, Q6_V_vzero()); + HVX_Vector sin_middle_value_qf32_high = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(sline1_pair), sinline1_vqf32_high); + middle_value_half_high = Q6_Vqf32_vadd_Vqf32Vqf32(cos_middle_value_qf32_high, sin_middle_value_qf32_high); + } - // auto value2 = in_value * sin_value + in_value_2 * cos_value; + HVX_Vector sline = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(middle_value_high, middle_value_low)); + sline = Q6_Vh_vdeal_Vh(sline); - HVX_Vector middle_value_half_high; - { - HVX_Vector cosline1_vqf32_high = Q6_Vqf32_vadd_VsfVsf(cosline1_high, Q6_V_vzero()); - HVX_Vector cos_middle_value_qf32_high = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(sline1_half_pair), cosline1_vqf32_high); + HVX_Vector sline_half = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(middle_value_half_high, middle_value_half_low)); + sline_half = Q6_Vh_vdeal_Vh(sline_half); - HVX_Vector sinline1_vqf32_high = Q6_Vqf32_vadd_VsfVsf(sinline1_high, Q6_V_vzero()); - HVX_Vector sin_middle_value_qf32_high = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(sline1_pair), sinline1_vqf32_high); + *joptr = sline; + *joptr_half = sline_half; - middle_value_half_high = Q6_Vqf32_vadd_Vqf32Vqf32(cos_middle_value_qf32_high, sin_middle_value_qf32_high); + jiptr += partial_dimension / 64; + jiptr_half += partial_dimension / 64; + joptr += partial_dimension / 64; + joptr_half += partial_dimension / 64; } - - HVX_Vector sline = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(middle_value_high, middle_value_low)); - sline = Q6_Vh_vdeal_Vh(sline); - - HVX_Vector sline_half = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(middle_value_half_high, middle_value_half_low)); - sline_half = Q6_Vh_vdeal_Vh(sline_half); - - *joptr = sline; - *joptr_half = sline_half; - - jiptr += partial_dimension/64; - jiptr_half += partial_dimension/64; - joptr += partial_dimension/64; - joptr_half += partial_dimension/64; - } } } @@ -631,448 +571,410 @@ int32_t hvx_rope_ahf( // } - if (leftover_size > 0) - return -1; + return -1; return 0; } - -template -GraphStatus ropeImpl(TensorType& out_0, - const TensorType& in_0, - const TensorType& sin, - const TensorType& cos, +template +GraphStatus ropeImpl(TensorType &out_0, + const TensorType &in_0, + const TensorType &sin, + const TensorType &cos, const TensorType1 &h_cnt, - const Tensor& pose_type) -{ + const Tensor &pose_type) { + out_0.set_dims(in_0); - out_0.set_dims(in_0); + auto pose_type_ = pose_type(0, 0, 0, 0); + auto h_cnt_ = static_cast(h_cnt(0, 0, 0, 0)); - auto pose_type_ = pose_type(0,0,0,0); - auto h_cnt_ = static_cast(h_cnt(0,0,0,0)); + if (pose_type_ == 4) { + DType dtype = out_0.get_dtype(); - if (pose_type_ == 4) { + if (in_0.get_dtype() == DType::Float32 && dtype == DType::Float32) { + auto in_ptr = (float *)in_0.raw_data_const(); + auto sin_ptr = (float *)sin.raw_data_const(); + auto cos_ptr = (float *)cos.raw_data_const(); + auto out_ptr = (float *)out_0.raw_data(); - DType dtype = out_0.get_dtype(); + auto [b_in, h_in, w_in, d_in] = in_0.dims(); - if (in_0.get_dtype() == DType::Float32 && dtype == DType::Float32) { - auto in_ptr = (float*)in_0.raw_data_const(); - auto sin_ptr = (float*)sin.raw_data_const(); - auto cos_ptr = (float*)cos.raw_data_const(); - auto out_ptr = (float*)out_0.raw_data(); + uint32_t half_dimension = d_in / 2; + sin_ptr += half_dimension * h_cnt_; + cos_ptr += half_dimension * h_cnt_; + int partial_dimension = d_in; - auto [b_in, h_in, w_in, d_in] = in_0.dims(); - - uint32_t half_dimension = d_in / 2; - sin_ptr += half_dimension * h_cnt_; - cos_ptr += half_dimension * h_cnt_; - - int partial_dimension = d_in; + // NSHD + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + // for (Idx w = 0; w < w_in; w++) { + hvx_rope_af(in_ptr, sin_ptr, cos_ptr, out_ptr, w_in * d_in, partial_dimension); - // NSHD - for (Idx b = 0; b < b_in; b++) { - for (Idx h = 0; h < h_in; h++) { - - // for (Idx w = 0; w < w_in; w++) { - hvx_rope_af(in_ptr, sin_ptr, cos_ptr, out_ptr, w_in * d_in, partial_dimension); - - in_ptr += w_in * d_in; - out_ptr += w_in * d_in; - // } + in_ptr += w_in * d_in; + out_ptr += w_in * d_in; + // } - sin_ptr += half_dimension; - cos_ptr += half_dimension; - } - } - } else if (in_0.get_dtype() == DType::Float16 && dtype == DType::Float16) { + sin_ptr += half_dimension; + cos_ptr += half_dimension; + } + } + } else if (in_0.get_dtype() == DType::Float16 && dtype == DType::Float16) { + auto in_ptr = (__fp16 *)in_0.raw_data_const(); + auto sin_ptr = (float *)sin.raw_data_const(); + auto cos_ptr = (float *)cos.raw_data_const(); + auto out_ptr = (__fp16 *)out_0.raw_data(); - auto in_ptr = (__fp16*)in_0.raw_data_const(); - auto sin_ptr = (float*)sin.raw_data_const(); - auto cos_ptr = (float*)cos.raw_data_const(); - auto out_ptr = (__fp16*)out_0.raw_data(); + auto [b_in, h_in, w_in, d_in] = in_0.dims(); + uint32_t half_dimension = d_in / 2; + sin_ptr += half_dimension * h_cnt_; + cos_ptr += half_dimension * h_cnt_; - auto [b_in, h_in, w_in, d_in] = in_0.dims(); + int partial_dimension = d_in; - uint32_t half_dimension = d_in / 2; - sin_ptr += half_dimension * h_cnt_; - cos_ptr += half_dimension * h_cnt_; + // NSHD + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + // for (Idx w = 0; w < w_in; w++) { + hvx_rope_ahf(in_ptr, sin_ptr, cos_ptr, out_ptr, w_in * d_in, partial_dimension); - int partial_dimension = d_in; + in_ptr += w_in * d_in; + out_ptr += w_in * d_in; + // } - // NSHD - for (Idx b = 0; b < b_in; b++) { - for (Idx h = 0; h < h_in; h++) { - - // for (Idx w = 0; w < w_in; w++) { - hvx_rope_ahf(in_ptr, sin_ptr, cos_ptr, out_ptr, w_in * d_in, partial_dimension); - - in_ptr += w_in * d_in; - out_ptr += w_in * d_in; - // } + sin_ptr += half_dimension; + cos_ptr += half_dimension; + } + } + } else if (in_0.get_dtype() == DType::QUInt8 && dtype == DType::Float32) { + auto in_ptr = (uint8_t *)in_0.raw_data_const(); + auto sin_ptr = (float *)sin.raw_data_const(); + auto cos_ptr = (float *)cos.raw_data_const(); + auto out_ptr = (float *)out_0.raw_data(); - sin_ptr += half_dimension; - cos_ptr += half_dimension; - } - } - } else if (in_0.get_dtype() == DType::QUInt8 && dtype == DType::Float32) { - auto in_ptr = (uint8_t*)in_0.raw_data_const(); - auto sin_ptr = (float*)sin.raw_data_const(); - auto cos_ptr = (float*)cos.raw_data_const(); - auto out_ptr = (float*)out_0.raw_data(); - - - auto [b_in, h_in, w_in, d_in] = in_0.dims(); - - uint32_t half_dimension = d_in / 2; - sin_ptr += half_dimension * h_cnt_; - cos_ptr += half_dimension * h_cnt_; - - int partial_dimension = d_in; - - // NSHD - for (Idx b = 0; b < b_in; b++) { - for (Idx h = 0; h < h_in; h++) { - - // for (Idx w = 0; w < w_in; w++) { - hvx_rope_uint8_af(in_ptr, sin_ptr, cos_ptr, out_ptr, w_in * d_in, partial_dimension); - - in_ptr += w_in * d_in; - out_ptr += w_in * d_in; - // } - - sin_ptr += half_dimension; - cos_ptr += half_dimension; - } - } - } else if (in_0.get_dtype() == DType::QUInt8 && dtype == DType::Float16) { + auto [b_in, h_in, w_in, d_in] = in_0.dims(); - auto in_ptr = (uint8_t*)in_0.raw_data_const(); - auto sin_ptr = (float*)sin.raw_data_const(); - auto cos_ptr = (float*)cos.raw_data_const(); - auto out_ptr = (__fp16*)out_0.raw_data(); + uint32_t half_dimension = d_in / 2; + sin_ptr += half_dimension * h_cnt_; + cos_ptr += half_dimension * h_cnt_; - float scale_ = in_0.get_interface_scale(); + int partial_dimension = d_in; - auto [b_in, h_in, w_in, d_in] = in_0.dims(); + // NSHD + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + // for (Idx w = 0; w < w_in; w++) { + hvx_rope_uint8_af(in_ptr, sin_ptr, cos_ptr, out_ptr, w_in * d_in, partial_dimension); - uint32_t half_dimension = d_in / 2; - sin_ptr += half_dimension * h_cnt_; - cos_ptr += half_dimension * h_cnt_; + in_ptr += w_in * d_in; + out_ptr += w_in * d_in; + // } - int partial_dimension = d_in; + sin_ptr += half_dimension; + cos_ptr += half_dimension; + } + } + } else if (in_0.get_dtype() == DType::QUInt8 && dtype == DType::Float16) { + auto in_ptr = (uint8_t *)in_0.raw_data_const(); + auto sin_ptr = (float *)sin.raw_data_const(); + auto cos_ptr = (float *)cos.raw_data_const(); + auto out_ptr = (__fp16 *)out_0.raw_data(); - // NSHD - for (Idx b = 0; b < b_in; b++) { - for (Idx h = 0; h < h_in; h++) { - - // for (Idx w = 0; w < w_in; w++) { - hvx_rope_uint8_ahf(in_ptr, sin_ptr, cos_ptr, out_ptr, w_in * d_in, partial_dimension, scale_); - - in_ptr += w_in * d_in; - out_ptr += w_in * d_in; - // } + float scale_ = in_0.get_interface_scale(); - sin_ptr += half_dimension; - cos_ptr += half_dimension; - } - } - } + auto [b_in, h_in, w_in, d_in] = in_0.dims(); + uint32_t half_dimension = d_in / 2; + sin_ptr += half_dimension * h_cnt_; + cos_ptr += half_dimension * h_cnt_; - } else { + int partial_dimension = d_in; - // only support pose_type == 2 (LLaMA) now - return GraphStatus::ErrorFatal; + // NSHD + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + // for (Idx w = 0; w < w_in; w++) { + hvx_rope_uint8_ahf(in_ptr, sin_ptr, cos_ptr, out_ptr, w_in * d_in, partial_dimension, scale_); - } + in_ptr += w_in * d_in; + out_ptr += w_in * d_in; + // } + sin_ptr += half_dimension; + cos_ptr += half_dimension; + } + } + } - - - - return GraphStatus::Success; + } else { + // only support pose_type == 2 (LLaMA) now + return GraphStatus::ErrorFatal; + } + return GraphStatus::Success; } - - #else - -template -GraphStatus ropeImpl(TensorType& out_0, - const TensorType& in_0, - const TensorType& sin, - const TensorType& cos, +template +GraphStatus ropeImpl(TensorType &out_0, + const TensorType &in_0, + const TensorType &sin, + const TensorType &cos, const TensorType1 &h_cnt, - const Tensor& pose_type) + const Tensor &pose_type) { - /* - * add code here - * */ - /* - * To have good performance and stability, it is required to avoid heap memory - * allocation in this function. The heap memory allocation includes but not - * limited to calling malloc, operator new, constructing STL container objects - * like std::vector with default allocator, and adding items like calling - * std::vector::push_back to STL container objects with default allocator. - * - * Please check in SDK documentation for more information. - */ - - debuglog("RoPE execute... dims=(%zdx%zdx%zdx%zd)", in_0.dim(0), in_0.dim(1), in_0.dim(2), in_0.dim(3)); - debuglog("RoPE execute... dims=(%zdx%zdx%zdx%zd)", sin.dim(0), sin.dim(1), sin.dim(2), sin.dim(3)); - debuglog("RoPE execute... dims=(%zdx%zdx%zdx%zd)", cos.dim(0), cos.dim(1), cos.dim(2), cos.dim(3)); - - // BSHD => NHWC - - // Todo: We need consider to store the sequence position if we have KV Cache - - auto pose_type_ = pose_type(0,0,0,0); - auto h_cnt_ = static_cast(h_cnt(0,0,0,0)); - - out_0.set_dims(in_0); - auto [b_in, h_in, w_in, d_in] = in_0.dims(); - if (pose_type_ == 4) { - DType dtype = out_0.get_dtype(); - - if (dtype == DType::Float32) { - for (Idx b = 0; b < b_in; b++) { - for (Idx h = 0; h < h_in; h++) { - for (Idx w = 0; w < w_in; w++) { - - int s = h; // BSHD order - int partial_dimension = d_in; - int half = (int)(partial_dimension / 2); - for (Idx d = 0; d < partial_dimension / 2; ++d) { - float in_value = in_0(b, h, w, d); - float in_value_2 = in_0(b, h, w, d + half); - float sin_value = sin(0, 0, s + h_cnt_, d); - float cos_value = cos(0, 0, s + h_cnt_, d); - auto value = in_value * cos_value - in_value_2 * sin_value; - auto value2 = in_value * sin_value + in_value_2 * cos_value; - out_0(b, h, w, d) = value; - out_0(b, h, w, d + half) = value2; + /* + * add code here + * */ + /* + * To have good performance and stability, it is required to avoid heap memory + * allocation in this function. The heap memory allocation includes but not + * limited to calling malloc, operator new, constructing STL container objects + * like std::vector with default allocator, and adding items like calling + * std::vector::push_back to STL container objects with default allocator. + * + * Please check in SDK documentation for more information. + */ + + debuglog("RoPE execute... dims=(%zdx%zdx%zdx%zd)", in_0.dim(0), in_0.dim(1), in_0.dim(2), in_0.dim(3)); + debuglog("RoPE execute... dims=(%zdx%zdx%zdx%zd)", sin.dim(0), sin.dim(1), sin.dim(2), sin.dim(3)); + debuglog("RoPE execute... dims=(%zdx%zdx%zdx%zd)", cos.dim(0), cos.dim(1), cos.dim(2), cos.dim(3)); + + // BSHD => NHWC + + // Todo: We need consider to store the sequence position if we have KV Cache + + auto pose_type_ = pose_type(0, 0, 0, 0); + auto h_cnt_ = static_cast(h_cnt(0, 0, 0, 0)); + + out_0.set_dims(in_0); + auto [b_in, h_in, w_in, d_in] = in_0.dims(); + if (pose_type_ == 4) { + DType dtype = out_0.get_dtype(); + + if (dtype == DType::Float32) { + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + for (Idx w = 0; w < w_in; w++) { + int s = h; // BSHD order + int partial_dimension = d_in; + int half = (int)(partial_dimension / 2); + for (Idx d = 0; d < partial_dimension / 2; ++d) { + float in_value = in_0(b, h, w, d); + float in_value_2 = in_0(b, h, w, d + half); + float sin_value = sin(0, 0, s + h_cnt_, d); + float cos_value = cos(0, 0, s + h_cnt_, d); + auto value = in_value * cos_value - in_value_2 * sin_value; + auto value2 = in_value * sin_value + in_value_2 * cos_value; + out_0(b, h, w, d) = value; + out_0(b, h, w, d + half) = value2; + } + } + } + } + } else if (dtype == DType::Float16) { + auto in_ptr = (__fp16 *)in_0.raw_data_const(); + // auto sin_ptr = (__fp16*)sin.raw_data_const(); + // auto cos_ptr = (__fp16*)cos.raw_data_const(); + auto out_ptr = (__fp16 *)out_0.raw_data(); + + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + for (Idx w = 0; w < w_in; w++) { + int s = h; // BSHD order + int partial_dimension = d_in; + int half = (int)(partial_dimension / 2); + for (Idx d = 0; d < partial_dimension / 2; ++d) { + __fp16 in_value = *in_ptr; + __fp16 in_value_2 = *(in_ptr + half); + float sin_value = sin(0, 0, s + h_cnt_, d); + float cos_value = cos(0, 0, s + h_cnt_, d); + auto value = in_value * cos_value - in_value_2 * sin_value; + auto value2 = in_value * sin_value + in_value_2 * cos_value; + *out_ptr = static_cast<__fp16>(value); + *(out_ptr + half) = static_cast<__fp16>(value2); + + out_ptr++; + in_ptr++; + } + + out_ptr += half; + in_ptr += half; + } + } } - } } - } - } else if (dtype == DType::Float16) { - - auto in_ptr = (__fp16*)in_0.raw_data_const(); - // auto sin_ptr = (__fp16*)sin.raw_data_const(); - // auto cos_ptr = (__fp16*)cos.raw_data_const(); - auto out_ptr = (__fp16*)out_0.raw_data(); - - for (Idx b = 0; b < b_in; b++) { - for (Idx h = 0; h < h_in; h++) { - for (Idx w = 0; w < w_in; w++) { + } - int s = h; // BSHD order - int partial_dimension = d_in; - int half = (int)(partial_dimension / 2); - for (Idx d = 0; d < partial_dimension / 2; ++d) { - __fp16 in_value = *in_ptr; - __fp16 in_value_2 = *(in_ptr + half); - float sin_value = sin(0, 0, s + h_cnt_, d); - float cos_value = cos(0, 0, s + h_cnt_, d); - auto value = in_value * cos_value - in_value_2 * sin_value; - auto value2 = in_value * sin_value + in_value_2 * cos_value; - *out_ptr = static_cast<__fp16>(value); - *(out_ptr + half) = static_cast<__fp16>(value2); - - out_ptr++; - in_ptr++; - } + // for (Idx b = 0; b < b_in; b++) { + // for (Idx h = 0; h < h_in; h++) { + // for (Idx w = 0; w < w_in; w++) { + // // RoPE + // for (Idx d = 0; d < d_in; d++) { + + // int s = h; // BSHD order + // if (pose_type_ == 1) { + // float in_value = in_0(b, h, w, d); + // float in_value_2; + // if (d < d_in / 2) { // 偶數 0,2,4 + // in_value_2 = -in_0(b, h, w, d + d_in / 2); + // } else { + // in_value_2 = in_0(b, h, w, d - d_in / 2); + // } + // float sin_value = sin(0, 0, s +h_cnt_, d); + // float cos_value = cos(0, 0, s +h_cnt_, d); + // auto value = in_value * cos_value + in_value_2 * sin_value; + // out_0(b, h, w, d) = value; + // } + // else if (pose_type_ == 2) { + // float in_value = in_0(b, h, w, d); + // debuglog("rope execute... in_value=(%f)", in_value); + // float in_value_2; + // if (d % 2 == 0) { // 偶數 0,2,4 + // in_value_2 = -in_0(b, h, w, d + 1); + // } else { + // in_value_2 = in_0(b, h, w, d - 1); + // } + // debuglog("rope execute... in_value_2=(%f)", in_value_2); + // float sin_value = sin(0, 0, s +h_cnt_, d); + // float cos_value = cos(0, 0, s +h_cnt_, d); + // auto value = in_value * cos_value + in_value_2 * sin_value; + + // debuglog("rope execute... sin_value=(%f)", sin_value); + // debuglog("rope execute... cos_value=(%f)", cos_value); + + // debuglog("rope execute... value=(%f)", value); + // out_0(b, h, w, d) = value; + // } else if (pose_type_ == 4) { + // } else { + // float in_value = in_0(b, h, w, d); + // float in_value_2; + // float sin_value = sin(0, 0, s +h_cnt_, d); + // float cos_value = cos(0, 0, s +h_cnt_, d); + // if (d < d_in / 4) { + // in_value_2 = -in_0(b, h, w, d + d_in / 4); + // auto value = in_value * cos_value + in_value_2 * sin_value; + + // out_0(b ,h , w, d) = value; + // } else if(d < d_in / 2){ + // in_value_2 = in_0(b, h, w, d - d_in / 4); + // auto value = in_value * cos_value + in_value_2 * sin_value; + + // out_0(b ,h , w, d) = value; + // }else { + + // out_0(b ,h , w, d) = in_value; + // } + // } + + // } + // } + // } + // } - out_ptr += half; - in_ptr += half; - } - } - } - } - } - - // for (Idx b = 0; b < b_in; b++) { - // for (Idx h = 0; h < h_in; h++) { - // for (Idx w = 0; w < w_in; w++) { - // // RoPE - // for (Idx d = 0; d < d_in; d++) { - - - // int s = h; // BSHD order - // if (pose_type_ == 1) { - // float in_value = in_0(b, h, w, d); - // float in_value_2; - // if (d < d_in / 2) { // 偶數 0,2,4 - // in_value_2 = -in_0(b, h, w, d + d_in / 2); - // } else { - // in_value_2 = in_0(b, h, w, d - d_in / 2); - // } - // float sin_value = sin(0, 0, s +h_cnt_, d); - // float cos_value = cos(0, 0, s +h_cnt_, d); - // auto value = in_value * cos_value + in_value_2 * sin_value; - // out_0(b, h, w, d) = value; - // } - // else if (pose_type_ == 2) { - // float in_value = in_0(b, h, w, d); - // debuglog("rope execute... in_value=(%f)", in_value); - // float in_value_2; - // if (d % 2 == 0) { // 偶數 0,2,4 - // in_value_2 = -in_0(b, h, w, d + 1); - // } else { - // in_value_2 = in_0(b, h, w, d - 1); - // } - // debuglog("rope execute... in_value_2=(%f)", in_value_2); - // float sin_value = sin(0, 0, s +h_cnt_, d); - // float cos_value = cos(0, 0, s +h_cnt_, d); - // auto value = in_value * cos_value + in_value_2 * sin_value; - - // debuglog("rope execute... sin_value=(%f)", sin_value); - // debuglog("rope execute... cos_value=(%f)", cos_value); - - // debuglog("rope execute... value=(%f)", value); - // out_0(b, h, w, d) = value; - // } else if (pose_type_ == 4) { - // } else { - // float in_value = in_0(b, h, w, d); - // float in_value_2; - // float sin_value = sin(0, 0, s +h_cnt_, d); - // float cos_value = cos(0, 0, s +h_cnt_, d); - // if (d < d_in / 4) { - // in_value_2 = -in_0(b, h, w, d + d_in / 4); - // auto value = in_value * cos_value + in_value_2 * sin_value; - - // out_0(b ,h , w, d) = value; - // } else if(d < d_in / 2){ - // in_value_2 = in_0(b, h, w, d - d_in / 4); - // auto value = in_value * cos_value + in_value_2 * sin_value; - - // out_0(b ,h , w, d) = value; - // }else { - - // out_0(b ,h , w, d) = in_value; - // } - // } - - // } - // } - // } - // } - - -// auto &input = inputs[0]; -// auto &output = outputs[0]; -// for (int n = 0; n < input->batch(); ++n) { -// for (int h = 0; h < input->head(); ++h) { -// for (int s = 0; s < input->sequence(); ++s) {//sequance -// #pragma omp parallel for num_threads(4) -// for (int d = 0; d < input->dimension(); ++d) { -// if (pose_type_== 1) { -// float in_value = input->dataAt(n, h, s, d); -// float in_value_2; -// if (d < input->dimension() / 2) { // 偶數 0,2,4 -// in_value_2 = -input->dataAt(n, h, s, d + input->dimension() / 2); -// } else { -// in_value_2 = input->dataAt(n, h, s, d - input->dimension() / 2); -// } -// float sin_value = sin_.dataAt(0, 0, s +h_cnt_, d); -// float cos_value = cos_.dataAt(0, 0, s +h_cnt_, d); -// auto value = in_value * cos_value + in_value_2 * sin_value; -// if(output->dtypeAt(n,h,s, d) == MLLM_TYPE_F32) { -// output->setDataAt(n, h, s, d, value); -// } -// else if(output->dtypeAt(n,h,s, d) == MLLM_TYPE_F16) { -// output->setDataAt(n, h, s, d, MLLM_FP32_TO_FP16(value)); -// } -// } -// else if (pose_type_== 2) { -// float in_value = input->dataAt(n, h, s, d); -// float in_value_2; -// if (d % 2 == 0) { // 偶數 0,2,4 -// in_value_2 = -input->dataAt(n, h, s, d + 1); -// } else { -// in_value_2 = input->dataAt(n, h, s, d - 1); -// } -// float sin_value = sin_.dataAt(0, 0, s +h_cnt_, d); -// float cos_value = cos_.dataAt(0, 0, s +h_cnt_, d); -// auto value = in_value * cos_value + in_value_2 * sin_value; -// if(output->dtypeAt(n,h,s, d) == MLLM_TYPE_F32) { -// output->setDataAt(n, h, s, d, value); -// } -// else if(output->dtypeAt(n,h,s, d) == MLLM_TYPE_F16) { -// output->setDataAt(n, h, s, d, MLLM_FP32_TO_FP16(value)); -// } -// }else{ -// float in_value = input->dataAt(n, h, s, d); -// float in_value_2; -// float sin_value = sin_.dataAt(0, 0, s +h_cnt_, d); -// float cos_value = cos_.dataAt(0, 0, s +h_cnt_, d); -// if (d < input->dimension() / 4) { -// in_value_2 = - input->dataAt(n, h, s, d + input->dimension() / 4); -// auto value = in_value * cos_value + in_value_2 * sin_value; -// if(output->dtypeAt(n,h,s, d) == MLLM_TYPE_F32) { -// output->setDataAt(n, h, s, d, value); -// } -// else if(output->dtypeAt(n,h,s, d) == MLLM_TYPE_F16) { -// output->setDataAt(n, h, s, d, MLLM_FP32_TO_FP16(value)); -// } -// } else if(d < input->dimension() / 2){ -// in_value_2 = input->dataAt(n, h, s, d - input->dimension() / 4); -// auto value = in_value * cos_value + in_value_2 * sin_value; -// if(output->dtypeAt(n,h,s, d) == MLLM_TYPE_F32) { -// output->setDataAt(n, h, s, d, value); -// } -// else if(output->dtypeAt(n,h,s, d) == MLLM_TYPE_F16) { -// output->setDataAt(n, h, s, d, MLLM_FP32_TO_FP16(value)); -// } -// }else { -// if(output->dtypeAt(n,h,s, d) == MLLM_TYPE_F32) { -// output->setDataAt(n, h, s, d, in_value); -// } -// else if(output->dtypeAt(n,h,s, d) == MLLM_TYPE_F16) { -// output->setDataAt(n, h, s, d, MLLM_FP32_TO_FP16(in_value)); -// } -// } -// } -// } -// } -// } -// } - - -// Todo store history position -// h_cnt_ += input->sequence(); -// if(h_cnt_ >pos_max_){ -// h_cnt_ = 0; -// } - - - return GraphStatus::Success; + // auto &input = inputs[0]; + // auto &output = outputs[0]; + // for (int n = 0; n < input->batch(); ++n) { + // for (int h = 0; h < input->head(); ++h) { + // for (int s = 0; s < input->sequence(); ++s) {//sequance + // #pragma omp parallel for num_threads(4) + // for (int d = 0; d < input->dimension(); ++d) { + // if (pose_type_== 1) { + // float in_value = input->dataAt(n, h, s, d); + // float in_value_2; + // if (d < input->dimension() / 2) { // 偶數 0,2,4 + // in_value_2 = -input->dataAt(n, h, s, d + input->dimension() / 2); + // } else { + // in_value_2 = input->dataAt(n, h, s, d - input->dimension() / 2); + // } + // float sin_value = sin_.dataAt(0, 0, s +h_cnt_, d); + // float cos_value = cos_.dataAt(0, 0, s +h_cnt_, d); + // auto value = in_value * cos_value + in_value_2 * sin_value; + // if(output->dtypeAt(n,h,s, d) == MLLM_TYPE_F32) { + // output->setDataAt(n, h, s, d, value); + // } + // else if(output->dtypeAt(n,h,s, d) == MLLM_TYPE_F16) { + // output->setDataAt(n, h, s, d, MLLM_FP32_TO_FP16(value)); + // } + // } + // else if (pose_type_== 2) { + // float in_value = input->dataAt(n, h, s, d); + // float in_value_2; + // if (d % 2 == 0) { // 偶數 0,2,4 + // in_value_2 = -input->dataAt(n, h, s, d + 1); + // } else { + // in_value_2 = input->dataAt(n, h, s, d - 1); + // } + // float sin_value = sin_.dataAt(0, 0, s +h_cnt_, d); + // float cos_value = cos_.dataAt(0, 0, s +h_cnt_, d); + // auto value = in_value * cos_value + in_value_2 * sin_value; + // if(output->dtypeAt(n,h,s, d) == MLLM_TYPE_F32) { + // output->setDataAt(n, h, s, d, value); + // } + // else if(output->dtypeAt(n,h,s, d) == MLLM_TYPE_F16) { + // output->setDataAt(n, h, s, d, MLLM_FP32_TO_FP16(value)); + // } + // }else{ + // float in_value = input->dataAt(n, h, s, d); + // float in_value_2; + // float sin_value = sin_.dataAt(0, 0, s +h_cnt_, d); + // float cos_value = cos_.dataAt(0, 0, s +h_cnt_, d); + // if (d < input->dimension() / 4) { + // in_value_2 = - input->dataAt(n, h, s, d + input->dimension() / 4); + // auto value = in_value * cos_value + in_value_2 * sin_value; + // if(output->dtypeAt(n,h,s, d) == MLLM_TYPE_F32) { + // output->setDataAt(n, h, s, d, value); + // } + // else if(output->dtypeAt(n,h,s, d) == MLLM_TYPE_F16) { + // output->setDataAt(n, h, s, d, MLLM_FP32_TO_FP16(value)); + // } + // } else if(d < input->dimension() / 2){ + // in_value_2 = input->dataAt(n, h, s, d - input->dimension() / 4); + // auto value = in_value * cos_value + in_value_2 * sin_value; + // if(output->dtypeAt(n,h,s, d) == MLLM_TYPE_F32) { + // output->setDataAt(n, h, s, d, value); + // } + // else if(output->dtypeAt(n,h,s, d) == MLLM_TYPE_F16) { + // output->setDataAt(n, h, s, d, MLLM_FP32_TO_FP16(value)); + // } + // }else { + // if(output->dtypeAt(n,h,s, d) == MLLM_TYPE_F32) { + // output->setDataAt(n, h, s, d, in_value); + // } + // else if(output->dtypeAt(n,h,s, d) == MLLM_TYPE_F16) { + // output->setDataAt(n, h, s, d, MLLM_FP32_TO_FP16(in_value)); + // } + // } + // } + // } + // } + // } + // } + + // Todo store history position + // h_cnt_ += input->sequence(); + // if(h_cnt_ >pos_max_){ + // h_cnt_ = 0; + // } + + return GraphStatus::Success; } #endif +__attribute__((unused)) static float ropeCostFunc(const Op *op) { + /* + * add code here + * */ -__attribute__((unused)) static float ropeCostFunc(const Op *op) -{ - /* - * add code here - * */ - - float cost = 0.0; // add cost computation here - return cost; + float cost = 0.0; // add cost computation here + return cost; } - - - - /* At the bottom of the op file, call END_PKG_OP_DEFINITION(), where is as BEGIN_PKG_OP_DEFINITION */ diff --git a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/SiLU.cpp b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/SiLU.cpp index 28271772f..8b56e7e80 100755 --- a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/SiLU.cpp +++ b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/SiLU.cpp @@ -9,14 +9,12 @@ #include "QnnOpPackage.h" #include "HTP/core/simple_reg.h" - BEGIN_PKG_OP_DEFINITION(PKG_SiLU); - // op execute function declarations -template -GraphStatus siluImpl(TensorType& out_0, - const TensorType& in_0); +template +GraphStatus siluImpl(TensorType &out_0, + const TensorType &in_0); // forward declaration of sample cost function static float siluCostFunc(const Op *op); @@ -61,11 +59,11 @@ DEF_PACKAGE_OP((siluImpl), "SiLU") * one definition per op, and this is optional * syntax: DEF_PACKAGE_PARAM_ORDER(OP,PARAM1,MANDATORY1,DEFAULT1,PARAM2,MANDATORY2,DEFAULT2...) * one or more parameters can be specified for each op - * order of parameters listed determines the order of parameters passed into op execution functions + * order of parameters listed determines the order of parameters passed into op execution functions * if an op does not have a parameter order definition, parameter order passed into Qnn_addNode * will be passed into op execution functions * if an op has a parameter order definition, any parameter passed into Qnn_addNode with unlisted - * name will be abandoned + * name will be abandoned * if two or more op packages with the same package name will be registered, they cannot list * conflicting parameter orders * PARAM refers to parameter name as a string literal @@ -79,7 +77,6 @@ DEF_PACKAGE_OP((siluImpl), "SiLU") * Qnn_addNode */ - /* execute functions for ops */ #ifndef REFERENCE_OP @@ -88,11 +85,10 @@ DEF_PACKAGE_OP((siluImpl), "SiLU") #include #include -#define BLOCK_SIZE (8*1024/VLEN) /* vector chunks */ -#define L2FETCH_AHEAD (BLOCK_SIZE) +#define BLOCK_SIZE (8 * 1024 / VLEN) /* vector chunks */ +#define L2FETCH_AHEAD (BLOCK_SIZE) -static inline int32_t float_to_fp16s(float input) -{ +static inline int32_t float_to_fp16s(float input) { union { int32_t i; __fp16 f[2]; @@ -100,48 +96,189 @@ static inline int32_t float_to_fp16s(float input) return fp32.i; } -static HVX_INLINE_ALWAYS uint32_t float_to_bits(float x) -{ - union { float f; uint32_t i; } fp32 = { .f = x }; +static HVX_INLINE_ALWAYS uint32_t float_to_bits(float x) { + union { + float f; + uint32_t i; + } fp32 = {.f = x}; return fp32.i; } - /* Polynomial coefficients */ static const float c0_coeffs[32] __attribute__((aligned(VLEN))) = -{ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.1329913082916337,0.22308514882873062,0.347752862580421,0.4845759228057826,0.5724725619240282,0.5532613332075828,0.5041402176920755,0.4999998945071365, -0.500005251569411,0.494975832882496,0.44426898861108216,0.42865769845972046,0.5186084804556764,0.6556781472810073,0.7780379623543565,0.8670752648575938, + { + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.1329913082916337, + 0.22308514882873062, + 0.347752862580421, + 0.4845759228057826, + 0.5724725619240282, + 0.5532613332075828, + 0.5041402176920755, + 0.4999998945071365, + 0.500005251569411, + 0.494975832882496, + 0.44426898861108216, + 0.42865769845972046, + 0.5186084804556764, + 0.6556781472810073, + 0.7780379623543565, + 0.8670752648575938, }; static const float c1_coeffs[32] __attribute__((aligned(VLEN))) = -{ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.0595948414501292,0.11153317908159224,0.19545701719511055,0.3058925677063833,0.3932668307015573,0.3630691859433203,0.26302954631996744,0.2499155333713503, -0.24983690256810576,0.26551386754654915,0.3670764533308477,0.39196882072648825,0.3030372911476408,0.19296191313371913,0.11084562978488391,0.059559556604464964, + { + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0595948414501292, + 0.11153317908159224, + 0.19545701719511055, + 0.3058925677063833, + 0.3932668307015573, + 0.3630691859433203, + 0.26302954631996744, + 0.2499155333713503, + 0.24983690256810576, + 0.26551386754654915, + 0.3670764533308477, + 0.39196882072648825, + 0.3030372911476408, + 0.19296191313371913, + 0.11084562978488391, + 0.059559556604464964, }; static const float c2_coeffs[32] __attribute__((aligned(VLEN))) = -{ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.010207999856103376,0.02144807112969563,0.04266485934992188,0.07616157468726052,0.10882760873715347,0.09125379784995667,0.013872106909816257,-0.0008786208359828815, -0.0011993845621092196,-0.01645080326288375,-0.09367947263571219,-0.10827006684348266,-0.07520301291634655,-0.04198514892887826,-0.021290356584896874,-0.010200991240527542, + { + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.010207999856103376, + 0.02144807112969563, + 0.04266485934992188, + 0.07616157468726052, + 0.10882760873715347, + 0.09125379784995667, + 0.013872106909816257, + -0.0008786208359828815, + 0.0011993845621092196, + -0.01645080326288375, + -0.09367947263571219, + -0.10827006684348266, + -0.07520301291634655, + -0.04198514892887826, + -0.021290356584896874, + -0.010200991240527542, }; static const float c3_coeffs[32] __attribute__((aligned(VLEN))) = -{ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.0007896351019423816,0.0018718593077865326,0.004259190313167949,0.008784166436796144,0.014228201960903939,0.009727536748893095,-0.01721317464724529,-0.023762851116001377, --0.02424226654277249,-0.01604104065157868,0.010376786273973133,0.014122038833203628,0.008641365746408176,0.004176981844803722,0.0018557930308154783,0.0007890167735032168, + { + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0007896351019423816, + 0.0018718593077865326, + 0.004259190313167949, + 0.008784166436796144, + 0.014228201960903939, + 0.009727536748893095, + -0.01721317464724529, + -0.023762851116001377, + -0.02424226654277249, + -0.01604104065157868, + 0.010376786273973133, + 0.014122038833203628, + 0.008641365746408176, + 0.004176981844803722, + 0.0018557930308154783, + 0.0007890167735032168, }; static const float c4_coeffs[32] __attribute__((aligned(VLEN))) = -{ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -2.3213858349988003e-05,6.232838199801025e-05,0.0001632037964535633,0.0003928983460811959,0.0007341577078787206,0.0003053082875419616,-0.003254838747910248,-0.004021655986643196, -0.004258314078650583,0.0030578644020607566,-0.00037014803880675387,-0.0007265964578827031,-0.0003849331969038772,-0.00015947916435728337,-6.171511304866758e-05,-2.319341439172678e-05, + { + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 2.3213858349988003e-05, + 6.232838199801025e-05, + 0.0001632037964535633, + 0.0003928983460811959, + 0.0007341577078787206, + 0.0003053082875419616, + -0.003254838747910248, + -0.004021655986643196, + 0.004258314078650583, + 0.0030578644020607566, + -0.00037014803880675387, + -0.0007265964578827031, + -0.0003849331969038772, + -0.00015947916435728337, + -6.171511304866758e-05, + -2.319341439172678e-05, }; /** @@ -151,8 +288,7 @@ static const float c4_coeffs[32] __attribute__((aligned(VLEN))) = * @param[in] length Number of elements in input/output arrays. * @return Returns 0 on successful execution. Otherwise -1. */ -int32_t hvx_silu_af(float *restrict input, float *restrict output, uint32_t size) -{ +int32_t hvx_silu_af(float *restrict input, float *restrict output, uint32_t size) { HVX_Vector *input_v_ptr; HVX_UVector *output_v_ptr; HVX_Vector input_min_v_f; @@ -191,13 +327,12 @@ int32_t hvx_silu_af(float *restrict input, float *restrict output, uint32_t size HVX_Vector f8, f_8; /* Check input arguments. Return error status if some argument has invalid value */ - if ((input == 0) || (output == 0) || (size == 0)) - { + if ((input == 0) || (output == 0) || (size == 0)) { return -1; } - input_v_ptr = (HVX_Vector *) input; - output_v_ptr = (HVX_UVector *) output; + input_v_ptr = (HVX_Vector *)input; + output_v_ptr = (HVX_UVector *)output; f8 = Q6_V_vsplat_R(float_to_bits(8.0f)); f_8 = Q6_V_vsplat_R(float_to_bits(-8.0f)); @@ -267,23 +402,20 @@ int32_t hvx_silu_af(float *restrict input, float *restrict output, uint32_t size * Handle number of whole vectors in input data. * Don't process last vector in order to avoid out-of-boundary load. */ - for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) - { + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { block = Q6_R_min_RR(i, BLOCK_SIZE); l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); - if (l2fetch_block > 0) - { + if (l2fetch_block > 0) { l2fetch(input_v_ptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); } /* Process one vector at a time */ - for (int32_t j = 0; j < block; ++j) - { + for (int32_t j = 0; j < block; ++j) { slinec = *input_v_ptr++; /* Compose vector of input data from slinec and slinep */ - sline = Q6_V_valign_VVR(slinec, slinep, (size_t) input); + sline = Q6_V_valign_VVR(slinec, slinep, (size_t)input); /* Shift input range from [input_min, input_max] to [0, input_max - input_min] */ input_shifted_v_qf32 = Q6_Vqf32_vsub_VsfVsf(sline, input_min_v_f); @@ -340,7 +472,7 @@ int32_t hvx_silu_af(float *restrict input, float *restrict output, uint32_t size // x * sigmod output_v = Q6_Vqf32_vmpy_Vqf32Vqf32(input_v_qf32, output_v); - + HVX_Vector out_v = Q6_Vsf_equals_Vqf32(output_v); HVX_VectorPred islf8 = Q6_Q_vcmp_gt_VsfVsf(sline, f8); @@ -349,7 +481,6 @@ int32_t hvx_silu_af(float *restrict input, float *restrict output, uint32_t size HVX_VectorPred islf_8 = Q6_Q_vcmp_gt_VsfVsf(f_8, sline); out_v = Q6_V_vmux_QVV(islf_8, zero_v_sf, out_v); - /* Store results to the output buffer and convert from qf32 to sf */ *((HVX_UVector *)(output_v_ptr++)) = out_v; @@ -359,10 +490,9 @@ int32_t hvx_silu_af(float *restrict input, float *restrict output, uint32_t size } /* Handle last whole vector from input data */ - if (vectors_in_rounddown > 0) - { + if (vectors_in_rounddown > 0) { slinec = is_aligned(input_v_ptr, VLEN) && leftover == 0 ? slinep : *input_v_ptr++; - sline = Q6_V_valign_VVR(slinec, slinep, (size_t) input); + sline = Q6_V_valign_VVR(slinec, slinep, (size_t)input); /* Shift input range from [input_min, input_max] to [0, input_max - input_min] */ input_shifted_v_qf32 = Q6_Vqf32_vsub_VsfVsf(sline, input_min_v_f); @@ -433,13 +563,10 @@ int32_t hvx_silu_af(float *restrict input, float *restrict output, uint32_t size } /* Handle leftover elements */ - if (leftover > 0) - { - slinec = (is_in_one_chunk(input_v_ptr, leftover_size, VLEN) - ? slinep - : *input_v_ptr++); + if (leftover > 0) { + slinec = (is_in_one_chunk(input_v_ptr, leftover_size, VLEN) ? slinep : *input_v_ptr++); - sline = Q6_V_valign_VVR(slinec, slinep, (size_t) input); + sline = Q6_V_valign_VVR(slinec, slinep, (size_t)input); /* Shift input range from [input_min, input_max] to [0, input_max - input_min] */ input_shifted_v_qf32 = Q6_Vqf32_vsub_VsfVsf(sline, input_min_v_f); @@ -510,41 +637,180 @@ int32_t hvx_silu_af(float *restrict input, float *restrict output, uint32_t size return 0; } - static const float fp16_c0_coeffs[32] __attribute__((aligned(VLEN))) = -{ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.13239719960243818,0.2216255210749415,0.3447664743728659,0.48137452032585476,0.5716299228719798,0.5547323231605259,0.5046287748870234,0.4999985574626892, -0.5000036514755082,0.49475652448004626,0.4441393352532763,0.428500379952032,0.5173297285470642,0.6541461039833616,0.7783931007462818,0.8678015179911097, + { + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.13239719960243818, + 0.2216255210749415, + 0.3447664743728659, + 0.48137452032585476, + 0.5716299228719798, + 0.5547323231605259, + 0.5046287748870234, + 0.4999985574626892, + 0.5000036514755082, + 0.49475652448004626, + 0.4441393352532763, + 0.428500379952032, + 0.5173297285470642, + 0.6541461039833616, + 0.7783931007462818, + 0.8678015179911097, }; static const float fp16_c1_coeffs[32] __attribute__((aligned(VLEN))) = -{ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.05928005756790343,0.11063222460270064,0.1932879057003057,0.30302440212086995,0.3922924462181049,0.36546332659415875,0.2644148210990377,0.24989020912329707, -0.2498532691910313,0.2661055781198988,0.36728015359480604,0.39215270010450015,0.3041825601732039,0.1940762094668647,0.11061794856987572,0.059174800917353595, + { + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.05928005756790343, + 0.11063222460270064, + 0.1932879057003057, + 0.30302440212086995, + 0.3922924462181049, + 0.36546332659415875, + 0.2644148210990377, + 0.24989020912329707, + 0.2498532691910313, + 0.2661055781198988, + 0.36728015359480604, + 0.39215270010450015, + 0.3041825601732039, + 0.1940762094668647, + 0.11061794856987572, + 0.059174800917353595, }; static const float fp16_c2_coeffs[32] __attribute__((aligned(VLEN))) = -{ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.010145494303219278,0.02123968384425681,0.04207468332514667,0.07519946712591977,0.10840620196267145,0.09270738184406795,0.015322371881818012,-0.0009948273994921822, -0.0011544907060402412,-0.017040517565094934,-0.09379878876657094,-0.10835043868732394,-0.07558705272699548,-0.04228875316413285,-0.021235740718738055,-0.010124599879590107, + { + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.010145494303219278, + 0.02123968384425681, + 0.04207468332514667, + 0.07519946712591977, + 0.10840620196267145, + 0.09270738184406795, + 0.015322371881818012, + -0.0009948273994921822, + 0.0011544907060402412, + -0.017040517565094934, + -0.09379878876657094, + -0.10835043868732394, + -0.07558705272699548, + -0.04228875316413285, + -0.021235740718738055, + -0.010124599879590107, }; static const float fp16_c3_coeffs[32] __attribute__((aligned(VLEN))) = -{ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.0007841223015974933,0.001850453397354219,0.004187899308371771,0.008640952434084206,0.01414741414964877,0.010117749275618,-0.01654848996354919,-0.02395108399453624, --0.024199111971064446,-0.015783556879607072,0.010407672131558174,0.014137608186323335,0.008698510795258909,0.004213708431213342,0.0018499827774393985,0.0007822799742289481, + { + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0007841223015974933, + 0.001850453397354219, + 0.004187899308371771, + 0.008640952434084206, + 0.01414741414964877, + 0.010117749275618, + -0.01654848996354919, + -0.02395108399453624, + -0.024199111971064446, + -0.015783556879607072, + 0.010407672131558174, + 0.014137608186323335, + 0.008698510795258909, + 0.004213708431213342, + 0.0018499827774393985, + 0.0007822799742289481, }; static const float fp16_c4_coeffs[32] __attribute__((aligned(VLEN))) = -{ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -2.3031641204975905e-05,6.150442488966733e-05,0.00015997783736818624,0.00038491646239693526,0.0007283649599237781,0.00034439150914392054,-0.003142246198646662,-0.004120389580321761, -0.004246050162553198,0.0030162727520777893,-0.00037312974308425725,-0.0007277242855014247,-0.00038811687679772674,-0.0001611434776868886,-6.14837984586862e-05,-2.297076123375133e-05, + { + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 2.3031641204975905e-05, + 6.150442488966733e-05, + 0.00015997783736818624, + 0.00038491646239693526, + 0.0007283649599237781, + 0.00034439150914392054, + -0.003142246198646662, + -0.004120389580321761, + 0.004246050162553198, + 0.0030162727520777893, + -0.00037312974308425725, + -0.0007277242855014247, + -0.00038811687679772674, + -0.0001611434776868886, + -6.14837984586862e-05, + -2.297076123375133e-05, }; /** @@ -554,8 +820,7 @@ static const float fp16_c4_coeffs[32] __attribute__((aligned(VLEN))) = * @param[in] length Number of elements in input/output arrays. * @return Returns 0 on successful execution. Otherwise -1. */ -int32_t hvx_silu_ahf(__fp16 *restrict input, __fp16 *restrict output, uint32_t size) -{ +int32_t hvx_silu_ahf(__fp16 *restrict input, __fp16 *restrict output, uint32_t size) { HVX_Vector *input_v_ptr; HVX_UVector *output_v_ptr; HVX_Vector input_min_v_hf; @@ -594,13 +859,12 @@ int32_t hvx_silu_ahf(__fp16 *restrict input, __fp16 *restrict output, uint32_t s HVX_Vector c4_coeff_v; /* Check input arguments. Return error status if some argument has invalid value */ - if ((input == 0) || (output == 0) || (size == 0)) - { + if ((input == 0) || (output == 0) || (size == 0)) { return -1; } - input_v_ptr = (HVX_Vector *) input; - output_v_ptr = (HVX_UVector *) output; + input_v_ptr = (HVX_Vector *)input; + output_v_ptr = (HVX_UVector *)output; /* * If input data is not aligned to HVX vector size, compose aligned vectors @@ -671,19 +935,16 @@ int32_t hvx_silu_ahf(__fp16 *restrict input, __fp16 *restrict output, uint32_t s * Handle number of whole vectors in input data. * Don't process last vector in order to avoid out-of-boundary load. */ - for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) - { + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { block = Q6_R_min_RR(i, BLOCK_SIZE); l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); - if (l2fetch_block > 0) - { + if (l2fetch_block > 0) { l2fetch(input_v_ptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); } /* Process one vector at a time */ - for (int32_t j = 0; j < block; ++j) - { + for (int32_t j = 0; j < block; ++j) { slinec = *input_v_ptr++; /* Compose vector of input data from slinec and slinep */ @@ -776,17 +1037,15 @@ int32_t hvx_silu_ahf(__fp16 *restrict input, __fp16 *restrict output, uint32_t s // output_v = Q6_Vqf16_vmpy_Vqf16Vqf16(output_v, input_v_qf16); // *output_v_ptr++ = Q6_Vhf_equals_Vqf16(output_v); - /* Prepare slinep for next iteration */ slinep = slinec; } } /* Handle last whole vector from input data */ - if (vectors_in_rounddown > 0) - { + if (vectors_in_rounddown > 0) { slinec = is_aligned(input_v_ptr, VLEN) && leftover == 0 ? slinep : *input_v_ptr++; - sline = Q6_V_valign_VVR(slinec, slinep, (size_t) input); + sline = Q6_V_valign_VVR(slinec, slinep, (size_t)input); tmp_v = Q6_Vh_vdeal_Vh(sline); /* Shift input range from [input_min, input_max] to [0, input_max - input_min] */ input_shifted_v_hf = Q6_Vqf16_vsub_VhfVhf(tmp_v, input_min_v_hf); @@ -862,11 +1121,8 @@ int32_t hvx_silu_ahf(__fp16 *restrict input, __fp16 *restrict output, uint32_t s } /* Handle leftover elements */ - if (leftover > 0) - { - slinec = (is_in_one_chunk(input_v_ptr, leftover_size, VLEN) - ? slinep - : *input_v_ptr++); + if (leftover > 0) { + slinec = (is_in_one_chunk(input_v_ptr, leftover_size, VLEN) ? slinep : *input_v_ptr++); sline = Q6_V_valign_VVR(slinec, slinep, (size_t)input); tmp_v = Q6_Vh_vdeal_Vh(sline); @@ -949,124 +1205,110 @@ int32_t hvx_silu_ahf(__fp16 *restrict input, __fp16 *restrict output, uint32_t s #endif -template -GraphStatus siluImpl(TensorType& out_0, - const TensorType& in_0) +template +GraphStatus siluImpl(TensorType &out_0, + const TensorType &in_0) { - /* - * add code here - * */ - /* - * To have good performance and stability, it is required to avoid heap memory - * allocation in this function. The heap memory allocation includes but not - * limited to calling malloc, operator new, constructing STL container objects - * like std::vector with default allocator, and adding items like calling - * std::vector::push_back to STL container objects with default allocator. - * - * Please check in SDK documentation for more information. - */ + /* + * add code here + * */ + /* + * To have good performance and stability, it is required to avoid heap memory + * allocation in this function. The heap memory allocation includes but not + * limited to calling malloc, operator new, constructing STL container objects + * like std::vector with default allocator, and adding items like calling + * std::vector::push_back to STL container objects with default allocator. + * + * Please check in SDK documentation for more information. + */ #ifdef REFERENCE_OP - debuglog("silu execute... inval=(%d)", in_0.get_dtype()); - debuglog("silu execute... inval=(%d)", out_0.get_dtype()); - + debuglog("silu execute... inval=(%d)", in_0.get_dtype()); + debuglog("silu execute... inval=(%d)", out_0.get_dtype()); + out_0.set_dims(in_0); // NHWC auto [b_in, h_in, w_in, d_in] = in_0.dims(); for (Idx b = 0; b < b_in; b++) { - for (Idx h = 0; h < h_in; h++) { - for (Idx w = 0; w < w_in; w++) { - // SiLU - for (Idx d = 0; d < d_in; d++) { - float inval = in_0(b, h, w, d); - float outval = 1 / (1 + expf(-inval)); - - - debuglog("silu execute... inval=(%f)", inval); - debuglog("silu execute... outval=(%f)", outval); - - out_0(b, h, w, d) = inval * outval; - - } + for (Idx h = 0; h < h_in; h++) { + for (Idx w = 0; w < w_in; w++) { + // SiLU + for (Idx d = 0; d < d_in; d++) { + float inval = in_0(b, h, w, d); + float outval = 1 / (1 + expf(-inval)); + + debuglog("silu execute... inval=(%f)", inval); + debuglog("silu execute... outval=(%f)", outval); + + out_0(b, h, w, d) = inval * outval; + } + } } - } } #else // HVX Method -- FP32 Version out_0.set_dims(in_0); - + DType dtype = in_0.get_dtype(); auto [b_in, h_in, w_in, d_in] = in_0.dims(); + size_t size = b_in * h_in * w_in * d_in; - size_t size = b_in*h_in*w_in*d_in; - // Noticable size >= 128 - + // SiLU inval / (1 + expf(-inval)); // sigmod 1.0/(exp(-x)+1.0) // SiLU inval * sigmod if (dtype == DType::Float16) { - - // NHWC - auto in_ptr = (__fp16*)in_0.raw_data_const(); - auto out_ptr = (__fp16*)out_0.raw_data(); - hvx_silu_ahf(in_ptr, out_ptr, size); + // NHWC + auto in_ptr = (__fp16 *)in_0.raw_data_const(); + auto out_ptr = (__fp16 *)out_0.raw_data(); + hvx_silu_ahf(in_ptr, out_ptr, size); } else { - // NHWC - auto in_ptr = (float*)in_0.raw_data_const(); - auto out_ptr = (float*)out_0.raw_data(); - hvx_silu_af(in_ptr, out_ptr, size); + // NHWC + auto in_ptr = (float *)in_0.raw_data_const(); + auto out_ptr = (float *)out_0.raw_data(); + hvx_silu_af(in_ptr, out_ptr, size); } return GraphStatus::Success; - - #endif #ifdef DEBUG for (Idx b = 0; b < b_in; b++) { - for (Idx h = 0; h < h_in; h++) { - for (Idx w = 0; w < w_in; w++) { - // SiLU - for (Idx d = 0; d < d_in; d++) { - float out_value = out_0(b, h, w, d); - debuglog("silu execute... outval=(%f)", out_value); - - } + for (Idx h = 0; h < h_in; h++) { + for (Idx w = 0; w < w_in; w++) { + // SiLU + for (Idx d = 0; d < d_in; d++) { + float out_value = out_0(b, h, w, d); + debuglog("silu execute... outval=(%f)", out_value); + } + } } - } } - -#endif - +#endif - return GraphStatus::Success; + return GraphStatus::Success; } -__attribute__((unused)) static float siluCostFunc(const Op *op) -{ - /* - * add code here - * */ +__attribute__((unused)) static float siluCostFunc(const Op *op) { + /* + * add code here + * */ - float cost = 0.0; // add cost computation here - return cost; + float cost = 0.0; // add cost computation here + return cost; } - - - - /* At the bottom of the op file, call END_PKG_OP_DEFINITION(), where is as BEGIN_PKG_OP_DEFINITION */ diff --git a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/SplitInput.cpp b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/SplitInput.cpp index f055afc51..b33decb01 100755 --- a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/SplitInput.cpp +++ b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/SplitInput.cpp @@ -9,17 +9,15 @@ #include "QnnOpPackage.h" #include "HTP/core/simple_reg.h" - BEGIN_PKG_OP_DEFINITION(PKG_SplitInput); - // op execute function declarations -template -GraphStatus splitinputImpl(TensorType& out_0, - TensorType& out_1, - const TensorType& in_0, +template +GraphStatus splitinputImpl(TensorType &out_0, + TensorType &out_1, + const TensorType &in_0, const TensorType1 &in_1, - const Tensor& num); + const Tensor &num); // forward declaration of sample cost function static float splitinputCostFunc(const Op *op); @@ -64,11 +62,11 @@ DEF_PACKAGE_OP((splitinputImpl), "SplitInput") * one definition per op, and this is optional * syntax: DEF_PACKAGE_PARAM_ORDER(OP,PARAM1,MANDATORY1,DEFAULT1,PARAM2,MANDATORY2,DEFAULT2...) * one or more parameters can be specified for each op - * order of parameters listed determines the order of parameters passed into op execution functions + * order of parameters listed determines the order of parameters passed into op execution functions * if an op does not have a parameter order definition, parameter order passed into Qnn_addNode * will be passed into op execution functions * if an op has a parameter order definition, any parameter passed into Qnn_addNode with unlisted - * name will be abandoned + * name will be abandoned * if two or more op packages with the same package name will be registered, they cannot list * conflicting parameter orders * PARAM refers to parameter name as a string literal @@ -82,84 +80,74 @@ DEF_PACKAGE_OP((splitinputImpl), "SplitInput") * Qnn_addNode */ - /* execute functions for ops */ -template -GraphStatus splitinputImpl(TensorType& out_0, - TensorType& out_1, - const TensorType& in_0, +template +GraphStatus splitinputImpl(TensorType &out_0, + TensorType &out_1, + const TensorType &in_0, const TensorType1 &in_1, - const Tensor& num) -{ - /* - * add code here - * */ - /* - * To have good performance and stability, it is required to avoid heap memory - * allocation in this function. The heap memory allocation includes but not - * limited to calling malloc, operator new, constructing STL container objects - * like std::vector with default allocator, and adding items like calling - * std::vector::push_back to STL container objects with default allocator. - * - * Please check in SDK documentation for more information. - */ - - // default is two. - - size_t o_size = in_1(0,0,0,0); - size_t x_size = in_1(0,0,0,1); + const Tensor &num) { + /* + * add code here + * */ + /* + * To have good performance and stability, it is required to avoid heap memory + * allocation in this function. The heap memory allocation includes but not + * limited to calling malloc, operator new, constructing STL container objects + * like std::vector with default allocator, and adding items like calling + * std::vector::push_back to STL container objects with default allocator. + * + * Please check in SDK documentation for more information. + */ - auto [b_in, h_in, w_in, d_in] = in_0.dims(); - - const size_t dims_0[] = {b_in, o_size, w_in, d_in}; - const size_t dims_1[] = {b_in, x_size, w_in, d_in}; + // default is two. - out_0.set_dims(dims_0); - out_1.set_dims(dims_1); + size_t o_size = in_1(0, 0, 0, 0); + size_t x_size = in_1(0, 0, 0, 1); - DType dtype = in_0.get_dtype(); - uint32_t bitwidth = 4; + auto [b_in, h_in, w_in, d_in] = in_0.dims(); - if (dtype == DType::QUInt8 || dtype == DType::QInt8) { + const size_t dims_0[] = {b_in, o_size, w_in, d_in}; + const size_t dims_1[] = {b_in, x_size, w_in, d_in}; - bitwidth = 1; + out_0.set_dims(dims_0); + out_1.set_dims(dims_1); - } else if (dtype == DType::Float16) { + DType dtype = in_0.get_dtype(); + uint32_t bitwidth = 4; - bitwidth = 2; - } else if (dtype == DType::Float32) { + if (dtype == DType::QUInt8 || dtype == DType::QInt8) { + bitwidth = 1; - bitwidth = 4; - } + } else if (dtype == DType::Float16) { + bitwidth = 2; + } else if (dtype == DType::Float32) { + bitwidth = 4; + } - const uint8_t *in_ptr = (uint8_t*)in_0.raw_data_const(); + const uint8_t *in_ptr = (uint8_t *)in_0.raw_data_const(); - uint8_t *out_ptr_0 = (uint8_t*)out_0.raw_data(); - uint8_t *out_ptr_1 = (uint8_t*)out_1.raw_data(); + uint8_t *out_ptr_0 = (uint8_t *)out_0.raw_data(); + uint8_t *out_ptr_1 = (uint8_t *)out_1.raw_data(); - memcpy(out_ptr_0, in_ptr, b_in * o_size * w_in * d_in * bitwidth); - in_ptr += b_in * o_size * w_in * d_in * bitwidth; + memcpy(out_ptr_0, in_ptr, b_in * o_size * w_in * d_in * bitwidth); + in_ptr += b_in * o_size * w_in * d_in * bitwidth; - memcpy(out_ptr_1, in_ptr, b_in * x_size * w_in * d_in * bitwidth * 4); + memcpy(out_ptr_1, in_ptr, b_in * x_size * w_in * d_in * bitwidth * 4); - return GraphStatus::Success; + return GraphStatus::Success; } -__attribute__((unused)) static float splitinputCostFunc(const Op *op) -{ - /* - * add code here - * */ +__attribute__((unused)) static float splitinputCostFunc(const Op *op) { + /* + * add code here + * */ - float cost = 0.0; // add cost computation here - return cost; + float cost = 0.0; // add cost computation here + return cost; } - - - - /* At the bottom of the op file, call END_PKG_OP_DEFINITION(), where is as BEGIN_PKG_OP_DEFINITION */ diff --git a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/WNop.cpp b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/WNop.cpp index 547e53589..2a7c1fd1a 100755 --- a/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/WNop.cpp +++ b/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/WNop.cpp @@ -12,17 +12,15 @@ #include #include - BEGIN_PKG_OP_DEFINITION(PKG_WNop); - // op execute function declarations -template -GraphStatus wnopImpl(TensorType& out_0, - TensorType1 &sync_var, - const TensorType& in_0, - const TensorType& in_1, - const Tensor& sync_type); +template +GraphStatus wnopImpl(TensorType &out_0, + TensorType1 &sync_var, + const TensorType &in_0, + const TensorType &in_1, + const Tensor &sync_type); // forward declaration of sample cost function static float wnopCostFunc(const Op *op); @@ -67,11 +65,11 @@ DEF_PACKAGE_OP((wnopImpl), "WNop") * one definition per op, and this is optional * syntax: DEF_PACKAGE_PARAM_ORDER(OP,PARAM1,MANDATORY1,DEFAULT1,PARAM2,MANDATORY2,DEFAULT2...) * one or more parameters can be specified for each op - * order of parameters listed determines the order of parameters passed into op execution functions + * order of parameters listed determines the order of parameters passed into op execution functions * if an op does not have a parameter order definition, parameter order passed into Qnn_addNode * will be passed into op execution functions * if an op has a parameter order definition, any parameter passed into Qnn_addNode with unlisted - * name will be abandoned + * name will be abandoned * if two or more op packages with the same package name will be registered, they cannot list * conflicting parameter orders * PARAM refers to parameter name as a string literal @@ -84,109 +82,88 @@ DEF_PACKAGE_OP((wnopImpl), "WNop") * graph construction will skip this parameter when this parameter is not provided at * Qnn_addNode */ -DEF_PACKAGE_PARAM_ORDER("WNop", +DEF_PACKAGE_PARAM_ORDER("WNop", "sync_type", true, nullptr) - /* execute functions for ops */ -template -GraphStatus wnopImpl(TensorType& out_0, - TensorType1 &sync_var, - const TensorType& in_0, - const TensorType& in_1, - const Tensor& sync_type) +template +GraphStatus wnopImpl(TensorType &out_0, + TensorType1 &sync_var, + const TensorType &in_0, + const TensorType &in_1, + const Tensor &sync_type) { - /* - * add code here - * */ - /* - * To have good performance and stability, it is required to avoid heap memory - * allocation in this function. The heap memory allocation includes but not - * limited to calling malloc, operator new, constructing STL container objects - * like std::vector with default allocator, and adding items like calling - * std::vector::push_back to STL container objects with default allocator. - * - * Please check in SDK documentation for more information. - */ - - - out_0.set_dims(in_0); - - auto sync_type_ = sync_type(0,0,0,0); - - - // sync_type == 0 sending signal to CPU - // sync_type == 1 waiting signal from CPU - - DType dtype = in_0.get_dtype(); - uint32_t bitwidth = 4; - - if (dtype == DType::QUInt8) { - + /* + * add code here + * */ + /* + * To have good performance and stability, it is required to avoid heap memory + * allocation in this function. The heap memory allocation includes but not + * limited to calling malloc, operator new, constructing STL container objects + * like std::vector with default allocator, and adding items like calling + * std::vector::push_back to STL container objects with default allocator. + * + * Please check in SDK documentation for more information. + */ + + out_0.set_dims(in_0); + + auto sync_type_ = sync_type(0, 0, 0, 0); + + // sync_type == 0 sending signal to CPU + // sync_type == 1 waiting signal from CPU + + DType dtype = in_0.get_dtype(); + uint32_t bitwidth = 4; + + if (dtype == DType::QUInt8) { bitwidth = 1; } else if (dtype == DType::Float16) { - bitwidth = 2; } else if (dtype == DType::Float32) { - bitwidth = 4; } - if (sync_type_ == 0) { - - auto [b_in, h_in, w_in, d_in] = in_0.dims(); + if (sync_type_ == 0) { + auto [b_in, h_in, w_in, d_in] = in_0.dims(); - auto in_ptr = (void*)in_0.raw_data_const(); - auto out_ptr = (void*)out_0.raw_data(); + auto in_ptr = (void *)in_0.raw_data_const(); + auto out_ptr = (void *)out_0.raw_data(); - memcpy(out_ptr, in_ptr, b_in * h_in * w_in * d_in * bitwidth); + memcpy(out_ptr, in_ptr, b_in * h_in * w_in * d_in * bitwidth); - sync_var(0,0,0,0) = 1; + sync_var(0, 0, 0, 0) = 1; - } else if (sync_type_ == 1) { + } else if (sync_type_ == 1) { + while (in_1(0, 0, 0, 0) == 0) { + Q6_V_vzero(); + } - while (in_1(0,0,0,0) == 0) { + auto [b_in, h_in, w_in, d_in] = in_0.dims(); - Q6_V_vzero(); + auto in_ptr = (void *)in_0.raw_data_const(); + auto out_ptr = (void *)out_0.raw_data(); + memcpy(out_ptr, in_ptr, b_in * h_in * w_in * d_in * bitwidth); } - auto [b_in, h_in, w_in, d_in] = in_0.dims(); - - auto in_ptr = (void*)in_0.raw_data_const(); - auto out_ptr = (void*)out_0.raw_data(); - - memcpy(out_ptr, in_ptr, b_in * h_in * w_in * d_in * bitwidth); - - } - - - - - - - return GraphStatus::Success; + return GraphStatus::Success; } -__attribute__((unused)) static float wnopCostFunc(const Op *op) -{ - /* - * add code here - * */ +__attribute__((unused)) static float wnopCostFunc(const Op *op) { + /* + * add code here + * */ - float cost = 0.0; // add cost computation here - return cost; + float cost = 0.0; // add cost computation here + return cost; } - - - - /* At the bottom of the op file, call END_PKG_OP_DEFINITION(), where is as BEGIN_PKG_OP_DEFINITION */ diff --git a/src/backends/qnn/Model/QnnModel.cpp b/src/backends/qnn/Model/QnnModel.cpp index 32adae185..7a2d91a68 100644 --- a/src/backends/qnn/Model/QnnModel.cpp +++ b/src/backends/qnn/Model/QnnModel.cpp @@ -14,6 +14,7 @@ #include "QnnModel.hpp" #include "QnnModelPal.hpp" #include "QnnTypeMacros.hpp" +#include "Utils/QnnSampleAppUtils.hpp" #define FREE_MEMORY(ptr1, ptr2, ptr3) \ do { \ @@ -72,6 +73,10 @@ ModelError_t QnnModel::initialize(const Qnn_BackendHandle_t &backendHandle, return MODEL_NO_ERROR; } +void QnnModel::setInitFromCache() { + isFromCache = true; +} + ModelError_t QnnModel::addTensor(const char *nodeName, Qnn_Tensor_t *tensor, bool saveTensor) { ModelError_t err; if (!tensor) { @@ -152,16 +157,20 @@ ModelError_t QnnModel::addTensor(const char *nodeName, Qnn_Tensor_t *tensor, boo QNN_TENSOR_SET_TYPE(tensor, QNN_TENSOR_TYPE_APP_READ); } - if (m_qnnInterface.tensorCreateGraphTensor(m_graph, tensor) != QNN_TENSOR_NO_ERROR) { - PRINT_ERROR("QnnModel::addTensor() Creating tensor for node: %s, tensorName: %s.\n", - nodeName, - QNN_TENSOR_GET_NAME(tensor)); - return MODEL_TENSOR_ERROR; + if (!isFromCache) { + if (m_qnnInterface.tensorCreateGraphTensor(m_graph, tensor) != QNN_TENSOR_NO_ERROR) { + PRINT_ERROR("QnnModel::addTensor() Creating tensor for node: %s, tensorName: %s.\n", + nodeName, + QNN_TENSOR_GET_NAME(tensor)); + return MODEL_TENSOR_ERROR; + } } if (saveTensor) { Qnn_Tensor_t tensorCopy; - VALIDATE(deepCopyQnnTensors(*tensor, tensorCopy), err); + if (!qnn::tools::sample_app::deepCopyQnnTensorInfo(&tensorCopy, tensor)) { + return MODEL_TENSOR_ERROR; + } // save network input/outputs tensors to use for setting the Qnn graph's input and output // tensors for populating GraphInfo_t for caller @@ -509,6 +518,16 @@ ModelError_t QnnModel::finalize(Qnn_ProfileHandle_t profile, Qnn_SignalHandle_t return err; } +size_t memscpy(void *dst, size_t dstSize, const void *src, size_t copySize) { + if (!dst || !src || !dstSize || !copySize) return 0; + + size_t minSize = dstSize < copySize ? dstSize : copySize; + + memcpy(dst, src, minSize); + + return minSize; +} + ModelError_t getGraphInfoFromModels(QnnModel *models, uint32_t numModels, GraphInfoPtr_t **graphsInfo) { @@ -611,25 +630,6 @@ ModelError_t getSingleGraphInfoFromModel(QnnModel &model, GraphInfoPtr_t* graphI return err; } -ModelError_t freeGraphsInfo(GraphInfoPtr_t **graphsInfo, uint32_t numGraphs) { - if (graphsInfo == nullptr || *graphsInfo == nullptr) { - PRINT_ERROR("freeGraphsInfo() invalid graphsInfo."); - return MODEL_TENSOR_ERROR; - } - for (uint32_t i = 0; i < numGraphs; i++) { - PRINT_INFO("Freeing graph in freeGraphInfo"); - free((*graphsInfo)[i]->graphName); - freeQnnTensors((*graphsInfo)[i]->inputTensors, (*graphsInfo)[i]->numInputTensors); - freeQnnTensors((*graphsInfo)[i]->outputTensors, (*graphsInfo)[i]->numOutputTensors); - } - - free(**graphsInfo); - free(*graphsInfo); - *graphsInfo = nullptr; - - return MODEL_NO_ERROR; -} - ModelError_t QnnModel::freeTensors() { for (std::map::iterator tensorIt = m_modelTensorsMap.begin(); diff --git a/src/backends/qnn/Model/QnnModel.hpp b/src/backends/qnn/Model/QnnModel.hpp index 6521d76e5..1823c49c8 100644 --- a/src/backends/qnn/Model/QnnModel.hpp +++ b/src/backends/qnn/Model/QnnModel.hpp @@ -55,6 +55,8 @@ class QnnModel { uint8_t doNodeValidations = 1, const QnnGraph_Config_t** graphConfigs = nullptr); + void setInitFromCache(); + /** * @brief A wrapper function to create a tensor inside class's context graph. * @@ -226,7 +228,7 @@ class QnnModel { bool m_debug = false; // flag to indicate if requested graph is to be run in debug mode // (i.e. all intermediate tensors will be accessible to client) // flag to indicate whether all addNode calls need to be validated - bool m_doNodeValidations = true; + bool m_doNodeValidations = true, isFromCache = false; std::vector m_modelInputTensors; std::vector m_modelOutputTensors; @@ -266,5 +268,4 @@ ModelError_t getSingleGraphInfoFromModel(QnnModel &model, GraphInfoPtr_t* graphI * @return Error code * */ -ModelError_t freeGraphsInfo(GraphInfoPtr_t** graphsInfo, uint32_t numGraphs); } // namespace qnn_wrapper_api diff --git a/src/backends/qnn/QNNBackend.cpp b/src/backends/qnn/QNNBackend.cpp index 69b112831..793a21b06 100755 --- a/src/backends/qnn/QNNBackend.cpp +++ b/src/backends/qnn/QNNBackend.cpp @@ -20,6 +20,7 @@ #include "QnnTypes.h" #include "HTP/QnnHtpGraph.h" #include "Layer.hpp" +#include "Layer.hpp" #include "Types.hpp" #include "op/QNNAdd.hpp" @@ -34,6 +35,8 @@ #include "op/QNNScale.hpp" #include "op/QNNSiLU.hpp" #include "op/QNNSoftMax.hpp" +#include "op/QNNSubGraphFinalize.hpp" +#include "op/QNNSubGraphStart.hpp" #include "op/QNNView.hpp" #include "op/QNNReLU.hpp" #include "op/QNNQuantize.hpp" @@ -86,6 +89,8 @@ void QNNBackend::registerOps() { addCreator(SPLITINPUT, (QNNBackend::Creator *)(new QNNSplitInputCreator())); addCreator(TRANSPOSE, (QNNBackend::Creator *)(new QNNTransposeCreator())); addCreator(SUPERSILU, (QNNBackend::Creator *)(new QNNSuperSiLUCreator())); + addCreator(SUBGRAPHSTART, (QNNBackend::Creator *)(new QNNSubGraphStartCreator())); + addCreator(SUBGRAPHFINALIZE, (QNNBackend::Creator *)(new QNNSubGraphFinalizeCreator())); } QNNBackend::QNNBackend(shared_ptr mm) : @@ -185,6 +190,23 @@ QNNBackend::QNNBackend(shared_ptr mm) : // register ops this->registerOps(); + + // check if the qnn_context.bin file exists + if (!std::filesystem::exists("qnn_context.bin")) { + // create qnn context + if (StatusCode::SUCCESS != this->createContext()) { + this->reportError("Context Creation failure"); + } + } else { + if (StatusCode::SUCCESS != this->retrieveQNNContext()) { + this->reportError("Context Retieve failure"); + } + } + // assign context to qnn memory manager +#ifdef QNN_ARM + auto qnnMM = std::static_pointer_cast(mem_manager_); + qnnMM->setQnnInterfaceAndContext(m_context); +#endif } QNNBackend::~QNNBackend() { @@ -215,14 +237,6 @@ void QNNBackend::onSetUpStart(vector> &inputs, vectorcreateContext()) { - this->reportError("Context Creation failure"); - } -#ifdef QNN_ARM - auto qnnMM = std::static_pointer_cast(mem_manager_); - qnnMM->setQnnInterfaceAndContext(m_context); -#endif // initialize qnn graph info, set graph info, graph count // NOTE: currently not using it @@ -240,21 +254,22 @@ void QNNBackend::onSetUpStart(vector> &inputs, vectorreportError("Graph Config Info failure"); + if (!isFromCache) { + err = qnnModels_[qnnModelIndex_].initialize(m_backendHandle, + m_qnnFunctionPointers.qnnInterface, + m_context, + graphName.c_str(), + m_debug, + DO_GRAPH_NODE_VALIDATIONS, + graphConfigs); + } else { + // set init from cache, the input and output tensor info still needs the QnnModel to maintain + // setting this is to avoid the tensor creation in the qnn graph + qnnModels_[qnnModelIndex_].setInitFromCache(); } - err = qnnModels_[qnnModelIndex_].initialize(m_backendHandle, - m_qnnFunctionPointers.qnnInterface, - m_context, - graphName.c_str(), - m_debug, - DO_GRAPH_NODE_VALIDATIONS, - graphConfigs); if (err != qnn_wrapper_api::MODEL_NO_ERROR) { this->reportError("Graph Initialization failure: " + graphName); } @@ -281,6 +296,9 @@ void QNNBackend::onSetUpStart(vector> &inputs, vector> &inputs, vectorload(&scaleTensor); - scale = roundf(scaleTensor.hostPtr()[0] / 127.0 * 100000) / 100000; + // scale = roundf(scaleTensor.hostPtr()[0] / (pow(2, 7) - 1) * 100000) / 100000; + scale = scaleTensor.hostPtr()[0] / (pow(2, 7) - 1); + scaleTensor.free(); + + break; + } + case MLLM_TYPE_I16: { + data_type = QNN_DATATYPE_SFIXED_POINT_16; + quantizeDefined = QNN_DEFINITION_DEFINED; + quantizeType = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET; + + string scaleName = input->name(); + + std::string wordToRemove = "outtensor-"; + int pos = scaleName.find(wordToRemove); + if (pos != -1) { // old frontend merge/split generated tensor + scaleName = scaleName.substr(wordToRemove.length()); + wordToRemove = "or_split"; + if (scaleName.find(wordToRemove) != -1) { + pos = scaleName.find("or_split"); + // scaleName.erase(pos, wordToRemove.length()); + scaleName = scaleName.substr(0, pos); + // o + scaleName += "o_proj.input_scale"; + } else if (scaleName.find("ires_split") != -1) { + pos = scaleName.find("ires_split"); + wordToRemove = "ires_split"; + // scaleName.erase(pos, wordToRemove.length()); + scaleName = scaleName.substr(0, pos); + // q + scaleName += "q_proj.input_scale"; + } else if (scaleName.find("fres_split") != -1) { + pos = scaleName.find("fres_split"); + wordToRemove = "fres_split"; + // scaleName.erase(pos, wordToRemove.length()); + scaleName = scaleName.substr(0, pos); + // fc1 + scaleName += "up_proj.input_scale"; + } + } else { // new frontend no merge/split condition + std::string prefix = "out-", suffix = ".quantize"; + if (input->name().find(prefix) != std::string::npos) { + scaleName = input->name().substr(prefix.length()); + } + if (scaleName.find(suffix) != std::string::npos) { + scaleName = scaleName.substr(0, scaleName.length() - suffix.length()); + } + scaleName += ".input_scale"; + } + scaleTensor.setName(scaleName); + loader->load(&scaleTensor); + // scale = roundf(scaleTensor.hostPtr()[0] / (pow(2, 15) - 1) * 100000) / 100000; + scale = scaleTensor.hostPtr()[0] / (pow(2, 15) - 1); scaleTensor.free(); break; @@ -375,15 +445,37 @@ void QNNBackend::onSetUpStart(vector> &inputs, vectorgraph, m_profileBackendHandle, nullptr)) { + return qnn_wrapper_api::ModelError_t::MODEL_GRAPH_ERROR; + } + if (ProfilingLevel::OFF != m_profilingLevel) { + extractBackendProfilingInfo(m_profileBackendHandle); + } + graphsInfo_.push_back(graphInfo); + + return qnn_wrapper_api::ModelError_t::MODEL_NO_ERROR; +} + void QNNBackend::onSetUpEnd(vector> &inputs, vector> &outputs, string graphName) { - currentInputBuffers = &inputBufferMap[graphName]; - currentOutputBuffers = &outputBufferMap[graphName]; - qnnModelIndex_ = qnnModelIndexMap_[graphName]; - PRINT_MEMORY_USAGE("before graph finilize") - auto status = graphFinilize(); - PRINT_MEMORY_USAGE("after graph finilize") - if (qnn_wrapper_api::ModelError_t::MODEL_NO_ERROR != status) { - this->reportError("Graph Finalization failure"); + // currentInputBuffers = &inputBufferMap[graphName]; + // currentOutputBuffers = &outputBufferMap[graphName]; + // qnnModelIndex_ = qnnModelIndexMap_[graphName]; + + // online graph building, finalize graph + if (!isFromCache) { + PRINT_MEMORY_USAGE("before graph finilize") + auto status = graphFinilize(); + PRINT_MEMORY_USAGE("after graph finilize") + if (qnn_wrapper_api::ModelError_t::MODEL_NO_ERROR != status) { + this->reportError("Graph Finalization failure"); + } } auto returnStatus = StatusCode::SUCCESS; @@ -391,7 +483,7 @@ void QNNBackend::onSetUpEnd(vector> &inputs, vector> &inputs, vectornumInputTensors; i++) { qnnMM->registerQnnTensor((*currentInputBuffers)[i], qnnInputs[i]); #ifdef DEBUGPRINT - if (i < inputs.size()) { - std::cout << "\nregistered input tensor: " << inputs[i]->hostPtr() << " backend staged ptr: " << (void *)(*currentInputBuffers)[i] << std::endl; - } else { - std::cout << "\n registered op added input" << std::endl; - } + std::cout << "\nregistered input tensor backend staged ptr: " << (void *)(*currentInputBuffers)[i] << std::endl; std::cout << "qnn input tensor name: " << qnnInputs[i].v1.name << std::endl; std::cout << "qnn input tensor scale: " << qnnInputs[i].v1.quantizeParams.scaleOffsetEncoding.scale << std::endl; #endif @@ -424,18 +512,14 @@ void QNNBackend::onSetUpEnd(vector> &inputs, vectornumOutputTensors; i++) { qnnMM->registerQnnTensor((*currentOutputBuffers)[i], qnnOutputs[i]); #ifdef DEBUGPRINT - if (i < outputs.size()) { - std::cout << "\nregistered output tensor: " << outputs[i]->hostPtr() << " backend staged ptr: " << (void *)(*currentOutputBuffers)[i] << std::endl; - } else { - std::cout << "\n registered op added output" << std::endl; - } + std::cout << "\nregistered output tensor backend staged ptr: " << (void *)(*currentOutputBuffers)[i] << std::endl; std::cout << "qnn output tensor name: " << qnnOutputs[i].v1.name << std::endl; std::cout << "qnn output tensor scale: " << qnnOutputs[i].v1.quantizeParams.scaleOffsetEncoding.scale << std::endl; #endif } - inputsMap_[qnnModelIndex_] = qnnInputs; - outputsMap_[qnnModelIndex_] = qnnOutputs; + graphInfo->inputTensors = qnnInputs; + graphInfo->outputTensors = qnnOutputs; } void QNNBackend::onExecuteStart(vector> &inputs, vector> &outputs, string graphName) { @@ -443,10 +527,12 @@ void QNNBackend::onExecuteStart(vector> &inputs, vectorinputTensors; + // Qnn_Tensor_t *outputs_ = outputsMap_[t_qnnModelIndex_]; + Qnn_Tensor_t *outputs_ = graphInfo->outputTensors; Qnn_ErrorHandle_t executeStatus = QNN_GRAPH_NO_ERROR; #ifdef DEBUGPRINT @@ -502,9 +588,7 @@ void QNNBackend::afterAllGraphsExecute() { inputBufferMap.clear(); outputBufferMap.clear(); - graphInfoMap_.clear(); - inputsMap_.clear(); - outputsMap_.clear(); + graphsInfo_.clear(); } std::string QNNBackend::getBackendBuildId() { @@ -521,6 +605,14 @@ qnn_wrapper_api::ModelError_t QNNBackend::graphAddNode(string name, std::vector outputTensors, std::vector params, string packageName) { + if (isFromCache) { + for (auto &qnnTensor : outputTensors) { + if (qnnTensor.v1.type == QNN_TENSOR_TYPE_APP_READ) { + qnnModels_[qnnModelIndex_].addTensor(qnnTensor.v1.name, qnnTensor); + } + } + return qnn_wrapper_api::ModelError_t::MODEL_NO_ERROR; + } qnn_wrapper_api::ModelError_t err = qnn_wrapper_api::ModelError_t::MODEL_NO_ERROR; Qnn_Param_t *paramsPtr = nullptr; if (!params.empty()) { @@ -542,27 +634,10 @@ qnn_wrapper_api::ModelError_t QNNBackend::graphAddNode(string name, return err; } -qnn_wrapper_api::ModelError_t QNNBackend::graphFinilize() { - // Populate the constructed graphs in provided output variables - qnn_wrapper_api::ModelError_t err = qnn_wrapper_api::MODEL_NO_ERROR; - qnn_wrapper_api::GraphInfo_t *graphInfo = nullptr; - - VALIDATE(getSingleGraphInfoFromModel(qnnModels_[qnnModelIndex_], &graphInfo), err); - - // Graph finalize - if (QNN_GRAPH_NO_ERROR != m_qnnFunctionPointers.qnnInterface.graphFinalize(graphInfo->graph, m_profileBackendHandle, nullptr)) { - return qnn_wrapper_api::ModelError_t::MODEL_GRAPH_ERROR; - } - if (ProfilingLevel::OFF != m_profilingLevel) { - extractBackendProfilingInfo(m_profileBackendHandle); - } - - graphInfoMap_[qnnModelIndex_] = graphInfo; - - return qnn_wrapper_api::ModelError_t::MODEL_NO_ERROR; -} - qnn_wrapper_api::ModelError_t QNNBackend::modelAddTensor(std::string nodeName, Qnn_Tensor_t tensor) { + if (isFromCache && tensor.v1.type != QNN_TENSOR_TYPE_APP_READ) { + return qnn_wrapper_api::ModelError_t::MODEL_NO_ERROR; + } return qnnModels_[qnnModelIndex_].addTensor(nodeName.c_str(), tensor); } @@ -766,6 +841,85 @@ StatusCode QNNBackend::freeDevice() { return StatusCode::SUCCESS; } +void QNNBackend::saveQNNContext() { + uint64_t binarySize, writtenSize; + m_qnnFunctionPointers.qnnInterface.contextGetBinarySize(m_context, &binarySize); + + std::unique_ptr binaryBuffer(new uint8_t[binarySize]); + + m_qnnFunctionPointers.qnnInterface.contextGetBinary(m_context, reinterpret_cast(binaryBuffer.get()), binarySize, &writtenSize); + + if (binarySize < writtenSize) { + QNN_ERROR( + "Illegal written buffer size [%d] bytes. Cannot exceed allocated memory of [%d] bytes", + binarySize, + writtenSize); + } + std::ofstream file("qnn_context.bin", std::ios::binary); + file.write(reinterpret_cast(binaryBuffer.get()), writtenSize); + file.close(); + + std::cout << "QNN context saved to qnn_context.bin written " << writtenSize << std::endl; +} + +StatusCode QNNBackend::retrieveQNNContext() { + auto returnStatus = StatusCode::SUCCESS; + // load qnn system function pointers + if (dynamicloadutil::StatusCode::SUCCESS != dynamicloadutil::getQnnSystemFunctionPointers("libQnnSystem.so", &m_qnnFunctionPointers)) { + reportError("Error initializing QNN System Function Pointers"); + } + + // Read the binary from qnn_context.bin and get the size in byte + std::ifstream file("qnn_context.bin", std::ios::binary | std::ios::ate); + std::streamsize size = file.tellg(); + file.seekg(0, std::ios::beg); + shared_ptr binaryBuffer(new uint8_t[size], std::default_delete()); + + file.read(reinterpret_cast(binaryBuffer.get()), size); + file.close(); + + // inspect binary info + + QnnSystemContext_Handle_t sysCtxHandle{nullptr}; + if (QNN_SUCCESS != m_qnnFunctionPointers.qnnSystemInterface.systemContextCreate(&sysCtxHandle)) { + QNN_ERROR("Could not create system handle."); + returnStatus = StatusCode::FAILURE; + } + const QnnSystemContext_BinaryInfo_t *binaryInfo{nullptr}; + Qnn_ContextBinarySize_t binaryInfoSize{0}; + if (StatusCode::SUCCESS == returnStatus && QNN_SUCCESS != m_qnnFunctionPointers.qnnSystemInterface.systemContextGetBinaryInfo(sysCtxHandle, static_cast(binaryBuffer.get()), size, &binaryInfo, &binaryInfoSize)) { + QNN_ERROR("Failed to get context binary info"); + returnStatus = StatusCode::FAILURE; + } + + qnn_wrapper_api::GraphInfo_t **graphsInfo = nullptr; + uint32_t graphNum; + // fill GraphInfo_t based on binary info + if (StatusCode::SUCCESS == returnStatus && !copyMetadataToGraphsInfo(binaryInfo, graphsInfo, graphNum)) { + QNN_ERROR("Failed to copy metadata."); + returnStatus = StatusCode::FAILURE; + } + m_qnnFunctionPointers.qnnSystemInterface.systemContextFree(sysCtxHandle); + sysCtxHandle = nullptr; + + graphsInfo_.assign(graphsInfo, graphsInfo + graphNum); + + Qnn_ContextBinarySize_t writtenSize = 0; + m_qnnFunctionPointers.qnnInterface.contextCreateFromBinary(m_backendHandle, m_deviceHandle, (const QnnContext_Config_t **)m_contextConfig, binaryBuffer.get(), size, &m_context, m_profileBackendHandle); + + for (auto &g : graphsInfo_) { + if (QNN_SUCCESS != m_qnnFunctionPointers.qnnInterface.graphRetrieve(m_context, g->graphName, &g->graph)) { + QNN_ERROR("Unable to retrieve graph handle"); + returnStatus = StatusCode::FAILURE; + } + } + + this->isFromCache = true; + + MLLM_LOG_INFO_STREAM << "QNN context retrieved from qnn_context.bin"; + return returnStatus; +} + std::vector QNNBackend::runFunc(std::vector out_names, TensorFuncType type, std::vector float_args, @@ -958,6 +1112,11 @@ std::vector QNNBackend::runLayer(Layer *layer, std::vector input layer->op_ = layer->backend_->opCreate(layer->param_, layer->name_); #endif } + if (layer->param_["type"] == SUBGRAPHFINALIZE) { + for (auto &input : inputs) { + activation_tensors[input.name()]->setTtype(GRAPH_OUTPUT); + } + } if (module->doLoad) { layer->op_->load(*module->loader); layer->inited_loaded = true; @@ -992,6 +1151,17 @@ std::vector QNNBackend::runLayer(Layer *layer, std::vector input } } next_name = Layer::layername_2_tensorname[layer_next_name]; + } else if (layer_next_name.find("visual") != string::npos) { + // QNN VLM trick: visual model use act tensor sharing + if (Layer::layername_2_tensorname.find(layer_next_name) == Layer::layername_2_tensorname.end()) { + if (layer->param_["type"] == KVCACHE) { + Layer::layername_2_tensorname[layer_next_name] = layer_next_name; + init_reset_KVCache(inputs[0].name(), module, layer->saved_list_idx, Layer::layername_2_tensorname, layer->backend_); + } else { + Layer::layername_2_tensorname[layer_next_name] = name_num_to_X(layer_next_name); + } + } + next_name = Layer::layername_2_tensorname[layer_next_name]; } else { next_name = layer_next_name; } @@ -1005,7 +1175,7 @@ std::vector QNNBackend::runLayer(Layer *layer, std::vector input if (module->doLoad) { vector output_result = {}; for (const auto &layer_next_name : layer_next_names) { - string next_name = Layer::use_layername_2_tensorname ? Layer::layername_2_tensorname[layer_next_name] : layer_next_name; + string next_name = Layer::use_layername_2_tensorname ? Layer::layername_2_tensorname[layer_next_name] : (layer_next_name.find("visual") != string::npos ? Layer::layername_2_tensorname[layer_next_name] : layer_next_name); output_result.push_back(*activation_tensors[next_name]); } return output_result; @@ -1035,7 +1205,7 @@ std::vector QNNBackend::runLayer(Layer *layer, std::vector input } vector> output_tensors = {}; for (const auto &layer_next_name : layer_next_names) { - string next_name = Layer::use_layername_2_tensorname ? Layer::layername_2_tensorname[layer_next_name] : layer_next_name; + string next_name = Layer::use_layername_2_tensorname ? Layer::layername_2_tensorname[layer_next_name] : (layer_next_name.find("visual") != string::npos ? Layer::layername_2_tensorname[layer_next_name] : layer_next_name); output_tensors.push_back(activation_tensors[next_name]); } #ifdef DEBUGOPTIME @@ -1043,17 +1213,25 @@ std::vector QNNBackend::runLayer(Layer *layer, std::vector input #endif switch (Tensor::tensor_status) { case TENSOR_STATIC_INIT: { - layer->op_->reshape(input_tensors, output_tensors); - layer->op_->setUp(input_tensors, output_tensors); + if (!Module::isFirstChunk && layer->backend_->type() == MLLM_QNN) { + } else { + layer->op_->reshape(input_tensors, output_tensors); + layer->op_->setUp(input_tensors, output_tensors); + } break; } case TENSOR_STATIC_READY: { - layer->op_->execute(input_tensors, output_tensors); + if (!Module::isFirstChunk && layer->backend_->type() == MLLM_QNN && layer->param_["type"] != SUBGRAPHSTART) { + } else { + layer->op_->execute(input_tensors, output_tensors); + } break; } case TENSOR_STATIC_TRACE: { if (layer->backend_->type() == BackendType::MLLM_CPU) { Tracer::addOp(layer->op_, input_tensors, output_tensors); + } else if (layer->param_["type"] == SUBGRAPHSTART) { // begin of QNN graph + Tracer::addModule(input_tensors, {}, layer->op_->name()); } break; } @@ -1061,29 +1239,29 @@ std::vector QNNBackend::runLayer(Layer *layer, std::vector input break; } } - // if (Backend::global_backends.size() == 1) { - // for (auto input_tensor : input_tensors) { - // if ((activation_tensors_num.find(input_tensor->name()) != activation_tensors_num.end())) { - // switch (Tensor::tensor_status) { - // case TENSOR_STATIC_INIT: { - // activation_tensors_num[input_tensor->name()] += 1; - // break; - // } - // case TENSOR_STATIC_READY: { - // activation_tensors_num[input_tensor->name()] -= 1; - // break; - // } - // default: { - // } - // } - // if (activation_tensors_num[input_tensor->name()] == 0 && activation_tensors[input_tensor->name()]->sequence() > 1 - // && activation_tensors[input_tensor->name()]->ttype() != GRAPH_OUTPUT) { - // activation_tensors[input_tensor->name()]->free(); - // // std::cout << input_tensor->name() << "|" << std::endl; - // } - // } - // } - // } +// if (Backend::global_backends.size() == 1) { +// for (auto input_tensor : input_tensors) { +// if ((activation_tensors_num.find(input_tensor->name()) != activation_tensors_num.end())) { +// switch (Tensor::tensor_status) { +// case TENSOR_STATIC_INIT: { +// activation_tensors_num[input_tensor->name()] += 1; +// break; +// } +// case TENSOR_STATIC_READY: { +// activation_tensors_num[input_tensor->name()] -= 1; +// break; +// } +// default: { +// } +// } +// if (activation_tensors_num[input_tensor->name()] == 0 && activation_tensors[input_tensor->name()]->sequence() > 1 +// && activation_tensors[input_tensor->name()]->ttype() != GRAPH_OUTPUT) { +// activation_tensors[input_tensor->name()]->free(); +// // std::cout << input_tensor->name() << "|" << std::endl; +// } +// } +// } +// } #ifdef DEBUGOPTIME if (Tensor::tensor_status == TENSOR_STATIC_READY) { auto end_t = mllm_time_us(); @@ -1092,7 +1270,7 @@ std::vector QNNBackend::runLayer(Layer *layer, std::vector input #endif vector output_result = {}; for (const auto &layer_next_name : layer_next_names) { - string next_name = Layer::use_layername_2_tensorname ? Layer::layername_2_tensorname[layer_next_name] : layer_next_name; + string next_name = Layer::use_layername_2_tensorname ? Layer::layername_2_tensorname[layer_next_name] : (layer_next_name.find("visual") != string::npos ? Layer::layername_2_tensorname[layer_next_name] : layer_next_name); #ifdef DEBUGSAVETENSOR activation_tensors[next_name]->saveNData(layer_next_name); #endif @@ -1102,19 +1280,19 @@ std::vector QNNBackend::runLayer(Layer *layer, std::vector input } std::vector QNNBackend::runForward(Module *module, std::vector inputs, std::vector args) { // set static tmp_device to device_ to init layers' op - auto previoud_device = Module::tmp_device; - Module::tmp_device = module->device_; + // auto previoud_device = Module::tmp_device; + // Module::tmp_device = module->device_; // Module Loading if (Module::llm_model_ptr && Module::llm_model_ptr->doLoad) { auto outputs = module->Forward(inputs, args); // for inner module, set output tensors to GRAPH_OUTPUT - if (inputs[0].ttype() != TensorType::INPUT_TENSOR) { // XPUs' module should not be the outermost input tensor - for (auto &output : outputs) { - inputs[0].module()->activation_tensors[output.name()]->setTtype(GRAPH_OUTPUT); - } - } - // set Module::tmp_device to previous device - Module::tmp_device = previoud_device; + // if (inputs[0].ttype() != TensorType::INPUT_TENSOR) { // XPUs' module should not be the outermost input tensor + // for (auto &output : outputs) { + // inputs[0].module()->activation_tensors[output.name()]->setTtype(GRAPH_OUTPUT); + // } + // } + // // set Module::tmp_device to previous device + // Module::tmp_device = previoud_device; return outputs; } // if (false) { diff --git a/src/backends/qnn/QNNBackend.hpp b/src/backends/qnn/QNNBackend.hpp index 0064ff753..6e5835a49 100644 --- a/src/backends/qnn/QNNBackend.hpp +++ b/src/backends/qnn/QNNBackend.hpp @@ -110,6 +110,10 @@ class QNNBackend : public Backend { dataLoader_ = dataLoader; } + void saveQNNContext(); + + StatusCode retrieveQNNContext(); + private: qnn_wrapper_api::ModelError_t graphFinilize(); qnn_wrapper_api::ModelError_t graphConfig(); @@ -171,7 +175,8 @@ class QNNBackend : public Backend { iotensor::InputDataType m_inputDataType; sample_app::ProfilingLevel m_profilingLevel; - std::map graphInfoMap_; + // std::map graphInfoMap_; + std::vector graphsInfo_; const QnnGraph_Config_t **graphConfigs = nullptr; // these two pointers is .so library handle @@ -187,8 +192,7 @@ class QNNBackend : public Backend { Qnn_BackendHandle_t m_backendHandle = nullptr; Qnn_DeviceHandle_t m_deviceHandle = nullptr; - std::map inputsMap_; - std::map outputsMap_; + bool isFromCache = false; }; } // namespace mllm diff --git a/src/backends/qnn/QnnTypeMacros.hpp b/src/backends/qnn/QnnTypeMacros.hpp index 99aecbecc..70c70d031 100644 --- a/src/backends/qnn/QnnTypeMacros.hpp +++ b/src/backends/qnn/QnnTypeMacros.hpp @@ -1,21 +1,17 @@ //============================================================================== // -// Copyright (c) 2022 Qualcomm Technologies, Inc. -// All Rights Reserved. +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All rights reserved. // Confidential and Proprietary - Qualcomm Technologies, Inc. // //============================================================================== +// TODO: remove once the SNPE build for QNN core is sorted out #pragma once -#include -#include -#include - #include "QnnTypes.h" -#include "WrapperUtils/QnnWrapperUtils.hpp" -namespace qnn_wrapper_api { +#define QNN_OP_CFG_VALID(opConfig) ((opConfig).version == QNN_OPCONFIG_VERSION_1) /** * @brief Verifies the tensor object passed is of supported Qnn_Tensor_t API version @@ -24,14 +20,8 @@ namespace qnn_wrapper_api { * * @return Error code */ -inline ModelError_t validateTensorVersion(Qnn_Tensor_t tensor) { - if (tensor.version != QNN_TENSOR_VERSION_1) { - PRINT_ERROR("validateTensorVersion() tensor %s, got unsupported version %d.", - tensor.v1.name, - tensor.version); - return MODEL_TENSOR_ERROR; - } - return MODEL_NO_ERROR; +inline bool validateTensorVersion(Qnn_Tensor_t tensor) { + return !(tensor.version != QNN_TENSOR_VERSION_1 && tensor.version != QNN_TENSOR_VERSION_2); } /** @@ -41,467 +31,683 @@ inline ModelError_t validateTensorVersion(Qnn_Tensor_t tensor) { * * @return Error code */ -inline ModelError_t validateOpConfigVersion(Qnn_OpConfig_t opConfig) { - if (opConfig.version != QNN_OPCONFIG_VERSION_1) { - PRINT_ERROR("validateOpConfigVersion() op %s, got unsupported version %d.", - opConfig.v1.name, - opConfig.version); - return MODEL_NODES_ERROR; - } - return MODEL_NO_ERROR; +inline bool validateOpConfigVersion(Qnn_OpConfig_t opConfig) { + return !(opConfig.version != QNN_OPCONFIG_VERSION_1); +} + +inline Qnn_OpConfig_t createQnnOpConfig(const Qnn_OpConfigVersion_t version) { + Qnn_OpConfig_t opConfig = QNN_OPCONFIG_INIT; + opConfig.version = version; + if (version == QNN_OPCONFIG_VERSION_1) { + opConfig.v1 = QNN_OPCONFIG_V1_INIT; + } + return opConfig; } -inline const char* getQnnOpConfigName(const Qnn_OpConfig_t& opConfig) { - if (opConfig.version == QNN_OPCONFIG_VERSION_1) { - return opConfig.v1.name; - } - return nullptr; +inline const char *getQnnOpConfigName(const Qnn_OpConfig_t &opConfig) { + if (opConfig.version == QNN_OPCONFIG_VERSION_1) { + return opConfig.v1.name; + } + return NULL; } -inline const char* getQnnOpConfigName(const Qnn_OpConfig_t* opConfig) { - return getQnnOpConfigName(*opConfig); +inline const char *getQnnOpConfigName(const Qnn_OpConfig_t *opConfig) { + return getQnnOpConfigName(*opConfig); } -inline const char* getQnnOpConfigPackageName(const Qnn_OpConfig_t& opConfig) { - if (opConfig.version == QNN_OPCONFIG_VERSION_1) { - return opConfig.v1.packageName; - } - return nullptr; +inline const char *getQnnOpConfigPackageName(const Qnn_OpConfig_t &opConfig) { + if (opConfig.version == QNN_OPCONFIG_VERSION_1) { + return opConfig.v1.packageName; + } + return NULL; } -inline const char* getQnnOpConfigPackageName(const Qnn_OpConfig_t* opConfig) { - return getQnnOpConfigPackageName(*opConfig); +inline const char *getQnnOpConfigPackageName(const Qnn_OpConfig_t *opConfig) { + return getQnnOpConfigPackageName(*opConfig); } -inline const char* getQnnOpConfigTypeName(const Qnn_OpConfig_t& opConfig) { - if (opConfig.version == QNN_OPCONFIG_VERSION_1) { - return opConfig.v1.typeName; - } - return nullptr; +inline const char *getQnnOpConfigTypeName(const Qnn_OpConfig_t &opConfig) { + if (opConfig.version == QNN_OPCONFIG_VERSION_1) { + return opConfig.v1.typeName; + } + return NULL; } -inline const char* getQnnOpConfigTypeName(const Qnn_OpConfig_t* opConfig) { - return getQnnOpConfigTypeName(*opConfig); +inline const char *getQnnOpConfigTypeName(const Qnn_OpConfig_t *opConfig) { + return getQnnOpConfigTypeName(*opConfig); } -inline uint32_t getQnnOpConfigNumParams(const Qnn_OpConfig_t& opConfig) { - if (opConfig.version == QNN_OPCONFIG_VERSION_1) { - return opConfig.v1.numOfParams; - } - return 0u; +inline uint32_t getQnnOpConfigNumParams(const Qnn_OpConfig_t &opConfig) { + if (opConfig.version == QNN_OPCONFIG_VERSION_1) { + return opConfig.v1.numOfParams; + } + return 0u; } -inline uint32_t getQnnOpConfigNumParams(const Qnn_OpConfig_t* opConfig) { - return getQnnOpConfigNumParams(*opConfig); +inline uint32_t getQnnOpConfigNumParams(const Qnn_OpConfig_t *opConfig) { + return getQnnOpConfigNumParams(*opConfig); } -inline const Qnn_Param_t* getQnnOpConfigParams(const Qnn_OpConfig_t& opConfig) { - if (opConfig.version == QNN_OPCONFIG_VERSION_1) { - return opConfig.v1.params; - } - return nullptr; +inline Qnn_Param_t *getQnnOpConfigParams(const Qnn_OpConfig_t &opConfig) { + if (opConfig.version == QNN_OPCONFIG_VERSION_1) { + return opConfig.v1.params; + } + return NULL; } -inline const Qnn_Param_t* getQnnOpConfigParams(const Qnn_OpConfig_t* opConfig) { - return getQnnOpConfigParams(*opConfig); +inline Qnn_Param_t *getQnnOpConfigParams(const Qnn_OpConfig_t *opConfig) { + return getQnnOpConfigParams(*opConfig); } -inline uint32_t getQnnOpConfigNumInputs(const Qnn_OpConfig_t& opConfig) { - if (opConfig.version == QNN_OPCONFIG_VERSION_1) { - return opConfig.v1.numOfInputs; - } - return 0u; +inline uint32_t getQnnOpConfigNumInputs(const Qnn_OpConfig_t &opConfig) { + if (opConfig.version == QNN_OPCONFIG_VERSION_1) { + return opConfig.v1.numOfInputs; + } + return 0u; } -inline uint32_t getQnnOpConfigNumInputs(const Qnn_OpConfig_t* opConfig) { - return getQnnOpConfigNumInputs(*opConfig); +inline uint32_t getQnnOpConfigNumInputs(const Qnn_OpConfig_t *opConfig) { + return getQnnOpConfigNumInputs(*opConfig); } -inline const Qnn_Tensor_t* getQnnOpConfigInputs(const Qnn_OpConfig_t& opConfig) { - if (opConfig.version == QNN_OPCONFIG_VERSION_1) { - return opConfig.v1.inputTensors; - } - return nullptr; +inline Qnn_Tensor_t *getQnnOpConfigInputs(const Qnn_OpConfig_t &opConfig) { + if (opConfig.version == QNN_OPCONFIG_VERSION_1) { + return opConfig.v1.inputTensors; + } + return NULL; } -inline const Qnn_Tensor_t* getQnnOpConfigInputs(const Qnn_OpConfig_t* opConfig) { - return getQnnOpConfigInputs(*opConfig); +inline Qnn_Tensor_t *getQnnOpConfigInputs(const Qnn_OpConfig_t *opConfig) { + return getQnnOpConfigInputs(*opConfig); } -inline uint32_t getQnnOpConfigNumOutputs(const Qnn_OpConfig_t& opConfig) { - if (opConfig.version == QNN_OPCONFIG_VERSION_1) { - return opConfig.v1.numOfOutputs; - } - return 0u; +inline uint32_t getQnnOpConfigNumOutputs(const Qnn_OpConfig_t &opConfig) { + if (opConfig.version == QNN_OPCONFIG_VERSION_1) { + return opConfig.v1.numOfOutputs; + } + return 0u; } -inline uint32_t getQnnOpConfigNumOutputs(const Qnn_OpConfig_t* opConfig) { - return getQnnOpConfigNumOutputs(*opConfig); +inline uint32_t getQnnOpConfigNumOutputs(const Qnn_OpConfig_t *opConfig) { + return getQnnOpConfigNumOutputs(*opConfig); } -inline const Qnn_Tensor_t* getQnnOpConfigOutputs(const Qnn_OpConfig_t& opConfig) { - if (opConfig.version == QNN_OPCONFIG_VERSION_1) { - return opConfig.v1.outputTensors; - } - return nullptr; +inline Qnn_Tensor_t *getQnnOpConfigOutputs(const Qnn_OpConfig_t &opConfig) { + if (opConfig.version == QNN_OPCONFIG_VERSION_1) { + return opConfig.v1.outputTensors; + } + return NULL; } -inline const Qnn_Tensor_t* getQnnOpConfigOutputs(const Qnn_OpConfig_t* opConfig) { - return getQnnOpConfigOutputs(*opConfig); +inline Qnn_Tensor_t *getQnnOpConfigOutputs(const Qnn_OpConfig_t *opConfig) { + return getQnnOpConfigOutputs(*opConfig); } -inline void setQnnOpConfigName(Qnn_OpConfig_t& opConfig, const char* name) { - if (opConfig.version == QNN_OPCONFIG_VERSION_1) { - opConfig.v1.name = name; - } +inline void setQnnOpConfigName(Qnn_OpConfig_t &opConfig, const char *name) { + if (opConfig.version == QNN_OPCONFIG_VERSION_1) { + opConfig.v1.name = name; + } } -inline void setQnnOpConfigName(Qnn_OpConfig_t* opConfig, const char* name) { - setQnnOpConfigName(*opConfig, name); +inline void setQnnOpConfigName(Qnn_OpConfig_t *opConfig, const char *name) { + setQnnOpConfigName(*opConfig, name); } -inline void setQnnOpConfigPackageName(Qnn_OpConfig_t& opConfig, const char* packageName) { - if (opConfig.version == QNN_OPCONFIG_VERSION_1) { - opConfig.v1.packageName = packageName; - } +inline void setQnnOpConfigPackageName(Qnn_OpConfig_t &opConfig, const char *packageName) { + if (opConfig.version == QNN_OPCONFIG_VERSION_1) { + opConfig.v1.packageName = packageName; + } } -inline void setQnnOpConfigPackageName(Qnn_OpConfig_t* opConfig, const char* packageName) { - setQnnOpConfigPackageName(*opConfig, packageName); +inline void setQnnOpConfigPackageName(Qnn_OpConfig_t *opConfig, const char *packageName) { + setQnnOpConfigPackageName(*opConfig, packageName); } -inline void setQnnOpConfigTypeName(Qnn_OpConfig_t& opConfig, const char* typeName) { - if (opConfig.version == QNN_OPCONFIG_VERSION_1) { - opConfig.v1.typeName = typeName; - } +inline void setQnnOpConfigTypeName(Qnn_OpConfig_t &opConfig, const char *typeName) { + if (opConfig.version == QNN_OPCONFIG_VERSION_1) { + opConfig.v1.typeName = typeName; + } } -inline void setQnnOpConfigTypeName(Qnn_OpConfig_t* opConfig, const char* typeName) { - setQnnOpConfigTypeName(*opConfig, typeName); +inline void setQnnOpConfigTypeName(Qnn_OpConfig_t *opConfig, const char *typeName) { + setQnnOpConfigTypeName(*opConfig, typeName); } -inline void setQnnOpConfigParams(Qnn_OpConfig_t& opConfig, +inline void setQnnOpConfigParams(Qnn_OpConfig_t &opConfig, uint32_t numOfParams, - Qnn_Param_t* params) { - if (opConfig.version == QNN_OPCONFIG_VERSION_1) { - opConfig.v1.numOfParams = numOfParams; - opConfig.v1.params = params; - } + Qnn_Param_t *params) { + if (opConfig.version == QNN_OPCONFIG_VERSION_1) { + opConfig.v1.numOfParams = numOfParams; + opConfig.v1.params = params; + } } -inline void setQnnOpConfigParams(Qnn_OpConfig_t* opConfig, +inline void setQnnOpConfigParams(Qnn_OpConfig_t *opConfig, uint32_t numOfParams, - Qnn_Param_t* params) { - setQnnOpConfigParams(*opConfig, numOfParams, params); + Qnn_Param_t *params) { + setQnnOpConfigParams(*opConfig, numOfParams, params); } -inline void setQnnOpConfigInputs(Qnn_OpConfig_t& opConfig, +inline void setQnnOpConfigInputs(Qnn_OpConfig_t &opConfig, uint32_t numOfInputs, - Qnn_Tensor_t* inputTensors) { - if (opConfig.version == QNN_OPCONFIG_VERSION_1) { - opConfig.v1.numOfInputs = numOfInputs; - opConfig.v1.inputTensors = inputTensors; - } + Qnn_Tensor_t *inputTensors) { + if (opConfig.version == QNN_OPCONFIG_VERSION_1) { + opConfig.v1.numOfInputs = numOfInputs; + opConfig.v1.inputTensors = inputTensors; + } } -inline void setQnnOpConfigInputs(Qnn_OpConfig_t* opConfig, +inline void setQnnOpConfigInputs(Qnn_OpConfig_t *opConfig, uint32_t numOfInputs, - Qnn_Tensor_t* inputTensors) { - setQnnOpConfigInputs(*opConfig, numOfInputs, inputTensors); + Qnn_Tensor_t *inputTensors) { + setQnnOpConfigInputs(*opConfig, numOfInputs, inputTensors); } -inline void setQnnOpConfigOutputs(Qnn_OpConfig_t& opConfig, +inline void setQnnOpConfigOutputs(Qnn_OpConfig_t &opConfig, uint32_t numOfOutputs, - Qnn_Tensor_t* outputTensors) { - if (opConfig.version == QNN_OPCONFIG_VERSION_1) { - opConfig.v1.numOfOutputs = numOfOutputs; - opConfig.v1.outputTensors = outputTensors; - } + Qnn_Tensor_t *outputTensors) { + if (opConfig.version == QNN_OPCONFIG_VERSION_1) { + opConfig.v1.numOfOutputs = numOfOutputs; + opConfig.v1.outputTensors = outputTensors; + } } -inline void setQnnOpConfigOutputs(Qnn_OpConfig_t* opConfig, +inline void setQnnOpConfigOutputs(Qnn_OpConfig_t *opConfig, uint32_t numOfOutputs, - Qnn_Tensor_t* outputTensors) { - setQnnOpConfigOutputs(*opConfig, numOfOutputs, outputTensors); + Qnn_Tensor_t *outputTensors) { + setQnnOpConfigOutputs(*opConfig, numOfOutputs, outputTensors); } -// inline Qnn_OpConfig_t +inline Qnn_Tensor_t createQnnTensor(const Qnn_TensorVersion_t version) { + Qnn_Tensor_t tensor = QNN_TENSOR_INIT; + tensor.version = version; + if (version == QNN_TENSOR_VERSION_1) { + tensor.v1 = QNN_TENSOR_V1_INIT; + } else if (version == QNN_TENSOR_VERSION_2) { + tensor.v2 = QNN_TENSOR_V2_INIT; + } + return tensor; +} -inline uint32_t getQnnTensorId(const Qnn_Tensor_t& tensor) { - if (tensor.version == QNN_TENSOR_VERSION_1) { +inline uint32_t getQnnTensorId(const Qnn_Tensor_t &tensor) { + // TensorCompatTest justifies no need to check version return tensor.v1.id; - } - return 0u; } -inline uint32_t getQnnTensorId(const Qnn_Tensor_t* tensor) { return getQnnTensorId(*tensor); } +inline uint32_t getQnnTensorId(const Qnn_Tensor_t *tensor) { + return getQnnTensorId(*tensor); +} -inline const char* getQnnTensorName(const Qnn_Tensor_t& tensor) { - if (tensor.version == QNN_TENSOR_VERSION_1) { +inline const char *getQnnTensorName(const Qnn_Tensor_t &tensor) { + // TensorCompatTest justifies no need to check version return tensor.v1.name; - } - return 0u; } - -inline const char* getQnnTensorName(const Qnn_Tensor_t* tensor) { - return getQnnTensorName(*tensor); +inline const char *getQnnTensorName(const Qnn_Tensor_t *tensor) { + return getQnnTensorName(*tensor); } -inline Qnn_TensorType_t getQnnTensorType(const Qnn_Tensor_t& tensor) { - if (tensor.version == QNN_TENSOR_VERSION_1) { +inline Qnn_TensorType_t getQnnTensorType(const Qnn_Tensor_t &tensor) { + // TensorCompatTest justifies no need to check version return tensor.v1.type; - } - return QNN_TENSOR_TYPE_UNDEFINED; } -inline Qnn_TensorType_t getQnnTensorType(const Qnn_Tensor_t* tensor) { - return getQnnTensorType(*tensor); +inline Qnn_TensorType_t getQnnTensorType(const Qnn_Tensor_t *tensor) { + return getQnnTensorType(*tensor); } -inline Qnn_TensorDataFormat_t getQnnTensorDataFormat(const Qnn_Tensor_t& tensor) { - if (tensor.version == QNN_TENSOR_VERSION_1) { +inline Qnn_TensorDataFormat_t getQnnTensorDataFormat(const Qnn_Tensor_t &tensor) { + // TensorCompatTest justifies no need to check version return tensor.v1.dataFormat; - } - return QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER; } -inline Qnn_TensorDataFormat_t getQnnTensorDataFormat(const Qnn_Tensor_t* tensor) { - return getQnnTensorDataFormat(*tensor); +inline Qnn_TensorDataFormat_t getQnnTensorDataFormat(const Qnn_Tensor_t *tensor) { + return getQnnTensorDataFormat(*tensor); } -inline Qnn_DataType_t getQnnTensorDataType(const Qnn_Tensor_t& tensor) { - if (tensor.version == QNN_TENSOR_VERSION_1) { +inline Qnn_DataType_t getQnnTensorDataType(const Qnn_Tensor_t &tensor) { + // TensorCompatTest justifies no need to check version return tensor.v1.dataType; - } - return QNN_DATATYPE_UNDEFINED; } -inline Qnn_DataType_t getQnnTensorDataType(const Qnn_Tensor_t* tensor) { - return getQnnTensorDataType(*tensor); +inline Qnn_DataType_t getQnnTensorDataType(const Qnn_Tensor_t *tensor) { + return getQnnTensorDataType(*tensor); } -inline Qnn_QuantizeParams_t getQnnTensorQuantParams(const Qnn_Tensor_t& tensor) { - if (tensor.version == QNN_TENSOR_VERSION_1) { +inline Qnn_QuantizeParams_t getQnnTensorQuantParams(const Qnn_Tensor_t &tensor) { + // TensorCompatTest justifies no need to check version return tensor.v1.quantizeParams; - } - return QNN_QUANTIZE_PARAMS_INIT; } -inline Qnn_QuantizeParams_t getQnnTensorQuantParams(const Qnn_Tensor_t* tensor) { - return getQnnTensorQuantParams(*tensor); +inline Qnn_QuantizeParams_t getQnnTensorQuantParams(const Qnn_Tensor_t *const tensor) { + if (tensor != nullptr) { + return getQnnTensorQuantParams(*tensor); + } + return QNN_QUANTIZE_PARAMS_INIT; } -inline uint32_t getQnnTensorRank(const Qnn_Tensor_t& tensor) { - if (tensor.version == QNN_TENSOR_VERSION_1) { +inline uint32_t getQnnTensorRank(const Qnn_Tensor_t &tensor) { + // TensorCompatTest justifies no need to check version return tensor.v1.rank; - } - return 0u; } -inline uint32_t getQnnTensorRank(const Qnn_Tensor_t* tensor) { return getQnnTensorRank(*tensor); } +inline uint32_t getQnnTensorRank(const Qnn_Tensor_t *const tensor) { + if (tensor != nullptr) { + return getQnnTensorRank(*tensor); + } + return 0u; +} -inline uint32_t* getQnnTensorDimensions(const Qnn_Tensor_t& tensor) { - if (tensor.version == QNN_TENSOR_VERSION_1) { +inline uint32_t *getQnnTensorDimensions(const Qnn_Tensor_t &tensor) { + // TensorCompatTest justifies no need to check version return tensor.v1.dimensions; - } - return nullptr; } -inline uint32_t* getQnnTensorDimensions(const Qnn_Tensor_t* tensor) { - return getQnnTensorDimensions(*tensor); +inline uint32_t *getQnnTensorDimensions(const Qnn_Tensor_t *tensor) { + return getQnnTensorDimensions(*tensor); +} + +inline uint8_t *getQnnTensorIsDynamicDimensions(const Qnn_Tensor_t &tensor) { + if (tensor.version == QNN_TENSOR_VERSION_2) { + return tensor.v2.isDynamicDimensions; + } + return NULL; +} + +inline uint8_t *getQnnTensorIsDynamicDimensions(const Qnn_Tensor_t *tensor) { + return getQnnTensorIsDynamicDimensions(*tensor); +} + +inline Qnn_SparseParams_t getQnnTensorSparseParams(const Qnn_Tensor_t &tensor) { + if (tensor.version == QNN_TENSOR_VERSION_2) { + return tensor.v2.sparseParams; + } + return QNN_SPARSE_PARAMS_INIT; +} + +inline Qnn_SparseParams_t getQnnTensorSparseParams(const Qnn_Tensor_t *tensor) { + return getQnnTensorSparseParams(*tensor); } -inline Qnn_TensorMemType_t getQnnTensorMemType(const Qnn_Tensor_t& tensor) { - if (tensor.version == QNN_TENSOR_VERSION_1) { +inline Qnn_TensorMemType_t getQnnTensorMemType(const Qnn_Tensor_t &tensor) { + // TensorCompatTest justifies no need to check version return tensor.v1.memType; - } - return QNN_TENSORMEMTYPE_UNDEFINED; } -inline Qnn_TensorMemType_t getQnnTensorMemType(const Qnn_Tensor_t* tensor) { - return getQnnTensorMemType(*tensor); +inline Qnn_TensorMemType_t getQnnTensorMemType(const Qnn_Tensor_t *tensor) { + return getQnnTensorMemType(*tensor); } -inline Qnn_ClientBuffer_t getQnnTensorClientBuf(const Qnn_Tensor_t& tensor) { - if (tensor.version == QNN_TENSOR_VERSION_1) { +inline Qnn_ClientBuffer_t getQnnTensorClientBuf(const Qnn_Tensor_t &tensor) { + // TensorCompatTest justifies no need to check version return tensor.v1.clientBuf; - } - return QNN_CLIENT_BUFFER_INIT; } -inline Qnn_ClientBuffer_t getQnnTensorClientBuf(const Qnn_Tensor_t* tensor) { - return getQnnTensorClientBuf(*tensor); +inline Qnn_ClientBuffer_t getQnnTensorClientBuf(const Qnn_Tensor_t *tensor) { + return getQnnTensorClientBuf(*tensor); } -inline Qnn_MemHandle_t getQnnTensorMemHandle(const Qnn_Tensor_t& tensor) { - if (tensor.version == QNN_TENSOR_VERSION_1) { +inline Qnn_MemHandle_t getQnnTensorMemHandle(const Qnn_Tensor_t &tensor) { + // TensorCompatTest justifies no need to check version return tensor.v1.memHandle; - } - return nullptr; } -inline Qnn_MemHandle_t getQnnTensorMemHandle(const Qnn_Tensor_t* tensor) { - return getQnnTensorMemHandle(*tensor); +inline Qnn_MemHandle_t getQnnTensorMemHandle(const Qnn_Tensor_t *tensor) { + return getQnnTensorMemHandle(*tensor); } -inline void setQnnTensorId(Qnn_Tensor_t& tensor, uint32_t id) { - if (tensor.version == QNN_TENSOR_VERSION_1) { +inline void setQnnTensorId(Qnn_Tensor_t &tensor, const uint32_t id) { + // TensorCompatTest justifies no need to check version tensor.v1.id = id; - } } -inline void setQnnTensorId(Qnn_Tensor_t* tensor, uint32_t id) { setQnnTensorId(*tensor, id); } +inline void setQnnTensorId(Qnn_Tensor_t *tensor, uint32_t id) { + setQnnTensorId(*tensor, id); +} -inline void setQnnTensorName(Qnn_Tensor_t& tensor, const char* name) { - if (tensor.version == QNN_TENSOR_VERSION_1) { +inline void setQnnTensorName(Qnn_Tensor_t &tensor, const char *const name) { + // TensorCompatTest justifies no need to check version tensor.v1.name = name; - } } -inline void setQnnTensorName(Qnn_Tensor_t* tensor, const char* name) { - setQnnTensorName(*tensor, name); +inline void setQnnTensorName(Qnn_Tensor_t *tensor, const char *name) { + setQnnTensorName(*tensor, name); } -inline void setQnnTensorType(Qnn_Tensor_t& tensor, Qnn_TensorType_t type) { - if (tensor.version == QNN_TENSOR_VERSION_1) { +inline void setQnnTensorType(Qnn_Tensor_t &tensor, Qnn_TensorType_t type) { + // TensorCompatTest justifies no need to check version tensor.v1.type = type; - } } -inline void setQnnTensorType(Qnn_Tensor_t* tensor, Qnn_TensorType_t type) { - setQnnTensorType(*tensor, type); +inline void setQnnTensorType(Qnn_Tensor_t *tensor, Qnn_TensorType_t type) { + setQnnTensorType(*tensor, type); } -inline void setQnnTensorDataFormat(Qnn_Tensor_t& tensor, Qnn_TensorDataFormat_t format) { - if (tensor.version == QNN_TENSOR_VERSION_1) { - tensor.v1.dataFormat = format; - } +inline void setQnnTensorDataFormat(Qnn_Tensor_t &tensor, const Qnn_TensorDataFormat_t dataFormat) { + // TensorCompatTest justifies no need to check version + tensor.v1.dataFormat = dataFormat; } -inline void setQnnTensorDataFormat(Qnn_Tensor_t* tensor, Qnn_TensorDataFormat_t format) { - setQnnTensorDataFormat(*tensor, format); +inline void setQnnTensorDataFormat(Qnn_Tensor_t *tensor, Qnn_TensorDataFormat_t format) { + setQnnTensorDataFormat(*tensor, format); } -inline void setQnnTensorDataType(Qnn_Tensor_t& tensor, Qnn_DataType_t dataType) { - if (tensor.version == QNN_TENSOR_VERSION_1) { +inline void setQnnTensorDataType(Qnn_Tensor_t &tensor, const Qnn_DataType_t dataType) { + // TensorCompatTest justifies no need to check version tensor.v1.dataType = dataType; - } } -inline void setQnnTensorDataType(Qnn_Tensor_t* tensor, Qnn_DataType_t dataType) { - setQnnTensorDataType(*tensor, dataType); +inline void setQnnTensorDataType(Qnn_Tensor_t *tensor, Qnn_DataType_t dataType) { + setQnnTensorDataType(*tensor, dataType); } -inline void setQnnTensorQuantParams(Qnn_Tensor_t& tensor, Qnn_QuantizeParams_t params) { - if (tensor.version == QNN_TENSOR_VERSION_1) { - tensor.v1.quantizeParams = params; - } +inline void setQnnTensorQuantParams(Qnn_Tensor_t &tensor, + const Qnn_QuantizeParams_t quantizeParams) { + // TensorCompatTest justifies no need to check version + tensor.v1.quantizeParams = quantizeParams; } -inline void setQnnTensorQuantParams(Qnn_Tensor_t* tensor, Qnn_QuantizeParams_t params) { - setQnnTensorQuantParams(*tensor, params); +inline void setQnnTensorQuantParams(Qnn_Tensor_t *tensor, Qnn_QuantizeParams_t params) { + setQnnTensorQuantParams(*tensor, params); } -inline void setQnnTensorRank(Qnn_Tensor_t& tensor, uint32_t rank) { - if (tensor.version == QNN_TENSOR_VERSION_1) { +inline void setQnnTensorRank(Qnn_Tensor_t &tensor, const uint32_t rank) { + // TensorCompatTest justifies no need to check version tensor.v1.rank = rank; - } } -inline void setQnnTensorRank(Qnn_Tensor_t* tensor, uint32_t rank) { - setQnnTensorRank(*tensor, rank); +inline void setQnnTensorRank(Qnn_Tensor_t *tensor, uint32_t rank) { + setQnnTensorRank(*tensor, rank); } -inline void setQnnTensorDimensions(Qnn_Tensor_t& tensor, uint32_t* dims) { - if (tensor.version == QNN_TENSOR_VERSION_1) { - tensor.v1.dimensions = dims; - } +inline void setQnnTensorDimensions(Qnn_Tensor_t &tensor, uint32_t *const dimensions) { + // TensorCompatTest justifies no need to check version + tensor.v1.dimensions = dimensions; } -inline void setQnnTensorDimensions(Qnn_Tensor_t* tensor, uint32_t* dims) { - setQnnTensorDimensions(*tensor, dims); +inline void setQnnTensorDimensions(Qnn_Tensor_t *tensor, uint32_t *dims) { + setQnnTensorDimensions(*tensor, dims); } -inline void setQnnTensorMemType(Qnn_Tensor_t& tensor, Qnn_TensorMemType_t memType) { - if (tensor.version == QNN_TENSOR_VERSION_1) { +inline void setQnnTensorIsDynamicDimensions(Qnn_Tensor_t &tensor, uint8_t *isDynamic) { + if (tensor.version == QNN_TENSOR_VERSION_2) { + tensor.v2.isDynamicDimensions = isDynamic; + } +} + +inline void setQnnTensorIsDynamicDimensions(Qnn_Tensor_t *tensor, uint8_t *isDynamic) { + setQnnTensorIsDynamicDimensions(*tensor, isDynamic); +} + +inline void setQnnTensorSparseParams(Qnn_Tensor_t &tensor, Qnn_SparseParams_t sparseParams) { + if (tensor.version == QNN_TENSOR_VERSION_2) { + tensor.v2.sparseParams = sparseParams; + } +} + +inline void setQnnTensorSparseParams(Qnn_Tensor_t *tensor, Qnn_SparseParams_t sparseParams) { + setQnnTensorSparseParams(*tensor, sparseParams); +} + +inline void setQnnTensorMemType(Qnn_Tensor_t &tensor, const Qnn_TensorMemType_t memType) { + // TensorCompatTest justifies no need to check version tensor.v1.memType = memType; - } } -inline void setQnnTensorMemType(Qnn_Tensor_t* tensor, Qnn_TensorMemType_t memType) { - setQnnTensorMemType(*tensor, memType); +inline void setQnnTensorMemType(Qnn_Tensor_t *tensor, Qnn_TensorMemType_t memType) { + setQnnTensorMemType(*tensor, memType); } -inline void setQnnTensorClientBuf(Qnn_Tensor_t& tensor, Qnn_ClientBuffer_t clientBuf) { - if (tensor.version == QNN_TENSOR_VERSION_1) { +inline void setQnnTensorClientBuf(Qnn_Tensor_t &tensor, const Qnn_ClientBuffer_t clientBuf) { + // TensorCompatTest justifies no need to check version tensor.v1.clientBuf = clientBuf; - } } -inline void setQnnTensorClientBuf(Qnn_Tensor_t* tensor, Qnn_ClientBuffer_t clientBuf) { - setQnnTensorClientBuf(*tensor, clientBuf); +inline void setQnnTensorClientBuf(Qnn_Tensor_t *tensor, Qnn_ClientBuffer_t clientBuf) { + setQnnTensorClientBuf(*tensor, clientBuf); } -inline void setQnnTensorMemHandle(Qnn_Tensor_t& tensor, Qnn_MemHandle_t handle) { - if (tensor.version == QNN_TENSOR_VERSION_1) { - tensor.v1.memHandle = handle; - } +inline void setQnnTensorMemHandle(Qnn_Tensor_t &tensor, const Qnn_MemHandle_t memHandle) { + // TensorCompatTest justifies no need to check version + tensor.v1.memHandle = memHandle; } -inline void setQnnTensorMemHandle(Qnn_Tensor_t* tensor, Qnn_MemHandle_t handle) { - setQnnTensorMemHandle(*tensor, handle); +inline void setQnnTensorMemHandle(Qnn_Tensor_t *tensor, Qnn_MemHandle_t handle) { + setQnnTensorMemHandle(*tensor, handle); +} + +inline void setQnnTensorClientBufRetrieve(Qnn_Tensor_t &tensor, + Qnn_TensorRetrieveRaw_t *const retrieve) { + if (tensor.version == QNN_TENSOR_VERSION_2) { + tensor.v2.retrieveRaw = retrieve; + } +} +inline void setQnnTensorClientBufRetrieve(Qnn_Tensor_t *const tensor, + Qnn_TensorRetrieveRaw_t *const retrieve) { + setQnnTensorClientBufRetrieve(*tensor, retrieve); +} +inline void setQnnTensorClientBufRetrieve(Qnn_Tensor_t &tensor, Qnn_TensorRetrieveRaw_t &retrieve) { + setQnnTensorClientBufRetrieve(tensor, &retrieve); +} +inline void setQnnTensorClientBufRetrieve(Qnn_Tensor_t *const tensor, + Qnn_TensorRetrieveRaw_t &retrieve) { + setQnnTensorClientBufRetrieve(*tensor, &retrieve); +} + +inline Qnn_TensorRetrieveRaw_t *getQnnTensorClientBufRetrieve(const Qnn_Tensor_t &tensor) { + if (tensor.version == QNN_TENSOR_VERSION_2) { + return tensor.v2.retrieveRaw; + } + return nullptr; +} +inline Qnn_TensorRetrieveRaw_t *getQnnTensorClientBufRetrieve(const Qnn_Tensor_t *const tensor) { + return getQnnTensorClientBufRetrieve(*tensor); +} + +inline Qnn_TensorSet_t createQnnTensorSet(const Qnn_TensorSetVersion_t version) { + Qnn_TensorSet_t tensorSet = QNN_TENSOR_SET_INIT; + tensorSet.version = version; + if (version == QNN_TENSOR_SET_VERSION_1) { + tensorSet.v1 = QNN_TENSOR_SET_V1_INIT; + } + return tensorSet; +} + +inline uint32_t getQnnTensorSetNumInputs(const Qnn_TensorSet_t &tensorSet) { + if (tensorSet.version == QNN_TENSOR_SET_VERSION_1) { + return tensorSet.v1.numInputs; + } + return 0; +} + +inline uint32_t getQnnTensorSetNumInputs(const Qnn_TensorSet_t *tensorSet) { + return getQnnTensorSetNumInputs(*tensorSet); +} + +inline Qnn_Tensor_t *getQnnTensorSetInputTensors(const Qnn_TensorSet_t &tensorSet) { + if (tensorSet.version == QNN_TENSOR_SET_VERSION_1) { + return tensorSet.v1.inputs; + } + return 0; +} + +inline Qnn_Tensor_t *getQnnTensorSetInputTensors(const Qnn_TensorSet_t *tensorSet) { + return getQnnTensorSetInputTensors(*tensorSet); +} + +inline uint32_t getQnnTensorSetNumOutputs(const Qnn_TensorSet_t &tensorSet) { + if (tensorSet.version == QNN_TENSOR_SET_VERSION_1) { + return tensorSet.v1.numOutputs; + } + return 0; +} + +inline uint32_t getQnnTensorSetNumOutputs(const Qnn_TensorSet_t *tensorSet) { + return getQnnTensorSetNumOutputs(*tensorSet); +} + +inline Qnn_Tensor_t *getQnnTensorSetOutputTensors(const Qnn_TensorSet_t &tensorSet) { + if (tensorSet.version == QNN_TENSOR_SET_VERSION_1) { + return tensorSet.v1.outputs; + } + return 0; +} + +inline Qnn_Tensor_t *getQnnTensorSetOutputTensors(const Qnn_TensorSet_t *tensorSet) { + return getQnnTensorSetOutputTensors(*tensorSet); +} + +inline void setQnnTensorSetInputTensors(Qnn_TensorSet_t &tensorSet, + Qnn_Tensor_t *inputTensors, + uint32_t const numInputs) { + if (tensorSet.version == QNN_TENSOR_SET_VERSION_1) { + tensorSet.v1.inputs = inputTensors; + tensorSet.v1.numInputs = numInputs; + } +} + +inline void setQnnTensorSetInputTensors(Qnn_TensorSet_t *tensorSet, + Qnn_Tensor_t *inputTensors, + uint32_t const numInputs) { + setQnnTensorSetInputTensors(*tensorSet, inputTensors, numInputs); +} + +inline void setQnnTensorSetOutputTensors(Qnn_TensorSet_t &tensorSet, + Qnn_Tensor_t *outputTensors, + const uint32_t numOutputs) { + if (tensorSet.version == QNN_TENSOR_SET_VERSION_1) { + tensorSet.v1.outputs = outputTensors; + tensorSet.v1.numOutputs = numOutputs; + } +} + +inline void setQnnTensorSetOutputTensors(Qnn_TensorSet_t *tensorSet, + Qnn_Tensor_t *outputTensors, + const uint32_t numOutputs) { + setQnnTensorSetOutputTensors(*tensorSet, outputTensors, numOutputs); } // Validation -#define VALIDATE_TENSOR_VERSION(tensor, err) VALIDATE(validateTensorVersion(tensor), err) -#define VALIDATE_OP_CONFIG_VERSION(op, err) VALIDATE(validateOpConfigVersion(op), err) +#define VALIDATE_TENSOR_VERSION(tensor, err) validateTensorVersion(tensor) +#define VALIDATE_OP_CONFIG_VERSION(op, err) validateOpConfigVersion(op) + +// Creator for QNN Op Config +#define QNN_OP_CFG_CREATE(version) createQnnOpConfig(version) // Accessors for QNN Op Config -#define QNN_OP_CFG_GET_NAME(opConfig) getQnnOpConfigName(opConfig) +#define QNN_OP_CFG_GET_NAME(opConfig) getQnnOpConfigName(opConfig) #define QNN_OP_CFG_GET_PACKAGE_NAME(opConfig) getQnnOpConfigPackageName(opConfig) -#define QNN_OP_CFG_GET_TYPE_NAME(opConfig) getQnnOpConfigTypeName(opConfig) -#define QNN_OP_CFG_GET_NUM_PARAMS(opConfig) getQnnOpConfigNumParams(opConfig) -#define QNN_OP_CFG_GET_PARAMS(opConfig) getQnnOpConfigParams(opConfig) -#define QNN_OP_CFG_GET_NUM_INPUTS(opConfig) getQnnOpConfigNumInputs(opConfig) -#define QNN_OP_CFG_GET_INPUTS(opConfig) getQnnOpConfigInputs(opConfig) -#define QNN_OP_CFG_GET_NUM_OUTPUTS(opConfig) getQnnOpConfigNumOutputs(opConfig) -#define QNN_OP_CFG_GET_OUTPUTS(opConfig) getQnnOpConfigOutputs(opConfig) +#define QNN_OP_CFG_GET_TYPE_NAME(opConfig) getQnnOpConfigTypeName(opConfig) +#define QNN_OP_CFG_GET_NUM_PARAMS(opConfig) getQnnOpConfigNumParams(opConfig) +#define QNN_OP_CFG_GET_PARAMS(opConfig) getQnnOpConfigParams(opConfig) +#define QNN_OP_CFG_GET_NUM_INPUTS(opConfig) getQnnOpConfigNumInputs(opConfig) +#define QNN_OP_CFG_GET_INPUTS(opConfig) getQnnOpConfigInputs(opConfig) +#define QNN_OP_CFG_GET_NUM_OUTPUTS(opConfig) getQnnOpConfigNumOutputs(opConfig) +#define QNN_OP_CFG_GET_OUTPUTS(opConfig) getQnnOpConfigOutputs(opConfig) // Modifiers for QNN Op Config -#define QNN_OP_CFG_SET_NAME(opConfig, value) setQnnOpConfigName(opConfig, value) +#define QNN_OP_CFG_SET_NAME(opConfig, value) setQnnOpConfigName(opConfig, value) #define QNN_OP_CFG_SET_PACKAGE_NAME(opConfig, value) setQnnOpConfigPackageName(opConfig, value) -#define QNN_OP_CFG_SET_TYPE_NAME(opConfig, value) setQnnOpConfigTypeName(opConfig, value) +#define QNN_OP_CFG_SET_TYPE_NAME(opConfig, value) setQnnOpConfigTypeName(opConfig, value) #define QNN_OP_CFG_SET_PARAMS(opConfig, numOfParams, params) \ - setQnnOpConfigParams(opConfig, numOfParams, params) + setQnnOpConfigParams(opConfig, numOfParams, params) #define QNN_OP_CFG_SET_INPUTS(opConfig, numOfInputs, inputTensors) \ - setQnnOpConfigInputs(opConfig, numOfInputs, inputTensors) + setQnnOpConfigInputs(opConfig, numOfInputs, inputTensors) #define QNN_OP_CFG_SET_OUTPUTS(opConfig, numOfOutputs, outputTensors) \ - setQnnOpConfigOutputs(opConfig, numOfOutputs, outputTensors) + setQnnOpConfigOutputs(opConfig, numOfOutputs, outputTensors) + +// Creator for QNN Tensor +#define QNN_TENSOR_CREATE(version) createQnnTensor(version) // Accessors for QNN Tensor -#define QNN_TENSOR_GET_ID(tensor) getQnnTensorId(tensor) -#define QNN_TENSOR_GET_NAME(tensor) getQnnTensorName(tensor) -#define QNN_TENSOR_GET_TYPE(tensor) getQnnTensorType(tensor) -#define QNN_TENSOR_GET_DATA_FORMAT(tensor) getQnnTensorDataFormat(tensor) -#define QNN_TENSOR_GET_DATA_TYPE(tensor) getQnnTensorDataType(tensor) +#define QNN_TENSOR_GET_ID(tensor) getQnnTensorId(tensor) +#define QNN_TENSOR_GET_NAME(tensor) getQnnTensorName(tensor) +#define QNN_TENSOR_GET_TYPE(tensor) getQnnTensorType(tensor) +#define QNN_TENSOR_GET_DATA_FORMAT(tensor) getQnnTensorDataFormat(tensor) +#define QNN_TENSOR_GET_DATA_TYPE(tensor) getQnnTensorDataType(tensor) #define QNN_TENSOR_GET_QUANT_PARAMS(tensor) getQnnTensorQuantParams(tensor) -#define QNN_TENSOR_GET_RANK(tensor) getQnnTensorRank(tensor) -#define QNN_TENSOR_GET_DIMENSIONS(tensor) getQnnTensorDimensions(tensor) -#define QNN_TENSOR_GET_MEM_TYPE(tensor) getQnnTensorMemType(tensor) -#define QNN_TENSOR_GET_CLIENT_BUF(tensor) getQnnTensorClientBuf(tensor) -#define QNN_TENSOR_GET_MEM_HANDLE(tensor) getQnnTensorMemHandle(tensor) +#define QNN_TENSOR_GET_RANK(tensor) getQnnTensorRank(tensor) +#define QNN_TENSOR_GET_DIMENSIONS(tensor) getQnnTensorDimensions(tensor) +#define QNN_TENSOR_GET_IS_DYNAMIC_DIMENSIONS(tensor) getQnnTensorIsDynamicDimensions(tensor) +#define QNN_TENSOR_GET_SPARSE_PARAMS(tensor) getQnnTensorSparseParams(tensor) +#define QNN_TENSOR_GET_MEM_TYPE(tensor) getQnnTensorMemType(tensor) +#define QNN_TENSOR_GET_CLIENT_BUF(tensor) getQnnTensorClientBuf(tensor) +#define QNN_TENSOR_GET_MEM_HANDLE(tensor) getQnnTensorMemHandle(tensor) +#define QNN_TENSOR_GET_CLIENT_BUF_RETRIEVE(tensor) getQnnTensorClientBufRetrieve(tensor) // Modifiers for QNN Tensor -#define QNN_TENSOR_SET_ID(tensor, value) setQnnTensorId(tensor, value) -#define QNN_TENSOR_SET_NAME(tensor, value) setQnnTensorName(tensor, value) -#define QNN_TENSOR_SET_TYPE(tensor, value) setQnnTensorType(tensor, value) -#define QNN_TENSOR_SET_DATA_FORMAT(tensor, value) setQnnTensorDataFormat(tensor, value) -#define QNN_TENSOR_SET_DATA_TYPE(tensor, value) setQnnTensorDataType(tensor, value) +#define QNN_TENSOR_SET_ID(tensor, value) setQnnTensorId(tensor, value) +#define QNN_TENSOR_SET_NAME(tensor, value) setQnnTensorName(tensor, value) +#define QNN_TENSOR_SET_TYPE(tensor, value) setQnnTensorType(tensor, value) +#define QNN_TENSOR_SET_DATA_FORMAT(tensor, value) setQnnTensorDataFormat(tensor, value) +#define QNN_TENSOR_SET_DATA_TYPE(tensor, value) setQnnTensorDataType(tensor, value) #define QNN_TENSOR_SET_QUANT_PARAMS(tensor, value) setQnnTensorQuantParams(tensor, value) -#define QNN_TENSOR_SET_RANK(tensor, value) setQnnTensorRank(tensor, value) -#define QNN_TENSOR_SET_DIMENSIONS(tensor, value) setQnnTensorDimensions(tensor, value) -#define QNN_TENSOR_SET_MEM_TYPE(tensor, value) setQnnTensorMemType(tensor, value) -#define QNN_TENSOR_SET_CLIENT_BUF(tensor, value) setQnnTensorClientBuf(tensor, value) -#define QNN_TENSOR_SET_MEM_HANDLE(tensor, value) setQnnTensorMemHandle(tensor, value) - -} // namespace qnn_wrapper_api +#define QNN_TENSOR_SET_RANK(tensor, value) setQnnTensorRank(tensor, value) +#define QNN_TENSOR_SET_DIMENSIONS(tensor, value) setQnnTensorDimensions(tensor, value) +#define QNN_TENSOR_SET_IS_DYNAMIC_DIMENSIONS(tensor, value) \ + setQnnTensorIsDynamicDimensions(tensor, value) +#define QNN_TENSOR_SET_SPARSE_PARAMS(tensor, value) setQnnTensorSparseParams(tensor, value) +#define QNN_TENSOR_SET_MEM_TYPE(tensor, value) setQnnTensorMemType(tensor, value) +#define QNN_TENSOR_SET_CLIENT_BUF(tensor, value) setQnnTensorClientBuf(tensor, value) +#define QNN_TENSOR_SET_MEM_HANDLE(tensor, value) setQnnTensorMemHandle(tensor, value) +#define QNN_TENSOR_SET_CLIENT_BUF_RETRIEVE(tensor, value) \ + setQnnTensorClientBufRetrieve(tensor, value) + +// Creator for QNN Tensor Set +#define QNN_TENSORSET_CREATE(version) createQnnTensorSet(version) + +// Accessors for QNN Tensor Set +#define QNN_TENSORSET_GET_NUM_INPUTS(tensorSet) getQnnTensorSetNumInputs(tensorSet) +#define QNN_TENSORSET_GET_INPUT_TENSORS(tensorSet) getQnnTensorSetInputTensors(tensorSet) +#define QNN_TENSORSET_GET_NUM_OUTPUTS(tensorSet) getQnnTensorSetNumOutputs(tensorSet) +#define QNN_TENSORSET_GET_OUTPUT_TENSORS(tensorSet) getQnnTensorSetOutputTensors(tensorSet) + +// Modifiers for QNN Tensor Set +#define QNN_TENSORSET_SET_INPUT_TENSORS(tensorSet, inputTensors, numInputs) \ + setQnnTensorSetInputTensors(tensorSet, inputTensors, numInputs) +#define QNN_TENSORSET_SET_OUTPUT_TENSORS(tensorSet, outputTensors, numOutputs) \ + setQnnTensorSetOutputTensors(tensorSet, outputTensors, numOutputs) + +inline bool isQnnTensorV1Compatible(const Qnn_Tensor_t &tensor) { + if (tensor.version == QNN_TENSOR_VERSION_2) { + if (tensor.v2.isDynamicDimensions != NULL) { + return false; + } + if (tensor.v2.dataFormat == QNN_TENSOR_DATA_FORMAT_SPARSE) { + return false; + } + } + return true; +} +inline bool isQnnTensorV1Compatible(const Qnn_Tensor_t *const tensor) { + return isQnnTensorV1Compatible(*tensor); +} +inline bool isQnnTensorV1Compatible(const Qnn_OpConfig_t &opConfig) { + if ((QNN_OP_CFG_GET_INPUTS(opConfig) != NULL) && (QNN_OP_CFG_GET_NUM_INPUTS(opConfig) > 0u)) { + for (uint32_t tensorIdx = 0u; tensorIdx < QNN_OP_CFG_GET_NUM_INPUTS(opConfig); tensorIdx++) { + if (!isQnnTensorV1Compatible(QNN_OP_CFG_GET_INPUTS(opConfig)[tensorIdx])) { + return false; + } + } + } + if ((QNN_OP_CFG_GET_OUTPUTS(opConfig) != NULL) && (QNN_OP_CFG_GET_NUM_OUTPUTS(opConfig) > 0u)) { + for (uint32_t tensorIdx = 0u; tensorIdx < QNN_OP_CFG_GET_NUM_OUTPUTS(opConfig); tensorIdx++) { + if (!isQnnTensorV1Compatible(QNN_OP_CFG_GET_OUTPUTS(opConfig)[tensorIdx])) { + return false; + } + } + } + if ((QNN_OP_CFG_GET_PARAMS(opConfig) != NULL) && (QNN_OP_CFG_GET_NUM_PARAMS(opConfig) > 0)) { + for (uint32_t paramIdx = 0u; paramIdx < QNN_OP_CFG_GET_NUM_PARAMS(opConfig); paramIdx++) { + const Qnn_Param_t ¶m = QNN_OP_CFG_GET_PARAMS(opConfig)[paramIdx]; + if (QNN_PARAMTYPE_TENSOR == param.paramType) { + if (!isQnnTensorV1Compatible(param.tensorParam)) { + return false; + } + } + } + } + return true; +} +inline bool isQnnTensorV1Compatible(const Qnn_OpConfig_t *const opConfig) { + return isQnnTensorV1Compatible(*opConfig); +} \ No newline at end of file diff --git a/src/backends/qnn/README.md b/src/backends/qnn/README.md index b7f38d908..6081c9398 100644 --- a/src/backends/qnn/README.md +++ b/src/backends/qnn/README.md @@ -7,7 +7,7 @@ This section is basically following the QNN documentation, for more details, see The QNN backend relies on the Qualcomm QNN SDK and Hexagon SDK to compile QNN Backends and LLM-specific operators. The QNN SDK can be downloaded [here](https://www.qualcomm.com/developer/software/qualcomm-ai-engine-direct-sdk). The Hexagon SDK can be downloaded using [QPM](https://qpm.qualcomm.com/). The compiling environment only supports Linux now. Version requirements: -* QNN: [Linux v2.20+](https://qpm.qualcomm.com/#/main/tools/details/qualcomm_neural_processing_sdk) +* QNN: [Linux v2.31+](https://qpm.qualcomm.com/#/main/tools/details/qualcomm_neural_processing_sdk) * Hexagon SDK: [Linux 5.x](https://qpm.qualcomm.com/#/main/tools/details/HexagonSDK5.x) (Some accounts may have no permission to access this SDK and may need to contact Qualcomm for support.) **NOTE:** After downloading the QNN SDK, unzip the file and move the folder name like `qairt/2.31.0.250130` to `src/backends/qnn/` and rename the version to 'sdk'. The folder structure should be like `src/backends/qnn/sdk`. diff --git a/src/backends/qnn/Utils/QnnSampleAppUtils.cpp b/src/backends/qnn/Utils/QnnSampleAppUtils.cpp index d38bf9948..6a213ed02 100644 --- a/src/backends/qnn/Utils/QnnSampleAppUtils.cpp +++ b/src/backends/qnn/Utils/QnnSampleAppUtils.cpp @@ -1,7 +1,7 @@ //============================================================================== // -// Copyright (c) 2019-2023 Qualcomm Technologies, Inc. -// All Rights Reserved. +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All rights reserved. // Confidential and Proprietary - Qualcomm Technologies, Inc. // //============================================================================== @@ -14,12 +14,14 @@ #include #include -#include "Log/Logger.hpp" +#include "Logger.hpp" +#ifndef __hexagon__ #include "PAL/Directory.hpp" #include "PAL/FileOp.hpp" #include "PAL/Path.hpp" +#endif #include "PAL/StringOp.hpp" -#include "Utils/QnnSampleAppUtils.hpp" +#include "QnnSampleAppUtils.hpp" #include "QnnTypeMacros.hpp" using namespace qnn; @@ -29,312 +31,411 @@ using namespace qnn_wrapper_api; void sample_app::split(std::vector &splitString, const std::string &tokenizedString, const char separator) { - splitString.clear(); - std::istringstream tokenizedStringStream(tokenizedString); - while (!tokenizedStringStream.eof()) { - std::string value; - getline(tokenizedStringStream, value, separator); - if (!value.empty()) { - splitString.push_back(value); + splitString.clear(); + std::istringstream tokenizedStringStream(tokenizedString); + while (!tokenizedStringStream.eof()) { + std::string value; + getline(tokenizedStringStream, value, separator); + if (!value.empty()) { + splitString.push_back(value); + } } - } } void sample_app::parseInputFilePaths(std::vector &inputFilePaths, std::vector &paths, std::string separator) { - for (auto &inputInfo : inputFilePaths) { - auto position = inputInfo.find(separator); - if (position != std::string::npos) { - auto path = inputInfo.substr(position + separator.size()); - paths.push_back(path); - } else { - paths.push_back(inputInfo); + for (auto &inputInfo : inputFilePaths) { + auto position = inputInfo.find(separator); + if (position != std::string::npos) { + auto path = inputInfo.substr(position + separator.size()); + paths.push_back(path); + } else { + paths.push_back(inputInfo); + } } - } } sample_app::ReadInputListsRetType_t sample_app::readInputLists( std::vector inputFileListPaths) { - std::vector>> filePathsLists; - for (auto const &path : inputFileListPaths) { - bool readSuccess; - std::vector> filePathList; - std::tie(filePathList, readSuccess) = readInputList(path); - if (!readSuccess) { - filePathsLists.clear(); - return std::make_tuple(filePathsLists, false); + std::vector>> filePathsLists; + std::vector> inputNameToIndexMaps; + for (auto const &path : inputFileListPaths) { + bool readSuccess; + std::vector> filePathList; + std::unordered_map inputNameToIndex; + std::tie(filePathList, inputNameToIndex, readSuccess) = readInputList(path); + if (!readSuccess) { + filePathsLists.clear(); + return std::make_tuple(filePathsLists, inputNameToIndexMaps, false); + } + filePathsLists.push_back(filePathList); + inputNameToIndexMaps.push_back(inputNameToIndex); } - filePathsLists.push_back(filePathList); - } - return std::make_tuple(filePathsLists, true); + return std::make_tuple(filePathsLists, inputNameToIndexMaps, true); } sample_app::ReadInputListRetType_t sample_app::readInputList(const std::string inputFileListPath) { - std::queue lines; - std::ifstream fileListStream(inputFileListPath); - if (!fileListStream) { - QNN_ERROR("Failed to open input file: %s", inputFileListPath.c_str()); - std::vector> result; - return std::make_tuple(result, false); - } - std::string fileLine; - while (std::getline(fileListStream, fileLine)) { - if (fileLine.empty()) continue; - lines.push(fileLine); - } - if (!lines.empty() && lines.front().compare(0, 1, "#") == 0) { - lines.pop(); - } - std::string separator = ":="; - std::vector> filePathsList; - while (!lines.empty()) { - std::vector paths{}; + std::queue lines; + std::ifstream fileListStream(inputFileListPath); + if (!fileListStream) { + QNN_ERROR("Failed to open input file: %s", inputFileListPath.c_str()); + return std::make_tuple(std::vector>{}, + std::unordered_map{}, + false); + } + + std::string fileLine; + while (std::getline(fileListStream, fileLine)) { + if (fileLine.empty()) continue; + lines.push(fileLine); + } + + if (!lines.empty() && lines.front().compare(0, 1, "#") == 0) { + lines.pop(); + } + + if (!lines.empty() && lines.front().compare(0, 1, "%") == 0) { + lines.pop(); + } + + std::string separator = ":="; + std::vector> filePathsList; + std::unordered_map inputNameToIndex; + if (!lines.empty()) { + inputNameToIndex = extractInputNameIndices(lines.front(), separator); + } + while (!lines.empty()) { + std::vector paths{}; + std::vector inputFilePaths; + split(inputFilePaths, lines.front(), ' '); + parseInputFilePaths(inputFilePaths, paths, separator); + filePathsList.reserve(paths.size()); + for (size_t idx = 0; idx < paths.size(); idx++) { + if (idx >= filePathsList.size()) { + filePathsList.push_back(std::vector()); + } + filePathsList[idx].push_back(paths[idx]); + } + lines.pop(); + } + return std::make_tuple(filePathsList, inputNameToIndex, true); +} + +std::unordered_map sample_app::extractInputNameIndices( + const std::string &inputLine, const std::string &separator) { std::vector inputFilePaths; - split(inputFilePaths, lines.front(), ' '); - parseInputFilePaths(inputFilePaths, paths, separator); - // TODO: multi input support - filePathsList.reserve(paths.size()); - for (size_t idx = 0; idx < paths.size(); idx++) { - if (idx >= filePathsList.size()) { - filePathsList.push_back(std::queue()); - } - filePathsList.back().push(paths[idx]); + std::unordered_map inputNameToIndex; + split(inputFilePaths, inputLine, ' '); + size_t inputCount = 0; + for (uint32_t idx = 0; idx < inputFilePaths.size(); idx++) { + auto position = inputFilePaths[idx].find(separator); + if (position != std::string::npos) { + auto unsanitizedTensorName = inputFilePaths[idx].substr(0, position); + auto sanitizedTensorName = sanitizeTensorName(unsanitizedTensorName); + if (sanitizedTensorName != unsanitizedTensorName) { + inputNameToIndex[unsanitizedTensorName] = idx; + } + inputNameToIndex[sanitizedTensorName] = idx; + inputCount = inputCount + 1; + } } - lines.pop(); - } - return std::make_tuple(filePathsList, true); + return inputCount == inputFilePaths.size() ? inputNameToIndex : std::unordered_map{}; +} + +std::string sample_app::sanitizeTensorName(std::string name) { + std::string sanitizedName = std::regex_replace(name, std::regex("\\W+"), "_"); + if (!std::isalpha(sanitizedName[0]) && sanitizedName[0] != '_') { + sanitizedName = "_" + sanitizedName; + } + return sanitizedName; } sample_app::ProfilingLevel sample_app::parseProfilingLevel(std::string profilingLevelString) { - std::transform(profilingLevelString.begin(), - profilingLevelString.end(), - profilingLevelString.begin(), - ::tolower); - ProfilingLevel parsedProfilingLevel = ProfilingLevel::INVALID; - if (profilingLevelString == "off") { - parsedProfilingLevel = ProfilingLevel::OFF; - } else if (profilingLevelString == "basic") { - parsedProfilingLevel = ProfilingLevel::BASIC; - } else if (profilingLevelString == "detailed") { - parsedProfilingLevel = ProfilingLevel::DETAILED; - } - return parsedProfilingLevel; + std::transform(profilingLevelString.begin(), + profilingLevelString.end(), + profilingLevelString.begin(), + ::tolower); + ProfilingLevel parsedProfilingLevel = ProfilingLevel::INVALID; + if (profilingLevelString == "off") { + parsedProfilingLevel = ProfilingLevel::OFF; + } else if (profilingLevelString == "basic") { + parsedProfilingLevel = ProfilingLevel::BASIC; + } else if (profilingLevelString == "detailed") { + parsedProfilingLevel = ProfilingLevel::DETAILED; + } + return parsedProfilingLevel; } bool sample_app::deepCopyQnnTensorInfo(Qnn_Tensor_t *dst, const Qnn_Tensor_t *src) { - if (nullptr == dst || nullptr == src) { - QNN_ERROR("Received nullptr"); - return false; - } - // set tensor.version before using QNN_TENSOR_SET macros, as they require the version to be set - // to correctly assign values - dst->version = src->version; - const char *tensorName = QNN_TENSOR_GET_NAME(src); - if (!tensorName) { - QNN_TENSOR_SET_NAME(dst, nullptr); - } else { - QNN_TENSOR_SET_NAME(dst, pal::StringOp::strndup(tensorName, strlen(tensorName))); - } - QNN_TENSOR_SET_ID(dst, QNN_TENSOR_GET_ID(src)); - QNN_TENSOR_SET_TYPE(dst, QNN_TENSOR_GET_TYPE(src)); - QNN_TENSOR_SET_DATA_FORMAT(dst, QNN_TENSOR_GET_DATA_FORMAT(src)); - QNN_TENSOR_SET_DATA_TYPE(dst, QNN_TENSOR_GET_DATA_TYPE(src)); - Qnn_QuantizeParams_t qParams = QNN_QUANTIZE_PARAMS_INIT; - qParams.encodingDefinition = QNN_TENSOR_GET_QUANT_PARAMS(src).encodingDefinition; - qParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_UNDEFINED; - if (QNN_TENSOR_GET_QUANT_PARAMS(src).quantizationEncoding == - QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) { - qParams.quantizationEncoding = QNN_TENSOR_GET_QUANT_PARAMS(src).quantizationEncoding; - qParams.scaleOffsetEncoding = QNN_TENSOR_GET_QUANT_PARAMS(src).scaleOffsetEncoding; - } else if (QNN_TENSOR_GET_QUANT_PARAMS(src).quantizationEncoding == - QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) { - qParams.quantizationEncoding = QNN_TENSOR_GET_QUANT_PARAMS(src).quantizationEncoding; - qParams.axisScaleOffsetEncoding.axis = - QNN_TENSOR_GET_QUANT_PARAMS(src).axisScaleOffsetEncoding.axis; - qParams.axisScaleOffsetEncoding.numScaleOffsets = - QNN_TENSOR_GET_QUANT_PARAMS(src).axisScaleOffsetEncoding.numScaleOffsets; - if (QNN_TENSOR_GET_QUANT_PARAMS(src).axisScaleOffsetEncoding.numScaleOffsets > 0) { - qParams.axisScaleOffsetEncoding.scaleOffset = (Qnn_ScaleOffset_t *)malloc( - QNN_TENSOR_GET_QUANT_PARAMS(src).axisScaleOffsetEncoding.numScaleOffsets * - sizeof(Qnn_ScaleOffset_t)); - if (qParams.axisScaleOffsetEncoding.scaleOffset) { - for (size_t idx = 0; - idx < QNN_TENSOR_GET_QUANT_PARAMS(src).axisScaleOffsetEncoding.numScaleOffsets; - idx++) { - qParams.axisScaleOffsetEncoding.scaleOffset[idx].scale = - QNN_TENSOR_GET_QUANT_PARAMS(src).axisScaleOffsetEncoding.scaleOffset[idx].scale; - qParams.axisScaleOffsetEncoding.scaleOffset[idx].offset = - QNN_TENSOR_GET_QUANT_PARAMS(src).axisScaleOffsetEncoding.scaleOffset[idx].offset; + if (nullptr == dst || nullptr == src) { + QNN_ERROR("Received nullptr"); + return false; + } + // set tensor.version before using QNN_TENSOR_SET macros, as they require the version to be set + // to correctly assign values + dst->version = src->version; + const char *tensorName = QNN_TENSOR_GET_NAME(src); + if (!tensorName) { + QNN_TENSOR_SET_NAME(dst, nullptr); + } else { + QNN_TENSOR_SET_NAME(dst, pal::StringOp::strndup(tensorName, strlen(tensorName))); + } + QNN_TENSOR_SET_ID(dst, QNN_TENSOR_GET_ID(src)); + QNN_TENSOR_SET_TYPE(dst, QNN_TENSOR_GET_TYPE(src)); + QNN_TENSOR_SET_DATA_FORMAT(dst, QNN_TENSOR_GET_DATA_FORMAT(src)); + QNN_TENSOR_SET_DATA_TYPE(dst, QNN_TENSOR_GET_DATA_TYPE(src)); + Qnn_QuantizeParams_t qParams = QNN_QUANTIZE_PARAMS_INIT; + qParams.encodingDefinition = QNN_TENSOR_GET_QUANT_PARAMS(src).encodingDefinition; + qParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_UNDEFINED; + if (QNN_TENSOR_GET_QUANT_PARAMS(src).quantizationEncoding == QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) { + qParams.quantizationEncoding = QNN_TENSOR_GET_QUANT_PARAMS(src).quantizationEncoding; + qParams.scaleOffsetEncoding = QNN_TENSOR_GET_QUANT_PARAMS(src).scaleOffsetEncoding; + } else if (QNN_TENSOR_GET_QUANT_PARAMS(src).quantizationEncoding == QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) { + qParams.quantizationEncoding = QNN_TENSOR_GET_QUANT_PARAMS(src).quantizationEncoding; + qParams.axisScaleOffsetEncoding.axis = + QNN_TENSOR_GET_QUANT_PARAMS(src).axisScaleOffsetEncoding.axis; + qParams.axisScaleOffsetEncoding.numScaleOffsets = + QNN_TENSOR_GET_QUANT_PARAMS(src).axisScaleOffsetEncoding.numScaleOffsets; + if (QNN_TENSOR_GET_QUANT_PARAMS(src).axisScaleOffsetEncoding.numScaleOffsets > 0) { + qParams.axisScaleOffsetEncoding.scaleOffset = (Qnn_ScaleOffset_t *)malloc( + QNN_TENSOR_GET_QUANT_PARAMS(src).axisScaleOffsetEncoding.numScaleOffsets * sizeof(Qnn_ScaleOffset_t)); + if (qParams.axisScaleOffsetEncoding.scaleOffset) { + for (size_t idx = 0; + idx < QNN_TENSOR_GET_QUANT_PARAMS(src).axisScaleOffsetEncoding.numScaleOffsets; + idx++) { + qParams.axisScaleOffsetEncoding.scaleOffset[idx].scale = + QNN_TENSOR_GET_QUANT_PARAMS(src).axisScaleOffsetEncoding.scaleOffset[idx].scale; + qParams.axisScaleOffsetEncoding.scaleOffset[idx].offset = + QNN_TENSOR_GET_QUANT_PARAMS(src).axisScaleOffsetEncoding.scaleOffset[idx].offset; + } + } } - } } - } - QNN_TENSOR_SET_QUANT_PARAMS(dst, qParams); - QNN_TENSOR_SET_RANK(dst, QNN_TENSOR_GET_RANK(src)); - QNN_TENSOR_SET_DIMENSIONS(dst, nullptr); - if (QNN_TENSOR_GET_RANK(src) > 0) { - QNN_TENSOR_SET_DIMENSIONS(dst, (uint32_t *)malloc(QNN_TENSOR_GET_RANK(src) * sizeof(uint32_t))); - if (QNN_TENSOR_GET_DIMENSIONS(dst)) { - pal::StringOp::memscpy(QNN_TENSOR_GET_DIMENSIONS(dst), - QNN_TENSOR_GET_RANK(src) * sizeof(uint32_t), - QNN_TENSOR_GET_DIMENSIONS(src), - QNN_TENSOR_GET_RANK(src) * sizeof(uint32_t)); + QNN_TENSOR_SET_QUANT_PARAMS(dst, qParams); + QNN_TENSOR_SET_RANK(dst, QNN_TENSOR_GET_RANK(src)); + QNN_TENSOR_SET_DIMENSIONS(dst, nullptr); + if (QNN_TENSOR_GET_RANK(src) > 0) { + QNN_TENSOR_SET_DIMENSIONS(dst, (uint32_t *)malloc(QNN_TENSOR_GET_RANK(src) * sizeof(uint32_t))); + if (QNN_TENSOR_GET_DIMENSIONS(dst)) { + pal::StringOp::memscpy(QNN_TENSOR_GET_DIMENSIONS(dst), + QNN_TENSOR_GET_RANK(src) * sizeof(uint32_t), + QNN_TENSOR_GET_DIMENSIONS(src), + QNN_TENSOR_GET_RANK(src) * sizeof(uint32_t)); + } + if (QNN_TENSOR_GET_IS_DYNAMIC_DIMENSIONS(src)) { + QNN_TENSOR_SET_IS_DYNAMIC_DIMENSIONS( + dst, (uint8_t *)malloc(QNN_TENSOR_GET_RANK(src) * sizeof(uint8_t))); + pal::StringOp::memscpy(QNN_TENSOR_GET_IS_DYNAMIC_DIMENSIONS(dst), + QNN_TENSOR_GET_RANK(src) * sizeof(uint8_t), + QNN_TENSOR_GET_IS_DYNAMIC_DIMENSIONS(src), + QNN_TENSOR_GET_RANK(src) * sizeof(uint8_t)); + } } - } - return true; + QNN_TENSOR_SET_SPARSE_PARAMS(dst, QNN_TENSOR_GET_SPARSE_PARAMS(src)); + return true; } bool sample_app::copyTensorsInfo(const Qnn_Tensor_t *tensorsInfoSrc, Qnn_Tensor_t *&tensorWrappers, uint32_t tensorsCount) { - QNN_FUNCTION_ENTRY_LOG; - auto returnStatus = true; - tensorWrappers = (Qnn_Tensor_t *)calloc(tensorsCount, sizeof(Qnn_Tensor_t)); - if (nullptr == tensorWrappers) { - QNN_ERROR("Failed to allocate memory for tensorWrappers."); - return false; - } - if (returnStatus) { - for (size_t tIdx = 0; tIdx < tensorsCount; tIdx++) { - QNN_DEBUG("Extracting tensorInfo for tensor Idx: %d", tIdx); - tensorWrappers[tIdx] = QNN_TENSOR_INIT; - deepCopyQnnTensorInfo(&tensorWrappers[tIdx], &tensorsInfoSrc[tIdx]); + QNN_FUNCTION_ENTRY_LOG; + auto returnStatus = true; + tensorWrappers = (Qnn_Tensor_t *)calloc(tensorsCount, sizeof(Qnn_Tensor_t)); + if (nullptr == tensorWrappers) { + QNN_ERROR("Failed to allocate memory for tensorWrappers."); + return false; + } + if (returnStatus) { + for (size_t tIdx = 0; tIdx < tensorsCount; tIdx++) { + QNN_DEBUG("Extracting tensorInfo for tensor Idx: %d", tIdx); + tensorWrappers[tIdx] = QNN_TENSOR_INIT; + deepCopyQnnTensorInfo(&tensorWrappers[tIdx], &tensorsInfoSrc[tIdx]); + } } - } - QNN_FUNCTION_EXIT_LOG; - return returnStatus; + QNN_FUNCTION_EXIT_LOG; + return returnStatus; } bool sample_app::copyGraphsInfoV1(const QnnSystemContext_GraphInfoV1_t *graphInfoSrc, qnn_wrapper_api::GraphInfo_t *graphInfoDst) { - graphInfoDst->graphName = nullptr; - if (graphInfoSrc->graphName) { - graphInfoDst->graphName = - pal::StringOp::strndup(graphInfoSrc->graphName, strlen(graphInfoSrc->graphName)); - } - graphInfoDst->inputTensors = nullptr; - graphInfoDst->numInputTensors = 0; - if (graphInfoSrc->graphInputs) { - if (!copyTensorsInfo( - graphInfoSrc->graphInputs, graphInfoDst->inputTensors, graphInfoSrc->numGraphInputs)) { - return false; + graphInfoDst->graphName = nullptr; + if (graphInfoSrc->graphName) { + graphInfoDst->graphName = + pal::StringOp::strndup(graphInfoSrc->graphName, strlen(graphInfoSrc->graphName)); + } + graphInfoDst->inputTensors = nullptr; + graphInfoDst->numInputTensors = 0; + if (graphInfoSrc->graphInputs) { + if (!copyTensorsInfo( + graphInfoSrc->graphInputs, graphInfoDst->inputTensors, graphInfoSrc->numGraphInputs)) { + return false; + } + graphInfoDst->numInputTensors = graphInfoSrc->numGraphInputs; + } + graphInfoDst->outputTensors = nullptr; + graphInfoDst->numOutputTensors = 0; + if (graphInfoSrc->graphOutputs) { + if (!copyTensorsInfo(graphInfoSrc->graphOutputs, + graphInfoDst->outputTensors, + graphInfoSrc->numGraphOutputs)) { + return false; + } + graphInfoDst->numOutputTensors = graphInfoSrc->numGraphOutputs; } - graphInfoDst->numInputTensors = graphInfoSrc->numGraphInputs; - } - graphInfoDst->outputTensors = nullptr; - graphInfoDst->numOutputTensors = 0; - if (graphInfoSrc->graphOutputs) { - if (!copyTensorsInfo(graphInfoSrc->graphOutputs, - graphInfoDst->outputTensors, - graphInfoSrc->numGraphOutputs)) { - return false; + return true; +} + +bool sample_app::copyGraphsInfoV3(const QnnSystemContext_GraphInfoV3_t *graphInfoSrc, + qnn_wrapper_api::GraphInfo_t *graphInfoDst) { + graphInfoDst->graphName = nullptr; + if (graphInfoSrc->graphName) { + graphInfoDst->graphName = + pal::StringOp::strndup(graphInfoSrc->graphName, strlen(graphInfoSrc->graphName)); + } + graphInfoDst->inputTensors = nullptr; + graphInfoDst->numInputTensors = 0; + if (graphInfoSrc->graphInputs) { + if (!copyTensorsInfo( + graphInfoSrc->graphInputs, graphInfoDst->inputTensors, graphInfoSrc->numGraphInputs)) { + return false; + } + graphInfoDst->numInputTensors = graphInfoSrc->numGraphInputs; + } + graphInfoDst->outputTensors = nullptr; + graphInfoDst->numOutputTensors = 0; + if (graphInfoSrc->graphOutputs) { + if (!copyTensorsInfo(graphInfoSrc->graphOutputs, + graphInfoDst->outputTensors, + graphInfoSrc->numGraphOutputs)) { + return false; + } + graphInfoDst->numOutputTensors = graphInfoSrc->numGraphOutputs; } - graphInfoDst->numOutputTensors = graphInfoSrc->numGraphOutputs; - } - return true; + return true; } bool sample_app::copyGraphsInfo(const QnnSystemContext_GraphInfo_t *graphsInput, const uint32_t numGraphs, qnn_wrapper_api::GraphInfo_t **&graphsInfo) { - QNN_FUNCTION_ENTRY_LOG; - if (!graphsInput) { - QNN_ERROR("Received nullptr for graphsInput."); - return false; - } - auto returnStatus = true; - graphsInfo = - (qnn_wrapper_api::GraphInfo_t **)calloc(numGraphs, sizeof(qnn_wrapper_api::GraphInfo_t *)); - qnn_wrapper_api::GraphInfo_t *graphInfoArr = - (qnn_wrapper_api::GraphInfo_t *)calloc(numGraphs, sizeof(qnn_wrapper_api::GraphInfo_t)); - if (nullptr == graphsInfo || nullptr == graphInfoArr) { - QNN_ERROR("Failure to allocate memory for *graphInfo"); - returnStatus = false; - } - if (true == returnStatus) { - for (size_t gIdx = 0; gIdx < numGraphs; gIdx++) { - QNN_DEBUG("Extracting graphsInfo for graph Idx: %d", gIdx); - if (graphsInput[gIdx].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_1) { - copyGraphsInfoV1(&graphsInput[gIdx].graphInfoV1, &graphInfoArr[gIdx]); - } - graphsInfo[gIdx] = graphInfoArr + gIdx; + QNN_FUNCTION_ENTRY_LOG; + if (!graphsInput) { + QNN_ERROR("Received nullptr for graphsInput."); + return false; } - } - if (true != returnStatus) { - QNN_ERROR("Received an ERROR during extractGraphsInfo. Freeing resources."); - if (graphsInfo) { - for (uint32_t gIdx = 0; gIdx < numGraphs; gIdx++) { - if (graphsInfo[gIdx]) { - if (nullptr != graphsInfo[gIdx]->graphName) { - free(graphsInfo[gIdx]->graphName); - graphsInfo[gIdx]->graphName = nullptr; - } - qnn_wrapper_api::freeQnnTensors(graphsInfo[gIdx]->inputTensors, - graphsInfo[gIdx]->numInputTensors); - qnn_wrapper_api::freeQnnTensors(graphsInfo[gIdx]->outputTensors, - graphsInfo[gIdx]->numOutputTensors); + auto returnStatus = true; + graphsInfo = + (qnn_wrapper_api::GraphInfo_t **)calloc(numGraphs, sizeof(qnn_wrapper_api::GraphInfo_t *)); + qnn_wrapper_api::GraphInfo_t *graphInfoArr = + (qnn_wrapper_api::GraphInfo_t *)calloc(numGraphs, sizeof(qnn_wrapper_api::GraphInfo_t)); + if (nullptr == graphsInfo || nullptr == graphInfoArr) { + QNN_ERROR("Failure to allocate memory for *graphInfo"); + returnStatus = false; + } + if (true == returnStatus) { + for (size_t gIdx = 0; gIdx < numGraphs; gIdx++) { + QNN_DEBUG("Extracting graphsInfo for graph Idx: %d", gIdx); + if (graphsInput[gIdx].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_1) { + copyGraphsInfoV1(&graphsInput[gIdx].graphInfoV1, &graphInfoArr[gIdx]); + } else if (graphsInput[gIdx].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_3) { + copyGraphsInfoV3(&graphsInput[gIdx].graphInfoV3, &graphInfoArr[gIdx]); + } + graphsInfo[gIdx] = graphInfoArr + gIdx; } - } - free(*graphsInfo); } - free(graphsInfo); - graphsInfo = nullptr; - } - QNN_FUNCTION_EXIT_LOG; - return true; + if (true != returnStatus) { + QNN_ERROR("Received an ERROR during extractGraphsInfo. Freeing resources."); + if (graphsInfo) { + for (uint32_t gIdx = 0; gIdx < numGraphs; gIdx++) { + if (graphsInfo[gIdx]) { + if (nullptr != graphsInfo[gIdx]->graphName) { + free(graphsInfo[gIdx]->graphName); + graphsInfo[gIdx]->graphName = nullptr; + } + qnn_wrapper_api::freeQnnTensors(graphsInfo[gIdx]->inputTensors, + graphsInfo[gIdx]->numInputTensors); + qnn_wrapper_api::freeQnnTensors(graphsInfo[gIdx]->outputTensors, + graphsInfo[gIdx]->numOutputTensors); + } + } + free(*graphsInfo); + } + free(graphsInfo); + graphsInfo = nullptr; + } + QNN_FUNCTION_EXIT_LOG; + return true; } bool sample_app::copyMetadataToGraphsInfo(const QnnSystemContext_BinaryInfo_t *binaryInfo, qnn_wrapper_api::GraphInfo_t **&graphsInfo, uint32_t &graphsCount) { - if (nullptr == binaryInfo) { - QNN_ERROR("binaryInfo is nullptr."); - return false; - } - graphsCount = 0; - if (binaryInfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_1) { - if (binaryInfo->contextBinaryInfoV1.graphs) { - if (!copyGraphsInfo(binaryInfo->contextBinaryInfoV1.graphs, - binaryInfo->contextBinaryInfoV1.numGraphs, - graphsInfo)) { - QNN_ERROR("Failed while copying graphs Info."); + if (nullptr == binaryInfo) { + QNN_ERROR("binaryInfo is nullptr."); return false; - } - graphsCount = binaryInfo->contextBinaryInfoV1.numGraphs; - return true; } - } else if (binaryInfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_2) { - if (binaryInfo->contextBinaryInfoV2.graphs) { - if (!copyGraphsInfo(binaryInfo->contextBinaryInfoV2.graphs, - binaryInfo->contextBinaryInfoV2.numGraphs, - graphsInfo)) { - QNN_ERROR("Failed while copying graphs Info."); - return false; - } - graphsCount = binaryInfo->contextBinaryInfoV2.numGraphs; - return true; + graphsCount = 0; + if (binaryInfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_1) { + if (binaryInfo->contextBinaryInfoV1.graphs) { + if (!copyGraphsInfo(binaryInfo->contextBinaryInfoV1.graphs, + binaryInfo->contextBinaryInfoV1.numGraphs, + graphsInfo)) { + QNN_ERROR("Failed while copying graphs Info."); + return false; + } + graphsCount = binaryInfo->contextBinaryInfoV1.numGraphs; + return true; + } + } else if (binaryInfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_2) { + if (binaryInfo->contextBinaryInfoV2.graphs) { + if (!copyGraphsInfo(binaryInfo->contextBinaryInfoV2.graphs, + binaryInfo->contextBinaryInfoV2.numGraphs, + graphsInfo)) { + QNN_ERROR("Failed while copying graphs Info."); + return false; + } + graphsCount = binaryInfo->contextBinaryInfoV2.numGraphs; + return true; + } + } else if (binaryInfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_3) { + if (binaryInfo->contextBinaryInfoV3.graphs) { + if (!copyGraphsInfo(binaryInfo->contextBinaryInfoV3.graphs, + binaryInfo->contextBinaryInfoV3.numGraphs, + graphsInfo)) { + QNN_ERROR("Failed while copying graphs Info."); + return false; + } + graphsCount = binaryInfo->contextBinaryInfoV3.numGraphs; + return true; + } } - } - QNN_ERROR("Unrecognized system context binary info version."); - return false; + QNN_ERROR("Unrecognized system context binary info version."); + return false; } QnnLog_Level_t sample_app::parseLogLevel(std::string logLevelString) { - QNN_FUNCTION_ENTRY_LOG; - std::transform(logLevelString.begin(), logLevelString.end(), logLevelString.begin(), ::tolower); - QnnLog_Level_t parsedLogLevel = QNN_LOG_LEVEL_MAX; - if (logLevelString == "error") { - parsedLogLevel = QNN_LOG_LEVEL_ERROR; - } else if (logLevelString == "warn") { - parsedLogLevel = QNN_LOG_LEVEL_WARN; - } else if (logLevelString == "info") { - parsedLogLevel = QNN_LOG_LEVEL_INFO; - } else if (logLevelString == "verbose") { - parsedLogLevel = QNN_LOG_LEVEL_VERBOSE; - } else if (logLevelString == "debug") { - parsedLogLevel = QNN_LOG_LEVEL_DEBUG; - } - QNN_FUNCTION_EXIT_LOG; - return parsedLogLevel; + QNN_FUNCTION_ENTRY_LOG; + std::transform(logLevelString.begin(), logLevelString.end(), logLevelString.begin(), ::tolower); + QnnLog_Level_t parsedLogLevel = QNN_LOG_LEVEL_MAX; + if (logLevelString == "error") { + parsedLogLevel = QNN_LOG_LEVEL_ERROR; + } else if (logLevelString == "warn") { + parsedLogLevel = QNN_LOG_LEVEL_WARN; + } else if (logLevelString == "info") { + parsedLogLevel = QNN_LOG_LEVEL_INFO; + } else if (logLevelString == "verbose") { + parsedLogLevel = QNN_LOG_LEVEL_VERBOSE; + } else if (logLevelString == "debug") { + parsedLogLevel = QNN_LOG_LEVEL_DEBUG; + } + QNN_FUNCTION_EXIT_LOG; + return parsedLogLevel; +} + +unsigned int sample_app::parseNumInferences(std::string numString) { + unsigned int num = 0; + std::stringstream numStream; + numStream << numString; + numStream >> num; + return num; } diff --git a/src/backends/qnn/Utils/QnnSampleAppUtils.hpp b/src/backends/qnn/Utils/QnnSampleAppUtils.hpp index d9f223230..4fcbfc148 100644 --- a/src/backends/qnn/Utils/QnnSampleAppUtils.hpp +++ b/src/backends/qnn/Utils/QnnSampleAppUtils.hpp @@ -1,7 +1,7 @@ //============================================================================== // -// Copyright (c) 2019-2022 Qualcomm Technologies, Inc. -// All Rights Reserved. +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All rights reserved. // Confidential and Proprietary - Qualcomm Technologies, Inc. // //============================================================================== @@ -10,7 +10,9 @@ #include #include #include +#include #include +#include #include #include @@ -20,16 +22,27 @@ namespace qnn { namespace tools { namespace sample_app { -enum class ProfilingLevel { OFF, BASIC, DETAILED, INVALID }; +enum class ProfilingLevel { OFF, + BASIC, + DETAILED, + INVALID }; -using ReadInputListRetType_t = std::tuple>, bool>; +using ReadInputListRetType_t = std:: + tuple>, std::unordered_map, bool>; ReadInputListRetType_t readInputList(std::string inputFileListPath); -using ReadInputListsRetType_t = std::tuple>>, bool>; +using ReadInputListsRetType_t = std::tuple>>, + std::vector>, + bool>; ReadInputListsRetType_t readInputLists(std::vector inputFileListPath); +std::unordered_map extractInputNameIndices(const std::string &inputLine, + const std::string &separator); + +std::string sanitizeTensorName(std::string name); + ProfilingLevel parseProfilingLevel(std::string profilingLevelString); void parseInputFilePaths(std::vector &inputFilePaths, @@ -51,6 +64,9 @@ bool copyGraphsInfo(const QnnSystemContext_GraphInfo_t *graphsInput, bool copyGraphsInfoV1(const QnnSystemContext_GraphInfoV1_t *graphInfoSrc, qnn_wrapper_api::GraphInfo_t *graphInfoDst); +bool copyGraphsInfoV3(const QnnSystemContext_GraphInfoV3_t *graphInfoSrc, + qnn_wrapper_api::GraphInfo_t *graphInfoDst); + bool copyTensorsInfo(const Qnn_Tensor_t *tensorsInfoSrc, Qnn_Tensor_t *&tensorWrappers, uint32_t tensorsCount); @@ -59,11 +75,13 @@ bool deepCopyQnnTensorInfo(Qnn_Tensor_t *dst, const Qnn_Tensor_t *src); QnnLog_Level_t parseLogLevel(std::string logLevelString); +unsigned int parseNumInferences(std::string numString); + void inline exitWithMessage(std::string &&msg, int code) { - std::cerr << msg << std::endl; - std::exit(code); + std::cerr << msg << std::endl; + std::exit(code); } -} // namespace sample_app -} // namespace tools -} // namespace qnn \ No newline at end of file +} +} +} // namespace qnn::tools::sample_app \ No newline at end of file diff --git a/src/backends/qnn/WrapperUtils/QnnWrapperUtils.cpp b/src/backends/qnn/WrapperUtils/QnnWrapperUtils.cpp index b70180308..afbec492d 100644 --- a/src/backends/qnn/WrapperUtils/QnnWrapperUtils.cpp +++ b/src/backends/qnn/WrapperUtils/QnnWrapperUtils.cpp @@ -1,198 +1,55 @@ //============================================================================== // -// Copyright (c) 2019-2022 Qualcomm Technologies, Inc. -// All Rights Reserved. +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All rights reserved. // Confidential and Proprietary - Qualcomm Technologies, Inc. // //============================================================================== -#include -#include -#include +#include -#include "QnnModelPal.hpp" #include "QnnTypeMacros.hpp" #include "QnnWrapperUtils.hpp" -namespace qnn_wrapper_api { -size_t memscpy(void *dst, size_t dstSize, const void *src, size_t copySize) { - if (!dst || !src || !dstSize || !copySize) return 0; - - size_t minSize = dstSize < copySize ? dstSize : copySize; - - memcpy(dst, src, minSize); - - return minSize; -} - -ModelError_t getQnnGraphConfigFromInfo(const char *graphName, - const GraphConfigInfo_t **graphsConfigInfo, - const uint32_t numGraphsConfigInfo, - const QnnGraph_Config_t **&graphConfigs) { - if (!graphsConfigInfo || numGraphsConfigInfo == 0) { - PRINT_DEBUG("getQnnGraphConfigFromInfo() no custom configs passed for graph:%s.\n", graphName); - return MODEL_NO_ERROR; - } - - size_t found = 0; - - for (uint32_t i = 0; i < numGraphsConfigInfo; i++) { - if (!graphsConfigInfo[i]) { - PRINT_ERROR( - "getQnnGraphConfigFromInfo() lookup error while trying to query graphName:%s. " - "numGraphsConfigInfo > num of element in graphsConfigInfo\n", - graphName); - return MODEL_INVALID_ARGUMENT_ERROR; +qnn_wrapper_api::ModelError_t qnn_wrapper_api::freeQnnTensor(Qnn_Tensor_t &tensor) { + // free all pointer allocations in struct + free((void *)QNN_TENSOR_GET_NAME(tensor)); + free(QNN_TENSOR_GET_DIMENSIONS(tensor)); + if (QNN_TENSOR_GET_IS_DYNAMIC_DIMENSIONS(tensor)) { + free(QNN_TENSOR_GET_IS_DYNAMIC_DIMENSIONS(tensor)); } - if (strcmp(graphsConfigInfo[i]->graphName, graphName) == 0) { - graphConfigs = graphsConfigInfo[i]->graphConfigs; - found++; + auto quant = QNN_TENSOR_GET_QUANT_PARAMS(tensor); + auto encoding = quant.quantizationEncoding; + if (encoding == QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) { + if (quant.axisScaleOffsetEncoding.scaleOffset != nullptr) { + free(quant.axisScaleOffsetEncoding.scaleOffset); + } } - } - - if (!found) { - PRINT_ERROR( - "getQnnGraphConfigFromInfo() unable to find graphName:%s in provided " - "graphsConfigInfo object.\n", - graphName); - return MODEL_INVALID_ARGUMENT_ERROR; - } else if (found > 1) { - PRINT_ERROR( - "getQnnGraphConfigFromInfo() duplicate GraphConfigInfo entries found with " - "graphName:%s.\n", - graphName); - return MODEL_INVALID_ARGUMENT_ERROR; - } else { return MODEL_NO_ERROR; - } } -ModelError_t deepCopyQnnTensors(Qnn_Tensor_t &src, Qnn_Tensor_t &dst) { - ModelError_t err; - VALIDATE_TENSOR_VERSION(src, err); - - dst.version = src.version; - QNN_TENSOR_SET_NAME( - dst, strnDup(QNN_TENSOR_GET_NAME(src), std::string(QNN_TENSOR_GET_NAME(src)).size())); - if (QNN_TENSOR_GET_NAME(dst) == nullptr) { - return MODEL_TENSOR_ERROR; - } - QNN_TENSOR_SET_ID(dst, QNN_TENSOR_GET_ID(src)); - QNN_TENSOR_SET_TYPE(dst, QNN_TENSOR_GET_TYPE(src)); - QNN_TENSOR_SET_DATA_FORMAT(dst, QNN_TENSOR_GET_DATA_FORMAT(src)); - QNN_TENSOR_SET_DATA_TYPE(dst, QNN_TENSOR_GET_DATA_TYPE(src)); - QNN_TENSOR_SET_MEM_TYPE(dst, QNN_TENSOR_GET_MEM_TYPE(src)); - - // Only metadata (i.e. non-static data) is copied from source to destination. The union still - // must be initialized so that the clientBuf/memHandle do not contain garbage data - if (QNN_TENSOR_GET_MEM_TYPE(src) == QNN_TENSORMEMTYPE_RAW) { - Qnn_ClientBuffer_t clientBuf = {nullptr, 0}; - QNN_TENSOR_SET_CLIENT_BUF(dst, clientBuf); - } else if (QNN_TENSOR_GET_MEM_TYPE(src) == QNN_TENSORMEMTYPE_MEMHANDLE) { - QNN_TENSOR_SET_MEM_HANDLE(dst, nullptr); - } else { - return MODEL_TENSOR_ERROR; - } - - Qnn_QuantizeParams_t srcQParam = QNN_TENSOR_GET_QUANT_PARAMS(src); - Qnn_QuantizationEncoding_t encoding = srcQParam.quantizationEncoding; - if (encoding == QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) { - // need to allocate and copy memory for scaleOffset as it is a pointer array - Qnn_QuantizeParams_t srcQParamCpy = srcQParam; - Qnn_AxisScaleOffset_t &axisScaleOffset = srcQParamCpy.axisScaleOffsetEncoding; - Qnn_ScaleOffset_t **scaleOffset = &axisScaleOffset.scaleOffset; - size_t scaleOffsetSize = axisScaleOffset.numScaleOffsets * sizeof(Qnn_ScaleOffset_t); - *scaleOffset = (Qnn_ScaleOffset_t *)malloc(scaleOffsetSize); - memscpy(*scaleOffset, - scaleOffsetSize, - srcQParam.axisScaleOffsetEncoding.scaleOffset, - scaleOffsetSize); - QNN_TENSOR_SET_QUANT_PARAMS(dst, srcQParamCpy); - } else if (encoding == QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET) { - // need to allocate and copy memory for scaleOffset as it is a pointer array - Qnn_QuantizeParams_t srcQParamCpy = srcQParam; - Qnn_BwAxisScaleOffset_t &bwAxisScaleOffset = srcQParamCpy.bwAxisScaleOffsetEncoding; - size_t scaleSize = bwAxisScaleOffset.numElements * sizeof(float); - float **scales = &bwAxisScaleOffset.scales; - int32_t **offsets = &bwAxisScaleOffset.offsets; - *scales = (float *)malloc(scaleSize); - memscpy(*scales, scaleSize, srcQParam.bwAxisScaleOffsetEncoding.scales, scaleSize); - - // Only copy offsets if present, nullptr implies all offsets are 0 - if (bwAxisScaleOffset.offsets != nullptr) { - size_t offsetSize = bwAxisScaleOffset.numElements * sizeof(int32_t); - *offsets = (int32_t *)malloc(offsetSize); - memscpy(*offsets, offsetSize, srcQParam.bwAxisScaleOffsetEncoding.offsets, offsetSize); +qnn_wrapper_api::ModelError_t qnn_wrapper_api::freeQnnTensors(Qnn_Tensor_t *&tensors, + uint32_t numTensors) { + // free all pointer allocations in struct + for (size_t i = 0; i < numTensors; i++) { + freeQnnTensor(tensors[i]); } - QNN_TENSOR_SET_QUANT_PARAMS(dst, srcQParamCpy); - } else { - QNN_TENSOR_SET_QUANT_PARAMS(dst, srcQParam); - } - - // need to allocate and copy memory for all the pointer members - uint32_t rank = QNN_TENSOR_GET_RANK(src); - QNN_TENSOR_SET_RANK(dst, rank); - size_t dimSize = rank * sizeof(uint32_t); - uint32_t *dimensions = (uint32_t *)malloc(dimSize); - if (dimensions == nullptr) { - PRINT_ERROR("deepCopyQnnTensors() Allocation error while copying tensor %s", - QNN_TENSOR_GET_NAME(src)); - return MODEL_TENSOR_ERROR; - } - memscpy(dimensions, dimSize, QNN_TENSOR_GET_DIMENSIONS(src), dimSize); - QNN_TENSOR_SET_DIMENSIONS(dst, dimensions); - - return err; -} - -ModelError_t freeQnnTensor(Qnn_Tensor_t &tensor) { - ModelError_t err; - VALIDATE_TENSOR_VERSION(tensor, err); - - // free all pointer allocations in struct - free((void *)QNN_TENSOR_GET_NAME(tensor)); - free(QNN_TENSOR_GET_DIMENSIONS(tensor)); - - return MODEL_NO_ERROR; -} - -ModelError_t freeQnnTensors(Qnn_Tensor_t *&tensors, uint32_t numTensors) { - // free all pointer allocations in struct - for (size_t i = 0; i < numTensors; i++) { - freeQnnTensor(tensors[i]); - } - free(tensors); - - return MODEL_NO_ERROR; + free(tensors); + return MODEL_NO_ERROR; } -std::string getModelErrorName(ModelError_t modelError) { - switch (modelError) { - case MODEL_NO_ERROR: - return "MODEL_NO_ERROR"; - case MODEL_TENSOR_ERROR: - return "MODEL_TENSOR_ERROR"; - case MODEL_PARAMS_ERROR: - return "MODEL_PARAMS_ERROR"; - case MODEL_NODES_ERROR: - return "MODEL_NODES_ERROR"; - case MODEL_GRAPH_ERROR: - return "MODEL_GRAPH_ERROR"; - case MODEL_CONTEXT_ERROR: - return "MODEL_CONTEXT_ERROR"; - case MODEL_GENERATION_ERROR: - return "MODEL_GENERATION_ERROR"; - case MODEL_SETUP_ERROR: - return "MODEL_SETUP_ERROR"; - case MODEL_UNKNOWN_ERROR: - return "MODEL_UNKNOWN_ERROR"; - case MODEL_INVALID_ARGUMENT_ERROR: - return "MODEL_INVALID_ARGUMENT_ERROR"; - case MODEL_FILE_ERROR: - return "MODEL_FILE_ERROR"; - default: - return "INVALID_ERROR_CODE"; - } +qnn_wrapper_api::ModelError_t qnn_wrapper_api::freeGraphsInfo(GraphInfoPtr_t **graphsInfo, + uint32_t numGraphs) { + if (graphsInfo == nullptr || *graphsInfo == nullptr) { + return MODEL_TENSOR_ERROR; + } + for (uint32_t i = 0; i < numGraphs; i++) { + free((*graphsInfo)[i]->graphName); + freeQnnTensors((*graphsInfo)[i]->inputTensors, (*graphsInfo)[i]->numInputTensors); + freeQnnTensors((*graphsInfo)[i]->outputTensors, (*graphsInfo)[i]->numOutputTensors); + } + free(**graphsInfo); + free(*graphsInfo); + *graphsInfo = nullptr; + return MODEL_NO_ERROR; } - -} // namespace qnn_wrapper_api diff --git a/src/backends/qnn/WrapperUtils/QnnWrapperUtils.hpp b/src/backends/qnn/WrapperUtils/QnnWrapperUtils.hpp index a51e3f0e8..d82126a28 100644 --- a/src/backends/qnn/WrapperUtils/QnnWrapperUtils.hpp +++ b/src/backends/qnn/WrapperUtils/QnnWrapperUtils.hpp @@ -1,7 +1,7 @@ //============================================================================== // -// Copyright (c) 2019-2022 Qualcomm Technologies, Inc. -// All Rights Reserved. +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All rights reserved. // Confidential and Proprietary - Qualcomm Technologies, Inc. // //============================================================================== @@ -12,7 +12,6 @@ #include "QnnGraph.h" #include "QnnTensor.h" #include "QnnTypes.h" -#include namespace qnn_wrapper_api { @@ -21,159 +20,116 @@ namespace qnn_wrapper_api { // Enables FILE[LINE]: FMT for VALIDATE macro #ifdef QNN_ENABLE_DEBUG -#define PRINTF(fmt, ...) \ - do { \ - printf("%s[%d]: ", __FILE__, __LINE__); \ - printf((fmt), ##__VA_ARGS__); \ - } while (0) +#define PRINTF(fmt, ...) \ + do { \ + printf("%s[%d]: ", __FILE__, __LINE__); \ + printf((fmt), ##__VA_ARGS__); \ + } while (0) #else -#define PRINTF(fmt, ...) \ - do { \ - printf((fmt), ##__VA_ARGS__); \ - } while (0) +#define PRINTF(fmt, ...) \ + do { \ + printf((fmt), ##__VA_ARGS__); \ + } while (0) #endif #ifdef QNN_ENABLE_DEBUG -#define PRINT_DEBUG(fmt, ...) \ - do { \ - printf("[ DEBUG ] "); \ - PRINTF((fmt), ##__VA_ARGS__); \ - } while (0) +#define PRINT_DEBUG(fmt, ...) \ + do { \ + printf("[ DEBUG ] "); \ + PRINTF((fmt), ##__VA_ARGS__); \ + } while (0) #else #define PRINT_DEBUG(fmt, ...) #endif // Enables ERROR tag for errors -#define PRINT_ERROR(fmt, ...) \ - do { \ - printf("[ ERROR ] "); \ - PRINTF((fmt), ##__VA_ARGS__); \ - } while (0) +#define PRINT_ERROR(fmt, ...) \ + do { \ + printf("[ ERROR ] "); \ + PRINTF((fmt), ##__VA_ARGS__); \ + } while (0) // Enables WARNING tag for errors -#define PRINT_WARNING(fmt, ...) \ - do { \ - printf("[ WARNING ] "); \ - PRINTF((fmt), ##__VA_ARGS__); \ - } while (0) +#define PRINT_WARNING(fmt, ...) \ + do { \ + printf("[ WARNING ] "); \ + PRINTF((fmt), ##__VA_ARGS__); \ + } while (0) // Enables INFO tag for errors -#define PRINT_INFO(fmt, ...) \ - do { \ - printf("[ INFO ] "); \ - PRINTF((fmt), ##__VA_ARGS__); \ - } while (0) +#define PRINT_INFO(fmt, ...) \ + do { \ + printf("[ INFO ] "); \ + PRINTF((fmt), ##__VA_ARGS__); \ + } while (0) -#define STRINGFY(str) str +#define STRINGFY(str) str #define STRINGFYVALUE(str) STRINGFY(str) // Ensures ModelError_t returning functions return MODEL_NO_ERROR // retStatus should be set to MODEL_NO_ERROR before passing to macro -#define VALIDATE(value, retStatus) \ - do { \ - retStatus = value; \ - if (retStatus != qnn_wrapper_api::MODEL_NO_ERROR) { \ - PRINT_ERROR( \ - "%s expected MODEL_NO_ERROR, got %s\n", #value, getModelErrorName(retStatus).c_str()); \ - return retStatus; \ - } \ - } while (0) +#define VALIDATE(value, retStatus) \ + do { \ + retStatus = value; \ + if (retStatus != qnn_wrapper_api::MODEL_NO_ERROR) { \ + PRINT_ERROR( \ + "%s expected MODEL_NO_ERROR, got %d\n", #value, retStatus); \ + return retStatus; \ + } \ + } while (0) // macros for retrieving binary data -#define BINVARSTART(NAME) \ - ({ \ - extern const uint8_t _binary_obj_binary_##NAME##_raw_start[]; \ - (void *)_binary_obj_binary_##NAME##_raw_start; \ - }) -#define BINVAREND(NAME) \ - ({ \ - extern const uint8_t _binary_obj_binary_##NAME##_raw_end[]; \ - (void *)_binary_obj_binary_##NAME##_raw_end; \ - }) -#define BINLEN(NAME) \ - ({ \ - extern const uint8_t _binary_obj_binary_##NAME##_raw_start[]; \ - extern const uint8_t _binary_obj_binary_##NAME##_raw_end[]; \ - (uint32_t)((_binary_obj_binary_##NAME##_raw_end) - (_binary_obj_binary_##NAME##_raw_start)); \ - }) +#define BINVARSTART(NAME) \ + ({ \ + extern const uint8_t _binary_obj_binary_##NAME##_raw_start[]; \ + (void *)_binary_obj_binary_##NAME##_raw_start; \ + }) +#define BINVAREND(NAME) \ + ({ \ + extern const uint8_t _binary_obj_binary_##NAME##_raw_end[]; \ + (void *)_binary_obj_binary_##NAME##_raw_end; \ + }) +#define BINLEN(NAME) \ + ({ \ + extern const uint8_t _binary_obj_binary_##NAME##_raw_start[]; \ + extern const uint8_t _binary_obj_binary_##NAME##_raw_end[]; \ + (uint32_t)((_binary_obj_binary_##NAME##_raw_end) - (_binary_obj_binary_##NAME##_raw_start)); \ + }) typedef enum ModelError { - MODEL_NO_ERROR = 0, - MODEL_TENSOR_ERROR = 1, - MODEL_PARAMS_ERROR = 2, - MODEL_NODES_ERROR = 3, - MODEL_GRAPH_ERROR = 4, - MODEL_CONTEXT_ERROR = 5, - MODEL_GENERATION_ERROR = 6, - MODEL_SETUP_ERROR = 7, - MODEL_INVALID_ARGUMENT_ERROR = 8, - MODEL_FILE_ERROR = 9, - MODEL_MEMORY_ALLOCATE_ERROR = 10, - // Value selected to ensure 32 bits. - MODEL_UNKNOWN_ERROR = 0x7FFFFFFF + MODEL_NO_ERROR = 0, + MODEL_TENSOR_ERROR = 1, + MODEL_PARAMS_ERROR = 2, + MODEL_NODES_ERROR = 3, + MODEL_GRAPH_ERROR = 4, + MODEL_CONTEXT_ERROR = 5, + MODEL_GENERATION_ERROR = 6, + MODEL_SETUP_ERROR = 7, + MODEL_INVALID_ARGUMENT_ERROR = 8, + MODEL_FILE_ERROR = 9, + MODEL_MEMORY_ALLOCATE_ERROR = 10, + // Value selected to ensure 32 bits. + MODEL_UNKNOWN_ERROR = 0x7FFFFFFF } ModelError_t; -/** - * @brief Returns the error message associated with a given error code - * - * @param[in] modelError ModelError_t error code - * - * @return string message - */ -std::string getModelErrorName(ModelError_t modelError); - typedef struct GraphInfo { - Qnn_GraphHandle_t graph; - char *graphName; - Qnn_Tensor_t *inputTensors; - uint32_t numInputTensors; - Qnn_Tensor_t *outputTensors; - uint32_t numOutputTensors; + Qnn_GraphHandle_t graph; + char *graphName; + Qnn_Tensor_t *inputTensors; + uint32_t numInputTensors; + Qnn_Tensor_t *outputTensors; + uint32_t numOutputTensors; } GraphInfo_t; typedef GraphInfo_t *GraphInfoPtr_t; typedef struct GraphConfigInfo { - char *graphName; - const QnnGraph_Config_t **graphConfigs; + char *graphName; + const QnnGraph_Config_t **graphConfigs; } GraphConfigInfo_t; -/** - * @brief Helper function to get Qnn GraphConfig structure from provided GraphConfigInfo using - * graphName. - * - * @param[in] graphName the Qnn graphName to use for lookup - * - * @param[in] graphsConfigInfo array of GraphConfig_t objects - * - * @param[in] numGraphsConfigInfo the number of array elements in graphConfigInfo - * - * @param[out] graphConfigs the result of query of graphName from graphsConfigInfo if successful. - * - * @return Error code - * - */ -ModelError_t getQnnGraphConfigFromInfo(const char *graphName, - const GraphConfigInfo_t **graphsConfigInfo, - const uint32_t numGraphsConfigInfo, - const QnnGraph_Config_t **&graphConfigs); - -/** - * @brief Deep Copies QnnTensor_t structs to a pointer array destination location. - * Note: The copy will be stored on the heap and as such requires caller to make - * appropriate free call(s) using function below. - * Note 2: deepCopy is only done for metadata - * - * @param[in] source tensor object to copy from - * - * @param[in] destination tensor object to copy to - * - * @return Error code - */ -ModelError_t deepCopyQnnTensors(Qnn_Tensor_t &source, Qnn_Tensor_t &destination); - /** * @brief Frees all memory allocated tensor attributes. * @@ -195,6 +151,16 @@ ModelError_t freeQnnTensor(Qnn_Tensor_t &tensor); */ ModelError_t freeQnnTensors(Qnn_Tensor_t *&tensors, uint32_t numTensors); -size_t memscpy(void *dst, size_t dstSize, const void *src, size_t copySize); +/** + * @brief A helper function to free memory malloced for communicating the Graph for a model(s) + * + * @param[in] graphsInfo Pointer pointing to location of graph objects + * + * @param[in] numGraphs The number of graph objects the above pointer is pointing to + * + * @return Error code + * + */ +ModelError_t freeGraphsInfo(GraphInfoPtr_t **graphsInfo, uint32_t numGraphs); -} // namespace qnn_wrapper_api +} // namespace qnn_wrapper_api diff --git a/src/backends/qnn/op/QNNCommonOp.cpp b/src/backends/qnn/op/QNNCommonOp.cpp index e448429a9..1196f4601 100644 --- a/src/backends/qnn/op/QNNCommonOp.cpp +++ b/src/backends/qnn/op/QNNCommonOp.cpp @@ -29,26 +29,31 @@ ErrorCode QNNCommonOp::graphAddNode(string name, string nodeType, vector(output->sequence()); } - // TODO tensor type = MLLM_TYPE_I8 - auto data_type = QNN_DATATYPE_FLOAT_32; - if (output->dtype() == MLLM_TYPE_I8) { - data_type = QNN_DATATYPE_SFIXED_POINT_8; - } - - if (output->dtype() == MLLM_TYPE_F16) { - data_type = QNN_DATATYPE_FLOAT_16; - } - - float quantScale = 0.0f; auto quantDefine = QNN_DEFINITION_UNDEFINED; auto quantType = QNN_QUANTIZATION_ENCODING_UNDEFINED; - - if (scale != nullptr) { - quantScale = scale->hostPtr()[0] / 127.0; - quantScale = roundf(quantScale * 100000) / 100000; + auto data_type = QNN_DATATYPE_FLOAT_32; + switch (output->dtype()) { + case MLLM_TYPE_I8: + data_type = QNN_DATATYPE_SFIXED_POINT_8; + quantScale = scale->hostPtr()[0] / (pow(2, 7) - 1); + // quantScale = roundf(quantScale * 100000) / 100000; quantDefine = QNN_DEFINITION_DEFINED; quantType = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET; + break; + case MLLM_TYPE_I16: + data_type = QNN_DATATYPE_SFIXED_POINT_16; + quantScale = scale->hostPtr()[0] / (pow(2, 15) - 1); + // quantScale = roundf(quantScale * 100000) / 100000; + quantDefine = QNN_DEFINITION_DEFINED; + quantType = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET; + break; + case MLLM_TYPE_F16: + data_type = QNN_DATATYPE_FLOAT_16; + break; + default: + data_type = QNN_DATATYPE_FLOAT_32; + break; } inputTensorNames_.push_back(new string(output->name())); @@ -66,7 +71,7 @@ ErrorCode QNNCommonOp::graphAddNode(string name, string nodeType, vectorgraphAddNode(name, nodeType, inputTensorNames, outputTensors, params, packageName)) { @@ -87,7 +92,7 @@ ErrorCode QNNCommonOp::graphAddNode(string name, string nodeType, vector Qnn_TensorType_t QNNCommonOp::getOutputTensorType(shared_ptr tensor) const { if (tensor->ttype() == GRAPH_OUTPUT) { // in Module API, the outputs of a graph is not allocated before setUp, alloc here - if(tensor->allocted() == 0) { + if (tensor->allocted() == 0) { tensor->alloc(); } qnnBackend_->pushOutputBuffers(tensor->hostPtr()); diff --git a/src/backends/qnn/op/QNNDequantize.cpp b/src/backends/qnn/op/QNNDequantize.cpp index 569deb73d..699a51666 100644 --- a/src/backends/qnn/op/QNNDequantize.cpp +++ b/src/backends/qnn/op/QNNDequantize.cpp @@ -6,11 +6,13 @@ #include namespace mllm { -QNNDequantize::QNNDequantize(Backend *bn, string opName, bool isNSHD, bool isFP32) : +QNNDequantize::QNNDequantize(Backend *bn, string opName, bool isNSHD, bool isFP32, DataType type) : QNNCommonOp(bn, opName) { isNSHD_ = isNSHD; isFP32_ = isFP32; + activation_dtype_ = type; scale_.setBackend(bn); + bias_.setBackend(bn); } ErrorCode QNNDequantize::reshape(vector> inputs, vector> outputs) { @@ -20,7 +22,6 @@ ErrorCode QNNDequantize::reshape(vector> inputs, vector> inputs, vector> outputs) { - auto outName = outputs[0]->name(); uint32_t dimensionsOutput[4]; @@ -37,94 +38,226 @@ ErrorCode QNNDequantize::setUp(vector> inputs, vector()[0] / 127.0; - dequantScale = roundf(dequantScale * 100000) / 100000; - - if (name().find("q_proj") != -1) { - dequantScale = dequantScale / std::sqrt(outputs[0]->dimension()); + switch (activation_dtype_) { + case MLLM_TYPE_I8: + dequantScale = scale_.hostPtr()[0] / (pow(2, 7) - 1); + break; + case MLLM_TYPE_I16: + dequantScale = scale_.hostPtr()[0] / (pow(2, 15) - 1); + break; + default: + return NOT_SUPPORT; } + // dequantScale = roundf(dequantScale * 100000) / 100000; + + // if (name().find("q_proj") != -1) { + // dequantScale = dequantScale / std::sqrt(outputs[0]->dimension()); + // } + + if (name().find("q_proj") != -1 || name().find("k_proj") != -1 || name().find("v_proj") != -1 ) { + if (isFP32_) { + uint32_t paramsDeQuantizeDimension[1] = {1}; + auto paramsDeQuantizeName = name() + "dequantize_params"; + + vector paramsDeQuantize = { + {.paramType = QNN_PARAMTYPE_TENSOR, + .name = "scale", + .tensorParam = + (Qnn_Tensor_t){.version = QNN_TENSOR_VERSION_1, + .v1 = { + .id = 0, + .name = paramsDeQuantizeName.c_str(), + .type = QNN_TENSOR_TYPE_STATIC, + .dataFormat = QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER, + .dataType = QNN_DATATYPE_FLOAT_32, + .quantizeParams = {QNN_DEFINITION_UNDEFINED, + QNN_QUANTIZATION_ENCODING_UNDEFINED, + {.scaleOffsetEncoding = {.scale = 0.0000000000000000f, .offset = 0}}}, + .rank = 1, + .dimensions = paramsDeQuantizeDimension, + .memType = QNN_TENSORMEMTYPE_RAW, + .clientBuf = {.data = (uint8_t *)&dequantScale, + .dataSize = sizeof(float)}}}}}; + + uint32_t dimensionsBias[4] = {1, 1, 1, static_cast(bias_.dimension())}; + qnnBackend_->modelAddTensor(bias_.name(), (Qnn_Tensor_t){ + .version = QNN_TENSOR_VERSION_1, + .v1 = { + .id = 0, + .name = bias_.name().c_str(), + .type = QNN_TENSOR_TYPE_STATIC, + .dataFormat = QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER, + .dataType = QNN_DATATYPE_FLOAT_32, + .rank = 4, + .dimensions = dimensionsBias, + .memType = QNN_TENSORMEMTYPE_RAW, + .clientBuf = {.data = bias_.hostPtr(), + .dataSize = (uint32_t)bias_.cntSize()}}}); + + vector outputTensor = {{.version = QNN_TENSOR_VERSION_1, + .v1 = { + .id = 0, + .name = outName.c_str(), + .type = getOutputTensorType(outputs[0]), + .dataFormat = QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER, + .dataType = QNN_DATATYPE_FLOAT_32, + .quantizeParams = {QNN_DEFINITION_DEFINED, + QNN_QUANTIZATION_ENCODING_SCALE_OFFSET, + {.scaleOffsetEncoding = {.scale = dequantScale, .offset = 0}}}, + .rank = 4, + .dimensions = dimensionsOutput, + .memType = QNN_TENSORMEMTYPE_RAW, + .clientBuf = {.data = nullptr, + .dataSize = 0}}}}; + return graphAddNode(name(), "LLaMADequantizeAdd", {inputs[0]->name(), bias_.name()}, outputTensor, paramsDeQuantize, "LLaMAPackage"); + } else { + outputs[0]->setDtype(MLLM_TYPE_F16); + uint32_t paramsDeQuantizeDimension[1] = {1}; + auto paramsDeQuantizeName = name() + "dequantize_params"; - if (isFP32_) { - uint32_t paramsDeQuantizeDimension[1] = {1}; - auto paramsDeQuantizeName = name() + "dequantize_params"; - vector paramsDeQuantize = { - {.paramType = QNN_PARAMTYPE_TENSOR, - .name = "scale", - .tensorParam = - (Qnn_Tensor_t){.version = QNN_TENSOR_VERSION_1, - .v1 = { - .id = 0, - .name = paramsDeQuantizeName.c_str(), - .type = QNN_TENSOR_TYPE_STATIC, - .dataFormat = QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER, - .dataType = QNN_DATATYPE_FLOAT_32, - .quantizeParams = {QNN_DEFINITION_UNDEFINED, - QNN_QUANTIZATION_ENCODING_UNDEFINED, - {.scaleOffsetEncoding = {.scale = 0.0000000000000000f, .offset = 0}}}, - .rank = 1, - .dimensions = paramsDeQuantizeDimension, - .memType = QNN_TENSORMEMTYPE_RAW, - .clientBuf = {.data = (uint8_t *)&dequantScale, - .dataSize = sizeof(float)}}}}}; - - vector outputTensor = {{.version = QNN_TENSOR_VERSION_1, - .v1 = { - .id = 0, - .name = outName.c_str(), - .type = getOutputTensorType(outputs[0]), - .dataFormat = QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER, - .dataType = QNN_DATATYPE_FLOAT_32, - .quantizeParams = {QNN_DEFINITION_DEFINED, - QNN_QUANTIZATION_ENCODING_SCALE_OFFSET, - {.scaleOffsetEncoding = {.scale = dequantScale, .offset = 0}}}, - .rank = 4, - .dimensions = dimensionsOutput, - .memType = QNN_TENSORMEMTYPE_RAW, - .clientBuf = {.data = nullptr, - .dataSize = 0}}}}; - return graphAddNode(name(), "LLaMADequantize", {inputs[0]->name()}, outputTensor, paramsDeQuantize, "LLaMAPackage"); + vector paramsDeQuantize = { + {.paramType = QNN_PARAMTYPE_TENSOR, + .name = "scale", + .tensorParam = + (Qnn_Tensor_t){.version = QNN_TENSOR_VERSION_1, + .v1 = { + .id = 0, + .name = paramsDeQuantizeName.c_str(), + .type = QNN_TENSOR_TYPE_STATIC, + .dataFormat = QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER, + .dataType = QNN_DATATYPE_FLOAT_32, + .quantizeParams = {QNN_DEFINITION_UNDEFINED, + QNN_QUANTIZATION_ENCODING_UNDEFINED, + {.scaleOffsetEncoding = {.scale = 0.0000000000000000f, + .offset = 0}}}, + .rank = 1, + .dimensions = paramsDeQuantizeDimension, + .memType = QNN_TENSORMEMTYPE_RAW, + .clientBuf = {.data = (uint8_t *)&dequantScale, + .dataSize = sizeof(float)}}}}}; + + + uint32_t dimensionsBias[4] = {1, 1, 1, static_cast(bias_.dimension())}; + qnnBackend_->modelAddTensor(bias_.name(), (Qnn_Tensor_t){ + .version = QNN_TENSOR_VERSION_1, + .v1 = { + .id = 0, + .name = bias_.name().c_str(), + .type = QNN_TENSOR_TYPE_STATIC, + .dataFormat = QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER, + .dataType = QNN_DATATYPE_FLOAT_32, + .rank = 4, + .dimensions = dimensionsBias, + .memType = QNN_TENSORMEMTYPE_RAW, + .clientBuf = {.data = bias_.hostPtr(), + .dataSize = (uint32_t)bias_.cntSize()}}}); + + vector outputTensor = {{QNN_TENSOR_VERSION_1, + {.v1 = { + .id = 0, + .name = outName.c_str(), + .type = getOutputTensorType(outputs[0]), + .dataFormat = QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER, + .dataType = QNN_DATATYPE_FLOAT_16, + .quantizeParams = {QNN_DEFINITION_DEFINED, + QNN_QUANTIZATION_ENCODING_SCALE_OFFSET, + {.scaleOffsetEncoding = {.scale = dequantScale, .offset = 0}}}, + .rank = 4, + .dimensions = dimensionsOutput, + .memType = QNN_TENSORMEMTYPE_RAW, + .clientBuf = {.data = nullptr, + .dataSize = 0}}}}}; + return graphAddNode(name(), "LLaMADequantizeAdd", {inputs[0]->name(), bias_.name()}, outputTensor, paramsDeQuantize, "LLaMAPackage"); + } } else { - outputs[0]->setDtype(MLLM_TYPE_F16); - uint32_t paramsDeQuantizeDimension[1] = {1}; - auto paramsDeQuantizeName = name() + "dequantize_params"; - vector paramsDeQuantize = { - {.paramType = QNN_PARAMTYPE_TENSOR, - .name = "scale", - .tensorParam = - (Qnn_Tensor_t){.version = QNN_TENSOR_VERSION_1, - .v1 = { - .id = 0, - .name = paramsDeQuantizeName.c_str(), - .type = QNN_TENSOR_TYPE_STATIC, - .dataFormat = QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER, - .dataType = QNN_DATATYPE_FLOAT_32, - .quantizeParams = {QNN_DEFINITION_UNDEFINED, - QNN_QUANTIZATION_ENCODING_UNDEFINED, - {.scaleOffsetEncoding = {.scale = 0.0000000000000000f, - .offset = 0}}}, - .rank = 1, - .dimensions = paramsDeQuantizeDimension, - .memType = QNN_TENSORMEMTYPE_RAW, - .clientBuf = {.data = (uint8_t *)&dequantScale, - .dataSize = sizeof(float)}}}}}; - - vector outputTensor = {{QNN_TENSOR_VERSION_1, - {.v1 = { - .id = 0, - .name = outName.c_str(), - .type = getOutputTensorType(outputs[0]), - .dataFormat = QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER, - .dataType = QNN_DATATYPE_FLOAT_16, - .quantizeParams = {QNN_DEFINITION_DEFINED, - QNN_QUANTIZATION_ENCODING_SCALE_OFFSET, - {.scaleOffsetEncoding = {.scale = dequantScale, .offset = 0}}}, - .rank = 4, - .dimensions = dimensionsOutput, - .memType = QNN_TENSORMEMTYPE_RAW, - .clientBuf = {.data = nullptr, - .dataSize = 0}}}}}; - return graphAddNode(name(), "LLaMADequantize", {inputs[0]->name()}, outputTensor, paramsDeQuantize, "LLaMAPackage"); + + if (isFP32_) { + uint32_t paramsDeQuantizeDimension[1] = {1}; + auto paramsDeQuantizeName = name() + "dequantize_params"; + + vector paramsDeQuantize = { + {.paramType = QNN_PARAMTYPE_TENSOR, + .name = "scale", + .tensorParam = + (Qnn_Tensor_t){.version = QNN_TENSOR_VERSION_1, + .v1 = { + .id = 0, + .name = paramsDeQuantizeName.c_str(), + .type = QNN_TENSOR_TYPE_STATIC, + .dataFormat = QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER, + .dataType = QNN_DATATYPE_FLOAT_32, + .quantizeParams = {QNN_DEFINITION_UNDEFINED, + QNN_QUANTIZATION_ENCODING_UNDEFINED, + {.scaleOffsetEncoding = {.scale = 0.0000000000000000f, .offset = 0}}}, + .rank = 1, + .dimensions = paramsDeQuantizeDimension, + .memType = QNN_TENSORMEMTYPE_RAW, + .clientBuf = {.data = (uint8_t *)&dequantScale, + .dataSize = sizeof(float)}}}}}; + + vector outputTensor = {{.version = QNN_TENSOR_VERSION_1, + .v1 = { + .id = 0, + .name = outName.c_str(), + .type = getOutputTensorType(outputs[0]), + .dataFormat = QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER, + .dataType = QNN_DATATYPE_FLOAT_32, + .quantizeParams = {QNN_DEFINITION_DEFINED, + QNN_QUANTIZATION_ENCODING_SCALE_OFFSET, + {.scaleOffsetEncoding = {.scale = dequantScale, .offset = 0}}}, + .rank = 4, + .dimensions = dimensionsOutput, + .memType = QNN_TENSORMEMTYPE_RAW, + .clientBuf = {.data = nullptr, + .dataSize = 0}}}}; + return graphAddNode(name(), "LLaMADequantize", {inputs[0]->name()}, outputTensor, paramsDeQuantize, "LLaMAPackage"); + } else { + outputs[0]->setDtype(MLLM_TYPE_F16); + uint32_t paramsDeQuantizeDimension[1] = {1}; + auto paramsDeQuantizeName = name() + "dequantize_params"; + + vector paramsDeQuantize = { + {.paramType = QNN_PARAMTYPE_TENSOR, + .name = "scale", + .tensorParam = + (Qnn_Tensor_t){.version = QNN_TENSOR_VERSION_1, + .v1 = { + .id = 0, + .name = paramsDeQuantizeName.c_str(), + .type = QNN_TENSOR_TYPE_STATIC, + .dataFormat = QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER, + .dataType = QNN_DATATYPE_FLOAT_32, + .quantizeParams = {QNN_DEFINITION_UNDEFINED, + QNN_QUANTIZATION_ENCODING_UNDEFINED, + {.scaleOffsetEncoding = {.scale = 0.0000000000000000f, + .offset = 0}}}, + .rank = 1, + .dimensions = paramsDeQuantizeDimension, + .memType = QNN_TENSORMEMTYPE_RAW, + .clientBuf = {.data = (uint8_t *)&dequantScale, + .dataSize = sizeof(float)}}}}}; + + vector outputTensor = {{QNN_TENSOR_VERSION_1, + {.v1 = { + .id = 0, + .name = outName.c_str(), + .type = getOutputTensorType(outputs[0]), + .dataFormat = QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER, + .dataType = QNN_DATATYPE_FLOAT_16, + .quantizeParams = {QNN_DEFINITION_DEFINED, + QNN_QUANTIZATION_ENCODING_SCALE_OFFSET, + {.scaleOffsetEncoding = {.scale = dequantScale, .offset = 0}}}, + .rank = 4, + .dimensions = dimensionsOutput, + .memType = QNN_TENSORMEMTYPE_RAW, + .clientBuf = {.data = nullptr, + .dataSize = 0}}}}}; + return graphAddNode(name(), "LLaMADequantize", {inputs[0]->name()}, outputTensor, paramsDeQuantize, "LLaMAPackage"); + } + } + + } ErrorCode QNNDequantize::load(AbstructLoader &loader) { @@ -150,6 +283,38 @@ ErrorCode QNNDequantize::load(AbstructLoader &loader) { scale_.alloc(); loader.load(&scale_); + + if (name().find("q_proj") != -1 || name().find("k_proj") != -1 || name().find("v_proj") != -1 ) { + + // std::cout << name() << std::endl; + + string biasName = name(); + wordToRemove = "dequantize"; + string biasTypeName = "bias"; + + int pos = biasName.find(wordToRemove); + if (pos != -1) { + biasName.erase(pos, wordToRemove.length()); + } + + // std::cout << biasName + biasTypeName << std::endl; + + int hidden_size = 1536; + if (name().find("k_proj") != -1 || name().find("v_proj") != -1 ) + hidden_size = 256; + + bias_.setName(biasName + biasTypeName); + bias_.reshape(1, 1, 1, hidden_size); + bias_.setDtype(MLLM_TYPE_F32); + bias_.alloc(); + loader.load(&bias_); + + // bias_.printData(); + + } + + + return Op::load(loader); } } // namespace mllm diff --git a/src/backends/qnn/op/QNNDequantize.hpp b/src/backends/qnn/op/QNNDequantize.hpp index b4b78fc22..229d11db4 100644 --- a/src/backends/qnn/op/QNNDequantize.hpp +++ b/src/backends/qnn/op/QNNDequantize.hpp @@ -3,10 +3,11 @@ #define MLLM_QNNDEQUANTIZE_H #include "QNNCommonOp.hpp" +#include "Types.hpp" namespace mllm { class QNNDequantize : public QNNCommonOp { public: - QNNDequantize(Backend *bn, string opName, bool isNSHD, bool isFP32); + QNNDequantize(Backend *bn, string opName, bool isNSHD, bool isFP32, DataType type = MLLM_TYPE_I8); virtual ~QNNDequantize() = default; virtual ErrorCode reshape(vector> inputs, vector> outputs) override; virtual ErrorCode setUp(vector> inputs, vector> outputs) override; @@ -15,12 +16,13 @@ class QNNDequantize : public QNNCommonOp { bool isNSHD_; bool isFP32_; Tensor scale_; + Tensor bias_; }; class QNNDequantizeCreator : public QNNBackend::Creator { public: virtual Op *create(OpParam op_param, Backend *bn, string name) const { - return new QNNDequantize(bn, name, (bool)op_param["isNSHD"], (bool)op_param["isFP32"]); + return new QNNDequantize(bn, name, (bool)op_param["isNSHD"], (bool)op_param["isFP32"], (DataType)op_param["inType"]); } }; diff --git a/src/backends/qnn/op/QNNIRoPE.cpp b/src/backends/qnn/op/QNNIRoPE.cpp index d78c8752a..a728aba73 100644 --- a/src/backends/qnn/op/QNNIRoPE.cpp +++ b/src/backends/qnn/op/QNNIRoPE.cpp @@ -92,7 +92,7 @@ ErrorCode QNNIRoPE::setUp(vector> inputs, vector()[0] / 127.0; - dequantScale = roundf(dequantScale * 100000) / 100000; + // dequantScale = roundf(dequantScale * 100000) / 100000; if (name().find("q_proj") != -1) { dequantScale = dequantScale / std::sqrt(outputs[0]->dimension()); diff --git a/src/backends/qnn/op/QNNLinearINT8.cpp b/src/backends/qnn/op/QNNLinearINT8.cpp index a1510a943..33b313f2a 100755 --- a/src/backends/qnn/op/QNNLinearINT8.cpp +++ b/src/backends/qnn/op/QNNLinearINT8.cpp @@ -1,5 +1,6 @@ #include "QNNLinearINT8.hpp" +#include "Backend.hpp" #include "QnnTypes.h" #include "Types.hpp" #include "QNNCommonOp.hpp" @@ -9,13 +10,12 @@ namespace mllm { QNNLinearINT8::QNNLinearINT8(Backend *bn, string opName, int in_features, int out_features, bool bias) : QNNCommonOp(bn, opName), in_features_(in_features), out_features_(out_features), support_bias_(bias) { - weight_.setBackend(bn); - bias_.setBackend(bn); + weight_.setBackend(Backend::global_backends[MLLM_CPU]); + bias_.setBackend(Backend::global_backends[MLLM_CPU]); - weightScale_.setBackend(bn); - biasScale_.setBackend(bn); - outputScale_.setBackend(bn); - inputScale_.setBackend(bn); + weightScale_.setBackend(Backend::global_backends[MLLM_CPU]); + biasScale_.setBackend(Backend::global_backends[MLLM_CPU]); + outputScale_.setBackend(Backend::global_backends[MLLM_CPU]); } ErrorCode QNNLinearINT8::reshape(vector> inputs, vector> outputs) { @@ -38,6 +38,17 @@ ErrorCode QNNLinearINT8::reshape(vector> inputs, vector> inputs, vector> outputs) { + switch (inputs[0]->dtype()) { + case MLLM_TYPE_I8: + return setUpW8A8(inputs, outputs); + case MLLM_TYPE_I16: + return setUpW8A16(inputs, outputs); + default: + return NOT_SUPPORT; + } +} + +ErrorCode QNNLinearINT8::setUpW8A8(vector> &inputs, vector> &outputs) { outputs[0]->setDtype(MLLM_TYPE_I8); // add matmul param to qnn vector paramsMatmul = { @@ -142,8 +153,8 @@ ErrorCode QNNLinearINT8::setUp(vector> inputs, vector()[0] / 127.0; - outputScale = roundf(outputScale * 100000) / 100000; + outputScale = outputScale_.hostPtr()[0] / (pow(2, 7) - 1); + // outputScale = roundf(outputScale * 100000) / 100000; vector matmulOut = {{QNN_TENSOR_VERSION_1, {.v1 = { @@ -170,6 +181,13 @@ ErrorCode QNNLinearINT8::setUp(vector> inputs, vector()[0]; + auto biasBuffer = (int8_t *)malloc(bias_.count() * sizeof(int8_t)); +#pragma omp parallel for + for (int i = 0; i < out_features_; i++) { + int32_t val = bias_.dataAt(0, 0, 0, i) + 128; + biasBuffer[i] = val; + } + qnnBackend_->modelAddTensor(bias_.name(), (Qnn_Tensor_t){ .version = QNN_TENSOR_VERSION_1, .v1 = { @@ -184,14 +202,14 @@ ErrorCode QNNLinearINT8::setUp(vector> inputs, vector(), - .dataSize = (uint32_t)bias_.cntSize()}}}); + .clientBuf = {.data = biasBuffer, + .dataSize = (uint32_t)(bias_.count() * sizeof(int8_t)) }}}); // free bias host memory bias_.free(); float outputScale = 0; - outputScale = outputScale_.hostPtr()[0] / 127.0; - outputScale = roundf(outputScale * 100000) / 100000; + outputScale = outputScale_.hostPtr()[0] / (pow(2, 7) - 1); + // outputScale = roundf(outputScale * 100000) / 100000; // final output vector biasOutput = {{QNN_TENSOR_VERSION_1, @@ -209,7 +227,191 @@ ErrorCode QNNLinearINT8::setUp(vector> inputs, vectorname(), weight_.name(), bias_.name()}, biasOutput, params_InceptionV3_InceptionV3_Conv2d_1a_3x3_Conv2D); + return graphAddNode(name() + ".linear_w8a8", "Conv2d", {inputs[0]->name(), weight_.name(), bias_.name()}, biasOutput, params_InceptionV3_InceptionV3_Conv2d_1a_3x3_Conv2D); +} + +ErrorCode QNNLinearINT8::setUpW8A16(vector> &inputs, vector> &outputs) { + outputs[0]->setDtype(MLLM_TYPE_I16); + // add matmul param to qnn + vector paramsMatmul = { + {.paramType = QNN_PARAMTYPE_SCALAR, + .name = "transpose_in0", + .scalarParam = (Qnn_Scalar_t){QNN_DATATYPE_BOOL_8, {.bool8Value = 0}}}, + {.paramType = QNN_PARAMTYPE_SCALAR, + .name = "transpose_in1", + .scalarParam = (Qnn_Scalar_t){QNN_DATATYPE_BOOL_8, {.bool8Value = 1}}}}; + + uint32_t dimensions_InceptionV3_InceptionV3_Conv2d_1a_3x3_Conv2D_dilation[] = {2}; + uint32_t InceptionV3_InceptionV3_Conv2d_1a_3x3_Conv2D_dilation[] = {1, 1}; + uint32_t dimensions_InceptionV3_InceptionV3_Conv2d_1a_3x3_Conv2D_pad_amount[] = {2, 2}; + uint32_t InceptionV3_InceptionV3_Conv2d_1a_3x3_Conv2D_pad_amount[] = {0, 0, 0, 0}; + uint32_t dimensions_InceptionV3_InceptionV3_Conv2d_1a_3x3_Conv2D_stride[] = {2}; + uint32_t InceptionV3_InceptionV3_Conv2d_1a_3x3_Conv2D_stride[] = {1, 1}; + + string strideName = name() + ".stride"; + string padName = name() + ".pad"; + vector params_InceptionV3_InceptionV3_Conv2d_1a_3x3_Conv2D = { + {.paramType = QNN_PARAMTYPE_TENSOR, + .name = "stride", + .tensorParam = + (Qnn_Tensor_t){ + .version = QNN_TENSOR_VERSION_1, + .v1 = {.id = 0, + .name = strideName.c_str(), + .type = QNN_TENSOR_TYPE_STATIC, + .dataFormat = QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER, + .dataType = QNN_DATATYPE_UINT_32, + .quantizeParams = {QNN_DEFINITION_UNDEFINED, + QNN_QUANTIZATION_ENCODING_UNDEFINED, + {.scaleOffsetEncoding = {.scale = 0.0000000000000000f, + .offset = 0}}}, + .rank = 1, + .dimensions = dimensions_InceptionV3_InceptionV3_Conv2d_1a_3x3_Conv2D_stride, + .memType = QNN_TENSORMEMTYPE_RAW, + .clientBuf = + {.data = (uint8_t *)InceptionV3_InceptionV3_Conv2d_1a_3x3_Conv2D_stride, + .dataSize = 8}}}}, + {.paramType = QNN_PARAMTYPE_TENSOR, + .name = "pad_amount", + .tensorParam = + (Qnn_Tensor_t){ + .version = QNN_TENSOR_VERSION_1, + .v1 = {.id = 0, + .name = padName.c_str(), + .type = QNN_TENSOR_TYPE_STATIC, + .dataFormat = QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER, + .dataType = QNN_DATATYPE_UINT_32, + .quantizeParams = {QNN_DEFINITION_UNDEFINED, + QNN_QUANTIZATION_ENCODING_UNDEFINED, + {.scaleOffsetEncoding = {.scale = 0.0000000000000000f, + .offset = 0}}}, + .rank = 2, + .dimensions = + dimensions_InceptionV3_InceptionV3_Conv2d_1a_3x3_Conv2D_pad_amount, + .memType = QNN_TENSORMEMTYPE_RAW, + .clientBuf = + {.data = (uint8_t *) + InceptionV3_InceptionV3_Conv2d_1a_3x3_Conv2D_pad_amount, + .dataSize = 16}}}}, + + }; + + // add weight tensor to qnn + uint32_t dimensionsWeight[4] = {1, 1, static_cast(weight_.sequence()), static_cast(weight_.dimension())}; + + auto qnnQuantDefined = QNN_DEFINITION_UNDEFINED; + float weightScale = 0; + + qnnQuantDefined = QNN_DEFINITION_DEFINED; + weightScale = weightScale_.hostPtr()[0]; + + qnnBackend_->modelAddTensor(weight_.name(), (Qnn_Tensor_t){ + .version = QNN_TENSOR_VERSION_1, + .v1 = { + .id = 0, + .name = weight_.name().c_str(), + .type = QNN_TENSOR_TYPE_STATIC, + .dataFormat = QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER, + .dataType = QNN_DATATYPE_SFIXED_POINT_8, + .quantizeParams = {qnnQuantDefined, + QNN_QUANTIZATION_ENCODING_SCALE_OFFSET, + {.scaleOffsetEncoding = {.scale = weightScale, .offset = 0}}}, + .rank = 4, + .dimensions = dimensionsWeight, + .memType = QNN_TENSORMEMTYPE_RAW, + .clientBuf = {.data = weight_.hostPtr(), + .dataSize = (uint32_t)weight_.cntSize()}}}); + // free weight host memory + weight_.free(); + + // dimensions of matmul output and bias + uint32_t dimensionsOutput[4] = {static_cast(outputs[0]->batch()), + static_cast(outputs[0]->sequence()), + static_cast(outputs[0]->head()), + static_cast(outputs[0]->dimension())}; + + auto outName = outputs[0]->name(); + + // if don't support bias, just dequantize and write to tensor with name of outputs[0] + if (!support_bias_) { + float outputScale = 0; + outputScale = outputScale_.hostPtr()[0] / (pow(2, 15) - 1); + // outputScale = roundf(outputScale * 100000) / 100000; + + vector matmulOut = {{QNN_TENSOR_VERSION_1, + {.v1 = { + .id = 0, + .name = outName.c_str(), + .type = getOutputTensorType(outputs[0]), + .dataFormat = QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER, + .dataType = QNN_DATATYPE_SFIXED_POINT_16, + .quantizeParams = {QNN_DEFINITION_DEFINED, + QNN_QUANTIZATION_ENCODING_SCALE_OFFSET, + {.scaleOffsetEncoding = {.scale = outputScale, .offset = 0}}}, + .rank = 4, + .dimensions = dimensionsOutput, + .memType = QNN_TENSORMEMTYPE_RAW, + .clientBuf = {.data = nullptr, + .dataSize = 0}}}}}; + return graphAddNode(name() + ".linearint8", "Conv2d", {inputs[0]->name(), weight_.name()}, matmulOut, params_InceptionV3_InceptionV3_Conv2d_1a_3x3_Conv2D); + } + + // add bias tensor to qnn + uint32_t dimensionsBias[1] = {(uint32_t)out_features_}; + float biasScale = 0; + + qnnQuantDefined = QNN_DEFINITION_DEFINED; + biasScale = biasScale_.hostPtr()[0]; + // create a int32 buffer, convert the bias to int32 + auto biasBuffer = (int32_t *)malloc(bias_.count() * sizeof(int32_t)); +#pragma omp parallel for + for (int i = 0; i < out_features_; i++) { + // int32_t val = bias_.dataAt(0, 0, 0, i) - 128; + int32_t val = bias_.dataAt(0, 0, 0, i); + biasBuffer[i] = val; + } + + qnnBackend_->modelAddTensor(bias_.name(), (Qnn_Tensor_t){ + .version = QNN_TENSOR_VERSION_1, + .v1 = { + .id = 0, + .name = bias_.name().c_str(), + .type = QNN_TENSOR_TYPE_STATIC, + .dataFormat = QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER, + .dataType = QNN_DATATYPE_SFIXED_POINT_32, + .quantizeParams = {qnnQuantDefined, + QNN_QUANTIZATION_ENCODING_SCALE_OFFSET, + {.scaleOffsetEncoding = {.scale = biasScale, .offset = 0}}}, + .rank = 1, + .dimensions = dimensionsBias, + .memType = QNN_TENSORMEMTYPE_RAW, + .clientBuf = {.data = biasBuffer, + .dataSize = (uint32_t)(bias_.count() * sizeof(int32_t))}}}); + // free bias host memory + bias_.free(); + delete biasBuffer; + + float outputScale = 0; + outputScale = outputScale_.hostPtr()[0] / (pow(2, 15) - 1); + // outputScale = roundf(outputScale * 100000) / 100000; + + // final output + vector biasOutput = {{QNN_TENSOR_VERSION_1, + {.v1 = { + .id = 0, + .name = outName.c_str(), + .type = getOutputTensorType(outputs[0]), + .dataFormat = QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER, + .dataType = QNN_DATATYPE_SFIXED_POINT_16, + .quantizeParams = {QNN_DEFINITION_DEFINED, + QNN_QUANTIZATION_ENCODING_SCALE_OFFSET, + {.scaleOffsetEncoding = {.scale = outputScale, .offset = 0}}}, + .rank = 4, + .dimensions = dimensionsOutput, + .memType = QNN_TENSORMEMTYPE_RAW, + .clientBuf = {.data = nullptr, + .dataSize = 0}}}}}; + return graphAddNode(name() + ".linear_w8a16", "Conv2d", {inputs[0]->name(), weight_.name(), bias_.name()}, biasOutput, params_InceptionV3_InceptionV3_Conv2d_1a_3x3_Conv2D); } ErrorCode QNNLinearINT8::load(AbstructLoader &loader) { @@ -221,16 +423,16 @@ ErrorCode QNNLinearINT8::load(AbstructLoader &loader) { bias_.setName(name() + ".bias"); bias_.reshape(1, 1, 1, out_features_); - bias_.setDtype(MLLM_TYPE_I8); + bias_.setDtype(MLLM_TYPE_I32); bias_.alloc(); if (support_bias_) { loader.load(&bias_); - // sign to unsign - for (int i = 0; i < out_features_; i++) { - int32_t val = bias_.dataAt(0, 0, 0, i); - val += 128; - bias_.setDataAt(0, 0, 0, i, (uint8_t)val); - } + // // sign to unsign + // for (int i = 0; i < out_features_; i++) { + // int32_t val = bias_.dataAt(0, 0, 0, i); + // val += 128; + // bias_.setDataAt(0, 0, 0, i, (uint8_t)val); + // } } else { memset(bias_.hostPtr(), 0, bias_.cntSize()); } @@ -253,12 +455,6 @@ ErrorCode QNNLinearINT8::load(AbstructLoader &loader) { outputScale_.alloc(); loader.load(&outputScale_); - inputScale_.setName(name() + ".input_scale"); - inputScale_.reshape(1, 1, 1, 1); - inputScale_.setDtype(MLLM_TYPE_F32); - inputScale_.alloc(); - loader.load(&inputScale_); - return Op::load(loader); } diff --git a/src/backends/qnn/op/QNNLinearINT8.hpp b/src/backends/qnn/op/QNNLinearINT8.hpp index ea9395eb3..e53fd3ccb 100644 --- a/src/backends/qnn/op/QNNLinearINT8.hpp +++ b/src/backends/qnn/op/QNNLinearINT8.hpp @@ -19,12 +19,15 @@ class QNNLinearINT8 : public QNNCommonOp { bool support_bias_; Tensor weight_; Tensor bias_; -// #ifdef SMOOTHQUANT + Tensor weightScale_; Tensor biasScale_; -// #endif + Tensor outputScale_; Tensor inputScale_; + + ErrorCode setUpW8A8(vector>& inputs, vector>& outputs); + ErrorCode setUpW8A16(vector>& inputs, vector>& outputs); }; class QNNLinearINT8Creator : public QNNBackend::Creator { diff --git a/src/backends/qnn/op/QNNQuantize.cpp b/src/backends/qnn/op/QNNQuantize.cpp index cbf4937e8..00f381b95 100644 --- a/src/backends/qnn/op/QNNQuantize.cpp +++ b/src/backends/qnn/op/QNNQuantize.cpp @@ -7,9 +7,11 @@ #include namespace mllm { -QNNQuantize::QNNQuantize(Backend *bn, string opName, bool isNSHD) : +QNNQuantize::QNNQuantize(Backend *bn, string opName, DataType type, bool isNSHD) : QNNCommonOp(bn, opName) { isNSHD_ = isNSHD; + assert(type == MLLM_TYPE_I8 || type == MLLM_TYPE_I16); + activation_dtype_ = type; scale_.setBackend(bn); } @@ -20,6 +22,17 @@ ErrorCode QNNQuantize::reshape(vector> inputs, vector> inputs, vector> outputs) { + switch (activation_dtype_) { + case MLLM_TYPE_I8: + return setUpI8(inputs, outputs); + case MLLM_TYPE_I16: + return setUpI16(inputs, outputs); + default: + return NOT_SUPPORT; + } +} + +ErrorCode QNNQuantize::setUpI8(vector> &inputs, vector> &outputs) { outputs[0]->setDtype(MLLM_TYPE_I8); auto outName = outputs[0]->name(); @@ -38,8 +51,8 @@ ErrorCode QNNQuantize::setUp(vector> inputs, vector()[0] / 127.0; - quantScale = roundf(quantScale * 100000) / 100000; + quantScale = scale_.hostPtr()[0] / (pow(2, 7) - 1); + // quantScale = roundf(quantScale * 100000) / 100000; uint32_t paramsQuantizeDimension[1] = {1}; auto paramsQuantizeName = name() + "quantize_params"; @@ -81,6 +94,70 @@ ErrorCode QNNQuantize::setUp(vector> inputs, vectorname()}, outputTensor, paramsQuantize, "LLaMAPackage"); } + +ErrorCode QNNQuantize::setUpI16(vector> &inputs, vector> &outputs) { + outputs[0]->setDtype(MLLM_TYPE_I16); + auto outName = outputs[0]->name(); + + uint32_t dimensionsOutput[4]; + + if (isNSHD_) { + dimensionsOutput[0] = static_cast(outputs[0]->batch()); + dimensionsOutput[1] = static_cast(outputs[0]->sequence()); + dimensionsOutput[2] = static_cast(outputs[0]->head()); + dimensionsOutput[3] = static_cast(outputs[0]->dimension()); + } else { + dimensionsOutput[0] = static_cast(outputs[0]->batch()); + dimensionsOutput[1] = static_cast(outputs[0]->head()); + dimensionsOutput[2] = static_cast(outputs[0]->sequence()); + dimensionsOutput[3] = static_cast(outputs[0]->dimension()); + } + + float quantScale = 0; + quantScale = scale_.hostPtr()[0] / (pow(2, 15) - 1); + // quantScale = roundf(quantScale * 100000) / 100000; + + uint32_t paramsQuantizeDimension[1] = {1}; + auto paramsQuantizeName = name() + "quantize_params"; + vector paramsQuantize = { + {.paramType = QNN_PARAMTYPE_TENSOR, + .name = "scale", + .tensorParam = + (Qnn_Tensor_t){.version = QNN_TENSOR_VERSION_1, + .v1 = { + .id = 0, + .name = paramsQuantizeName.c_str(), + .type = QNN_TENSOR_TYPE_STATIC, + .dataFormat = QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER, + .dataType = QNN_DATATYPE_FLOAT_32, + .quantizeParams = {QNN_DEFINITION_UNDEFINED, + QNN_QUANTIZATION_ENCODING_UNDEFINED, + {.scaleOffsetEncoding = {.scale = 0.0000000000000000f, + .offset = 0}}}, + .rank = 1, + .dimensions = paramsQuantizeDimension, + .memType = QNN_TENSORMEMTYPE_RAW, + .clientBuf = {.data = (uint8_t *)&quantScale, + .dataSize = sizeof(float)}}}}}; + + vector outputTensor = {{QNN_TENSOR_VERSION_1, + {.v1 = { + .id = 0, + .name = outName.c_str(), + .type = getOutputTensorType(outputs[0]), + .dataFormat = QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER, + .dataType = QNN_DATATYPE_SFIXED_POINT_16, + .quantizeParams = {QNN_DEFINITION_DEFINED, + QNN_QUANTIZATION_ENCODING_SCALE_OFFSET, + {.scaleOffsetEncoding = {.scale = quantScale, .offset = 0}}}, + .rank = 4, + .dimensions = dimensionsOutput, + .memType = QNN_TENSORMEMTYPE_RAW, + .clientBuf = {.data = nullptr, + .dataSize = 0}}}}}; + return graphAddNode(name(), "LLaMAQuantize", {inputs[0]->name()}, outputTensor, paramsQuantize, "LLaMAPackage"); +} + ErrorCode QNNQuantize::load(AbstructLoader &loader) { string scaleName = name(); diff --git a/src/backends/qnn/op/QNNQuantize.hpp b/src/backends/qnn/op/QNNQuantize.hpp index dbe15852e..d08e044d2 100644 --- a/src/backends/qnn/op/QNNQuantize.hpp +++ b/src/backends/qnn/op/QNNQuantize.hpp @@ -6,20 +6,24 @@ namespace mllm { class QNNQuantize : public QNNCommonOp { public: - QNNQuantize(Backend *bn, string opName, bool isNSHD); + QNNQuantize(Backend *bn, string opName, DataType type, bool isNSHD); virtual ~QNNQuantize() = default; virtual ErrorCode reshape(vector> inputs, vector> outputs) override; virtual ErrorCode setUp(vector> inputs, vector> outputs) override; virtual ErrorCode load(AbstructLoader &loader) override; + private: bool isNSHD_; Tensor scale_; + + ErrorCode setUpI8(vector> &inputs, vector> &outputs); + ErrorCode setUpI16(vector> &inputs, vector> &outputs); }; class QNNQuantizeCreator : public QNNBackend::Creator { public: virtual Op *create(OpParam op_param, Backend *bn, string name) const { - return new QNNQuantize(bn, name, (bool)op_param["isNSHD"]); + return new QNNQuantize(bn, name, (DataType)op_param["dtype"], (bool)op_param["isNSHD"]); } }; diff --git a/src/backends/qnn/op/QNNRMSNorm.cpp b/src/backends/qnn/op/QNNRMSNorm.cpp index 708cc2bbe..f79297c31 100644 --- a/src/backends/qnn/op/QNNRMSNorm.cpp +++ b/src/backends/qnn/op/QNNRMSNorm.cpp @@ -20,7 +20,7 @@ ErrorCode QNNRMSNorm::reshape(vector> inputs, vector> inputs, vector> outputs) { float quantScale = 0; quantScale = scale_.hostPtr()[0] / 127.0; - quantScale = roundf(quantScale * 100000) / 100000; + // quantScale = roundf(quantScale * 100000) / 100000; uint32_t dimWeight[4] = {(uint32_t)normSize_}; qnnBackend_->modelAddTensor(weight_.name(), (Qnn_Tensor_t){ diff --git a/src/backends/qnn/op/QNNRoPE.cpp b/src/backends/qnn/op/QNNRoPE.cpp index 183633a27..4bfa028f6 100644 --- a/src/backends/qnn/op/QNNRoPE.cpp +++ b/src/backends/qnn/op/QNNRoPE.cpp @@ -124,7 +124,7 @@ ErrorCode QNNRoPE::setUp(vector> inputs, vector()[0] / 127.0; - dequantScale = roundf(dequantScale * 100000) / 100000; + // dequantScale = roundf(dequantScale * 100000) / 100000; if (name().find("q_proj") != -1) { dequantScale = dequantScale / std::sqrt(outputs[0]->dimension()); diff --git a/src/backends/qnn/op/QNNSiLU.cpp b/src/backends/qnn/op/QNNSiLU.cpp index 7af121896..7a40caef8 100644 --- a/src/backends/qnn/op/QNNSiLU.cpp +++ b/src/backends/qnn/op/QNNSiLU.cpp @@ -1,4 +1,3 @@ - #include "QNNSiLU.hpp" #include "Types.hpp" #include "QNNCommonOp.hpp" @@ -23,7 +22,6 @@ ErrorCode QNNSiLU::setUp(vector> inputs, vector(outputs[0]->head()); dimensionsOutput[3] = static_cast(outputs[0]->dimension()); - auto type = QNN_DATATYPE_FLOAT_32; outputs[0]->setDtype(MLLM_TYPE_F32); @@ -31,7 +29,6 @@ ErrorCode QNNSiLU::setUp(vector> inputs, vectorsetDtype(MLLM_TYPE_F16); } - vector outputTensor = {{QNN_TENSOR_VERSION_1, {.v1 = { @@ -51,4 +48,4 @@ ErrorCode QNNSiLU::setUp(vector> inputs, vectorname()}, outputTensor, {}, "LLaMAPackage"); } -} // namespace mllm +} // namespace mllm \ No newline at end of file diff --git a/src/backends/qnn/op/QNNSubGraphFinalize.cpp b/src/backends/qnn/op/QNNSubGraphFinalize.cpp new file mode 100644 index 000000000..55df06768 --- /dev/null +++ b/src/backends/qnn/op/QNNSubGraphFinalize.cpp @@ -0,0 +1,32 @@ + +#include "QNNSubGraphFinalize.hpp" +#include "Types.hpp" +#include "QNNCommonOp.hpp" +#include + +namespace mllm { +QNNSubGraphFinalize::QNNSubGraphFinalize(Backend *bn, string opName) : + QNNCommonOp(bn, opName) { +} + +ErrorCode QNNSubGraphFinalize::reshape(vector> inputs, vector> outputs) { + for(auto& t : inputs) { + t->setTtype(GRAPH_OUTPUT); + } + return Op::reshape(inputs, outputs); +} + +ErrorCode QNNSubGraphFinalize::setUp(vector> inputs, vector> outputs) { + for (auto input : inputs) { + input->to(MLLM_CPU); + } + + this->backend_->onSetUpEnd(inputs, outputs); + return MLLM_NO_ERROR; +} + +ErrorCode QNNSubGraphFinalize::free(vector> inputs, vector> outputs) { + return MLLM_NO_ERROR; +} + +} // namespace mllm diff --git a/src/backends/qnn/op/QNNSubGraphFinalize.hpp b/src/backends/qnn/op/QNNSubGraphFinalize.hpp new file mode 100644 index 000000000..f1cb2cddd --- /dev/null +++ b/src/backends/qnn/op/QNNSubGraphFinalize.hpp @@ -0,0 +1,25 @@ + +#ifndef MLLM_QNNSUBGRAPHFINALIZE_H +#define MLLM_QNNSUBGRAPHFINALIZE_H + +#include "QNNCommonOp.hpp" +namespace mllm { +class QNNSubGraphFinalize : public QNNCommonOp { +public: + QNNSubGraphFinalize(Backend *bn, string opName); + virtual ~QNNSubGraphFinalize() = default; + virtual ErrorCode reshape(vector> inputs, vector> outputs) override; + virtual ErrorCode setUp(vector> inputs, vector> outputs) override; + virtual ErrorCode free(vector> inputs, vector> outputs) override; +}; + +class QNNSubGraphFinalizeCreator : public QNNBackend::Creator { +public: + virtual Op *create(OpParam op_param, Backend *bn, string name) const { + return new QNNSubGraphFinalize(bn, name); + } +}; + +} // namespace mllm + +#endif diff --git a/src/backends/qnn/op/QNNSubGraphStart.cpp b/src/backends/qnn/op/QNNSubGraphStart.cpp new file mode 100644 index 000000000..1bf35c61d --- /dev/null +++ b/src/backends/qnn/op/QNNSubGraphStart.cpp @@ -0,0 +1,37 @@ + +#include "QNNSubGraphStart.hpp" +#include "Types.hpp" +#include "QNNCommonOp.hpp" +#include + +namespace mllm { +QNNSubGraphStart::QNNSubGraphStart(Backend *bn, string opName) : + QNNCommonOp(bn, opName) { +} + +ErrorCode QNNSubGraphStart::reshape(vector> inputs, vector> outputs) { + return Op::reshape(inputs, outputs); +} + +ErrorCode QNNSubGraphStart::setUp(vector> inputs, vector> outputs) { + for(auto input : inputs) { + input->to(MLLM_QNN); + input->alloc(); + } + + this->backend_->onSetUpStart(inputs, outputs, name_); + return MLLM_NO_ERROR; +} + +ErrorCode QNNSubGraphStart::free(vector> inputs, vector> outputs) { + return MLLM_NO_ERROR; +} + +ErrorCode QNNSubGraphStart::execute(vector> inputs, vector> outputs) { + this->backend_->onExecuteStart(inputs, outputs, name_); + return MLLM_NO_ERROR; +} + + + +} // namespace mllm diff --git a/src/backends/qnn/op/QNNSubGraphStart.hpp b/src/backends/qnn/op/QNNSubGraphStart.hpp new file mode 100644 index 000000000..ebb15824e --- /dev/null +++ b/src/backends/qnn/op/QNNSubGraphStart.hpp @@ -0,0 +1,26 @@ + +#ifndef MLLM_QNNSUBGRAPHSTART_H +#define MLLM_QNNSUBGRAPHSTART_H + +#include "QNNCommonOp.hpp" +namespace mllm { +class QNNSubGraphStart : public QNNCommonOp { +public: + QNNSubGraphStart(Backend *bn, string opName); + virtual ~QNNSubGraphStart() = default; + virtual ErrorCode reshape(vector> inputs, vector> outputs) override; + virtual ErrorCode setUp(vector> inputs, vector> outputs) override; + virtual ErrorCode free(vector> inputs, vector> outputs) override; + virtual ErrorCode execute(vector> inputs, vector> outputs) override; +}; + +class QNNSubGraphStartCreator : public QNNBackend::Creator { +public: + virtual Op *create(OpParam op_param, Backend *bn, string name) const { + return new QNNSubGraphStart(bn, name); + } +}; + +} // namespace mllm + +#endif diff --git a/src/backends/qnn/op/QNNSuperSiLU.cpp b/src/backends/qnn/op/QNNSuperSiLU.cpp index 2e0dcc051..11a2f33cc 100644 --- a/src/backends/qnn/op/QNNSuperSiLU.cpp +++ b/src/backends/qnn/op/QNNSuperSiLU.cpp @@ -31,15 +31,15 @@ ErrorCode QNNSuperSiLU::setUp(vector> inputs, vector()[0] / 127.0; - aScale = roundf(aScale * 100000) / 100000; + // aScale = roundf(aScale * 100000) / 100000; float bScale = 0; bScale = b_scale_.hostPtr()[0] / 127.0; - bScale = roundf(bScale * 100000) / 100000; + // bScale = roundf(bScale * 100000) / 100000; float oScale = 0; oScale = o_scale_.hostPtr()[0] / 127.0; - oScale = roundf(oScale * 100000) / 100000; + // oScale = roundf(oScale * 100000) / 100000; auto paramsSuperSiLuNameA = name() + ".supersilu_params.a_scale"; auto paramsSuperSiLuNameB = name() + ".supersilu_params.b_scale"; diff --git a/src/backends/qnn/op/QNNView.cpp b/src/backends/qnn/op/QNNView.cpp index e383ce226..0a971bcd8 100644 --- a/src/backends/qnn/op/QNNView.cpp +++ b/src/backends/qnn/op/QNNView.cpp @@ -82,7 +82,7 @@ ErrorCode QNNView::reshape(vector> inputs, vector> inputs, vector> outputs) { outputs[0]->setDtype(inputs[0]->dtype()); - if (outputs[0]->dtype() == MLLM_TYPE_I8) + if (outputs[0]->dtype() == MLLM_TYPE_I8 || outputs[0]->dtype() == MLLM_TYPE_I16) return graphAddNode(name(), "Reshape", inputs, outputs, {}, "qti.aisw", true, &scale_); else { return graphAddNode(name(), "Reshape", inputs, outputs, {}, "qti.aisw", true, nullptr); diff --git a/src/models/bert/modeling_bert.hpp b/src/models/bert/modeling_bert.hpp index adb344372..d2f9012c5 100644 --- a/src/models/bert/modeling_bert.hpp +++ b/src/models/bert/modeling_bert.hpp @@ -39,7 +39,7 @@ class BertLayer : public Module { BertLayer() = default; BertLayer(const BertConfig &config, const string &base_name) { // base_name: encoder.layer.n. - attention = MultiHeadAttention(config.hidden_size, config.num_attention_heads, config.num_attention_heads, config.hidden_size / config.num_attention_heads, SPLIT_NONE, false, false, RoPEType::NONE, -1, -1, 0, false, true, config.names_config, base_name + config.names_config._attn_base_name); + attention = MultiHeadAttention(config.hidden_size, config.num_attention_heads, config.num_attention_heads, config.hidden_size / config.num_attention_heads, SPLIT_NONE, false, false, RoPEType::NONE, -1, -1, 0, false, true, config.attn_implementation, config.names_config, base_name + config.names_config._attn_base_name); feed_forward = FeedForward(config.hidden_size, config.intermediate_size, config.hidden_act, true, config.names_config, base_name); diff --git a/src/models/clip/modeling_clip.hpp b/src/models/clip/modeling_clip.hpp index c2cb383b0..c47996b9e 100644 --- a/src/models/clip/modeling_clip.hpp +++ b/src/models/clip/modeling_clip.hpp @@ -23,7 +23,7 @@ class ClipVisionEmbedding final : public Module { position_ids = Parameter(1, std::ceil(img_hw / patch) * std::ceil(img_hw / patch) + 1, 1, 1, base_name + names._position_ids_name); position_embedding = Embedding(std::ceil(img_hw / patch) * std::ceil(img_hw / patch) + 1, hidden_dim, base_name + names._position_embeddings_name); } - vector Forward(vector inputs, vector args) override { + vector Forward(vector inputs, vector args) override { auto embd = patch_embedding(inputs[0]); embd = embd.transpose({{SEQUENCE, DIMENSION}, {HEAD, SEQUENCE}}); // BSHD->BDHS->BDSH embd = embd.flatten(HEAD, SEQUENCE); @@ -42,13 +42,14 @@ class CLipVisionModel final : public Module { public: CLipVisionModel() = default; CLipVisionModel(int hidden_dim, int head_size, int ffn_hidden, const string &act_fn_type, int patch, int img_hw, int block_num, + string attn_implementation, const ViTNameConfig &names, const string &base_name) { embedding = ClipVisionEmbedding(hidden_dim, patch, img_hw, names, base_name + names._embd_name); pre_layrnorm = LayerNorm(hidden_dim, true, 1e-6, base_name + names._vision_pre_layrnorm_name); - blocks = List(block_num, hidden_dim, head_size, ffn_hidden, act_fn_type, names, base_name + names._layer_name); + blocks = List(block_num, hidden_dim, head_size, ffn_hidden, act_fn_type, attn_implementation, names, base_name + names._layer_name); norm = LayerNorm(hidden_dim, true, 1e-6, base_name + names._post_norm_name); } - vector Forward(vector inputs, vector args) override { + vector Forward(vector inputs, vector args) override { auto x = embedding(inputs)[0]; x = pre_layrnorm(x); for (auto &block : blocks) { @@ -70,7 +71,7 @@ class ClipTextMLP final : public Module { up_proj = Linear(hidden_dim, ffn_hidden, true, base_name + names._up_proj_name); act = ACT_FN[act_fn_type](base_name + names._ffn_base_name + "act"); } - vector Forward(vector inputs, vector args) override { + vector Forward(vector inputs, vector args) override { auto x = up_proj(inputs[0]); x = act(x); return {x}; @@ -86,15 +87,18 @@ class ClipTextBlock final : public Module { public: ClipTextBlock() = default; - ClipTextBlock(int hidden_dim, int head_size, int ffn_hidden, const string &act_fn_type, const ClipTextNameConfig &names, const string &base_name) { - attention = MultiHeadAttention(hidden_dim, head_size, head_size, hidden_dim / head_size, SPLIT_NONE, false, false, - RoPEType::NONE, -1,-1, 0, true, true, names, base_name + names._attn_base_name); + ClipTextBlock(int hidden_dim, int head_size, int ffn_hidden, const string &act_fn_type, + string attn_implementation, const ClipTextNameConfig &names, const string &base_name) { + attention = MultiHeadAttention(hidden_dim, head_size, head_size, + hidden_dim / head_size, SPLIT_NONE, false, false, + RoPEType::NONE, -1, -1, 0, true, true, attn_implementation, + names, base_name + names._attn_base_name); mlp = ClipTextMLP(hidden_dim, ffn_hidden, act_fn_type, names, base_name + names._ffn_base_name); down_proj = Linear(ffn_hidden, hidden_dim, true, base_name + names._down_proj_name); norm1 = LayerNorm(hidden_dim, true, 1e-6, base_name + names._attn_norm_name); norm2 = LayerNorm(hidden_dim, true, 1e-6, base_name + names._ffn_norm_name); } - vector Forward(vector inputs, vector args) override { + vector Forward(vector inputs, vector args) override { auto x = norm1(inputs[0]); x = attention({x, x, x})[0]; auto tmp = x + inputs[0]; @@ -118,7 +122,7 @@ class ClipTextEmbedding final : public Module { position_ids = Parameter(1, max_position_embeddings, 1, 1, base_name + names._position_ids_name); position_embedding = Embedding(max_position_embeddings, hidden_dim, base_name + names._position_embeddings_name); } - vector Forward(vector inputs, vector args) override { + vector Forward(vector inputs, vector args) override { auto embd = token_embedding(inputs[0]); auto pos_embd = position_ids().clip({}, {}, {0, embd.sequence()}, {}); auto p_embd = position_embedding(pos_embd); @@ -134,14 +138,17 @@ class CLipTextModel final : public Module { public: CLipTextModel() = default; - CLipTextModel(int hidden_dim, int head_size, int ffn_hidden, const string &act_fn_type, int max_position_embeddings, int vocab_size, int block_num, + CLipTextModel(int hidden_dim, int head_size, int ffn_hidden, const string &act_fn_type, + int max_position_embeddings, int vocab_size, int block_num, + string attn_implementation, const ClipTextNameConfig &names, const string &base_name) { embedding = ClipTextEmbedding(vocab_size, hidden_dim, max_position_embeddings, names, base_name + names._embd_name); - blocks = List(block_num, hidden_dim, head_size, ffn_hidden, act_fn_type, names, base_name + names._layer_name); + blocks = List(block_num, hidden_dim, head_size, ffn_hidden, act_fn_type, + attn_implementation, names, base_name + names._layer_name); norm = LayerNorm(hidden_dim, true, 1e-6, base_name + names._post_norm_name); } - vector Forward(vector inputs, vector args) override { + vector Forward(vector inputs, vector args) override { auto x = embedding(inputs)[0]; for (auto &block : blocks) { x = block({x})[0]; @@ -164,22 +171,27 @@ class CLipModel final : public Module { config.hidden_dim, config.head_size, config.ffn_hidden, config.act_fn_type, config.max_position_embeddings, config.text_vocab_size, config.text_block_num, config.patch, config.img_hw, config.block_num, + config.attn_implementation, config.text_names_config, "text_model", config.names_config, "vision_model"){}; CLipModel(int text_hidden_dim, int text_head_size, int text_ffn_hidden, int vision_hidden_dim, int vision_head_size, int vision_ffn_hidden, const string &act_fn_type, int max_position_embeddings, int vocab_size, int text_block_num, int patch, int img_hw, int vision_block_num, + string attn_implementation, const ClipTextNameConfig &text_names, const string &text_base_name, const ViTNameConfig &vit_names, const string &vision_base_name) { - text_model = CLipTextModel(text_hidden_dim, text_head_size, text_ffn_hidden, act_fn_type, max_position_embeddings, vocab_size, text_block_num, + text_model = CLipTextModel(text_hidden_dim, text_head_size, text_ffn_hidden, + act_fn_type, max_position_embeddings, + vocab_size, text_block_num, + attn_implementation, text_names, text_base_name); text_projection = Linear(text_hidden_dim, text_hidden_dim, false, "text_projection"); vision_model = CLipVisionModel(vision_hidden_dim, vision_head_size, vision_ffn_hidden, act_fn_type, patch, img_hw, vision_block_num, - vit_names, vision_base_name); + attn_implementation, vit_names, vision_base_name); visual_projection = Linear(vision_hidden_dim, text_hidden_dim, false, "visual_projection"); } - vector Forward(vector inputs, vector args) override { + vector Forward(vector inputs, vector args) override { auto text = text_model({inputs[0]})[0]; text = text_projection(text); text = text / text.norm(2); diff --git a/src/models/dclm/modeling_dclm.hpp b/src/models/dclm/modeling_dclm.hpp index 8aa7cca4a..99e805702 100644 --- a/src/models/dclm/modeling_dclm.hpp +++ b/src/models/dclm/modeling_dclm.hpp @@ -58,6 +58,7 @@ class DCLMAttention final : public Module { int attn_hidden_dim_; int head_dim_; int n_heads_; + string attn_implementation_; public: DCLMAttention() = default; @@ -66,14 +67,15 @@ class DCLMAttention final : public Module { attn_hidden_dim_ = cfg.n_heads * head_dim; head_dim_ = head_dim; n_heads_ = cfg.n_heads; + attn_implementation_ = cfg.attn_implementation; in_proj = Linear(cfg.dim, 3 * cfg.n_heads * head_dim, false, base_name + "in_proj"); out_proj = Linear(cfg.n_heads * head_dim, cfg.dim, false, base_name + "out_proj"); q_norm = LayerNorm(cfg.n_heads * head_dim, false, cfg.norm_eps, base_name + "q_norm"); k_norm = LayerNorm(cfg.n_heads * head_dim, false, cfg.norm_eps, base_name + "k_norm"); q_rope = RoPE(cfg.RoPE_type, 10000, cfg.seq_len, base_name + "q_rope"); k_rope = RoPE(cfg.RoPE_type, 10000, cfg.seq_len, base_name + "k_rope"); - k_cache = KVCache(cfg.n_heads, head_dim, 1, cfg.cache_limit, base_name + "k_cache"); - v_cache = KVCache(cfg.n_heads, head_dim, 1, cfg.cache_limit, base_name + "v_cache"); + k_cache = KVCache(cfg.n_heads, head_dim, 1, cfg.cache_limit, (cfg.attn_implementation == "flash_attention_2"), base_name + "k_cache"); + v_cache = KVCache(cfg.n_heads, head_dim, 1, cfg.cache_limit, (cfg.attn_implementation == "flash_attention_2"), base_name + "v_cache"); softmax = Softmax(DIMENSION, true, base_name + "softmax"); } @@ -98,13 +100,17 @@ class DCLMAttention final : public Module { k = k_cache(k); v = v_cache(v); - k = k.transpose(SEQUENCE, DIMENSION); - auto qk = Tensor::mm(q, k); - qk = qk / std::sqrt(head_dim_); + Tensor o; + if (attn_implementation_ == "flash_attention_2") { + o = Tensor::flash_attention2_forward(q, k, v, true); + } else { // eager implementation + k = k.transpose(SEQUENCE, DIMENSION); + auto qk = Tensor::mm(q, k); + qk = qk / std::sqrt(head_dim_); - qk = softmax(qk, k_cache.getCacheSeqLen()); - - auto o = Tensor::mm(qk, v); + qk = softmax(qk, k_cache.getCacheSeqLen()); + o = Tensor::mm(qk, v); + } o = o.view(-1, 1, -1, n_heads_ * head_dim_); o = out_proj(o); return {o}; diff --git a/src/models/fuyu/configuration_fuyu.hpp b/src/models/fuyu/configuration_fuyu.hpp index e240d2c4e..95be8228e 100644 --- a/src/models/fuyu/configuration_fuyu.hpp +++ b/src/models/fuyu/configuration_fuyu.hpp @@ -55,13 +55,14 @@ class FuyuConfig { block_num = 36; patch_size = 30; chl_size = 3; - max_position_embeddings= 16384; + max_position_embeddings = 16384; rope_theta = 25000; } else { throw std::runtime_error("Unsupported model size"); } cache_limit = token_limit; } + string attn_implementation = "flash_attention_2"; // Options: "flash_attention_2", "eager" }; #endif // CONFIG_FUYU_HPP diff --git a/src/models/fuyu/modeling_fuyu.hpp b/src/models/fuyu/modeling_fuyu.hpp index e459fcd95..4ae3e7f98 100644 --- a/src/models/fuyu/modeling_fuyu.hpp +++ b/src/models/fuyu/modeling_fuyu.hpp @@ -22,9 +22,14 @@ class PersimmonBlock final : public Module { public: PersimmonBlock() = default; - PersimmonBlock(int hidden_dim, int head_size, int ffn_hidden, float rope_theta, int max_position_embeddings, int cache_limit, const FuyuNameConfig &names, const string &base_name) { - attention = MultiHeadAttention(hidden_dim, head_size, head_size, hidden_dim / head_size, SPLIT_D_HD, true, false, - PERSIMMONROPE, rope_theta, max_position_embeddings, cache_limit, true, true, names, base_name + names._attn_base_name); + PersimmonBlock(int hidden_dim, int head_size, int ffn_hidden, float rope_theta, int max_position_embeddings, int cache_limit, + string attn_implementation, + const FuyuNameConfig &names, const string &base_name) { + attention = MultiHeadAttention(hidden_dim, head_size, head_size, hidden_dim / head_size, + SPLIT_D_HD, true, false, + PERSIMMONROPE, rope_theta, max_position_embeddings, cache_limit, true, true, + attn_implementation, + names, base_name + names._attn_base_name); mlp = FeedForward(hidden_dim, ffn_hidden, "ReLU2", true, names, base_name + names._ffn_base_name); norm1 = LayerNorm(hidden_dim, true, 1e-6, base_name + names._attn_norm_name); @@ -52,8 +57,14 @@ class Persimmon final : public Module { public: Persimmon() = default; - Persimmon(int hidden_dim, int head_size, int ffn_hidden, float rope_theta, int max_position_embeddings, int cache_limit, int block_num, int vocab_size, const FuyuNameConfig &names) { - blocks = List(block_num, hidden_dim, head_size, ffn_hidden, rope_theta, max_position_embeddings, cache_limit, names, names.blk_name); + Persimmon(int hidden_dim, int head_size, int ffn_hidden, float rope_theta, int max_position_embeddings, + int cache_limit, int block_num, int vocab_size, + string attn_implementation, + const FuyuNameConfig &names) { + blocks = List(block_num, hidden_dim, head_size, ffn_hidden, + rope_theta, max_position_embeddings, cache_limit, + attn_implementation, + names, names.blk_name); norm = LayerNorm(hidden_dim, true, 1e-6, names.post_norm_name); lm_head = Linear(hidden_dim, vocab_size, false, names.lm_head_name); } @@ -89,15 +100,17 @@ class FuyuModel final : public Module { FuyuModel(config.vocab_size, config.hidden_dim, config.head_size, config.ffn_hidden, config.block_num, config.rope_theta, config.max_position_embeddings, config.cache_limit, config.patch_size, config.chl_size, + config.attn_implementation, config.name_config) { } FuyuModel(int vocab_size, int hidden_dim, int head_size, int ffn_hidden, int block_num, float rope_theta, int max_position_embeddings, int cache_limit, int patch_size, int chl_size, + string attn_implementation, const FuyuNameConfig &names) { embed_tokens = Embedding(vocab_size, hidden_dim, names.token_embd_name); vision_embed_tokens = Linear(patch_size * patch_size * chl_size, hidden_dim, true, names.vision_embed_tokens_name); - persimmon = Persimmon(hidden_dim, head_size, ffn_hidden, rope_theta, max_position_embeddings, cache_limit, block_num, vocab_size, names); + persimmon = Persimmon(hidden_dim, head_size, ffn_hidden, rope_theta, max_position_embeddings, cache_limit, block_num, vocab_size, attn_implementation, names); } vector Forward(vector inputs, vector args) override { auto input_ids = embed_tokens(inputs[0]); diff --git a/src/models/gemma/modeling_gemma.hpp b/src/models/gemma/modeling_gemma.hpp index 08ed1788f..8ebd491f8 100644 --- a/src/models/gemma/modeling_gemma.hpp +++ b/src/models/gemma/modeling_gemma.hpp @@ -53,9 +53,12 @@ class GemmaDecoder final : public Module { public: GemmaDecoder() = default; GemmaDecoder(const GemmaConfig &config, const GemmaNameConfig &names, const string &base_name) { - self_atten = MultiHeadAttention(config.hidden_size, config.num_attention_heads, config.num_key_value_heads, + self_atten = MultiHeadAttention(config.hidden_size, config.num_attention_heads, + config.num_key_value_heads, config.hidden_size / config.num_attention_heads, SPLIT_NONE, false, false, - config.RoPE_type, config.rope_theta, config.max_position_embeddings, config.cache_limit, true, false, names, base_name + names._attn_base_name); + config.RoPE_type, config.rope_theta, config.max_position_embeddings, config.cache_limit, true, false, + config.attn_implementation, + names, base_name + names._attn_base_name); mlp = GemmaMLP(config.hidden_size, config.intermediate_size, names, base_name + names._ffn_base_name); input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, true, base_name + names._attn_norm_name); post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, true, base_name + names._ffn_norm_name); diff --git a/src/models/gemma2/modeling_gemma2.hpp b/src/models/gemma2/modeling_gemma2.hpp index 5fb9984db..6f5042ab7 100644 --- a/src/models/gemma2/modeling_gemma2.hpp +++ b/src/models/gemma2/modeling_gemma2.hpp @@ -20,7 +20,7 @@ class Gemma2Attention final : public Module { head_dim = 2048 / num_heads; num_key_value_heads = config.num_key_value_heads; num_key_value_groups = num_heads / num_key_value_heads; - + attn_impl = config.attn_implementation; // init layers q_proj = Linear(hidden_size, head_dim * num_heads, false, base_name + names._q_proj_name); k_proj = Linear(hidden_size, head_dim * num_key_value_heads, false, @@ -32,8 +32,8 @@ class Gemma2Attention final : public Module { base_name + "q_rope"); k_rope = RoPE(config.RoPE_type, config.rope_theta, config.max_position_embeddings, base_name + "k_rope"); - k_cache = KVCache(num_key_value_heads, head_dim, num_key_value_groups, config.cache_limit, base_name + "k_cache"); - v_cache = KVCache(num_key_value_heads, head_dim, num_key_value_groups, config.cache_limit, base_name + "v_cache"); + k_cache = KVCache(num_key_value_heads, head_dim, num_key_value_groups, config.cache_limit, (config.attn_implementation == "flash_attention_2"), base_name + "k_cache"); + v_cache = KVCache(num_key_value_heads, head_dim, num_key_value_groups, config.cache_limit, (config.attn_implementation == "flash_attention_2"), base_name + "v_cache"); softmax = Softmax(DIMENSION, true, base_name + "softmax"); } @@ -56,15 +56,17 @@ class Gemma2Attention final : public Module { key_states = k_cache(key_states); value_states = v_cache(value_states); - // attention weight - auto atten_weight = - Tensor::mm(query_states, key_states.transpose(Chl::SEQUENCE, Chl::DIMENSION)) - / std::sqrt(head_dim); - - atten_weight = softmax(atten_weight, k_cache.getCacheSeqLen()); - - // attention output - auto atten_output = Tensor::mm(atten_weight, value_states); + Tensor atten_output; + if (attn_impl == "flash_attention_2") { + atten_output = Tensor::flash_attention2_forward(query_states, key_states, value_states, true); + } else { // eager implementation + // attention weight + auto atten_weight = + Tensor::mm(query_states, key_states.transpose(Chl::SEQUENCE, Chl::DIMENSION)) + / std::sqrt(head_dim); + atten_weight = softmax(atten_weight, k_cache.getCacheSeqLen()); + atten_output = Tensor::mm(atten_weight, value_states); + } atten_output = atten_output.view(-1, 1, -1, head_dim * num_heads); atten_output = o_proj(atten_output); return {atten_output}; @@ -93,6 +95,7 @@ class Gemma2Attention final : public Module { KVCache k_cache; KVCache v_cache; Softmax softmax; + string attn_impl; }; class Gemma2MLP final : public Module { diff --git a/src/models/imagebind/modeling_imagebind.hpp b/src/models/imagebind/modeling_imagebind.hpp index 2fe620851..be8179dc7 100644 --- a/src/models/imagebind/modeling_imagebind.hpp +++ b/src/models/imagebind/modeling_imagebind.hpp @@ -20,7 +20,9 @@ class EncoderBlock final : public Module { public: EncoderBlock() = default; - EncoderBlock(int hidden_dim, int head_size, int ffn_hidden, const string &model_type, const ImagebindNameConfig &names, const string &base_name) { + EncoderBlock(int hidden_dim, int head_size, int ffn_hidden, const string &model_type, + string attn_implementation, + const ImagebindNameConfig &names, const string &base_name) { bool do_mask = false; bool bias_kv_cat = false; if (model_type == "text") { @@ -31,6 +33,7 @@ class EncoderBlock final : public Module { attention = MultiHeadAttention(hidden_dim, head_size, head_size, hidden_dim / head_size, SPLIT_HD, false, bias_kv_cat, RoPEType::NONE, -1, -1, 0, do_mask, true, + attn_implementation, names, base_name + names._attn_base_name); ffn = FeedForward(hidden_dim, ffn_hidden, "GELU", true, names, base_name + names._ffn_base_name); @@ -83,13 +86,16 @@ class ImagebindVisionModel final : public Module { ImagebindVisionModel(const ImagebindConfig &config) : ImagebindVisionModel(config.vision_hidden_dim, config.vision_head_size, config.vision_ffn_hidden, config.head_hidden_dim, config.patch, config.patch_time, config.img_hw, config.vision_block_num, + config.attn_implementation, config.names_config){}; ImagebindVisionModel(int hidden_dim, int head_size, int ffn_hidden, int head_hidden_dim, int patch, int patch_time, int img_hw, int block_num, + string attn_implementation, const ImagebindNameConfig &names) { embedding = ImagebindVisionEmbedding(hidden_dim, patch, patch_time, img_hw, names, names._vision_embd_name); pre_transformer_layer = LayerNorm(hidden_dim, true, 1e-6, names.vision_pre_transformer_layer_name); - blocks = List(block_num, hidden_dim, head_size, ffn_hidden, "vision", names, names._vision_blocks_name); + blocks = List(block_num, hidden_dim, head_size, ffn_hidden, "vision", + attn_implementation, names, names._vision_blocks_name); norm = LayerNorm(hidden_dim, true, 1e-6, names.vision_post_norm_name); head = Linear(hidden_dim, head_hidden_dim, false, names.vision_head_name); } @@ -133,14 +139,17 @@ class ImagebindTextModel final : public Module { public: ImagebindTextModel() = default; ImagebindTextModel(const ImagebindConfig &config) : - ImagebindTextModel(config.text_hidden_dim, config.text_head_size, config.text_ffn_hidden, config.head_hidden_dim, + ImagebindTextModel(config.text_hidden_dim, config.text_head_size, + config.text_ffn_hidden, config.head_hidden_dim, config.vocab_size, config.max_position_embeddings, config.text_block_num, + config.attn_implementation, config.names_config){}; ImagebindTextModel(int hidden_dim, int head_size, int ffn_hidden, int head_hidden_dim, int vocab_size, int max_position_embeddings, int block_num, + string attn_implementation, const ImagebindNameConfig &names) { embedding = ImagebindTextEmbedding(vocab_size, hidden_dim, max_position_embeddings, names, names._text_embd_name); - blocks = List(block_num, hidden_dim, head_size, ffn_hidden, "text", names, names._text_blocks_name); + blocks = List(block_num, hidden_dim, head_size, ffn_hidden, "text", attn_implementation, names, names._text_blocks_name); norm = LayerNorm(hidden_dim, true, 1e-6, names.text_post_norm_name); head = Linear(hidden_dim, head_hidden_dim, false, names.text_head_name); } @@ -197,12 +206,15 @@ class ImagebindAudioModel final : public Module { ImagebindAudioModel(config.audio_hidden_dim, config.audio_head_size, config.audio_ffn_hidden, config.head_hidden_dim, config.audio_kernal, config.audio_stride, config.audio_h, config.audio_w, config.audio_block_num, + config.attn_implementation, config.names_config){}; ImagebindAudioModel(int hidden_dim, int head_size, int ffn_hidden, int head_hidden_dim, int patch, int stride, int img_h, int img_w, int block_num, + string attn_implementation, const ImagebindNameConfig &names) { embedding = ImagebindAudioEmbedding(hidden_dim, patch, stride, img_h, img_w, names, names._audio_embd_name); - blocks = List(block_num, hidden_dim, head_size, ffn_hidden, "audio", names, names._audio_blocks_name); + blocks = List(block_num, hidden_dim, head_size, ffn_hidden, "audio", + attn_implementation, names, names._audio_blocks_name); norm = LayerNorm(hidden_dim, true, 1e-6, names.audio_post_norm_name); head = Linear(hidden_dim, head_hidden_dim, false, names.audio_head_name); } @@ -235,18 +247,24 @@ class ImagebindModel final : public Module { config.text_hidden_dim, config.text_head_size, config.text_ffn_hidden, config.vocab_size, config.max_position_embeddings, config.text_block_num, config.audio_hidden_dim, config.audio_head_size, config.audio_ffn_hidden, config.audio_kernal, config.audio_stride, config.audio_h, config.audio_w, config.audio_block_num, config.head_hidden_dim, + config.attn_implementation, config.names_config){}; - ImagebindModel(int vision_hidden_dim, int vision_head_size, int vision_ffn_hidden, int patch, int patch_time, int img_hw, int vision_block_num, + ImagebindModel(int vision_hidden_dim, int vision_head_size, int vision_ffn_hidden, int patch, int patch_time, + int img_hw, int vision_block_num, int text_hidden_dim, int text_head_size, int text_ffn_hidden, int vocab_size, int max_position_embeddings, int text_block_num, int audio_hidden_dim, int audio_head_size, int audio_ffn_hidden, int audio_kernal, int audio_stride, int audio_h, int audio_w, int audio_block_num, int head_hidden_dim, + string attn_implementation, const ImagebindNameConfig &names) { - text_model = ImagebindTextModel(text_hidden_dim, text_head_size, text_ffn_hidden, head_hidden_dim, - vocab_size, max_position_embeddings, text_block_num, names); - vision_model = ImagebindVisionModel(vision_hidden_dim, vision_head_size, vision_ffn_hidden, head_hidden_dim, - patch, patch_time, img_hw, vision_block_num, names); - audio_model = ImagebindAudioModel(audio_hidden_dim, audio_head_size, audio_ffn_hidden, head_hidden_dim, - audio_kernal, audio_stride, audio_h, audio_w, audio_block_num, names); + text_model = ImagebindTextModel(text_hidden_dim, text_head_size, + text_ffn_hidden, head_hidden_dim, + vocab_size, max_position_embeddings, text_block_num, attn_implementation, names); + vision_model = ImagebindVisionModel(vision_hidden_dim, vision_head_size, + vision_ffn_hidden, head_hidden_dim, + patch, patch_time, img_hw, vision_block_num, attn_implementation, names); + audio_model = ImagebindAudioModel(audio_hidden_dim, audio_head_size, + audio_ffn_hidden, head_hidden_dim, + audio_kernal, audio_stride, audio_h, audio_w, audio_block_num, attn_implementation, names); softmax = Softmax(DIMENSION, "final.softmax1"); softmax2 = Softmax(DIMENSION, "final.softmax2"); } diff --git a/src/models/llama/configuration_llama.hpp b/src/models/llama/configuration_llama.hpp index f97ede2be..46256e6a6 100644 --- a/src/models/llama/configuration_llama.hpp +++ b/src/models/llama/configuration_llama.hpp @@ -75,9 +75,11 @@ class LLaMAConfig : public TransformerConfig { float rope_theta; int max_position_embeddings; - explicit LLaMAConfig(int token_limit, string billions = "7B", RoPEType type = LLAMAROPE, int vocab = 32000) { + explicit LLaMAConfig(int token_limit, string billions = "7B", RoPEType type = LLAMAROPE, int vocab = 32000, + string attn_implementation_ = "flash_attention_2") { names_config.init(type); vocab_size = vocab; + attn_implementation = attn_implementation_; if (billions == "7B" || billions == "7b") { hidden_dim = 4096; head_size = 32; diff --git a/src/models/llama/modeling_llama.hpp b/src/models/llama/modeling_llama.hpp index 5e559ca59..7d8eabe26 100644 --- a/src/models/llama/modeling_llama.hpp +++ b/src/models/llama/modeling_llama.hpp @@ -44,9 +44,15 @@ class LLaMABlock final : public Module { public: LLaMABlock() = default; - LLaMABlock(int hidden_dim, int head_size, int kv_head_size, int ffn_hidden, RoPEType RoPE_type, float rope_theta, int max_position_embeddings, int cache_limit, const LLaMANameConfig &names, const string &base_name) { - attention = MultiHeadAttention(hidden_dim, head_size, kv_head_size, hidden_dim / head_size, SPLIT_NONE, false, false, - RoPE_type, rope_theta, max_position_embeddings, cache_limit, true, false, names, base_name + names._attn_base_name); + LLaMABlock(int hidden_dim, int head_size, int kv_head_size, int ffn_hidden, + RoPEType RoPE_type, float rope_theta, int max_position_embeddings, int cache_limit, + string attn_implementation, + const LLaMANameConfig &names, const string &base_name) { + attention = MultiHeadAttention(hidden_dim, head_size, kv_head_size, hidden_dim / head_size, + SPLIT_NONE, false, false, + RoPE_type, rope_theta, max_position_embeddings, cache_limit, true, false, + attn_implementation, + names, base_name + names._attn_base_name); mlp = LLaMAMLP(hidden_dim, ffn_hidden, names, base_name + names._ffn_base_name); norm1 = RMSNorm(hidden_dim, 1e-6, base_name + names._attn_norm_name); norm2 = RMSNorm(hidden_dim, 1e-6, base_name + names._ffn_norm_name); @@ -74,14 +80,24 @@ class LLaMAModel final : public Module { public: explicit LLaMAModel(const LLaMAConfig &config) : - LLaMAModel(config.vocab_size, config.hidden_dim, config.head_size, config.num_key_value_heads, config.ffn_hidden, config.block_num, - config.RoPE_type, config.rope_theta, config.max_position_embeddings, config.cache_limit, + LLaMAModel(config.vocab_size, config.hidden_dim, config.head_size, + config.num_key_value_heads, config.ffn_hidden, config.block_num, + config.RoPE_type, config.rope_theta, + config.max_position_embeddings, config.cache_limit, + config.attn_implementation, config.names_config, config.names_config.blk_name) { } - LLaMAModel(int vocab_size, int hidden_dim, int head_size, int kv_head_size, int ffn_hidden, int block_num, RoPEType RoPE_type, float rope_theta, int max_position_embeddings, int cache_limit, + LLaMAModel(int vocab_size, int hidden_dim, int head_size, + int kv_head_size, int ffn_hidden, int block_num, + RoPEType RoPE_type, float rope_theta, int max_position_embeddings, int cache_limit, + string attn_implementation, const LLaMANameConfig &names, const string &base_name) { embedding = Embedding(vocab_size, hidden_dim, names.token_embd_name); - blocks = List(block_num, hidden_dim, head_size, kv_head_size, ffn_hidden, RoPE_type, rope_theta, max_position_embeddings, cache_limit, names, base_name); + blocks = List(block_num, hidden_dim, head_size, + kv_head_size, ffn_hidden, RoPE_type, + rope_theta, max_position_embeddings, cache_limit, + attn_implementation, + names, base_name); norm = RMSNorm(hidden_dim, 1e-6, names.post_norm_name); lm_head = Linear(hidden_dim, vocab_size, false, names.lm_head_name); } diff --git a/src/models/llama/modeling_sparse_llama.hpp b/src/models/llama/modeling_sparse_llama.hpp index c496eb471..d59363239 100644 --- a/src/models/llama/modeling_sparse_llama.hpp +++ b/src/models/llama/modeling_sparse_llama.hpp @@ -24,13 +24,13 @@ class SparseLLaMAMLP final : public Module { gate_proj = Linear(hidden_dim, ffn_hidden, false, base_name + names._gate_proj_name); relu = ReLU(base_name + "act"); up_proj = SparseIdLinear(hidden_dim, ffn_hidden, base_name + names._up_proj_name); - if(is_down_sparse) { + if (is_down_sparse) { down_proj = SparseLinear(ffn_hidden, hidden_dim, base_name + names._down_proj_name); - }else{ + } else { down_proj = Linear(ffn_hidden, hidden_dim, false, base_name + names._down_proj_name); } } - vector Forward(vector inputs, vector args) override { + vector Forward(vector inputs, vector args) override { auto x = inputs[0]; auto id = gate_proj(inputs[0]); auto gate = relu(id); @@ -49,14 +49,17 @@ class SparseLLaMABlock final : public Module { public: SparseLLaMABlock() = default; - SparseLLaMABlock(bool is_down_sparse, int hidden_dim, int head_size, int ffn_hidden, RoPEType RoPE_type, float rope_theta, int max_position_embeddings, int cache_limit, const LLaMANameConfig &names, const string &base_name) { - attention = MultiHeadAttention(hidden_dim, head_size, head_size, hidden_dim / head_size, SPLIT_NONE, false, false, - RoPE_type, rope_theta, max_position_embeddings, cache_limit, true, false, names, base_name + names._attn_base_name); + SparseLLaMABlock(bool is_down_sparse, int hidden_dim, int head_size, int ffn_hidden, RoPEType RoPE_type, float rope_theta, int max_position_embeddings, int cache_limit, string attn_implementation, const LLaMANameConfig &names, const string &base_name) { + attention = MultiHeadAttention(hidden_dim, head_size, head_size, hidden_dim / head_size, + SPLIT_NONE, false, false, + RoPE_type, rope_theta, max_position_embeddings, cache_limit, true, false, + attn_implementation, + names, base_name + names._attn_base_name); mlp = SparseLLaMAMLP(hidden_dim, ffn_hidden, names, base_name + names._ffn_base_name, is_down_sparse); norm1 = RMSNorm(hidden_dim, 1e-6, base_name + names._attn_norm_name); norm2 = RMSNorm(hidden_dim, 1e-6, base_name + names._ffn_norm_name); } - vector Forward(vector inputs, vector args) override { + vector Forward(vector inputs, vector args) override { auto x = norm1(inputs[0]); x = attention({x, x, x})[0]; auto tmp = x + inputs[0]; @@ -75,19 +78,22 @@ class SparseLLaMAModel final : public Module { public: explicit SparseLLaMAModel(const LLaMAConfig &config, bool is_down_sparse = false) : - SparseLLaMAModel(config.vocab_size, config.hidden_dim, config.head_size, config.ffn_hidden, config.block_num, config.RoPE_type, - config.rope_theta, config.max_position_embeddings, config.cache_limit, - config.names_config, config.names_config.blk_name, is_down_sparse) { + SparseLLaMAModel(config.vocab_size, config.hidden_dim, config.head_size, + config.ffn_hidden, config.block_num, config.RoPE_type, + config.rope_theta, config.max_position_embeddings, config.cache_limit, + config.attn_implementation, + config.names_config, config.names_config.blk_name, is_down_sparse) { } - SparseLLaMAModel(int vocab_size, int hidden_dim, int head_size, int ffn_hidden, int block_num, RoPEType RoPE_type, - float rope_theta, int max_position_embeddings, int cache_limit, - const LLaMANameConfig &names, const string &base_name, bool is_down_sparse) { + SparseLLaMAModel(int vocab_size, int hidden_dim, int head_size, int ffn_hidden, int block_num, RoPEType RoPE_type, + float rope_theta, int max_position_embeddings, int cache_limit, + string attn_implementation, + const LLaMANameConfig &names, const string &base_name, bool is_down_sparse) { embedding = Embedding(vocab_size, hidden_dim, names.token_embd_name); - blocks = List(block_num, is_down_sparse, hidden_dim, head_size, ffn_hidden, RoPE_type, rope_theta, max_position_embeddings, cache_limit, names, base_name); + blocks = List(block_num, is_down_sparse, hidden_dim, head_size, ffn_hidden, RoPE_type, rope_theta, max_position_embeddings, cache_limit, attn_implementation, names, base_name); norm = RMSNorm(hidden_dim, 1e-6, names.post_norm_name); lm_head = Linear(hidden_dim, vocab_size, false, names.lm_head_name); } - vector Forward(vector inputs, vector args) override { + vector Forward(vector inputs, vector args) override { auto x = embedding(inputs[0]); for (auto &block : blocks) { x = block({x})[0]; diff --git a/src/models/llama3/modeling_llama3.hpp b/src/models/llama3/modeling_llama3.hpp index 66020eb7f..63eeb10b4 100644 --- a/src/models/llama3/modeling_llama3.hpp +++ b/src/models/llama3/modeling_llama3.hpp @@ -45,16 +45,18 @@ class Llama3Attention final : public Module { int head_size_; // Size of each attention head int kv_head_size_; // Size of each key/value head int hidden_dim_; // Hidden dimension size + string attn_impl; public: Llama3Attention() = default; Llama3Attention(int hidden_dim, int head_size, int kv_head_size, RoPEType RoPE_type, float rope_theta, - int max_position_embeddings, int cache_limit, const TransformerNameConfig &names, - const string &base_name, const RoPEConfig &rope_config = {}) { + int max_position_embeddings, int cache_limit, string attn_implementation, + const TransformerNameConfig &names, const string &base_name, const RoPEConfig &rope_config = {}) { hidden_dim_ = hidden_dim; head_size_ = head_size; kv_head_size_ = kv_head_size; + attn_impl = attn_implementation; // Initialize projections q_proj = Linear(hidden_dim, head_size * (hidden_dim / head_size), false, base_name + names._q_proj_name); @@ -73,8 +75,8 @@ class Llama3Attention final : public Module { // Initialize KV cache if (cache_limit > 0) { - k_cache = KVCache(kv_head_size, hidden_dim / head_size, head_size / kv_head_size, cache_limit, base_name + "k_cache"); - v_cache = KVCache(kv_head_size, hidden_dim / head_size, head_size / kv_head_size, cache_limit, base_name + "v_cache"); + k_cache = KVCache(kv_head_size, hidden_dim / head_size, head_size / kv_head_size, cache_limit, (attn_impl == "flash_attention_2"), base_name + "k_cache"); + v_cache = KVCache(kv_head_size, hidden_dim / head_size, head_size / kv_head_size, cache_limit, (attn_impl == "flash_attention_2"), base_name + "v_cache"); } // Initialize softmax @@ -102,23 +104,27 @@ class Llama3Attention final : public Module { k = k_cache(k); v = v_cache(v); } - - // Transpose keys for dot product - k = k.transpose(SEQUENCE, DIMENSION); - - // Compute attention scores - Tensor qk = Tensor::mm(q, k); // Dot product of queries and keys - qk = qk / std::sqrt(hidden_dim_ / head_size_); // Scale by sqrt(d_k) - - // Apply softmax - if (k_cache.ready() && v_cache.ready()) { - qk = softmax(qk, k_cache.getCacheSeqLen()); // Masked softmax if cache is used - } else { - qk = softmax(qk); // Regular softmax + Tensor o; + if (attn_impl == "flash_attention_2") { + o = Tensor::flash_attention2_forward(q, k, v, true); + } else { // eager implementation + // Transpose keys for dot product + k = k.transpose(SEQUENCE, DIMENSION); + + // Compute attention scores + Tensor qk = Tensor::mm(q, k); // Dot product of queries and keys + qk = qk / std::sqrt(hidden_dim_ / head_size_); // Scale by sqrt(d_k) + + // Apply softmax + if (k_cache.ready() && v_cache.ready()) { + qk = softmax(qk, k_cache.getCacheSeqLen()); // Masked softmax if cache is used + } else { + qk = softmax(qk); // Regular softmax + } + + // Compute attention output + o = Tensor::mm(qk, v); // Weighted sum of values } - - // Compute attention output - Tensor o = Tensor::mm(qk, v); // Weighted sum of values o = o.view(-1, 1, -1, hidden_dim_); // Reshape to original dimensions o = o_proj(o); // Output projection @@ -154,7 +160,8 @@ class Llama3Block final : public Module { } attention = Llama3Attention(hidden_dim, head_size, kv_head_size, RoPE_type, rope_theta, - max_position_embeddings, cache_limit, names, base_name + names._attn_base_name, rope_config); + max_position_embeddings, cache_limit, config.attn_implementation, + names, base_name + names._attn_base_name, rope_config); mlp = Llama3MLP(hidden_dim, ffn_hidden, names, base_name + names._ffn_base_name); norm1 = RMSNorm(hidden_dim, 1e-6, base_name + names._attn_norm_name); norm2 = RMSNorm(hidden_dim, 1e-6, base_name + names._ffn_norm_name); @@ -198,8 +205,8 @@ class Llama3Model final : public Module { // we just simply use the token embedding as the lm_head // but now we are not really tying the word embeddings auto lm_head_name = names.lm_head_name; - if (config.tie_word_embeddings) - lm_head_name = names.token_embd_name; + assert(config.tie_word_embeddings); + lm_head_name = names.token_embd_name; lm_head = Parameter(1, vocab_size, 1, hidden_dim, lm_head_name + ".weight"); } vector Forward(vector inputs, vector args) override { diff --git a/src/models/llava/modeling_llava.hpp b/src/models/llava/modeling_llava.hpp index eafa26b92..8aca0088f 100644 --- a/src/models/llava/modeling_llava.hpp +++ b/src/models/llava/modeling_llava.hpp @@ -20,9 +20,8 @@ class LLaMABodyModel final : public Module { public: LLaMABodyModel() = default; - LLaMABodyModel(int vocab_size, int hidden_dim, int head_size, int ffn_hidden, int block_num, RoPEType RoPE_type, float rope_theta, int max_position_embeddings, int cache_limit, - const LLaMANameConfig &names, const string &base_name) { - blocks = List(block_num, hidden_dim, head_size, head_size, ffn_hidden, RoPE_type, rope_theta, max_position_embeddings, cache_limit, names, base_name); + LLaMABodyModel(int vocab_size, int hidden_dim, int head_size, int ffn_hidden, int block_num, RoPEType RoPE_type, float rope_theta, int max_position_embeddings, int cache_limit, string attn_implementation, const LLaMANameConfig &names, const string &base_name) { + blocks = List(block_num, hidden_dim, head_size, head_size, ffn_hidden, RoPE_type, rope_theta, max_position_embeddings, cache_limit, attn_implementation, names, base_name); norm = RMSNorm(hidden_dim, 1e-6, names.post_norm_name); lm_head = Linear(hidden_dim, vocab_size, false, names.lm_head_name); } @@ -74,10 +73,11 @@ class LLaVAVisionModel final : public Module { public: LLaVAVisionModel() = default; LLaVAVisionModel(int hidden_dim, int head_size, int ffn_hidden, int patch, int img_hw, int block_num, + string attn_implementation, const ViTNameConfig &names, const string &base_name) { embedding = LLaVAVisionEmbedding(hidden_dim, patch, img_hw, names, base_name + names._embd_name); pre_layrnorm = LayerNorm(hidden_dim, true, 1e-6, base_name + names._vision_pre_layrnorm_name); - blocks = List(block_num, hidden_dim, head_size, ffn_hidden, "QuickGELU", names, base_name + names._layer_name); + blocks = List(block_num, hidden_dim, head_size, ffn_hidden, "QuickGELU", attn_implementation, names, base_name + names._layer_name); clip_len_ = std::ceil(img_hw / patch) * std::ceil(img_hw / patch) + 1; linear_1 = Linear(hidden_dim, ffn_hidden, true, "multi_modal_projector.linear_1"); gelu = GELU("multi_modal_projector.act"); @@ -106,19 +106,21 @@ class LLaVAModel final : public Module { LLaVAModel(config.vocab_size, config.hidden_dim, config.head_size, config.ffn_hidden, config.block_num, config.RoPE_type, config.rope_theta, config.max_position_embeddings, config.cache_limit, config.names_config, - config.vision_hidden_dim, config.vision_head_size, config.vision_ffn_hidden, config.patch, config.img_hw, config.vision_block_num, + config.vision_hidden_dim, config.vision_head_size, config.vision_ffn_hidden, config.patch, config.img_hw, config.vision_block_num, config.attn_implementation, config.vit_names_config) { } LLaVAModel(int vocab_size, int hidden_dim, int head_size, int ffn_hidden, int block_num, RoPEType RoPE_type, float rope_theta, int max_position_embeddings, int cache_limit, const LLaMANameConfig &names_config, int vision_hidden_dim, int vision_head_size, int vision_ffn_hidden, int patch, int img_hw, int vision_block_num, + string attn_implementation, const ViTNameConfig &vit_names_config) { text_embedding = Embedding(vocab_size, hidden_dim, names_config.token_embd_name); llama_body = LLaMABodyModel(vocab_size, hidden_dim, head_size, ffn_hidden, block_num, - RoPE_type, rope_theta, max_position_embeddings, cache_limit, + RoPE_type, rope_theta, max_position_embeddings, cache_limit, attn_implementation, names_config, names_config.blk_name); vision_tower = LLaVAVisionModel(vision_hidden_dim, vision_head_size, vision_ffn_hidden, patch, img_hw, vision_block_num, + attn_implementation, vit_names_config, vit_names_config.vison_model_name); } vector Forward(vector inputs, vector args) override { diff --git a/src/models/minicpm/modeling_minicpm.hpp b/src/models/minicpm/modeling_minicpm.hpp index 163bc0b47..502357032 100644 --- a/src/models/minicpm/modeling_minicpm.hpp +++ b/src/models/minicpm/modeling_minicpm.hpp @@ -41,10 +41,13 @@ class MiniCPMDecoder final : public Module { public: MiniCPMDecoder() = default; MiniCPMDecoder(const MiniCPMConfig &config, const MiniCPMNameConfig &names, const string &base_name) { - self_atten = MultiHeadAttention(config.hidden_size, config.num_attention_heads, config.num_key_value_heads, + self_atten = MultiHeadAttention(config.hidden_size, config.num_attention_heads, + config.num_key_value_heads, config.hidden_size / config.num_attention_heads, SPLIT_NONE, false, false, config.RoPE_type, config.rope_theta, config.max_position_embeddings, config.cache_limit, - true, false, names, base_name + names._attn_base_name); + true, false, + config.attn_implementation, + names, base_name + names._attn_base_name); mlp = MiniCPMMLP(config.hidden_size, config.intermediate_size, names, base_name + names._ffn_base_name); input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, base_name + names._attn_norm_name); post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, base_name + names._ffn_norm_name); diff --git a/src/models/minicpm_moe/mbm/modeling_minicpm_moe_mbm.hpp b/src/models/minicpm_moe/mbm/modeling_minicpm_moe_mbm.hpp index bdb8f7865..c2ebd5567 100644 --- a/src/models/minicpm_moe/mbm/modeling_minicpm_moe_mbm.hpp +++ b/src/models/minicpm_moe/mbm/modeling_minicpm_moe_mbm.hpp @@ -98,116 +98,99 @@ class MiniCPMMoE final : public Module { expert_weights = expert_weights.view(-1, -1, 1, 1); // 1, k* batch*seq, 1, 1 auto idxs = expert_indices.argsort(); // 1, 1, 1, k* batch*seq auto tokens_per_expert = expert_indices.bincount(); // (1, 1, 1, 0) 1, 1, 1, k - /* - load_experts_1th(tokens_per_expert); - auto expert_cache = moe_infer(hidden_states, tokens_per_expert, expert_weights, idxs); - */ + Tensor expert_cache; #ifdef MTIME - if (Tensor::tensor_status == TENSOR_STATIC_READY && hidden_states.sequence() == 1) { - std::cout << "attn || exe time: " << (mllm_time_us() - end_infer_last) / 1000.0F << "ms" << std::endl; - } + std::cout << "attn || exe time: " << (mllm_time_us() - end_infer_last) / 1000.0F << "ms" << std::endl; #endif - if (Tensor::tensor_status == TENSOR_STATIC_READY) { - vector tokens_per_expert_vector; - for (int i = 0; i < tokens_per_expert.dimension(); ++i) { - if (tokens_per_expert.d(0, 0, 0, i)) { - tokens_per_expert_vector.push_back(i); - } - } - // - if (layer_idx < 39 && tokens_per_expert_vector.size() == 2) { - if (mbm_maps[layer_idx].find(tokens_per_expert_vector) != mbm_maps[layer_idx].end()) { - mbm_load_expert_idxs.clear(); - auto c = mbm_maps[layer_idx][tokens_per_expert_vector]; - mbm_load_expert_idxs = c[0]; - mbm_load_layer_idx = layer_idx + 1; - do_mbm_load = true; - } - } else if (layer_idx == 39 && tokens_per_expert_vector.size() == 2) { - if (mbm_maps[layer_idx].find(tokens_per_expert_vector) != mbm_maps[layer_idx].end()) { - mbm_load_expert_idxs.clear(); - auto c = mbm_maps[layer_idx][tokens_per_expert_vector]; - mbm_load_expert_idxs = c[0]; - mbm_load_layer_idx = 0; - do_mbm_load = true; - } - } - /* - mbm_load_expert_idxs = mbm_idxs; - mbm_load_layer_idx = layer_idx; - do_mbm_load = true; - */ - if (mbm_idxs_size == 2 && tokens_per_expert_vector.size() == 2) { // layer_idx > 0 && && layer_idx < 39 - int &done = dones[layer_idx]; // 标志变量,用于表示数据是否已被修改 - cvs[layer_idx]->wait(locks[layer_idx], [&done] { return done; }); // 等待条件满足 - assert(dones[layer_idx]); + vector tokens_per_expert_vector; + for (int i = 0; i < tokens_per_expert.dimension(); ++i) { + if (tokens_per_expert.d(0, 0, 0, i)) { + tokens_per_expert_vector.push_back(i); } - if (!experts_loaded(tokens_per_expert_vector)) { - load_experts(tokens_per_expert_vector); + } + // + if (layer_idx < 39 && tokens_per_expert_vector.size() == 2) { + if (mbm_maps[layer_idx].find(tokens_per_expert_vector) != mbm_maps[layer_idx].end()) { + mbm_load_expert_idxs.clear(); + auto c = mbm_maps[layer_idx][tokens_per_expert_vector]; + mbm_load_expert_idxs = c[0]; + mbm_load_layer_idx = layer_idx + 1; + do_mbm_load = true; } - assert(experts_loaded(tokens_per_expert_vector)); - expert_cache = moe_infer(hidden_states, tokens_per_expert, expert_weights, idxs); - if (mbm_idxs_size == 2 && tokens_per_expert_vector.size() == 2) { // layer_idx > 0 && && layer_idx < 39 - reset_syntax_mbm(layer_idx); + } else if (layer_idx == 39 && tokens_per_expert_vector.size() == 2) { + if (mbm_maps[layer_idx].find(tokens_per_expert_vector) != mbm_maps[layer_idx].end()) { + mbm_load_expert_idxs.clear(); + auto c = mbm_maps[layer_idx][tokens_per_expert_vector]; + mbm_load_expert_idxs = c[0]; + mbm_load_layer_idx = 0; + do_mbm_load = true; } - if (layer_idx == 0) - mbm_idxs_size = tokens_per_expert_vector.size(); - } else { - expert_cache = moe_infer(hidden_states, tokens_per_expert, expert_weights, idxs); } -#ifdef MTIME - if (Tensor::tensor_status == TENSOR_STATIC_READY && hidden_states.sequence() == 1) { - end_infer_last = mllm_time_us(); + if (mbm_idxs_size == 2 && tokens_per_expert_vector.size() == 2) { // layer_idx > 0 && && layer_idx < 39 + int &done = dones[layer_idx]; // 标志变量,用于表示数据是否已被修改 + cvs[layer_idx]->wait(locks[layer_idx], [&done] { return done; }); // 等待条件满足 + assert(dones[layer_idx]); + } + if (!experts_loaded(tokens_per_expert_vector)) { + load_experts(tokens_per_expert_vector); + } + assert(experts_loaded(tokens_per_expert_vector)); + expert_cache = moe_infer(hidden_states, tokens_per_expert, expert_weights, idxs); + if (mbm_idxs_size == 2 && tokens_per_expert_vector.size() == 2) { // layer_idx > 0 && && layer_idx < 39 + reset_syntax_mbm(layer_idx); } + if (layer_idx == 0) + mbm_idxs_size = tokens_per_expert_vector.size(); +#ifdef MTIME + end_infer_last = mllm_time_us(); #endif return {expert_cache}; } void load_experts(vector expert_idxs) { - if (Tensor::tensor_status == TENSOR_STATIC_READY) { #ifdef MTIME - auto start_infer = mllm_time_us(); + auto start_infer = mllm_time_us(); #endif - int result; - // #pragma omp parallel for num_threads(CPUBackend::cpu_threads) - for (int i = 0; i < expert_idxs.size(); ++i) { - if (expert_idxs.size() == 2) { - if (std::find(mbm_v[layer_idx].begin(), mbm_v[layer_idx].end(), expert_idxs[i]) != mbm_v[layer_idx].end()) { - // 在 mbm_v[layer_idx] 中找到了 expert_idxs[i] - if (experts[expert_idxs[i]].loaded()) { - continue; - } else { - std::cout << "[ERROR] experts load." << std::endl; - experts[expert_idxs[i]].load(); - continue; - } - } - if (mbm_v[layer_idx].size() >= mbm_num_max_experts) { - result = mbm_queue_remove(mbm_v[layer_idx], expert_idxs); - if (result != -1) { // mbm_v[layer_idx]不全是expert_idxs - experts[result].free(); - mbm_v[layer_idx].push_back(expert_idxs[i]); - // if (mbm_load_layer_idx != layer_idx) - // std::cout << layer_idx << " " << mbm_load_layer_idx << " : " << expert_idxs[i] << std::endl; - experts[expert_idxs[i]].load(); - } + int result; + // #pragma omp parallel for num_threads(CPUBackend::cpu_threads) + for (int i = 0; i < expert_idxs.size(); ++i) { + if (expert_idxs.size() == 2) { + if (std::find(mbm_v[layer_idx].begin(), mbm_v[layer_idx].end(), expert_idxs[i]) != mbm_v[layer_idx].end()) { + // 在 mbm_v[layer_idx] 中找到了 expert_idxs[i] + if (experts[expert_idxs[i]].loaded()) { + continue; } else { + std::cout << "[ERROR] experts load." << std::endl; + experts[expert_idxs[i]].load(); + continue; + } + } + if (mbm_v[layer_idx].size() >= mbm_num_max_experts) { + result = mbm_queue_remove(mbm_v[layer_idx], expert_idxs); + if (result != -1) { // mbm_v[layer_idx]不全是expert_idxs + experts[result].free(); mbm_v[layer_idx].push_back(expert_idxs[i]); + // if (mbm_load_layer_idx != layer_idx) + // std::cout << layer_idx << " " << mbm_load_layer_idx << " : " << expert_idxs[i] << std::endl; experts[expert_idxs[i]].load(); } - assert(experts[expert_idxs[i]].loaded()); } else { + mbm_v[layer_idx].push_back(expert_idxs[i]); experts[expert_idxs[i]].load(); } + assert(experts[expert_idxs[i]].loaded()); + } else { + experts[expert_idxs[i]].load(); } + } #ifdef MTIME - if (expert_idxs.size() == 2) { - auto end_infer = mllm_time_us(); - std::cout << "expert|| load time: " << (end_infer - start_infer) / 1000.0F << "ms" << std::endl; - } -#endif + if (expert_idxs.size() == 2) { + auto end_infer = mllm_time_us(); + std::cout << "expert|| load time: " << (end_infer - start_infer) / 1000.0F << "ms" << std::endl; } +#endif + // } } private: @@ -245,10 +228,8 @@ class MiniCPMMoE final : public Module { } } void free_experts(vector expert_idxs) { - if (Tensor::tensor_status == TENSOR_STATIC_READY) { - for (int i = 0; i < expert_idxs.size(); ++i) { - experts[expert_idxs[i]].free(); - } + for (int i = 0; i < expert_idxs.size(); ++i) { + experts[expert_idxs[i]].free(); } } Tensor moe_infer(Tensor &hidden_states, Tensor &tokens_per_expert, Tensor &expert_weights, Tensor &idxs) { @@ -287,10 +268,8 @@ class MiniCPMMoE final : public Module { // expert_cache.view(ANYDIM, seq, -1, -1); } #ifdef MTIME - if (Tensor::tensor_status == TENSOR_STATIC_READY && hidden_states.sequence() == 1) { - auto end_infer = mllm_time_us(); - std::cout << "expert|| exe time: " << (end_infer - start_infer) / 1000.0F << "ms" << std::endl; - } + auto end_infer = mllm_time_us(); + std::cout << "expert|| exe time: " << (end_infer - start_infer) / 1000.0F << "ms" << std::endl; #endif return expert_cache; } @@ -306,10 +285,12 @@ class MiniCPMDecoder final : public Module { public: MiniCPMDecoder() = default; MiniCPMDecoder(const MiniCPMConfig &config, const MiniCPMNameConfig &names, const string &base_name) { - self_atten = MultiHeadAttention(config.hidden_size, config.num_attention_heads, config.num_key_value_heads, + self_atten = MultiHeadAttention(config.hidden_size, config.num_attention_heads, + config.num_key_value_heads, config.hidden_size / config.num_attention_heads, SPLIT_NONE, false, false, config.RoPE_type, config.rope_theta, config.max_position_embeddings, config.cache_limit, - true, false, names, base_name + names._attn_base_name); + true, false, + config.attn_implementation, names, base_name + names._attn_base_name); moe = MiniCPMMoE(config, names, base_name + names._ffn_base_name); input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, base_name + names._attn_norm_name); post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, base_name + names._ffn_norm_name); @@ -389,7 +370,7 @@ class MiniCPMForCausalLM final : public Module { } std::vector Forward(std::vector inputs, std::vector args) override { std::vector outputs; - if (Tensor::tensor_status == TENSOR_STATIC_READY && inputs[0].sequence() == 1) { + if (inputs[0].dimension() == 1) { omp_set_max_active_levels(2); // Enable OpenMP nesting #pragma omp parallel num_threads(2) if (omp_get_thread_num() == 0) { // 根据线程ID决定执行哪个函数 diff --git a/src/models/minicpm_moe/mbp/modeling_minicpm_moe_mbp.hpp b/src/models/minicpm_moe/mbp/modeling_minicpm_moe_mbp.hpp index 7739090b5..f2254cdf6 100644 --- a/src/models/minicpm_moe/mbp/modeling_minicpm_moe_mbp.hpp +++ b/src/models/minicpm_moe/mbp/modeling_minicpm_moe_mbp.hpp @@ -292,10 +292,12 @@ class MiniCPMDecoder final : public Module { public: MiniCPMDecoder() = default; MiniCPMDecoder(const MiniCPMConfig &config, const MiniCPMNameConfig &names, const string &base_name) { - self_atten = MultiHeadAttention(config.hidden_size, config.num_attention_heads, config.num_key_value_heads, + self_atten = MultiHeadAttention(config.hidden_size, config.num_attention_heads, + config.num_key_value_heads, config.hidden_size / config.num_attention_heads, SPLIT_NONE, false, false, config.RoPE_type, config.rope_theta, config.max_position_embeddings, config.cache_limit, - true, false, names, base_name + names._attn_base_name); + true, false, + config.attn_implementation, names, base_name + names._attn_base_name); moe = MiniCPMMoE(config, names, base_name + names._ffn_base_name); input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, base_name + names._attn_norm_name); post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, base_name + names._ffn_norm_name); diff --git a/src/models/minicpm_moe/modeling_minicpm_moe.hpp b/src/models/minicpm_moe/modeling_minicpm_moe.hpp index 6ba25aa79..96b10f25a 100644 --- a/src/models/minicpm_moe/modeling_minicpm_moe.hpp +++ b/src/models/minicpm_moe/modeling_minicpm_moe.hpp @@ -106,10 +106,12 @@ class MiniCPMDecoder final : public Module { public: MiniCPMDecoder() = default; MiniCPMDecoder(const MiniCPMConfig &config, const MiniCPMNameConfig &names, const string &base_name) { - self_atten = MultiHeadAttention(config.hidden_size, config.num_attention_heads, config.num_key_value_heads, + self_atten = MultiHeadAttention(config.hidden_size, config.num_attention_heads, + config.num_key_value_heads, config.hidden_size / config.num_attention_heads, SPLIT_NONE, false, false, config.RoPE_type, config.rope_theta, config.max_position_embeddings, config.cache_limit, - true, false, names, base_name + names._attn_base_name); + true, false, + config.attn_implementation, names, base_name + names._attn_base_name); moe = MiniCPMMoE(config, names, base_name + names._ffn_base_name); input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, base_name + names._attn_norm_name); post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, base_name + names._ffn_norm_name); diff --git a/src/models/mistral/modeling_mistral.hpp b/src/models/mistral/modeling_mistral.hpp index 37de46474..a04ae91ff 100644 --- a/src/models/mistral/modeling_mistral.hpp +++ b/src/models/mistral/modeling_mistral.hpp @@ -51,10 +51,12 @@ class MistralDecoder final : public Module { public: MistralDecoder() = default; MistralDecoder(const MistralConfig &config, const MistralNameConfig &names, const string &base_name) { - self_atten = MultiHeadAttention(config.hidden_size, config.num_attention_heads, config.num_key_value_heads, + self_atten = MultiHeadAttention(config.hidden_size, config.num_attention_heads, + config.num_key_value_heads, config.hidden_size / config.num_attention_heads, SPLIT_NONE, false, false, config.RoPE_type, config.rope_theta, config.max_position_embeddings, config.cache_limit, - true, false, names, base_name + names._attn_base_name); + true, false, + config.attn_implementation, names, base_name + names._attn_base_name); mlp = MistralMLP(config.hidden_size, config.intermediate_size, names, base_name + names._ffn_base_name); input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, base_name + names._attn_norm_name); post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, base_name + names._ffn_norm_name); diff --git a/src/models/openelm/modeling_openelm.hpp b/src/models/openelm/modeling_openelm.hpp index 8700729d1..14904a7dd 100644 --- a/src/models/openelm/modeling_openelm.hpp +++ b/src/models/openelm/modeling_openelm.hpp @@ -51,6 +51,7 @@ class OpenELMMultiHeadCausalAttention final : public Module { Softmax softmax; int iter = 0; + string attn_impl; public: OpenELMMultiHeadCausalAttention() = default; @@ -60,6 +61,7 @@ class OpenELMMultiHeadCausalAttention final : public Module { q_heads_ = cfg.num_query_heads[layer_idx]; k_heads_ = cfg.num_kv_heads[layer_idx]; v_heads_ = cfg.num_kv_heads[layer_idx]; + attn_impl = cfg.attn_implementation; qkv_proj = Linear(cfg.model_dim, (q_heads_ + k_heads_ + v_heads_) * head_dim_, false, base_name + "qkv_proj"); q_rope = RoPE(cfg.RoPE_type, cfg.rope_freq_constant, cfg.rope_max_length, base_name + "q_rope"); @@ -70,8 +72,8 @@ class OpenELMMultiHeadCausalAttention final : public Module { out_proj = Linear(q_heads_ * head_dim_, cfg.model_dim, false, base_name + "out_proj"); - k_cache = KVCache(k_heads_, head_dim_, q_heads_ / k_heads_, cfg.cache_limit, base_name + "k_cache"); - v_cache = KVCache(v_heads_, head_dim_, q_heads_ / v_heads_, cfg.cache_limit, base_name + "v_cache"); + k_cache = KVCache(k_heads_, head_dim_, q_heads_ / k_heads_, cfg.cache_limit, (attn_impl == "flash_attention_2"), base_name + "k_cache"); + v_cache = KVCache(v_heads_, head_dim_, q_heads_ / v_heads_, cfg.cache_limit, (attn_impl == "flash_attention_2"), base_name + "v_cache"); softmax = Softmax(DIMENSION, true, base_name + "softmax"); } @@ -98,13 +100,16 @@ class OpenELMMultiHeadCausalAttention final : public Module { k = k_cache(k); v = v_cache(v); - k = k.transpose(SEQUENCE, DIMENSION); - auto qk = Tensor::mm(q, k); - - qk = qk / std::sqrt(head_dim_); - - qk = softmax(qk, k_cache.getCacheSeqLen()); - auto o = Tensor::mm(qk, v); + Tensor o; + if (attn_impl == "flash_attention_2") { + o = Tensor::flash_attention2_forward(q, k, v, true); + } else { // eager implementation + k = k.transpose(SEQUENCE, DIMENSION); + auto qk = Tensor::mm(q, k); + qk = qk / std::sqrt(head_dim_); + qk = softmax(qk, k_cache.getCacheSeqLen()); + o = Tensor::mm(qk, v); + } o = o.view(-1, 1, -1, q_heads_ * head_dim_); o = out_proj(o); diff --git a/src/models/opt/modeling_opt.hpp b/src/models/opt/modeling_opt.hpp index b40270924..ca1b5bbca 100644 --- a/src/models/opt/modeling_opt.hpp +++ b/src/models/opt/modeling_opt.hpp @@ -16,9 +16,13 @@ class OPTBlock final : public Module { public: OPTBlock() = default; - OPTBlock(int hidden_dim, int head_size, int ffn_hidden, int cache_limit, const optNameConfig &names, const string &base_name) { - attention = MultiHeadAttention(hidden_dim, head_size, head_size, hidden_dim / head_size, SPLIT_NONE, false, false, - NONE, -1, -1, cache_limit, true, true, names, base_name + names._attn_base_name); + OPTBlock(int hidden_dim, int head_size, int ffn_hidden, int cache_limit, + string attn_implementation, const optNameConfig &names, const string &base_name) { + attention = MultiHeadAttention(hidden_dim, head_size, head_size, + hidden_dim / head_size, SPLIT_NONE, false, false, + NONE, -1, -1, cache_limit, true, true, + attn_implementation, + names, base_name + names._attn_base_name); mlp = FeedForward(hidden_dim, ffn_hidden, "ReLU", true, names, base_name + names._ffn_base_name); norm1 = LayerNorm(hidden_dim, true, 1e-05, base_name + names._attn_norm_name); @@ -47,13 +51,18 @@ class OPTModel final : public Module { public: explicit OPTModel(const OPTConfig &config) : - OPTModel(config.vocab_size, config.hidden_dim, config.head_size, config.ffn_hidden, config.block_num, config.cache_limit, config.names_config, config.names_config.blk_name) { + OPTModel(config.vocab_size, config.hidden_dim, + config.head_size, config.ffn_hidden, config.block_num, config.cache_limit, + config.attn_implementation, + config.names_config, config.names_config.blk_name) { } - OPTModel(int vocab_size, int hidden_dim, int head_size, int ffn_hidden, int block_num, int cache_limit, const optNameConfig &names, const string &base_name) { + OPTModel(int vocab_size, int hidden_dim, int head_size, int ffn_hidden, int block_num, int cache_limit, + string attn_implementation, + const optNameConfig &names, const string &base_name) { embedding = Embedding(vocab_size, hidden_dim, names.token_embd_name); pos_embedding = Embedding(2050, hidden_dim, names.pos_name); pos = Position("pos"); - blocks = List(block_num, hidden_dim, head_size, ffn_hidden, cache_limit, names, base_name); + blocks = List(block_num, hidden_dim, head_size, ffn_hidden, cache_limit, attn_implementation, names, base_name); norm = LayerNorm(hidden_dim, true, 1e-05, names.post_norm_name); lm_head = Linear(hidden_dim, vocab_size, false, names.lm_head_name); } diff --git a/src/models/phi3/modeling_phi3.hpp b/src/models/phi3/modeling_phi3.hpp index b88aaca46..c5a2719cc 100644 --- a/src/models/phi3/modeling_phi3.hpp +++ b/src/models/phi3/modeling_phi3.hpp @@ -45,9 +45,8 @@ class Phi3Block final : public Module { public: Phi3Block() = default; - Phi3Block(int hidden_dim, int head_size, int kv_head_size, int ffn_hidden, RoPEType RoPE_type, float rope_theta, int max_position_embeddings, int cache_limit, const Phi3NameConfig &names, const string &base_name) { - attention = MultiHeadAttention(hidden_dim, head_size, kv_head_size, hidden_dim / head_size, SPLIT_HD, false, false, - RoPE_type, rope_theta, max_position_embeddings, cache_limit, true, false, names, base_name + names._attn_base_name); + Phi3Block(int hidden_dim, int head_size, int kv_head_size, int ffn_hidden, RoPEType RoPE_type, float rope_theta, int max_position_embeddings, int cache_limit, string attn_implementation, const Phi3NameConfig &names, const string &base_name) { + attention = MultiHeadAttention(hidden_dim, head_size, kv_head_size, hidden_dim / head_size, SPLIT_HD, false, false, RoPE_type, rope_theta, max_position_embeddings, cache_limit, true, false, attn_implementation, names, base_name + names._attn_base_name); mlp = Phi3MLP(hidden_dim, ffn_hidden, names, base_name + names._ffn_base_name); norm1 = RMSNorm(hidden_dim, 1e-6, base_name + names._attn_norm_name); norm2 = RMSNorm(hidden_dim, 1e-6, base_name + names._ffn_norm_name); @@ -76,13 +75,13 @@ class Phi3Model final : public Module { public: explicit Phi3Model(const Phi3Config &config) : Phi3Model(config.vocab_size, config.hidden_dim, config.head_size, config.num_key_value_heads, config.ffn_hidden, config.block_num, - config.RoPE_type, config.rope_theta, config.max_position_embeddings, config.cache_limit, + config.RoPE_type, config.rope_theta, config.max_position_embeddings, config.cache_limit, config.attn_implementation, config.names_config, config.names_config.blk_name) { } - Phi3Model(int vocab_size, int hidden_dim, int head_size, int kv_head_size, int ffn_hidden, int block_num, RoPEType RoPE_type, float rope_theta, int max_position_embeddings, int cache_limit, + Phi3Model(int vocab_size, int hidden_dim, int head_size, int kv_head_size, int ffn_hidden, int block_num, RoPEType RoPE_type, float rope_theta, int max_position_embeddings, int cache_limit, string attn_implementation, const Phi3NameConfig &names, const string &base_name) { embedding = Embedding(vocab_size, hidden_dim, names.token_embd_name); - blocks = List(block_num, hidden_dim, head_size, kv_head_size, ffn_hidden, RoPE_type, rope_theta, max_position_embeddings, cache_limit, names, base_name); + blocks = List(block_num, hidden_dim, head_size, kv_head_size, ffn_hidden, RoPE_type, rope_theta, max_position_embeddings, cache_limit, attn_implementation, names, base_name); norm = RMSNorm(hidden_dim, 1e-6, names.post_norm_name); lm_head = Linear(hidden_dim, vocab_size, false, names.lm_head_name); } diff --git a/src/models/phi3v/modeling_phi3v.hpp b/src/models/phi3v/modeling_phi3v.hpp index 9efb8c22e..7502a9a9b 100644 --- a/src/models/phi3v/modeling_phi3v.hpp +++ b/src/models/phi3v/modeling_phi3v.hpp @@ -50,11 +50,12 @@ class Phi3VisionModel final : public Module { public: Phi3VisionModel() = default; Phi3VisionModel(int hidden_dim, int head_size, int ffn_hidden, const string &act_fn_type, int patch, int img_hw, int block_num, + string attn_implementation, const Phi3VNameConfig &names, const string &base_name) { embedding = Phi3VisionEmbedding(hidden_dim, patch, img_hw, names, base_name + names._embd_name); pre_layrnorm = LayerNorm(hidden_dim, true, 1e-5, base_name + names._vision_pre_layrnorm_name); clip_len_ = std::ceil(img_hw / patch) * std::ceil(img_hw / patch) + 1; - blocks = List(block_num, hidden_dim, head_size, ffn_hidden, act_fn_type, names, base_name + names._layer_name); + blocks = List(block_num, hidden_dim, head_size, ffn_hidden, act_fn_type, attn_implementation, names, base_name + names._layer_name); } vector Forward(vector inputs, vector args) override { auto x = embedding(inputs)[0]; @@ -81,9 +82,9 @@ class Phi3Embedding final : public Module { public: Phi3Embedding() = default; - explicit Phi3Embedding(int vocab_size, int hidden_dim, int head_size, int ffn, int vision_hidden_dim, string &projection_cls, const Phi3VNameConfig &nameconfig, const string &base_name, const string &embd_name) { + explicit Phi3Embedding(int vocab_size, int hidden_dim, int head_size, int ffn, int vision_hidden_dim, string &projection_cls, const Phi3VNameConfig &nameconfig, string attn_implementation, const string &base_name, const string &embd_name) { embed_tokens = Embedding(vocab_size, hidden_dim, embd_name); - img_processor = Phi3VisionModel(vision_hidden_dim, 16, vision_hidden_dim * 4, "QuickGELU", 14, 336, 23, nameconfig, nameconfig.vison_model_name); + img_processor = Phi3VisionModel(vision_hidden_dim, 16, vision_hidden_dim * 4, "QuickGELU", 14, 336, 23, attn_implementation, nameconfig, nameconfig.vison_model_name); glb_GN = Parameter(1, 1, 1, vision_hidden_dim * 4, nameconfig._vision_model_prefix + nameconfig._glb_GN); sub_GN = Parameter(1, 1, 1, vision_hidden_dim * 4, nameconfig._vision_model_prefix + nameconfig._sub_GN); project_cls = projection_cls; @@ -150,15 +151,15 @@ class Phi3VModel final : public Module { public: explicit Phi3VModel(const Phi3VConfig &config) : - Phi3VModel(config.vocab_size, config.hidden_dim, config.head_size, config.num_key_value_heads, config.ffn_hidden, config.block_num, - config.RoPE_type, config.rope_theta, config.max_position_embeddings, config.cache_limit, config.vision_hidden_dim, config.projection_cls, config.name_config, + Phi3VModel(config.vocab_size, config.hidden_dim, config.head_size, + config.num_key_value_heads, config.ffn_hidden, config.block_num, + config.RoPE_type, config.rope_theta, config.max_position_embeddings, config.cache_limit, config.vision_hidden_dim, config.projection_cls, config.attn_implementation, config.name_config, config.names_config, config.names_config.blk_name) { } - Phi3VModel(int vocab_size, int hidden_dim, int head_size, int kv_head_size, int ffn_hidden, int block_num, RoPEType RoPE_type, float rope_theta, int max_position_embeddings, int cache_limit, int vision_hidden_dim, string projection_cls, const Phi3VNameConfig &visionconfig, - const Phi3NameConfig &names, const string &base_name) { + Phi3VModel(int vocab_size, int hidden_dim, int head_size, int kv_head_size, int ffn_hidden, int block_num, RoPEType RoPE_type, float rope_theta, int max_position_embeddings, int cache_limit, int vision_hidden_dim, string projection_cls, string attn_implementation, const Phi3VNameConfig &visionconfig, const Phi3NameConfig &names, const string &base_name) { norm = RMSNorm(hidden_dim, 1e-6, names.post_norm_name); - vision_embed_tokens = Phi3Embedding(vocab_size, hidden_dim, head_size, ffn_hidden, vision_hidden_dim, projection_cls, visionconfig, base_name, names.token_embd_name); - blocks = List(block_num, hidden_dim, head_size, kv_head_size, ffn_hidden, RoPE_type, rope_theta, max_position_embeddings, cache_limit, names, base_name); + vision_embed_tokens = Phi3Embedding(vocab_size, hidden_dim, head_size, ffn_hidden, vision_hidden_dim, projection_cls, visionconfig, attn_implementation, base_name, names.token_embd_name); + blocks = List(block_num, hidden_dim, head_size, kv_head_size, ffn_hidden, RoPE_type, rope_theta, max_position_embeddings, cache_limit, attn_implementation, names, base_name); norm = RMSNorm(hidden_dim, 1e-6, names.post_norm_name); lm_head = Linear(hidden_dim, vocab_size, false, names.lm_head_name); } diff --git a/src/models/phonelm/modeling_phonelm.hpp b/src/models/phonelm/modeling_phonelm.hpp index 31a8dfefa..ccb719a73 100644 --- a/src/models/phonelm/modeling_phonelm.hpp +++ b/src/models/phonelm/modeling_phonelm.hpp @@ -50,6 +50,7 @@ class PhoneLMAttention final : public Module { head_dim = config.hidden_size / num_heads; num_key_value_heads = config.num_key_value_heads; num_key_value_groups = num_heads / num_key_value_heads; + attn_impl = config.attn_implementation; // init layers q_proj = Linear(hidden_size, num_heads * head_dim, false, base_name + names._q_proj_name); @@ -62,8 +63,8 @@ class PhoneLMAttention final : public Module { base_name + "q_rope"); k_rope = IRoPE(config.RoPE_type, config.rope_theta, config.max_position_embeddings, base_name + "k_rope"); - k_cache = KVCache(num_key_value_heads, head_dim, num_key_value_groups, config.cache_limit, base_name + "k_cache"); - v_cache = KVCache(num_key_value_heads, head_dim, num_key_value_groups, config.cache_limit, base_name + "v_cache"); + k_cache = KVCache(num_key_value_heads, head_dim, num_key_value_groups, config.cache_limit, (attn_impl == "flash_attention_2"), base_name + "k_cache"); + v_cache = KVCache(num_key_value_heads, head_dim, num_key_value_groups, config.cache_limit, (attn_impl == "flash_attention_2"), base_name + "v_cache"); softmax = Softmax(DIMENSION, true, base_name + "softmax"); } @@ -84,15 +85,20 @@ class PhoneLMAttention final : public Module { k = k_cache(k); v = v_cache(v); } - k = k.transpose(SEQUENCE, DIMENSION); - auto qk = Tensor::mm(q, k); - qk = qk / std::sqrt(head_dim); - if (k_cache.ready() && v_cache.ready()) { - qk = softmax(qk, k_cache.getCacheSeqLen()); - } else { - qk = softmax(qk); + Tensor o; + if (attn_impl == "flash_attention_2") { + o = Tensor::flash_attention2_forward(q, k, v, true); + } else { // eager implementation + k = k.transpose(SEQUENCE, DIMENSION); + auto qk = Tensor::mm(q, k); + qk = qk / std::sqrt(head_dim); + if (k_cache.ready() && v_cache.ready()) { + qk = softmax(qk, k_cache.getCacheSeqLen()); + } else { + qk = softmax(qk); + } + auto o = Tensor::mm(qk, v); } - auto o = Tensor::mm(qk, v); o = o.view(-1, 1, -1, head_dim * num_heads); o = o_proj(o); return {o}; @@ -120,6 +126,7 @@ class PhoneLMAttention final : public Module { KVCache k_cache; KVCache v_cache; Softmax softmax; + string attn_impl; }; class PhoneLMDecoder final : public Module { diff --git a/src/models/phonelm/modeling_phonelm_npu.hpp b/src/models/phonelm/modeling_phonelm_npu.hpp index c13d0906f..e47f8b76b 100644 --- a/src/models/phonelm/modeling_phonelm_npu.hpp +++ b/src/models/phonelm/modeling_phonelm_npu.hpp @@ -10,8 +10,11 @@ using namespace mllm; +std::set phonelmShadowLayers = {0, 1, 3, 4}; + // NPU QKV part -class PhoneLMDecoderNPUPart1 final : public Module { +class PhoneLMDecoderNPUPart1 : public Module { +protected: int hidden_size; int num_heads; int head_dim; @@ -81,6 +84,64 @@ class PhoneLMDecoderNPUPart1 final : public Module { } }; +class PhoneLMDecoderNPUPart1WithRes final : public PhoneLMDecoderNPUPart1 { + Layer input_layernorm; + Layer pre_attn_quantize; + +public: + PhoneLMDecoderNPUPart1WithRes() = default; + + PhoneLMDecoderNPUPart1WithRes(const PhoneLMConfig &config, const PhoneLMNameConfig &names, int chunk_size, const string &base_name) { + hidden_size = config.hidden_size; + num_heads = config.num_attention_heads; + head_dim = config.hidden_size / num_heads; + num_key_value_heads = config.num_key_value_heads; + num_key_value_groups = num_heads / num_key_value_heads; + + auto layer_base_name = base_name.substr(0, base_name.size() - 10); + input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, layer_base_name + names._attn_norm_name); + pre_attn_quantize = Quantize(true, layer_base_name + names._attn_base_name + names._q_proj_name + ".quantize"); + + pre_attn_view = View(-1, 1, -1, num_heads * head_dim, base_name + "ires_split-00_view_"); + + q_proj = Linear(hidden_size, num_heads * head_dim, false, base_name + names._q_proj_name); + k_proj = Linear(hidden_size, num_key_value_heads * head_dim, false, base_name + names._k_proj_name); + v_proj = Linear(hidden_size, num_key_value_heads * head_dim, false, base_name + names._v_proj_name); + + q_view = View(-1, num_heads, -1, head_dim, base_name + names._q_proj_name + "-00_view_"); + k_view = View(-1, num_heads, -1, head_dim, base_name + names._k_proj_name + "-00_view_"); + v_view = View(-1, num_heads, -1, head_dim, base_name + names._v_proj_name + "-00_view_"); + + q_dequant = Dequantize(true, base_name + names._q_proj_name + ".dequantize"); + k_dequant = Dequantize(true, base_name + names._k_proj_name + ".dequantize", false); + v_dequant = Dequantize(true, base_name + names._v_proj_name + ".dequantize", false); + + v_transpose = Transpose({0, 2, 3, 1}, base_name + names._v_proj_name + ".transpose"); + } + + vector Forward(vector inputs, vector args) override { + auto x = input_layernorm(inputs[0]); + x = pre_attn_quantize(x); + + x = pre_attn_view(x); + + auto query_states = q_proj(x); + auto key_states = k_proj(x); + auto value_states = v_proj(x); + + query_states = q_view(query_states); + key_states = k_view(key_states); + value_states = v_view(value_states); + + query_states = q_dequant(query_states); + key_states = k_dequant(key_states); + value_states = v_dequant(value_states); + + value_states = v_transpose(value_states); + return {query_states, key_states, value_states, inputs[0]}; + } +}; + // CPU QKV MM part class PhoneLMQKVmm final : public Module { IRoPE q_rope; @@ -140,7 +201,8 @@ class PhoneLMQKVmm final : public Module { }; // QNN mlp part -class PhoneLMDecoderNPUPart2 final : public Module { +class PhoneLMDecoderNPUPart2 : public Module { +protected: int hidden_size; int num_heads; int head_dim; @@ -250,39 +312,7 @@ class PhoneLMDecoderNPUPart2 final : public Module { } }; -class PhoneLMDecoderNPUPart2WithShadow final : public Module { - int hidden_size; - int num_heads; - int head_dim; - int num_key_value_heads; - int num_key_value_groups; - int intermediate_size; - - // NPU part2 of attention - Layer pre_oproj_view; - Layer out_proj; - Layer post_oproj_view; - Layer post_oproj_dequantize; - - // NPU mlp - Layer pre_mlp_quantize; - Layer pre_mlp_view; - Layer gate_proj; - Layer up_proj; - Layer post_up_proj_dequantize; - Layer post_gate_proj_dequantize; - Layer relu; - Layer post_attn_layernorm; - - Layer down_proj; - Layer pre_down_proj_quantize; - Layer post_down_proj_dequantize; - Layer post_mlp_view; - - Layer post_atten_res_add; - Layer post_mlp_res_add; - Layer mlp_mul; - +class PhoneLMDecoderNPUPart2WithShadow final : public PhoneLMDecoderNPUPart2 { public: PhoneLMDecoderNPUPart2WithShadow() = default; @@ -370,11 +400,17 @@ class PhoneLMNPU_CPUDecoder final : public Module { int num_key_value_heads; int num_key_value_groups; + int layer_idx; + int num_layers; + + SubgraphStart _SubgraphStart_1, _SubgraphStart_2; + SubgraphFinalize _SubgraphEnd_1, _SubgraphEnd_2; + Layer input_layernorm; Layer pre_attn_quantize; - PhoneLMDecoderNPUPart1 part1; + unique_ptr part1; PhoneLMQKVmm qkv_mm; - PhoneLMDecoderNPUPart2 part2; + unique_ptr part2; public: PhoneLMNPU_CPUDecoder() = default; @@ -386,39 +422,65 @@ class PhoneLMNPU_CPUDecoder final : public Module { num_key_value_heads = config.num_key_value_heads; num_key_value_groups = num_heads / num_key_value_heads; - input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, base_name + names._attn_norm_name); - pre_attn_quantize = Quantize(true, base_name + names._attn_base_name + names._q_proj_name + ".quantize"); - - part1 = PhoneLMDecoderNPUPart1(config, names, chunk_size, base_name + names._attn_base_name); - part1.to(MLLM_QNN); + // extract layer index from base_name like "model.layers.10." + std::regex re(R"(\d+)"); + std::smatch match; + std::regex_search(base_name, match, re); + layer_idx = std::stoi(match[0]); + num_layers = config.num_hidden_layers; + + if (layer_idx == 0 || phonelmShadowLayers.find(layer_idx - 1) != phonelmShadowLayers.end()) { + input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, base_name + names._attn_norm_name); + pre_attn_quantize = Quantize(true, base_name + names._attn_base_name + names._q_proj_name + ".quantize"); + part1 = make_unique(config, names, chunk_size, base_name + names._attn_base_name); + } else { + part1 = make_unique(config, names, chunk_size, base_name + names._attn_base_name); + } qkv_mm = PhoneLMQKVmm(config, names, chunk_size, base_name + names._attn_base_name); - qkv_mm.to(MLLM_CPU); - part2 = PhoneLMDecoderNPUPart2(config, names, chunk_size, base_name); - part2.to(MLLM_QNN); + part2 = make_unique(config, names, chunk_size, base_name); + + _SubgraphStart_1 = SubgraphStart(base_name + "subgraph_start1"); + _SubgraphEnd_1 = SubgraphFinalize(base_name + "subgraph_end1"); + _SubgraphStart_2 = SubgraphStart(base_name + "subgraph_start2"); + _SubgraphEnd_2 = SubgraphFinalize(base_name + "subgraph_end2"); } vector Forward(vector inputs, vector args) override { - auto x = input_layernorm(inputs[0]); - x = pre_attn_quantize(x); - - if (x.device() != MLLM_QNN) { - x = Tensor::toQNN({x})[0]; + Tensor x, q, k, v, res; + if (layer_idx == 0 || phonelmShadowLayers.find(layer_idx - 1) != phonelmShadowLayers.end()) { + x = input_layernorm(inputs[0]); + x = pre_attn_quantize(x); + + _SubgraphStart_1({x}); + + auto q_k_v = (*part1)({x}); // q,k,v + q = q_k_v[0]; + k = q_k_v[1]; + v = q_k_v[2]; + res = inputs[0]; + _SubgraphEnd_1(q_k_v); + } else { + auto q_k_v_res = (*part1)(inputs); // q,k,v,res + q = q_k_v_res[0]; + k = q_k_v_res[1]; + v = q_k_v_res[2]; + res = q_k_v_res[3]; + _SubgraphEnd_1(q_k_v_res); } - auto q_k_v = part1({x}); // q,k,v - auto o_x = qkv_mm(q_k_v)[0]; + auto o_x = qkv_mm({q, k, v})[0]; - if (o_x.device() != MLLM_QNN) { - o_x = Tensor::toQNN({o_x})[0]; - } - if (inputs[0].device() != MLLM_QNN) { - inputs[0] = Tensor::toQNN({inputs[0]})[0]; + _SubgraphStart_2({o_x, res}); + + auto out_part2 = (*part2)({o_x, res}); + + if (layer_idx == num_layers - 1) { + _SubgraphEnd_2(out_part2); } - x = part2({o_x, inputs[0]})[0]; - return {x}; + return out_part2; } }; @@ -432,9 +494,15 @@ class PhoneLMNPU_CPUDecoderWithShadow final : public Module { Layer input_layernorm; Layer pre_attn_quantize; Layer shadow_linear; - PhoneLMDecoderNPUPart1 part1; + unique_ptr part1; PhoneLMQKVmm qkv_mm; - PhoneLMDecoderNPUPart2WithShadow part2; + unique_ptr part2; + + int layer_idx; + int num_layers; + + SubgraphStart _SubgraphStart_1, _SubgraphStart_2; + SubgraphFinalize _SubgraphEnd_1, _SubgraphEnd_2; public: PhoneLMNPU_CPUDecoderWithShadow() = default; @@ -446,45 +514,69 @@ class PhoneLMNPU_CPUDecoderWithShadow final : public Module { num_key_value_heads = config.num_key_value_heads; num_key_value_groups = num_heads / num_key_value_heads; - input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, base_name + names._attn_norm_name); - pre_attn_quantize = Quantize(true, base_name + names._attn_base_name + names._q_proj_name + ".quantize"); - - part1 = PhoneLMDecoderNPUPart1(config, names, chunk_size, base_name + names._attn_base_name); - part1.to(MLLM_QNN); + // extract layer index from base_name like "model.layers.10." + std::regex re(R"(\d+)"); + std::smatch match; + std::regex_search(base_name, match, re); + layer_idx = std::stoi(match[0]); + num_layers = config.num_hidden_layers; + + if (layer_idx == 0 || phonelmShadowLayers.find(layer_idx - 1) != phonelmShadowLayers.end()) { + input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, base_name + names._attn_norm_name); + pre_attn_quantize = Quantize(true, base_name + names._attn_base_name + names._q_proj_name + ".quantize"); + part1 = make_unique(config, names, chunk_size, base_name + names._attn_base_name); + } else { + part1 = make_unique(config, names, chunk_size, base_name + names._attn_base_name); + } qkv_mm = PhoneLMQKVmm(config, names, chunk_size, base_name + names._attn_base_name); - qkv_mm.to(MLLM_CPU); - part2 = PhoneLMDecoderNPUPart2WithShadow(config, names, chunk_size, base_name); - part2.to(MLLM_QNN); + part2 = make_unique(config, names, chunk_size, base_name); shadow_linear = ShadowLinear(config.intermediate_size, hidden_size, 1024, false, base_name + names._ffn_base_name + names._down_proj_name + ".shadow"); + + _SubgraphStart_1 = SubgraphStart(base_name + "subgraph_start1"); + _SubgraphEnd_1 = SubgraphFinalize(base_name + "subgraph_end1"); + _SubgraphStart_2 = SubgraphStart(base_name + "subgraph_start2"); + _SubgraphEnd_2 = SubgraphFinalize(base_name + "subgraph_end2"); } vector Forward(vector inputs, vector args) override { - auto x = input_layernorm(inputs[0]); - x = pre_attn_quantize(x); - - if (x.device() != MLLM_QNN) { - x = Tensor::toQNN({x})[0]; + Tensor x, q, k, v, res; + if (layer_idx == 0 || phonelmShadowLayers.find(layer_idx - 1) != phonelmShadowLayers.end()) { + x = input_layernorm(inputs[0]); + x = pre_attn_quantize(x); + + _SubgraphStart_1({x}); + + auto q_k_v = (*part1)({x}); // q,k,v + q = q_k_v[0]; + k = q_k_v[1]; + v = q_k_v[2]; + res = inputs[0]; + _SubgraphEnd_1(q_k_v); + } else { + auto q_k_v_res = (*part1)(inputs); // q,k,v,res + q = q_k_v_res[0]; + k = q_k_v_res[1]; + v = q_k_v_res[2]; + res = q_k_v_res[3]; + _SubgraphEnd_1(q_k_v_res); } - auto q_k_v = part1({x}); // q,k,v - auto o_x = qkv_mm(q_k_v)[0]; + auto o_x = qkv_mm({q, k, v})[0]; + + _SubgraphStart_2({o_x, res}); + + auto decoder_out = (*part2)({o_x, res}); + decoder_out = Tensor::toCPU(decoder_out); + + _SubgraphEnd_2(decoder_out); - if (o_x.device() != MLLM_QNN) { - o_x = Tensor::toQNN({o_x})[0]; - } - if (inputs[0].device() != MLLM_QNN) { - inputs[0] = Tensor::toQNN({inputs[0]})[0]; - } - auto decoder_out = part2({o_x, inputs[0]}); - if (decoder_out[0].device() != MLLM_CPU) { - decoder_out = Tensor::toCPU(decoder_out); - } auto shadow_input_1 = decoder_out[0]; auto shadow_input_2 = decoder_out[1]; x = decoder_out[2]; + x = shadow_linear(shadow_input_1, shadow_input_2, x); return {x}; @@ -499,11 +591,10 @@ class PhoneLMModel_NPU final : public Module { static_assert(std::is_base_of::value, "SHADOW must be a subclass of Module"); listIdx = 0; vector> modules; - std::set shadowLayers = {0, 1, 3, 4}; - // for index in shadowLayers, create shadow decoder, for others, create normal decoder + // for index in phonelmShadowLayers, create shadow decoder, for others, create normal decoder for (int i = 0; i < n; i++) { auto new_args = change_last(args...); // 创建新的参数包,最后一个参数被修改为原来的值+ std::to_string(listIdx)+ "." - if (shadowLayers.find(listIdx) != shadowLayers.end()) { + if (phonelmShadowLayers.find(listIdx) != phonelmShadowLayers.end()) { modules.push_back(std::make_unique(std::apply([&](auto &&...args) { return SHADOW(std::forward(args)...); }, new_args))); } else { modules.push_back(std::make_unique(std::apply([&](auto &&...args) { return T1(std::forward(args)...); }, new_args))); diff --git a/src/models/qwen/configuration_qwen.hpp b/src/models/qwen/configuration_qwen.hpp index 446984385..686b48b46 100644 --- a/src/models/qwen/configuration_qwen.hpp +++ b/src/models/qwen/configuration_qwen.hpp @@ -131,6 +131,36 @@ struct QWenConfig : public TransformerConfig { sliding_window = 32768; vocab_size = 151936; tie_embedding_words = true; + } else if (billionsType == "1.5b-rotated") { + attention_dropout = 0.0; + std::string hidden_act = "silu"; + hidden_size = 1536; + intermediate_size = 8960; + max_position_embeddings = 32768; + max_window_layers = 28; + num_attention_heads = 12; + num_hidden_layers = 28; + num_key_value_heads = 2; + rms_norm_eps = 1e-6; + rope_theta = 1000000.0; + sliding_window = 32768; + vocab_size = 151936; + tie_embedding_words = false; + } else if (billionsType == "3b") { + attention_dropout = 0.0; + std::string hidden_act = "silu"; + hidden_size = 2048; + intermediate_size = 11008; + max_position_embeddings = 32768; + max_window_layers = 70; + num_attention_heads = 16; + num_hidden_layers = 36; + num_key_value_heads = 2; + rms_norm_eps = 1e-6; + rope_theta = 1000000.0; + sliding_window = 32768; + vocab_size = 151936; + tie_embedding_words = true; } else if (billionsType == "3b") { attention_dropout = 0.0; std::string hidden_act = "silu"; diff --git a/src/models/qwen/modeling_qwen.hpp b/src/models/qwen/modeling_qwen.hpp index e73c50a8c..90bdfefc8 100644 --- a/src/models/qwen/modeling_qwen.hpp +++ b/src/models/qwen/modeling_qwen.hpp @@ -61,7 +61,7 @@ class QWenAttention final : public Module { head_dim = config.hidden_size / num_heads; num_key_value_heads = config.num_key_value_heads; num_key_value_groups = num_heads / num_key_value_heads; - + attn_impl = config.attn_implementation; // init layers q_proj = Linear(hidden_size, num_heads * head_dim, true, base_name + names._q_proj_name); k_proj = Linear(hidden_size, num_key_value_heads * head_dim, true, @@ -73,10 +73,8 @@ class QWenAttention final : public Module { base_name + "q_rope"); k_rope = RoPE(config.RoPE_type, config.rope_theta, config.max_position_embeddings, base_name + "k_rope"); - k_cache = KVCache(num_key_value_heads, head_dim, num_key_value_groups, config.cache_limit, base_name + "k_cache"); - v_cache = KVCache(num_key_value_heads, head_dim, num_key_value_groups, config.cache_limit, base_name + "v_cache"); - // mask = SlidingWindowMask(config.sliding_window, base_name + "mask"); - // mask = Causalmask(base_name + "mask"); + k_cache = KVCache(num_key_value_heads, head_dim, num_key_value_groups, config.cache_limit, (config.attn_implementation == "flash_attention_2"), base_name + "k_cache"); + v_cache = KVCache(num_key_value_heads, head_dim, num_key_value_groups, config.cache_limit, (config.attn_implementation == "flash_attention_2"), base_name + "v_cache"); softmax = Softmax(DIMENSION, true, base_name + "softmax"); } @@ -98,15 +96,19 @@ class QWenAttention final : public Module { key_states = k_cache(key_states); value_states = v_cache(value_states); - // attention weight - auto atten_weight = - Tensor::mm(query_states, key_states.transpose(Chl::SEQUENCE, Chl::DIMENSION)) - / std::sqrt(head_dim); - // atten_weight = mask(atten_weight, k_cache.getCacheSeqLen()); - atten_weight = softmax(atten_weight, k_cache.getCacheSeqLen()); + Tensor atten_output; + if (attn_impl == "flash_attention_2") { + atten_output = Tensor::flash_attention2_forward(query_states, key_states, value_states, true); + } else { // eager implementation + auto atten_weight = + Tensor::mm(query_states, key_states.transpose(Chl::SEQUENCE, Chl::DIMENSION)) + / std::sqrt(head_dim); + // atten_weight = mask(atten_weight, k_cache.getCacheSeqLen()); + atten_weight = softmax(atten_weight, k_cache.getCacheSeqLen()); + auto atten_output = Tensor::mm(atten_weight, value_states); + } // attention output - auto atten_output = Tensor::mm(atten_weight, value_states); atten_output = atten_output.view(-1, 1, -1, head_dim * num_heads); atten_output = o_proj(atten_output); return {atten_output}; @@ -135,6 +137,7 @@ class QWenAttention final : public Module { KVCache v_cache; // Causalmask mask; Softmax softmax; + string attn_impl; }; // Copied from GemmaDecoder with Gemma->Qwen and set RmsNorm(without add_unit_offset) diff --git a/src/models/qwen/modeling_qwen_npu.hpp b/src/models/qwen/modeling_qwen_npu.hpp index 30ae8e1a3..e14dabaa9 100644 --- a/src/models/qwen/modeling_qwen_npu.hpp +++ b/src/models/qwen/modeling_qwen_npu.hpp @@ -1,5 +1,5 @@ -#ifndef MODELING_QWENNPU_HPP -#define MODELING_QWENNPU_HPP +#ifndef MODELING_QWENNPU_V2_HPP +#define MODELING_QWENNPU_V2_HPP #include "Backend.hpp" #include "Layer.hpp" @@ -7,11 +7,18 @@ #include "Tensor.hpp" #include "Types.hpp" #include "configuration_qwen.hpp" +#include using namespace mllm; +namespace v2 { + +// a 'just working' try +std::set shadowLayers = {1, 2, 4, 5, 26}; + // NPU QKV part -class QwenDecoderNPUPart1 final : public Module { +class QwenDecoderNPUPart1 : public Module { +protected: int hidden_size; int num_heads; int head_dim; @@ -45,17 +52,17 @@ class QwenDecoderNPUPart1 final : public Module { pre_attn_view = View(-1, 1, -1, num_heads * head_dim, base_name + "ires_split-00_view_"); - q_proj = Linear(hidden_size, num_heads * head_dim, true, base_name + names._q_proj_name); - k_proj = Linear(hidden_size, num_key_value_heads * head_dim, true, base_name + names._k_proj_name); - v_proj = Linear(hidden_size, num_key_value_heads * head_dim, true, base_name + names._v_proj_name); + q_proj = Linear(hidden_size, num_heads * head_dim, false, base_name + names._q_proj_name); + k_proj = Linear(hidden_size, num_key_value_heads * head_dim, false, base_name + names._k_proj_name); + v_proj = Linear(hidden_size, num_key_value_heads * head_dim, false, base_name + names._v_proj_name); q_view = View(-1, num_heads, -1, head_dim, base_name + names._q_proj_name + "-00_view_"); k_view = View(-1, num_key_value_heads, -1, head_dim, base_name + names._k_proj_name + "-00_view_"); v_view = View(-1, num_key_value_heads, -1, head_dim, base_name + names._v_proj_name + "-00_view_"); - q_dequant = Dequantize(true, base_name + names._q_proj_name + ".dequantize"); - k_dequant = Dequantize(true, base_name + names._k_proj_name + ".dequantize", false); - v_dequant = Dequantize(true, base_name + names._v_proj_name + ".dequantize", false); + q_dequant = Dequantize(true, base_name + names._q_proj_name + ".dequantize", true, MLLM_TYPE_I16); + k_dequant = Dequantize(true, base_name + names._k_proj_name + ".dequantize", false, MLLM_TYPE_I16); + v_dequant = Dequantize(true, base_name + names._v_proj_name + ".dequantize", false, MLLM_TYPE_I16); v_transpose = Transpose({0, 2, 3, 1}, base_name + names._v_proj_name + ".transpose"); } @@ -80,6 +87,64 @@ class QwenDecoderNPUPart1 final : public Module { } }; +class QwenDecoderNPUPart1WithRes final : public QwenDecoderNPUPart1 { + Layer input_layernorm; + Layer pre_attn_quantize; + +public: + QwenDecoderNPUPart1WithRes() = default; + QwenDecoderNPUPart1WithRes(const QWenConfig &config, const QWenNameConfig &names, int chunk_size, const string &base_name) { + hidden_size = config.hidden_size; + num_heads = config.num_attention_heads; + head_dim = config.hidden_size / num_heads; + num_key_value_heads = config.num_key_value_heads; + num_key_value_groups = num_heads / num_key_value_heads; + + // remove "self_attn." in base_name + auto layer_base_name = base_name.substr(0, base_name.size() - 10); + input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, layer_base_name + names._attn_norm_name); + pre_attn_quantize = Quantize(true, layer_base_name + names._attn_base_name + names._q_proj_name + ".quantize", MLLM_TYPE_I16); + + pre_attn_view = View(-1, 1, -1, num_heads * head_dim, base_name + "ires_split-00_view_"); + + q_proj = Linear(hidden_size, num_heads * head_dim, false, base_name + names._q_proj_name); + k_proj = Linear(hidden_size, num_key_value_heads * head_dim, false, base_name + names._k_proj_name); + v_proj = Linear(hidden_size, num_key_value_heads * head_dim, false, base_name + names._v_proj_name); + + q_view = View(-1, num_heads, -1, head_dim, base_name + names._q_proj_name + "-00_view_"); + k_view = View(-1, num_key_value_heads, -1, head_dim, base_name + names._k_proj_name + "-00_view_"); + v_view = View(-1, num_key_value_heads, -1, head_dim, base_name + names._v_proj_name + "-00_view_"); + + q_dequant = Dequantize(true, base_name + names._q_proj_name + ".dequantize", true, MLLM_TYPE_I16); + k_dequant = Dequantize(true, base_name + names._k_proj_name + ".dequantize", false, MLLM_TYPE_I16); + v_dequant = Dequantize(true, base_name + names._v_proj_name + ".dequantize", false, MLLM_TYPE_I16); + + v_transpose = Transpose({0, 2, 3, 1}, base_name + names._v_proj_name + ".transpose"); + } + + vector Forward(vector inputs, vector args) override { + auto x = input_layernorm(inputs[0]); + x = pre_attn_quantize(x); + + x = pre_attn_view(x); + + auto query_states = q_proj(x); + auto key_states = k_proj(x); + auto value_states = v_proj(x); + + query_states = q_view(query_states); + key_states = k_view(key_states); + value_states = v_view(value_states); + + query_states = q_dequant(query_states); + key_states = k_dequant(key_states); + value_states = v_dequant(value_states); + + value_states = v_transpose(value_states); + return {query_states, key_states, value_states, inputs[0]}; + } +}; + // CPU QKV MM part class QwenQKVmm final : public Module { RoPE q_rope; @@ -99,6 +164,8 @@ class QwenQKVmm final : public Module { QwenQKVmm() = default; QwenQKVmm(const QWenConfig &config, const QWenNameConfig &names, int chunk_size, const string &base_name) { hidden_size = config.hidden_size; + num_heads = config.num_attention_heads; + head_dim = config.hidden_size / num_heads; num_heads = config.num_attention_heads * config.hidden_size / config.num_attention_heads; q_rope = RoPE(config.RoPE_type, config.rope_theta, config.max_position_embeddings, base_name + "q_rope"); @@ -107,6 +174,9 @@ class QwenQKVmm final : public Module { k_cache = KVCache(config.num_attention_heads / config.num_key_value_heads, config.cache_limit, base_name + "k_cache", true); v_cache = KVCache(config.num_attention_heads / config.num_key_value_heads, config.cache_limit, base_name + "v_cache", true); + // k_cache = KVCache(config.num_key_value_heads, head_dim, config.num_attention_heads / config.num_key_value_heads, config.cache_limit, base_name + "k_cache", true); + // v_cache = KVCache(config.num_key_value_heads, head_dim, config.num_attention_heads / config.num_key_value_heads, config.cache_limit, base_name + "v_cache", true); + softmax = Softmax(DIMENSION, true, base_name + "softmax"); o_quantize = Quantize(true, base_name + names._o_proj_name + ".quantize"); @@ -134,7 +204,8 @@ class QwenQKVmm final : public Module { }; // QNN mlp part -class QwenDecoderNPUPart2 final : public Module { +class QwenDecoderNPUPart2 : public Module { +protected: int hidden_size; int num_heads; int head_dim; @@ -243,39 +314,7 @@ class QwenDecoderNPUPart2 final : public Module { } }; -class QwenDecoderNPUPart2WithShadow final : public Module { - int hidden_size; - int num_heads; - int head_dim; - int num_key_value_heads; - int num_key_value_groups; - int intermediate_size; - - // NPU part2 of attention - Layer pre_oproj_view; - Layer out_proj; - Layer post_oproj_view; - Layer post_oproj_dequantize; - - // NPU mlp - Layer pre_mlp_quantize; - Layer pre_mlp_view; - Layer gate_proj; - Layer up_proj; - Layer post_up_proj_dequantize; - Layer post_gate_proj_dequantize; - Layer silu; - Layer post_attn_layernorm; - - Layer down_proj; - Layer pre_down_proj_quantize; - Layer post_down_proj_dequantize; - Layer post_mlp_view; - - Layer post_atten_res_add; - Layer post_mlp_res_add; - Layer mlp_mul; - +class QwenDecoderNPUPart2WithShadow final : public QwenDecoderNPUPart2 { public: QwenDecoderNPUPart2WithShadow() = default; QwenDecoderNPUPart2WithShadow(const QWenConfig &config, const QWenNameConfig &names, int chunk_size, const string &base_name) { @@ -351,6 +390,7 @@ class QwenDecoderNPUPart2WithShadow final : public Module { gate_out = post_mlp_view(gate_out); gate_out = post_mlp_res_add(gate_out, tmp); + return {shadow_input_1, shadow_input_2, gate_out}; } }; @@ -362,11 +402,17 @@ class QwenNPU_CPUDecoder final : public Module { int num_key_value_heads; int num_key_value_groups; + int layer_idx; + int num_layers; + + SubgraphStart _SubgraphStart_1, _SubgraphStart_2; + SubgraphFinalize _SubgraphEnd_1, _SubgraphEnd_2; + Layer input_layernorm; Layer pre_attn_quantize; - QwenDecoderNPUPart1 part1; + unique_ptr part1; QwenQKVmm qkv_mm; - QwenDecoderNPUPart2 part2; + unique_ptr part2; public: QwenNPU_CPUDecoder() = default; @@ -377,36 +423,65 @@ class QwenNPU_CPUDecoder final : public Module { num_key_value_heads = config.num_key_value_heads; num_key_value_groups = num_heads / num_key_value_heads; - input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, base_name + names._attn_norm_name); - pre_attn_quantize = Quantize(true, base_name + names._attn_base_name + names._q_proj_name + ".quantize"); - - part1 = QwenDecoderNPUPart1(config, names, chunk_size, base_name + names._attn_base_name); - part1.to(MLLM_QNN); + // extract layer index from base_name like "model.layers.10." + std::regex re(R"(\d+)"); + std::smatch match; + std::regex_search(base_name, match, re); + layer_idx = std::stoi(match[0]); + num_layers = config.num_hidden_layers; + + if (layer_idx == 0 || shadowLayers.find(layer_idx - 1) != shadowLayers.end()) { + input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, base_name + names._attn_norm_name); + pre_attn_quantize = Quantize(true, base_name + names._attn_base_name + names._q_proj_name + ".quantize", MLLM_TYPE_I16); + part1 = make_unique(config, names, chunk_size, base_name + names._attn_base_name); + } else { + part1 = make_unique(config, names, chunk_size, base_name + names._attn_base_name); + } qkv_mm = QwenQKVmm(config, names, chunk_size, base_name + names._attn_base_name); - qkv_mm.to(MLLM_CPU); - part2 = QwenDecoderNPUPart2(config, names, chunk_size, base_name); - part2.to(MLLM_QNN); + part2 = make_unique(config, names, chunk_size, base_name); + + _SubgraphStart_1 = SubgraphStart(base_name + "subgraph_start1"); + _SubgraphEnd_1 = SubgraphFinalize(base_name + "subgraph_end1"); + _SubgraphStart_2 = SubgraphStart(base_name + "subgraph_start2"); + _SubgraphEnd_2 = SubgraphFinalize(base_name + "subgraph_end2"); } vector Forward(vector inputs, vector args) override { - auto x = input_layernorm(inputs[0]); - x = pre_attn_quantize(x); + Tensor x, q, k, v, res; + if (layer_idx == 0 || shadowLayers.find(layer_idx - 1) != shadowLayers.end()) { + x = input_layernorm(inputs[0]); + x = pre_attn_quantize(x); + + _SubgraphStart_1({x}); + + auto q_k_v = (*part1)({x}); // q,k,v + q = q_k_v[0]; + k = q_k_v[1]; + v = q_k_v[2]; + res = inputs[0]; + _SubgraphEnd_1(q_k_v); + } else { + auto q_k_v_res = (*part1)(inputs); // q,k,v,res + q = q_k_v_res[0]; + k = q_k_v_res[1]; + v = q_k_v_res[2]; + res = q_k_v_res[3]; + _SubgraphEnd_1(q_k_v_res); + } - x = Tensor::toQNN({x})[0]; - auto q_k_v = part1({x}); // q,k,v - q_k_v = Tensor::toCPU(q_k_v); + auto o_x = qkv_mm({q, k, v})[0]; - auto o_x = qkv_mm(q_k_v)[0]; + _SubgraphStart_2({o_x, res}); - auto qnn_tensor = Tensor::toQNN({o_x, inputs[0]}); - o_x = qnn_tensor[0]; - inputs[0] = qnn_tensor[1]; - x = part2({o_x, inputs[0]})[0]; - x = Tensor::toCPU({x})[0]; + auto out_part2 = (*part2)({o_x, res}); - return {x}; + if (layer_idx == num_layers - 1) { + _SubgraphEnd_2(out_part2); + } + + return out_part2; } }; @@ -420,9 +495,15 @@ class QwenNPU_CPUDecoderWithShadow final : public Module { Layer input_layernorm; Layer pre_attn_quantize; Layer shadow_linear; - QwenDecoderNPUPart1 part1; + unique_ptr part1; QwenQKVmm qkv_mm; - QwenDecoderNPUPart2WithShadow part2; + unique_ptr part2; + + int layer_idx; + int num_layers; + + SubgraphStart _SubgraphStart_1, _SubgraphStart_2; + SubgraphFinalize _SubgraphEnd_1, _SubgraphEnd_2; public: QwenNPU_CPUDecoderWithShadow() = default; @@ -433,40 +514,69 @@ class QwenNPU_CPUDecoderWithShadow final : public Module { num_key_value_heads = config.num_key_value_heads; num_key_value_groups = num_heads / num_key_value_heads; - input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, base_name + names._attn_norm_name); - pre_attn_quantize = Quantize(true, base_name + names._attn_base_name + names._q_proj_name + ".quantize"); - - part1 = QwenDecoderNPUPart1(config, names, chunk_size, base_name + names._attn_base_name); - part1.to(MLLM_QNN); + // extract layer index from base_name like "model.layers.10." + std::regex re(R"(\d+)"); + std::smatch match; + std::regex_search(base_name, match, re); + layer_idx = std::stoi(match[0]); + num_layers = config.num_hidden_layers; + + if (layer_idx == 0 || shadowLayers.find(layer_idx - 1) != shadowLayers.end()) { + input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, base_name + names._attn_norm_name); + pre_attn_quantize = Quantize(true, base_name + names._attn_base_name + names._q_proj_name + ".quantize", MLLM_TYPE_I16); + part1 = make_unique(config, names, chunk_size, base_name + names._attn_base_name); + } else { + part1 = make_unique(config, names, chunk_size, base_name + names._attn_base_name); + } qkv_mm = QwenQKVmm(config, names, chunk_size, base_name + names._attn_base_name); - qkv_mm.to(MLLM_CPU); - part2 = QwenDecoderNPUPart2WithShadow(config, names, chunk_size, base_name); - part2.to(MLLM_QNN); + part2 = make_unique(config, names, chunk_size, base_name); shadow_linear = ShadowLinear(config.intermediate_size, hidden_size, 1024, false, base_name + names._ffn_base_name + names._down_proj_name + ".shadow"); + + _SubgraphStart_1 = SubgraphStart(base_name + "subgraph_start1"); + _SubgraphEnd_1 = SubgraphFinalize(base_name + "subgraph_end1"); + _SubgraphStart_2 = SubgraphStart(base_name + "subgraph_start2"); + _SubgraphEnd_2 = SubgraphFinalize(base_name + "subgraph_end2"); } vector Forward(vector inputs, vector args) override { - auto x = input_layernorm(inputs[0]); - x = pre_attn_quantize(x); + Tensor x, q, k, v, res; + if (layer_idx == 0 || shadowLayers.find(layer_idx - 1) != shadowLayers.end()) { + x = input_layernorm(inputs[0]); + x = pre_attn_quantize(x); + + _SubgraphStart_1({x}); + + auto q_k_v = (*part1)({x}); // q,k,v + q = q_k_v[0]; + k = q_k_v[1]; + v = q_k_v[2]; + res = inputs[0]; + _SubgraphEnd_1(q_k_v); + } else { + auto q_k_v_res = (*part1)(inputs); // q,k,v,res + q = q_k_v_res[0]; + k = q_k_v_res[1]; + v = q_k_v_res[2]; + res = q_k_v_res[3]; + _SubgraphEnd_1(q_k_v_res); + } - x = Tensor::toQNN({x})[0]; - auto q_k_v = part1({x}); // q,k,v - q_k_v = Tensor::toCPU(q_k_v); + auto o_x = qkv_mm({q, k, v})[0]; - auto o_x = qkv_mm(q_k_v)[0]; + _SubgraphStart_2({o_x, res}); - auto qnn_tensor = Tensor::toQNN({o_x, inputs[0]}); - o_x = qnn_tensor[0]; - inputs[0] = qnn_tensor[1]; - auto decoder_out = part2({o_x, inputs[0]}); + auto decoder_out = (*part2)({o_x, res}); decoder_out = Tensor::toCPU(decoder_out); + _SubgraphEnd_2(decoder_out); + auto shadow_input_1 = decoder_out[0]; auto shadow_input_2 = decoder_out[1]; x = decoder_out[2]; + x = shadow_linear(shadow_input_1, shadow_input_2, x); return {x}; @@ -481,7 +591,7 @@ class QWenModel_NPU final : public Module { static_assert(std::is_base_of::value, "SHADOW must be a subclass of Module"); listIdx = 0; vector> modules; - std::set shadowLayers = {1, 2, 26}; + // for index in shadowLayers, create shadow decoder, for others, create normal decoder for (int i = 0; i < n; i++) { auto new_args = change_last(args...); // 创建新的参数包,最后一个参数被修改为原来的值+ std::to_string(listIdx)+ "." @@ -500,7 +610,7 @@ class QWenModel_NPU final : public Module { QWenModel_NPU() = default; QWenModel_NPU(const QWenConfig &config, const QWenNameConfig &names, int chunk_size, const string &base_name) { // blocks = List(1, config, names, base_name); - blocks = ListWithShadow(24, config, names, chunk_size, base_name); + blocks = ListWithShadow(config.num_hidden_layers, config, names, chunk_size, base_name); norm = RMSNorm(config.hidden_size, config.rms_norm_eps, names.post_norm_name); } @@ -593,5 +703,6 @@ class QWenForCausalLM_NPU final : public Module { Layer lm_head_layer; QWenModel_NPU model; }; +} // namespace v2 -#endif //! MODELING_QWENNPU_HPP chunk_size, \ No newline at end of file +#endif //! MODELING_QWENNPU_V2_HPP chunk_size, \ No newline at end of file diff --git a/src/models/qwen2_vl/modeling_qwen2_vl.hpp b/src/models/qwen2_vl/modeling_qwen2_vl.hpp index bee84ccae..793b4a118 100644 --- a/src/models/qwen2_vl/modeling_qwen2_vl.hpp +++ b/src/models/qwen2_vl/modeling_qwen2_vl.hpp @@ -40,14 +40,16 @@ class VisionAttention final : public Module { int head_size_{}; int kv_head_size_{}; int attn_hidden_dim_{}; + string attn_impl; public: VisionAttention() = default; - VisionAttention(int hidden_dim, int head_size, int kv_head_size, int attn_hidden_dim, bool bias, + VisionAttention(int hidden_dim, int head_size, int kv_head_size, int attn_hidden_dim, bool bias, string attn_implementation, const TransformerNameConfig &names, const string &base_name) { attn_hidden_dim_ = attn_hidden_dim; head_size_ = head_size; kv_head_size_ = kv_head_size; + attn_impl = attn_implementation; qkv_proj = Linear(hidden_dim, head_size * attn_hidden_dim * 3, bias, base_name + names._qkv_proj_name); softmax = Softmax(DIMENSION, false, base_name + "softmax"); @@ -59,7 +61,6 @@ class VisionAttention final : public Module { auto seq_length = inputs[0].sequence(); Tensor q, k, v; auto qkv = qkv_proj(inputs[0]); - // auto qkv_sp = qkv.split({attn_hidden_dim_, attn_hidden_dim_, attn_hidden_dim_}, HD, head_size_); auto qkv_sp = qkv.split({attn_hidden_dim_ * head_size_, attn_hidden_dim_ * head_size_, attn_hidden_dim_ * head_size_}, DIMENSION); q = qkv_sp[0]; k = qkv_sp[1]; @@ -69,12 +70,17 @@ class VisionAttention final : public Module { v = v.view(-1, head_size_, -1, attn_hidden_dim_); q = Tensor::apply_rotary_pos_emb_vision(q, rotary_pos_emb); k = Tensor::apply_rotary_pos_emb_vision(k, rotary_pos_emb); - k = k.transpose(SEQUENCE, DIMENSION); - auto qk = Tensor::mm(q, k); - qk = qk / std::sqrt(attn_hidden_dim_); - // mask - qk = softmax(qk); - auto o = Tensor::mm(qk, v); + Tensor o; + if (attn_impl == "flash_attention_2") { + o = Tensor::flash_attention2_forward(q, k, v, false); + } else { // eager implementation + k = k.transpose(SEQUENCE, DIMENSION); + auto qk = Tensor::mm(q, k); + qk = qk / std::sqrt(attn_hidden_dim_); + // mask + qk = softmax(qk); + o = Tensor::mm(qk, v); + } o = o.view(-1, 1, -1, attn_hidden_dim_ * head_size_); o = o_proj(o); return {o}; @@ -109,8 +115,8 @@ class VisionBlock final : public Module { public: VisionBlock() = default; - VisionBlock(int hidden_dim, int head_size, int ffn_hidden, const string &act_fn_type, const ViTNameConfig &names, const string &base_name) { - attention = VisionAttention(hidden_dim, head_size, head_size, hidden_dim / head_size, true, names, base_name + names._attn_base_name); + VisionBlock(int hidden_dim, int head_size, int ffn_hidden, const string &act_fn_type, string attn_implementation, const ViTNameConfig &names, const string &base_name) { + attention = VisionAttention(hidden_dim, head_size, head_size, hidden_dim / head_size, true, attn_implementation, names, base_name + names._attn_base_name); mlp = VisionMLP(hidden_dim, ffn_hidden, act_fn_type, names, base_name + names._ffn_base_name); norm1 = LayerNorm(hidden_dim, true, 1e-6, base_name + names._attn_norm_name); norm2 = LayerNorm(hidden_dim, true, 1e-6, base_name + names._ffn_norm_name); @@ -160,10 +166,10 @@ class Qwen2VisionModel final : public Module { public: Qwen2VisionModel() = default; - Qwen2VisionModel(int hidden_dim, int vision_embed_dim, int head_size, int mlp_hidden_dim, const string &act_fn_type, int patch, int img_hw, int block_num, int spatial_merge_size, const Qwen2VLNameConfig &names, const string &base_name) { + Qwen2VisionModel(int hidden_dim, int vision_embed_dim, int head_size, int mlp_hidden_dim, const string &act_fn_type, int patch, int img_hw, int block_num, int spatial_merge_size, string attn_implementation, const Qwen2VLNameConfig &names, const string &base_name) { patch_embed = Qwen2PatchEmbed(vision_embed_dim, patch, img_hw, names, base_name + names.patch_embed_name); rot_pos_emb = VisionRoPE((vision_embed_dim / head_size) / 2, spatial_merge_size, base_name + ".rot_pos_emb"); - blocks = List(block_num, vision_embed_dim, head_size, mlp_hidden_dim, act_fn_type, names, base_name + names._layer_name); + blocks = List(block_num, vision_embed_dim, head_size, mlp_hidden_dim, act_fn_type, attn_implementation, names, base_name + names._layer_name); patch_merger = PatchMerger(hidden_dim, vision_embed_dim, spatial_merge_size, names, base_name + names._merger_name); } vector Forward(vector inputs, vector args) override { @@ -219,6 +225,8 @@ class QWen2Attention final : public Module { head_dim = config.hidden_size / num_heads; num_key_value_heads = config.num_key_value_heads; num_key_value_groups = num_heads / num_key_value_heads; + name = base_name; + attn_impl = config.attn_implementation; // init layers q_proj = Linear(hidden_size, num_heads * head_dim, true, base_name + names._q_proj_name); @@ -229,14 +237,13 @@ class QWen2Attention final : public Module { o_proj = Linear(num_heads * head_dim, hidden_size, false, base_name + names._o_proj_name); q_rope = MultimodalRoPE(config.rope_theta, config.max_position_embeddings, config.mrope_section, base_name + "q_rope"); k_rope = MultimodalRoPE(config.rope_theta, config.max_position_embeddings, config.mrope_section, base_name + "k_rope"); - k_cache = KVCache(num_key_value_heads, head_dim, num_key_value_groups, config.cache_limit, base_name + "k_cache"); - v_cache = KVCache(num_key_value_heads, head_dim, num_key_value_groups, config.cache_limit, base_name + "v_cache"); + k_cache = KVCache(num_key_value_heads, head_dim, num_key_value_groups, config.cache_limit, (config.attn_implementation == "flash_attention_2"), base_name + "k_cache"); + v_cache = KVCache(num_key_value_heads, head_dim, num_key_value_groups, config.cache_limit, (config.attn_implementation == "flash_attention_2"), base_name + "v_cache"); softmax = Softmax(DIMENSION, true, base_name + "softmax"); } std::vector Forward(std::vector inputs, std::vector args) override { auto position_ids = inputs[1]; - auto query_states = q_proj(inputs[0]); auto key_states = k_proj(inputs[0]); auto value_states = v_proj(inputs[0]); @@ -247,11 +254,17 @@ class QWen2Attention final : public Module { key_states = k_rope(key_states, position_ids); key_states = k_cache(key_states); value_states = v_cache(value_states); - auto atten_weight = - Tensor::mm(query_states, key_states.transpose(Chl::SEQUENCE, Chl::DIMENSION)) - / std::sqrt(head_dim); - atten_weight = softmax(atten_weight, k_cache.getCacheSeqLen()); - auto atten_output = Tensor::mm(atten_weight, value_states); + + Tensor atten_output; + if (attn_impl == "flash_attention_2") { + atten_output = Tensor::flash_attention2_forward(query_states, key_states, value_states, true); + } else { // eager implementation + auto atten_weight = + Tensor::mm(query_states, key_states.transpose(Chl::SEQUENCE, Chl::DIMENSION)) + / std::sqrt(head_dim); + atten_weight = softmax(atten_weight, k_cache.getCacheSeqLen()); + atten_output = Tensor::mm(atten_weight, value_states); + } atten_output = atten_output.view(-1, 1, -1, head_dim * num_heads); atten_output = o_proj(atten_output); return {atten_output}; @@ -279,6 +292,8 @@ class QWen2Attention final : public Module { KVCache k_cache; KVCache v_cache; Softmax softmax; + string name; + string attn_impl; }; // Copied from GemmaDecoder with Gemma->Qwen and set RmsNorm(without add_unit_offset) @@ -349,7 +364,7 @@ class Qwen2VLModel final : public Module { vision_start_token_id = config.vision_start_token_id; embed_tokens = Embedding(vocab_size, hidden_dim, qwen_names.token_embd_name); - visual = Qwen2VisionModel(hidden_dim, vision_embed_dim, 16, vision_embed_dim * 4, "QuickGELU", 14, 336, 32, spatial_merge_size, vision_names, vision_names.vison_model_name); + visual = Qwen2VisionModel(hidden_dim, vision_embed_dim, 16, vision_embed_dim * 4, "QuickGELU", 14, 336, 32, spatial_merge_size, config.attn_implementation, vision_names, vision_names.vison_model_name); blocks = List(config.num_hidden_layers, config, qwen_names, qwen_names.blk_name); norm = RMSNorm(hidden_dim, 1e-6, qwen_names.post_norm_name); @@ -575,4 +590,4 @@ class Qwen2VLModel final : public Module { return {position_ids, mrope_position_deltas}; } }; -#endif // MODELING_PHI3_HPP \ No newline at end of file +#endif // MODELING_QWEN2VL_HPP \ No newline at end of file diff --git a/src/models/qwen2_vl/modeling_qwen2_vl_npu.hpp b/src/models/qwen2_vl/modeling_qwen2_vl_npu.hpp new file mode 100644 index 000000000..e5025d65c --- /dev/null +++ b/src/models/qwen2_vl/modeling_qwen2_vl_npu.hpp @@ -0,0 +1,978 @@ +#ifndef MODELING_QWEN2VL_NPU_HPP +#define MODELING_QWEN2VL_NPU_HPP + +#include "Layer.hpp" +#include "Module.hpp" +#include "Tensor.hpp" +#include "Timing.hpp" +#include "Types.hpp" +#include "configuration_qwen2_vl.hpp" +#include "models/qwen2_vl/modeling_qwen2_vl.hpp" +#include +#include +#include +#include + +using namespace mllm; + +// current version of showui/qwen2-vl don't need shadow layers +std::set qwenvlShadowLayers = {100}; + +// NPU QKV part +class QwenDecoderNPUPart1 : public Module { +protected: + int hidden_size; + int num_heads; + int head_dim; + int num_key_value_heads; + int num_key_value_groups; + + // it is for speed up the QNN linear implemented by conv, TODO: should integrate into QNNLinear + Layer pre_attn_view; + + Layer q_proj; + Layer k_proj; + Layer v_proj; + + Layer q_view; + Layer k_view; + Layer v_view; + + Layer q_dequant; + Layer k_dequant; + Layer v_dequant; + Layer v_transpose; + +public: + QwenDecoderNPUPart1() = default; + QwenDecoderNPUPart1(const Qwen2VLConfig &config, const QWenNameConfig &names, int chunk_size, const string &base_name) { + hidden_size = config.hidden_size; + num_heads = config.num_attention_heads; + head_dim = config.hidden_size / num_heads; + num_key_value_heads = config.num_key_value_heads; + num_key_value_groups = num_heads / num_key_value_heads; + + pre_attn_view = View(1, utils::closestFactors(chunk_size).first, utils::closestFactors(chunk_size).second, num_heads * head_dim, base_name + "ires_split-00_view_"); + + q_proj = Linear(hidden_size, num_heads * head_dim, false, base_name + names._q_proj_name); + k_proj = Linear(hidden_size, num_key_value_heads * head_dim, false, base_name + names._k_proj_name); + v_proj = Linear(hidden_size, num_key_value_heads * head_dim, false, base_name + names._v_proj_name); + + q_view = View(1, num_heads, chunk_size, head_dim, base_name + names._q_proj_name + "-00_view_"); + k_view = View(1, num_key_value_heads, chunk_size, head_dim, base_name + names._k_proj_name + "-00_view_"); + v_view = View(1, num_key_value_heads, chunk_size, head_dim, base_name + names._v_proj_name + "-00_view_"); + + q_dequant = Dequantize(true, base_name + names._q_proj_name + ".dequantize", true, MLLM_TYPE_I16); + k_dequant = Dequantize(true, base_name + names._k_proj_name + ".dequantize", false, MLLM_TYPE_I16); + v_dequant = Dequantize(true, base_name + names._v_proj_name + ".dequantize", false, MLLM_TYPE_I16); + + v_transpose = Transpose({0, 2, 3, 1}, base_name + names._v_proj_name + ".transpose"); + } + + vector Forward(vector inputs, vector args) override { + auto x = pre_attn_view(inputs[0]); + + auto query_states = q_proj(x); + auto key_states = k_proj(x); + auto value_states = v_proj(x); + + query_states = q_view(query_states); + key_states = k_view(key_states); + value_states = v_view(value_states); + + // return {query_states, key_states, value_states}; + + query_states = q_dequant(query_states); + key_states = k_dequant(key_states); + value_states = v_dequant(value_states); + + value_states = v_transpose(value_states); + return {query_states, key_states, value_states}; + } +}; + +class QwenDecoderNPUPart1WithRes final : public QwenDecoderNPUPart1 { + Layer input_layernorm; + Layer pre_attn_quantize; + +public: + QwenDecoderNPUPart1WithRes() = default; + QwenDecoderNPUPart1WithRes(const Qwen2VLConfig &config, const QWenNameConfig &names, int chunk_size, const string &base_name) { + hidden_size = config.hidden_size; + num_heads = config.num_attention_heads; + head_dim = config.hidden_size / num_heads; + num_key_value_heads = config.num_key_value_heads; + num_key_value_groups = num_heads / num_key_value_heads; + + // remove "self_attn." in base_name + auto layer_base_name = base_name.substr(0, base_name.size() - 10); + input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, layer_base_name + names._attn_norm_name); + pre_attn_quantize = Quantize(true, layer_base_name + names._attn_base_name + names._q_proj_name + ".quantize", MLLM_TYPE_I16); + + pre_attn_view = View(1, utils::closestFactors(chunk_size).first, utils::closestFactors(chunk_size).second, num_heads * head_dim, base_name + "ires_split-00_view_"); + + q_proj = Linear(hidden_size, num_heads * head_dim, false, base_name + names._q_proj_name); + k_proj = Linear(hidden_size, num_key_value_heads * head_dim, false, base_name + names._k_proj_name); + v_proj = Linear(hidden_size, num_key_value_heads * head_dim, false, base_name + names._v_proj_name); + + q_view = View(1, num_heads, chunk_size, head_dim, base_name + names._q_proj_name + "-00_view_"); + k_view = View(1, num_key_value_heads, chunk_size, head_dim, base_name + names._k_proj_name + "-00_view_"); + v_view = View(1, num_key_value_heads, chunk_size, head_dim, base_name + names._v_proj_name + "-00_view_"); + + q_dequant = Dequantize(true, base_name + names._q_proj_name + ".dequantize", true, MLLM_TYPE_I16); + k_dequant = Dequantize(true, base_name + names._k_proj_name + ".dequantize", false, MLLM_TYPE_I16); + v_dequant = Dequantize(true, base_name + names._v_proj_name + ".dequantize", false, MLLM_TYPE_I16); + + v_transpose = Transpose({0, 2, 3, 1}, base_name + names._v_proj_name + ".transpose"); + } + + vector Forward(vector inputs, vector args) override { + auto x = input_layernorm(inputs[0]); + x = pre_attn_quantize(x); + + x = pre_attn_view(x); + + auto query_states = q_proj(x); + auto key_states = k_proj(x); + auto value_states = v_proj(x); + + query_states = q_view(query_states); + key_states = k_view(key_states); + value_states = v_view(value_states); + + query_states = q_dequant(query_states); + key_states = k_dequant(key_states); + value_states = v_dequant(value_states); + + value_states = v_transpose(value_states); + return {query_states, key_states, value_states, inputs[0]}; + } +}; + +// CPU QKV MM part +class QwenQKVmm final : public Module { + MultimodalRoPE q_rope; + MultimodalRoPE k_rope; + KVCache k_cache; + KVCache v_cache; + Softmax softmax; + Layer o_quantize; + + int hidden_size; + int num_heads; + int head_dim; + int num_key_value_heads; + int num_key_value_groups; + +public: + QwenQKVmm() = default; + QwenQKVmm(const Qwen2VLConfig &config, const QWenNameConfig &names, int chunk_size, const string &base_name) { + hidden_size = config.hidden_size; + num_heads = config.num_attention_heads; + head_dim = config.hidden_size / num_heads; + + q_rope = MultimodalRoPE(config.rope_theta, config.max_position_embeddings, config.mrope_section, base_name + "q_rope"); + k_rope = MultimodalRoPE(config.rope_theta, config.max_position_embeddings, config.mrope_section, base_name + "k_rope"); + + k_cache = KVCache(config.num_key_value_heads, head_dim, config.num_attention_heads / config.num_key_value_heads, config.cache_limit, base_name + "k_cache", true); + v_cache = KVCache(config.num_key_value_heads, head_dim, config.num_attention_heads / config.num_key_value_heads, config.cache_limit, base_name + "v_cache", true); + + softmax = Softmax(DIMENSION, true, base_name + "softmax"); + + o_quantize = Quantize(true, base_name + names._o_proj_name + ".quantize"); + } + + vector Forward(vector inputs, vector args) override { + // TODO: remove it + // auto qkv_start = mllm_time_ms(); + auto position_ids = inputs[3]; + + auto q = inputs[0]; + auto k = inputs[1]; + auto v = inputs[2]; + + q = q_rope(q, position_ids); + k = k_rope(k, position_ids); + + k = k_cache(k); + v = v_cache(v); + + auto qk = Tensor::mm(q, k.transpose(Chl::SEQUENCE, Chl::DIMENSION)); + qk = qk / std::sqrt(head_dim); + qk = softmax(qk); + auto o = Tensor::mm(qk, v); + + o = o_quantize(o); + + // TODO: remove it + // auto qkv_end = mllm_time_ms(); + // std::cout << "QKV mm time: " << qkv_end - qkv_start << "ms" << std::endl; + + return {o}; + } +}; + +// QNN mlp part +class QwenDecoderNPUPart2 : public Module { +protected: + int hidden_size; + int num_heads; + int head_dim; + int num_key_value_heads; + int num_key_value_groups; + int intermediate_size; + + // NPU part2 of attention + Layer pre_oproj_view; + Layer out_proj; + Layer post_oproj_view; + Layer post_oproj_dequantize; + + // NPU mlp + Layer pre_mlp_quantize; + Layer pre_mlp_view; + Layer gate_proj; + Layer up_proj; + Layer post_up_proj_dequantize; + Layer post_gate_proj_dequantize; + Layer silu; + Layer post_attn_layernorm; + + Layer down_proj; + Layer pre_down_proj_quantize; + Layer post_down_proj_dequantize; + Layer post_mlp_view; + + Layer post_atten_res_add; + Layer post_mlp_res_add; + Layer mlp_mul; + +public: + QwenDecoderNPUPart2() = default; + QwenDecoderNPUPart2(const Qwen2VLConfig &config, const QWenNameConfig &names, int chunk_size, const string &base_name) { + hidden_size = config.hidden_size; + num_heads = config.num_attention_heads; + head_dim = config.hidden_size / num_heads; + intermediate_size = config.intermediate_size; + num_key_value_heads = config.num_key_value_heads; + num_key_value_groups = num_heads / num_key_value_heads; + + // for QNN linear speed up + pre_oproj_view = View(1, utils::closestFactors(chunk_size).first, utils::closestFactors(chunk_size).second, head_dim * num_heads, base_name + names._attn_base_name + "or_split-00_view_"); + out_proj = Linear(hidden_size, hidden_size, false, base_name + names._attn_base_name + names._o_proj_name); + post_oproj_dequantize = Dequantize(true, base_name + names._attn_base_name + names._o_proj_name + ".dequantize"); + post_oproj_view = View(1, 1, chunk_size, hidden_size, base_name + names._attn_base_name + names._o_proj_name + ".dequantize-00_view_"); + post_atten_res_add = Add(base_name + names._attn_base_name + "post_atten_add"); + + post_attn_layernorm = + RMSNorm(config.hidden_size, config.rms_norm_eps, base_name + names._ffn_norm_name); + + auto mlp_base_name = base_name + names._ffn_base_name; + pre_mlp_quantize = Quantize(true, mlp_base_name + names._up_proj_name + ".quantize"); + pre_mlp_view = View(1, utils::closestFactors(chunk_size).first, utils::closestFactors(chunk_size).second, hidden_size, mlp_base_name + names._up_proj_name + ".quantize-00_view_"); + gate_proj = Linear(hidden_size, intermediate_size, false, mlp_base_name + names._gate_proj_name); + silu = SiLU(mlp_base_name + "act"); + up_proj = Linear(hidden_size, intermediate_size, false, mlp_base_name + names._up_proj_name); + post_up_proj_dequantize = Dequantize(true, mlp_base_name + names._up_proj_name + ".dequantize"); + post_gate_proj_dequantize = Dequantize(true, mlp_base_name + names._gate_proj_name + ".dequantize"); + + down_proj = Linear(intermediate_size, hidden_size, false, mlp_base_name + names._down_proj_name); + pre_down_proj_quantize = Quantize(true, mlp_base_name + names._down_proj_name + ".quantize"); + post_down_proj_dequantize = Dequantize(true, mlp_base_name + names._down_proj_name + ".dequantize"); + post_mlp_view = View(1, 1, chunk_size, hidden_size, mlp_base_name + names._down_proj_name + ".dequantize-00_view_"); + + mlp_mul = Mul(mlp_base_name + "mul"); + post_mlp_res_add = Add(mlp_base_name + "res_add"); + } + + std::vector Forward(std::vector inputs, std::vector args) override { + auto atten_output = inputs[0]; + auto res = inputs[1]; + + atten_output = pre_oproj_view(atten_output); + atten_output = out_proj(atten_output); + atten_output = post_oproj_dequantize(atten_output); + auto float_oproj = post_oproj_view(atten_output); + + auto tmp = post_atten_res_add(float_oproj, res); + + auto x = post_attn_layernorm(tmp); + + x = pre_mlp_quantize(x); + // reshape to 32,2 + x = pre_mlp_view(x); + + auto gate_out = gate_proj(x); + auto up_out = up_proj(x); + + gate_out = post_gate_proj_dequantize(gate_out); + auto silu_out = silu(gate_out); + + up_out = post_up_proj_dequantize(up_out); + gate_out = mlp_mul(silu_out, up_out); + + gate_out = pre_down_proj_quantize(gate_out); + gate_out = down_proj(gate_out); + gate_out = post_down_proj_dequantize(gate_out); + + // reshape to 64,1 + auto float_gate_out = post_mlp_view(gate_out); + + gate_out = post_mlp_res_add(float_gate_out, tmp); + return {gate_out, float_oproj, silu_out, float_gate_out}; + } +}; + +class QwenDecoderNPUPart2WithShadow final : public QwenDecoderNPUPart2 { +public: + QwenDecoderNPUPart2WithShadow() = default; + QwenDecoderNPUPart2WithShadow(const Qwen2VLConfig &config, const QWenNameConfig &names, int chunk_size, const string &base_name) { + hidden_size = config.hidden_size; + num_heads = config.num_attention_heads; + head_dim = config.hidden_size / num_heads; + intermediate_size = config.intermediate_size; + num_key_value_heads = config.num_key_value_heads; + num_key_value_groups = num_heads / num_key_value_heads; + + // for QNN linear speed up + pre_oproj_view = View(1, utils::closestFactors(chunk_size).first, utils::closestFactors(chunk_size).second, head_dim * num_heads, base_name + names._attn_base_name + "or_split-00_view_"); + out_proj = Linear(hidden_size, hidden_size, false, base_name + names._attn_base_name + names._o_proj_name); + post_oproj_dequantize = Dequantize(true, base_name + names._attn_base_name + names._o_proj_name + ".dequantize"); + post_oproj_view = View(1, 1, chunk_size, hidden_size, base_name + names._attn_base_name + names._o_proj_name + ".dequantize-00_view_"); + post_atten_res_add = Add(base_name + names._attn_base_name + "post_atten_add"); + + post_attn_layernorm = + RMSNorm(config.hidden_size, config.rms_norm_eps, base_name + names._ffn_norm_name); + + auto mlp_base_name = base_name + names._ffn_base_name; + pre_mlp_quantize = Quantize(true, mlp_base_name + names._up_proj_name + ".quantize"); + pre_mlp_view = View(1, utils::closestFactors(chunk_size).first, utils::closestFactors(chunk_size).second, hidden_size, mlp_base_name + names._up_proj_name + ".quantize-00_view_"); + gate_proj = Linear(hidden_size, intermediate_size, false, mlp_base_name + names._gate_proj_name); + silu = SiLU(mlp_base_name + "act"); + up_proj = Linear(hidden_size, intermediate_size, false, mlp_base_name + names._up_proj_name); + post_up_proj_dequantize = Dequantize(true, mlp_base_name + names._up_proj_name + ".dequantize"); + post_gate_proj_dequantize = Dequantize(true, mlp_base_name + names._gate_proj_name + ".dequantize"); + + down_proj = Linear(intermediate_size, hidden_size, false, mlp_base_name + names._down_proj_name); + pre_down_proj_quantize = Quantize(true, mlp_base_name + names._down_proj_name + ".quantize"); + post_down_proj_dequantize = Dequantize(true, mlp_base_name + names._down_proj_name + ".dequantize"); + post_mlp_view = View(1, 1, chunk_size, hidden_size, mlp_base_name + names._down_proj_name + ".dequantize-00_view_"); + + mlp_mul = Mul(mlp_base_name + "mul"); + post_mlp_res_add = Add(mlp_base_name + "res_add"); + } + + std::vector Forward(std::vector inputs, std::vector args) override { + auto atten_output = inputs[0]; + auto res = inputs[1]; + + atten_output = pre_oproj_view(atten_output); + atten_output = out_proj(atten_output); + atten_output = post_oproj_dequantize(atten_output); + atten_output = post_oproj_view(atten_output); + + auto tmp = post_atten_res_add(atten_output, res); + + auto x = post_attn_layernorm(tmp); + + x = pre_mlp_quantize(x); + // reshape to 32,2 + x = pre_mlp_view(x); + + auto gate_out = gate_proj(x); + auto up_out = up_proj(x); + + gate_out = post_gate_proj_dequantize(gate_out); + gate_out = silu(gate_out); + + up_out = post_up_proj_dequantize(up_out); + gate_out = mlp_mul(gate_out, up_out); + + auto shadow_input_1 = gate_out; + + gate_out = pre_down_proj_quantize(gate_out); + gate_out = down_proj(gate_out); + auto shadow_input_2 = gate_out; + gate_out = post_down_proj_dequantize(gate_out); + + // reshape to 64,1 + gate_out = post_mlp_view(gate_out); + + gate_out = post_mlp_res_add(gate_out, tmp); + + return {shadow_input_1, shadow_input_2, gate_out}; + } +}; + +class QwenNPU_CPUDecoder final : public Module { + int hidden_size; + int num_heads; + int head_dim; + int num_key_value_heads; + int num_key_value_groups; + + int layer_idx; + int num_layers; + + SubgraphStart _SubgraphStart_1, _SubgraphStart_2; + SubgraphFinalize _SubgraphEnd_1, _SubgraphEnd_2; + + Layer input_layernorm; + Layer pre_attn_quantize; + unique_ptr part1; + QwenQKVmm qkv_mm; + unique_ptr part2; + +public: + QwenNPU_CPUDecoder() = default; + QwenNPU_CPUDecoder(const Qwen2VLConfig &config, const QWenNameConfig &names, int chunk_size, const string &base_name) { + hidden_size = config.hidden_size; + num_heads = config.num_attention_heads; + head_dim = config.hidden_size / num_heads; + num_key_value_heads = config.num_key_value_heads; + num_key_value_groups = num_heads / num_key_value_heads; + + // extract layer index from base_name like "model.layers.10." + std::regex re(R"(\d+)"); + std::smatch match; + std::regex_search(base_name, match, re); + layer_idx = std::stoi(match[0]); + num_layers = config.num_hidden_layers; + + if (layer_idx == 0 || qwenvlShadowLayers.find(layer_idx - 1) != qwenvlShadowLayers.end()) { + input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, base_name + names._attn_norm_name); + pre_attn_quantize = Quantize(true, base_name + names._attn_base_name + names._q_proj_name + ".quantize", MLLM_TYPE_I16); + part1 = make_unique(config, names, chunk_size, base_name + names._attn_base_name); + } else { + part1 = make_unique(config, names, chunk_size, base_name + names._attn_base_name); + } + + qkv_mm = QwenQKVmm(config, names, chunk_size, base_name + names._attn_base_name); + + part2 = make_unique(config, names, chunk_size, base_name); + + _SubgraphStart_1 = SubgraphStart(base_name + "subgraph_start1"); + _SubgraphEnd_1 = SubgraphFinalize(base_name + "subgraph_end1"); + _SubgraphStart_2 = SubgraphStart(base_name + "subgraph_start2"); + _SubgraphEnd_2 = SubgraphFinalize(base_name + "subgraph_end2"); + } + + vector Forward(vector inputs, vector args) override { + auto position_ids = inputs[1]; + + Tensor x, q, k, v, res; + if (layer_idx == 0 || qwenvlShadowLayers.find(layer_idx - 1) != qwenvlShadowLayers.end()) { + x = input_layernorm(inputs[0]); + + x = pre_attn_quantize(x); + + _SubgraphStart_1({x}); + + auto q_k_v = (*part1)({x}); // q,k,v + q = q_k_v[0]; + k = q_k_v[1]; + v = q_k_v[2]; + res = inputs[0]; + _SubgraphEnd_1(q_k_v); + + } else { + auto q_k_v_res = (*part1)(inputs); // q,k,v,res + q = q_k_v_res[0]; + k = q_k_v_res[1]; + v = q_k_v_res[2]; + res = q_k_v_res[3]; + _SubgraphEnd_1(q_k_v_res); + } + + auto o_x = qkv_mm({q, k, v, position_ids})[0]; + + _SubgraphStart_2({o_x, res}); + + auto out_part2 = (*part2)({o_x, res}); + + if (layer_idx == num_layers - 1) { + _SubgraphEnd_2(out_part2); + } + + return out_part2; + } +}; + +class QwenNPU_CPUDecoderWithShadow final : public Module { + int hidden_size; + int num_heads; + int head_dim; + int num_key_value_heads; + int num_key_value_groups; + + Layer input_layernorm; + Layer pre_attn_quantize; + Layer shadow_linear; + unique_ptr part1; + QwenQKVmm qkv_mm; + unique_ptr part2; + + int layer_idx; + int num_layers; + + SubgraphStart _SubgraphStart_1, _SubgraphStart_2; + SubgraphFinalize _SubgraphEnd_1, _SubgraphEnd_2; + +public: + QwenNPU_CPUDecoderWithShadow() = default; + QwenNPU_CPUDecoderWithShadow(const Qwen2VLConfig &config, const QWenNameConfig &names, int chunk_size, const string &base_name) { + hidden_size = config.hidden_size; + num_heads = config.num_attention_heads; + head_dim = config.hidden_size / num_heads; + num_key_value_heads = config.num_key_value_heads; + num_key_value_groups = num_heads / num_key_value_heads; + + // extract layer index from base_name like "model.layers.10." + std::regex re(R"(\d+)"); + std::smatch match; + std::regex_search(base_name, match, re); + layer_idx = std::stoi(match[0]); + num_layers = config.num_hidden_layers; + + if (layer_idx == 0 || qwenvlShadowLayers.find(layer_idx - 1) != qwenvlShadowLayers.end()) { + input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, base_name + names._attn_norm_name); + pre_attn_quantize = Quantize(true, base_name + names._attn_base_name + names._q_proj_name + ".quantize", MLLM_TYPE_I16); + part1 = make_unique(config, names, chunk_size, base_name + names._attn_base_name); + } else { + part1 = make_unique(config, names, chunk_size, base_name + names._attn_base_name); + } + + qkv_mm = QwenQKVmm(config, names, chunk_size, base_name + names._attn_base_name); + + part2 = make_unique(config, names, chunk_size, base_name); + + shadow_linear = ShadowLinear(config.intermediate_size, hidden_size, 1024, false, base_name + names._ffn_base_name + names._down_proj_name + ".shadow"); + + _SubgraphStart_1 = SubgraphStart(base_name + "subgraph_start1"); + _SubgraphEnd_1 = SubgraphFinalize(base_name + "subgraph_end1"); + _SubgraphStart_2 = SubgraphStart(base_name + "subgraph_start2"); + _SubgraphEnd_2 = SubgraphFinalize(base_name + "subgraph_end2"); + } + + vector Forward(vector inputs, vector args) override { + auto position_ids = inputs[1]; + + Tensor x, q, k, v, res; + if (layer_idx == 0 || qwenvlShadowLayers.find(layer_idx - 1) != qwenvlShadowLayers.end()) { + x = input_layernorm(inputs[0]); + x = pre_attn_quantize(x); + + _SubgraphStart_1({x}); + + auto q_k_v = (*part1)({x}); // q,k,v + q = q_k_v[0]; + k = q_k_v[1]; + v = q_k_v[2]; + res = inputs[0]; + _SubgraphEnd_1(q_k_v); + } else { + auto q_k_v_res = (*part1)(inputs); // q,k,v,res + q = q_k_v_res[0]; + k = q_k_v_res[1]; + v = q_k_v_res[2]; + res = q_k_v_res[3]; + _SubgraphEnd_1(q_k_v_res); + } + + auto o_x = qkv_mm({q, k, v, position_ids})[0]; + + _SubgraphStart_2({o_x, res}); + + auto decoder_out = (*part2)({o_x, res}); + decoder_out = Tensor::toCPU(decoder_out); + + _SubgraphEnd_2(decoder_out); + + auto shadow_input_1 = decoder_out[0]; + auto shadow_input_2 = decoder_out[1]; + x = decoder_out[2]; + + x = shadow_linear(shadow_input_1, shadow_input_2, x); + + return {x}; + } +}; + +class Qwen2VL_ImagePatchAndEmbedding final : public Module { + Qwen2VisionModel visual; + Layer embed_tokens; + + Layer norm; + Parameter lm_head; + Layer lm_head_layer; + + bool tie_embedding_words; + + int64_t spatial_merge_size; + int64_t image_token_id; + int64_t video_token_id; + int64_t vision_start_token_id; + +public: + explicit Qwen2VL_ImagePatchAndEmbedding(const Qwen2VLConfig &config) { + auto vocab_size = config.vocab_size; + auto hidden_dim = config.hidden_size; + auto head_size = config.num_attention_heads; + auto ffn_hidden = config.intermediate_size; + auto projection_cls = config.projection_cls; + auto vision_embed_dim = config.vision_embed_dim; + image_token_id = config.image_token_id; + auto vision_names = config.vision_names_config; + auto qwen_names = config.names_config; + tie_embedding_words = config.tie_embedding_words; + spatial_merge_size = config.spatial_merge_size; + image_token_id = config.image_token_id; + video_token_id = config.video_token_id; + vision_start_token_id = config.vision_start_token_id; + + embed_tokens = Embedding(vocab_size, hidden_dim, qwen_names.token_embd_name); + visual = Qwen2VisionModel(hidden_dim, vision_embed_dim, 16, vision_embed_dim * 4, "QuickGELU", 14, 336, 32, spatial_merge_size, config.attn_implementation, vision_names, vision_names.vison_model_name); + } + + vector Forward(vector inputs, vector args) override { + auto hidden_states = embed_tokens({inputs[0]}); + + auto image_embeds = visual({inputs[1], inputs[2]})[0]; + auto n_image_features = image_embeds.sequence(); + auto where_idx = inputs[0].where(image_token_id, SEQUENCE); + hidden_states = hidden_states.index_put(image_embeds, where_idx, false); + + return {hidden_states}; + } + + // changed from get_position_ids in CPU Qwen2VL, enable padding + // when prefilling, padding_to should be the max length of the input + // when decoding, real_seq should be the real length of the input, thus get the correct position_ids for decoding + void get_position_ids(vector &inputs, int padding_to = 0, int real_seq = 0) { + if (inputs[0].sequence() > 1) { + Tensor video_grid_thw(0, 0, 0, 0, MLLM_CPU, true); + auto rope_indices = get_rope_index_cpp(inputs[0], inputs[2], video_grid_thw, padding_to); + auto position = rope_indices[0]; + if (inputs.size() == 4) { + inputs[3] = position; + } else { + inputs.push_back(position); + } + } else { + auto &position_ids = inputs[3]; + auto last_pos = real_seq == 0 ? position_ids.dataAt(0, 0, 0, position_ids.dimension() - 1) : real_seq - 1; + position_ids.reshape(position_ids.batch(), 1, position_ids.sequence(), 1); + for (int b = 0; b < position_ids.batch(); b++) { + for (int s = 0; s < position_ids.sequence(); s++) { + position_ids.setDataAt(b, 0, s, 0, last_pos + 1); + } + } + } + } + +private: + vector get_rope_index_cpp( + Tensor input_ids, + Tensor image_grid_thw, + Tensor video_grid_thw, + int padding_to = 0) { + vector> attention_mask; + auto attention_mask_shape = input_ids.sequence(); + for (int b = 0; b < input_ids.batch(); b++) { + attention_mask.emplace_back(attention_mask_shape, 1); + } + const size_t batch_size = input_ids.batch(); // input_ids.size(); + + // NOTE: changed from original + const size_t seq_len = batch_size > 0 ? (padding_to > input_ids.sequence() ? padding_to : input_ids.sequence()) : 0; // batch_size > 0 ? input_ids[0].size() : 0; + + Tensor position_ids(3, 1, batch_size, seq_len, Backend::global_backends[MLLM_CPU], true); + Tensor mrope_position_deltas(1, 1, 1, batch_size, Backend::global_backends[MLLM_CPU], true); + bool has_vision = (image_grid_thw.sequence() > 0) || (video_grid_thw.sequence() > 0); // image_grid_thw || video_grid_thw; + if (!has_vision) { + // Pure text case + for (size_t i = 0; i < batch_size; ++i) { + const auto &mask = !attention_mask.empty() ? attention_mask[i] : vector(seq_len, 1); + vector positions; + int64_t pos = 0; + for (size_t j = 0; j < seq_len; ++j) { + if (mask[j] == 1) { + positions.push_back(pos++); + } else { + positions.push_back(1); // Will be overwritten by mask + } + } + for (int dim = 0; dim < 3; ++dim) { + for (size_t j = 0; j < seq_len; ++j) { + position_ids.setDataAt(dim, 0, i, j, (float)(mask[j] == 1 ? positions[j] : 1)); + } + } + int64_t max_pos = pos - 1; + mrope_position_deltas.setDataAt(0, 0, 0, i, (float)((max_pos + 1) - static_cast(input_ids.sequence()))); + } + position_ids.setName("position_ids"); + mrope_position_deltas.setName("mrope_position_deltas"); + return {position_ids, mrope_position_deltas}; + } + // Process vision cases + size_t image_idx = 0, video_idx = 0; + for (size_t i = 0; i < batch_size; ++i) { + const auto &mask = !attention_mask.empty() ? attention_mask[i] : vector(seq_len, 1); + // Extract valid tokens + vector valid_tokens; + for (size_t j = 0; j < input_ids.sequence(); ++j) { + if (mask[j] == 1) valid_tokens.push_back((int)input_ids.dataAt(i, 0, j, 0)); + } + // Find vision start positions + vector vision_starts; + vector vision_types; + for (size_t j = 0; j < valid_tokens.size(); ++j) { + if (valid_tokens[j] == vision_start_token_id && j + 1 < valid_tokens.size()) { + vision_starts.push_back(j); + vision_types.push_back(valid_tokens[j + 1]); + } + } + int64_t image_count = count(vision_types.begin(), vision_types.end(), image_token_id); + int64_t video_count = vision_types.size() - image_count; + vector> llm_positions(3); + size_t st = 0; + int64_t current_max = 0; + int64_t remain_images = image_count; + int64_t remain_videos = video_count; + // Process each vision segment + for (size_t vs = 0; vs < vision_starts.size(); ++vs) { + // Find next vision token + size_t ed_image = valid_tokens.size(); + size_t ed_video = valid_tokens.size(); + if (remain_images > 0) { + auto it = find(valid_tokens.begin() + st, valid_tokens.end(), image_token_id); + if (it != valid_tokens.end()) ed_image = it - valid_tokens.begin(); + } + if (remain_videos > 0) { + auto it = find(valid_tokens.begin() + st, valid_tokens.end(), video_token_id); + if (it != valid_tokens.end()) ed_video = it - valid_tokens.begin(); + } + size_t ed = min(ed_image, ed_video); + if (ed == valid_tokens.size()) break; + // Get grid parameters + int64_t t, h, w; + bool is_image = (ed == ed_image); + if (is_image) { + t = (int64_t)image_grid_thw.dataAt(0, 0, image_idx, 0); + h = (int64_t)image_grid_thw.dataAt(0, 0, image_idx, 1); + w = (int64_t)image_grid_thw.dataAt(0, 0, image_idx, 2); + image_idx++; + remain_images--; + } else { + t = (int64_t)video_grid_thw.dataAt(0, 0, video_idx, 0); + h = (int64_t)video_grid_thw.dataAt(0, 0, video_idx, 1); + w = (int64_t)video_grid_thw.dataAt(0, 0, video_idx, 2); + video_idx++; + remain_videos--; + } + // Calculate grid dimensions + int64_t llm_grid_t = t; + int64_t llm_grid_h = h / spatial_merge_size; + int64_t llm_grid_w = w / spatial_merge_size; + // Process text segment + size_t text_len = ed - st; + if (text_len > 0) { + int64_t start_idx = current_max; + for (int64_t k = 0; k < text_len; ++k) { + for (int dim = 0; dim < 3; ++dim) { + llm_positions[dim].push_back(start_idx + k); + } + } + current_max += text_len; + } + for (int64_t ti = 0; ti < llm_grid_t; ++ti) { + for (int64_t hi = 0; hi < llm_grid_h; ++hi) { + for (int64_t wi = 0; wi < llm_grid_w; ++wi) { + llm_positions[0].push_back(current_max + ti); + llm_positions[1].push_back(current_max + hi); + llm_positions[2].push_back(current_max + wi); + } + } + } + current_max = std::max({llm_positions[0][llm_positions[0].size() - 1], + llm_positions[1][llm_positions[1].size() - 1], + llm_positions[2][llm_positions[2].size() - 1]}); + st = ed + llm_grid_t * llm_grid_h * llm_grid_w; + } + // Process remaining text + if (st < valid_tokens.size()) { + size_t text_len = valid_tokens.size() - st; + int64_t st_idx = current_max + 1; + for (int64_t k = 0; k < text_len; ++k) { + for (int dim = 0; dim < 3; ++dim) { + llm_positions[dim].push_back(st_idx + k); + } + } + current_max += text_len; + } + // Fill position_ids with valid positions + size_t valid_idx = 0; + for (size_t j = 0; j < seq_len; ++j) { + if (mask[j] == 1) { + if (valid_idx < llm_positions[0].size()) { + position_ids.setDataAt(0, 0, i, j, (float)llm_positions[0][valid_idx]); + position_ids.setDataAt(1, 0, i, j, (float)llm_positions[1][valid_idx]); + position_ids.setDataAt(2, 0, i, j, (float)llm_positions[2][valid_idx]); + valid_idx++; + } + } + } + // Calculate delta + int64_t max_pos = 0; + for (const auto &dim : llm_positions) { + for (auto val : dim) { + max_pos = max(max_pos, val); + } + } + mrope_position_deltas.setDataAt(0, 0, 0, i, (float)((max_pos + 1) - static_cast(input_ids.sequence()))); + } + position_ids.setName("position_ids"); + mrope_position_deltas.setName("mrope_position_deltas"); + return {position_ids, mrope_position_deltas}; + } +}; + +class Qwen2VL_PrefillBody final : public Module { + std::vector> blocks; + Layer norm; + Parameter lm_head; + Layer lm_head_layer; + int num_layer; + + bool tie_embedding_words; + + template + static vector> ListWithShadow(int n, Args &&...args) { + static_assert(std::is_base_of::value, "T1 must be a subclass of Module"); + static_assert(std::is_base_of::value, "SHADOW must be a subclass of Module"); + listIdx = 0; + vector> modules; + + // for index in shadowLayers, create shadow decoder, for others, create normal decoder + for (int i = 0; i < n; i++) { + auto new_args = change_last(args...); // 创建新的参数包,最后一个参数被修改为原来的值+ std::to_string(listIdx)+ "." + if (qwenvlShadowLayers.find(listIdx) != qwenvlShadowLayers.end()) { + modules.push_back(std::make_unique(std::apply([&](auto &&...args) { return SHADOW(std::forward(args)...); }, new_args))); + } else { + modules.push_back(std::make_unique(std::apply([&](auto &&...args) { return T1(std::forward(args)...); }, new_args))); + } + listIdx++; + } + listIdx = 0; + return modules; + } + +public: + explicit Qwen2VL_PrefillBody(const Qwen2VLConfig &config, int chunk_size) { + // Module::initBackend(MLLM_QNN); + auto vocab_size = config.vocab_size; + auto hidden_dim = config.hidden_size; + auto head_size = config.num_attention_heads; + auto qwen_names = config.names_config; + tie_embedding_words = config.tie_embedding_words; + + num_layer = config.num_hidden_layers; + + blocks = ListWithShadow(config.num_hidden_layers, config, qwen_names, chunk_size, qwen_names.blk_name); + norm = RMSNorm(hidden_dim, 1e-6, qwen_names.post_norm_name); + if (tie_embedding_words) { + lm_head = Parameter(1, config.vocab_size, 1, config.hidden_size, qwen_names.token_embd_name + ".weight"); + } else { + lm_head_layer = HeadLinear(config.hidden_size, config.vocab_size, false, qwen_names.lm_head_name); + } + } + + vector Forward(vector inputs, vector args) override { + auto hidden_states = inputs[0]; + auto position_ids = inputs[1]; + + for (auto &block : blocks) { + hidden_states = (*block)({hidden_states, position_ids})[0]; + } + + hidden_states = norm(hidden_states); + + if (tie_embedding_words) { + hidden_states = Tensor::mm(hidden_states, lm_head().transpose(Chl::SEQUENCE, Chl::DIMENSION)); + } else { + hidden_states = lm_head_layer(hidden_states); + } + return {hidden_states}; + } +}; + +// CPU decoding model with only the LLM backbone +class Qwen2VL_Decoding_Model final : public Module { + Layer embed_tokens; + + vector blocks; + Layer norm; + Parameter lm_head; + Layer lm_head_layer; + + bool tie_embedding_words; + + int64_t spatial_merge_size; + int64_t image_token_id; + int64_t video_token_id; + int64_t vision_start_token_id; + +public: + explicit Qwen2VL_Decoding_Model(const Qwen2VLConfig &config) { + auto vocab_size = config.vocab_size; + auto hidden_dim = config.hidden_size; + auto head_size = config.num_attention_heads; + auto ffn_hidden = config.intermediate_size; + auto projection_cls = config.projection_cls; + auto vision_embed_dim = config.vision_embed_dim; + image_token_id = config.image_token_id; + auto vision_names = config.vision_names_config; + auto qwen_names = config.names_config; + tie_embedding_words = config.tie_embedding_words; + spatial_merge_size = config.spatial_merge_size; + image_token_id = config.image_token_id; + video_token_id = config.video_token_id; + vision_start_token_id = config.vision_start_token_id; + + embed_tokens = Embedding(vocab_size, hidden_dim, qwen_names.token_embd_name); + + blocks = List(config.num_hidden_layers, config, qwen_names, qwen_names.blk_name); + norm = RMSNorm(hidden_dim, 1e-6, qwen_names.post_norm_name); + if (tie_embedding_words) { + lm_head = Parameter(1, config.vocab_size, 1, config.hidden_size, qwen_names.token_embd_name + ".weight"); + } else { + lm_head_layer = Linear(config.hidden_size, config.vocab_size, false, qwen_names.lm_head_name); + } + } + vector Forward(vector inputs, vector args) override { + auto position_ids = inputs[3]; + + auto hidden_states = embed_tokens({inputs[0]}); + + for (auto &block : blocks) { + hidden_states = block({hidden_states, position_ids})[0]; + } + hidden_states = norm(hidden_states); + if (tie_embedding_words) { + hidden_states = Tensor::mm(hidden_states, lm_head().transpose(Chl::SEQUENCE, Chl::DIMENSION)); + } else { + hidden_states = lm_head_layer(hidden_states); + } + return {hidden_states}; + } + void clear_kvcache() override { + for (auto &block : blocks) { + auto kvcahce = block.get_attention().get_cache(); + for (auto &cache : kvcahce) { + cache->clearCache(); + } + } + } +}; + +#endif // MODELING_QWEN2VL_NPU_HPP \ No newline at end of file diff --git a/src/models/qwen2_vl/processing_qwen2_vl.hpp b/src/models/qwen2_vl/processing_qwen2_vl.hpp index 3a70e972d..bed3638be 100644 --- a/src/models/qwen2_vl/processing_qwen2_vl.hpp +++ b/src/models/qwen2_vl/processing_qwen2_vl.hpp @@ -191,6 +191,12 @@ class Qwen2VLImageProcessor { vector> input_ids_; pair>, vector> preprocess_images(const uint8_t *image, const size_t &image_length) { auto imageinfos = vector(); + // int width, height, channels; + // auto data = stbi_load_from_memory(image, image_length, &width, &height, &channels, 3); + // if (data == nullptr) { + // MLLM_LOG_ERROR_STREAM << "Error: Failed to load image from memory." << std::endl; + // exit(-1); + // } int width, height, channels; auto data = stbi_load_from_memory(image, image_length, &width, &height, &channels, 0); if (data == nullptr) { @@ -431,12 +437,12 @@ class Qwen2VLProcessor final : public PreProcessor { return tokenizer->detokenize(tokens); } - std::pair detokenize(Tensor &result) { + std::pair detokenize(Tensor &result, int seq = 0) { assert(result.batch() == 1 && "Batch size of result is not 1. Which is not supported for now."); assert(result.head() == 1 && "The 3rd dim of result should be one. e.g.:[1, 1, seq, hidden]"); vector scores; int _dims = result.dimension(); - int _seq = result.sequence() - 1; + int _seq = seq == 0 ? result.sequence() - 1 : seq - 1; for (int i = 0; i < _dims; ++i) { auto value = result.dataAt(0, 0, _seq, i); scores.push_back(value); diff --git a/src/models/qwen2_vl/vtp/modeling_qwen2_vl.hpp b/src/models/qwen2_vl/vtp/modeling_qwen2_vl.hpp index 7243ca144..5d3c179e2 100644 --- a/src/models/qwen2_vl/vtp/modeling_qwen2_vl.hpp +++ b/src/models/qwen2_vl/vtp/modeling_qwen2_vl.hpp @@ -4,14 +4,21 @@ #ifndef MODELING_QWEN2VL_HPP #define MODELING_QWEN2VL_HPP +// #define VTP +#include +#define NDC + #include "Layer.hpp" #include "Module.hpp" #include "Tensor.hpp" #include "Types.hpp" #include "../configuration_qwen2_vl.hpp" -// #include "models/qwen/modeling_qwen.hpp" -// #include +#if defined(VTP) #include "vtp_tools.hpp" +#endif +#ifdef NDC +#include "ndc_tools.hpp" +#endif #include #include @@ -231,7 +238,8 @@ class QWen2Attention final : public Module { k_rope = MultimodalRoPE(config.rope_theta, config.max_position_embeddings, config.mrope_section, base_name + "k_rope"); k_cache = KVCache(num_key_value_heads, head_dim, num_key_value_groups, config.cache_limit, base_name + "k_cache"); v_cache = KVCache(num_key_value_heads, head_dim, num_key_value_groups, config.cache_limit, base_name + "v_cache"); - softmax = Softmax(DIMENSION, true, base_name + "softmax"); + mask = Causalmask(base_name + "mask"); + softmax = Softmax(DIMENSION, false, base_name + "softmax"); } std::vector Forward(std::vector inputs, std::vector args) override { @@ -246,24 +254,43 @@ class QWen2Attention final : public Module { value_states = value_states.view(-1, num_key_value_heads, -1, head_dim); query_states = q_rope(query_states, position_ids); key_states = k_rope(key_states, position_ids); - key_states = k_cache(key_states); - value_states = v_cache(value_states); + auto key_cache_states = k_cache(key_states); + auto value_cache_states = v_cache(value_states); +#if defined(NDC) + // ====================================================================================== + WHERE_TOKEN_PRUNING.get_kvcache(key_cache_states, value_cache_states, key_states, value_states, + layer_index, k_cache.getCacheSeqLen()); + //====================================================================================== +#endif auto atten_weight = - Tensor::mm(query_states, key_states.transpose(Chl::SEQUENCE, Chl::DIMENSION)) + Tensor::mm(query_states, key_cache_states.transpose(Chl::SEQUENCE, Chl::DIMENSION)) / std::sqrt(head_dim); +#if defined(NDC) + if (WHERE_TOKEN_PRUNING.causal_masks.find(layer_index) != WHERE_TOKEN_PRUNING.causal_masks.end() + && atten_weight.sequence() > 1) { + atten_weight = atten_weight + WHERE_TOKEN_PRUNING.causal_masks[layer_index]; + WHERE_TOKEN_PRUNING.causal_masks.erase(layer_index); + } else { +#endif + atten_weight = mask(atten_weight, k_cache.getCacheSeqLen()); +#if defined(NDC) + } +#endif atten_weight = softmax(atten_weight, k_cache.getCacheSeqLen()); - auto atten_output = Tensor::mm(atten_weight, value_states); + auto atten_output = Tensor::mm(atten_weight, value_cache_states); atten_output = atten_output.view(-1, 1, -1, head_dim * num_heads); atten_output = o_proj(atten_output); +#if defined(VTP) //====================================================================================== // pruning stage if (WHERE_TOKEN_PRUNING.is_prefill()) { WHERE_TOKEN_PRUNING.set_prefill_layer(layer_index); WHERE_TOKEN_PRUNING.update_attn_acc_score(atten_weight); - atten_output = WHERE_TOKEN_PRUNING.prunning_attn_output(atten_output); + atten_output = WHERE_TOKEN_PRUNING.prunning_attn_output(atten_output, layer_index); } //====================================================================================== - return {atten_output}; +#endif + return {atten_output, atten_weight}; } vector get_cache() { @@ -288,6 +315,7 @@ class QWen2Attention final : public Module { KVCache k_cache; KVCache v_cache; Softmax softmax; + Causalmask mask; }; // Copied from GemmaDecoder with Gemma->Qwen and set RmsNorm(without add_unit_offset) @@ -308,17 +336,26 @@ class QWen2Decoder final : public Module { auto position_ids = inputs[1]; auto residual = inputs[0]; auto x = input_layernorm(residual); - x = self_atten({x, position_ids}, layer_index)[0]; + auto xs = self_atten({x, position_ids}, layer_index); + x = xs[0]; + auto atten_weight = xs[1]; +#if defined(VTP) //====================================================================================== // pruning stage if (WHERE_TOKEN_PRUNING.is_prefill()) { residual = WHERE_TOKEN_PRUNING.pruning_(residual); } //====================================================================================== +#endif auto tmp = x + residual; x = post_attention_layernorm(tmp); x = mlp({x})[0]; x = x + tmp; +#if defined(NDC) + //====================================================================================== + WHERE_TOKEN_PRUNING.update_hidden_pos(x, atten_weight, layer_index); + //====================================================================================== +#endif return {x}; } QWen2Attention &get_attention() { @@ -342,6 +379,8 @@ class Qwen2VLModel final : public Module { Layer lm_head_layer; bool tie_embedding_words; + int num_hidden_layers; + int num_attention_heads; int64_t spatial_merge_size; int64_t image_token_id; @@ -361,6 +400,8 @@ class Qwen2VLModel final : public Module { auto qwen_names = config.names_config; tie_embedding_words = config.tie_embedding_words; spatial_merge_size = config.spatial_merge_size; + num_hidden_layers = config.num_hidden_layers; + num_attention_heads = config.num_attention_heads; image_token_id = config.image_token_id; video_token_id = config.video_token_id; vision_start_token_id = config.vision_start_token_id; @@ -377,12 +418,11 @@ class Qwen2VLModel final : public Module { } } vector Forward(vector inputs, vector args) override { - WHERE_TOKEN_PRUNING.init(); - if (inputs[0].sequence() <= 1) { - WHERE_TOKEN_PRUNING.prefill_stage = false; - } else { - WHERE_TOKEN_PRUNING.prefill_stage = true; - } +#if defined(VTP) || defined(NDC) + // ====================================================================================== + WHERE_TOKEN_PRUNING.init(inputs[0], num_hidden_layers, num_attention_heads); + // ====================================================================================== +#endif auto position_ids = inputs[3]; bool have_img = inputs[1].batch() > 0; auto hidden_states = embed_tokens({inputs[0]}); @@ -390,16 +430,27 @@ class Qwen2VLModel final : public Module { auto image_embeds = visual({inputs[1], inputs[2]})[0]; auto n_image_features = image_embeds.sequence(); auto where_idx = inputs[0].where(image_token_id, SEQUENCE); - // ======================================================================================================== +#if defined(VTP) || defined(NDC) + // ====================================================================================== // Pruning Stage 1 Start - if (WHERE_TOKEN_PRUNING.is_prefill()) { - WHERE_TOKEN_PRUNING.set_vision_token(where_idx, hidden_states, image_embeds); - } - // ======================================================================================================== + WHERE_TOKEN_PRUNING.set_vision_token(where_idx, hidden_states, image_embeds); + // ====================================================================================== +#endif hidden_states = hidden_states.index_put(image_embeds, where_idx, false); } +#if defined(NDC) + // ====================================================================================== + // if (WHERE_TOKEN_PRUNING.is_prefill()) { + auto past_kv_seq_len = blocks[0].get_attention().get_cache()[0]->getCacheSeqLen(); + if (past_kv_seq_len != -1) { + WHERE_TOKEN_PRUNING.ndc_prepare(hidden_states, position_ids, past_kv_seq_len); + } + // } + // ====================================================================================== +#endif int layer_index = 0; for (auto &block : blocks) { +#if defined(VTP) //====================================================================================== // pruning stage if (WHERE_TOKEN_PRUNING.is_prefill()) { @@ -407,7 +458,17 @@ class Qwen2VLModel final : public Module { position_ids = WHERE_TOKEN_PRUNING.pruning_pos(position_ids, DIMENSION); } //====================================================================================== +#endif hidden_states = block({hidden_states, position_ids}, layer_index)[0]; +#if defined(NDC) + // ====================================================================================== + // change position_ids + auto kv_seq_len = blocks[layer_index + 1].get_attention().get_cache()[0]->getCacheSeqLen(); + if (kv_seq_len != -1) { + hidden_states = WHERE_TOKEN_PRUNING.prepare_next_layer(layer_index, position_ids, hidden_states, kv_seq_len); + } + // ====================================================================================== +#endif layer_index++; } hidden_states = norm(hidden_states); @@ -419,12 +480,14 @@ class Qwen2VLModel final : public Module { } else { hidden_states = lm_head_layer(hidden_states); } +#if defined(VTP) //====================================================================================== // pruning stage if (WHERE_TOKEN_PRUNING.is_prefill() && (Tensor::tensor_status == TENSOR_STATIC_READY)) { WHERE_TOKEN_PRUNING.prefill_stage = false; } //====================================================================================== +#endif return {hidden_states}; } void clear_kvcache() override { diff --git a/src/models/qwen2_vl/vtp/ndc_tools.hpp b/src/models/qwen2_vl/vtp/ndc_tools.hpp new file mode 100644 index 000000000..6ada028b3 --- /dev/null +++ b/src/models/qwen2_vl/vtp/ndc_tools.hpp @@ -0,0 +1,515 @@ +// +// Created by Rongjie Yi on 25-5-29. +// +#ifndef NDC_TOOLS_HPP +#define NDC_TOOLS_HPP + +#include "Module.hpp" +#include "Tensor.hpp" +#include "Types.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace mllm; + +class DelayComputeKVCache { +public: + vector> kv_true_token_appds; + vector hidden_states_cache; + vector hidden_states_filled; + DelayComputeKVCache() { + } + void init_cache_list(int layers) { + hidden_states_cache.resize(layers); + hidden_states_filled.resize(layers); + kv_true_token_appds.resize(layers); + } + void update_hidden_states(Tensor hidden_states, int layer_idx, int original_hs_length, vector pos, bool is_prefill) { + auto b = hidden_states.batch(); + auto d = hidden_states.dimension(); + if (hidden_states_cache[layer_idx].name().empty()) { //=如果hidden_states_cache[layer_idx]为空Tensor + hidden_states_cache[layer_idx] = Tensor(b, 1, original_hs_length, d, MLLM_CPU, true); + hidden_states_cache[layer_idx].setName("hidden_states_cache_" + std::to_string(layer_idx)); + hidden_states_filled[layer_idx] = Tensor(b, 1, original_hs_length, 1, MLLM_CPU, true); + hidden_states_filled[layer_idx].setName("hidden_states_fille_" + std::to_string(layer_idx)); + } + for (int bb = 0; bb < b; ++bb) { + for (int i = 0; i < pos.size(); ++i) { + auto p = pos[i]; + memcpy(hidden_states_cache[layer_idx].ptrAt(bb, 0, p, 0), + hidden_states.ptrAt(bb, 0, i, 0), + sizeof(float) * d); + hidden_states_filled[layer_idx].setDataAt(bb, 0, p, 0, 1.0f); + } + } + } + Tensor get_hidden_states(int layer_idx, vector pos) { + return hidden_states_cache[layer_idx].clip(pos, SEQUENCE); + } + Tensor reset_hidden_states(Tensor hidden_states, int layer_idx, vector pos) { + assert(hidden_states.batch() == 1); + vector hidden_states_last; + hidden_states_last.resize(hidden_states.dimension()); + memcpy(hidden_states_last.data(), + hidden_states.ptrAt(0, 0, hidden_states.sequence() - 1, 0), + sizeof(float) * hidden_states.dimension()); + hidden_states.reshape(1, 1, pos.size() + 1, hidden_states.dimension()); + hidden_states.alloc(); + for (int i = 0; i < pos.size(); ++i) { + int p = pos[i]; + memcpy(hidden_states.ptrAt(0, 0, i, 0), + hidden_states_cache[layer_idx].ptrAt(0, 0, p, 0), + sizeof(float) * hidden_states.dimension()); + } + memcpy(hidden_states.ptrAt(0, 0, hidden_states.sequence() - 1, 0), + hidden_states_last.data(), + sizeof(float) * hidden_states.dimension()); + return hidden_states; + } + vector kv_not_filled_pos(int layer_idx, int original_kv_length) { + auto filled_token = kv_true_token_appds[layer_idx]; + vector not_filled_pos; + for (int i = 0; i < original_kv_length; ++i) { + if (std::find(filled_token.begin(), filled_token.end(), i) == filled_token.end()) { + not_filled_pos.push_back(i); + } + } + return not_filled_pos; + } + template + static void reorder_cache(Tensor &k_cache, Tensor &v_cache, + const vector &indices, + int pos_first, int cache_sequence) { + const int num_heads = v_cache.head(); + const int k_per_head = k_cache.dimension(); + const int v_per_head = v_cache.dimension(); + const int k_size = num_heads * k_per_head; + const int v_size = v_per_head; + // 1. 分配临时内存 + vector> k_cache_data(cache_sequence - pos_first); + vector>> v_cache_data(num_heads); + for (int i = pos_first; i < cache_sequence; i++) { + k_cache_data[i - pos_first].resize(k_size); + } + for (int h = 0; h < num_heads; ++h) { + v_cache_data[h].resize(cache_sequence - pos_first); + for (int i = pos_first; i < cache_sequence; i++) { + v_cache_data[h][i - pos_first].resize(v_size); + } + } + // 2. 拷贝数据到临时内存 + for (int i = pos_first; i < cache_sequence; i++) { + // K_cache拷贝(全部heads) + memcpy(k_cache_data[i - pos_first].data(), + k_cache.ptrAt(0, 0, i, 0), + sizeof(T) * k_size); + // V_cache拷贝(每个head分开) +#pragma omp parallel for num_threads(CPUBackend::cpu_threads) + for (int h = 0; h < num_heads; ++h) { + memcpy(v_cache_data[h][i - pos_first].data(), + v_cache.ptrAt(0, h, i, 0), + sizeof(T) * v_size); + } + } + // 3. 根据索引重新排序 + for (size_t idx : indices) { + if (idx >= (size_t)pos_first && idx < (size_t)cache_sequence) { + const int temp_idx = idx - pos_first; + // 写回K_cache + memcpy(k_cache.ptrAt(0, 0, idx, 0), + k_cache_data[temp_idx].data(), + sizeof(T) * k_size); + // 写回V_cache +#pragma omp parallel for num_threads(CPUBackend::cpu_threads) + for (int h = 0; h < num_heads; ++h) { + memcpy(v_cache.ptrAt(0, h, idx, 0), + v_cache_data[h][temp_idx].data(), + sizeof(T) * v_size); + } + } + } + } + + void update_kv_cache(Tensor &k_cache, Tensor &v_cache, Tensor &k_state, Tensor &v_state, int cache_sequence, int layer_idx, + bool is_prefill, string update_mode, vector pos = {}, int original_kv_length = -1) { + if (update_mode == "insert") { + assert(k_cache.masterTensor() == k_state.masterTensor()); + assert(v_cache.masterTensor() == v_state.masterTensor()); + // pos代表现在的token列表:{0,1,2,3,5,6,8,9}, 8个token及其列表 + if (is_prefill) { + kv_true_token_appds[layer_idx] = pos; // 记录当前token的列表 + assert(kv_true_token_appds[layer_idx].size() == cache_sequence); + } else { + auto new_token_pos = kv_true_token_appds[layer_idx][kv_true_token_appds[layer_idx].size() - 1] + 1; + assert(kv_true_token_appds[layer_idx].size() + 1 + pos.size() == cache_sequence); + kv_true_token_appds[layer_idx].insert(kv_true_token_appds[layer_idx].end(), pos.begin(), pos.end()); + kv_true_token_appds[layer_idx].push_back(new_token_pos); // 添加新的token位置 + auto &cur_pos = kv_true_token_appds[layer_idx]; + // for k_cache; + if (pos.size() > 1) { + auto pos_first = pos[0]; + assert(v_cache.ctype() == BHDS); + // 创建并初始化索引数组 + std::vector indices(cur_pos.size()); + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) { + return cur_pos[a] < cur_pos[b]; + }); + assert(k_cache.batch() == 1); + if (k_cache.dtype() == MLLM_TYPE_F16) { + reorder_cache(k_cache, v_cache, indices, pos_first, cache_sequence); + } else { + reorder_cache(k_cache, v_cache, indices, pos_first, cache_sequence); + } + std::sort(kv_true_token_appds[layer_idx].begin(), kv_true_token_appds[layer_idx].end()); + } + } + } + } +}; + +class NdcContext { + int first_img_token_pos = 0; + int last_img_token_pos = 0; + int last_img_token_pos_l = 0; + DelayComputeKVCache kvcache_ctx; + int cur_step = -1; + vector> chosen_pos_in_each; + vector> chosen_pos_to_delay_compute; + Tensor global_position_ids; + int num_hidden_layers = 0; + int num_head = 0; + int original_kv_length = 0; + int chunk_size = 4; + map pruning_place_cfg = {{3, 0.2}, {9, 0.2}, {12, 0.2}, {15, 0.2}, {18, 0.2}, {26, 0.2}}; + // map pruning_place_cfg = {{3, 0.2}, {9, 0.2}}; + +public: + map causal_masks; // layer_idx -> causal_mask + bool prefill_stage = true; + void init(Tensor input_ids, int num_layers, int num_attention_heads) { + if (Module::llm_model_ptr->doLoad) { return; } + num_hidden_layers = num_layers; + num_head = num_attention_heads; + if (kvcache_ctx.hidden_states_cache.empty()) { + chosen_pos_in_each.resize(num_hidden_layers, {}); + kvcache_ctx.init_cache_list(num_hidden_layers); // Initialize with 1 layer, can be adjusted as needed + } + + if (input_ids.sequence() <= 1) { + prefill_stage = false; + } else { + prefill_stage = true; + } + } + bool is_prefill() { + return prefill_stage && cur_step == 0; + } + void set_vision_token(Tensor where_idx, Tensor hidden_states, Tensor image_embeds) { + if (Module::llm_model_ptr->doLoad) { return; } + first_img_token_pos = int(where_idx.dataAt(0, 0, 0, 0)); + last_img_token_pos = int(where_idx.dataAt(0, 0, 0, where_idx.dimension() - 1)) + 1; + } + + void ndc_prepare(Tensor hidden_states, Tensor position_ids, int past_kv_seq_len) { + if (Module::llm_model_ptr->doLoad) { return; } + cur_step += 1; + if (cur_step == 0) { + global_position_ids = position_ids; + original_kv_length = hidden_states.sequence(); + } + chosen_pos_in_each.resize(num_hidden_layers, {}); + int new_seq_len = hidden_states.sequence() + past_kv_seq_len; + chosen_pos_in_each[0].clear(); + for (int i = 0; i < new_seq_len; ++i) { + chosen_pos_in_each[0].push_back(i); + } + if (!is_prefill()) { + chosen_pos_to_delay_compute.resize(num_hidden_layers, {}); + } + } + void get_kvcache(Tensor &k_cache, Tensor &v_cache, Tensor &k_state, Tensor &v_state, int layer_idx, int cache_sequence) { + if (Module::llm_model_ptr->doLoad) { return; } + if (is_prefill()) { + auto chosen_pos = chosen_pos_in_each[layer_idx]; + kvcache_ctx.update_kv_cache(k_cache, v_cache, k_state, v_state, cache_sequence, layer_idx, + is_prefill(), "insert", + chosen_pos, original_kv_length); + } else { + auto chosen_pos = chosen_pos_in_each[layer_idx]; + kvcache_ctx.update_kv_cache(k_cache, v_cache, k_state, v_state, cache_sequence, layer_idx, + is_prefill(), "insert", + chosen_pos_to_delay_compute[layer_idx], original_kv_length); + } + } + + void topk_partial_sort(const vector &scores, int k, + vector &topk_values, vector &topk_indices) { + if (k <= 0 || scores.empty()) { + topk_values.clear(); + topk_indices.clear(); + return; + } + k = std::min(k, static_cast(scores.size())); + // 创建索引向量 + vector indices(scores.size()); + for (int i = 0; i < scores.size(); i++) { + indices[i] = i; + } + // 部分排序 - 将前k个最大的元素移动到前部 + std::partial_sort(indices.begin(), indices.begin() + k, indices.end(), + [&scores](int a, int b) { + return scores[a] > scores[b]; // 降序排序 + }); + // 提取结果 + topk_values.resize(k); + topk_indices.resize(k); + for (int i = 0; i < k; i++) { + topk_indices[i] = indices[i]; + topk_values[i] = scores[indices[i]]; + } + } + + vector select_high_score_visual_token_prefill(Tensor attn, int layer_idx, int chunk_size = 4) { + auto cur_chosen_pos = chosen_pos_in_each[layer_idx]; + // attention_score_analyze_prefill start + attn = attn.mean(HEAD); // 1,t,1,t + int visual_start_in_selected = -1; + int visual_end_in_selected = -1; + for (int i = 0; i < cur_chosen_pos.size(); ++i) { + auto pos = cur_chosen_pos[i]; + if (pos == first_img_token_pos - 1) { // 0 && visual_end_in_selected > 0) { + break; + } + } + int attn_seq_start = visual_end_in_selected + 1; // +1 for the end image token + int attn_seq_end = attn.sequence(); // exclusive + int attn_dim_start = visual_start_in_selected + 1; // +1 for the first image token + int attn_dim_end = visual_end_in_selected; // exclusive + vector attn_score; // 1,1,1,visual_end_in_selected - visual_start_in_selected + 1 + for (int j = attn_dim_start; j < attn_dim_end; ++j) { + float data = 0.0f; + for (int i = attn_seq_start; i < attn_seq_end; ++i) { + data += attn.dataAt(0, 0, i, j); + } + // data /= (attn_seq_end - attn_seq_start); + attn_score.push_back(data); + } + auto v_s = visual_start_in_selected; + auto v_e = visual_end_in_selected; + // attention_score_analyze_prefill end + auto pruning_rate = pruning_place_cfg[layer_idx]; + auto cur_visual_token_length = attn_score.size(); + auto keep_ratio = 1 - pruning_rate; + int k_initial = static_cast(std::ceil(cur_visual_token_length * keep_ratio)); + int k_final = (k_initial / chunk_size) * chunk_size; + k_final = std::min(k_final, static_cast(cur_visual_token_length)); // 确保不超过当前实际长度 + vector topk_vals; + vector topk_indices; + topk_partial_sort(attn_score, k_final, topk_vals, topk_indices); // torch.topk(attn_score, k_final) + vector final_token_chosen; + vector cur_chosen_pos_p1(cur_chosen_pos.begin(), cur_chosen_pos.begin() + v_s + 1); + vector cur_chosen_pos_p2(cur_chosen_pos.begin() + v_s + 1, cur_chosen_pos.begin() + v_e); + vector cur_chosen_pos_p3(cur_chosen_pos.begin() + v_e, cur_chosen_pos.end()); + final_token_chosen = cur_chosen_pos_p1; + for (auto item : topk_indices) { + final_token_chosen.push_back(cur_chosen_pos_p2[item]); + } + final_token_chosen.insert(final_token_chosen.end(), cur_chosen_pos_p3.begin(), cur_chosen_pos_p3.end()); + std::sort(final_token_chosen.begin(), final_token_chosen.end()); + return final_token_chosen; + } + vector select_high_score_visual_token_decode(Tensor attn, int layer_idx, int chunk_size = 4) { + auto cur_chosen_pos = chosen_pos_in_each[layer_idx]; + // attention_score_analyze_decode start + attn = attn.mean(HEAD); // 1,t,1,t TODO + if (attn.sequence() != 1) { + attn = attn.clip({}, {}, {-1}, {}); // 1,1,1,t + } + auto cur_chosen_tokens = chosen_pos_in_each[layer_idx]; + int visual_start_in_selected = -1; + int visual_end_in_selected = -1; + for (int i = 0; i < cur_chosen_pos.size(); ++i) { + auto pos = cur_chosen_pos[i]; + if (pos == first_img_token_pos - 1) { // 0 && visual_end_in_selected > 0) { + break; + } + } + int attn_dim_start = visual_start_in_selected + 1; // +1 for the first image token + int attn_dim_end = visual_end_in_selected; // exclusive + vector attn_score; // 1,1,1,visual_end_in_selected - visual_start_in_selected + 1 + for (int j = attn_dim_start; j < attn_dim_end; ++j) { + float data = attn.dataAt(0, 0, 0, j); + attn_score.push_back(data); + } + auto v_s = visual_start_in_selected; + auto v_e = visual_end_in_selected; + // attention_score_analyze_decode end + auto pruning_rate = pruning_place_cfg[layer_idx]; + auto cur_visual_token_length = attn_score.size(); + auto keep_ratio = 1 - pruning_rate; + int k_initial = static_cast(std::ceil(cur_visual_token_length * keep_ratio)); + int k_final = (k_initial / chunk_size) * chunk_size; + k_final = std::min(k_final, static_cast(cur_visual_token_length)); // 确保不超过当前实际长度 + vector topk_vals; + vector topk_indices; + topk_partial_sort(attn_score, k_final, topk_vals, topk_indices); // torch.topk(attn_score, k_final) + vector final_token_chosen; + vector cur_chosen_pos_p1(cur_chosen_pos.begin(), cur_chosen_pos.begin() + v_s + 1); + vector cur_chosen_pos_p2(cur_chosen_pos.begin() + v_s + 1, cur_chosen_pos.begin() + v_e); + vector cur_chosen_pos_p3(cur_chosen_pos.begin() + v_e, cur_chosen_pos.end()); + final_token_chosen = cur_chosen_pos_p1; + for (auto item : topk_indices) { + final_token_chosen.push_back(cur_chosen_pos_p2[item]); + } + final_token_chosen.insert(final_token_chosen.end(), cur_chosen_pos_p3.begin(), cur_chosen_pos_p3.end()); + std::sort(final_token_chosen.begin(), final_token_chosen.end()); + return final_token_chosen; + } + + void update_hidden_pos(Tensor hidden_states, Tensor attn_weight, int layer_idx) { + if (Module::llm_model_ptr->doLoad) { return; } + if (is_prefill()) { + auto chs_pos = chosen_pos_in_each[layer_idx]; + if (pruning_place_cfg.find(layer_idx) != pruning_place_cfg.end()) { + kvcache_ctx.update_hidden_states(hidden_states, layer_idx, original_kv_length, chs_pos, is_prefill()); + chosen_pos_in_each[layer_idx + 1] = select_high_score_visual_token_prefill(attn_weight, layer_idx, chunk_size); + } else { + if (layer_idx + 1 < num_hidden_layers) { + chosen_pos_in_each[layer_idx + 1] = chosen_pos_in_each[layer_idx]; + } + } + } else { + auto chs_pos = chosen_pos_to_delay_compute[layer_idx]; + if (pruning_place_cfg.find(layer_idx) != pruning_place_cfg.end()) { + kvcache_ctx.update_hidden_states(hidden_states, layer_idx, original_kv_length, chs_pos, is_prefill()); + chosen_pos_in_each[layer_idx + 1] = select_high_score_visual_token_decode(attn_weight, layer_idx, chunk_size); + } else { + if (layer_idx + 1 < num_hidden_layers) { + chosen_pos_in_each[layer_idx + 1] = chosen_pos_in_each[layer_idx]; + } + } + } + } + + Tensor prepare_next_layer(int layer_idx, Tensor &position_ids, Tensor &hidden_states, int kv_seq_len) { + if (Module::llm_model_ptr->doLoad) { return hidden_states; } + if (is_prefill()) { + if (pruning_place_cfg.find(layer_idx) != pruning_place_cfg.end()) { + auto this_layer_pos = chosen_pos_in_each[layer_idx]; + auto next_layer_pos = chosen_pos_in_each[layer_idx + 1]; + position_ids = global_position_ids.clip(next_layer_pos, DIMENSION); + std::vector mapping_this_2_next_pos; + for (size_t idx = 0; idx < this_layer_pos.size(); ++idx) { + int value = this_layer_pos[idx]; + if (std::find(next_layer_pos.begin(), next_layer_pos.end(), value) != next_layer_pos.end()) { + mapping_this_2_next_pos.push_back(idx); + } + } + assert(mapping_this_2_next_pos.size() == next_layer_pos.size()); + hidden_states = hidden_states.clip(mapping_this_2_next_pos, SEQUENCE); + } else { + if (layer_idx + 1 < num_hidden_layers) { + auto this_layer_pos = chosen_pos_in_each[layer_idx]; + auto next_layer_pos = chosen_pos_in_each[layer_idx + 1]; + assert(this_layer_pos.size() == next_layer_pos.size()); + assert(std::equal(this_layer_pos.begin(), this_layer_pos.end(), next_layer_pos.begin())); + } + } + } else { + if (pruning_place_cfg.find(layer_idx) != pruning_place_cfg.end()) { + auto this_layer_pos = chosen_pos_in_each[layer_idx]; + auto next_layer_pos = chosen_pos_in_each[layer_idx + 1]; + auto next_layer_kv_cache_not_filled_pos = kvcache_ctx.kv_not_filled_pos(layer_idx + 1, original_kv_length); + std::vector need_to_delay_compute_in_next_layer_pos; + for (int item : next_layer_pos) { + if (std::find(next_layer_kv_cache_not_filled_pos.begin(), + next_layer_kv_cache_not_filled_pos.end(), + item) + != next_layer_kv_cache_not_filled_pos.end()) { + need_to_delay_compute_in_next_layer_pos.push_back(item); + } + } + std::sort(need_to_delay_compute_in_next_layer_pos.begin(), + need_to_delay_compute_in_next_layer_pos.end()); + chosen_pos_to_delay_compute[layer_idx + 1] = need_to_delay_compute_in_next_layer_pos; + if (!need_to_delay_compute_in_next_layer_pos.empty()) { + position_ids = Tensor::cat( + {global_position_ids.clip(need_to_delay_compute_in_next_layer_pos, DIMENSION), + position_ids.clip({}, {}, {}, {-1})}, + DIMENSION); + hidden_states = kvcache_ctx.reset_hidden_states(hidden_states, layer_idx, need_to_delay_compute_in_next_layer_pos); + // mask + int seq = chosen_pos_to_delay_compute[layer_idx + 1].size(); + int dim = kv_seq_len + hidden_states.sequence(); + auto &delay_compute_vec = chosen_pos_to_delay_compute[layer_idx + 1]; + auto &in_each_vec = chosen_pos_in_each[layer_idx + 1]; + Tensor causal_mask(1, num_head, seq + 1, dim, MLLM_CPU, true); + causal_mask.setName("causal_mask_" + std::to_string(layer_idx + 1)); + float min_val = std::numeric_limits::lowest(); + for (int q_side_idx = 0; q_side_idx < seq; ++q_side_idx) { + // 获取当前查询位置对应的值 + int target_value = delay_compute_vec[q_side_idx]; + // 在in_each_vec中查找target_value的位置 + auto it = std::find(in_each_vec.begin(), in_each_vec.end(), target_value); + // 确保找到目标值 + if (it == in_each_vec.end()) { + // 处理未找到的情况 - 可选择报错或跳过 + std::cerr << "Error: target_value not found in chosen_pos_in_each" << std::endl; + continue; // 跳过当前迭代 + } + // 计算在向量中的索引位置 + int start_index = std::distance(in_each_vec.begin(), it) + 1; + // 设置从start_index到末尾的所有元素为min_val + for (int h = 0; h < num_head; h++) { + for (int j = 0; j < start_index; ++j) { + causal_mask.setDataAt(0, h, q_side_idx, j, 0); + } + for (int j = start_index; j < dim; ++j) { + causal_mask.setDataAt(0, h, q_side_idx, j, -INFINITY); + } + } + } + + for (int h = 0; h < num_head; h++) { + memset(causal_mask.ptrAt(0, h, causal_mask.sequence() - 1, 0), + 0, causal_mask.dimension() * sizeof(float)); + } + causal_masks[layer_idx + 1] = causal_mask; + } else { + hidden_states = hidden_states.clip({}, {}, {-1}, {}); + position_ids = position_ids.clip({}, {}, {}, {-1}); + } + } else { + if (layer_idx + 1 < num_hidden_layers) { + auto this_layer_pos = chosen_pos_in_each[layer_idx]; + auto next_layer_pos = chosen_pos_in_each[layer_idx + 1]; + chosen_pos_to_delay_compute[layer_idx + 1] = chosen_pos_to_delay_compute[layer_idx]; + } + } + } + return hidden_states; + } +}; +NdcContext WHERE_TOKEN_PRUNING; + +#endif // NDC_TOOLS_HPP \ No newline at end of file diff --git a/src/models/qwen2_vl/vtp/vtp_tools.hpp b/src/models/qwen2_vl/vtp/vtp_tools.hpp index 168a9cdf0..dde6c315b 100644 --- a/src/models/qwen2_vl/vtp/vtp_tools.hpp +++ b/src/models/qwen2_vl/vtp/vtp_tools.hpp @@ -23,14 +23,16 @@ using namespace mllm; class VtpContext { public: - void init() { + void init(Tensor input_ids, int num_hidden_layers) { + if (input_ids.sequence() <= 1) { + prefill_stage = false; + } else { + prefill_stage = true; + } if (global_selected.backend() == nullptr) global_selected = Tensor(1, 1, 1, 1, MLLM_CPU); } void set_vision_token(Tensor where_idx, Tensor hidden_states, Tensor image_embeds) { - // if (Module::llm_model_ptr->doLoad) { - // return; - // } no_visual_token_len = hidden_states.sequence() - image_embeds.sequence(); global_selected.reshape(1, 1, 1, hidden_states.sequence()); // pre_visual_token_len); global_selected.alloc(); @@ -46,9 +48,6 @@ class VtpContext { no_visual_token_len = hidden_states.sequence() - pre_visual_token_len; } bool is_prefill() { - // if (Module::llm_model_ptr->doLoad) { - // return false; - // } return prefill_stage; } void set_prefill_layer(int layer_idx_) { @@ -142,7 +141,7 @@ class VtpContext { } } } - Tensor prunning_attn_output(Tensor attn_output) { + Tensor prunning_attn_output(Tensor attn_output, int layer_idx) { if (layer_idx == 0) { return attn_output; } @@ -198,7 +197,15 @@ class VtpContext { int HEAD_TOP_K = 3; float ATTN_ACC_ALPHA = 0.2; - map pruning_setting = {{3, 0.5}}; //{{3, 0.5}}; + map pruning_setting = {{3, 0.5}, {8, 0.8}}; + // map pruning_setting = {{3, 0.2}, {9, 0.2}, {12, 0.4}, {18, 0.4}, {21, 0.8}, {26, 0.8}}; + // map pruning_setting = {{3, 0.2}, {9, 0.2}, {12, 0.2}, {18, 0.5}, {21, 0.5}, {26, 0.5}}; + // map pruning_setting = {{3, 0.8}, {9, 0.8}, {12, 0.8}, {18, 0.8}, {21, 0.8}, {26, 0.8}}; + // map pruning_setting = {{3, 0.5}}; + // map pruning_setting = {{3, 0.8}}; + + // 3, 9, 12, 18, 21, 26 + // 0.2, 0.2, 0.4, 0.4, 0.8, 0.8 private: // 实现 topk 功能 diff --git a/src/models/qwen3/modeling_qwen3.hpp b/src/models/qwen3/modeling_qwen3.hpp index 83e304518..9da5c49e9 100644 --- a/src/models/qwen3/modeling_qwen3.hpp +++ b/src/models/qwen3/modeling_qwen3.hpp @@ -60,6 +60,7 @@ class QWen3Attention final : public Module { num_key_value_heads = config.num_key_value_heads; num_key_value_groups = num_heads / num_key_value_heads; rms_norm_eps = config.rms_norm_eps; + attn_impl = config.attn_implementation; // init layers q_proj = Linear(hidden_size, num_heads * head_dim, config.attention_bias, base_name + names._q_proj_name); k_proj = Linear(hidden_size, num_key_value_heads * head_dim, config.attention_bias, @@ -77,8 +78,8 @@ class QWen3Attention final : public Module { base_name + "q_rope"); k_rope = RoPE(config.RoPE_type, config.rope_theta, config.max_position_embeddings, base_name + "k_rope"); - k_cache = KVCache(num_key_value_heads, head_dim, num_key_value_groups, config.cache_limit, base_name + "k_cache"); - v_cache = KVCache(num_key_value_heads, head_dim, num_key_value_groups, config.cache_limit, base_name + "v_cache"); + k_cache = KVCache(num_key_value_heads, head_dim, num_key_value_groups, config.cache_limit, (attn_impl == "flash_attention_2"), base_name + "k_cache"); + v_cache = KVCache(num_key_value_heads, head_dim, num_key_value_groups, config.cache_limit, (attn_impl == "flash_attention_2"), base_name + "v_cache"); softmax = Softmax(DIMENSION, true, base_name + "softmax"); } @@ -104,14 +105,19 @@ class QWen3Attention final : public Module { key_states = k_cache(key_states); value_states = v_cache(value_states); - // attention weight - auto atten_weight = - Tensor::mm(query_states, key_states.transpose(Chl::SEQUENCE, Chl::DIMENSION)) - / std::sqrt(head_dim); - atten_weight = softmax(atten_weight, k_cache.getCacheSeqLen()); + Tensor atten_output; + if (attn_impl == "flash_attention_2") { + atten_output = Tensor::flash_attention2_forward(query_states, key_states, value_states, true); + } else { // eager implementation + // attention weight + auto atten_weight = + Tensor::mm(query_states, key_states.transpose(Chl::SEQUENCE, Chl::DIMENSION)) + / std::sqrt(head_dim); + atten_weight = softmax(atten_weight, k_cache.getCacheSeqLen()); - // attention output - auto atten_output = Tensor::mm(atten_weight, value_states); + // attention output + atten_output = Tensor::mm(atten_weight, value_states); + } atten_output = atten_output.view(-1, 1, -1, head_dim * num_heads); atten_output = o_proj(atten_output); return {atten_output}; @@ -143,6 +149,7 @@ class QWen3Attention final : public Module { KVCache v_cache; // Causalmask mask; Softmax softmax; + string attn_impl; }; class QWen3Decoder final : public Module { diff --git a/src/models/smollm/modeling_smollm.hpp b/src/models/smollm/modeling_smollm.hpp index 86d8d6ca8..c09621a42 100644 --- a/src/models/smollm/modeling_smollm.hpp +++ b/src/models/smollm/modeling_smollm.hpp @@ -50,9 +50,8 @@ class SmolLMBlock final : public Module { public: SmolLMBlock() = default; - SmolLMBlock(int hidden_dim, int head_size, int kv_head_size, int ffn_hidden, RoPEType RoPE_type, float rope_theta, int max_position_embeddings, int cache_limit, const SmolLMNameConfig &names, const string &base_name) { - attention = MultiHeadAttention(hidden_dim, head_size, kv_head_size, hidden_dim / head_size, SPLIT_NONE, false, false, - RoPE_type, rope_theta, max_position_embeddings, cache_limit, true, false, names, base_name + names._attn_base_name); + SmolLMBlock(int hidden_dim, int head_size, int kv_head_size, int ffn_hidden, RoPEType RoPE_type, float rope_theta, int max_position_embeddings, int cache_limit, string attn_implementation, const SmolLMNameConfig &names, const string &base_name) { + attention = MultiHeadAttention(hidden_dim, head_size, kv_head_size, hidden_dim / head_size, SPLIT_NONE, false, false, RoPE_type, rope_theta, max_position_embeddings, cache_limit, true, false, attn_implementation, names, base_name + names._attn_base_name); mlp = SmolLMMLP(hidden_dim, ffn_hidden, names, base_name + names._ffn_base_name); norm1 = RMSNorm(hidden_dim, 1e-6, base_name + names._attn_norm_name); norm2 = RMSNorm(hidden_dim, 1e-6, base_name + names._ffn_norm_name); @@ -81,13 +80,13 @@ class SmolLMModel final : public Module { public: explicit SmolLMModel(const SmolLMConfig &config) : SmolLMModel(config.vocab_size, config.hidden_dim, config.head_size, config.num_key_value_heads, config.ffn_hidden, config.block_num, - config.RoPE_type, config.rope_theta, config.max_position_embeddings, config.cache_limit, + config.RoPE_type, config.rope_theta, config.max_position_embeddings, config.cache_limit, config.attn_implementation, config.names_config, config.names_config.blk_name) { } - SmolLMModel(int vocab_size, int hidden_dim, int head_size, int kv_head_size, int ffn_hidden, int block_num, RoPEType RoPE_type, float rope_theta, int max_position_embeddings, int cache_limit, + SmolLMModel(int vocab_size, int hidden_dim, int head_size, int kv_head_size, int ffn_hidden, int block_num, RoPEType RoPE_type, float rope_theta, int max_position_embeddings, int cache_limit, string attn_implementation, const SmolLMNameConfig &names, const string &base_name) { embedding = Embedding(vocab_size, hidden_dim, names.token_embd_name); - blocks = List(block_num, hidden_dim, head_size, kv_head_size, ffn_hidden, RoPE_type, rope_theta, max_position_embeddings, cache_limit, names, base_name); + blocks = List(block_num, hidden_dim, head_size, kv_head_size, ffn_hidden, RoPE_type, rope_theta, max_position_embeddings, cache_limit, attn_implementation, names, base_name); norm = RMSNorm(hidden_dim, 1e-6, names.post_norm_name); lm_head = Parameter(1, vocab_size, 1, hidden_dim, names.token_embd_name + ".weight"); diff --git a/src/models/stablelm/modeling_stablelm.hpp b/src/models/stablelm/modeling_stablelm.hpp index 8033689e8..6a0acc582 100644 --- a/src/models/stablelm/modeling_stablelm.hpp +++ b/src/models/stablelm/modeling_stablelm.hpp @@ -22,15 +22,17 @@ class StableLMMultiHeadAttention final : public Module { int kv_head_size_{}; int attn_hidden_dim_{}; Chl split_chl_{}; + string attn_impl; public: StableLMMultiHeadAttention() = default; StableLMMultiHeadAttention(int hidden_dim, int head_size, int kv_head_size, int attn_hidden_dim, - RoPEType RoPE_type, int cache_limit, bool do_mask, bool bias, + RoPEType RoPE_type, int cache_limit, bool do_mask, bool bias, string attn_implementation, const TransformerNameConfig &names, const string &base_name) { attn_hidden_dim_ = attn_hidden_dim; head_size_ = head_size; kv_head_size_ = kv_head_size; + attn_impl = attn_implementation; q_proj = Linear(hidden_dim, head_size * attn_hidden_dim, bias, base_name + names._q_proj_name); k_proj = Linear(hidden_dim, kv_head_size * attn_hidden_dim, bias, base_name + names._k_proj_name); v_proj = Linear(hidden_dim, kv_head_size * attn_hidden_dim, bias, base_name + names._v_proj_name); @@ -39,8 +41,8 @@ class StableLMMultiHeadAttention final : public Module { k_rope = RoPE(RoPE_type, 10000, 0.25, 4096, base_name + "k_rope"); } if (cache_limit > 0) { - k_cache = KVCache(kv_head_size, attn_hidden_dim, head_size / kv_head_size, cache_limit, base_name + "k_cache"); - v_cache = KVCache(kv_head_size, attn_hidden_dim, head_size / kv_head_size, cache_limit, base_name + "v_cache"); + k_cache = KVCache(kv_head_size, attn_hidden_dim, head_size / kv_head_size, cache_limit, (attn_impl == "flash_attention_2"), base_name + "k_cache"); + v_cache = KVCache(kv_head_size, attn_hidden_dim, head_size / kv_head_size, cache_limit, (attn_impl == "flash_attention_2"), base_name + "v_cache"); } softmax = Softmax(DIMENSION, do_mask, base_name + "softmax"); o_proj = Linear(head_size * attn_hidden_dim, hidden_dim, false, base_name + names._o_proj_name); @@ -61,11 +63,17 @@ class StableLMMultiHeadAttention final : public Module { k = k_cache(k); v = v_cache(v); } - k = k.transpose(SEQUENCE, DIMENSION); - auto qk = Tensor::mm(q, k); - qk = qk / std::sqrt(attn_hidden_dim_); - qk = softmax(qk, k_cache.getCacheSeqLen()); - auto o = Tensor::mm(qk, v); + + Tensor o; + if (attn_impl == "flash_attention_2") { + o = Tensor::flash_attention2_forward(q, k, v, true); + } else { // eager implementation + k = k.transpose(SEQUENCE, DIMENSION); + auto qk = Tensor::mm(q, k); + qk = qk / std::sqrt(attn_hidden_dim_); + qk = softmax(qk, k_cache.getCacheSeqLen()); + o = Tensor::mm(qk, v); + } o = o.view(-1, 1, -1, attn_hidden_dim_ * head_size_); o = o_proj(o); return {o}; @@ -104,9 +112,10 @@ class StableLMBlock final : public Module { public: StableLMBlock() = default; - StableLMBlock(int hidden_dim, int head_size, int ffn_hidden, RoPEType RoPE_type, int cache_limit, const stablelmNameConfig &names, const string &base_name) { + StableLMBlock(int hidden_dim, int head_size, int ffn_hidden, RoPEType RoPE_type, int cache_limit, string attn_implementation, const stablelmNameConfig &names, const string &base_name) { attention = StableLMMultiHeadAttention(hidden_dim, head_size, head_size, hidden_dim / head_size, - RoPE_type, cache_limit, true, true, names, base_name + names._attn_base_name); + RoPE_type, cache_limit, true, true, attn_implementation, + names, base_name + names._attn_base_name); mlp = StableLMMLP(hidden_dim, ffn_hidden, names, base_name + names._ffn_base_name); norm1 = LayerNorm(hidden_dim, true, 1e-5, base_name + names._attn_norm_name); norm2 = LayerNorm(hidden_dim, true, 1e-5, base_name + names._ffn_norm_name); @@ -130,13 +139,11 @@ class StableLMModel final : public Module { public: explicit StableLMModel(const StableLMConfig &config) : - StableLMModel(config.vocab_size, config.hidden_dim, config.head_size, config.ffn_hidden, config.block_num, config.RoPE_type, config.cache_limit, - config.names_config, config.names_config.blk_name) { + StableLMModel(config.vocab_size, config.hidden_dim, config.head_size, config.ffn_hidden, config.block_num, config.RoPE_type, config.cache_limit, config.attn_implementation, config.names_config, config.names_config.blk_name) { } - StableLMModel(int vocab_size, int hidden_dim, int head_size, int ffn_hidden, int block_num, RoPEType RoPE_type, int cache_limit, - const stablelmNameConfig &names, const string &base_name) { + StableLMModel(int vocab_size, int hidden_dim, int head_size, int ffn_hidden, int block_num, RoPEType RoPE_type, int cache_limit, string attn_implementation, const stablelmNameConfig &names, const string &base_name) { embedding = Embedding(vocab_size, hidden_dim, names.token_embd_name); - blocks = List(block_num, hidden_dim, head_size, ffn_hidden, RoPE_type, cache_limit, names, base_name); + blocks = List(block_num, hidden_dim, head_size, ffn_hidden, RoPE_type, cache_limit, attn_implementation, names, base_name); norm = LayerNorm(hidden_dim, true, 1e-5, names.post_norm_name); lm_head = Linear(hidden_dim, vocab_size, false, names.lm_head_name); } diff --git a/src/models/tinyllama/modeling_tinyllama.hpp b/src/models/tinyllama/modeling_tinyllama.hpp index 3d54d6cf4..69facdf51 100644 --- a/src/models/tinyllama/modeling_tinyllama.hpp +++ b/src/models/tinyllama/modeling_tinyllama.hpp @@ -20,9 +20,12 @@ class TinyLLaMABlock final : public Module { public: TinyLLaMABlock() = default; - TinyLLaMABlock(int hidden_dim, int head_size, int kv_head_size, int ffn_hidden, RoPEType RoPE_type, float rope_theta, int max_position_embeddings, int cache_limit, const LLaMANameConfig &names, const string &base_name) { - attention = MultiHeadAttention(hidden_dim, head_size, kv_head_size, hidden_dim / head_size, SPLIT_NONE, false, false, - RoPE_type, rope_theta, max_position_embeddings, cache_limit, true, false, names, base_name + names._attn_base_name); + TinyLLaMABlock(int hidden_dim, int head_size, int kv_head_size, int ffn_hidden, RoPEType RoPE_type, float rope_theta, int max_position_embeddings, int cache_limit, string attn_implementation, const LLaMANameConfig &names, const string &base_name) { + attention = MultiHeadAttention(hidden_dim, head_size, kv_head_size, + hidden_dim / head_size, SPLIT_NONE, false, false, + RoPE_type, rope_theta, max_position_embeddings, + cache_limit, true, false, + attn_implementation, names, base_name + names._attn_base_name); mlp = LLaMAMLP(hidden_dim, ffn_hidden, names, base_name + names._ffn_base_name); norm1 = RMSNorm(hidden_dim, 1e-6, base_name + names._attn_norm_name); norm2 = RMSNorm(hidden_dim, 1e-6, base_name + names._ffn_norm_name); @@ -46,15 +49,15 @@ class TinyLLaMAModel final : public Module { public: explicit TinyLLaMAModel(const TinyLLaMAConfig &config) : - TinyLLaMAModel(config.vocab_size, config.hidden_dim, config.head_size, config.kv_head_size, config.ffn_hidden, config.block_num, - config.RoPE_type, config.rope_theta, config.max_position_embeddings, config.cache_limit, + TinyLLaMAModel(config.vocab_size, config.hidden_dim, config.head_size, config.kv_head_size, config.ffn_hidden, config.block_num, + config.RoPE_type, config.rope_theta, config.max_position_embeddings, config.cache_limit, config.attn_implementation, config.names_config, config.names_config.blk_name) { } - TinyLLaMAModel(int vocab_size, int hidden_dim, int head_size, int kv_head_size, int ffn_hidden, int block_num, - RoPEType RoPE_type, float rope_theta, int max_position_embeddings, int cache_limit, + TinyLLaMAModel(int vocab_size, int hidden_dim, int head_size, int kv_head_size, int ffn_hidden, int block_num, + RoPEType RoPE_type, float rope_theta, int max_position_embeddings, int cache_limit, string attn_implementation, const LLaMANameConfig &names, const string &base_name) { embedding = Embedding(vocab_size, hidden_dim, names.token_embd_name); - blocks = List(block_num, hidden_dim, head_size, kv_head_size, ffn_hidden, RoPE_type, rope_theta, max_position_embeddings, cache_limit, names, base_name); + blocks = List(block_num, hidden_dim, head_size, kv_head_size, ffn_hidden, RoPE_type, rope_theta, max_position_embeddings, cache_limit, attn_implementation, names, base_name); norm = RMSNorm(hidden_dim, 1e-6, names.post_norm_name); lm_head = Linear(hidden_dim, vocab_size, false, names.lm_head_name); } diff --git a/src/models/transformer/configuration_transformer.hpp b/src/models/transformer/configuration_transformer.hpp index 042e0f97f..c425bb768 100644 --- a/src/models/transformer/configuration_transformer.hpp +++ b/src/models/transformer/configuration_transformer.hpp @@ -37,5 +37,6 @@ class TransformerConfig { public: TransformerConfig() { } + string attn_implementation = "flash_attention_2"; // Options: "flash_attention_2", "eager" }; #endif // CONFIGURATION_TRANSFORMER_HPP diff --git a/src/models/transformer/modeling_transformer.hpp b/src/models/transformer/modeling_transformer.hpp index dbc601ce2..b243f46d5 100644 --- a/src/models/transformer/modeling_transformer.hpp +++ b/src/models/transformer/modeling_transformer.hpp @@ -31,6 +31,8 @@ class MultiHeadAttention final : public Module { int kv_head_size_{}; int attn_hidden_dim_{}; Chl split_chl_{}; + bool causal_mask = true; + string attn_implementation_ = "flash_attention_2"; // Options: "flash_attention_2", "eager" public: MultiHeadAttention() = default; @@ -38,10 +40,13 @@ class MultiHeadAttention final : public Module { AttnQKVSplitType do_qkv_proj, bool post_qkv_norm, bool bias_kv_cat, RoPEType RoPE_type, float rope_theta, int max_position_embeddings, int cache_limit, bool do_mask, bool bias, + string attn_implementation, const TransformerNameConfig &names, const string &base_name) { attn_hidden_dim_ = attn_hidden_dim; head_size_ = head_size; kv_head_size_ = kv_head_size; + causal_mask = do_mask; + attn_implementation_ = attn_implementation; if (do_qkv_proj > 0) { qkv_proj = Linear(hidden_dim, head_size * attn_hidden_dim * 3, bias, base_name + names._qkv_proj_name); split_chl_ = (Chl)do_qkv_proj; @@ -59,8 +64,12 @@ class MultiHeadAttention final : public Module { k_rope = RoPE(RoPE_type, rope_theta, max_position_embeddings, base_name + "k_rope"); } if (cache_limit > 0) { - k_cache = KVCache(kv_head_size, attn_hidden_dim, head_size / kv_head_size, cache_limit, base_name + "k_cache"); - v_cache = KVCache(kv_head_size, attn_hidden_dim, head_size / kv_head_size, cache_limit, base_name + "v_cache"); + k_cache = KVCache(kv_head_size, attn_hidden_dim, + head_size / kv_head_size, cache_limit, + (attn_implementation_ == "flash_attention_2"), base_name + "k_cache"); + v_cache = KVCache(kv_head_size, attn_hidden_dim, + head_size / kv_head_size, cache_limit, + (attn_implementation_ == "flash_attention_2"), base_name + "v_cache"); } softmax = Softmax(DIMENSION, do_mask, base_name + "softmax"); o_proj = Linear(head_size * attn_hidden_dim, hidden_dim, bias, base_name + names._o_proj_name); @@ -101,15 +110,20 @@ class MultiHeadAttention final : public Module { k = k_cache(k); v = v_cache(v); } - k = k.transpose(SEQUENCE, DIMENSION); - auto qk = Tensor::mm(q, k); - qk = qk / std::sqrt(attn_hidden_dim_); - if (k_cache.ready() && v_cache.ready()) { - qk = softmax(qk, k_cache.getCacheSeqLen()); - } else { - qk = softmax(qk); + Tensor o; + if (attn_implementation_ == "flash_attention_2") { + o = Tensor::flash_attention2_forward(q, k, v, causal_mask); + } else { // eager implementation + k = k.transpose(SEQUENCE, DIMENSION); + auto qk = Tensor::mm(q, k); + qk = qk / std::sqrt(attn_hidden_dim_); + if (k_cache.ready() && v_cache.ready()) { + qk = softmax(qk, k_cache.getCacheSeqLen()); + } else { + qk = softmax(qk); + } + o = Tensor::mm(qk, v); } - auto o = Tensor::mm(qk, v); o = o.view(-1, 1, -1, attn_hidden_dim_ * head_size_); o = o_proj(o); return {o}; diff --git a/src/models/vit/modeling_vit.hpp b/src/models/vit/modeling_vit.hpp index a2a97915d..a26634132 100644 --- a/src/models/vit/modeling_vit.hpp +++ b/src/models/vit/modeling_vit.hpp @@ -37,10 +37,11 @@ class ViTBlock final : public Module { public: ViTBlock() = default; - ViTBlock(int hidden_dim, int head_size, int ffn_hidden, const string &act_fn_type, const ViTNameConfig &names, const string &base_name) { + ViTBlock(int hidden_dim, int head_size, int ffn_hidden, const string &act_fn_type, + string attn_implementation, const ViTNameConfig &names, const string &base_name) { attention = MultiHeadAttention(hidden_dim, head_size, head_size, hidden_dim / head_size, SPLIT_NONE, false, false, RoPEType::NONE, - -1, -1, 0, false, true, + -1, -1, 0, false, true, attn_implementation, names, base_name + names._attn_base_name); mlp = ViTMLP(hidden_dim, ffn_hidden, act_fn_type, names, base_name + names._ffn_base_name); down_proj = Linear(ffn_hidden, hidden_dim, true, base_name + names._down_proj_name); @@ -89,13 +90,11 @@ class ViTModel final : public Module { public: explicit ViTModel(const ViTConfig &config) : - ViTModel(config.hidden_dim, config.head_size, config.ffn_hidden, config.act_fn_type, config.patch, config.img_hw, config.block_num, config.class_size, - config.names_config, config.names_config.vison_model_name) { + ViTModel(config.hidden_dim, config.head_size, config.ffn_hidden, config.act_fn_type, config.patch, config.img_hw, config.block_num, config.class_size, config.attn_implementation, config.names_config, config.names_config.vison_model_name) { } - ViTModel(int hidden_dim, int head_size, int ffn_hidden, const string &act_fn_type, int patch, int img_hw, int block_num, int class_size, - const ViTNameConfig &names, const string &base_name) { + ViTModel(int hidden_dim, int head_size, int ffn_hidden, const string &act_fn_type, int patch, int img_hw, int block_num, int class_size, string attn_implementation, const ViTNameConfig &names, const string &base_name) { embedding = ViTEmbedding(hidden_dim, patch, img_hw, names, base_name + names._embd_name); - blocks = List(block_num, hidden_dim, head_size, ffn_hidden, act_fn_type, names, base_name + names._layer_name); + blocks = List(block_num, hidden_dim, head_size, ffn_hidden, act_fn_type, attn_implementation, names, base_name + names._layer_name); norm = LayerNorm(hidden_dim, true, 1e-6, base_name + names._post_norm_name); lm_head = Linear(hidden_dim, class_size, false, names.lm_head_name); } diff --git a/tools/convertor/profiling_activation/export_int8_model.py b/tools/convertor/profiling_activation/export_int8_model.py index 825cc69aa..b5c51065d 100644 --- a/tools/convertor/profiling_activation/export_int8_model.py +++ b/tools/convertor/profiling_activation/export_int8_model.py @@ -58,7 +58,14 @@ def quantize_weight_per_tensor_absmax(w, n_bits=8): act_dict = args.scale_file.name t01m_clip_threshold = args.t01m_clip_threshold - model = AutoModelForCausalLM.from_pretrained(model_name) + if args.model_type != "qwen2vl": + model = AutoModelForCausalLM.from_pretrained(model_name) + else: + from transformers import Qwen2VLForConditionalGeneration + model = Qwen2VLForConditionalGeneration.from_pretrained( + model_name + ) + act_dict = json.load(open(act_dict)) act_scales, clip_top, return_dict = get_clip_and_scale( @@ -77,6 +84,8 @@ def quantize_weight_per_tensor_absmax(w, n_bits=8): q_model = quantize_llama_like(model, act_scales, layer_clip=clip_top) elif args.model_type == "qwen2" or args.model_type == "qwen1": q_model = quantize_qwen2_like(model, act_scales, layer_clip=clip_top) + elif args.model_type == "qwen2vl": + q_model = quantize_qwen2vl_like(model, act_scales, layer_clip=clip_top) elif args.model_type == "gemma": q_model = quantize_gemma_like(model, act_scales, layer_clip=clip_top) elif args.model_type == "phi": diff --git a/tools/convertor/profiling_activation/utils/quantization_simulation.py b/tools/convertor/profiling_activation/utils/quantization_simulation.py index 716d2c7f3..eeb5367f0 100644 --- a/tools/convertor/profiling_activation/utils/quantization_simulation.py +++ b/tools/convertor/profiling_activation/utils/quantization_simulation.py @@ -374,6 +374,68 @@ def quantize_qwen2_like( ) return model +def quantize_qwen2vl_like( + model, + decoder_scales, + weight_quant="per_tensor", + act_quant="per_tensor", + quantize_bmm_input=False, + layer_clip={}, +): + from transformers.models.qwen2_vl.modeling_qwen2_vl import ( + Qwen2VLSdpaAttention, + Qwen2MLP, + ) + + for name, m in model.model.named_modules(): + if isinstance(m, Qwen2MLP): + m.gate_proj = W8A8LinearStatic.from_float( + m.gate_proj, + decoder_scales["model." + name + ".gate_proj"], + weight_quant_type=weight_quant, + clip_top=layer_clip["model." + name + ".gate_proj"], + ) + m.up_proj = W8A8LinearStatic.from_float( + m.up_proj, + decoder_scales["model." + name + ".up_proj"], + weight_quant_type=weight_quant, + clip_top=layer_clip["model." + name + ".up_proj"], + ) + m.down_proj = W8A8LinearStatic.from_float( + m.down_proj, + decoder_scales["model." + name + ".down_proj"], + weight_quant_type=weight_quant, + clip_top=layer_clip["model." + name + ".down_proj"], + ) + elif isinstance(m, Qwen2VLSdpaAttention): + # Here we simulate quantizing BMM inputs by quantizing the output of q_proj, k_proj, v_proj + m.q_proj = W8A8LinearStatic.from_float( + m.q_proj, + decoder_scales["model." + name + ".q_proj"], + weight_quant_type=weight_quant, + clip_top=layer_clip["model." + name + ".q_proj"], + ) + m.k_proj = W8A8LinearStatic.from_float( + m.k_proj, + decoder_scales["model." + name + ".k_proj"], + weight_quant_type=weight_quant, + clip_top=layer_clip["model." + name + ".k_proj"], + ) + m.v_proj = W8A8LinearStatic.from_float( + m.v_proj, + decoder_scales["model." + name + ".v_proj"], + weight_quant_type=weight_quant, + clip_top=layer_clip["model." + name + ".v_proj"], + ) + + m.o_proj = W8A8LinearStatic.from_float( + m.o_proj, + decoder_scales["model." + name + ".o_proj"], + weight_quant_type=weight_quant, + clip_top=layer_clip["model." + name + ".o_proj"], + ) + return model + def quantize_gemma_like( model, diff --git a/tools/jni/LibHelper.cpp b/tools/jni/LibHelper.cpp index 4c5819360..cfa08047a 100644 --- a/tools/jni/LibHelper.cpp +++ b/tools/jni/LibHelper.cpp @@ -32,7 +32,7 @@ using namespace mllm; #ifdef USE_QNN #include "models/qwen/modeling_qwen_npu.hpp" #include "models/phonelm/modeling_phonelm_npu.hpp" - +#include "models/qwen2_vl/modeling_qwen2_vl_npu.hpp" #endif inline bool exists_test(const std::string &name) { std::ifstream f(name.c_str()); @@ -55,7 +55,18 @@ unsigned int LibHelper::postProcessing(shared_ptr result, shared_ptr(qwconfig, chunk_size); + prefill_module_ = make_shared(qwconfig, chunk_size); prefill_module_->load(qnn_weights_path); auto tokenizer = dynamic_pointer_cast(tokenizer_); @@ -120,7 +131,23 @@ bool LibHelper::setUp(const std::string &base_path, std::string weights_path, st break; case QWEN2VL: processor_ = new Qwen2VLProcessor(vocab_path, merge_path); - module_ = make_shared(qwvlconfig); + LOGI("Init Qwen2VLProcessor: %d", backend_type); +#ifdef USE_QNN + if (backend_type == MLLMBackendType::QNN) { + int chunk_size = 256; + prefill_module_ = make_shared(qwvlconfig, chunk_size); + prefill_module_->load(qnn_weights_path); + prefill_embedding_ = make_shared(qwvlconfig); + prefill_embedding_->load(weights_path); + qwvlconfig.attn_implementation = "eager"; + module_ = make_shared(qwvlconfig); + } else { +#endif + module_ = make_shared(qwvlconfig); + +#ifdef USE_QNN + } +#endif break; case Bert: tokenizer_ = make_shared(vocab_path, true); @@ -318,26 +345,148 @@ void LibHelper::run(std::string &input_str, uint8_t *image, unsigned max_step, u } module_->clear_kvcache(); } else if (model_ == QWEN2VL) { - auto model = dynamic_cast(module_.get()); auto processor = dynamic_cast(processor_); input_str = "Based on the screenshot of the page, I give a text description and you give its corresponding location. The coordinate represents a clickable location [x, y] for an element, which is a relative coordinate on the screenshot, scaled from 0 to 1.<|vision_start|><|image_pad|><|vision_end|>" + input_str; input_str = processor->tokenizer->apply_chat_template(input_str); auto input_tensors = processor->process(input_str, {image}, {image_length}); LOGE("Instruct: %s", input_str.c_str()); LOGE("Tokens: %d", input_tensors[0].sequence()); - for (int step = 0; step < 100; step++) { - model->get_position_ids(input_tensors); - auto result = (*model)(input_tensors); - auto outputs = processor->detokenize(result[0]); - auto out_string = outputs.first; - auto out_token = outputs.second; - auto [end, string] = processor->tokenizer->postprocess(out_string); - output_string_ += string; - callback_(output_string_, !end, {}); - if (!end) { break; } + +#ifdef USE_QNN + if (backend_ == MLLMBackendType::QNN) { + int chunk_size = 256; + + const int real_seq_length = input_tensors[0].sequence(); + const int num_iter = (real_seq_length + chunk_size - 1) / chunk_size; + auto model = dynamic_cast(module_.get()); + auto prefill_embedding = dynamic_cast(prefill_embedding_.get()); + // padding the position_ids to total chunk length(example: 256*2) for CPUMultimodalRoPEPipeline + LOGE("before get_position_ids"); + prefill_embedding->get_position_ids(input_tensors, chunk_size * num_iter); + LOGE("after get_position_ids"); + + // warm up (still need a warm up as the setup stage is not omitted now) + auto merged_embd_warmup_tensor = Tensor(Backend::global_backends[MLLM_QNN]); + merged_embd_warmup_tensor.reshape(1, 1, chunk_size, 1536); + merged_embd_warmup_tensor.setTtype(INPUT_TENSOR); + merged_embd_warmup_tensor.alloc(); + merged_embd_warmup_tensor.setTtype(INPUT_TENSOR); + input_tensors.back().setTtype(INPUT_TENSOR); + vector prefill_input = {merged_embd_warmup_tensor, input_tensors.back()}; + (*prefill_module_)(prefill_input); + LOGE("after warm up"); + + Module::isFirstChunk = false; + static_cast(Backend::global_backends[MLLM_CPU])->setCurSequenceLength(0); + static_cast(Backend::global_backends[MLLM_CPU])->setExecutionType(PROMPT); + static_cast(Backend::global_backends[MLLM_CPU])->toggleSwitching(); + + // set total seq length for HeadLinear execute, which can not get the real seq length from Opts + static_cast(Backend::global_backends[MLLM_CPU])->setTotalSequenceLength(real_seq_length); + // set chunk size for the HeadLinear execute, which can not get the chunk size from Opts + static_cast(Backend::global_backends[MLLM_CPU])->setChunkSize(chunk_size); + + for (auto &t : input_tensors) { + t.setTtype(INPUT_TENSOR); + } + + // 1. get the vit embedding using CPU + auto merged_embd = (*prefill_embedding)(input_tensors); + LOGE("after vit embedding"); + + // free prefill embedding tensor, approximately free 1GB for 59ms + auto begin_free = mllm_time_ms(); + auto &embedding_act = prefill_embedding->activation_tensors; + // go through the activation tensors to get the merged_embd + for (auto iter = embedding_act.begin(); iter != embedding_act.end(); ++iter) { + // std::cout << iter->first << std::endl; + if (iter->first.find("input") != std::string::npos || iter->first.find("index_put") != std::string::npos) { + continue; + } + iter->second->free(); + } + auto end_free = mllm_time_ms(); + LOGE("after free"); + + // 2. QNN LLM Prefill + unsigned int out_token = 0; + for (auto i = 0; i < num_iter; ++i) { + // copy the data from merged_embd[0] to merged_embd_warmup_tensor + auto source = merged_embd[0].ptrAt(0, 0, chunk_size * i, 0); + auto dest = prefill_input[0].hostPtr(); + if (i == 0) { + memcpy(dest, source, prefill_input[0].cntSize()); + } + { + memcpy(dest, source, (merged_embd[0].sequence() % chunk_size) * merged_embd[0].dimension() * sizeof(float)); + } + + auto result = (*prefill_module_)(prefill_input); + + if (i == 0) { // turn off switching to avoid RoPE h_cnt_ reset to curSequenceLength in next chunk + static_cast(Backend::global_backends[MLLM_CPU])->toggleSwitching(); + } + + if (i == 1) { + auto outputs = processor->detokenize(result[0], real_seq_length % chunk_size); + auto out_string = outputs.first; + out_token = outputs.second; + // auto [not_end, output_string] = processor->tokenizer->postprocess(out_string); + // std::cout << output_string << std::flush; + auto [end, string] = processor->tokenizer->postprocess(out_string); + output_string_ += string; + callback_(output_string_, !end, {}); + } + } + chatPostProcessing(out_token, input_tensors[0], {&input_tensors[1], &input_tensors[2]}); + + static_cast(Backend::global_backends[MLLM_CPU])->setCurSequenceLength(real_seq_length); + static_cast(Backend::global_backends[MLLM_CPU])->setExecutionType(AUTOREGRESSIVE); + static_cast(Backend::global_backends[MLLM_CPU])->toggleSwitching(); + + // 3. CPU LLM Decoding + for (auto &t : input_tensors) { // set to INPUT_TENSOR to let decoding module update act + t.setTtype(INPUT_TENSOR); + } + + const int last_position_id = input_tensors[3].dataAt(0, 0, 0, real_seq_length - 1); + for (int step = 0; step < 100; step++) { + // use the last position id(no padding position) in decoding + prefill_embedding->get_position_ids(input_tensors, 0, last_position_id + 1 + step); + + auto result = (*model)(input_tensors); + auto outputs = processor->detokenize(result[0]); + auto out_string = outputs.first; + auto out_token = outputs.second; + auto [end, string] = processor->tokenizer->postprocess(out_string); + output_string_ += string; + callback_(output_string_, !end, {}); + if (!end) { break; } + chatPostProcessing(out_token, input_tensors[0], {&input_tensors[1], &input_tensors[2]}); + if (step == 0) static_cast(Backend::global_backends[MLLM_CPU])->toggleSwitching(); + } + + std::cout << std::endl; + } else { +#endif + auto model = dynamic_cast(module_.get()); + for (int step = 0; step < 100; step++) { + model->get_position_ids(input_tensors); + auto result = (*model)(input_tensors); + auto outputs = processor->detokenize(result[0]); + auto out_string = outputs.first; + auto out_token = outputs.second; + auto [end, string] = processor->tokenizer->postprocess(out_string); + output_string_ += string; + callback_(output_string_, !end, {}); + if (!end) { break; } + chatPostProcessing(out_token, input_tensors[0], {&input_tensors[1], &input_tensors[2]}); + } + module_->clear_kvcache(); +#ifdef USE_QNN } - module_->clear_kvcache(); +#endif } else if (model_ == Bert) { LOGE("Bert model is not supported in this version."); } else if (model_ == PhoneLM) { diff --git a/tools/jni/LibHelper.hpp b/tools/jni/LibHelper.hpp index 98791ae3d..97431528a 100644 --- a/tools/jni/LibHelper.hpp +++ b/tools/jni/LibHelper.hpp @@ -47,6 +47,7 @@ class LibHelper { PreProcessor *processor_; std::shared_ptr module_; std::shared_ptr prefill_module_; + std::shared_ptr prefill_embedding_; // Tokenizer *tokenizer_ = nullptr; unsigned int eos_id_ = 2;