diff --git a/.gitignore b/.gitignore index 22e2a9a6f..7397d6ecc 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ tasks/mllmteam* build*/ install*/ mllm-sdk-*/ +mllm-install-*/ # Pymllm related stubs/ diff --git a/CMakeLists.txt b/CMakeLists.txt index 298b412c0..a19e80df3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -262,6 +262,7 @@ add_subdirectory(third_party/fmt) add_subdirectory(third_party/xxHash) set(FLATBUFFERS_BUILD_TESTS OFF) add_subdirectory(third_party/flatbuffers EXCLUDE_FROM_ALL) +set_target_properties(flatbuffers PROPERTIES POSITION_INDEPENDENT_CODE ON) add_subdirectory(mllm) if(MLLM_ENABLE_TEST) @@ -332,6 +333,13 @@ install( ARCHIVE DESTINATION lib RUNTIME DESTINATION bin) +install( + TARGETS flatbuffers + EXPORT MllmTargets + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib + RUNTIME DESTINATION bin) + if(MLLM_BUILD_SDK_C_BINDING) install( TARGETS MllmSdkC diff --git "a/Icon\r" "b/Icon\r" new file mode 100644 index 000000000..e69de29bb diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 180c3cbe6..a2426f229 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -2,6 +2,7 @@ add_subdirectory(qwen2vl) add_subdirectory(qwen2vl_tracer) add_subdirectory(qwen2_5vl) add_subdirectory(qwen2_5vl_tracer) +add_subdirectory(qwen2_5omni) add_subdirectory(llama) add_subdirectory(minicpm_o) add_subdirectory(minicpm4) diff --git a/examples/qwen2_5omni/CMakeLists.txt b/examples/qwen2_5omni/CMakeLists.txt new file mode 100644 index 000000000..479c3a635 --- /dev/null +++ b/examples/qwen2_5omni/CMakeLists.txt @@ -0,0 +1,11 @@ +add_executable(mllm-qwen2_5-omni-text-runner text_infer.cpp) +target_link_libraries(mllm-qwen2_5-omni-text-runner PRIVATE MllmRT MllmCPUBackend) +target_include_directories(mllm-qwen2_5-omni-text-runner PRIVATE ${MLLM_INCLUDE_DIR}) + +add_executable(mllm-qwen2_5-omni-image-runner image_infer.cpp) +target_link_libraries(mllm-qwen2_5-omni-image-runner PRIVATE MllmRT MllmCPUBackend) +target_include_directories(mllm-qwen2_5-omni-image-runner PRIVATE ${MLLM_INCLUDE_DIR}) + +add_executable(mllm-qwen2_5-omni-audio-runner audio_infer.cpp) +target_link_libraries(mllm-qwen2_5-omni-audio-runner PRIVATE MllmRT MllmCPUBackend) +target_include_directories(mllm-qwen2_5-omni-audio-runner PRIVATE ${MLLM_INCLUDE_DIR}) diff --git a/examples/qwen2_5omni/audio_infer.cpp b/examples/qwen2_5omni/audio_infer.cpp new file mode 100644 index 000000000..d159c2b3e --- /dev/null +++ b/examples/qwen2_5omni/audio_infer.cpp @@ -0,0 +1,84 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include + +using mllm::Argparse; + +MLLM_MAIN({ + mllm::Logger::level() = mllm::LogLevel::kError; + + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model_path").help("Model path").required(true); + auto& model_version = Argparse::add("-mv|--model_version").help("Model version").required(true); + auto& tokenizer_path = Argparse::add("-t|--tokenizer_path").help("Tokenizer directory").required(true); + auto& config_path = Argparse::add("-c|--config_path").help("Config path").required(true); + + Argparse::parse(argc, argv); + + mllm::ModelFileVersion file_version = mllm::ModelFileVersion::kV1; + if (model_version.get() == "v1") { + file_version = mllm::ModelFileVersion::kV1; + } else if (model_version.get() == "v2") { + file_version = mllm::ModelFileVersion::kV2; + } + + if (help.isSet()) { + Argparse::printHelp(); + mllm::shutdownContext(); + return 0; + } + + { + auto qwen2_5omni_cfg = mllm::models::qwen2_5omni::Qwen2_5OmniConfig(config_path.get()); + auto qwen2_5omni_tokenizer = mllm::models::qwen2_5omni::Qwen2_5OmniTokenizer(tokenizer_path.get()); + auto qwen2_5omni = mllm::models::qwen2_5omni::Qwen2_5OmniForCausalLM(qwen2_5omni_cfg); + + auto param = mllm::load(model_path.get(), file_version); + qwen2_5omni.thinker_.load(param); + + fmt::print("\n{:*^60}\n", " Qwen2.5-Omni Audio CLI "); + fmt::print("Enter 'exit' or 'quit' to end the session\n\n"); + + std::string audio_path; + std::string prompt_text; + + fmt::print("Audio path (or 'exit/quit'): "); + //std::getline(std::cin, audio_path); + //if (audio_path == "exit" || audio_path == "quit") { return 0; } + audio_path = ""; + + fmt::print("Prompt text: "); + //std::getline(std::cin, prompt_text); + //if (prompt_text.empty()) { prompt_text = "Please describe the audio."; } + prompt_text = ""; + + try { + fmt::print("Processing...\n"); + auto inputs = qwen2_5omni_tokenizer.convertAudioMessage({.prompt = prompt_text, .audio_file_path = audio_path}); + + fmt::print("\nResponse: "); + qwen2_5omni.streamGenerate(inputs, + { + {"do_sample", mllm::AnyValue(false)}, + {"max_length", mllm::AnyValue(qwen2_5omni_cfg.max_cache_length)}, + }, + [&](int64_t token_id) { + auto str = qwen2_5omni_tokenizer.detokenize(token_id); + std::wcout << str << std::flush; + }); + + fmt::print("\n{}\n", std::string(60, '-')); + } catch (const std::exception& e) { fmt::print("\nError: {}\n{}\n", e.what(), std::string(60, '-')); } + + qwen2_5omni.perfSummary(); + } + + mllm::print("\n"); + mllm::memoryReport(); +}) diff --git a/examples/qwen2_5omni/config_qwen2_5omni_7B.json b/examples/qwen2_5omni/config_qwen2_5omni_7B.json new file mode 100644 index 000000000..8f27b94b9 --- /dev/null +++ b/examples/qwen2_5omni/config_qwen2_5omni_7B.json @@ -0,0 +1,495 @@ +{ + "architectures": [ + "Qwen2_5OmniModel" + ], + "enable_audio_output": true, + "enable_talker": true, + "model_type": "qwen2_5_omni", + "talker_config": { + "_attn_implementation_autoset": true, + "_name_or_path": "Qwen2.5-Omni-7B/talker", + "architectures": [ + "Qwen2OmniTalkerForConditionalGeneration" + ], + "attention_dropout": 0.0, + "audio_end_token_id": 151648, + "audio_start_token_id": 151647, + "audio_token_index": 151646, + "embedding_size": 3584, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 896, + "image_token_index": 151655, + "init_std": 0.02, + "initializer_range": 0.02, + "intermediate_size": 18944, + "max_position_embeddings": 32768, + "max_window_layers": 28, + "model_type": "qwen2_5_omni_talker", + "num_attention_heads": 12, + "num_hidden_layers": 24, + "num_key_value_heads": 4, + "position_id_per_seconds": 25, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "mrope_section": [ + 16, + 24, + 24 + ], + "rope_type": "default", + "type": "default" + }, + "rope_theta": 1000000.0, + "seconds_per_chunk": 2, + "sliding_window": 32768, + "spatial_merge_size": 2, + "torch_dtype": "bfloat16", + "tts_codec_end_token_id": 8294, + "tts_codec_mask_token_id": 8296, + "tts_codec_pad_token_id": 8292, + "tts_codec_start_token_id": 8293, + "tts_text_end_token_id": 151861, + "tts_text_pad_token_id": 151859, + "tts_text_start_token_id": 151860, + "use_cache": true, + "use_sliding_window": false, + "video_token_index": 151656, + "vision_end_token_id": 151653, + "vision_start_token_id": 151652, + "vocab_size": 8448 + }, + "thinker_config": { + "_attn_implementation_autoset": true, + "_name_or_path": "Qwen2.5-Omni-7B/thinker", + "architectures": [ + "Qwen2OmniNaViTThinkerForConditionalGeneration" + ], + "audio_config": { + "_attn_implementation_autoset": true, + "_name_or_path": "", + "activation_dropout": 0.0, + "activation_function": "gelu", + "add_cross_attention": false, + "architectures": null, + "attention_dropout": 0.0, + "bad_words_ids": null, + "begin_suppress_tokens": null, + "bos_token_id": null, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": null, + "d_model": 1280, + "decoder_start_token_id": null, + "diversity_penalty": 0.0, + "do_sample": false, + "dropout": 0.0, + "early_stopping": false, + "encoder_attention_heads": 20, + "encoder_ffn_dim": 5120, + "encoder_layerdrop": 0.0, + "encoder_layers": 32, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": null, + "exponential_decay_length_penalty": null, + "finetuning_task": null, + "forced_bos_token_id": null, + "forced_eos_token_id": null, + "id2label": { + "0": "LABEL_0", + "1": "LABEL_1" + }, + "init_std": 0.02, + "is_decoder": false, + "is_encoder_decoder": false, + "label2id": { + "LABEL_0": 0, + "LABEL_1": 1 + }, + "length_penalty": 1.0, + "max_length": 20, + "max_source_positions": 1500, + "min_length": 0, + "model_type": "qwen2_5_omni_audio_encoder", + "n_window": 100, + "no_repeat_ngram_size": 0, + "num_beam_groups": 1, + "num_beams": 1, + "num_hidden_layers": 32, + "num_mel_bins": 128, + "num_return_sequences": 1, + "output_attentions": false, + "output_dim": 3584, + "output_hidden_states": false, + "output_scores": false, + "pad_token_id": null, + "prefix": null, + "problem_type": null, + "pruned_heads": {}, + "remove_invalid_values": false, + "repetition_penalty": 1.0, + "return_dict": true, + "return_dict_in_generate": false, + "scale_embedding": false, + "sep_token_id": null, + "suppress_tokens": null, + "task_specific_params": null, + "temperature": 1.0, + "tf_legacy_loss": false, + "tie_encoder_decoder": false, + "tie_word_embeddings": true, + "tokenizer_class": null, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": null, + "torchscript": false, + "typical_p": 1.0, + "use_bfloat16": false + }, + "text_config": { + "model_type": "qwen2_5_omni_text", + "hidden_act": "silu", + "hidden_size": 3584, + "init_std": 0.02, + "intermediate_size": 18944, + "vocab_size": 152064, + "num_attention_heads": 28, + "num_hidden_layers": 28, + "num_key_value_heads": 4, + "max_position_embeddings": 32768, + "max_window_layers": 28, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "mrope_section": [ + 16, + 24, + 24 + ], + "rope_type": "default", + "type": "default" + }, + "use_cache": true, + "rope_theta": 1000000.0, + "use_sliding_window": false, + "sliding_window": 32768, + "attention_dropout": 0.0, + "tie_word_embeddings": false + }, + "audio_end_token_id": 151648, + "audio_start_token_id": 151647, + "audio_token_index": 151646, + "bos_token_id": 151644, + "eos_token_id": 151645, + "ignore_index": -100, + "image_token_index": 151655, + "init_std": 0.02, + "model_type": "qwen2_5_omni_thinker", + "pad_token_id": 151643, + "position_id_per_seconds": 25, + "seconds_per_chunk": 2, + "torch_dtype": "bfloat16", + "user_token_id": 872, + "video_token_index": 151656, + "vision_config": { + "_attn_implementation_autoset": true, + "_name_or_path": "", + "add_cross_attention": false, + "architectures": null, + "bad_words_ids": null, + "begin_suppress_tokens": null, + "bos_token_id": null, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": null, + "decoder_start_token_id": null, + "depth": 32, + "diversity_penalty": 0.0, + "do_sample": false, + "early_stopping": false, + "embed_dim": 1280, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": null, + "exponential_decay_length_penalty": null, + "finetuning_task": null, + "forced_bos_token_id": null, + "forced_eos_token_id": null, + "fullatt_block_indexes": [ + 7, + 15, + 23, + 31 + ], + "hidden_act": "silu", + "hidden_size": 1280, + "id2label": { + "0": "LABEL_0", + "1": "LABEL_1" + }, + "in_channels": 3, + "in_chans": 3, + "init_std": 0.02, + "intermediate_size": 3420, + "is_decoder": false, + "is_encoder_decoder": false, + "label2id": { + "LABEL_0": 0, + "LABEL_1": 1 + }, + "length_penalty": 1.0, + "max_length": 20, + "min_length": 0, + "model_type": "qwen2_5_omni_vision_encoder", + "no_repeat_ngram_size": 0, + "num_beam_groups": 1, + "num_beams": 1, + "num_heads": 16, + "num_return_sequences": 1, + "out_hidden_size": 3584, + "output_attentions": false, + "output_hidden_states": false, + "output_scores": false, + "pad_token_id": null, + "patch_size": 14, + "prefix": null, + "problem_type": null, + "pruned_heads": {}, + "remove_invalid_values": false, + "repetition_penalty": 1.0, + "return_dict": true, + "return_dict_in_generate": false, + "sep_token_id": null, + "spatial_merge_size": 2, + "spatial_patch_size": 14, + "suppress_tokens": null, + "task_specific_params": null, + "temperature": 1.0, + "temporal_patch_size": 2, + "tf_legacy_loss": false, + "tie_encoder_decoder": false, + "tie_word_embeddings": true, + "tokenizer_class": null, + "tokens_per_second": 25, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": null, + "torchscript": false, + "typical_p": 1.0, + "use_bfloat16": false, + "window_size": 112 + }, + "vision_end_token_id": 151653, + "vision_start_token_id": 151652, + "vision_token_id": 151654 + }, + "token2wav_config": { + "_attn_implementation_autoset": true, + "bigvgan_config": { + "_attn_implementation_autoset": true, + "_name_or_path": "", + "add_cross_attention": false, + "architectures": null, + "bad_words_ids": null, + "begin_suppress_tokens": null, + "bos_token_id": null, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": null, + "decoder_start_token_id": null, + "diversity_penalty": 0.0, + "do_sample": false, + "early_stopping": false, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": null, + "exponential_decay_length_penalty": null, + "finetuning_task": null, + "forced_bos_token_id": null, + "forced_eos_token_id": null, + "id2label": { + "0": "LABEL_0", + "1": "LABEL_1" + }, + "is_decoder": false, + "is_encoder_decoder": false, + "label2id": { + "LABEL_0": 0, + "LABEL_1": 1 + }, + "length_penalty": 1.0, + "max_length": 20, + "mel_dim": 80, + "min_length": 0, + "model_type": "qwen2_5_omni_bigvgan", + "no_repeat_ngram_size": 0, + "num_beam_groups": 1, + "num_beams": 1, + "num_return_sequences": 1, + "output_attentions": false, + "output_hidden_states": false, + "output_scores": false, + "pad_token_id": null, + "prefix": null, + "problem_type": null, + "pruned_heads": {}, + "remove_invalid_values": false, + "repetition_penalty": 1.0, + "resblock_dilation_sizes": [ + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ] + ], + "resblock_kernel_sizes": [ + 3, + 7, + 11 + ], + "return_dict": true, + "return_dict_in_generate": false, + "sep_token_id": null, + "suppress_tokens": null, + "task_specific_params": null, + "temperature": 1.0, + "tf_legacy_loss": false, + "tie_encoder_decoder": false, + "tie_word_embeddings": true, + "tokenizer_class": null, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": null, + "torchscript": false, + "typical_p": 1.0, + "upsample_initial_channel": 1536, + "upsample_kernel_sizes": [ + 11, + 7, + 4, + 4, + 4, + 4 + ], + "upsample_rates": [ + 5, + 3, + 2, + 2, + 2, + 2 + ], + "use_bfloat16": false, + "use_bias_at_final": false + }, + "dit_config": { + "_attn_implementation_autoset": true, + "_name_or_path": "", + "add_cross_attention": false, + "architectures": null, + "bad_words_ids": null, + "begin_suppress_tokens": null, + "bos_token_id": null, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": null, + "decoder_start_token_id": null, + "depth": 22, + "dim": 1024, + "diversity_penalty": 0.0, + "do_sample": false, + "dropout": 0.1, + "early_stopping": false, + "emb_dim": 512, + "enc_attention_channels": 64, + "enc_channels": [ + 256, + 256, + 256, + 256, + 768 + ], + "enc_dilations": [ + 1, + 2, + 3, + 4, + 1 + ], + "enc_dim": 128, + "enc_emb_dim": 192, + "enc_global_context": true, + "enc_kernel_sizes": [ + 5, + 3, + 3, + 3, + 1 + ], + "enc_lin_neurons": 192, + "enc_res2net_scale": 2, + "enc_se_channels": 64, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": null, + "exponential_decay_length_penalty": null, + "ff_mult": 2, + "finetuning_task": null, + "forced_bos_token_id": null, + "forced_eos_token_id": null, + "head_dim": 64, + "heads": 16, + "id2label": { + "0": "LABEL_0", + "1": "LABEL_1" + }, + "is_decoder": false, + "is_encoder_decoder": false, + "label2id": { + "LABEL_0": 0, + "LABEL_1": 1 + }, + "length_penalty": 1.0, + "max_length": 20, + "mel_dim": 80, + "min_length": 0, + "model_type": "qwen2_5_omni_dit", + "no_repeat_ngram_size": 0, + "num_beam_groups": 1, + "num_beams": 1, + "num_embeds": 8193, + "num_return_sequences": 1, + "output_attentions": false, + "output_hidden_states": false, + "output_scores": false, + "pad_token_id": null, + "prefix": null, + "problem_type": null, + "pruned_heads": {}, + "remove_invalid_values": false, + "repeats": 2, + "repetition_penalty": 1.0, + "return_dict": true, + "return_dict_in_generate": false, + "sep_token_id": null, + "suppress_tokens": null, + "task_specific_params": null, + "temperature": 1.0, + "tf_legacy_loss": false, + "tie_encoder_decoder": false, + "tie_word_embeddings": true, + "tokenizer_class": null, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": "float32", + "torchscript": false, + "typical_p": 1.0, + "use_bfloat16": false + }, + "model_type": "qwen2_5_omni_token2wav" + }, + "torch_dtype": "bfloat16", + "transformers_version": "4.50.0.dev0" +} diff --git a/examples/qwen2_5omni/image_infer.cpp b/examples/qwen2_5omni/image_infer.cpp new file mode 100644 index 000000000..3c0bf214b --- /dev/null +++ b/examples/qwen2_5omni/image_infer.cpp @@ -0,0 +1,84 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include + +using mllm::Argparse; + +MLLM_MAIN({ + mllm::Logger::level() = mllm::LogLevel::kError; + + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model_path").help("Model path").required(true); + auto& model_version = Argparse::add("-mv|--model_version").help("Model version").required(true); + auto& tokenizer_path = Argparse::add("-t|--tokenizer_path").help("Tokenizer directory").required(true); + auto& config_path = Argparse::add("-c|--config_path").help("Config path").required(true); + + Argparse::parse(argc, argv); + + mllm::ModelFileVersion file_version = mllm::ModelFileVersion::kV1; + if (model_version.get() == "v1") { + file_version = mllm::ModelFileVersion::kV1; + } else if (model_version.get() == "v2") { + file_version = mllm::ModelFileVersion::kV2; + } + + if (help.isSet()) { + Argparse::printHelp(); + mllm::shutdownContext(); + return 0; + } + + { + auto qwen2_5omni_cfg = mllm::models::qwen2_5omni::Qwen2_5OmniConfig(config_path.get()); + auto qwen2_5omni_tokenizer = + mllm::models::qwen2_5omni::Qwen2_5OmniTokenizer(tokenizer_path.get(), qwen2_5omni_cfg.visual_spatial_merge_size); + auto qwen2_5omni = mllm::models::qwen2_5omni::Qwen2_5OmniForCausalLM(qwen2_5omni_cfg); + + auto param = mllm::load(model_path.get(), file_version); + qwen2_5omni.thinker_.load(param); + + fmt::print("\n{:*^60}\n", " Qwen2.5-Omni Image CLI "); + fmt::print("Enter 'exit' or 'quit' to end the session\n\n"); + + std::string image_path; + std::string prompt_text; + + fmt::print("Image path (or 'exit/quit'): "); + image_path = ""; + //std::getline(std::cin, image_path); + if (image_path == "exit" || image_path == "quit") { return 0; } + + fmt::print("Prompt text: "); + prompt_text = ""; + //std::getline(std::cin, prompt_text); + + try { + fmt::print("Processing...\n"); + auto inputs = qwen2_5omni_tokenizer.convertVisionMessage({.prompt = prompt_text, .img_file_path = image_path}); + + fmt::print("\nResponse: "); + qwen2_5omni.streamGenerate(inputs, + { + {"do_sample", mllm::AnyValue(false)}, + {"max_length", mllm::AnyValue(qwen2_5omni_cfg.max_cache_length)}, + }, + [&](int64_t token_id) { + auto str = qwen2_5omni_tokenizer.detokenize(token_id); + std::wcout << str << std::flush; + }); + + fmt::print("\n{}\n", std::string(60, '-')); + } catch (const std::exception& e) { fmt::print("\nError: {}\n{}\n", e.what(), std::string(60, '-')); } + + qwen2_5omni.perfSummary(); + } + + mllm::print("\n"); + mllm::memoryReport(); +}) diff --git a/examples/qwen2_5omni/text_infer.cpp b/examples/qwen2_5omni/text_infer.cpp new file mode 100644 index 000000000..299a0e07d --- /dev/null +++ b/examples/qwen2_5omni/text_infer.cpp @@ -0,0 +1,72 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include + +using mllm::Argparse; + +MLLM_MAIN({ + mllm::Logger::level() = mllm::LogLevel::kError; + + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model_path").help("Model path").required(true); + auto& model_version = Argparse::add("-mv|--model_version").help("Model version").required(true); + auto& tokenizer_path = Argparse::add("-t|--tokenizer_path").help("Tokenizer directory").required(true); + auto& config_path = Argparse::add("-c|--config_path").help("Config path").required(true); + + Argparse::parse(argc, argv); + + mllm::ModelFileVersion file_version = mllm::ModelFileVersion::kV1; + if (model_version.get() == "v1") { + file_version = mllm::ModelFileVersion::kV1; + } else if (model_version.get() == "v2") { + file_version = mllm::ModelFileVersion::kV2; + } + + if (help.isSet()) { + Argparse::printHelp(); + mllm::shutdownContext(); + return 0; + } + + { + auto qwen2_5omni_cfg = mllm::models::qwen2_5omni::Qwen2_5OmniConfig(config_path.get()); + auto qwen2_5omni_tokenizer = mllm::models::qwen2_5omni::Qwen2_5OmniTokenizer(tokenizer_path.get()); + auto qwen2_5omni = mllm::models::qwen2_5omni::Qwen2_5OmniForCausalLM(qwen2_5omni_cfg); + + auto param = mllm::load(model_path.get(), file_version); + qwen2_5omni.thinker_.load(param); + + fmt::print("\n{:*^60}\n", " Qwen2.5-Omni Text CLI "); + fmt::print("Enter 'exit' or 'quit' to end the session\n\n"); + + std::string prompt_text; + + fmt::print("šŸ’¬ Prompt text (or 'exit/quit'): "); + std::getline(std::cin, prompt_text); + + if (prompt_text == "exit" || prompt_text == "quit") { return 0; } + + try { + fmt::print("šŸ”„ Processing...\n"); + auto inputs = qwen2_5omni_tokenizer.convertMessage({.prompt = prompt_text}); + + fmt::print("\nšŸ¤– Response: "); + for (auto& step : qwen2_5omni.chat(inputs)) { + std::wcout << qwen2_5omni_tokenizer.detokenize(step.cur_token_id) << std::flush; + } + + fmt::print("\n{}\n", std::string(60, '-')); + } catch (const std::exception& e) { fmt::print("\nāŒ Error: {}\n{}\n", e.what(), std::string(60, '-')); } + + qwen2_5omni.perfSummary(); + } + + mllm::print("\n"); + mllm::memoryReport(); +}) diff --git a/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot.hpp b/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot.hpp index 9eed37267..a2d054bad 100644 --- a/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot.hpp +++ b/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot.hpp @@ -15,14 +15,6 @@ namespace mllm::models::qwen3 { -Tensor rotateHalf(Tensor x) { // NOLINT - // X is [x, x, x, D] - auto D = x.size(-1); - auto x1 = x.slice({kAll, kAll, kAll, {kAll, D / 2}}, /*ssa=*/true); - auto x2 = x.slice({kAll, kAll, kAll, {D / 2, kAll}}, /*ssa=*/true); - return nn::functional::concat({-x2, x1}, -1); -} - namespace ptq { Tensor QDQ(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { @@ -112,6 +104,14 @@ Tensor QDQ_ROPE(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch } // namespace ptq +Tensor rotateHalf(Tensor x, nn::Module* m, const std::string& qdq_name_in_pytorch) { // NOLINT + // X is [x, x, x, D] + auto D = x.size(-1); + auto x1 = x.slice({kAll, kAll, kAll, {kAll, D / 2}}, /*ssa=*/true); + auto x2 = x.slice({kAll, kAll, kAll, {D / 2, kAll}}, /*ssa=*/true); + return nn::functional::concat({ptq::QDQ(m, -x2, qdq_name_in_pytorch), x1}, -1); +} + using vi32 = std::vector; #define CONV2D_PROPERTY vi32{1, 1}, vi32{1, 1}, vi32{0, 0}, vi32{1, 1}, false, aops::Conv2DOpImplType::kQNN_LPBQ_w4a16o16_G32 @@ -232,14 +232,16 @@ class Qwen3Attention final : public nn::Module { // [B, H, S, D] auto cos = llm_embedding_cos.unsqueeze(1); auto sin = llm_embedding_sin.unsqueeze(1); - query_states = ptq::QDQ(this, - ptq::QDQ(this, query_states * cos, "q_rope_mul_0_output_qdq") - + ptq::QDQ(this, rotateHalf(query_states) * sin, "q_rope_mul_1_output_qdq"), - "q_rope_add_0_output_qdq"); - key_states = ptq::QDQ(this, - ptq::QDQ(this, key_states * cos, "k_rope_mul_0_output_qdq") - + ptq::QDQ(this, rotateHalf(key_states) * sin, "k_rope_mul_1_output_qdq"), - "k_rope_add_0_output_qdq"); + query_states = + ptq::QDQ(this, + ptq::QDQ(this, query_states * cos, "q_rope_mul_0_output_qdq") + + ptq::QDQ(this, rotateHalf(query_states, this, "q_rope_neg_half_qdq") * sin, "q_rope_mul_1_output_qdq"), + "q_rope_add_0_output_qdq"); + key_states = + ptq::QDQ(this, + ptq::QDQ(this, key_states * cos, "k_rope_mul_0_output_qdq") + + ptq::QDQ(this, rotateHalf(key_states, this, "k_rope_neg_half_qdq") * sin, "k_rope_mul_1_output_qdq"), + "k_rope_add_0_output_qdq"); // De-quantization and quantization again key_states = key_states.to(kFloat32); @@ -272,7 +274,9 @@ class Qwen3Attention final : public nn::Module { auto attn_min = ptq::QDQ(this, attn.min(-1, true), "reduce_min_output_qdq"); auto minus_value = Tensor::constant(-20, kFloat32); minus_value = ptq::QDQ(this, minus_value, "neg_20_qdq"); - attn = nn::functional::where(causal_mask.equal(0.f), attn, attn_min.addConstant(minus_value)); + auto attn_vv = ptq::QDQ(this, attn_min.addConstant(minus_value), "minus_0_output_qdq"); + attn = nn::functional::where(causal_mask.equal(0.f), attn, attn_vv); + attn = ptq::QDQ(this, attn, "where_attn_qdq"); attn = ptq::QDQ(this, nn::functional::softmax(attn, -1), "softmax_output_qdq"); auto y = ptq::QDQ(this, nn::functional::matmul(attn, vh), "attn_value_matmul_output_qdq"); y = y.transpose(1, 2).view({1, 1, -1, num_attention_heads_ * head_dim_}, /*ssa=*/true); diff --git a/mllm/CMakeLists.txt b/mllm/CMakeLists.txt index 9df6b7741..06fa5aab2 100644 --- a/mllm/CMakeLists.txt +++ b/mllm/CMakeLists.txt @@ -56,6 +56,17 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "App endif() endif() +# FIXME: @oreomaker Need to remove comma features in slice! +# Suppress comma-subscript warnings (deprecated C++ feature that will be removed in C++26) +# This flag is only available in Clang 13+ and GCC 10+ +if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang") + target_compile_options(MllmRT PUBLIC -Wno-comma-subscript) +elseif(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL "10.0") + target_compile_options(MllmRT PUBLIC -Wno-comma-subscript) + endif() +endif() + # ONLY APPLE CAN DO ! # Processing OpenMP if(MLLM_KERNEL_USE_THREADS AND MLLM_KERNEL_THREADS_VENDOR_OPENMP) diff --git a/mllm/backends/cpu/kernels/common/fill-inl.hpp b/mllm/backends/cpu/kernels/common/fill-inl.hpp new file mode 100644 index 000000000..4c799daf6 --- /dev/null +++ b/mllm/backends/cpu/kernels/common/fill-inl.hpp @@ -0,0 +1,363 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +// NOTE: Do NOT use #pragma once here! +// Highway's foreach_target.h mechanism requires -inl.hpp files to be included +// multiple times, once for each target architecture (AVX3_DL, AVX10_2, etc.). + +#include +#include +#include "mllm/core/DataTypes.hpp" + +HWY_BEFORE_NAMESPACE(); +namespace mllm::cpu::common { // NOLINT +namespace HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; + +//===----------------------------------------------------------------------===// +// Fill Zeros +//===----------------------------------------------------------------------===// +template +HWY_INLINE void fill_zeros_impl(T* HWY_RESTRICT dst, size_t count) { + const hn::ScalableTag d; + const size_t N = hn::Lanes(d); + const hn::Vec zero = hn::Zero(d); + size_t idx = 0; + + for (; idx + N <= count; idx += N) { hn::StoreU(zero, d, dst + idx); } + + if (idx < count) { hn::StoreN(zero, d, dst + idx, count - idx); } +} + +// Specialization for types not supported by Highway SIMD, use memset +template +HWY_INLINE void fill_zeros_scalar(T* HWY_RESTRICT dst, size_t count) { + if constexpr (std::is_trivial_v) { + std::memset(dst, 0, count * sizeof(T)); + } else { + T zero_val{}; + for (size_t i = 0; i < count; ++i) { dst[i] = zero_val; } + } +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_zeros_fp32(mllm_fp32_t* HWY_RESTRICT dst, size_t size) { + fill_zeros_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_zeros_fp64(mllm_fp64_t* HWY_RESTRICT dst, size_t size) { + fill_zeros_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_zeros_i32(mllm_int32_t* HWY_RESTRICT dst, size_t size) { + fill_zeros_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_zeros_u32(mllm_uint32_t* HWY_RESTRICT dst, size_t size) { + fill_zeros_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_zeros_i64(mllm_int64_t* HWY_RESTRICT dst, size_t size) { + fill_zeros_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_zeros_u64(mllm_uint64_t* HWY_RESTRICT dst, size_t size) { + fill_zeros_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_zeros_i16(mllm_int16_t* HWY_RESTRICT dst, size_t size) { + fill_zeros_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_zeros_u16(mllm_uint16_t* HWY_RESTRICT dst, size_t size) { + fill_zeros_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_zeros_i8(mllm_int8_t* HWY_RESTRICT dst, size_t size) { + fill_zeros_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_zeros_u8(mllm_uint8_t* HWY_RESTRICT dst, size_t size) { + fill_zeros_impl(dst, size); +} + +//===----------------------------------------------------------------------===// +// Fill Ones +//===----------------------------------------------------------------------===// +template +HWY_INLINE void fill_ones_impl(T* HWY_RESTRICT dst, size_t count) { + const hn::ScalableTag d; + const size_t N = hn::Lanes(d); + const hn::Vec one = hn::Set(d, static_cast(1)); + size_t idx = 0; + + for (; idx + N <= count; idx += N) { hn::StoreU(one, d, dst + idx); } + + if (idx < count) { hn::StoreN(one, d, dst + idx, count - idx); } +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_ones_fp32(mllm_fp32_t* HWY_RESTRICT dst, size_t size) { + fill_ones_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_ones_fp64(mllm_fp64_t* HWY_RESTRICT dst, size_t size) { + fill_ones_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_ones_i32(mllm_int32_t* HWY_RESTRICT dst, size_t size) { + fill_ones_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_ones_u32(mllm_uint32_t* HWY_RESTRICT dst, size_t size) { + fill_ones_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_ones_i64(mllm_int64_t* HWY_RESTRICT dst, size_t size) { + fill_ones_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_ones_u64(mllm_uint64_t* HWY_RESTRICT dst, size_t size) { + fill_ones_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_ones_i16(mllm_int16_t* HWY_RESTRICT dst, size_t size) { + fill_ones_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_ones_u16(mllm_uint16_t* HWY_RESTRICT dst, size_t size) { + fill_ones_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_ones_i8(mllm_int8_t* HWY_RESTRICT dst, size_t size) { + fill_ones_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_ones_u8(mllm_uint8_t* HWY_RESTRICT dst, size_t size) { + fill_ones_impl(dst, size); +} + +//===----------------------------------------------------------------------===// +// Fill Specific Value +//===----------------------------------------------------------------------===// +template +HWY_INLINE void fill_value_impl(T* HWY_RESTRICT dst, size_t count, T value) { + const hn::ScalableTag d; + const size_t N = hn::Lanes(d); + const hn::Vec v = hn::Set(d, value); + size_t idx = 0; + + for (; idx + N <= count; idx += N) { hn::StoreU(v, d, dst + idx); } + + if (idx < count) { hn::StoreN(v, d, dst + idx, count - idx); } +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_value_fp32(mllm_fp32_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t value) { + fill_value_impl(dst, size, value); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_value_fp64(mllm_fp64_t* HWY_RESTRICT dst, size_t size, mllm_fp64_t value) { + fill_value_impl(dst, size, value); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_value_i32(mllm_int32_t* HWY_RESTRICT dst, size_t size, mllm_int32_t value) { + fill_value_impl(dst, size, value); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_value_u32(mllm_uint32_t* HWY_RESTRICT dst, size_t size, mllm_uint32_t value) { + fill_value_impl(dst, size, value); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_value_i64(mllm_int64_t* HWY_RESTRICT dst, size_t size, mllm_int64_t value) { + fill_value_impl(dst, size, value); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_value_u64(mllm_uint64_t* HWY_RESTRICT dst, size_t size, mllm_uint64_t value) { + fill_value_impl(dst, size, value); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_value_i16(mllm_int16_t* HWY_RESTRICT dst, size_t size, mllm_int16_t value) { + fill_value_impl(dst, size, value); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_value_u16(mllm_uint16_t* HWY_RESTRICT dst, size_t size, mllm_uint16_t value) { + fill_value_impl(dst, size, value); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_value_i8(mllm_int8_t* HWY_RESTRICT dst, size_t size, mllm_int8_t value) { + fill_value_impl(dst, size, value); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_value_u8(mllm_uint8_t* HWY_RESTRICT dst, size_t size, mllm_uint8_t value) { + fill_value_impl(dst, size, value); +} + +//===----------------------------------------------------------------------===// +// Fill Arange (start, end, step) +//===----------------------------------------------------------------------===// +template +HWY_INLINE void fill_arange_impl(T* HWY_RESTRICT dst, size_t count, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + if (step == 0) { + fill_value_impl(dst, count, static_cast(start)); + return; + } + + // Calculate the actual number of elements to fill + size_t n = 0; + if ((step > 0 && start < end) || (step < 0 && start > end)) { + mllm_fp32_t n_float = (end - start) / step; + if (n_float > 0) { + n = static_cast(std::ceil(n_float)); + if (step > 0) { + if (start + (n - 1) * step >= end) --n; + } else { + if (start + (n - 1) * step <= end) --n; + } + n = std::min(n, count); + } + } + + // Use SIMD for float types where we can vectorize the computation + if constexpr (std::is_same_v) { + const hn::ScalableTag d; + const size_t N = hn::Lanes(d); + + // Create increment vector: [0, 1, 2, 3, ...] * step + const hn::Vec step_vec = hn::Set(d, step); + const hn::Vec n_step_vec = hn::Set(d, step * static_cast(N)); + + // Create base offsets [0, 1, 2, 3, ...] + hn::Vec base = hn::Iota(d, 0); + base = hn::Mul(base, step_vec); + hn::Vec current_start = hn::Add(hn::Set(d, start), base); + + size_t idx = 0; + for (; idx + N <= n; idx += N) { + hn::StoreU(current_start, d, dst + idx); + current_start = hn::Add(current_start, n_step_vec); + } + + // Handle remaining elements + for (; idx < n; ++idx) { dst[idx] = static_cast(start + idx * step); } + } else { + // Scalar fallback for other types + for (size_t i = 0; i < n; ++i) { dst[i] = static_cast(start + i * step); } + } +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_arange_fp32(mllm_fp32_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, mllm_fp32_t step) { + fill_arange_impl(dst, size, start, end, step); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_arange_i32(mllm_int32_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, mllm_fp32_t step) { + fill_arange_impl(dst, size, start, end, step); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_arange_u32(mllm_uint32_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, mllm_fp32_t step) { + fill_arange_impl(dst, size, start, end, step); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_arange_i64(mllm_int64_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, mllm_fp32_t step) { + fill_arange_impl(dst, size, start, end, step); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_arange_u64(mllm_uint64_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, mllm_fp32_t step) { + fill_arange_impl(dst, size, start, end, step); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_arange_i16(mllm_int16_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, mllm_fp32_t step) { + fill_arange_impl(dst, size, start, end, step); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_arange_u16(mllm_uint16_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, mllm_fp32_t step) { + fill_arange_impl(dst, size, start, end, step); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_arange_i8(mllm_int8_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, mllm_fp32_t step) { + fill_arange_impl(dst, size, start, end, step); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_arange_u8(mllm_uint8_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, mllm_fp32_t step) { + fill_arange_impl(dst, size, start, end, step); +} + +//===----------------------------------------------------------------------===// +// Fill Random (using LCG random number generator) +//===----------------------------------------------------------------------===// +template +HWY_INLINE void fill_random_impl(T* HWY_RESTRICT dst, size_t count, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + const uint64_t multiplier = 1103515245ULL; + const uint64_t increment = 12345ULL; + const uint64_t modulus = 1ULL << 31; // 2^31 + const mllm_fp32_t range = end - start; + + if (range == 0) { + fill_value_impl(dst, count, static_cast(start)); + return; + } + + uint64_t state = seed; + state = (multiplier * state + increment) % modulus; + + for (size_t i = 0; i < count; ++i) { + state = (multiplier * state + increment) % modulus; + const mllm_fp32_t random_value = static_cast(state) / static_cast(modulus - 1); + dst[i] = static_cast(start + random_value * range); + } +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_random_fp32(mllm_fp32_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, uint64_t seed) { + fill_random_impl(dst, size, start, end, seed); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_random_i32(mllm_int32_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, uint64_t seed) { + fill_random_impl(dst, size, start, end, seed); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_random_u32(mllm_uint32_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, uint64_t seed) { + fill_random_impl(dst, size, start, end, seed); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_random_i64(mllm_int64_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, uint64_t seed) { + fill_random_impl(dst, size, start, end, seed); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_random_u64(mllm_uint64_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, uint64_t seed) { + fill_random_impl(dst, size, start, end, seed); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_random_i16(mllm_int16_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, uint64_t seed) { + fill_random_impl(dst, size, start, end, seed); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_random_u16(mllm_uint16_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, uint64_t seed) { + fill_random_impl(dst, size, start, end, seed); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_random_i8(mllm_int8_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, uint64_t seed) { + fill_random_impl(dst, size, start, end, seed); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_random_u8(mllm_uint8_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, uint64_t seed) { + fill_random_impl(dst, size, start, end, seed); +} + +} // namespace HWY_NAMESPACE +} // namespace mllm::cpu::common +HWY_AFTER_NAMESPACE(); diff --git a/mllm/backends/cpu/kernels/common/kernel_dispatch.cpp b/mllm/backends/cpu/kernels/common/kernel_dispatch.cpp index 1ad3cee93..7e81adfdf 100644 --- a/mllm/backends/cpu/kernels/common/kernel_dispatch.cpp +++ b/mllm/backends/cpu/kernels/common/kernel_dispatch.cpp @@ -17,6 +17,7 @@ // Include all inline implementations here #include "mllm/backends/cpu/kernels/common/elewise-inl.hpp" +#include "mllm/backends/cpu/kernels/common/fill-inl.hpp" #if HWY_ONCE namespace mllm::cpu::common { @@ -69,11 +70,188 @@ HWY_DLLEXPORT void call_elewise_div_scalar_fp32(mllm_fp32_t* out, const mllm_fp3 // GELU //===----------------------------------------------------------------------===// // HWY_EXPORT(gelu_fp32); -// +// // HWY_DLLEXPORT void call_gelu_fp32(mllm_fp32_t* out, const mllm_fp32_t* in, size_t n) { // HWY_DYNAMIC_DISPATCH(gelu_fp32)(out, in, n); // } +//===----------------------------------------------------------------------===// +// Fill Zeros +//===----------------------------------------------------------------------===// +HWY_EXPORT(fill_zeros_fp32); +HWY_EXPORT(fill_zeros_fp64); +HWY_EXPORT(fill_zeros_i32); +HWY_EXPORT(fill_zeros_u32); +HWY_EXPORT(fill_zeros_i64); +HWY_EXPORT(fill_zeros_u64); +HWY_EXPORT(fill_zeros_i16); +HWY_EXPORT(fill_zeros_u16); +HWY_EXPORT(fill_zeros_i8); +HWY_EXPORT(fill_zeros_u8); + +HWY_DLLEXPORT void call_fill_zeros_fp32(mllm_fp32_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_zeros_fp32)(dst, n); } +HWY_DLLEXPORT void call_fill_zeros_fp64(mllm_fp64_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_zeros_fp64)(dst, n); } +HWY_DLLEXPORT void call_fill_zeros_i32(mllm_int32_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_zeros_i32)(dst, n); } +HWY_DLLEXPORT void call_fill_zeros_u32(mllm_uint32_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_zeros_u32)(dst, n); } +HWY_DLLEXPORT void call_fill_zeros_i64(mllm_int64_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_zeros_i64)(dst, n); } +HWY_DLLEXPORT void call_fill_zeros_u64(mllm_uint64_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_zeros_u64)(dst, n); } +HWY_DLLEXPORT void call_fill_zeros_i16(mllm_int16_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_zeros_i16)(dst, n); } +HWY_DLLEXPORT void call_fill_zeros_u16(mllm_uint16_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_zeros_u16)(dst, n); } +HWY_DLLEXPORT void call_fill_zeros_i8(mllm_int8_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_zeros_i8)(dst, n); } +HWY_DLLEXPORT void call_fill_zeros_u8(mllm_uint8_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_zeros_u8)(dst, n); } + +//===----------------------------------------------------------------------===// +// Fill Ones +//===----------------------------------------------------------------------===// +HWY_EXPORT(fill_ones_fp32); +HWY_EXPORT(fill_ones_fp64); +HWY_EXPORT(fill_ones_i32); +HWY_EXPORT(fill_ones_u32); +HWY_EXPORT(fill_ones_i64); +HWY_EXPORT(fill_ones_u64); +HWY_EXPORT(fill_ones_i16); +HWY_EXPORT(fill_ones_u16); +HWY_EXPORT(fill_ones_i8); +HWY_EXPORT(fill_ones_u8); + +HWY_DLLEXPORT void call_fill_ones_fp32(mllm_fp32_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_ones_fp32)(dst, n); } +HWY_DLLEXPORT void call_fill_ones_fp64(mllm_fp64_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_ones_fp64)(dst, n); } +HWY_DLLEXPORT void call_fill_ones_i32(mllm_int32_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_ones_i32)(dst, n); } +HWY_DLLEXPORT void call_fill_ones_u32(mllm_uint32_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_ones_u32)(dst, n); } +HWY_DLLEXPORT void call_fill_ones_i64(mllm_int64_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_ones_i64)(dst, n); } +HWY_DLLEXPORT void call_fill_ones_u64(mllm_uint64_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_ones_u64)(dst, n); } +HWY_DLLEXPORT void call_fill_ones_i16(mllm_int16_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_ones_i16)(dst, n); } +HWY_DLLEXPORT void call_fill_ones_u16(mllm_uint16_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_ones_u16)(dst, n); } +HWY_DLLEXPORT void call_fill_ones_i8(mllm_int8_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_ones_i8)(dst, n); } +HWY_DLLEXPORT void call_fill_ones_u8(mllm_uint8_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_ones_u8)(dst, n); } + +//===----------------------------------------------------------------------===// +// Fill Specific Value +//===----------------------------------------------------------------------===// +HWY_EXPORT(fill_value_fp32); +HWY_EXPORT(fill_value_fp64); +HWY_EXPORT(fill_value_i32); +HWY_EXPORT(fill_value_u32); +HWY_EXPORT(fill_value_i64); +HWY_EXPORT(fill_value_u64); +HWY_EXPORT(fill_value_i16); +HWY_EXPORT(fill_value_u16); +HWY_EXPORT(fill_value_i8); +HWY_EXPORT(fill_value_u8); + +HWY_DLLEXPORT void call_fill_value_fp32(mllm_fp32_t* dst, size_t n, mllm_fp32_t value) { + HWY_DYNAMIC_DISPATCH(fill_value_fp32)(dst, n, value); +} +HWY_DLLEXPORT void call_fill_value_fp64(mllm_fp64_t* dst, size_t n, mllm_fp64_t value) { + HWY_DYNAMIC_DISPATCH(fill_value_fp64)(dst, n, value); +} +HWY_DLLEXPORT void call_fill_value_i32(mllm_int32_t* dst, size_t n, mllm_int32_t value) { + HWY_DYNAMIC_DISPATCH(fill_value_i32)(dst, n, value); +} +HWY_DLLEXPORT void call_fill_value_u32(mllm_uint32_t* dst, size_t n, mllm_uint32_t value) { + HWY_DYNAMIC_DISPATCH(fill_value_u32)(dst, n, value); +} +HWY_DLLEXPORT void call_fill_value_i64(mllm_int64_t* dst, size_t n, mllm_int64_t value) { + HWY_DYNAMIC_DISPATCH(fill_value_i64)(dst, n, value); +} +HWY_DLLEXPORT void call_fill_value_u64(mllm_uint64_t* dst, size_t n, mllm_uint64_t value) { + HWY_DYNAMIC_DISPATCH(fill_value_u64)(dst, n, value); +} +HWY_DLLEXPORT void call_fill_value_i16(mllm_int16_t* dst, size_t n, mllm_int16_t value) { + HWY_DYNAMIC_DISPATCH(fill_value_i16)(dst, n, value); +} +HWY_DLLEXPORT void call_fill_value_u16(mllm_uint16_t* dst, size_t n, mllm_uint16_t value) { + HWY_DYNAMIC_DISPATCH(fill_value_u16)(dst, n, value); +} +HWY_DLLEXPORT void call_fill_value_i8(mllm_int8_t* dst, size_t n, mllm_int8_t value) { + HWY_DYNAMIC_DISPATCH(fill_value_i8)(dst, n, value); +} +HWY_DLLEXPORT void call_fill_value_u8(mllm_uint8_t* dst, size_t n, mllm_uint8_t value) { + HWY_DYNAMIC_DISPATCH(fill_value_u8)(dst, n, value); +} + +//===----------------------------------------------------------------------===// +// Fill Arange +//===----------------------------------------------------------------------===// +HWY_EXPORT(fill_arange_fp32); +HWY_EXPORT(fill_arange_i32); +HWY_EXPORT(fill_arange_u32); +HWY_EXPORT(fill_arange_i64); +HWY_EXPORT(fill_arange_u64); +HWY_EXPORT(fill_arange_i16); +HWY_EXPORT(fill_arange_u16); +HWY_EXPORT(fill_arange_i8); +HWY_EXPORT(fill_arange_u8); + +HWY_DLLEXPORT void call_fill_arange_fp32(mllm_fp32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + HWY_DYNAMIC_DISPATCH(fill_arange_fp32)(dst, n, start, end, step); +} +HWY_DLLEXPORT void call_fill_arange_i32(mllm_int32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + HWY_DYNAMIC_DISPATCH(fill_arange_i32)(dst, n, start, end, step); +} +HWY_DLLEXPORT void call_fill_arange_u32(mllm_uint32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + HWY_DYNAMIC_DISPATCH(fill_arange_u32)(dst, n, start, end, step); +} +HWY_DLLEXPORT void call_fill_arange_i64(mllm_int64_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + HWY_DYNAMIC_DISPATCH(fill_arange_i64)(dst, n, start, end, step); +} +HWY_DLLEXPORT void call_fill_arange_u64(mllm_uint64_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + HWY_DYNAMIC_DISPATCH(fill_arange_u64)(dst, n, start, end, step); +} +HWY_DLLEXPORT void call_fill_arange_i16(mllm_int16_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + HWY_DYNAMIC_DISPATCH(fill_arange_i16)(dst, n, start, end, step); +} +HWY_DLLEXPORT void call_fill_arange_u16(mllm_uint16_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + HWY_DYNAMIC_DISPATCH(fill_arange_u16)(dst, n, start, end, step); +} +HWY_DLLEXPORT void call_fill_arange_i8(mllm_int8_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + HWY_DYNAMIC_DISPATCH(fill_arange_i8)(dst, n, start, end, step); +} +HWY_DLLEXPORT void call_fill_arange_u8(mllm_uint8_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + HWY_DYNAMIC_DISPATCH(fill_arange_u8)(dst, n, start, end, step); +} + +//===----------------------------------------------------------------------===// +// Fill Random +//===----------------------------------------------------------------------===// +HWY_EXPORT(fill_random_fp32); +HWY_EXPORT(fill_random_i32); +HWY_EXPORT(fill_random_u32); +HWY_EXPORT(fill_random_i64); +HWY_EXPORT(fill_random_u64); +HWY_EXPORT(fill_random_i16); +HWY_EXPORT(fill_random_u16); +HWY_EXPORT(fill_random_i8); +HWY_EXPORT(fill_random_u8); + +HWY_DLLEXPORT void call_fill_random_fp32(mllm_fp32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + HWY_DYNAMIC_DISPATCH(fill_random_fp32)(dst, n, start, end, seed); +} +HWY_DLLEXPORT void call_fill_random_i32(mllm_int32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + HWY_DYNAMIC_DISPATCH(fill_random_i32)(dst, n, start, end, seed); +} +HWY_DLLEXPORT void call_fill_random_u32(mllm_uint32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + HWY_DYNAMIC_DISPATCH(fill_random_u32)(dst, n, start, end, seed); +} +HWY_DLLEXPORT void call_fill_random_i64(mllm_int64_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + HWY_DYNAMIC_DISPATCH(fill_random_i64)(dst, n, start, end, seed); +} +HWY_DLLEXPORT void call_fill_random_u64(mllm_uint64_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + HWY_DYNAMIC_DISPATCH(fill_random_u64)(dst, n, start, end, seed); +} +HWY_DLLEXPORT void call_fill_random_i16(mllm_int16_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + HWY_DYNAMIC_DISPATCH(fill_random_i16)(dst, n, start, end, seed); +} +HWY_DLLEXPORT void call_fill_random_u16(mllm_uint16_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + HWY_DYNAMIC_DISPATCH(fill_random_u16)(dst, n, start, end, seed); +} +HWY_DLLEXPORT void call_fill_random_i8(mllm_int8_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + HWY_DYNAMIC_DISPATCH(fill_random_i8)(dst, n, start, end, seed); +} +HWY_DLLEXPORT void call_fill_random_u8(mllm_uint8_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + HWY_DYNAMIC_DISPATCH(fill_random_u8)(dst, n, start, end, seed); +} + } // namespace mllm::cpu::common #endif // HWY_ONCE diff --git a/mllm/backends/cpu/kernels/common/kernel_dispatch.hpp b/mllm/backends/cpu/kernels/common/kernel_dispatch.hpp index eb100ac43..4df34db0e 100644 --- a/mllm/backends/cpu/kernels/common/kernel_dispatch.hpp +++ b/mllm/backends/cpu/kernels/common/kernel_dispatch.hpp @@ -7,6 +7,7 @@ #include "mllm/utils/CPUArchHelper.hpp" #if !(defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM)) +#include #include "mllm/core/DataTypes.hpp" // Platform-specific definitions used for declaring an interface, independent of @@ -30,6 +31,222 @@ HWY_DLLEXPORT void call_elewise_sub_scalar_fp32(mllm_fp32_t* out, const mllm_fp3 HWY_DLLEXPORT void call_elewise_mul_scalar_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n); HWY_DLLEXPORT void call_elewise_div_scalar_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n); +//===----------------------------------------------------------------------===// +// Fill Zeros +//===----------------------------------------------------------------------===// +HWY_DLLEXPORT void call_fill_zeros_fp32(mllm_fp32_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_zeros_fp64(mllm_fp64_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_zeros_i32(mllm_int32_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_zeros_u32(mllm_uint32_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_zeros_i64(mllm_int64_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_zeros_u64(mllm_uint64_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_zeros_i16(mllm_int16_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_zeros_u16(mllm_uint16_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_zeros_i8(mllm_int8_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_zeros_u8(mllm_uint8_t* dst, size_t n); + +//===----------------------------------------------------------------------===// +// Fill Ones +//===----------------------------------------------------------------------===// +HWY_DLLEXPORT void call_fill_ones_fp32(mllm_fp32_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_ones_fp64(mllm_fp64_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_ones_i32(mllm_int32_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_ones_u32(mllm_uint32_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_ones_i64(mllm_int64_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_ones_u64(mllm_uint64_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_ones_i16(mllm_int16_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_ones_u16(mllm_uint16_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_ones_i8(mllm_int8_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_ones_u8(mllm_uint8_t* dst, size_t n); + +//===----------------------------------------------------------------------===// +// Fill Specific Value +//===----------------------------------------------------------------------===// +HWY_DLLEXPORT void call_fill_value_fp32(mllm_fp32_t* dst, size_t n, mllm_fp32_t value); +HWY_DLLEXPORT void call_fill_value_fp64(mllm_fp64_t* dst, size_t n, mllm_fp64_t value); +HWY_DLLEXPORT void call_fill_value_i32(mllm_int32_t* dst, size_t n, mllm_int32_t value); +HWY_DLLEXPORT void call_fill_value_u32(mllm_uint32_t* dst, size_t n, mllm_uint32_t value); +HWY_DLLEXPORT void call_fill_value_i64(mllm_int64_t* dst, size_t n, mllm_int64_t value); +HWY_DLLEXPORT void call_fill_value_u64(mllm_uint64_t* dst, size_t n, mllm_uint64_t value); +HWY_DLLEXPORT void call_fill_value_i16(mllm_int16_t* dst, size_t n, mllm_int16_t value); +HWY_DLLEXPORT void call_fill_value_u16(mllm_uint16_t* dst, size_t n, mllm_uint16_t value); +HWY_DLLEXPORT void call_fill_value_i8(mllm_int8_t* dst, size_t n, mllm_int8_t value); +HWY_DLLEXPORT void call_fill_value_u8(mllm_uint8_t* dst, size_t n, mllm_uint8_t value); + +//===----------------------------------------------------------------------===// +// Fill Arange +//===----------------------------------------------------------------------===// +HWY_DLLEXPORT void call_fill_arange_fp32(mllm_fp32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step); +HWY_DLLEXPORT void call_fill_arange_i32(mllm_int32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step); +HWY_DLLEXPORT void call_fill_arange_u32(mllm_uint32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step); +HWY_DLLEXPORT void call_fill_arange_i64(mllm_int64_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step); +HWY_DLLEXPORT void call_fill_arange_u64(mllm_uint64_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step); +HWY_DLLEXPORT void call_fill_arange_i16(mllm_int16_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step); +HWY_DLLEXPORT void call_fill_arange_u16(mllm_uint16_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step); +HWY_DLLEXPORT void call_fill_arange_i8(mllm_int8_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step); +HWY_DLLEXPORT void call_fill_arange_u8(mllm_uint8_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step); + +//===----------------------------------------------------------------------===// +// Fill Random +//===----------------------------------------------------------------------===// +HWY_DLLEXPORT void call_fill_random_fp32(mllm_fp32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed); +HWY_DLLEXPORT void call_fill_random_i32(mllm_int32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed); +HWY_DLLEXPORT void call_fill_random_u32(mllm_uint32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed); +HWY_DLLEXPORT void call_fill_random_i64(mllm_int64_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed); +HWY_DLLEXPORT void call_fill_random_u64(mllm_uint64_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed); +HWY_DLLEXPORT void call_fill_random_i16(mllm_int16_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed); +HWY_DLLEXPORT void call_fill_random_u16(mllm_uint16_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed); +HWY_DLLEXPORT void call_fill_random_i8(mllm_int8_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed); +HWY_DLLEXPORT void call_fill_random_u8(mllm_uint8_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed); + +//===----------------------------------------------------------------------===// +// Template wrapper for generic fill operations +//===----------------------------------------------------------------------===// +template +inline void fill_zeros_anytype(T* dst, size_t n) { + if constexpr (std::is_same_v) { + call_fill_zeros_fp32(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_zeros_fp64(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_zeros_i32(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_zeros_u32(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_zeros_i64(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_zeros_u64(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_zeros_i16(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_zeros_u16(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_zeros_i8(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_zeros_u8(dst, n); + } else { + // Fallback for unsupported types + std::memset(dst, 0, n * sizeof(T)); + } +} + +template +inline void fill_ones_anytype(T* dst, size_t n) { + if constexpr (std::is_same_v) { + call_fill_ones_fp32(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_ones_fp64(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_ones_i32(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_ones_u32(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_ones_i64(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_ones_u64(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_ones_i16(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_ones_u16(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_ones_i8(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_ones_u8(dst, n); + } else { + // Fallback + for (size_t i = 0; i < n; ++i) { dst[i] = static_cast(1); } + } +} + +template +inline void fill_value_anytype(T* dst, size_t n, mllm_fp32_t value) { + if constexpr (std::is_same_v) { + call_fill_value_fp32(dst, n, value); + } else if constexpr (std::is_same_v) { + call_fill_value_fp64(dst, n, static_cast(value)); + } else if constexpr (std::is_same_v) { + call_fill_value_i32(dst, n, static_cast(value)); + } else if constexpr (std::is_same_v) { + call_fill_value_u32(dst, n, static_cast(value)); + } else if constexpr (std::is_same_v) { + call_fill_value_i64(dst, n, static_cast(value)); + } else if constexpr (std::is_same_v) { + call_fill_value_u64(dst, n, static_cast(value)); + } else if constexpr (std::is_same_v) { + call_fill_value_i16(dst, n, static_cast(value)); + } else if constexpr (std::is_same_v) { + call_fill_value_u16(dst, n, static_cast(value)); + } else if constexpr (std::is_same_v) { + call_fill_value_i8(dst, n, static_cast(value)); + } else if constexpr (std::is_same_v) { + call_fill_value_u8(dst, n, static_cast(value)); + } else { + // Fallback + for (size_t i = 0; i < n; ++i) { dst[i] = static_cast(value); } + } +} + +template +inline void fill_arange_anytype(T* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + if constexpr (std::is_same_v) { + call_fill_arange_fp32(dst, n, start, end, step); + } else if constexpr (std::is_same_v) { + call_fill_arange_i32(dst, n, start, end, step); + } else if constexpr (std::is_same_v) { + call_fill_arange_u32(dst, n, start, end, step); + } else if constexpr (std::is_same_v) { + call_fill_arange_i64(dst, n, start, end, step); + } else if constexpr (std::is_same_v) { + call_fill_arange_u64(dst, n, start, end, step); + } else if constexpr (std::is_same_v) { + call_fill_arange_i16(dst, n, start, end, step); + } else if constexpr (std::is_same_v) { + call_fill_arange_u16(dst, n, start, end, step); + } else if constexpr (std::is_same_v) { + call_fill_arange_i8(dst, n, start, end, step); + } else if constexpr (std::is_same_v) { + call_fill_arange_u8(dst, n, start, end, step); + } else { + // Fallback + for (size_t i = 0; i < n; ++i) { dst[i] = static_cast(start + i * step); } + } +} + +template +inline void fill_random_anytype(T* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + if constexpr (std::is_same_v) { + call_fill_random_fp32(dst, n, start, end, seed); + } else if constexpr (std::is_same_v) { + call_fill_random_i32(dst, n, start, end, seed); + } else if constexpr (std::is_same_v) { + call_fill_random_u32(dst, n, start, end, seed); + } else if constexpr (std::is_same_v) { + call_fill_random_i64(dst, n, start, end, seed); + } else if constexpr (std::is_same_v) { + call_fill_random_u64(dst, n, start, end, seed); + } else if constexpr (std::is_same_v) { + call_fill_random_i16(dst, n, start, end, seed); + } else if constexpr (std::is_same_v) { + call_fill_random_u16(dst, n, start, end, seed); + } else if constexpr (std::is_same_v) { + call_fill_random_i8(dst, n, start, end, seed); + } else if constexpr (std::is_same_v) { + call_fill_random_u8(dst, n, start, end, seed); + } else { + // Fallback using LCG + const uint64_t multiplier = 1103515245ULL; + const uint64_t increment = 12345ULL; + const uint64_t modulus = 1ULL << 31; + const mllm_fp32_t range = end - start; + uint64_t state = seed; + for (size_t i = 0; i < n; ++i) { + state = (multiplier * state + increment) % modulus; + const mllm_fp32_t random_value = static_cast(state) / static_cast(modulus - 1); + dst[i] = static_cast(start + random_value * range); + } + } +} + } // namespace mllm::cpu::common #endif diff --git a/mllm/backends/cpu/ops/FillOp.cpp b/mllm/backends/cpu/ops/FillOp.cpp index e4d935f51..cf5cee47e 100644 --- a/mllm/backends/cpu/ops/FillOp.cpp +++ b/mllm/backends/cpu/ops/FillOp.cpp @@ -21,7 +21,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& switch (dst.dtype()) { case kFloat32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - x86::fill_zeros(dst.ptr(), dst.numel(), threads); + common::fill_zeros_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_zeros(dst.ptr(), dst.numel(), threads); #endif @@ -29,7 +29,8 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kFloat16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + // FP16 not directly supported by Highway on x86, use scalar fallback + std::memset(dst.ptr(), 0, dst.numel() * sizeof(mllm_fp16_t)); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_zeros_fp16(dst.ptr(), dst.numel(), threads); #endif @@ -37,7 +38,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt64: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_zeros_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_zeros_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -45,7 +46,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_zeros_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_zeros_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -53,7 +54,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_zeros_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_zeros_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -61,7 +62,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt8: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_zeros_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_zeros_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -69,7 +70,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt64: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_zeros_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_zeros_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -77,7 +78,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_zeros_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_zeros_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -85,7 +86,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_zeros_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_zeros_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -93,7 +94,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt8: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_zeros_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_zeros_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -110,7 +111,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& switch (dst.dtype()) { case kFloat32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - x86::fill_ones(dst.ptr(), dst.numel(), threads); + common::fill_ones_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_ones(dst.ptr(), dst.numel(), threads); #endif @@ -118,7 +119,9 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kFloat16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + // FP16 not directly supported by Highway on x86, use scalar fallback + auto ptr = dst.ptr(); + for (size_t i = 0; i < dst.numel(); ++i) { ptr[i] = static_cast(1.0f); } #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_ones_fp16(dst.ptr(), dst.numel(), threads); #endif @@ -126,7 +129,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt64: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_ones_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_ones_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -134,7 +137,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_ones_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_ones_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -142,7 +145,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_ones_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_ones_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -150,7 +153,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt8: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_ones_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_ones_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -158,7 +161,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt64: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_ones_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_ones_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -166,7 +169,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_ones_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_ones_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -174,7 +177,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_ones_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_ones_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -182,7 +185,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt8: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_ones_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_ones_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -199,7 +202,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& switch (dst.dtype()) { case kFloat32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - x86::fill_arange(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); + common::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_arange(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); #endif @@ -207,7 +210,9 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kFloat16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + // FP16 not directly supported by Highway on x86, use scalar fallback + auto ptr = dst.ptr(); + for (size_t i = 0; i < dst.numel(); ++i) { ptr[i] = static_cast(options_.start + i * options_.step); } #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_arange_fp16(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); #endif @@ -215,7 +220,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt64: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); @@ -224,7 +229,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); @@ -233,7 +238,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); @@ -242,7 +247,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt8: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); @@ -251,7 +256,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt64: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); @@ -260,7 +265,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); @@ -269,7 +274,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); @@ -278,7 +283,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt8: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); @@ -295,7 +300,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& switch (dst.dtype()) { case kFloat32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - x86::fill_random(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); + common::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_random(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); #endif @@ -303,7 +308,18 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kFloat16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + // FP16 not directly supported by Highway on x86, use scalar fallback + const uint64_t multiplier = 1103515245ULL; + const uint64_t increment = 12345ULL; + const uint64_t modulus = 1ULL << 31; + const mllm_fp32_t range = options_.end - options_.start; + uint64_t state = options_.seed; + auto ptr = dst.ptr(); + for (size_t i = 0; i < dst.numel(); ++i) { + state = (multiplier * state + increment) % modulus; + const mllm_fp32_t random_value = static_cast(state) / static_cast(modulus - 1); + ptr[i] = static_cast(options_.start + random_value * range); + } #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_random_fp16(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); #endif @@ -311,7 +327,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt64: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); #endif @@ -319,7 +335,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); #endif @@ -327,7 +343,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); #endif @@ -335,7 +351,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt8: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); #endif @@ -343,7 +359,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt64: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); #endif @@ -351,7 +367,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); #endif @@ -359,7 +375,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); #endif @@ -367,7 +383,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt8: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); #endif @@ -383,7 +399,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& switch (dst.dtype()) { case kFloat32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - x86::fill_specific_value(dst.ptr(), dst.numel(), options_.value, threads); + common::fill_value_anytype(dst.ptr(), dst.numel(), options_.value); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_specific_value(dst.ptr(), dst.numel(), options_.value, threads); #endif @@ -391,7 +407,9 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kFloat16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + // FP16 not directly supported by Highway on x86, use scalar fallback + auto ptr = dst.ptr(); + for (size_t i = 0; i < dst.numel(); ++i) { ptr[i] = static_cast(options_.value); } #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_specific_value_fp16(dst.ptr(), dst.numel(), options_.value, threads); #endif @@ -399,7 +417,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt64: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_value_anytype(dst.ptr(), dst.numel(), options_.value); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_specific_value_anytype(dst.ptr(), dst.numel(), options_.value, threads); #endif @@ -407,7 +425,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_value_anytype(dst.ptr(), dst.numel(), options_.value); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_specific_value_anytype(dst.ptr(), dst.numel(), options_.value, threads); #endif @@ -415,7 +433,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_value_anytype(dst.ptr(), dst.numel(), options_.value); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_specific_value_anytype(dst.ptr(), dst.numel(), options_.value, threads); #endif @@ -423,7 +441,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt8: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_value_anytype(dst.ptr(), dst.numel(), options_.value); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_specific_value_anytype(dst.ptr(), dst.numel(), options_.value, threads); #endif @@ -431,7 +449,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt64: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_value_anytype(dst.ptr(), dst.numel(), options_.value); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_specific_value_anytype(dst.ptr(), dst.numel(), options_.value, threads); #endif @@ -439,7 +457,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_value_anytype(dst.ptr(), dst.numel(), options_.value); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_specific_value_anytype(dst.ptr(), dst.numel(), options_.value, threads); #endif @@ -447,7 +465,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_value_anytype(dst.ptr(), dst.numel(), options_.value); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_specific_value_anytype(dst.ptr(), dst.numel(), options_.value, threads); #endif @@ -455,7 +473,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt8: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_value_anytype(dst.ptr(), dst.numel(), options_.value); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_specific_value_anytype(dst.ptr(), dst.numel(), options_.value, threads); #endif diff --git a/mllm/backends/qnn/CMakeLists.txt b/mllm/backends/qnn/CMakeLists.txt index 0ad833792..83b4a43f9 100644 --- a/mllm/backends/qnn/CMakeLists.txt +++ b/mllm/backends/qnn/CMakeLists.txt @@ -44,3 +44,10 @@ get_property(current_includes DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY INC message(STATUS "MLLM_QNN INCLUDES: ${current_includes}") #print include directories target_link_libraries(MllmQNNBackend PUBLIC MllmRT) + +install( + TARGETS MllmQNNBackend + EXPORT MllmTargets + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib + RUNTIME DESTINATION bin) diff --git a/mllm/backends/qnn/QNNBackend.cpp b/mllm/backends/qnn/QNNBackend.cpp index 54da97c9d..05ebedfcb 100644 --- a/mllm/backends/qnn/QNNBackend.cpp +++ b/mllm/backends/qnn/QNNBackend.cpp @@ -29,15 +29,28 @@ QNNBackend::QNNBackend() : Backend(kQNN, createQNNAllocator()) { QNNViewOpFactory, QNNRMSNormOpFactory, QNNTransposeOpFactory, QNNX2XOpFactory, QNNCastTypeOpFactory, QNNParamOpFactory, QNNSiLUOpFactory, QNNEmbeddingOpFactory>(); - QnnLog_Level_t qnnLogLevel = QNN_LOG_LEVEL_ERROR; // default QNN log level + QnnLog_Level_t qnnLogLevel = QNN_LOG_LEVEL_VERBOSE; // default QNN log level profilingLevel_ = ProfilingLevel::OFF; debug_ = false; // when set true, NATIVE tensor will be regared as APP_READ tensor - loadQNNSymbol(); - loadQNNSystemSymbol(); + if (!loadQNNSymbol()) { + MLLM_ERROR_EXIT(ExitCode::kQnnError, "Failed to load QNN symbols"); + } else { + MLLM_INFO("QNN symbols loaded successfully"); + } + + if (!loadQNNSystemSymbol()) { + MLLM_ERROR_EXIT(ExitCode::kQnnError, "Failed to load QNN System symbols"); + } else { + MLLM_INFO("QNN System symbols loaded successfully"); + } runtime_ = QNNRuntime::create(profilingLevel_, qnnLogLevel); - if (!runtime_) { MLLM_ERROR_EXIT(1, "Failed to create QNN Runtime"); } + if (!runtime_) { + MLLM_ERROR_EXIT(ExitCode::kQnnError, "Failed to create QNN Runtime"); + } else { + MLLM_INFO("QNN Runtime created successfully"); + } // check QNN capability, detect QNN features for future use char* backendBuildId{nullptr}; @@ -59,6 +72,7 @@ QNNBackend::QNNBackend() : Backend(kQNN, createQNNAllocator()) { perf_ = QNNPerf::create(&runtime_->qnnInterface); perf_->setPowerConfigBurst(); perf_->setRpcLatencyAndPolling(); + MLLM_INFO("QNN Perf created successfully"); } QNNPerf::QNNPerf(const QNN_INTERFACE_VER_TYPE* qnnInterface) { @@ -204,11 +218,13 @@ QNNRuntime* QNNRuntime::initRuntime(ProfilingLevel profilingLevel, QnnLog_Level_ // Create Log Qnn_LogHandle_t logHandle = nullptr; { - QnnLog_Callback_t logCallback = &__mllmQnnLoggerCallback; + QnnLog_Callback_t logCallback = __mllmQnnLoggerCallback; if ((QNN_GET_ERROR_CODE(qnnInterface.logCreate(logCallback, qnnLogLevel, &logHandle)) != QNN_SUCCESS) || (logHandle == nullptr)) { MLLM_ERROR("Failed to initialize logging in the backend."); return nullptr; + } else { + MLLM_INFO("Logging initialized successfully"); } } @@ -220,6 +236,8 @@ QNNRuntime* QNNRuntime::initRuntime(ProfilingLevel profilingLevel, QnnLog_Level_ || (backendHandle == nullptr)) { MLLM_ERROR("Failed to create the backend."); return nullptr; + } else { + MLLM_INFO("Backend created successfully"); } } @@ -227,16 +245,13 @@ QNNRuntime* QNNRuntime::initRuntime(ProfilingLevel profilingLevel, QnnLog_Level_ Qnn_DeviceHandle_t deviceHandle = nullptr; { // Check whether the device API is supported. - if (nullptr != qnnInterface.propertyHasCapability) { - auto qnnStatus = qnnInterface.propertyHasCapability(QNN_PROPERTY_GROUP_DEVICE); - if (QNN_PROPERTY_NOT_SUPPORTED == qnnStatus) { - MLLM_WARN("Device property is not supported"); - return nullptr; - } - if (QNN_PROPERTY_ERROR_UNKNOWN_KEY == qnnStatus) { - MLLM_ERROR("Device property is not known to backend"); + if (nullptr != qnnInterface.deviceCreate) { + auto status = qnnInterface.deviceCreate(logHandle, nullptr, &deviceHandle); + if (QNN_SUCCESS != status) { + MLLM_ERROR("Failed to create device, error: {}", (int)status); return nullptr; } + MLLM_INFO("Device created successfully"); } } @@ -269,9 +284,7 @@ QNNRuntime* QNNRuntime::initRuntime(ProfilingLevel profilingLevel, QnnLog_Level_ std::string target; }; - std::vector opPackages = { - {.path = "libQnnLLaMAPackage_CPU.so", .interfaceProvider = "LLaMAPackageInterfaceProvider", .target = "CPU"}, - {.path = "libQnnLLaMAPackage_HTP.so", .interfaceProvider = "LLaMAPackageInterfaceProvider", .target = "HTP"}}; + std::vector opPackages = {}; for (const auto& pkg : opPackages) { if (!qnnInterface.backendRegisterOpPackage) { @@ -298,6 +311,8 @@ QNNRuntime* QNNRuntime::initRuntime(ProfilingLevel profilingLevel, QnnLog_Level_ != QnnSystemInterface_getProviders((const QnnSystemInterface_t***)&systemInterfaceProviders, &numProviders)) { MLLM_ERROR("Failed to get system interface providers."); return nullptr; + } else { + MLLM_INFO("System interface providers found: {}", numProviders); } if (0 == numProviders) { MLLM_ERROR("Failed to get interface providers: 0 interface providers."); @@ -305,11 +320,17 @@ QNNRuntime* QNNRuntime::initRuntime(ProfilingLevel profilingLevel, QnnLog_Level_ } bool foundValidSystemInterface = false; for (size_t pIdx = 0; pIdx < numProviders; pIdx++) { - foundValidSystemInterface = true; if (QNN_SYSTEM_API_VERSION_MAJOR == systemInterfaceProviders[pIdx]->systemApiVersion.major && QNN_SYSTEM_API_VERSION_MINOR <= systemInterfaceProviders[pIdx]->systemApiVersion.minor) { qnnSystemInterface = systemInterfaceProviders[pIdx]->QNN_SYSTEM_INTERFACE_VER_NAME; + foundValidSystemInterface = true; break; + } else { + // Print system interface provider and self version + MLLM_WARN("System interface provider: {} version: {}", systemInterfaceProviders[pIdx]->systemApiVersion.major, + systemInterfaceProviders[pIdx]->systemApiVersion.minor); + MLLM_WARN("Self version: {} {}", QNN_SYSTEM_API_VERSION_MAJOR, QNN_SYSTEM_API_VERSION_MINOR); + MLLM_WARN("Unable to find a valid system interface."); } } if (!foundValidSystemInterface) { @@ -334,7 +355,14 @@ bool QNNRuntime::retrieveContext(const std::string& contextBinaryPath, Qnn_Conte std::vector>& qnnModels, QnnContext_Config_t** contextConfig) { // Read the binary from qnn_context.bin and get the size in byte std::ifstream file(contextBinaryPath, std::ios::binary | std::ios::ate); + if (!file.is_open() || !file.good()) { + MLLM_ERROR("Could not open context binary file: {}", contextBinaryPath); + return false; + } else { + MLLM_INFO("Context binary file opened successfully: {}", contextBinaryPath); + } std::streamsize size = file.tellg(); + MLLM_INFO("Context binary file size: {} MB", size / 1024 / 1024); file.seekg(0, std::ios::beg); auto binaryBuffer = std::make_unique(size); @@ -344,17 +372,27 @@ bool QNNRuntime::retrieveContext(const std::string& contextBinaryPath, Qnn_Conte // inspect binary info QnnSystemContext_Handle_t sysCtxHandle{nullptr}; + if (!qnnSystemInterface.systemContextCreate) { + MLLM_ERROR("systemContextCreate is nullptr."); + return false; + } if (QNN_SUCCESS != qnnSystemInterface.systemContextCreate(&sysCtxHandle)) { MLLM_ERROR("Could not create system handle."); return false; + } else { + MLLM_INFO("System context created successfully"); } + const QnnSystemContext_BinaryInfo_t* binaryInfo{nullptr}; Qnn_ContextBinarySize_t binaryInfoSize{0}; + if (QNN_SUCCESS != qnnSystemInterface.systemContextGetBinaryInfo(sysCtxHandle, static_cast(binaryBuffer.get()), size, &binaryInfo, &binaryInfoSize)) { MLLM_ERROR("Failed to get context binary info"); return false; + } else { + MLLM_INFO("Context binary info retrieved successfully"); } // Extract graph metadata to create QNNModels instead of GraphInfo_t @@ -365,13 +403,24 @@ bool QNNRuntime::retrieveContext(const std::string& contextBinaryPath, Qnn_Conte MLLM_ERROR("Failed to copy metadata."); return false; } - qnnSystemInterface.systemContextFree(sysCtxHandle); + if (QNN_SUCCESS != qnnSystemInterface.systemContextFree(sysCtxHandle)) { + MLLM_ERROR("Could not free system context."); + return false; + } else { + MLLM_INFO("System context freed successfully"); + } sysCtxHandle = nullptr; // Create context from binary Qnn_ContextBinarySize_t writtenSize = 0; - qnnInterface.contextCreateFromBinary(backendHandle, deviceHandle, (const QnnContext_Config_t**)contextConfig, - binaryBuffer.get(), size, &context, profileHandle); + if (QNN_CONTEXT_NO_ERROR + != qnnInterface.contextCreateFromBinary(backendHandle, deviceHandle, (const QnnContext_Config_t**)contextConfig, + binaryBuffer.get(), size, &context, profileHandle)) { + MLLM_ERROR("Could not create context from binary. Mostly due to binary's qnn version mismatch with backend's qnn version."); + return false; + } else { + MLLM_INFO("Context created from binary successfully"); + } // Create QNNModels for each graph and initialize from context qnnModels.clear(); diff --git a/mllm/backends/qnn/QNNBackend.hpp b/mllm/backends/qnn/QNNBackend.hpp index 49669c7c1..78953f32d 100644 --- a/mllm/backends/qnn/QNNBackend.hpp +++ b/mllm/backends/qnn/QNNBackend.hpp @@ -45,7 +45,7 @@ class QNNRuntime { ~QNNRuntime(); static std::unique_ptr create(ProfilingLevel profilingLevel = ProfilingLevel::OFF, - QnnLog_Level_t qnnLogLevel = QNN_LOG_LEVEL_WARN) { + QnnLog_Level_t qnnLogLevel = QNN_LOG_LEVEL_VERBOSE) { return std::unique_ptr(initRuntime(profilingLevel, qnnLogLevel)); } diff --git a/mllm/backends/qnn/Register.cpp b/mllm/backends/qnn/Register.cpp index 158294f35..88185921e 100644 --- a/mllm/backends/qnn/Register.cpp +++ b/mllm/backends/qnn/Register.cpp @@ -21,9 +21,18 @@ void initQnnBackend(const std::string& context_path) { // 1. Register backend auto backend = std::make_shared(); if (std::filesystem::exists(context_path)) { - if (!backend->loadContext(context_path)) { MLLM_ERROR_EXIT(1, "Failed to load QNN context from {}", context_path); } + MLLM_INFO("QNN context path exists: {}", context_path); + if (!backend->loadContext(context_path)) { + MLLM_ERROR_EXIT(1, "Failed to load QNN context from {}", context_path); + } else { + MLLM_INFO("QNN context loaded successfully from {}", context_path); + } } else { - if (!backend->createContext()) { MLLM_ERROR_EXIT(1, "Failed to create QNN context"); } + if (!backend->createContext()) { + MLLM_ERROR_EXIT(1, "Failed to create QNN context"); + } else { + MLLM_INFO("QNN context created successfully"); + } } ctx.registerBackend(backend); @@ -33,6 +42,8 @@ void initQnnBackend(const std::string& context_path) { .really_large_tensor_threshold = 0, .using_buddy_mem_pool = false, }); + MLLM_INFO("QNN memory manager registered"); + // 3. Initialize dispatcher manager ctx.dispatcherManager()->registerDispatcher( createQNNDispatcher(ctx.dispatcherManager()->getExecutor(), qnn::QNNDispatcherOptions())); diff --git a/mllm/backends/qnn/aot/QnnWrappersAPI.cpp b/mllm/backends/qnn/aot/QnnWrappersAPI.cpp index b2b04fd78..a79047e78 100644 --- a/mllm/backends/qnn/aot/QnnWrappersAPI.cpp +++ b/mllm/backends/qnn/aot/QnnWrappersAPI.cpp @@ -107,6 +107,7 @@ std::string QnnAOTNodeTensor::parseQnnTensorNameFromIR(const ir::tensor::TensorV Qnn_QuantizeParams_t QnnAOTNodeTensor::parseQnnQuantizeParamFromIR(const ir::tensor::TensorValue::ptr_t& v) { Qnn_QuantizeParams_t ret = QNN_QUANTIZE_PARAMS_INIT; + MLLM_RT_ASSERT(v); MLLM_RT_ASSERT(v->getAttr("quant_recipe")); auto quant_spec = v->getAttr("quant_recipe")->cast_()->spec_; @@ -120,6 +121,9 @@ Qnn_QuantizeParams_t QnnAOTNodeTensor::parseQnnQuantizeParamFromIR(const ir::ten auto cfg = std::static_pointer_cast(quant_spec); ret.encodingDefinition = QNN_DEFINITION_DEFINED; ret.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET; + if (!cfg->scale || !cfg->zero_point) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "AsymPerTensor quant recipe has no scale or zero point. tensor: {}", v->name()); + } ret.scaleOffsetEncoding = Qnn_ScaleOffset_t{.scale = cfg->scale.item(), .offset = cfg->zero_point.item()}; break; } @@ -127,6 +131,9 @@ Qnn_QuantizeParams_t QnnAOTNodeTensor::parseQnnQuantizeParamFromIR(const ir::ten auto cfg = std::static_pointer_cast(quant_spec); ret.encodingDefinition = QNN_DEFINITION_DEFINED; ret.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET; + if (!cfg->scale) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "SymPerTensor quant recipe has no scale. tensor: {}", v->name()); + } ret.scaleOffsetEncoding = Qnn_ScaleOffset_t{.scale = cfg->scale.item(), .offset = 0}; break; } @@ -429,6 +436,12 @@ void QnnAOTEnv::_setup(const std::string& path) { } std::shared_ptr QnnAOTEnv::createContext(const std::string& name, bool weights_sharing) { + // Check if context with this name already exists + if (contexts_.count(name) > 0) { + MLLM_WARN("Context '{}' already exists, reusing the existing context", name); + return contexts_[name]; + } + std::shared_ptr context = std::make_shared(); context->name_ = name; diff --git a/mllm/backends/qnn/aot/passes/LLMQuantRecipePass.cpp b/mllm/backends/qnn/aot/passes/LLMQuantRecipePass.cpp index 90ee4ad72..18bbb505c 100644 --- a/mllm/backends/qnn/aot/passes/LLMQuantRecipePass.cpp +++ b/mllm/backends/qnn/aot/passes/LLMQuantRecipePass.cpp @@ -369,8 +369,7 @@ bool LLMQuantRecipeNegPattern::isMatch(const mllm::ir::op_ptr_t& op) { } bool LLMQuantRecipeNegPattern::rewrite(ir::IRWriter& writer, const ir::op_ptr_t& node) { - return shareQuantSpecSingleInputToSingleOutputAndSetOpQuantAnnoAttr(writer.getContext(), - node->cast_()); + return noSharingSingleInAndSingleOutQuantAnnoAttr(writer.getContext(), node->cast_()); } //===----------------------------------------------------------------------===// @@ -651,8 +650,15 @@ bool LLMQuantRecipeConcatPattern::rewrite(ir::IRWriter& writer, const ir::op_ptr return false; } - MLLM_RETURN_FALSE_IF_NOT(i_0->getAttr("quant_recipe")); - MLLM_RETURN_FALSE_IF_NOT(i_1->getAttr("quant_recipe")); + // Create quant_recipe if not present + if (!i_0->getAttr("quant_recipe")) { + auto i_0_spec = genSimpleQuantizationSpecAttr(writer.getContext(), i_0->cast_()); + i_0->setAttr("quant_recipe", i_0_spec); + } + if (!i_1->getAttr("quant_recipe")) { + auto i_1_spec = genSimpleQuantizationSpecAttr(writer.getContext(), i_1->cast_()); + i_1->setAttr("quant_recipe", i_1_spec); + } o_0->setAttr("quant_recipe", i_0->getAttr("quant_recipe")); @@ -795,7 +801,8 @@ bool LLMQuantRecipeWherePattern::rewrite(ir::IRWriter& writer, const ir::op_ptr_ MLLM_RETURN_FALSE_IF_NOT(i_1->getAttr("quant_recipe")); MLLM_RETURN_FALSE_IF_NOT(i_2->getAttr("quant_recipe")); - o_0->setAttr("quant_recipe", i_2->getAttr("quant_recipe")); + auto o_0_spec = genSimpleQuantizationSpecAttr(writer.getContext(), o_0->cast_()); + o_0->setAttr("quant_recipe", o_0_spec); auto annotation_attr = writer.create(); annotation_attr->annotation_.inputs.emplace_back( @@ -979,6 +986,7 @@ bool LLMQuantRecipeEmbeddingPattern::rewrite(ir::IRWriter& writer, const ir::op_ auto annotation_attr = writer.create(); + // i_0 logic stays the same if (!i_0->getAttr("quant_recipe")) { auto i_0_spec = genSimpleQuantizationSpecAttr(writer.getContext(), i_0->cast_()); i_0->setAttr("quant_recipe", i_0_spec); @@ -989,16 +997,7 @@ bool LLMQuantRecipeEmbeddingPattern::rewrite(ir::IRWriter& writer, const ir::op_ i_0->getAttr("quant_recipe")->cast_()->spec_); } - if (!o_0->getAttr("quant_recipe")) { - auto o_0_spec = genSimpleQuantizationSpecAttr(writer.getContext(), o_0->cast_()); - o_0->setAttr("quant_recipe", o_0_spec); - annotation_attr->annotation_.outputs.emplace_back(o_0_spec->spec_); - } else { - annotation_attr->annotation_.outputs.emplace_back( - o_0->getAttr("quant_recipe")->cast_()->spec_); - } - - // Weights + // Weights - must be uint16, force set to kUInt16PerTensorAsy auto weight_name = embedding_op->getAOp()->getName() + ".weight"; auto weight_reg_tensor_ir = writer.getContext()->lookupSymbolTable(weight_name); MLLM_RETURN_FALSE_IF_NOT(weight_reg_tensor_ir); @@ -1006,11 +1005,21 @@ bool LLMQuantRecipeEmbeddingPattern::rewrite(ir::IRWriter& writer, const ir::op_ MLLM_RETURN_FALSE_IF_NOT(weight_reg_tensor_ir->outputs().front()->isa_()); auto weight_tensor = weight_reg_tensor_ir->outputs().front()->cast_(); - // Embedding weight quantization method same as outputs, but not share, just same type - auto weight_spec_attr = genSimpleQuantizationSpecAttr(writer.getContext(), weight_tensor); + // Embedding weight dtype must be uint16, force set to kUInt16PerTensorAsy + MLLM_RETURN_FALSE_IF_NOT(weight_tensor->tensor_.dtype() == kUInt16 || weight_tensor->tensor_.dtype() == kUInt16PerTensorAsy); + weight_tensor->tensor_ = weight_tensor->tensor_.__unsafeSetDType(kUInt16PerTensorAsy); + + // Create weight spec with kUInt16PerTensorAsy (AsymPerTensor) + auto weight_spec = + ir::linalg::QuantizationSpecAsymPerTensor::create(0, 65535, kUInt16, kFloat32, kInt32, Tensor::nil(), Tensor::nil()); + auto weight_spec_attr = writer.getContext()->create(weight_spec); weight_reg_tensor_ir->outputs().front()->setAttr("quant_recipe", weight_spec_attr); annotation_attr->annotation_.weights.insert({"weight", weight_spec_attr->spec_}); + // o_0's quant recipe shares with weight + o_0->setAttr("quant_recipe", weight_spec_attr); + annotation_attr->annotation_.outputs.emplace_back(weight_spec_attr->spec_); + // Attach to quantize node node->setAttr("quant_recipe", annotation_attr); diff --git a/mllm/backends/qnn/aot/passes/PTQPass.cpp b/mllm/backends/qnn/aot/passes/PTQPass.cpp index 1d42d58d3..d9f1d97cb 100644 --- a/mllm/backends/qnn/aot/passes/PTQPass.cpp +++ b/mllm/backends/qnn/aot/passes/PTQPass.cpp @@ -111,6 +111,22 @@ void solveEmbeddingWeight(const ir::IRContext::ptr_t& ctx, const ParameterFile:: weight_spec->solved = true; break; } + case ir::linalg::QuantizationSpecType::kAsymPerTensor: { + auto this_spec = std::static_pointer_cast(weight_spec); + auto scale = pf->pull(mllm_op->getName() + ".scale"); + auto zero_point = pf->pull(mllm_op->getName() + ".zero_point"); + this_spec->scale = scale; + this_spec->zero_point = zero_point; + checkTypeLimits(pf->pull(mllm_op->getName() + ".weight"), this_spec->quant_min, this_spec->quant_max); + MLLM_RT_ASSERT(scale.dtype() == kFloat32); + MLLM_RT_ASSERT(scale.rank() == 1); + MLLM_RT_ASSERT(scale.item() > 0); + MLLM_RT_ASSERT(zero_point.dtype() == kInt32); + MLLM_RT_ASSERT(zero_point.rank() == 1); + MLLM_RT_ASSERT(zero_point.item() >= 0); + weight_spec->solved = true; + break; + } default: { NYI("quant recipe type not support"); } @@ -203,6 +219,9 @@ void _recursiveSolveNormalImpl(const ir::IRContext::ptr_t& ctx, const ir::Val::p auto _attr = ctx->create(std::vector{(uint16_t)ptq_constant_v}); tv->removeAttr("constant"); tv->setAttr("constant", _attr); + + MLLM_INFO("Constant tensor '{}' quantized (AsymPerTensor): before={}, after={}", tv->name(), constant_v, + ptq_constant_v); } this_spec->solved = true; @@ -262,6 +281,8 @@ void _recursiveSolveNormalImpl(const ir::IRContext::ptr_t& ctx, const ir::Val::p auto _attr = ctx->create(std::vector{(uint16_t)ptq_constant_v}); tv->removeAttr("constant"); tv->setAttr("constant", _attr); + + MLLM_INFO("Constant tensor '{}' quantized (SymPerTensor): before={}, after={}", tv->name(), constant_v, ptq_constant_v); } this_spec->solved = true; @@ -273,7 +294,7 @@ void _recursiveSolveNormalImpl(const ir::IRContext::ptr_t& ctx, const ir::Val::p break; } default: { - NYI("quant recipe type not support on tensor: {}", v->name()); + NYI("Quant recipe type not support on tensor: {}", v->name()); } } } @@ -300,6 +321,135 @@ void recursiveSolveNormal(const std::shared_ptr& ir_ctx, const ir }); } +void recursiveCheckUnsolved(const std::shared_ptr& ir_ctx, const ir::graph::SubGraphOp::ptr_t& call_op) { + auto wow = ir::IRWriter(ir_ctx, call_op->getTopRegion()); + wow.walk([&](ir::IRWriter& w, const ir::Op::ptr_t& op) -> ir::IRWriter::WalkResult { + if (op->isa_()) { + auto linalg_op = op->cast_(); + std::string op_name = linalg_op->getAOp()->getName(); + + auto inputs = op->inputs(); + auto outputs = op->outputs(); + + for (auto iii : inputs) { + if (!iii->isa_()) continue; + auto tv = iii->cast_(); + if (!tv->getAttr("quant_recipe")) continue; + auto f_spec = tv->getAttr("quant_recipe")->cast_(); + if (!f_spec->spec_->solved) { + MLLM_WARN("PTQPass: TensorValue '{}' is not solved, used by Op: '{}'", tv->name(), op_name); + } + } + + for (auto ooo : outputs) { + if (!ooo->isa_()) continue; + auto tv = ooo->cast_(); + if (!tv->getAttr("quant_recipe")) continue; + auto f_spec = tv->getAttr("quant_recipe")->cast_(); + if (!f_spec->spec_->solved) { + MLLM_WARN("PTQPass: TensorValue '{}' is not solved, produced by Op: '{}'", tv->name(), op_name); + } + } + } + + if (op->isa_()) { + auto ns = op->cast_()->getSymbolAttr()->str(); + recursiveCheckUnsolved(w.getContext(), w.getContext()->lookupSymbolTable(ns)->cast_()); + } + return ir::IRWriter::WALK_CONTINUE; + }); +} + +void recursiveCheckConcatInputs(const std::shared_ptr& ir_ctx, const ir::graph::SubGraphOp::ptr_t& call_op) { + auto wow = ir::IRWriter(ir_ctx, call_op->getTopRegion()); + wow.walk([&](ir::IRWriter& w, const ir::Op::ptr_t& op) -> ir::IRWriter::WalkResult { + if (op->isa_()) { + auto concat_op = op->cast_(); + std::string op_name = concat_op->getAOp()->getName(); + + auto inputs = op->inputs(); + if (inputs.empty()) { return ir::IRWriter::WALK_CONTINUE; } + + // Get first input's scale and zero_point as reference + Tensor ref_scale; + Tensor ref_zero_point; + bool has_ref = false; + std::string ref_input_name; + + for (auto iii : inputs) { + if (!iii->isa_()) continue; + auto tv = iii->cast_(); + if (!tv->getAttr("quant_recipe")) continue; + auto f_spec = tv->getAttr("quant_recipe")->cast_(); + + if (f_spec->spec_->type == ir::linalg::QuantizationSpecType::kAsymPerTensor) { + auto this_spec = std::static_pointer_cast(f_spec->spec_); + if (!this_spec->solved) continue; + + if (!has_ref) { + ref_scale = this_spec->scale; + ref_zero_point = this_spec->zero_point; + ref_input_name = tv->name(); + has_ref = true; + } else { + // Check if scale and zero_point match + auto cur_scale = this_spec->scale; + auto cur_zero_point = this_spec->zero_point; + + MLLM_RT_ASSERT_EQ(ref_scale.numel(), 1); + MLLM_RT_ASSERT_EQ(cur_scale.numel(), 1); + MLLM_RT_ASSERT_EQ(ref_zero_point.numel(), 1); + MLLM_RT_ASSERT_EQ(cur_zero_point.numel(), 1); + + auto ref_scale_v = ref_scale.item(); + auto cur_scale_v = cur_scale.item(); + auto ref_zp_v = ref_zero_point.item(); + auto cur_zp_v = cur_zero_point.item(); + + if (std::abs(ref_scale_v - cur_scale_v) > 1e-6 || ref_zp_v != cur_zp_v) { + MLLM_ERROR("PTQPass: ConcatOp '{}' has mismatched scale/zp between inputs. " + "Input '{}': scale={}, zp={}, scale_name={}, zp_name={}; Input '{}': scale={}, zp={}, scale_name={}, " + "zp_name={}", + op_name, ref_input_name, ref_scale_v, ref_zp_v, ref_scale.name(), ref_zero_point.name(), tv->name(), + cur_scale_v, cur_zp_v, cur_scale.name(), cur_zero_point.name()); + } + } + } else if (f_spec->spec_->type == ir::linalg::QuantizationSpecType::kSymPerTensor) { + auto this_spec = std::static_pointer_cast(f_spec->spec_); + if (!this_spec->solved) continue; + + if (!has_ref) { + ref_scale = this_spec->scale; + ref_input_name = tv->name(); + has_ref = true; + } else { + // Check if scale matches + auto cur_scale = this_spec->scale; + + MLLM_RT_ASSERT_EQ(ref_scale.numel(), 1); + MLLM_RT_ASSERT_EQ(cur_scale.numel(), 1); + + auto ref_scale_v = ref_scale.item(); + auto cur_scale_v = cur_scale.item(); + + if (std::abs(ref_scale_v - cur_scale_v) > 1e-6) { + MLLM_ERROR("PTQPass: ConcatOp '{}' has mismatched scale between inputs. " + "Input '{}': scale={}; Input '{}': scale={}", + op_name, ref_input_name, ref_scale_v, tv->name(), cur_scale_v); + } + } + } + } + } + + if (op->isa_()) { + auto ns = op->cast_()->getSymbolAttr()->str(); + recursiveCheckConcatInputs(w.getContext(), w.getContext()->lookupSymbolTable(ns)->cast_()); + } + return ir::IRWriter::WALK_CONTINUE; + }); +} + } // namespace uint8_t PTQPass::run(const ir::node_ptr_t& op) { @@ -330,6 +480,16 @@ uint8_t PTQPass::run(const ir::node_ptr_t& op) { getCtx()->lookupSymbolTable(call_main_graph_op->getSymbolAttr()->str())->cast_(), pf); + // Check for unsolved tensorValues and warn + recursiveCheckUnsolved( + writer.getContext(), + getCtx()->lookupSymbolTable(call_main_graph_op->getSymbolAttr()->str())->cast_()); + + // Check Concat inputs have consistent scale and zero_point + recursiveCheckConcatInputs( + writer.getContext(), + getCtx()->lookupSymbolTable(call_main_graph_op->getSymbolAttr()->str())->cast_()); + return ir::PASS_RET_SUCCESS; } diff --git a/mllm/backends/qnn/aot/visitor/RMSNorm.cpp b/mllm/backends/qnn/aot/visitor/RMSNorm.cpp index 27f72e2e2..351e2562a 100644 --- a/mllm/backends/qnn/aot/visitor/RMSNorm.cpp +++ b/mllm/backends/qnn/aot/visitor/RMSNorm.cpp @@ -47,9 +47,12 @@ bool QnnAOTRMSNormPattern::rewrite(ir::IRWriter& writer, const ir::op_ptr_t& op) auto bias_tensor = mllm::Tensor::zeros(weight->tensor_.shape(), weight->tensor_.dtype()); auto bias_node = ir::tensor::TensorValue::build(writer.getContext().get(), bias_tensor); bias_node->tensor_.setName(a->getName() + "_runtime_bias"); + bias_node->name() = a->getName() + "_runtime_bias"; // fake bias quant recipe - auto quant_spec = mllm::ir::linalg::QuantizationSpecSymPerTensor::create(0, 0, kInt32, kFloat32, Tensor::ones({1})); + auto bias_scale = Tensor::ones({1}); + bias_scale.at({0}) = 1.0 / 32767; + auto quant_spec = mllm::ir::linalg::QuantizationSpecSymPerTensor::create(-32768, 32767, kInt16, kFloat32, bias_scale); auto quant_attr = mllm::ir::linalg::LinalgIRQuantizatonSpecAttr::build(writer.getContext().get(), quant_spec); bias_node->setAttr("quant_recipe", quant_attr); diff --git a/mllm/backends/qnn/aot_rt/PromptProcessor.cpp b/mllm/backends/qnn/aot_rt/PromptProcessor.cpp index 50396955d..99cd22db9 100644 --- a/mllm/backends/qnn/aot_rt/PromptProcessor.cpp +++ b/mllm/backends/qnn/aot_rt/PromptProcessor.cpp @@ -1,4 +1,3 @@ - // Copyright (c) MLLM Team. // Licensed under the MIT License. diff --git a/mllm/core/aops/EmbeddingOp.cpp b/mllm/core/aops/EmbeddingOp.cpp index b67eeff7a..a5a6400dd 100644 --- a/mllm/core/aops/EmbeddingOp.cpp +++ b/mllm/core/aops/EmbeddingOp.cpp @@ -70,8 +70,10 @@ void EmbeddingOp::reshape(const std::vector& inputs, std::vector std::vector o_shape{/*batch*/ shape[0], /*seq*/ shape[1], /*feat dim*/ options_.hidden_size}; - // FIXME: We should tell embedding output to use what kinds of data types. Currently it's hardcoded to float32. - outputs.emplace_back(Tensor::empty(o_shape, kFloat32, i.device())); + // Output dtype should match weight dtype (e.g., uint16 for AsymPerTensor quantization) + auto out_dtype = weight_.dtype(); + if (weight_.dtype() == kUInt16) { out_dtype = kUInt16PerTensorAsy; } + outputs.emplace_back(Tensor::empty(o_shape, out_dtype, i.device())); } void EmbeddingOp::setup(const std::vector& inputs, std::vector& outputs) { BaseOp::setup(inputs, outputs); } diff --git a/mllm/ffi/Extension.cc b/mllm/ffi/Extension.cc index 22449f883..cb999191d 100644 --- a/mllm/ffi/Extension.cc +++ b/mllm/ffi/Extension.cc @@ -53,9 +53,25 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("mllm.cpu_", []() -> mllm::ffi::Device { return mllm::ffi::Device(::mllm::DeviceTypes::kCPU); }); refl::GlobalDef().def("mllm.cuda_", []() -> mllm::ffi::Device { return mllm::ffi::Device(::mllm::DeviceTypes::kCUDA); }); refl::GlobalDef().def("mllm.qnn_", []() -> mllm::ffi::Device { return mllm::ffi::Device(::mllm::DeviceTypes::kQNN); }); + // Floating point types refl::GlobalDef().def("mllm.float32_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kFloat32); }); refl::GlobalDef().def("mllm.float16_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kFloat16); }); refl::GlobalDef().def("mllm.bfloat16_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kBFloat16); }); + + // Signed integer types + refl::GlobalDef().def("mllm.int8_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kInt8); }); + refl::GlobalDef().def("mllm.int16_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kInt16); }); + refl::GlobalDef().def("mllm.int32_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kInt32); }); + refl::GlobalDef().def("mllm.int64_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kInt64); }); + + // Unsigned integer types + refl::GlobalDef().def("mllm.uint8_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kUInt8); }); + refl::GlobalDef().def("mllm.uint16_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kUInt16); }); + refl::GlobalDef().def("mllm.uint32_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kUInt32); }); + refl::GlobalDef().def("mllm.uint64_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kUInt64); }); + + // Bool type + refl::GlobalDef().def("mllm.bool_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kUInt8); }); } //===----------------------------------------------------------------------===// diff --git a/mllm/models/qwen2_5omni/audio_preprocessor_qwen2_5omni.hpp b/mllm/models/qwen2_5omni/audio_preprocessor_qwen2_5omni.hpp new file mode 100644 index 000000000..392bfc17b --- /dev/null +++ b/mllm/models/qwen2_5omni/audio_preprocessor_qwen2_5omni.hpp @@ -0,0 +1,240 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "mllm/core/Tensor.hpp" +#include "mllm/nn/Module.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/preprocessor/audio/Audio.hpp" + +namespace mllm::models::qwen2_5omni { + +inline float hertz_to_mel_slaney(float freq) { + constexpr float kMinLogHertz = 1000.0f; + constexpr float kMinLogMel = 15.0f; + const float logstep = 27.0f / std::log(6.4f); + + if (freq < kMinLogHertz) { + return 3.0f * freq / 200.0f; + } + return kMinLogMel + std::log(freq / kMinLogHertz) * logstep; +} + +inline float mel_to_hertz_slaney(float mel) { + constexpr float kMinLogHertz = 1000.0f; + constexpr float kMinLogMel = 15.0f; + const float logstep = std::log(6.4f) / 27.0f; + + if (mel < kMinLogMel) { + return 200.0f * mel / 3.0f; + } + return kMinLogHertz * std::exp(logstep * (mel - kMinLogMel)); +} + +inline Tensor create_hann_window(int32_t window_length, bool periodic = true) { + int32_t length = periodic ? window_length + 1 : window_length; + auto window = Tensor::empty({1, window_length}, kFloat32, kCPU).alloc(); + float* window_ptr = window.ptr(); + + for (int32_t i = 0; i < window_length; ++i) { + float n = static_cast(i); + float denominator = periodic ? static_cast(length) : static_cast(length - 1); + window_ptr[i] = 0.5f - 0.5f * std::cos(2.0f * M_PI * n / denominator); + } + + return window; +} + +inline Tensor create_mel_filterbank(int32_t num_frequency_bins, int32_t num_mel_filters, float min_frequency, + float max_frequency, int32_t sampling_rate) { + std::vector fft_freqs(num_frequency_bins); + for (int32_t i = 0; i < num_frequency_bins; ++i) { + fft_freqs[i] = static_cast(i) * (sampling_rate / 2.0f) / (num_frequency_bins - 1); + } + + float mel_min = hertz_to_mel_slaney(min_frequency); + float mel_max = hertz_to_mel_slaney(max_frequency); + + std::vector mel_freqs(num_mel_filters + 2); + for (int32_t i = 0; i < num_mel_filters + 2; ++i) { + mel_freqs[i] = mel_min + static_cast(i) * (mel_max - mel_min) / (num_mel_filters + 1); + } + + std::vector filter_freqs(num_mel_filters + 2); + for (int32_t i = 0; i < num_mel_filters + 2; ++i) { filter_freqs[i] = mel_to_hertz_slaney(mel_freqs[i]); } + + auto mel_filters = Tensor::empty({num_frequency_bins, num_mel_filters}, kFloat32, kCPU).alloc(); + float* filters_ptr = mel_filters.ptr(); + std::fill_n(filters_ptr, num_frequency_bins * num_mel_filters, 0.0f); + + for (int32_t mel_idx = 0; mel_idx < num_mel_filters; ++mel_idx) { + float left_freq = filter_freqs[mel_idx]; + float center_freq = filter_freqs[mel_idx + 1]; + float right_freq = filter_freqs[mel_idx + 2]; + + for (int32_t freq_idx = 0; freq_idx < num_frequency_bins; ++freq_idx) { + float freq = fft_freqs[freq_idx]; + float value = 0.0f; + + if (freq >= left_freq && freq <= center_freq && center_freq != left_freq) { + value = (freq - left_freq) / (center_freq - left_freq); + } else if (freq >= center_freq && freq <= right_freq && right_freq != center_freq) { + value = (right_freq - freq) / (right_freq - center_freq); + } + + filters_ptr[freq_idx * num_mel_filters + mel_idx] = value; + } + } + + for (int32_t mel_idx = 0; mel_idx < num_mel_filters; ++mel_idx) { + float enorm = 2.0f / (filter_freqs[mel_idx + 2] - filter_freqs[mel_idx]); + for (int32_t freq_idx = 0; freq_idx < num_frequency_bins; ++freq_idx) { + filters_ptr[freq_idx * num_mel_filters + mel_idx] *= enorm; + } + } + + return mel_filters; +} + +class MelSpectrogramFeatures final : public nn::Module { + int32_t n_fft_; + int32_t hop_length_; + int32_t win_length_; + int32_t n_mels_; + std::string padding_; + int power_; + nn::STFT stft_; + Tensor window_; + Tensor melscale_fbanks_; + + public: + MelSpectrogramFeatures() = default; + + explicit inline MelSpectrogramFeatures(const std::string& name, int32_t sample_rate = 16000, int32_t n_fft = 400, + int32_t hop_length = 160, int32_t n_mels = 128, + const std::string& padding = "center", int power = 2) + : nn::Module(name), n_fft_(n_fft), hop_length_(hop_length), n_mels_(n_mels), padding_(padding), power_(power) { + if (padding != "center" && padding != "same") { throw std::invalid_argument("Padding must be 'center' or 'same'."); } + + win_length_ = n_fft_; + stft_ = reg("stft", n_fft_, hop_length_, win_length_, true, true, "reflect", true); + window_ = create_hann_window(win_length_, true); + melscale_fbanks_ = create_mel_filterbank(n_fft_ / 2 + 1, n_mels_, 0.0f, 8000.0f, sample_rate); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto audio = inputs[0]; // [B, T] + + if (padding_ == "same") { + NYI("apply same padding in MelSpectrogramFeatures not implemented"); + } + + auto stft_result = stft_(audio, window_); + auto specgram = stft_result.abs(); + if (power_ == 2) { + specgram = specgram * specgram; + } else if (power_ != 1) { + NYI("power != 1 and power != 2 not implemented"); + } + + auto mel_specgram = nn::functional::matmul(specgram.T(), melscale_fbanks_).T(); + mel_specgram = nn::functional::clip(mel_specgram, 1e-10f, std::numeric_limits::max()); + mel_specgram = nn::functional::log(mel_specgram) / std::log(10.0f); + auto max_val = mel_specgram.max(); + float threshold = max_val.item() - 8.0f; + mel_specgram = nn::functional::clip(mel_specgram, threshold, std::numeric_limits::max()); + mel_specgram = (mel_specgram + 4.0f) / 4.0f; + + return {mel_specgram}; + } +}; + +struct Qwen2_5OmniAudioFeatures { + Tensor input_features = Tensor::nil(); + int32_t feature_length = 0; +}; + +class Qwen2_5OmniAudioPreprocessor { + MelSpectrogramFeatures mel_extractor_; + int32_t sample_rate_; + int32_t n_mels_; + int32_t hop_length_; + int32_t chunk_length_; + int32_t n_samples_; + + public: + explicit Qwen2_5OmniAudioPreprocessor(int32_t sample_rate = 16000, int32_t n_mels = 128, int32_t hop_length = 160, + int32_t chunk_length = 300) + : mel_extractor_("feature_extractor.mel_spec", sample_rate, 400, hop_length, n_mels, "center", 2), + sample_rate_(sample_rate), + n_mels_(n_mels), + hop_length_(hop_length), + chunk_length_(chunk_length), + n_samples_(chunk_length * sample_rate) {} + + [[nodiscard]] Qwen2_5OmniAudioFeatures processAudioFile(const std::string& audio_file_path) { + auto audio_data = mllm::audio::readWAV(audio_file_path, sample_rate_); + if (audio_data.empty()) { return {}; } + return processAudioData(audio_data.data(), static_cast(audio_data.size())); + } + + [[nodiscard]] Qwen2_5OmniAudioFeatures processAudioData(const float* audio_data, int32_t audio_length) { + Qwen2_5OmniAudioFeatures result; + if (audio_data == nullptr || audio_length <= 0) { return result; } + + int32_t padded_length = n_samples_; + int32_t effective_length = std::min(audio_length, padded_length); + + auto audio_tensor = Tensor::empty({1, padded_length}, kFloat32, kCPU).alloc(); + float* audio_ptr = audio_tensor.ptr(); + + if (audio_length <= padded_length) { + std::memcpy(audio_ptr, audio_data, audio_length * sizeof(float)); + std::fill(audio_ptr + audio_length, audio_ptr + padded_length, 0.0f); + } else { + std::memcpy(audio_ptr, audio_data, padded_length * sizeof(float)); + } + + auto mel_spec = mel_extractor_.forward({audio_tensor}, {})[0]; // [1, n_mels, n_frames] + + int32_t valid_frames = calcFeatureLength(effective_length); + int32_t max_frames = mel_spec.shape()[2]; + if (valid_frames > max_frames) { valid_frames = max_frames; } + if (valid_frames <= 0) { return result; } + + auto trimmed = Tensor::empty({1, n_mels_, valid_frames}, kFloat32, kCPU).alloc(); + for (int32_t m = 0; m < n_mels_; ++m) { + auto src_ptr = mel_spec.offsettedPtr({0, m, 0}); + auto dst_ptr = trimmed.offsettedPtr({0, m, 0}); + std::memcpy(dst_ptr, src_ptr, valid_frames * sizeof(float)); + } + + result.input_features = trimmed; + result.feature_length = valid_frames; + return result; + } + + [[nodiscard]] int32_t calcFeatureLength(int32_t audio_length) const { + if (audio_length <= 0) { return 0; } + return (audio_length + hop_length_ - 1) / hop_length_; + } + + [[nodiscard]] int32_t calcAudioTokenLength(int32_t feature_length) const { + if (feature_length <= 0) { return 0; } + int32_t after_conv = (feature_length - 1) / 2 + 1; + if (after_conv < 2) { return 0; } + int32_t after_pool = (after_conv - 2) / 2 + 1; + return std::max(0, after_pool); + } +}; + +} // namespace mllm::models::qwen2_5omni diff --git a/mllm/models/qwen2_5omni/configuration_qwen2_5omni.hpp b/mllm/models/qwen2_5omni/configuration_qwen2_5omni.hpp new file mode 100644 index 000000000..d0e000642 --- /dev/null +++ b/mllm/models/qwen2_5omni/configuration_qwen2_5omni.hpp @@ -0,0 +1,179 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include + +#include "mllm/core/aops/LinearOp.hpp" +#include "mllm/engine/ConfigFile.hpp" + +namespace mllm::models::qwen2_5omni { + +struct Qwen2_5OmniConfig : protected ConfigFile { + Qwen2_5OmniConfig() = default; + + explicit Qwen2_5OmniConfig(const std::string& file_path) : ConfigFile(file_path) { + auto& root = data(); + + if (root.contains("thinker_config")) { + auto& thinker_cfg = root["thinker_config"]; + auto& text_cfg = thinker_cfg["text_config"]; + + hidden_size = text_cfg["hidden_size"]; + intermediate_size = text_cfg["intermediate_size"]; + num_attention_heads = text_cfg["num_attention_heads"]; + num_key_value_heads = text_cfg["num_key_value_heads"]; + num_hidden_layers = text_cfg["num_hidden_layers"]; + max_position_embeddings = text_cfg["max_position_embeddings"]; + rms_norm_eps = text_cfg["rms_norm_eps"]; + vocab_size = text_cfg["vocab_size"]; + rope_theta = text_cfg["rope_theta"]; + tie_word_embeddings = text_cfg.value("tie_word_embeddings", false); + + if (text_cfg.contains("rope_scaling") && text_cfg["rope_scaling"].contains("mrope_section")) { + mrope_section = text_cfg["rope_scaling"]["mrope_section"].get>(); + } + + if (thinker_cfg.contains("vision_config")) { + auto& vision_cfg = thinker_cfg["vision_config"]; + visual_in_chans = vision_cfg.value("in_channels", vision_cfg.value("in_chans", visual_in_chans)); + visual_hidden_size = vision_cfg.value("hidden_size", vision_cfg.value("embed_dim", visual_hidden_size)); + visual_patch_size = vision_cfg.value("patch_size", vision_cfg.value("spatial_patch_size", visual_patch_size)); + visual_temporal_patch_size = vision_cfg.value("temporal_patch_size", visual_temporal_patch_size); + visual_spatial_merge_size = vision_cfg.value("spatial_merge_size", visual_spatial_merge_size); + visual_out_hidden_size = vision_cfg.value("out_hidden_size", visual_out_hidden_size); + visual_num_heads = vision_cfg.value("num_heads", visual_num_heads); + visual_depth = vision_cfg.value("depth", visual_depth); + visual_intermediate_size = vision_cfg.value("intermediate_size", visual_intermediate_size); + if (vision_cfg.contains("fullatt_block_indexes")) { + visual_fullatt_block_indexes = vision_cfg["fullatt_block_indexes"].get>(); + } + visual_window_size = vision_cfg.value("window_size", visual_window_size); + } + + if (thinker_cfg.contains("audio_config")) { + auto& audio_cfg = thinker_cfg["audio_config"]; + audio_d_model = audio_cfg.value("d_model", audio_d_model); + audio_num_mel_bins = audio_cfg.value("num_mel_bins", audio_num_mel_bins); + audio_encoder_layers = audio_cfg.value("encoder_layers", audio_encoder_layers); + audio_encoder_attention_heads = audio_cfg.value("encoder_attention_heads", audio_encoder_attention_heads); + audio_encoder_ffn_dim = audio_cfg.value("encoder_ffn_dim", audio_encoder_ffn_dim); + audio_max_source_positions = audio_cfg.value("max_source_positions", audio_max_source_positions); + audio_n_window = audio_cfg.value("n_window", audio_n_window); + audio_output_dim = audio_cfg.value("output_dim", audio_output_dim); + } + + bos_token_id = thinker_cfg.value("bos_token_id", bos_token_id); + eos_token_id = thinker_cfg.value("eos_token_id", eos_token_id); + pad_token_id = thinker_cfg.value("pad_token_id", pad_token_id); + image_token_id = thinker_cfg.value("image_token_index", image_token_id); + audio_token_id = thinker_cfg.value("audio_token_index", audio_token_id); + video_token_id = thinker_cfg.value("video_token_index", video_token_id); + audio_start_token_id = thinker_cfg.value("audio_start_token_id", audio_start_token_id); + audio_end_token_id = thinker_cfg.value("audio_end_token_id", audio_end_token_id); + vision_start_token_id = thinker_cfg.value("vision_start_token_id", vision_start_token_id); + vision_end_token_id = thinker_cfg.value("vision_end_token_id", vision_end_token_id); + vision_token_id = thinker_cfg.value("vision_token_id", vision_token_id); + position_id_per_seconds = thinker_cfg.value("position_id_per_seconds", position_id_per_seconds); + seconds_per_chunk = thinker_cfg.value("seconds_per_chunk", seconds_per_chunk); + } else { + hidden_size = root["hidden_size"]; + intermediate_size = root["intermediate_size"]; + num_attention_heads = root["num_attention_heads"]; + num_key_value_heads = root["num_key_value_heads"]; + num_hidden_layers = root["num_hidden_layers"]; + max_position_embeddings = root["max_position_embeddings"]; + rms_norm_eps = root["rms_norm_eps"]; + vocab_size = root["vocab_size"]; + rope_theta = root["rope_theta"]; + tie_word_embeddings = root.value("tie_word_embeddings", tie_word_embeddings); + if (root.contains("mrope_section")) { + mrope_section = root["mrope_section"].get>(); + } + if (root.contains("audio_config")) { + auto& audio_cfg = root["audio_config"]; + audio_d_model = audio_cfg.value("d_model", audio_d_model); + audio_num_mel_bins = audio_cfg.value("num_mel_bins", audio_num_mel_bins); + audio_encoder_layers = audio_cfg.value("encoder_layers", audio_encoder_layers); + audio_encoder_attention_heads = audio_cfg.value("encoder_attention_heads", audio_encoder_attention_heads); + audio_encoder_ffn_dim = audio_cfg.value("encoder_ffn_dim", audio_encoder_ffn_dim); + audio_max_source_positions = audio_cfg.value("max_source_positions", audio_max_source_positions); + audio_n_window = audio_cfg.value("n_window", audio_n_window); + audio_output_dim = audio_cfg.value("output_dim", audio_output_dim); + } + bos_token_id = root.value("bos_token_id", bos_token_id); + eos_token_id = root.value("eos_token_id", eos_token_id); + pad_token_id = root.value("pad_token_id", pad_token_id); + image_token_id = root.value("image_token_id", image_token_id); + audio_token_id = root.value("audio_token_id", audio_token_id); + video_token_id = root.value("video_token_id", video_token_id); + audio_start_token_id = root.value("audio_start_token_id", audio_start_token_id); + audio_end_token_id = root.value("audio_end_token_id", audio_end_token_id); + vision_start_token_id = root.value("vision_start_token_id", vision_start_token_id); + vision_end_token_id = root.value("vision_end_token_id", vision_end_token_id); + vision_token_id = root.value("vision_token_id", vision_token_id); + position_id_per_seconds = root.value("position_id_per_seconds", position_id_per_seconds); + seconds_per_chunk = root.value("seconds_per_chunk", seconds_per_chunk); + } + + max_cache_length = root.value("max_cache_length", max_position_embeddings); + + if (root.contains("linear_impl_type")) { + linear_impl_type = aops::str2LinearImplTypes(root["linear_impl_type"]); + } + } + + int32_t hidden_size = 3584; + int32_t intermediate_size = 18944; + int32_t num_attention_heads = 28; + int32_t num_key_value_heads = 4; + int32_t num_hidden_layers = 28; + int32_t max_position_embeddings = 32768; + float rms_norm_eps = 1e-06f; + int32_t vocab_size = 152064; + std::vector mrope_section = {16, 24, 24}; + float rope_theta = 1000000.0f; + bool tie_word_embeddings = false; + + int32_t visual_in_chans = 3; + int32_t visual_hidden_size = 1280; + int32_t visual_patch_size = 14; + int32_t visual_temporal_patch_size = 2; + int32_t visual_spatial_merge_size = 2; + int32_t visual_out_hidden_size = 3584; + int32_t visual_num_heads = 16; + int32_t visual_depth = 32; + int32_t visual_intermediate_size = 3420; + std::vector visual_fullatt_block_indexes = {7, 15, 23, 31}; + int32_t visual_window_size = 112; + + int32_t audio_d_model = 1280; + int32_t audio_num_mel_bins = 128; + int32_t audio_encoder_layers = 32; + int32_t audio_encoder_attention_heads = 20; + int32_t audio_encoder_ffn_dim = 5120; + int32_t audio_max_source_positions = 1500; + int32_t audio_n_window = 100; + int32_t audio_output_dim = 3584; + + int32_t max_cache_length = 32768; + + int64_t bos_token_id = 151644; + int64_t eos_token_id = 151645; + int64_t pad_token_id = 151643; + int64_t image_token_id = 151655; + int64_t audio_token_id = 151646; + int64_t video_token_id = 151656; + int64_t audio_start_token_id = 151647; + int64_t audio_end_token_id = 151648; + int64_t vision_start_token_id = 151652; + int64_t vision_end_token_id = 151653; + int64_t vision_token_id = 151654; + int32_t position_id_per_seconds = 25; + int32_t seconds_per_chunk = 2; + + aops::LinearImplTypes linear_impl_type = aops::LinearImplTypes::kDefault; +}; + +} // namespace mllm::models::qwen2_5omni diff --git a/mllm/models/qwen2_5omni/modeling_qwen2_5omni.hpp b/mllm/models/qwen2_5omni/modeling_qwen2_5omni.hpp new file mode 100644 index 000000000..fac087bae --- /dev/null +++ b/mllm/models/qwen2_5omni/modeling_qwen2_5omni.hpp @@ -0,0 +1,1378 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include + +#include "mllm/mllm.hpp" +#include "mllm/core/SlicePrimitives.hpp" +#include "mllm/nn/Module.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/nn/lmcache/StaticCache.hpp" +#include "mllm/models/ARGeneration.hpp" +#include "mllm/utils/Enumerate.hpp" + +#include "mllm/models/qwen2_5omni/configuration_qwen2_5omni.hpp" + +namespace mllm::models::qwen2_5omni { + +inline auto makeMultimodalRoPEInvFreq(int output_dim, float rope_theta) -> Tensor { + auto inv_freq = Tensor::empty({output_dim / 2}, kFloat32, kCPU).alloc(); + auto inv_freq_ptr = inv_freq.ptr(); + for (int i = 0; i < output_dim / 2; i++) { inv_freq_ptr[i] = 1.0f / std::pow(rope_theta, 2.0f * i / output_dim); } + return inv_freq; +} + +inline auto makeMultimodalPositionEmbedding(Tensor& position_ids, const Tensor& inv_freq, int seq_len, int output_dim, + const std::vector& mrope_section) -> std::pair { + MLLM_RT_ASSERT_EQ(position_ids.shape().size(), 3); + MLLM_RT_ASSERT_EQ(position_ids.shape()[1], 1); + + Tensor tmp_sin = Tensor::empty({3, position_ids.shape()[2], inv_freq.shape()[0] * 2}).alloc(); + Tensor tmp_cos = Tensor::empty({3, position_ids.shape()[2], inv_freq.shape()[0] * 2}).alloc(); + + for (int b = 0; b < 3; ++b) { + for (int d = 0; d < inv_freq.shape()[0]; ++d) { + for (int s = 0; s < position_ids.shape()[2]; ++s) { + auto value = inv_freq.ptr()[d] * (*position_ids.offsettedPtr({b, 0, s})); + *tmp_cos.offsettedPtr({b, s, d}) = cosf(value); + *tmp_cos.offsettedPtr({b, s, d + inv_freq.shape()[0]}) = cosf(value); + *tmp_sin.offsettedPtr({b, s, d}) = sinf(value); + *tmp_sin.offsettedPtr({b, s, d + inv_freq.shape()[0]}) = sinf(value); + } + } + } + + Tensor sin = Tensor::nil(); + Tensor cos = Tensor::nil(); + + if (!mrope_section.empty()) { + auto double_rope_section = mrope_section; + for (int i : mrope_section) { double_rope_section.push_back(i); } + + int num_rows = tmp_sin.shape()[1]; + int num_cols = tmp_sin.shape()[2]; + + sin = Tensor::empty({num_rows, num_cols}, kFloat32, kCPU).alloc(); + cos = Tensor::empty({num_rows, num_cols}, kFloat32, kCPU).alloc(); + + std::vector start_cols; + int current_start = 0; + start_cols.push_back(current_start); + for (int s : double_rope_section) { + current_start += s; + start_cols.push_back(current_start); + } + + for (int j = 0; j < static_cast(double_rope_section.size()); ++j) { + int layer = j % 3; + int s_j = double_rope_section[j]; + int start_col_in = start_cols[j]; + int start_col_out = start_cols[j]; + for (int row = 0; row < num_rows; ++row) { + auto in_cos_row_ptr = tmp_cos.offsettedPtr({layer, row, 0}); + auto out_cos_row_ptr = cos.offsettedPtr({row, 0}); + for (int c = 0; c < s_j; ++c) { out_cos_row_ptr[start_col_out + c] = in_cos_row_ptr[start_col_in + c]; } + + auto in_sin_row_ptr = tmp_sin.offsettedPtr({layer, row, 0}); + auto out_sin_row_ptr = sin.offsettedPtr({row, 0}); + for (int c = 0; c < s_j; ++c) { out_sin_row_ptr[start_col_out + c] = in_sin_row_ptr[start_col_in + c]; } + } + } + } else { + sin = tmp_sin; + cos = tmp_cos; + } + + return {sin, cos}; +} + +inline auto makeWindowIndex(const Tensor& grid_thw, int window_size, int spatial_merge_size, + int patch_size) -> std::pair, std::vector> { + MLLM_RT_ASSERT_EQ(grid_thw.shape().size(), 2); + const int grid_num = grid_thw.shape()[0]; + + const int vit_merger_window_size = window_size / spatial_merge_size / patch_size; + const int spatial_merge_unit = spatial_merge_size * spatial_merge_size; + + std::vector window_index; + std::vector cu_window_seqlens = {0}; + int window_index_id = 0; + + for (int grid_idx = 0; grid_idx < grid_num; ++grid_idx) { + const int grid_t = grid_thw.constAt({grid_idx, 0}); + const int grid_h = grid_thw.constAt({grid_idx, 1}); + const int grid_w = grid_thw.constAt({grid_idx, 2}); + + const int llm_grid_h = grid_h / spatial_merge_size; + const int llm_grid_w = grid_w / spatial_merge_size; + const int pad_h = (vit_merger_window_size - llm_grid_h % vit_merger_window_size) % vit_merger_window_size; + const int pad_w = (vit_merger_window_size - llm_grid_w % vit_merger_window_size) % vit_merger_window_size; + + const int num_windows_h = (llm_grid_h + pad_h) / vit_merger_window_size; + const int num_windows_w = (llm_grid_w + pad_w) / vit_merger_window_size; + const int total_windows = grid_t * num_windows_h * num_windows_w; + + std::vector>> index( + grid_t, std::vector>(llm_grid_h, std::vector(llm_grid_w))); + + int counter = 0; + for (int t = 0; t < grid_t; t++) { + for (int h = 0; h < llm_grid_h; h++) { + for (int w = 0; w < llm_grid_w; w++) { index[t][h][w] = counter++; } + } + } + + std::vector>> index_padded( + grid_t, std::vector>(llm_grid_h + pad_h, std::vector(llm_grid_w + pad_w, -100))); + + for (int t = 0; t < grid_t; t++) { + for (int h = 0; h < llm_grid_h; h++) { + for (int w = 0; w < llm_grid_w; w++) { index_padded[t][h][w] = index[t][h][w]; } + } + } + + std::vector seqlens(total_windows, 0); + for (int t = 0; t < grid_t; t++) { + for (int wh = 0; wh < num_windows_h; wh++) { + for (int ww = 0; ww < num_windows_w; ww++) { + const int window_idx = t * num_windows_h * num_windows_w + wh * num_windows_w + ww; + for (int h = 0; h < vit_merger_window_size; h++) { + for (int w = 0; w < vit_merger_window_size; w++) { + const int orig_h = wh * vit_merger_window_size + h; + const int orig_w = ww * vit_merger_window_size + w; + if (index_padded[t][orig_h][orig_w] != -100) { + window_index.push_back(index_padded[t][orig_h][orig_w] + window_index_id); + seqlens[window_idx]++; + } + } + } + } + } + } + + int cumulative = cu_window_seqlens.back(); + for (int i = 0; i < total_windows; i++) { + cumulative += seqlens[i] * spatial_merge_unit; + cu_window_seqlens.push_back(cumulative); + } + + window_index_id += grid_t * llm_grid_h * llm_grid_w; + } + + return {window_index, cu_window_seqlens}; +} + +inline auto makeVisualRoPEInvFreq(int32_t dims, float theta) -> Tensor { + const int half_dim = dims / (2 * 2); + Tensor inv_freq = Tensor::empty({half_dim}, kFloat32).alloc(); + float* inv_freq_ptr = inv_freq.ptr(); + const float dims_inv = 1.0f / static_cast(dims / 2); + for (int i = 0; i < half_dim; ++i) { + const float exponent = (2.0f * i) * dims_inv; + inv_freq_ptr[i] = 1.0f / std::pow(theta, exponent); + } + return inv_freq; +} + +inline auto makeVisualRotaryPosEmbIds(Tensor& grid_thw, int32_t spatial_merge_size) -> Tensor { + MLLM_RT_ASSERT_EQ(grid_thw.shape().size(), 2); + + const auto img_nums = grid_thw.shape()[0]; + int total_positions = 0; + for (int row = 0; row < img_nums; ++row) { + const int* dims = grid_thw.offsettedPtr({row, 0}); + total_positions += dims[0] * dims[1] * dims[2]; + } + + Tensor out = Tensor::empty({total_positions, 2}, kInt32).alloc(); + int* out_ptr = out.ptr(); + int out_offset = 0; + + for (int row = 0; row < img_nums; ++row) { + const int* dims = grid_thw.offsettedPtr({row, 0}); + const int t = dims[0]; + const int h = dims[1]; + const int w = dims[2]; + + const int num_h_blocks = h / spatial_merge_size; + const int num_w_blocks = w / spatial_merge_size; + const int total_blocks = num_h_blocks * num_w_blocks; + const int block_area = spatial_merge_size * spatial_merge_size; + const int grid_size = h * w; + + std::vector flatten_hpos(grid_size); + std::vector flatten_wpos(grid_size); + + for (int block_idx = 0; block_idx < total_blocks; ++block_idx) { + const int i_h = block_idx / num_w_blocks; + const int i_w = block_idx % num_w_blocks; + const int start_idx = block_idx * block_area; + + const int base_h = i_h * spatial_merge_size; + const int base_w = i_w * spatial_merge_size; + + for (int j_h = 0; j_h < spatial_merge_size; ++j_h) { + const int global_h = base_h + j_h; + for (int j_w = 0; j_w < spatial_merge_size; ++j_w) { + const int global_w = base_w + j_w; + const int pos = start_idx + j_h * spatial_merge_size + j_w; + flatten_hpos[pos] = global_h; + flatten_wpos[pos] = global_w; + } + } + } + + for (int frame = 0; frame < t; ++frame) { + for (int pos = 0; pos < grid_size; ++pos) { + const int out_idx = out_offset + (frame * grid_size + pos) * 2; + out_ptr[out_idx] = flatten_hpos[pos]; + out_ptr[out_idx + 1] = flatten_wpos[pos]; + } + } + out_offset += t * grid_size * 2; + } + + return out; +} + +inline auto makeVisualRotaryPosEmbFull(Tensor& inv_freq, int seq_len) -> Tensor { + MLLM_RT_ASSERT(seq_len > 0); + const int32_t dim = inv_freq.shape()[0]; + Tensor freqs = Tensor::empty({seq_len, dim}, kFloat32, kCPU).alloc(); + float* inv_freq_ptr = inv_freq.ptr(); + float* freqs_ptr = freqs.ptr(); + for (int i = 0; i < seq_len; ++i) { + const float i_val = static_cast(i); + float* row_ptr = freqs_ptr + i * dim; + for (int j = 0; j < dim; ++j) { row_ptr[j] = i_val * inv_freq_ptr[j]; } + } + return freqs; +} + +inline auto makeVisualRotaryPosEmb(Tensor& rotary_pos_emb_full, Tensor& pos_ids, Tensor& grid_thw) -> Tensor { + const int32_t dim = rotary_pos_emb_full.shape()[1]; + const int32_t batch_size = pos_ids.shape()[0]; + const int32_t seq_len = pos_ids.shape()[1]; + + int total_positions = 0; + for (int row = 0; row < grid_thw.shape()[0]; ++row) { + const int* dims = grid_thw.offsettedPtr({row, 0}); + total_positions += dims[0] * dims[1] * dims[2]; + } + + Tensor out = Tensor::empty({batch_size, seq_len * dim}, kFloat32, kCPU).alloc(); + + auto rotary_pos_emb_full_ptr = rotary_pos_emb_full.ptr(); + auto pos_ids_ptr = pos_ids.ptr(); + + if (rotary_pos_emb_full.shape()[0] <= 0 || dim <= 0 || batch_size <= 0) { + MLLM_ERROR_EXIT(ExitCode::kSliceOB, "Invalid tensor dimensions"); + } + + if (total_positions != batch_size) { MLLM_ERROR_EXIT(ExitCode::kSliceOB, "Grid dimensions mismatch with batch size"); } + + for (int i = 0; i < batch_size; ++i) { + for (int j = 0; j < seq_len; ++j) { + const int idx = pos_ids_ptr[i * seq_len + j]; + if (idx < 0 || idx >= rotary_pos_emb_full.shape()[0]) { + MLLM_ERROR_EXIT(ExitCode::kSliceOB, "Position index out of bounds"); + } + } + } + + for (int i = 0; i < batch_size; ++i) { + auto batch_ptr = out.offsettedPtr({i, 0}); + size_t offset = 0; + for (int j = 0; j < seq_len; ++j) { + const int idx = pos_ids_ptr[i * seq_len + j]; + auto emb_ptr = rotary_pos_emb_full_ptr + idx * dim; + std::copy(emb_ptr, emb_ptr + dim, batch_ptr + offset); + offset += dim; + } + } + + return out; +} + +inline auto makeVisualRotarySinCos(Tensor& rotary_pos_emb) -> std::pair { + const auto seq = rotary_pos_emb.shape()[0]; + const auto dim = rotary_pos_emb.shape()[1]; + + auto rotary_pos_emb_ptr = rotary_pos_emb.ptr(); + + Tensor sin_pos_emb = Tensor::empty({seq, dim}, kFloat32, kCPU).alloc(); + Tensor cos_pos_emb = Tensor::empty({seq, dim}, kFloat32, kCPU).alloc(); + + auto sin_pos_emb_ptr = sin_pos_emb.ptr(); + auto cos_pos_emb_ptr = cos_pos_emb.ptr(); + + for (int i = 0; i < seq; i++) { + for (int j = 0; j < dim; j++) { + sin_pos_emb_ptr[i * dim + j] = std::sin(rotary_pos_emb_ptr[i * dim + j]); + cos_pos_emb_ptr[i * dim + j] = std::cos(rotary_pos_emb_ptr[i * dim + j]); + } + } + + return {sin_pos_emb, cos_pos_emb}; +} + +inline auto makeAudioSinusoidalPosEmb(int32_t length, int32_t channels, float max_timescale = 10000.0f) -> Tensor { + MLLM_RT_ASSERT(channels % 2 == 0); + auto pos_emb = Tensor::empty({length, channels}, kFloat32, kCPU).alloc(); + auto pos_ptr = pos_emb.ptr(); + + const int half = channels / 2; + const float log_timescale_increment = std::log(max_timescale) / static_cast(half - 1); + + std::vector inv_timescales(half); + for (int i = 0; i < half; ++i) { + inv_timescales[i] = std::exp(-log_timescale_increment * static_cast(i)); + } + + for (int t = 0; t < length; ++t) { + for (int i = 0; i < half; ++i) { + const float scaled_time = static_cast(t) * inv_timescales[i]; + pos_ptr[t * channels + i] = std::sin(scaled_time); + pos_ptr[t * channels + half + i] = std::cos(scaled_time); + } + } + + return pos_emb; +} + +class Qwen2_5OmniPatchEmbed final : public nn::Module { + int32_t in_chans_; + int32_t embed_dim_; + int32_t patch_size_; + int32_t temporal_patch_size_; + + nn::Conv3D proj_; + + public: + Qwen2_5OmniPatchEmbed() = default; + + explicit Qwen2_5OmniPatchEmbed(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + in_chans_ = cfg.visual_in_chans; + embed_dim_ = cfg.visual_hidden_size; + patch_size_ = cfg.visual_patch_size; + temporal_patch_size_ = cfg.visual_temporal_patch_size; + + proj_ = reg("proj", cfg.visual_in_chans, cfg.visual_hidden_size, + std::vector{cfg.visual_temporal_patch_size, cfg.visual_patch_size, cfg.visual_patch_size}, + std::vector{cfg.visual_temporal_patch_size, cfg.visual_patch_size, cfg.visual_patch_size}, + false); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + hidden_states = hidden_states.view({-1, in_chans_, temporal_patch_size_, patch_size_, patch_size_}); + hidden_states = proj_(hidden_states).view({-1, embed_dim_}); + return {hidden_states}; + } +}; + +class Qwen2_5OmniPatchMerger final : public nn::Module { + int32_t hidden_size_; + int32_t spatial_merge_size_; + int32_t context_dim_; + + nn::RMSNorm ln_q_; + nn::Linear mlp_0_; + nn::Linear mlp_2_; + nn::GELU mlp_gelu_; + + public: + Qwen2_5OmniPatchMerger() = default; + + explicit Qwen2_5OmniPatchMerger(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + context_dim_ = cfg.visual_hidden_size; + spatial_merge_size_ = cfg.visual_spatial_merge_size; + hidden_size_ = context_dim_ * spatial_merge_size_ * spatial_merge_size_; + + ln_q_ = reg("ln_q", 1e-6); + mlp_0_ = reg("mlp.0", hidden_size_, hidden_size_, true, cfg.linear_impl_type); + mlp_gelu_ = reg("mlp.gelu"); + mlp_2_ = reg("mlp.2", hidden_size_, cfg.visual_out_hidden_size, true, cfg.linear_impl_type); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto o = ln_q_(inputs[0]).view({-1, hidden_size_}); + o = mlp_0_(o); + o = mlp_gelu_(o); + o = mlp_2_(o); + return {o}; + } +}; + +class Qwen2_5OmniVisionMLP final : public nn::Module { + nn::Linear gate_proj_; + nn::Linear up_proj_; + nn::Linear down_proj_; + nn::SiLU silu_; + + public: + Qwen2_5OmniVisionMLP() = default; + explicit Qwen2_5OmniVisionMLP(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + gate_proj_ = reg("gate_proj", cfg.visual_hidden_size, cfg.visual_intermediate_size, true); + silu_ = reg("act"); + up_proj_ = reg("up_proj", cfg.visual_hidden_size, cfg.visual_intermediate_size, true); + down_proj_ = reg("down_proj", cfg.visual_intermediate_size, cfg.visual_hidden_size, true); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = gate_proj_(inputs[0]); + x = silu_(x); + auto y = up_proj_(inputs[0]); + x = x * y; + x = down_proj_(x); + return {x}; + } +}; + +class Qwen2_5OmniVisionAttention final : public nn::Module { + int32_t dim_; + int32_t num_heads_; + int32_t head_dim_; + + nn::Linear q_; + nn::Linear k_; + nn::Linear v_; + nn::Linear proj_; + nn::Softmax softmax_; + nn::VisionRoPE vision_rope_q_; + nn::VisionRoPE vision_rope_k_; + + public: + Qwen2_5OmniVisionAttention() = default; + + explicit Qwen2_5OmniVisionAttention(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + dim_ = cfg.visual_hidden_size; + num_heads_ = cfg.visual_num_heads; + head_dim_ = dim_ / num_heads_; + + q_ = reg("q", dim_, dim_, true, cfg.linear_impl_type); + k_ = reg("k", dim_, dim_, true, cfg.linear_impl_type); + v_ = reg("v", dim_, dim_, true, cfg.linear_impl_type); + proj_ = reg("proj", dim_, dim_, true, cfg.linear_impl_type); + softmax_ = reg("softmax", -1); + + vision_rope_q_ = reg("vision_rope_q", aops::VisionRoPEOpOptionsType::kQwen2VL, + aops::Qwen2VLRoPEOpOptions{ + .dims = head_dim_, + .spatial_merge_size = cfg.visual_spatial_merge_size, + .theta = 10000.0, + }); + vision_rope_k_ = reg("vision_rope_k", aops::VisionRoPEOpOptionsType::kQwen2VL, + aops::Qwen2VLRoPEOpOptions{ + .dims = head_dim_, + .spatial_merge_size = cfg.visual_spatial_merge_size, + .theta = 10000.0, + }); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + auto visual_embedding_sin = inputs[1]; + auto visual_embedding_cos = inputs[2]; + auto& mask = inputs[3]; + + auto seq_length = hidden_states.shape()[0]; + + auto query_states = q_(hidden_states).view({seq_length, num_heads_, head_dim_}).unsqueeze(0); + auto key_states = k_(hidden_states).view({seq_length, num_heads_, head_dim_}).unsqueeze(0); + auto value_states = v_(hidden_states).view({seq_length, num_heads_, head_dim_}).unsqueeze(0); + + query_states = vision_rope_q_(query_states, visual_embedding_sin, visual_embedding_cos); + key_states = vision_rope_k_(key_states, visual_embedding_sin, visual_embedding_cos); + + query_states = query_states.transpose(1, 2); + key_states = key_states.transpose(1, 2); + value_states = value_states.transpose(1, 2); + + auto attn = nn::functional::matmul(query_states, key_states, false, true) * (1.f / sqrtf(head_dim_)); + if (mask) { attn = attn + mask; } + attn = softmax_(attn); + + auto attn_output = nn::functional::matmul(attn, value_states); + attn_output = attn_output.transpose(1, 2).view({seq_length, -1}); + attn_output = proj_(attn_output); + return {attn_output}; + } +}; + +class Qwen2_5OmniVisionBlock final : public nn::Module { + nn::RMSNorm norm1_; + nn::RMSNorm norm2_; + + Qwen2_5OmniVisionAttention attn_; + Qwen2_5OmniVisionMLP mlp_; + + public: + Qwen2_5OmniVisionBlock() = default; + + explicit Qwen2_5OmniVisionBlock(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + norm1_ = reg("norm1", 1e-6); + norm2_ = reg("norm2", 1e-6); + attn_ = reg("attn", cfg); + mlp_ = reg("mlp", cfg); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + auto visual_embedding_sin = inputs[1]; + auto visual_embedding_cos = inputs[2]; + auto mask = inputs[3]; + + hidden_states = hidden_states + attn_(norm1_(hidden_states), visual_embedding_sin, visual_embedding_cos, mask)[0]; + hidden_states = hidden_states + mlp_(norm2_(hidden_states))[0]; + return {hidden_states}; + } +}; + +class Qwen2_5OmniVisionEncoder final : public nn::Module { + Qwen2_5OmniPatchEmbed patch_embed_; + Qwen2_5OmniPatchMerger patch_merger_; + nn::ModuleList blocks_; + std::vector visual_fullatt_block_indexes_; + int32_t visual_window_size_ = 0; + int32_t visual_spatial_merge_size_ = 1; + int32_t visual_patch_size_ = 1; + int32_t spatial_merge_unit_ = 1; + + public: + Qwen2_5OmniVisionEncoder() = default; + + explicit Qwen2_5OmniVisionEncoder(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + visual_window_size_ = cfg.visual_window_size; + visual_spatial_merge_size_ = cfg.visual_spatial_merge_size; + visual_patch_size_ = cfg.visual_patch_size; + spatial_merge_unit_ = visual_spatial_merge_size_ * visual_spatial_merge_size_; + visual_fullatt_block_indexes_ = cfg.visual_fullatt_block_indexes; + patch_embed_ = reg("patch_embed", cfg); + patch_merger_ = reg("merger", cfg); + blocks_ = reg>("blocks", cfg.visual_depth, cfg); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + auto embedding_sin = inputs[1]; + auto embedding_cos = inputs[2]; + auto& grid_thw = inputs[3]; + + hidden_states = patch_embed_(hidden_states)[0]; + auto [window_index, cu_window_seqlens] = + makeWindowIndex(grid_thw, visual_window_size_, visual_spatial_merge_size_, visual_patch_size_); + + auto seq_len = hidden_states.shape()[0]; + hidden_states = hidden_states.view({seq_len / spatial_merge_unit_, spatial_merge_unit_, -1}); + hidden_states = hidden_states[{window_index, {kAll}, {kAll}}]; + hidden_states = hidden_states.view({seq_len, -1}); + + embedding_sin = embedding_sin.view({seq_len / spatial_merge_unit_, spatial_merge_unit_, -1}); + embedding_sin = embedding_sin[{window_index, {kAll}, {kAll}}]; + embedding_sin = embedding_sin.view({seq_len, -1}); + embedding_cos = embedding_cos.view({seq_len / spatial_merge_unit_, spatial_merge_unit_, -1}); + embedding_cos = embedding_cos[{window_index, {kAll}, {kAll}}]; + embedding_cos = embedding_cos.view({seq_len, -1}); + + auto mask = Tensor::empty({1, 1, seq_len, seq_len}, DataTypes::kFloat32, DeviceTypes::kCPU).alloc(); + { + auto mask_ptr = mask.ptr(); + const mllm_fp32_t neg_inf = -1e12f; + for (int i = 0; i < seq_len * seq_len; ++i) { mask_ptr[i] = neg_inf; } + for (int i = 1; i < cu_window_seqlens.size(); ++i) { + const int start = cu_window_seqlens[i - 1]; + const int end = cu_window_seqlens[i]; + for (int r = start; r < end; ++r) { + for (int c = start; c < end; ++c) { mask_ptr[r * seq_len + c] = 0.0f; } + } + } + } + + for (auto [layer_idx, b] : enumerate(blocks_.list())) { + if (std::find(visual_fullatt_block_indexes_.begin(), visual_fullatt_block_indexes_.end(), layer_idx) + != visual_fullatt_block_indexes_.end()) { + hidden_states = b(hidden_states, embedding_sin, embedding_cos, Tensor::nil())[0]; + } else { + hidden_states = b(hidden_states, embedding_sin, embedding_cos, mask)[0]; + } + } + + hidden_states = patch_merger_(hidden_states)[0]; + + std::vector reverse_indices(window_index.size()); + std::iota(reverse_indices.begin(), reverse_indices.end(), 0); + std::sort(reverse_indices.begin(), reverse_indices.end(), + [&window_index](int i, int j) { return window_index[i] < window_index[j]; }); + hidden_states = hidden_states[{reverse_indices, {kAll}}]; + + return {hidden_states}; + } +}; + +class Qwen2_5OmniAudioAttention final : public nn::Module { + int32_t embed_dim_ = 0; + int32_t num_heads_ = 0; + int32_t head_dim_ = 0; + + nn::Linear k_proj_; + nn::Linear v_proj_; + nn::Linear q_proj_; + nn::Linear out_proj_; + + public: + Qwen2_5OmniAudioAttention() = default; + + explicit Qwen2_5OmniAudioAttention(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + embed_dim_ = cfg.audio_d_model; + num_heads_ = cfg.audio_encoder_attention_heads; + head_dim_ = embed_dim_ / num_heads_; + + k_proj_ = reg("k_proj", embed_dim_, embed_dim_, false); + v_proj_ = reg("v_proj", embed_dim_, embed_dim_, true); + q_proj_ = reg("q_proj", embed_dim_, embed_dim_, true); + out_proj_ = reg("out_proj", embed_dim_, embed_dim_, true); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; // [seq_len, embed_dim] + auto seq_len = hidden_states.shape()[0]; + + auto hidden = hidden_states.unsqueeze(0); // [1, seq_len, embed_dim] + auto query_states = q_proj_(hidden); + auto key_states = k_proj_(hidden); + auto value_states = v_proj_(hidden); + + query_states = query_states.view({1, seq_len, num_heads_, head_dim_}).transpose(1, 2); + key_states = key_states.view({1, seq_len, num_heads_, head_dim_}).transpose(1, 2); + value_states = value_states.view({1, seq_len, num_heads_, head_dim_}).transpose(1, 2); + + float scale = 1.0f / std::sqrt(static_cast(head_dim_)); + auto attn_weights = nn::functional::matmul(query_states, key_states.transpose(-2, -1)) * scale; + attn_weights = nn::functional::softmax(attn_weights, -1); + auto attn_output = nn::functional::matmul(attn_weights, value_states); + + attn_output = attn_output.transpose(1, 2).contiguous().view({1, seq_len, embed_dim_}); + attn_output = out_proj_(attn_output); + + return {attn_output.squeeze(0)}; + } +}; + +class Qwen2_5OmniAudioEncoderLayer final : public nn::Module { + Qwen2_5OmniAudioAttention self_attn_; + nn::LayerNorm self_attn_layer_norm_; + nn::Linear fc1_; + nn::Linear fc2_; + nn::LayerNorm final_layer_norm_; + nn::GELU activation_fn_; + + public: + Qwen2_5OmniAudioEncoderLayer() = default; + + explicit Qwen2_5OmniAudioEncoderLayer(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + const int32_t embed_dim = cfg.audio_d_model; + self_attn_ = reg("self_attn", cfg); + self_attn_layer_norm_ = + reg("self_attn_layer_norm", std::vector{embed_dim}, true, true, 1e-5); + fc1_ = reg("fc1", embed_dim, cfg.audio_encoder_ffn_dim, true); + fc2_ = reg("fc2", cfg.audio_encoder_ffn_dim, embed_dim, true); + final_layer_norm_ = reg("final_layer_norm", std::vector{embed_dim}, true, true, 1e-5); + activation_fn_ = reg("activation_fn"); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + auto residual = hidden_states; + + hidden_states = self_attn_layer_norm_(hidden_states); + hidden_states = self_attn_(hidden_states)[0]; + hidden_states = residual + hidden_states; + + residual = hidden_states; + hidden_states = final_layer_norm_(hidden_states); + hidden_states = fc1_(hidden_states); + hidden_states = activation_fn_(hidden_states); + hidden_states = fc2_(hidden_states); + hidden_states = residual + hidden_states; + + if (hidden_states.dtype() == kFloat16) { + const float clamp_value = 65504.0f - 1000.0f; + hidden_states = nn::functional::clip(hidden_states, -clamp_value, clamp_value); + } + + return {hidden_states}; + } +}; + +class Qwen2_5OmniAudioEncoder final : public nn::Module { + nn::Conv1D conv1_; + nn::Conv1D conv2_; + nn::GELU gelu_; + nn::ModuleList layers_; + nn::LayerNorm ln_post_; + nn::AvgPool1d avg_pooler_; + nn::Linear proj_; + nn::Embedding audio_bos_eos_token_; + + int32_t num_mel_bins_ = 0; + int32_t embed_dim_ = 0; + int32_t n_window_ = 0; + int32_t output_dim_ = 0; + + public: + Qwen2_5OmniAudioEncoder() = default; + + explicit Qwen2_5OmniAudioEncoder(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + num_mel_bins_ = cfg.audio_num_mel_bins; + embed_dim_ = cfg.audio_d_model; + n_window_ = cfg.audio_n_window; + output_dim_ = cfg.audio_output_dim; + + conv1_ = reg("conv1", num_mel_bins_, embed_dim_, 3, 1, 1); + conv2_ = reg("conv2", embed_dim_, embed_dim_, 3, 2, 1); + gelu_ = reg("gelu"); + audio_bos_eos_token_ = reg("audio_bos_eos_token", 2, cfg.audio_output_dim); + layers_ = reg>("layers", cfg.audio_encoder_layers, cfg); + ln_post_ = reg("ln_post", std::vector{embed_dim_}, true, true, 1e-5); + avg_pooler_ = reg("avg_pooler", 2, 2); + proj_ = reg("proj", embed_dim_, cfg.audio_output_dim, true); + + auto pos_emb = makeAudioSinusoidalPosEmb(cfg.audio_max_source_positions, embed_dim_); + registerBuffer("positional_embedding", pos_emb); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto input_features = inputs[0]; // [B, n_mels, T] + MLLM_RT_ASSERT_EQ(input_features.shape().size(), 3); + + const int32_t batch_size = input_features.shape()[0]; + MLLM_RT_ASSERT_EQ(input_features.shape()[1], num_mel_bins_); + const int32_t feature_len = input_features.shape()[2]; + MLLM_RT_ASSERT(feature_len > 0); + + auto pos_emb = getBuffer("positional_embedding"); + + std::vector audio_outputs; + audio_outputs.reserve(batch_size); + + for (int32_t b = 0; b < batch_size; ++b) { + Tensor audio_b = input_features[make_slice(b), kAll, kAll].view({1, num_mel_bins_, feature_len}).contiguous(); + + const int32_t chunk_size = n_window_ * 2; + const int32_t num_chunks = (feature_len + chunk_size - 1) / chunk_size; + + std::vector chunk_outputs; + chunk_outputs.reserve(num_chunks); + + for (int32_t chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) { + const int32_t start = chunk_idx * chunk_size; + const int32_t chunk_len = std::min(chunk_size, feature_len - start); + auto chunk = Tensor::empty({1, num_mel_bins_, chunk_len}, kFloat32, kCPU).alloc(); + for (int32_t m = 0; m < num_mel_bins_; ++m) { + auto src_ptr = audio_b.offsettedPtr({0, m, start}); + auto dst_ptr = chunk.offsettedPtr({0, m, 0}); + std::memcpy(dst_ptr, src_ptr, chunk_len * sizeof(float)); + } + + auto x = conv1_(chunk); + x = gelu_(x); + x = conv2_(x); + x = gelu_(x); + x = x.transpose(1, 2).contiguous(); // [1, T2, D] + + const int32_t t2 = x.shape()[1]; + MLLM_RT_ASSERT(t2 <= pos_emb.shape()[0]); + auto pos_ptr = pos_emb.ptr(); + auto x_ptr = x.ptr(); + for (int32_t t = 0; t < t2; ++t) { + const float* pos_row = pos_ptr + t * embed_dim_; + float* x_row = x_ptr + t * embed_dim_; + for (int32_t d = 0; d < embed_dim_; ++d) { x_row[d] += pos_row[d]; } + } + + auto hidden_states = x.squeeze(0); // [T2, D] + for (auto& layer : layers_.list()) { hidden_states = layer(hidden_states)[0]; } + if (hidden_states.shape()[0] < 2) { continue; } + + auto pooled = hidden_states.unsqueeze(0).transpose(1, 2); // [1, D, T] + pooled = avg_pooler_(pooled); + pooled = pooled.transpose(1, 2).squeeze(0); // [T', D] + pooled = ln_post_(pooled); + pooled = proj_(pooled); + chunk_outputs.push_back(pooled); + } + + int32_t total_len = 0; + for (const auto& chunk : chunk_outputs) { total_len += chunk.shape()[0]; } + + auto merged = Tensor::empty({total_len, output_dim_}, kFloat32, kCPU).alloc(); + int32_t offset = 0; + for (const auto& chunk : chunk_outputs) { + const int32_t len = chunk.shape()[0]; + const float* src_ptr = chunk.ptr(); + float* dst_ptr = merged.offsettedPtr({offset, 0}); + std::memcpy(dst_ptr, src_ptr, len * output_dim_ * sizeof(float)); + offset += len; + } + + audio_outputs.push_back(merged); + } + + int32_t total_audio_tokens = 0; + for (const auto& out : audio_outputs) { total_audio_tokens += out.shape()[0]; } + + auto output = Tensor::empty({total_audio_tokens, output_dim_}, kFloat32, kCPU).alloc(); + int32_t offset = 0; + for (const auto& out : audio_outputs) { + const int32_t len = out.shape()[0]; + const float* src_ptr = out.ptr(); + float* dst_ptr = output.offsettedPtr({offset, 0}); + std::memcpy(dst_ptr, src_ptr, len * output_dim_ * sizeof(float)); + offset += len; + } + + return {output}; + } +}; + +class Qwen2_5OmniMLP final : public nn::Module { + nn::Linear gate_proj_; + nn::Linear up_proj_; + nn::Linear down_proj_; + nn::SiLU silu_; + + public: + Qwen2_5OmniMLP() = default; + Qwen2_5OmniMLP(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + gate_proj_ = reg("gate_proj", cfg.hidden_size, cfg.intermediate_size, false, cfg.linear_impl_type); + silu_ = reg("act"); + up_proj_ = reg("up_proj", cfg.hidden_size, cfg.intermediate_size, false, cfg.linear_impl_type); + down_proj_ = reg("down_proj", cfg.intermediate_size, cfg.hidden_size, false, cfg.linear_impl_type); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = gate_proj_(inputs[0]); + x = silu_(x); + auto y = up_proj_(inputs[0]); + x = x * y; + x = down_proj_(x); + return {x}; + } +}; + +class Qwen2_5OmniAttention final : public nn::Module { + nn::Linear q_proj_; + nn::Linear k_proj_; + nn::Linear v_proj_; + nn::Linear o_proj_; + nn::MultimodalRoPE q_rope_; + nn::MultimodalRoPE k_rope_; + nn::CausalMask mask_; + nn::Softmax softmax_; + + int hidden_size_; + int head_dim_; + int num_attention_heads_; + int num_key_value_heads_; + int num_key_value_groups_; + + public: + Qwen2_5OmniAttention() = default; + + Qwen2_5OmniAttention(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + hidden_size_ = cfg.hidden_size; + num_attention_heads_ = cfg.num_attention_heads; + num_key_value_heads_ = cfg.num_key_value_heads; + head_dim_ = hidden_size_ / num_attention_heads_; + num_key_value_groups_ = num_attention_heads_ / num_key_value_heads_; + + q_proj_ = reg("q_proj", hidden_size_, head_dim_ * num_attention_heads_, true, cfg.linear_impl_type); + k_proj_ = reg("k_proj", hidden_size_, head_dim_ * num_key_value_heads_, true, cfg.linear_impl_type); + v_proj_ = reg("v_proj", hidden_size_, head_dim_ * num_key_value_heads_, true, cfg.linear_impl_type); + o_proj_ = reg("o_proj", head_dim_ * num_attention_heads_, hidden_size_, false, cfg.linear_impl_type); + + q_rope_ = reg( + "q_rope", aops::Qwen2VLMultimodalRoPEOpOptions{.rope_theta = cfg.rope_theta, + .max_position_embeddings = cfg.max_position_embeddings, + .mrope_section = cfg.mrope_section}); + k_rope_ = reg( + "k_rope", aops::Qwen2VLMultimodalRoPEOpOptions{.rope_theta = cfg.rope_theta, + .max_position_embeddings = cfg.max_position_embeddings, + .mrope_section = cfg.mrope_section}); + + mask_ = reg("mask"); + softmax_ = reg("softmax", -1); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = inputs[0]; + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto past_kv_cache = args[0].get(); + + auto query_states = q_proj_(x); + auto key_states = k_proj_(x); + auto value_states = v_proj_(x); + + int B = inputs[0].shape()[0]; + int S = inputs[0].shape()[1]; + + query_states = query_states.view({B, S, num_attention_heads_, head_dim_}); + key_states = key_states.view({B, S, num_key_value_heads_, head_dim_}); + value_states = value_states.view({B, S, num_key_value_heads_, head_dim_}); + + query_states = query_states.transpose(1, 2); + key_states = key_states.transpose(1, 2); + value_states = value_states.transpose(1, 2); + + query_states = q_rope_(query_states, llm_embedding_sin, llm_embedding_cos); + key_states = k_rope_(key_states, llm_embedding_sin, llm_embedding_cos); + + auto [k, v] = past_kv_cache->updateKVCache(layer_idx_, key_states, value_states); + key_states = k; + value_states = v; + + Tensor attn; + if (key_states.dtype() == kFloat32) { + attn = nn::functional::matmul(query_states, key_states, false, true) * (1.f / sqrtf(head_dim_)); + attn = mask_(attn); + attn = softmax_(attn); + } else if (key_states.dtype() == kFloat16) { + attn = nn::functional::matmul(query_states.to(kFloat32), key_states.to(kFloat32), false, true) * (1.f / sqrtf(head_dim_)); + attn = mask_(attn); + attn = softmax_(attn); + attn = attn.to(kFloat16); + } + + auto output = nn::functional::matmul(attn, value_states); + output = output.transpose(1, 2).view({B, S, num_attention_heads_ * head_dim_}); + output = o_proj_(output); + return {output}; + } + + int layer_idx_; +}; + +class Qwen2_5OmniDecoder final : public nn::Module { + public: + Qwen2_5OmniAttention self_attn_; + Qwen2_5OmniMLP mlp_; + nn::RMSNorm input_layer_norm_; + nn::RMSNorm post_attention_layer_norm_; + + Qwen2_5OmniDecoder() = default; + + Qwen2_5OmniDecoder(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + self_attn_ = reg("self_attn", cfg); + mlp_ = reg("mlp", cfg); + input_layer_norm_ = reg("input_layernorm", cfg.rms_norm_eps); + post_attention_layer_norm_ = reg("post_attention_layernorm", cfg.rms_norm_eps); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto& kv_cache = args[0]; + + auto x = input_layer_norm_(inputs[0]); + x = self_attn_(x, llm_embedding_sin, llm_embedding_cos, kv_cache)[0]; + auto tmp = x + inputs[0]; + x = post_attention_layer_norm_(tmp); + x = mlp_(x)[0]; + x = x + tmp; + return {x}; + } +}; + +class Qwen2_5OmniText final : public nn::Module { + nn::ModuleList decode_blocks_; + nn::RMSNorm norm_; + + public: + Qwen2_5OmniText() = default; + + Qwen2_5OmniText(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + decode_blocks_ = reg>("layers", cfg.num_hidden_layers, cfg); + for (auto [idx, b] : enumerate(decode_blocks_.list())) { b.self_attn_.layer_idx_ = idx; } + + norm_ = reg("norm", cfg.rms_norm_eps); + embedding_ = reg("embed_tokens", cfg.vocab_size, cfg.hidden_size); + + auto inv = makeMultimodalRoPEInvFreq(cfg.hidden_size / cfg.num_attention_heads, cfg.rope_theta); + registerBuffer("inv_freq", inv); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto& blocks = decode_blocks_.list(); + auto x = inputs[0]; + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto& kv_cache = args[0]; + + for (auto& block : blocks) { x = block(x, llm_embedding_sin, llm_embedding_cos, kv_cache)[0]; } + x = norm_(x); + + return {x}; + } + + nn::Embedding embedding_; +}; + +class Qwen2_5OmniThinker final : public nn::Module { + public: + Qwen2_5OmniThinker() = default; + Qwen2_5OmniThinker(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + model_ = reg("model", cfg); + audio_tower_ = reg("audio_tower", cfg); + visual_ = reg("visual", cfg); + lm_head_ = reg("lm_head", cfg.hidden_size, cfg.vocab_size, false, cfg.linear_impl_type); + } + + Qwen2_5OmniText model_; + Qwen2_5OmniAudioEncoder audio_tower_; + Qwen2_5OmniVisionEncoder visual_; + nn::Linear lm_head_; +}; + +class Qwen2_5OmniForCausalLM : public ARGeneration { + public: + explicit Qwen2_5OmniForCausalLM(const Qwen2_5OmniConfig& cfg) : cfg_(cfg), thinker_("thinker", cfg) { + kv_cache_ = nn::StaticCache(cfg.max_cache_length, cfg.num_hidden_layers, + cfg.num_attention_heads, + cfg.num_key_value_heads, + cfg.hidden_size / cfg.num_attention_heads, + kFloat32, + kFloat32, + kCPU, + false); + eos_token_id_ = cfg.eos_token_id; + max_length_ = cfg.max_cache_length; + } + + ARGenerationOutputPast forward(const ARGenerationOutputPast& input, const ARGenerationArgs& args) override { + auto sequence = input.at("sequence"); + + auto input_embeddings = thinker_.model_.embedding_(sequence); + + if (input.count("input_features")) { + auto input_features = input.at("input_features"); + auto audio_embeddings = thinker_.audio_tower_(input_features)[0]; + MLLM_RT_ASSERT_EQ(audio_embeddings.shape()[1], input_embeddings.shape()[2]); + if (audio_embeddings.dtype() != input_embeddings.dtype()) { + audio_embeddings = audio_embeddings.to(input_embeddings.dtype()); + } + + MLLM_RT_ASSERT_EQ(sequence.shape()[0], 1); + auto S = sequence.shape()[1]; + std::vector audio_positions; + audio_positions.reserve(audio_embeddings.shape()[0]); + auto input_ids_ptr = sequence.ptr(); + for (int s = 0; s < S; ++s) { + if (input_ids_ptr[s] == cfg_.audio_token_id) { audio_positions.push_back(s); } + } + MLLM_RT_ASSERT_EQ(static_cast(audio_positions.size()), audio_embeddings.shape()[0]); + + auto D = input_embeddings.shape()[2]; + if (input_embeddings.dtype() == kFloat32) { + for (size_t i = 0; i < audio_positions.size(); ++i) { + auto out_ptr = input_embeddings.offsettedPtr({0, audio_positions[i], 0}); + auto in_ptr = audio_embeddings.offsettedPtr({static_cast(i), 0}); + std::copy(in_ptr, in_ptr + D, out_ptr); + } + } else if (input_embeddings.dtype() == kFloat16) { + for (size_t i = 0; i < audio_positions.size(); ++i) { + auto out_ptr = input_embeddings.offsettedPtr({0, audio_positions[i], 0}); + auto in_ptr = audio_embeddings.offsettedPtr({static_cast(i), 0}); + std::copy(in_ptr, in_ptr + D, out_ptr); + } + } else { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Unsupported embedding dtype for Qwen2.5-Omni audio input."); + } + } + + if (input.count("img")) { + auto img = input.at("img"); + auto grid_thw = input.at("grid_thw"); + + auto inv_freq = makeVisualRoPEInvFreq(cfg_.visual_hidden_size / cfg_.visual_num_heads, 10000.0f); + auto pos_ids = makeVisualRotaryPosEmbIds(grid_thw, cfg_.visual_spatial_merge_size); + + int max_grid = 0; + for (int row = 0; row < grid_thw.shape()[0]; ++row) { + const int* dims = grid_thw.offsettedPtr({row, 0}); + max_grid = std::max({max_grid, dims[1], dims[2]}); + } + MLLM_RT_ASSERT(max_grid > 0); + auto rotary_pos_emb_full = makeVisualRotaryPosEmbFull(inv_freq, max_grid); + auto pos_emb = makeVisualRotaryPosEmb(rotary_pos_emb_full, pos_ids, grid_thw); + auto [visual_embedding_sin, visual_embedding_cos] = makeVisualRotarySinCos(pos_emb); + + auto visual_embeddings = thinker_.visual_(img, visual_embedding_sin, visual_embedding_cos, grid_thw)[0]; + MLLM_RT_ASSERT_EQ(visual_embeddings.shape()[1], input_embeddings.shape()[2]); + if (visual_embeddings.dtype() != input_embeddings.dtype()) { + visual_embeddings = visual_embeddings.to(input_embeddings.dtype()); + } + + MLLM_RT_ASSERT_EQ(sequence.shape()[0], 1); + auto S = sequence.shape()[1]; + std::vector image_positions; + image_positions.reserve(visual_embeddings.shape()[0]); + auto input_ids_ptr = sequence.ptr(); + for (int s = 0; s < S; ++s) { + if (input_ids_ptr[s] == cfg_.image_token_id) { image_positions.push_back(s); } + } + MLLM_RT_ASSERT_EQ(static_cast(image_positions.size()), visual_embeddings.shape()[0]); + + auto D = input_embeddings.shape()[2]; + if (input_embeddings.dtype() == kFloat32) { + for (size_t i = 0; i < image_positions.size(); ++i) { + auto out_ptr = input_embeddings.offsettedPtr({0, image_positions[i], 0}); + auto in_ptr = visual_embeddings.offsettedPtr({static_cast(i), 0}); + std::copy(in_ptr, in_ptr + D, out_ptr); + } + } else if (input_embeddings.dtype() == kFloat16) { + for (size_t i = 0; i < image_positions.size(); ++i) { + auto out_ptr = input_embeddings.offsettedPtr({0, image_positions[i], 0}); + auto in_ptr = visual_embeddings.offsettedPtr({static_cast(i), 0}); + std::copy(in_ptr, in_ptr + D, out_ptr); + } + } else { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Unsupported embedding dtype for Qwen2.5-Omni image input."); + } + } + + Tensor position_ids = input.count("position_ids") ? input.at("position_ids") : Tensor::nil(); + Tensor img = input.count("img") ? input.at("img") : Tensor::nil(); + Tensor grid_thw = input.count("grid_thw") ? input.at("grid_thw") : Tensor::nil(); + position_ids = getPositionIds(img, grid_thw, sequence, position_ids); + + auto [llm_embedding_sin, llm_embedding_cos] = + makeMultimodalPositionEmbedding(position_ids, thinker_.model_.getBuffer("inv_freq"), cfg_.max_position_embeddings, + cfg_.hidden_size / cfg_.num_attention_heads, cfg_.mrope_section); + + auto hidden_states = thinker_.model_(input_embeddings, llm_embedding_sin, llm_embedding_cos, AnyValue(&kv_cache_))[0]; + auto seq_len = hidden_states.shape()[1]; + auto last_hidden = hidden_states[{kAll, {seq_len - 1}, kAll}]; + auto logits = thinker_.lm_head_(last_hidden); + + return { + {"sequence", logits}, + {"position_ids", position_ids}, + }; + } + + Qwen2_5OmniThinker thinker_; + + private: + Tensor getPositionIds(Tensor& img, Tensor& grid_thw, Tensor& input_ids, Tensor& position_ids) const { + MLLM_RT_ASSERT_EQ(input_ids.shape().size(), 2); + + bool has_multimodal = false; + auto input_ids_ptr = input_ids.ptr(); + auto seq_len = input_ids.shape()[1]; + for (int s = 0; s < seq_len; ++s) { + if (input_ids_ptr[s] == cfg_.vision_start_token_id || input_ids_ptr[s] == cfg_.audio_start_token_id) { + has_multimodal = true; + break; + } + } + + if (has_multimodal) { return getPositionIdsPrefill(input_ids, grid_thw); } + + if (!position_ids.isNil()) { + auto last_pos = *position_ids.offsettedPtr({0, 0, position_ids.shape()[2] - 1}); + auto ret_position_ids = Tensor::empty({3, 1, 1}, kInt64, kCPU).alloc(); + *ret_position_ids.offsettedPtr({0, 0, 0}) = last_pos + 1; + *ret_position_ids.offsettedPtr({1, 0, 0}) = last_pos + 1; + *ret_position_ids.offsettedPtr({2, 0, 0}) = last_pos + 1; + return ret_position_ids; + } + + auto B = input_ids.shape()[0]; + auto S = seq_len; + MLLM_RT_ASSERT_EQ(B, 1); + + Tensor out = Tensor::empty({3, B, S}, kInt64, kCPU).alloc(); + for (int d = 0; d < 3; ++d) { + auto out_ptr = out.offsettedPtr({d, 0, 0}); + for (int64_t s = 0; s < S; ++s) { out_ptr[s] = s; } + } + return out; + } + + Tensor getPositionIdsPrefill(Tensor& input_ids, Tensor& image_grid_thw) const { + MLLM_RT_ASSERT_EQ(input_ids.shape().size(), 2); + + auto B = input_ids.shape()[0]; + auto S = input_ids.shape()[1]; + MLLM_RT_ASSERT_EQ(B, 1); + + Tensor position_ids = Tensor::empty({3, B, S}, kInt64, kCPU).alloc(); + + auto input_ids_ptr = input_ids.ptr(); + + auto fill_text_positions = [&](int start_seq, int len, int64_t start_id) { + for (int d = 0; d < 3; ++d) { + auto out_ptr = position_ids.offsettedPtr({d, 0, 0}); + for (int i = 0; i < len; ++i) { out_ptr[start_seq + i] = start_id + i; } + } + }; + + int seq_idx = 0; + int image_idx = 0; + int64_t current_max_position_id = -1; + const int total_images = image_grid_thw.isNil() ? 0 : image_grid_thw.shape()[0]; + + while (seq_idx < S) { + int next_vision = -1; + int next_audio = -1; + for (int i = seq_idx; i < S; ++i) { + if (input_ids_ptr[i] == cfg_.vision_start_token_id) { + next_vision = i; + break; + } + } + for (int i = seq_idx; i < S; ++i) { + if (input_ids_ptr[i] == cfg_.audio_start_token_id) { + next_audio = i; + break; + } + } + + if (next_vision == -1 && next_audio == -1) { + const int text_len = S - seq_idx; + if (text_len > 0) { fill_text_positions(seq_idx, text_len, current_max_position_id + 1); } + break; + } + + const bool is_vision = (next_vision != -1) && (next_audio == -1 || next_vision < next_audio); + const int segment_start = is_vision ? next_vision : next_audio; + + const int text_len = segment_start - seq_idx; + if (text_len > 0) { + fill_text_positions(seq_idx, text_len, current_max_position_id + 1); + current_max_position_id += text_len; + } + + if (is_vision) { + fill_text_positions(segment_start, 1, current_max_position_id + 1); + current_max_position_id += 1; + + int vision_end = -1; + for (int i = segment_start + 1; i < S; ++i) { + if (input_ids_ptr[i] == cfg_.vision_end_token_id) { + vision_end = i; + break; + } + } + MLLM_RT_ASSERT(vision_end != -1); + MLLM_RT_ASSERT(image_idx < total_images); + if (image_grid_thw.isNil()) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Missing grid_thw for Qwen2.5-Omni vision input."); + } + MLLM_RT_ASSERT_EQ(image_grid_thw.shape().size(), 2); + + std::vector image_positions; + for (int i = segment_start + 1; i < vision_end; ++i) { + if (input_ids_ptr[i] == cfg_.image_token_id) { + image_positions.push_back(i); + } else { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Unsupported token inside vision segment."); + } + } + + const int* grid_dims = image_grid_thw.offsettedPtr({image_idx, 0}); + const int grid_t = grid_dims[0]; + const int grid_h = grid_dims[1]; + const int grid_w = grid_dims[2]; + + const int image_token_len = (grid_t * grid_h * grid_w) + / (cfg_.visual_spatial_merge_size * cfg_.visual_spatial_merge_size); + MLLM_RT_ASSERT_EQ(static_cast(image_positions.size()), image_token_len); + + const int inputs_t = grid_t; + const int inputs_h = grid_h / cfg_.visual_spatial_merge_size; + const int inputs_w = grid_w / cfg_.visual_spatial_merge_size; + + const int64_t vision_start_id = current_max_position_id + 1; + int pos_counter = 0; + for (int ti = 0; ti < inputs_t; ++ti) { + const int64_t t_id = vision_start_id + static_cast(ti) * cfg_.position_id_per_seconds; + for (int hi = 0; hi < inputs_h; ++hi) { + for (int wi = 0; wi < inputs_w; ++wi) { + const auto seq_pos = image_positions[pos_counter++]; + *position_ids.offsettedPtr({0, 0, seq_pos}) = t_id; + *position_ids.offsettedPtr({1, 0, seq_pos}) = vision_start_id + hi; + *position_ids.offsettedPtr({2, 0, seq_pos}) = vision_start_id + wi; + } + } + } + + const int64_t dim_0_tail = vision_start_id + static_cast(inputs_t - 1) * cfg_.position_id_per_seconds; + const int64_t dim_1_tail = vision_start_id + inputs_h - 1; + const int64_t dim_2_tail = vision_start_id + inputs_w - 1; + current_max_position_id = std::max({dim_0_tail, dim_1_tail, dim_2_tail}); + + fill_text_positions(vision_end, 1, current_max_position_id + 1); + current_max_position_id += 1; + + seq_idx = vision_end + 1; + image_idx += 1; + } else { + fill_text_positions(segment_start, 1, current_max_position_id + 1); + current_max_position_id += 1; + + int audio_end = -1; + for (int i = segment_start + 1; i < S; ++i) { + if (input_ids_ptr[i] == cfg_.audio_end_token_id) { + audio_end = i; + break; + } + } + MLLM_RT_ASSERT(audio_end != -1); + + std::vector audio_positions; + for (int i = segment_start + 1; i < audio_end; ++i) { + if (input_ids_ptr[i] == cfg_.audio_token_id) { + audio_positions.push_back(i); + } else { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Unsupported token inside audio segment."); + } + } + + const int audio_len = static_cast(audio_positions.size()); + if (audio_len == 0) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Empty audio tokens inside audio segment."); + } + const int64_t audio_start_id = current_max_position_id + 1; + for (int i = 0; i < audio_len; ++i) { + const int64_t pos_id = audio_start_id + i; + for (int d = 0; d < 3; ++d) { + *position_ids.offsettedPtr({d, 0, audio_positions[i]}) = pos_id; + } + } + current_max_position_id += audio_len; + + fill_text_positions(audio_end, 1, current_max_position_id + 1); + current_max_position_id += 1; + + seq_idx = audio_end + 1; + } + } + + MLLM_RT_ASSERT_EQ(image_idx, total_images); + return position_ids; + } + + const Qwen2_5OmniConfig& cfg_; + nn::StaticCache kv_cache_; +}; + +} // namespace mllm::models::qwen2_5omni diff --git a/mllm/models/qwen2_5omni/tokenization_qwen2_5omni.hpp b/mllm/models/qwen2_5omni/tokenization_qwen2_5omni.hpp new file mode 100644 index 000000000..961b5c8f2 --- /dev/null +++ b/mllm/models/qwen2_5omni/tokenization_qwen2_5omni.hpp @@ -0,0 +1,385 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include + +#include "mllm/preprocessor/tokenizers/BPE.hpp" +#include "mllm/preprocessor/tokenizers/Unicode.hpp" +#include "mllm/preprocessor/tokenizers/AutoTokenizer.hpp" +#include "mllm/models/ARGeneration.hpp" +#include "mllm/models/qwen2vl/image_preprocessor_qwen2vl.hpp" +#include "mllm/models/qwen2_5omni/audio_preprocessor_qwen2_5omni.hpp" +#include "mllm/utils/Common.hpp" + +namespace mllm::models::qwen2_5omni { + +// same regex as Qwen2/Qwen2-VL tokenizers +inline bool qwen2_5OmniTokenizerMatchPattern(const std::wstring& str, size_t& pos, std::wstring& matched) { + if (pos >= str.size()) return false; + + static const std::wstring contractions[] = {L"'s", L"'t", L"'re", L"'ve", L"'m", L"'ll", L"'d"}; + for (const auto& contraction : contractions) { + if (pos + contraction.size() <= str.size() && str.compare(pos, contraction.size(), contraction) == 0) { + matched = contraction; + pos += contraction.size(); + return true; + } + } + + { + size_t original_pos = pos; + bool has_prefix = false; + matched.clear(); + + if (!preprocessor::isLetter(str[pos]) && !preprocessor::isDigit(str[pos]) && str[pos] != L'\r' && str[pos] != L'\n') { + matched += str[pos]; + ++pos; + has_prefix = true; + } + + if (pos < str.size() && preprocessor::isLetter(str[pos])) { + do { + matched += str[pos]; + ++pos; + } while (pos < str.size() && preprocessor::isLetter(str[pos])); + return true; + } else if (has_prefix) { + pos = original_pos; + matched.clear(); + } + } + + if (preprocessor::isDigit(str[pos])) { + matched = str.substr(pos, 1); + ++pos; + return true; + } + + { + size_t original_pos = pos; + matched.clear(); + size_t start = pos; + + if (str[pos] == L' ') { ++pos; } + + if (pos < str.size() && !std::iswspace(str[pos]) && !preprocessor::isLetter(str[pos]) && !preprocessor::isDigit(str[pos])) { + do { + ++pos; + } while (pos < str.size() && !std::iswspace(str[pos]) && !preprocessor::isLetter(str[pos]) + && !preprocessor::isDigit(str[pos])); + + matched = str.substr(start, pos - start); + + while (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) { + matched += str[pos]; + ++pos; + } + return true; + } else { + pos = original_pos; + } + } + + { + size_t start = pos; + while (pos < str.size() && std::iswspace(str[pos])) ++pos; + if (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) { + while (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) ++pos; + matched = str.substr(start, pos - start); + return true; + } else { + pos = start; + } + } + + if (std::iswspace(str[pos])) { + size_t start = pos; + while (pos < str.size() && std::iswspace(str[pos])) ++pos; + if (pos >= str.size() || std::iswspace(str[pos])) { + matched = str.substr(start, pos - start); + return true; + } else { + pos = start; + } + } + + if (std::iswspace(str[pos])) { + size_t start = pos; + while (pos < str.size() && std::iswspace(str[pos])) ++pos; + matched = str.substr(start, pos - start); + return true; + } + + return false; +} + +inline bool qwen2_5OmniRegex(const std::string& str, std::vector& splitted) { + auto w_string = preprocessor::utf8string2WideString(str); + size_t pos = 0; + while (pos < w_string.size()) { + std::wstring matched; + if (qwen2_5OmniTokenizerMatchPattern(w_string, pos, matched)) { + splitted.push_back(matched); + } else { + ++pos; + } + } + return true; +} + +struct Qwen2_5OmniMessage { + std::string prompt; + std::string system_prompt = "You are a helpful assistant."; + + [[nodiscard]] std::string buildChatMessage() const { + std::string result; + if (!system_prompt.empty()) { + result += "<|im_start|>system\n" + system_prompt + "<|im_end|>\n"; + } + result += "<|im_start|>user\n" + prompt + "<|im_end|>\n"; + result += "<|im_start|>assistant\n"; + return result; + } +}; + +struct Qwen2_5OmniVisionMessage { + std::string prompt; + std::string img_file_path; + std::string system_prompt = "You are a helpful assistant."; + + [[nodiscard]] std::string buildChatMessage() const { + std::string result; + if (!system_prompt.empty()) { + result += "<|im_start|>system\n" + system_prompt + "<|im_end|>\n"; + } + result += "<|im_start|>user\n<|vision_bos|><|IMAGE|><|vision_eos|>" + prompt + "<|im_end|>\n"; + result += "<|im_start|>assistant\n"; + return result; + } +}; + +struct Qwen2_5OmniAudioMessage { + std::string prompt; + std::string audio_file_path; + std::string system_prompt = "You are a helpful assistant."; + + [[nodiscard]] std::string buildChatMessage() const { + std::string result; + if (!system_prompt.empty()) { + result += "<|im_start|>system\n" + system_prompt + "<|im_end|>\n"; + } + result += "<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>" + prompt + "<|im_end|>\n"; + result += "<|im_start|>assistant\n"; + return result; + } +}; + +class Qwen2_5OmniTokenizer final : public mllm::preprocessor::AutoTokenizer { + public: + explicit Qwen2_5OmniTokenizer(const std::string& file_path, + int32_t spatial_merge_size = 2, + int32_t min_pixels = 56 * 56, + int32_t max_pixels = 1280 * 1280, + int32_t audio_sample_rate = 16000, + int32_t audio_n_mels = 128, + int32_t audio_hop_length = 160, + int32_t audio_chunk_length = 300) + //interestingly, the answer went bad when setting max_pixels higher, eg. 3584*3584) + : image_preprocessor_(min_pixels, max_pixels), + audio_preprocessor_(audio_sample_rate, audio_n_mels, audio_hop_length, audio_chunk_length), + spatial_merge_size_(spatial_merge_size) { + preprocessor::initLocal(); + preprocessor::makeBytes2UnicodeMap(bytes_2_unicode_dict_); + for (auto& kv : bytes_2_unicode_dict_) { bytes_2_unicode_dict_inverse_.insert({kv.second, kv.first}); } + bpe_.initFromSentencePieceJson(file_path); + special_tokens_trie_.add(L"<|endoftext|>"); + special_tokens_trie_.add(L"<|im_start|>"); + special_tokens_trie_.add(L"<|im_end|>"); + special_tokens_trie_.add(L"<|object_ref_start|>"); + special_tokens_trie_.add(L"<|object_ref_end|>"); + special_tokens_trie_.add(L"<|box_start|>"); + special_tokens_trie_.add(L"<|box_end|>"); + special_tokens_trie_.add(L"<|quad_start|>"); + special_tokens_trie_.add(L"<|quad_end|>"); + special_tokens_trie_.add(L"<|vision_bos|>"); + special_tokens_trie_.add(L"<|vision_eos|>"); + special_tokens_trie_.add(L"<|vision_pad|>"); + special_tokens_trie_.add(L"<|image_pad|>"); + special_tokens_trie_.add(L"<|video_pad|>"); + special_tokens_trie_.add(L"<|AUDIO|>"); + special_tokens_trie_.add(L"<|audio_bos|>"); + special_tokens_trie_.add(L"<|audio_eos|>"); + special_tokens_trie_.add(L"<|IMAGE|>"); + special_tokens_trie_.add(L"<|VIDEO|>"); + } + + std::vector _tokenize(const std::string& str) override { + std::vector ret; + std::vector splitted; + ::mllm::models::qwen2_5omni::qwen2_5OmniRegex(str, splitted); + for (const auto& s : splitted) { + auto utf_8_str = preprocessor::wideString2Utf8String(s); + std::wstring mapped_str; + for (unsigned char c : utf_8_str) { mapped_str.push_back(bytes_2_unicode_dict_[c]); } + + auto bpe_ts = bpe_._bpe(mapped_str); + + for (const auto& bpe_t : bpe_ts) { ret.push_back(bpe_t); } + } + + return ret; + } + + std::vector tokenize(const std::string& str) override { + auto tokens = special_tokens_trie_.split(preprocessor::utf8string2WideString(str)); + std::vector all_tokens; + for (const auto& token : tokens) { + if (special_tokens_trie_.isSpecialToken(token)) { + all_tokens.emplace_back(token); + continue; + } + auto tmp_tokens = _tokenize(preprocessor::wideString2Utf8String(token)); + all_tokens.insert(all_tokens.end(), tmp_tokens.begin(), tmp_tokens.end()); + } + return all_tokens; + } + + std::wstring _detokenize(int64_t pos_idx) override { return bpe_._lookup_inverse_vocab(pos_idx); } + + std::wstring detokenize(int64_t pos_idx) override { + auto str = _detokenize(pos_idx); + std::string utf_8_str; + for (wchar_t c : str) { utf_8_str.push_back((unsigned char)(bytes_2_unicode_dict_inverse_[c])); } + return {mllm::preprocessor::utf8string2WideString(utf_8_str)}; + } + + Tensor convert2Ids(const std::vector& strs) override { + std::vector ids; + ids.reserve(strs.size()); + for (const auto& str : strs) { ids.emplace_back(bpe_._lookup_vocab(str)); } + Tensor ret = Tensor::empty({1, static_cast(ids.size())}, kInt64, kCPU) + .setMemType(kExtraInput) + .setName("qwen2_5omni-tokenizer-i0") + .alloc(); + + auto ptr = ret.ptr(); + for (size_t i = 0; i < ids.size(); ++i) { ptr[i] = ids[i]; } + + return ret; + } + + ARGenerationOutputPast convertMessage(const Qwen2_5OmniMessage& message) { + auto applied_string = message.buildChatMessage(); + auto sequence_str = tokenize(applied_string); + + std::vector ids; + ids.reserve(sequence_str.size()); + for (const auto& str : sequence_str) { ids.emplace_back(bpe_._lookup_vocab(str)); } + + Tensor sequence = Tensor::empty({1, static_cast(ids.size())}, kInt64, kCPU) + .setMemType(kNormal) + .setName("qwen2_5omni-tokenizer-i0") + .alloc(); + + auto ptr = sequence.ptr(); + for (size_t i = 0; i < ids.size(); ++i) { ptr[i] = ids[i]; } + + return {{"sequence", sequence}}; + } + + ARGenerationOutputPast convertVisionMessage(const Qwen2_5OmniVisionMessage& message) { + auto applied_string = message.buildChatMessage(); + + auto [img, grid_thw] = image_preprocessor_(message.img_file_path); + + auto sequence_str = tokenize(applied_string); + std::vector ids; + ids.reserve(sequence_str.size()); + for (const auto& str : sequence_str) { ids.emplace_back(bpe_._lookup_vocab(str)); } + + auto grid_t = grid_thw.ptr()[0]; + auto grid_h = grid_thw.ptr()[1]; + auto grid_w = grid_thw.ptr()[2]; + int32_t img_token_nums = grid_t * grid_h * grid_w; + img_token_nums /= (spatial_merge_size_ * spatial_merge_size_); + + auto image_token_id = bpe_._lookup_vocab(L"<|IMAGE|>"); + { + auto it = std::find(ids.begin(), ids.end(), image_token_id); + if (it == ids.end()) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Missing <|IMAGE|> token in Qwen2.5-Omni prompt template."); + } + ids.insert(it + 1, img_token_nums - 1, image_token_id); + } + + Tensor sequence = Tensor::empty({1, static_cast(ids.size())}, kInt64, kCPU) + .setMemType(kNormal) + .setName("qwen2_5omni-tokenizer-i0") + .alloc(); + + auto ptr = sequence.ptr(); + for (size_t i = 0; i < ids.size(); ++i) { ptr[i] = ids[i]; } + + return { + {"sequence", sequence}, + {"img", img}, + {"grid_thw", grid_thw}, + }; + } + + ARGenerationOutputPast convertAudioMessage(const Qwen2_5OmniAudioMessage& message) { + auto applied_string = message.buildChatMessage(); + auto sequence_str = tokenize(applied_string); + + std::vector ids; + ids.reserve(sequence_str.size()); + for (const auto& str : sequence_str) { ids.emplace_back(bpe_._lookup_vocab(str)); } + + auto audio_result = audio_preprocessor_.processAudioFile(message.audio_file_path); + if (audio_result.input_features.isNil() || audio_result.feature_length <= 0) { + MLLM_ERROR_EXIT(ExitCode::kIOError, "Failed to extract audio features for Qwen2.5-Omni."); + } + + int32_t audio_token_nums = audio_preprocessor_.calcAudioTokenLength(audio_result.feature_length); + if (audio_token_nums <= 0) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Invalid audio token length for Qwen2.5-Omni."); + } + + auto audio_token_id = bpe_._lookup_vocab(L"<|AUDIO|>"); + { + auto it = std::find(ids.begin(), ids.end(), audio_token_id); + if (it == ids.end()) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Missing <|AUDIO|> token in Qwen2.5-Omni prompt template."); + } + ids.insert(it + 1, audio_token_nums - 1, audio_token_id); + } + + Tensor sequence = Tensor::empty({1, static_cast(ids.size())}, kInt64, kCPU) + .setMemType(kNormal) + .setName("qwen2_5omni-tokenizer-i0") + .alloc(); + + auto ptr = sequence.ptr(); + for (size_t i = 0; i < ids.size(); ++i) { ptr[i] = ids[i]; } + + audio_result.input_features.setName("input_features"); + + return { + {"sequence", sequence}, + {"input_features", audio_result.input_features}, + }; + } + + private: + preprocessor::BPE bpe_; + std::unordered_map bytes_2_unicode_dict_; + std::unordered_map bytes_2_unicode_dict_inverse_; + mllm::models::qwen2vl::Qwen2VLImagePreprocessor image_preprocessor_; + Qwen2_5OmniAudioPreprocessor audio_preprocessor_; + int32_t spatial_merge_size_ = 2; +}; + +} // namespace mllm::models::qwen2_5omni diff --git a/pymllm/__init__.py b/pymllm/__init__.py index 66240b714..1bd31cd6c 100644 --- a/pymllm/__init__.py +++ b/pymllm/__init__.py @@ -12,12 +12,27 @@ from . import service from . import backends from .ffi import ( + # Floating point types float32, float16, bfloat16, + # Signed integer types + int8, + int16, + int32, + int64, + # Unsigned integer types + uint8, + uint16, + uint32, + uint64, + # Bool type + boolean, + # Devices cpu, cuda, qnn, + # Tensor and utilities Tensor, empty, echo, @@ -26,7 +41,6 @@ is_numpy_available, from_torch, from_numpy, - empty, zeros, ones, arange, diff --git a/pymllm/backends/qualcomm/transformers/core/embedding.py b/pymllm/backends/qualcomm/transformers/core/embedding.py new file mode 100644 index 000000000..84c4d61fe --- /dev/null +++ b/pymllm/backends/qualcomm/transformers/core/embedding.py @@ -0,0 +1,133 @@ +import torch +import torch.nn as nn +from torch.ao.quantization import FakeQuantize, MinMaxObserver + + +class QEmbedding(nn.Module): + def __init__( + self, + num_embeddings, + embedding_dim, + padding_idx=None, + quant_bits=16, + ): + super().__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.quant_bits = quant_bits + + self.weight = nn.Parameter(torch.empty(num_embeddings, embedding_dim)) + nn.init.normal_(self.weight) + + if padding_idx is not None: + with torch.no_grad(): + self.weight[padding_idx].fill_(0) + + # Quantization configuration for Weight + self.weight_fake_quant = FakeQuantize( + observer=MinMaxObserver.with_args( + qscheme=torch.per_tensor_affine, + dtype=torch.qint32, + eps=0.0001 / 65535, + ), + quant_min=0, + quant_max=2 ** (quant_bits) - 1, + dtype=torch.qint32, + qscheme=torch.per_tensor_affine, + ) + + def forward(self, x): + # 1. Weight fake quantization + # If observer is not closed, this step will continuously update scale/zp + # If freeze_weight() is called, this will just use fixed scale/zp for quantization + w_q = self.weight_fake_quant(self.weight) + + # 2. Embedding lookup (Gather operation) + return nn.functional.embedding( + x, + w_q, + padding_idx=self.padding_idx, + ) + + @torch.no_grad() + def convert_to_deploy(self): + """ + In-place replacement of self.weight: + Float Parameter -> Int Buffer + """ + # 1. Ensure quantization parameters are ready + if self.weight_fake_quant.scale is None: + self.freeze_weight() + + scale = self.weight_fake_quant.scale + zero_point = self.weight_fake_quant.zero_point + quant_min = self.weight_fake_quant.quant_min + quant_max = self.weight_fake_quant.quant_max + + # 2. Calculate integer values + # w_int = round(w / s + zp) + w_int = torch.round(self.weight / scale + zero_point).clamp( + quant_min, quant_max + ) + + # 3. Set target integer type + if self.quant_bits <= 8: + target_dtype = torch.uint8 + elif self.quant_bits <= 16: + target_dtype = torch.uint16 + else: + target_dtype = torch.uint32 + + w_int = w_int.to(target_dtype) + + # === Key steps: Replacement operations === + + # A. Delete original Parameter 'weight' + # Must delete first, otherwise cannot register buffer with same name + del self.weight + + # B. Register Buffer with same name 'weight' + # This makes state_dict['weight'] become Int Tensor + self.register_buffer("weight", w_int) + + # C. Register Scale (usually needed by engine) + self.register_buffer("scale", scale) + self.register_buffer("zero_point", zero_point) + + # D. Clean up unnecessary modules + if hasattr(self, "weight_fake_quant"): + del self.weight_fake_quant + + class_name = self.__class__.__name__ + instance_class_name = type(self).__name__ + print( + f"Class: {class_name}, Instance: {instance_class_name}, Deploy Mode Activated. 'weight' is now {self.weight.dtype} buffer. zp is {zero_point}" + ) + + @torch.no_grad() + def freeze_weight(self): + """ + Manually trigger Observer to observe and calculate scale, then lock it. + Solve the problem of output being 0 on first run. + """ + self.weight_fake_quant.activation_post_process(self.weight) + s, zp = self.weight_fake_quant.activation_post_process.calculate_qparams() + self.weight_fake_quant.scale.copy_(s) + self.weight_fake_quant.zero_point.copy_(zp) + self.weight_fake_quant.disable_observer() + class_name = self.__class__.__name__ + instance_class_name = type(self).__name__ + print( + f"Class: {class_name}, Instance: {instance_class_name}, Weight Quantized: scale={self.weight_fake_quant.scale}, zp={self.weight_fake_quant.zero_point}" + ) + + def disable_quant(self): + """Completely turn off quantization noise and return to floating point mode""" + self.weight_fake_quant.disable_fakequant() + + def extra_repr(self): + s = f"{self.num_embeddings}, {self.embedding_dim}" + if self.padding_idx is not None: + s += f", padding_idx={self.padding_idx}" + return s diff --git a/pymllm/backends/qualcomm/transformers/core/observer.py b/pymllm/backends/qualcomm/transformers/core/observer.py new file mode 100644 index 000000000..67a946b10 --- /dev/null +++ b/pymllm/backends/qualcomm/transformers/core/observer.py @@ -0,0 +1,56 @@ +import torch +from torchao.quantization.pt2e import UniformQuantizationObserverBase + + +class ConcatObserver(UniformQuantizationObserverBase): + """ + Fetch maximum data range of all tensors to be concatenated + """ + + def __init__( + self, + dtype=torch.uint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps, # noqa: B008 + is_dynamic=False, + **kwargs, + ) -> None: + super().__init__( + dtype=dtype, + qscheme=qscheme, + reduce_range=reduce_range, + quant_min=quant_min, + quant_max=quant_max, + factory_kwargs=factory_kwargs, + eps=eps, + is_dynamic=is_dynamic, + **kwargs, + ) + + factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) + self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs)) + self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs)) + # get concat node and its inputs + self.input_observers = [] + + def add_observer(self, observer): + self.input_observers.append(observer) + + def forward(self, x_orig): + # calculate the min / max first + self.min_val = min(self.min_val, x_orig.min()) + self.max_val = max(self.max_val, x_orig.max()) + + # update min / max for all observers of input nodes + for observers in self.input_observers: + observers.min_val = self.min_val + observers.max_val = self.max_val + + return x_orig + + def calculate_qparams(self): + return self._calculate_qparams(self.min_val, self.max_val) diff --git a/pymllm/backends/qualcomm/transformers/core/qdq.py b/pymllm/backends/qualcomm/transformers/core/qdq.py index ce67729f4..c13011a51 100644 --- a/pymllm/backends/qualcomm/transformers/core/qdq.py +++ b/pymllm/backends/qualcomm/transformers/core/qdq.py @@ -1,6 +1,13 @@ import torch import torch.nn as nn -from torch.ao.quantization import FakeQuantize, MinMaxObserver +from torch.ao.quantization import ( + FakeQuantize, + MinMaxObserver, +) +from torch.ao.quantization.observer import FixedQParamsObserver + +DEFAULT_EPS_8BIT = 0.0001 / 255 +DEFAULT_EPS_16BIT = 0.0001 / 65535 class ActivationQDQ(nn.Module): @@ -30,16 +37,24 @@ def __init__(self, bits=8, qscheme=torch.per_tensor_affine): self.quant_min = 0 self.quant_max = (2**bits) - 1 + if bits == 8: + eps = DEFAULT_EPS_8BIT + elif bits == 16: + eps = DEFAULT_EPS_16BIT + else: + raise ValueError(f"Unsupported bit width: {bits}") + # 2. Initialize FakeQuantize - # MinMaxObserver calculates scale and zero_point based on observed tensors. + # MovingAverageMinMaxObserver calculates scale and zero_point based on observed tensors. # Passing quant_min/max to the observer ensures consistency. self.fake_quant = FakeQuantize( observer=MinMaxObserver.with_args( - qscheme=self.qscheme, dtype=self.dtype, + qscheme=self.qscheme, quant_min=self.quant_min, quant_max=self.quant_max, reduce_range=False, + eps=eps, ), quant_min=self.quant_min, quant_max=self.quant_max, @@ -63,12 +78,106 @@ def disable_observer(self): def enable_fakequant(self): """Enable simulation of quantization error.""" - self.fake_quant.enable_fakequant() + self.fake_quant.enable_fake_quant() def disable_fakequant(self): """Disable quantization simulation (act as identity).""" - self.fake_quant.disable_fakequant() + self.fake_quant.disable_fake_quant() def extra_repr(self): mode = "Symmetric" if "symmetric" in str(self.qscheme) else "Asymmetric" return f"bits={self.bits}, mode={mode}, q_range=({self.quant_min}, {self.quant_max}), dtype={self.dtype}" + + +class FixedActivationQDQ(nn.Module): + """ + Fixed activation Quantization-DeQuantization (QDQ) module. + Uses pre-determined scale and zero_point instead of dynamic observation. + Supports both Symmetric and Asymmetric (Affine) quantization. + Uses torch.qint32 as a unified type to support various bit-widths. + """ + + def __init__(self, scale, zero_point, bits=8, qscheme=torch.per_tensor_affine): + super().__init__() + self.bits = bits + self.qscheme = qscheme + + # Define the simulation dtype as qint32 to avoid overflow across different bit-widths + self.dtype = torch.qint32 + + # 1. Calculate quantization range based on bits and scheme + if qscheme in [torch.per_tensor_symmetric, torch.per_channel_symmetric]: + # Symmetric: range is [-(2^(bits-1)), 2^(bits-1) - 1] + # e.g., 8-bit: -128 to 127 + self.quant_min = -(2 ** (bits - 1)) + self.quant_max = 2 ** (bits - 1) - 1 + else: + # Asymmetric (Affine): range is [0, 2^bits - 1] + # e.g., 8-bit: 0 to 255 + self.quant_min = 0 + self.quant_max = (2**bits) - 1 + + if bits not in [8, 16]: + raise ValueError(f"Unsupported bit width: {bits}") + + # 2. Convert scale and zero_point to tensors if needed + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale, dtype=torch.float32) + if not isinstance(zero_point, torch.Tensor): + zero_point = torch.tensor(zero_point, dtype=torch.int32) + + # 3. Initialize FakeQuantize with fixed parameters + # Use FakeQuantize with FixedQParamsObserver for fixed scale and zero_point + self.fake_quant = FakeQuantize.with_args( + observer=FixedQParamsObserver.with_args( + scale=scale, + zero_point=zero_point, + ), + dtype=self.dtype, + qscheme=self.qscheme, + quant_min=self.quant_min, + quant_max=self.quant_max, + )() + + def forward(self, x): + # Applies fake quantization with fixed scale and zero_point: + # rounds to nearest integer and clamps to [min, max], + # then dequantizes back to float to simulate quantization noise. + return self.fake_quant(x) + + # Control methods for quantization-aware training (QAT) + # Note: FixedActivationQDQ doesn't have observer, so these methods + # only control fake quantization behavior + def enable_observer(self): + """No-op: FixedActivationQDQ doesn't use observer.""" + pass + + def disable_observer(self): + """No-op: FixedActivationQDQ doesn't use observer.""" + pass + + def enable_fakequant(self): + """Enable simulation of quantization error.""" + self.fake_quant.enable_fake_quant() + + def disable_fakequant(self): + """Disable quantization simulation (act as identity).""" + self.fake_quant.disable_fake_quant() + + @property + def scale(self): + """Get the fixed scale value.""" + return self.fake_quant.scale + + @property + def zero_point(self): + """Get the fixed zero_point value.""" + return self.fake_quant.zero_point + + def extra_repr(self): + mode = "Symmetric" if "symmetric" in str(self.qscheme) else "Asymmetric" + scale_val = self.scale.item() if self.scale.numel() == 1 else self.scale + zp_val = ( + self.zero_point.item() if self.zero_point.numel() == 1 else self.zero_point + ) + return f"bits={self.bits}, mode={mode}, scale={scale_val}, zero_point={zp_val}, q_range=({self.quant_min}, {self.quant_max}), dtype={self.dtype}" diff --git a/pymllm/backends/qualcomm/transformers/core/qlinear.py b/pymllm/backends/qualcomm/transformers/core/qlinear.py index d9c55e759..255f52ffb 100644 --- a/pymllm/backends/qualcomm/transformers/core/qlinear.py +++ b/pymllm/backends/qualcomm/transformers/core/qlinear.py @@ -296,7 +296,9 @@ def convert_to_conv2d_deploy_hwio(self): s1_permuted = ( s1.view(self.out_features, -1).t().contiguous() ) # [Out, Blocks] -> [Blocks, Out] - s1_hwio = s1_permuted.view(1, 1, -1, self.out_features) # Shape: [1, 1, Blocks, Out] + s1_hwio = s1_permuted.view( + 1, 1, -1, self.out_features + ) # Shape: [1, 1, Blocks, Out] del self.weight self.register_buffer("weight", w_hwio) diff --git a/pymllm/backends/qualcomm/transformers/core/rms_norm.py b/pymllm/backends/qualcomm/transformers/core/rms_norm.py index 0101d6aee..b3964469f 100644 --- a/pymllm/backends/qualcomm/transformers/core/rms_norm.py +++ b/pymllm/backends/qualcomm/transformers/core/rms_norm.py @@ -21,7 +21,9 @@ def __init__( # Quantization configuration for Weight self.weight_fake_quant = FakeQuantize( observer=MinMaxObserver.with_args( - qscheme=torch.per_tensor_affine, dtype=torch.qint32 + qscheme=torch.per_tensor_affine, + dtype=torch.qint32, + eps=0.0001 / 65535, ), quant_min=0, quant_max=2 ** (quant_bits) - 1, diff --git a/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py b/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py index 9c0696328..cf71a48ba 100644 --- a/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py +++ b/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py @@ -49,9 +49,13 @@ from pymllm.backends.qualcomm.transformers.core.rms_norm import QRMSNorm from pymllm.backends.qualcomm.transformers.core.qlinear import ( QLinearLPBQ, - QLinearW8A16_PerChannelSym, ) -from pymllm.backends.qualcomm.transformers.core.qdq import ActivationQDQ +from pymllm.backends.qualcomm.transformers.core.qdq import ( + ActivationQDQ, + FixedActivationQDQ, +) +from pymllm.backends.qualcomm.transformers.core.embedding import QEmbedding +from pymllm.backends.qualcomm.transformers.core.observer import ConcatObserver class Qwen3MLP(nn.Module): @@ -76,7 +80,12 @@ def __init__(self, config): self.gate_proj_output_qdq = ActivationQDQ(bits=16) self.act_output_qdq = ActivationQDQ(bits=16) self.down_proj_input_qdq = ActivationQDQ(bits=16) - self.sigmoid_output_qdq = ActivationQDQ(bits=16) + # For sigmoid output: scale = 1 / (q_max - q_min + 1), zp = 0 + # For 16-bit: q_min = 0, q_max = 65535 + sigmoid_scale = 1.0 / (65535 - 0 + 1) # 1 / 65536 + self.sigmoid_output_qdq = FixedActivationQDQ( + scale=sigmoid_scale, zero_point=0, bits=16 + ) def forward(self, x): x = self.up_proj_input_qdq(x) @@ -93,11 +102,13 @@ def forward(self, x): return o -def rotate_half(x): +def rotate_half( + x, x_observer, x2_neg_fake_quant: ActivationQDQ, concat_observer: ConcatObserver +): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) + return concat_observer(torch.cat((x2_neg_fake_quant(-x2), x1), dim=-1)) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): @@ -207,6 +218,39 @@ def __init__(self, config: Qwen3Config, layer_idx: int): self.k_rope_mul_1_output_qdq = ActivationQDQ(bits=16) self.k_rope_add_0_output_qdq = ActivationQDQ(bits=16) + self.q_rope_concat_observer = ConcatObserver( + dtype=torch.int32, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=0, + quant_max=2**16 - 1, + eps=0.0001 / 65535, + is_dynamic=False, + ) + self.q_rope_neg_half_qdq = ActivationQDQ(bits=16) + self.k_rope_concat_observer = ConcatObserver( + dtype=torch.int32, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=0, + quant_max=2**16 - 1, + eps=0.0001 / 65535, + is_dynamic=False, + ) + self.k_rope_neg_half_qdq = ActivationQDQ(bits=16) + self.k_rope_concat_observer.add_observer( + self.k_norm_output_qdq.fake_quant.activation_post_process + ) + self.k_rope_concat_observer.add_observer( + self.k_rope_neg_half_qdq.fake_quant.activation_post_process + ) + self.q_rope_concat_observer.add_observer( + self.q_norm_output_qdq.fake_quant.activation_post_process + ) + self.q_rope_concat_observer.add_observer( + self.q_rope_neg_half_qdq.fake_quant.activation_post_process + ) + # In qnn, is uint8 sym. self.k_cast_to_int8_qdq = ActivationQDQ( bits=8, qscheme=torch.per_tensor_symmetric @@ -224,6 +268,7 @@ def __init__(self, config: Qwen3Config, layer_idx: int): self.minus_0_output_qdq = ActivationQDQ(bits=16) self.softmax_output_qdq = ActivationQDQ(bits=16) self.attn_value_matmul_output_qdq = ActivationQDQ(bits=16) + self.where_attn_qdq = ActivationQDQ(bits=16) @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( @@ -256,11 +301,27 @@ def forward( sin = sin.unsqueeze(1) query_states = self.q_rope_add_0_output_qdq( self.q_rope_mul_0_output_qdq(query_states * cos) - + self.q_rope_mul_1_output_qdq(rotate_half(query_states) * sin) + + self.q_rope_mul_1_output_qdq( + rotate_half( + query_states, + self.q_norm_output_qdq.fake_quant.activation_post_process, + self.q_rope_neg_half_qdq, + self.q_rope_concat_observer, + ) + * sin + ) ) key_states = self.k_rope_add_0_output_qdq( self.k_rope_mul_0_output_qdq(key_states * cos) - + self.k_rope_mul_1_output_qdq(rotate_half(key_states) * sin) + + self.k_rope_mul_1_output_qdq( + rotate_half( + key_states, + self.k_norm_output_qdq.fake_quant.activation_post_process, + self.k_rope_neg_half_qdq, + self.k_rope_concat_observer, + ) + * sin + ) ) key_states = self.k_cast_to_int8_qdq(key_states) @@ -281,7 +342,7 @@ def forward( torch.matmul(query_states, key_states.transpose(2, 3)) ) * self.scaling_qdq( - torch.ones(1, dtype=torch.bfloat16, device=value_states.device) + torch.ones(1, dtype=value_states.dtype, device=value_states.device) * self.scaling ) ) @@ -292,10 +353,13 @@ def forward( attn_vv = self.minus_0_output_qdq( attn_min + self.neg_20_qdq( - torch.ones(1, dtype=torch.bfloat16, device=value_states.device) * (-20) + torch.ones(1, dtype=value_states.dtype, device=value_states.device) + * (-20) ) ) - attn_weights = torch.where(attention_mask == 0, attn_weights, attn_vv) + attn_weights = self.where_attn_qdq( + torch.where(attention_mask == 0, attn_weights, attn_vv) + ) attn_weights = self.softmax_output_qdq( nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( @@ -315,6 +379,7 @@ def forward( class Qwen3DecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Qwen3Config, layer_idx: int): super().__init__() + self.layer_dix = layer_idx self.hidden_size = config.hidden_size self.self_attn = Qwen3Attention(config=config, layer_idx=layer_idx) @@ -329,7 +394,8 @@ def __init__(self, config: Qwen3Config, layer_idx: int): self.attention_type = config.layer_types[layer_idx] # QDQ - self.input_layernorm_input_qdq = ActivationQDQ(bits=16) + if self.layer_dix != 0: + self.input_layernorm_input_qdq = ActivationQDQ(bits=16) self.add_0_lhs_input_qdq = ActivationQDQ(bits=16) self.add_0_output_qdq = ActivationQDQ(bits=16) self.add_1_lhs_input_qdq = ActivationQDQ(bits=16) @@ -348,7 +414,8 @@ def forward( ] = None, # necessary, but kept here for BC **kwargs: Unpack[TransformersKwargs], ) -> torch.Tensor: - hidden_states = self.input_layernorm_input_qdq(hidden_states) + if self.layer_dix != 0: + hidden_states = self.input_layernorm_input_qdq(hidden_states) residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention @@ -362,6 +429,7 @@ def forward( position_embeddings=position_embeddings, **kwargs, ) + hidden_states = self.add_0_output_qdq( residual + self.add_0_lhs_input_qdq(hidden_states) ) @@ -448,9 +516,8 @@ def __init__(self, config: Qwen3Config): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding( - config.vocab_size, config.hidden_size, self.padding_idx + self.embed_tokens = QEmbedding( + config.vocab_size, config.hidden_size, self.padding_idx, quant_bits=16 ) self.layers = nn.ModuleList( [ @@ -567,6 +634,12 @@ def forward( self.mllm_max_cos_embedding, self.mllm_max_sin_embedding = self.rotary_emb( hidden_states, max_position_ids ) + self.mllm_max_cos_embedding = self.mllm_max_cos_embedding.to( + inputs_embeds.dtype + ) + self.mllm_max_sin_embedding = self.mllm_max_sin_embedding.to( + inputs_embeds.dtype + ) self.mllm_max_cos_embedding = self.cos_embedding_input_qdq( self.mllm_max_cos_embedding ) diff --git a/pymllm/backends/qualcomm/transformers/qwen3/runner.py b/pymllm/backends/qualcomm/transformers/qwen3/runner.py index 53ab40a9e..416816875 100644 --- a/pymllm/backends/qualcomm/transformers/qwen3/runner.py +++ b/pymllm/backends/qualcomm/transformers/qwen3/runner.py @@ -2,12 +2,16 @@ from tqdm import tqdm from modelscope.msdatasets import MsDataset from transformers import AutoTokenizer -from pymllm.backends.qualcomm.transformers.core.qdq import ActivationQDQ +from pymllm.backends.qualcomm.transformers.core.qdq import ( + ActivationQDQ, + FixedActivationQDQ, +) from pymllm.backends.qualcomm.transformers.core.rms_norm import QRMSNorm from pymllm.backends.qualcomm.transformers.core.qlinear import ( QLinearLPBQ, QLinearW8A16_PerChannelSym, ) +from pymllm.backends.qualcomm.transformers.core.embedding import QEmbedding from pymllm.backends.qualcomm.transformers.qwen3.modeling_qwen3 import Qwen3ForCausalLM @@ -31,11 +35,23 @@ def enable_qdq_observer(m): m.enable_observer() +def enable_fake_quant(m): + if isinstance(m, ActivationQDQ) or isinstance(m, FixedActivationQDQ): + m.enable_fakequant() + + +def disable_fake_quant(m): + if isinstance(m, ActivationQDQ) or isinstance(m, FixedActivationQDQ): + m.disable_fakequant() + + def convert_weight(m): if isinstance(m, QLinearLPBQ) or isinstance(m, QLinearW8A16_PerChannelSym): m.convert_to_conv2d_deploy_hwio() if isinstance(m, QRMSNorm): m.convert_to_deploy() + if isinstance(m, QEmbedding): + m.convert_to_deploy() class Qwen3Quantizer: @@ -44,6 +60,7 @@ def __init__(self, model_path: str, mllm_qualcomm_max_length=2048): self.model = Qwen3ForCausalLM.from_pretrained( model_path, attn_implementation="eager", + dtype=torch.float32, ) self.model.cuda() self.mllm_qualcomm_max_length = mllm_qualcomm_max_length @@ -60,6 +77,12 @@ def freeze_activation(self): def enable_activation_update(self): self.model.apply(enable_qdq_observer) + def enable_fake_quant(self): + self.model.apply(enable_fake_quant) + + def disable_fake_quant(self): + self.model.apply(disable_fake_quant) + def compile(self): print("Compile Start.") self.model = torch.compile( diff --git a/pymllm/backends/qualcomm/transformers/qwen3/train.py b/pymllm/backends/qualcomm/transformers/qwen3/train.py index 13ad2785a..9c4604d8f 100644 --- a/pymllm/backends/qualcomm/transformers/qwen3/train.py +++ b/pymllm/backends/qualcomm/transformers/qwen3/train.py @@ -37,16 +37,24 @@ def main(): args = parser.parse_args() m = Qwen3Quantizer(args.model_path, mllm_qualcomm_max_length=args.max_length) + + # FIXME: Should disable or not. + m.disable_fake_quant() m.calibrate(num_samples=args.num_samples, max_seq_length=args.max_length) - # m.compile() + m.enable_fake_quant() m.infer(args.infer_text) # !!! # Things below is for deploy. We will turn all fp32 weights and some buffers(rope) to quantized dtype. # !!! - m.model.lm_head.weight = torch.nn.Parameter( - m.model.model.embed_tokens.weight.clone() - ) + # This line maybe error. we need use quantized weight!!! not embed_tokens.weight!!! + # m.model.lm_head.weight = torch.nn.Parameter( + # m.model.model.embed_tokens.weight.clone() + # ) + if "1.7B" in args.model_path: + raise ValueError( + "1.7B model is not supported for now due to tied embedding weights is not supported." + ) m.convert() os.makedirs(args.output_dir, exist_ok=True) diff --git a/pymllm/convertor/model_file_v2.py b/pymllm/convertor/model_file_v2.py index 302e3e21b..976c04411 100644 --- a/pymllm/convertor/model_file_v2.py +++ b/pymllm/convertor/model_file_v2.py @@ -24,6 +24,14 @@ MLLM_MODEL_FILE_V2_TENSOR_SHAPE_LENGTH = 16 +def _torch_tensor_bytes(tensor: "torch.Tensor") -> bytes: + # Use uint8 view to preserve raw bytes for dtypes not supported by numpy. + t = tensor.detach().cpu().contiguous() + if t.dim() == 0: + t = t.reshape(1) + return t.view(torch.uint8).numpy().tobytes() + + class ModelFileV2Descriptor: SIZE = 532 @@ -132,7 +140,7 @@ def streaming_write(self, tensor_name, tensor_obj): if MLLM_FIND_TORCH_AVAILABLE and isinstance(tensor_obj, torch.Tensor): # PyTorch tensor shape = list(tensor_obj.shape) - tensor_data = tensor_obj.detach().cpu().numpy().tobytes() + tensor_data = _torch_tensor_bytes(tensor_obj) true_dtype = MLLM_TYPE_MAPPING[tensor_obj.dtype] elif MLLM_FIND_NUMPY_AVAILABLE and isinstance(tensor_obj, np.ndarray): # Numpy array @@ -203,7 +211,7 @@ def static_write(self, tensor_obj): if MLLM_FIND_TORCH_AVAILABLE and isinstance(tensor, torch.Tensor): # PyTorch tensor shape = list(tensor.shape) - tensor_data = tensor.detach().cpu().numpy().tobytes() + tensor_data = _torch_tensor_bytes(tensor) true_dtype = MLLM_TYPE_MAPPING[tensor.dtype] elif MLLM_FIND_NUMPY_AVAILABLE and isinstance(tensor, np.ndarray): # Numpy array diff --git a/pymllm/ffi/__init__.py b/pymllm/ffi/__init__.py index 17bd04c19..9780eabb0 100644 --- a/pymllm/ffi/__init__.py +++ b/pymllm/ffi/__init__.py @@ -48,6 +48,10 @@ def to_pod(self) -> int: return tvm_ffi.get_global_func("mllm.DType.to_pod")(self) +# ============================================================================= +# DType factory functions +# ============================================================================= +# Floating point types def float32_() -> DType: return _ffi_api.float32_() @@ -60,6 +64,45 @@ def bfloat16_() -> DType: return _ffi_api.bfloat16_() +# Signed integer types +def int8_() -> DType: + return _ffi_api.int8_() + + +def int16_() -> DType: + return _ffi_api.int16_() + + +def int32_() -> DType: + return _ffi_api.int32_() + + +def int64_() -> DType: + return _ffi_api.int64_() + + +# Unsigned integer types +def uint8_() -> DType: + return _ffi_api.uint8_() + + +def uint16_() -> DType: + return _ffi_api.uint16_() + + +def uint32_() -> DType: + return _ffi_api.uint32_() + + +def uint64_() -> DType: + return _ffi_api.uint64_() + + +# Bool type (backed by uint8) +def bool_() -> DType: + return _ffi_api.bool_() + + def cpu_() -> Device: return _ffi_api.cpu_() @@ -219,10 +262,32 @@ def is_contiguous(self): return tvm_ffi.get_global_func("mllm.Tensor.is_contiguous")(self) -# Global dtypes +# ============================================================================= +# Global dtype instances +# ============================================================================= +# Floating point types float32: DType = float32_() float16: DType = float16_() bfloat16: DType = bfloat16_() + +# Signed integer types +int8: DType = int8_() +int16: DType = int16_() +int32: DType = int32_() +int64: DType = int64_() + +# Unsigned integer types +uint8: DType = uint8_() +uint16: DType = uint16_() +uint32: DType = uint32_() +uint64: DType = uint64_() + +# Bool type (use 'boolean' to avoid shadowing Python's built-in 'bool') +boolean: DType = bool_() + +# ============================================================================= +# Global device instances +# ============================================================================= cpu: Device = cpu_() cuda: Device = cuda_() qnn: Device = qnn_() diff --git a/tasks/build_sdk_android_qnn_aot.yaml b/tasks/build_sdk_android_qnn_aot.yaml new file mode 100644 index 000000000..f0e983b75 --- /dev/null +++ b/tasks/build_sdk_android_qnn_aot.yaml @@ -0,0 +1,22 @@ +Tasks: + - CMakeConfigTask: + cmake_cfg_path: "build-android-arm64-v8a-qnn" + cmake_build_type: "ReleaseWithDebInfo" + cmake_toolchain_file: "$ANDROID_NDK_PATH/build/cmake/android.toolchain.cmake" + cmake_extra_args: + - "-DMLLM_CROSS_COMPILE=ON" + - "-DMLLM_BUILD_ARM_BACKEND=ON" + - "-DMLLM_BUILD_QNN_BACKEND=ON" + - "-DANDROID_PLATFORM=android-28" + - "-DANDROID_ABI=arm64-v8a" + - '-DMLLM_CPU_BACKEND_COMPILE_OPTIONS="-march=armv8.2-a+fp16+fp16fml+dotprod+i8mm;-ffast-math;-Wno-nan-infinity-disabled"' + - "-DCMAKE_INSTALL_PREFIX=mllm-install-android-arm64-v8a-qnn" + - "-DMLLM_KERNEL_USE_THREADS=ON" + - "-DMLLM_KERNEL_THREADS_VENDOR_OPENMP=ON" + - "-DMLLM_KERNEL_USE_THREADS_VENDOR_MLLM=OFF" + + - CMakeBuildTask: + cmake_cfg_path: "build-android-arm64-v8a-qnn" + + - CMakeInstallTask: + cmake_cfg_path: "build-android-arm64-v8a-qnn" diff --git a/tasks/build_sdk_x86_qnn_aot.yaml b/tasks/build_sdk_x86_qnn_aot.yaml index f33281616..fd9131d2e 100644 --- a/tasks/build_sdk_x86_qnn_aot.yaml +++ b/tasks/build_sdk_x86_qnn_aot.yaml @@ -1,7 +1,7 @@ Tasks: - CMakeConfigTask: cmake_cfg_path: "build-qnn-aot" - cmake_build_type: "Release" + cmake_build_type: "ReleaseWithDebInfo" cmake_extra_args: # Optional, If use Highway - "-DHWY_ENABLE_TESTS=OFF"