diff --git a/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot.hpp b/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot.hpp index 14241684a..c24d41cd0 100644 --- a/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot.hpp +++ b/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot.hpp @@ -336,10 +336,14 @@ class Qwen3Text final : public nn::Module { // Quantization x = x.to(kUInt16PerTensorAsy); - auto position_ids = inputs[1]; + const auto& position_ids = inputs[1]; auto causal_mask = inputs[2]; - auto llm_embedding_sin = ptq::QDQ_ROPE(this, rope_sin_(), "sin_embedding_input_qdq")[{{0}, position_ids, {kAll}}]; - auto llm_embedding_cos = ptq::QDQ_ROPE(this, rope_cos_(), "cos_embedding_input_qdq")[{{0}, position_ids, {kAll}}]; + + auto llm_embedding_sin = + nn::functional::gather(ptq::QDQ_ROPE(this, rope_sin_(), "sin_embedding_input_qdq"), 1, position_ids); + + auto llm_embedding_cos = + nn::functional::gather(ptq::QDQ_ROPE(this, rope_cos_(), "cos_embedding_input_qdq"), 1, position_ids); std::vector keys; std::vector values; diff --git a/mllm/backends/cpu/CPUBackend.cpp b/mllm/backends/cpu/CPUBackend.cpp index f8a3d8d1c..0964cba0d 100644 --- a/mllm/backends/cpu/CPUBackend.cpp +++ b/mllm/backends/cpu/CPUBackend.cpp @@ -23,6 +23,7 @@ #include "mllm/backends/cpu/ops/FlashAttention2Op.hpp" #include "mllm/backends/cpu/ops/FlashAttn2WithSinkAndSwaOp.hpp" #include "mllm/backends/cpu/ops/GELUOp.hpp" +#include "mllm/backends/cpu/ops/GatherOp.hpp" #include "mllm/backends/cpu/ops/InterpolateOp.hpp" #include "mllm/backends/cpu/ops/LayerNorm2DOp.hpp" #include "mllm/backends/cpu/ops/MaskedScatterOp.hpp" @@ -81,7 +82,8 @@ CPUBackend::CPUBackend() : Backend(kCPU, createCPUAllocator()) { CPUMeanOpFactory, CPUKVCacheOpFactory, CPUPagedAttnOpFactory, CPUScatter2ShardsOpFactory, CPURadixAttnOpFactory, CPUConv2DOpFactory, CPULayerNorm2DOpFactory, CPUInterpolateOpFactory, CPUPadOpFactory, CPUMaskedScatterOpFactory, CPUArgsortOpFactory, CPUCloneOpFactory, CPUAvgPool1dOpFactory, CPUFlashAttention2SwaSinkOpFactory, - CPURadixAttnRelaxOpFactory, CPURadixAttnSwaSinkOpFactory, CPUEqualOpFactory, CPUWhereOpFactory>(); + CPURadixAttnRelaxOpFactory, CPURadixAttnSwaSinkOpFactory, CPUEqualOpFactory, CPUWhereOpFactory, + CPUGatherOpFactory>(); } CPUBackend::~CPUBackend() { diff --git a/mllm/backends/cpu/ops/GatherOp.cpp b/mllm/backends/cpu/ops/GatherOp.cpp new file mode 100644 index 000000000..dd8d0a8e2 --- /dev/null +++ b/mllm/backends/cpu/ops/GatherOp.cpp @@ -0,0 +1,69 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/backends/cpu/ops/GatherOp.hpp" +#include "mllm/core/Tensor.hpp" + +namespace mllm::cpu { + +CPUGatherOp::CPUGatherOp(const aops::GatherOpOptions& options) : aops::GatherOp(options) {} + +void CPUGatherOp::forward(const std::vector& inputs, std::vector& outputs) { + auto& table = inputs[0]; + auto& indices = inputs[1]; + auto& output = outputs[0]; + + int dim = options_.dim; + if (dim < 0) dim += table.shape().size(); + + int64_t outer_size = 1; + for (int i = 0; i < dim; ++i) outer_size *= table.shape()[i]; + + int64_t inner_size = 1; + for (int i = dim + 1; i < table.shape().size(); ++i) inner_size *= table.shape()[i]; + + int64_t dim_size = table.shape()[dim]; + int64_t indices_count = indices.numel(); + + size_t data_type_size = 4; + switch (table.dtype()) { + case MLLM_TYPE_F32: data_type_size = sizeof(float); break; + case MLLM_TYPE_F16: data_type_size = sizeof(mllm_fp16_t); break; + case MLLM_TYPE_I32: data_type_size = sizeof(int32_t); break; + default: MLLM_ERROR("GatherOp table type not supported: {}", (int)table.dtype()); + } + + const uint8_t* table_ptr = table.ptr(); + uint8_t* output_ptr = output.ptr(); + + const int32_t* indices_i32 = indices.dtype() == MLLM_TYPE_I32 ? indices.ptr() : nullptr; + const float* indices_f32 = !indices_i32 && indices.dtype() == MLLM_TYPE_F32 ? indices.ptr() : nullptr; + + if (!indices_i32 && !indices_f32) { + MLLM_ERROR("GatherOp indices type not supported: {}", (int)indices.dtype()); + return; + } + + // FIXME: parallel + for (int64_t o = 0; o < outer_size; ++o) { + for (int64_t i = 0; i < indices_count; ++i) { + int64_t idx = 0; + if (indices_i32) { + idx = indices_i32[i]; + } else if (indices_f32) { + idx = (int64_t)indices_f32[i]; + } + + if (idx < 0) idx += dim_size; + + if (idx < 0 || idx >= dim_size) { continue; } + + int64_t src_offset = (o * dim_size + idx) * inner_size * data_type_size; + int64_t dst_offset = (o * indices_count + i) * inner_size * data_type_size; + + std::memcpy(output_ptr + dst_offset, table_ptr + src_offset, inner_size * data_type_size); + } + } +} + +} // namespace mllm::cpu diff --git a/mllm/backends/cpu/ops/GatherOp.hpp b/mllm/backends/cpu/ops/GatherOp.hpp new file mode 100644 index 000000000..5f6c2719b --- /dev/null +++ b/mllm/backends/cpu/ops/GatherOp.hpp @@ -0,0 +1,25 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/aops/GatherOp.hpp" + +namespace mllm::cpu { + +class CPUGatherOp final : public aops::GatherOp { + public: + explicit CPUGatherOp(const aops::GatherOpOptions& options); + + void forward(const std::vector& inputs, std::vector& outputs) override; +}; + +class CPUGatherOpFactory : public TypedOpFactory { + public: + std::shared_ptr createOpImpl(const aops::GatherOpOptions& options) override { + return std::make_shared(options); + } +}; + +} // namespace mllm::cpu diff --git a/mllm/backends/qnn/QNNUtils.cpp b/mllm/backends/qnn/QNNUtils.cpp index c31d71222..e6e8580b7 100644 --- a/mllm/backends/qnn/QNNUtils.cpp +++ b/mllm/backends/qnn/QNNUtils.cpp @@ -5,6 +5,7 @@ #include "mllm/core/DataTypes.hpp" #include "mllm/core/DeviceTypes.hpp" #include "mllm/engine/Context.hpp" +#include "mllm/mllm.hpp" #include "mllm/utils/Common.hpp" #include "mllm/utils/Log.hpp" #include "mllm/compile/ir/tensor/Value.hpp" @@ -303,11 +304,11 @@ Qnn_DataType_t mllmDataTypeToQnnDataType(DataTypes dtype) { Qnn_DataType_t ret = QNN_DATATYPE_UNDEFINED; switch (dtype) { case kInt8: { - ret = QNN_DATATYPE_INT_8; + ret = QNN_DATATYPE_SFIXED_POINT_8; break; } case kInt16: { - ret = QNN_DATATYPE_INT_16; + ret = QNN_DATATYPE_UFIXED_POINT_16; break; } case kInt32: { @@ -319,11 +320,11 @@ Qnn_DataType_t mllmDataTypeToQnnDataType(DataTypes dtype) { break; } case kUInt8: { - ret = QNN_DATATYPE_UINT_8; + ret = QNN_DATATYPE_UFIXED_POINT_8; break; } case kUInt16: { - ret = QNN_DATATYPE_UINT_16; + ret = QNN_DATATYPE_UFIXED_POINT_16; break; } case kUInt32: { @@ -449,7 +450,8 @@ std::shared_ptr QNNTensorWrapper::create(const std::string& na // in this case, the tensor may be a placeholder(input/output except for graph IO) // it will be allocated to QNN shared buffer via QNNTensorWrapper::alloc() later MLLM_RT_ASSERT(!name.empty()); - if (type != QNN_TENSOR_TYPE_STATIC) { MLLM_RT_ASSERT(tensor.device() == kQNN); } + // in AOT case, the tensor is all on CPU (TODO: handle this) + // if (type != QNN_TENSOR_TYPE_STATIC) { MLLM_RT_ASSERT(tensor.device() == kQNN); } Qnn_DataType_t dataType = mllmDataTypeToQnnDataType(tensor.dtype()); @@ -467,11 +469,6 @@ std::shared_ptr QNNTensorWrapper::createStaticTensor(const std Qnn_QuantizeParams_t quantize) { MLLM_RT_ASSERT(!name.empty() && tensor.rank() > 0 && !tensor.isNil()); - // mllm currently support float16/float32/sfixed8(int8) as static tensor (weight) data type - // uint8 and int32 is caused by QNNLinear which uses Conv2d - MLLM_RT_ASSERT(tensor.dtype() == kFloat16 || tensor.dtype() == kFloat32 || tensor.dtype() == kInt8 || tensor.dtype() == kUInt8 - || tensor.dtype() == kInt32); - std::shared_ptr tensorWrapper = QNNTensorWrapper::create(name, QNN_TENSOR_TYPE_STATIC, tensor, quantize); tensorWrapper->isAlloc_ = true; @@ -618,4 +615,75 @@ void propagateQuantScale(const Tensor& input, Tensor& output) { } } +void __printQnnTensor(const Qnn_Tensor_t* tensor) { + if (tensor == nullptr) { + MLLM_ERROR("Tensor is null"); + return; + } + if (tensor->version != QNN_TENSOR_VERSION_2) { + MLLM_ERROR("Only Qnn_TensorV2_t is supported"); + return; + } + + const Qnn_TensorV2_t& t = tensor->v2; + + std::string tensor_type = ""; + + switch (t.type) { + case QNN_TENSOR_TYPE_APP_READ: tensor_type = "APP_READ"; break; + case QNN_TENSOR_TYPE_APP_WRITE: tensor_type = "APP_WRITE"; break; + case QNN_TENSOR_TYPE_NATIVE: tensor_type = "APP_NATIVE"; break; + case QNN_TENSOR_TYPE_STATIC: tensor_type = "STATIC"; break; + default: tensor_type = "UNKNOWN"; + } + + std::string dtype_str; + switch (t.dataType) { + case QNN_DATATYPE_INT_8: dtype_str = "INT_8"; break; + case QNN_DATATYPE_INT_16: dtype_str = "INT_16"; break; + case QNN_DATATYPE_INT_32: dtype_str = "INT_32"; break; + case QNN_DATATYPE_INT_64: dtype_str = "INT_64"; break; + case QNN_DATATYPE_UINT_8: dtype_str = "UINT_8"; break; + case QNN_DATATYPE_UINT_16: dtype_str = "UINT_16"; break; + case QNN_DATATYPE_UINT_32: dtype_str = "UINT_32"; break; + case QNN_DATATYPE_UINT_64: dtype_str = "UINT_64"; break; + case QNN_DATATYPE_FLOAT_16: dtype_str = "FLOAT_16"; break; + case QNN_DATATYPE_FLOAT_32: dtype_str = "FLOAT_32"; break; + case QNN_DATATYPE_FLOAT_64: dtype_str = "FLOAT_64"; break; + case QNN_DATATYPE_SFIXED_POINT_4: dtype_str = "SFIXED_POINT_4"; break; + case QNN_DATATYPE_SFIXED_POINT_8: dtype_str = "SFIXED_POINT_8"; break; + case QNN_DATATYPE_SFIXED_POINT_16: dtype_str = "SFIXED_POINT_16"; break; + case QNN_DATATYPE_SFIXED_POINT_32: dtype_str = "SFIXED_POINT_32"; break; + case QNN_DATATYPE_UFIXED_POINT_4: dtype_str = "UFIXED_POINT_4"; break; + case QNN_DATATYPE_UFIXED_POINT_8: dtype_str = "UFIXED_POINT_8"; break; + case QNN_DATATYPE_UFIXED_POINT_16: dtype_str = "UFIXED_POINT_16"; break; + case QNN_DATATYPE_UFIXED_POINT_32: dtype_str = "UFIXED_POINT_32"; break; + case QNN_DATATYPE_BOOL_8: dtype_str = "BOOL_8"; break; + case QNN_DATATYPE_STRING: dtype_str = "STRING"; break; + default: dtype_str = "UNKNOWN"; break; + } + + std::string shape_str = "["; + for (uint32_t i = 0; i < t.rank; ++i) { + shape_str += std::to_string(t.dimensions[i]); + if (i < t.rank - 1) shape_str += ", "; + } + shape_str += "]"; + + std::string quant_str = "None"; + if (t.quantizeParams.encodingDefinition == QNN_DEFINITION_DEFINED) { + if (t.quantizeParams.quantizationEncoding == QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) { + quant_str = "Scale: " + std::to_string(t.quantizeParams.scaleOffsetEncoding.scale) + + ", Offset: " + std::to_string(t.quantizeParams.scaleOffsetEncoding.offset); + } else if (t.quantizeParams.quantizationEncoding == QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) { + quant_str = "Axis Scale Offset (Axis: " + std::to_string(t.quantizeParams.axisScaleOffsetEncoding.axis) + ")"; + } else if (t.quantizeParams.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION) { + quant_str = "Blockwise Expansion (axis:" + std::to_string(t.quantizeParams.blockwiseExpansion->axis) + + ", blockSize:" + std::to_string(t.quantizeParams.blockwiseExpansion->numBlocksPerAxis) + ")"; + } + } + + MLLM_INFO("Tensor: {}, Type:{}, Shape: {}, Dtype: {}, Quant: {}", t.name, tensor_type, shape_str, dtype_str, quant_str); +} + } // namespace mllm::qnn diff --git a/mllm/backends/qnn/QNNUtils.hpp b/mllm/backends/qnn/QNNUtils.hpp index a57f62982..8b9931e0a 100644 --- a/mllm/backends/qnn/QNNUtils.hpp +++ b/mllm/backends/qnn/QNNUtils.hpp @@ -6,7 +6,6 @@ #include "QnnTypes.h" #include "mllm/core/Tensor.hpp" -#include #include #include #include @@ -89,6 +88,8 @@ bool freeQnnTensor(Qnn_Tensor_t& tensor); bool freeQnnTensors(Qnn_Tensor_t*& tensors, uint32_t numTensors); +void __printQnnTensor(const Qnn_Tensor_t* tensor); // for debug use + inline void __mllmQnnLoggerCallback(const char* fmt, QnnLog_Level_t level, uint64_t times_tamp, va_list argp) { const char* level_str = ""; const char* color_start = ""; @@ -277,9 +278,12 @@ QNNParamScalarWrapper::QNNParamScalarWrapper(const std::string& name, T value) : if constexpr (std::is_same_v) { qnnParam_.scalarParam.dataType = QNN_DATATYPE_BOOL_8; qnnParam_.scalarParam.bool8Value = static_cast(value); - } else if constexpr (std::is_same_v || std::is_same_v) { + } else if constexpr (std::is_same_v) { qnnParam_.scalarParam.dataType = QNN_DATATYPE_UINT_32; qnnParam_.scalarParam.uint32Value = static_cast(value); + } else if constexpr (std::is_same_v) { + qnnParam_.scalarParam.dataType = QNN_DATATYPE_INT_32; + qnnParam_.scalarParam.int32Value = static_cast(value); } else if constexpr (std::is_same_v || std::is_same_v) { qnnParam_.scalarParam.dataType = QNN_DATATYPE_FLOAT_32; qnnParam_.scalarParam.floatValue = static_cast(value); diff --git a/mllm/backends/qnn/aot/QnnWrappersAPI.cpp b/mllm/backends/qnn/aot/QnnWrappersAPI.cpp index bddb33db6..ce51c419d 100644 --- a/mllm/backends/qnn/aot/QnnWrappersAPI.cpp +++ b/mllm/backends/qnn/aot/QnnWrappersAPI.cpp @@ -31,6 +31,7 @@ QnnAOTNodeTensor::QnnAOTNodeTensor(const ir::tensor::TensorValue::ptr_t& v, bool } else { tensor_wrapper_ = mllm::qnn::QNNTensorWrapper::create(name, type, v->tensor_, quant); } + setupComplexTensorQuantization(v); // per-channel and LPBQ cases } Qnn_TensorType_t QnnAOTNodeTensor::parseQnnTensorTypeFromIR(const ir::tensor::TensorValue::ptr_t& v) { @@ -90,7 +91,7 @@ Qnn_TensorType_t QnnAOTNodeTensor::parseQnnTensorTypeFromIR(const ir::tensor::Te // Check Attribute. The Attribute priority is higher than tensor type if (v->getAttr("qnn_graph_outputs")) { ret_qnn_tensor_type = QNN_TENSOR_TYPE_APP_READ; } - if (v->getAttr("qnn_graph_inputs")) { ret_qnn_tensor_type = QNN_TENSOR_TYPE_APP_READWRITE; } + if (v->getAttr("qnn_graph_inputs")) { ret_qnn_tensor_type = QNN_TENSOR_TYPE_APP_WRITE; } if (v->getAttr("constant")) { ret_qnn_tensor_type = QNN_TENSOR_TYPE_STATIC; } return ret_qnn_tensor_type; @@ -109,7 +110,16 @@ Qnn_QuantizeParams_t QnnAOTNodeTensor::parseQnnQuantizeParamFromIR(const ir::ten auto quant_spec = v->getAttr("quant_recipe")->cast_()->spec_; switch (quant_spec->type) { - case ir::linalg::QuantizationSpecType::kRaw: { + case ir::linalg::QuantizationSpecType::kRaw: + case ir::linalg::QuantizationSpecType::kSymPerChannel: + case ir::linalg::QuantizationSpecType::kLPBQ: { + break; + } + case ir::linalg::QuantizationSpecType::kAsymPerTensor: { + auto cfg = std::static_pointer_cast(quant_spec); + ret.encodingDefinition = QNN_DEFINITION_DEFINED; + ret.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET; + ret.scaleOffsetEncoding = Qnn_ScaleOffset_t{.scale = cfg->scale.item(), .offset = cfg->zero_point.item()}; break; } case ir::linalg::QuantizationSpecType::kSymPerTensor: { @@ -119,6 +129,19 @@ Qnn_QuantizeParams_t QnnAOTNodeTensor::parseQnnQuantizeParamFromIR(const ir::ten ret.scaleOffsetEncoding = Qnn_ScaleOffset_t{.scale = cfg->scale.item(), .offset = 0}; break; } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't handle kNone type"); + } + } + + return ret; +} + +void QnnAOTNodeTensor::setupComplexTensorQuantization(const ir::tensor::TensorValue::ptr_t& v) { + MLLM_RT_ASSERT(v->getAttr("quant_recipe")); + auto quant_spec = v->getAttr("quant_recipe")->cast_()->spec_; + + switch (quant_spec->type) { case ir::linalg::QuantizationSpecType::kSymPerChannel: { auto cfg = std::static_pointer_cast(quant_spec); @@ -135,12 +158,6 @@ Qnn_QuantizeParams_t QnnAOTNodeTensor::parseQnnQuantizeParamFromIR(const ir::ten tensor_wrapper_->setScaleOffsetQuantization(scale_offsets, cfg->ch_axis); break; } - case ir::linalg::QuantizationSpecType::kSymPerBlock: - case ir::linalg::QuantizationSpecType::kAsymPerTensor: - case ir::linalg::QuantizationSpecType::kAsymPerChannel: - case ir::linalg::QuantizationSpecType::kAsymPerBlock: { - MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't handle [kSymPerBlock, kAsymPerTensor, kAsymPerChannel, kAsymPerBlock] type"); - } case ir::linalg::QuantizationSpecType::kLPBQ: { auto cfg = std::static_pointer_cast(quant_spec); @@ -150,15 +167,14 @@ Qnn_QuantizeParams_t QnnAOTNodeTensor::parseQnnQuantizeParamFromIR(const ir::ten MLLM_RT_ASSERT_EQ(num_scale_offsets, cfg->scale_level_1_fp.size(0)); MLLM_RT_ASSERT_EQ(cfg->scale_level_0_int.dtype(), kUInt8); for (int i = 0; i < num_scale_offsets; ++i) { - scale_offsets[i].scale = cfg->scale_level_1_fp.at({i}); + scale_offsets[i].scale = cfg->scale_level_1_fp.at({i, 0, 0}); scale_offsets[i].offset = 0; } Qnn_BlockwiseExpansion_t blockwise_expansion; blockwise_expansion.axis = cfg->ch_axis; - blockwise_expansion.axis = cfg->ch_axis; blockwise_expansion.scaleOffsets = nullptr; // Will be set by setBlockwiseQuantization - blockwise_expansion.numBlocksPerAxis = v->tensor_.size(cfg->ch_axis) / cfg->block_size; + blockwise_expansion.numBlocksPerAxis = v->tensor_.size(1) / cfg->block_size; blockwise_expansion.blockScaleBitwidth = 12; // 12 bits for 4 to 16 expansion blockwise_expansion.blockScaleStorageType = QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_8; blockwise_expansion.blocksScale8 = cfg->scale_level_0_int.ptr(); @@ -166,12 +182,8 @@ Qnn_QuantizeParams_t QnnAOTNodeTensor::parseQnnQuantizeParamFromIR(const ir::ten tensor_wrapper_->setBlockwiseQuantization(blockwise_expansion, scale_offsets); break; } - default: { - MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't handle kNone type"); - } + default: break; } - - return ret; } // QnnAOTNodeOperation implementations diff --git a/mllm/backends/qnn/aot/QnnWrappersAPI.hpp b/mllm/backends/qnn/aot/QnnWrappersAPI.hpp index 1c5189d55..ebbf7f3b2 100644 --- a/mllm/backends/qnn/aot/QnnWrappersAPI.hpp +++ b/mllm/backends/qnn/aot/QnnWrappersAPI.hpp @@ -61,6 +61,9 @@ class QnnAOTNodeTensor : public std::enable_shared_from_this { Qnn_QuantizeParams_t parseQnnQuantizeParamFromIR(const ir::tensor::TensorValue::ptr_t& v); + // intend for per-channel and LPBQ quantization + void setupComplexTensorQuantization(const ir::tensor::TensorValue::ptr_t& v); + std::shared_ptr tensor_wrapper_; }; diff --git a/mllm/backends/qnn/aot/passes/LLM2QnnLoweringPass.cpp b/mllm/backends/qnn/aot/passes/LLM2QnnLoweringPass.cpp index 34a670ef4..0e01d91b9 100644 --- a/mllm/backends/qnn/aot/passes/LLM2QnnLoweringPass.cpp +++ b/mllm/backends/qnn/aot/passes/LLM2QnnLoweringPass.cpp @@ -14,14 +14,18 @@ #include "mllm/backends/qnn/aot/visitor/Elewise.hpp" #include "mllm/backends/qnn/aot/visitor/Embedding.hpp" +#include "mllm/backends/qnn/aot/visitor/Gather.hpp" #include "mllm/backends/qnn/aot/visitor/CastType.hpp" #include "mllm/backends/qnn/aot/visitor/View.hpp" #include "mllm/backends/qnn/aot/visitor/Index.hpp" +#include "mllm/backends/qnn/aot/visitor/RMSNorm.hpp" +#include "mllm/backends/qnn/aot/visitor/Linear.hpp" namespace mllm::qnn::aot { LLM2QnnLoweringPass::LLM2QnnLoweringPass() { - registerPatterns(); + registerPatterns(); } uint8_t LLM2QnnLoweringPass::run(const ir::node_ptr_t& op) { diff --git a/mllm/backends/qnn/aot/passes/LLMQuantRecipePass.cpp b/mllm/backends/qnn/aot/passes/LLMQuantRecipePass.cpp index 37fdaffec..66bd99c78 100644 --- a/mllm/backends/qnn/aot/passes/LLMQuantRecipePass.cpp +++ b/mllm/backends/qnn/aot/passes/LLMQuantRecipePass.cpp @@ -433,6 +433,36 @@ bool LLMQuantRecipeIndexPattern::rewrite(ir::IRWriter& writer, const ir::op_ptr_ node->cast_()); } +//===----------------------------------------------------------------------===// +// Gather Pattern +//===----------------------------------------------------------------------===// +bool LLMQuantRecipeGatherPattern::isMatch(const mllm::ir::op_ptr_t& op) { + if (op->isa_()) { return true; } + return false; +} + +bool LLMQuantRecipeGatherPattern::rewrite(ir::IRWriter& writer, const ir::op_ptr_t& node) { + auto gather_ir = node->cast_(); + auto i_0 = *(node->inputs().begin()); + + if (!i_0->getAttr("quant_recipe")) { + auto i_0_spec = genSimpleQuantizationSpecAttr(writer.getContext(), i_0->cast_()); + i_0->setAttr("quant_recipe", i_0_spec); + } + + auto annotation_attr = writer.getContext()->create(); + auto op = node->cast_(); + + // Share + auto quant_spec = op->inputs().front()->getAttr("quant_recipe")->cast_(); + annotation_attr->annotation_.inputs.emplace_back(quant_spec->spec_); + annotation_attr->annotation_.outputs.emplace_back(quant_spec->spec_); + op->outputs().front()->setAttr("quant_recipe", quant_spec); + op->setAttr("quant_recipe", annotation_attr); + + return true; +} + //===----------------------------------------------------------------------===// // Slice Pattern //===----------------------------------------------------------------------===// @@ -764,7 +794,7 @@ bool LLMQuantRecipeLinearPattern::rewrite(ir::IRWriter& writer, const ir::op_ptr if (precision == "w4a16") { weight_quant_spec = - ir::linalg::QuantizationSpecLPBQ::create(-8, 7, block_size, -1, 4, kUInt4, kFloat32, Tensor::nil(), Tensor::nil()); + ir::linalg::QuantizationSpecLPBQ::create(-8, 7, block_size, 0, 4, kUInt4, kFloat32, Tensor::nil(), Tensor::nil()); // output sym int16 auto out_quant_spec = ir::linalg::QuantizationSpecAsymPerTensor::create(0, 65536 - 1, kUInt16, kFloat32, kInt32, @@ -980,6 +1010,7 @@ LLMQuantRecipePass::LLMQuantRecipePass() { addPattern(LLMQuantRecipeLinearPattern::create(), "linear", 0); addPattern(LLMQuantRecipeEmbeddingPattern::create(), "embedding", 0); addPattern(LLMQuantRecipeViewPattern::create(), "view", 0); + addPattern(LLMQuantRecipeGatherPattern::create(), "gather", 0); } uint8_t LLMQuantRecipePass::run(const ir::node_ptr_t& op) { diff --git a/mllm/backends/qnn/aot/passes/LLMQuantRecipePass.hpp b/mllm/backends/qnn/aot/passes/LLMQuantRecipePass.hpp index abd7cdbcc..4b9189242 100644 --- a/mllm/backends/qnn/aot/passes/LLMQuantRecipePass.hpp +++ b/mllm/backends/qnn/aot/passes/LLMQuantRecipePass.hpp @@ -296,6 +296,20 @@ class LLMQuantRecipeEmbeddingPattern : public ir::Pattern { } }; +//===----------------------------------------------------------------------===// +// Gather Pattern +//===----------------------------------------------------------------------===// +class LLMQuantRecipeGatherPattern : public ir::Pattern { + public: + bool isMatch(const mllm::ir::op_ptr_t& op) override; + + bool rewrite(ir::IRWriter& writer, const ir::op_ptr_t& node) override; + + static inline std::shared_ptr create() { + return std::make_shared(); + } +}; + //===----------------------------------------------------------------------===// // Qwen3 Attention Pattern //===----------------------------------------------------------------------===// diff --git a/mllm/backends/qnn/aot/visitor/Concat.cpp b/mllm/backends/qnn/aot/visitor/Concat.cpp index e69de29bb..972aaee64 100644 --- a/mllm/backends/qnn/aot/visitor/Concat.cpp +++ b/mllm/backends/qnn/aot/visitor/Concat.cpp @@ -0,0 +1,73 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/utils/Common.hpp" +#include "mllm/compile/ir/linalg/Op.hpp" +#include "mllm/compile/ir/builtin/Attribute.hpp" +#include "mllm/backends/qnn/aot/QnnWrappersAPI.hpp" +#include "mllm/backends/qnn/aot/visitor/Concat.hpp" +#include "mllm/backends/qnn/aot/passes/AOTCompileContext.hpp" +#include "mllm/core/aops/ConcatOp.hpp" + +namespace mllm::qnn::aot { + +bool QnnAOTConcatPattern::isMatch(const mllm::ir::op_ptr_t& op) { + return op->isa_() && (op->getAttr("using_qnn") != nullptr); +} + +bool QnnAOTConcatPattern::rewrite(ir::IRWriter& writer, const ir::op_ptr_t& op) { + auto env = AOTCompileContext::getInstance().getEnv(); + + auto concat_op = op->cast_(); + if (!concat_op) { + MLLM_ERROR("Failed to cast to linalg::ConcatOp"); + return false; + } + + MLLM_RETURN_FALSE_IF_NOT(op->getAttr("qnn_graph_name")); + auto qnn_graph_name = op->getAttr("qnn_graph_name")->cast_()->data(); + MLLM_RETURN_FALSE_IF_NOT(op->getAttr("qnn_context_name")); + auto qnn_context_name = op->getAttr("qnn_context_name")->cast_()->data(); + + auto output = op->outputs().front()->cast_(); + + auto base_op = concat_op->getAOp(); + auto real_concat_op = dynamic_cast(base_op); + if (!real_concat_op) { + MLLM_ERROR("Failed to cast BaseOp to mllm::aops::ConcatOp"); + return false; + } + + int axis = real_concat_op->options().dim; + + // Handle negative axis + // We can use the first input to determine rank, assuming all inputs have same rank + auto first_input = op->inputs().front()->cast_(); + auto input_shape = first_input->tensor_.shape(); + int rank = input_shape.size(); + if (axis < 0) { axis += rank; } + + // Create QNN Op Node + auto qnn_op_node = QnnAOTNodeOperation::create("Concat"); + qnn_op_node->setPackageName("qti.aisw"); + + // Add Inputs + for (auto& input_val : op->inputs()) { + auto input = input_val->cast_(); + qnn_op_node->emplaceInput(env->captureQnnAOTNodeTensor(qnn_context_name, qnn_graph_name, input)); + } + + // Add Params + qnn_op_node->emplaceParamScalar(QNNParamScalarWrapper::create("axis", (uint32_t)axis)); + + // Add Output + qnn_op_node->emplaceOutput(env->captureQnnAOTNodeTensor(qnn_context_name, qnn_graph_name, output)) + ->setName(base_op->getName()); + + // Register + env->captureAOTNodeOp(qnn_context_name, qnn_graph_name, qnn_op_node); + + return true; +} + +} // namespace mllm::qnn::aot diff --git a/mllm/backends/qnn/aot/visitor/Concat.hpp b/mllm/backends/qnn/aot/visitor/Concat.hpp index e69de29bb..f5fa10986 100644 --- a/mllm/backends/qnn/aot/visitor/Concat.hpp +++ b/mllm/backends/qnn/aot/visitor/Concat.hpp @@ -0,0 +1,23 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/OpTypes.hpp" +#include "mllm/compile/ir/Node.hpp" +#include "mllm/backends/qnn/aot/visitor/Base.hpp" + +namespace mllm::qnn::aot { + +class QnnAOTConcatPattern : public QnnAOTBasePattern { + public: + bool isMatch(const mllm::ir::op_ptr_t& op) override; + + bool rewrite(ir::IRWriter& writer, const ir::op_ptr_t& op) override; + + static inline std::pair> create() { + return {OpTypes::kConcat, std::make_shared()}; + } +}; + +} // namespace mllm::qnn::aot diff --git a/mllm/backends/qnn/aot/visitor/Equal.cpp b/mllm/backends/qnn/aot/visitor/Equal.cpp index e69de29bb..95f073553 100644 --- a/mllm/backends/qnn/aot/visitor/Equal.cpp +++ b/mllm/backends/qnn/aot/visitor/Equal.cpp @@ -0,0 +1,54 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/utils/Common.hpp" +#include "mllm/compile/ir/linalg/Op.hpp" +#include "mllm/compile/ir/builtin/Attribute.hpp" +#include "mllm/backends/qnn/aot/QnnWrappersAPI.hpp" +#include "mllm/backends/qnn/aot/visitor/Equal.hpp" +#include +#include "mllm/backends/qnn/aot/passes/AOTCompileContext.hpp" + +namespace mllm::qnn::aot { + +bool QnnAOTEqualPattern::isMatch(const mllm::ir::op_ptr_t& op) { + return op->isa_() && (op->getAttr("using_qnn") != nullptr); +} + +bool QnnAOTEqualPattern::rewrite(ir::IRWriter& writer, const ir::op_ptr_t& op) { + auto env = AOTCompileContext::getInstance().getEnv(); + + auto equal_op = op->cast_(); + if (!equal_op) { + MLLM_ERROR("Failed to cast to linalg::EqualOp"); + return false; + } + + MLLM_RETURN_FALSE_IF_NOT(op->getAttr("qnn_graph_name")); + auto qnn_graph_name = op->getAttr("qnn_graph_name")->cast_()->data(); + MLLM_RETURN_FALSE_IF_NOT(op->getAttr("qnn_context_name")); + auto qnn_context_name = op->getAttr("qnn_context_name")->cast_()->data(); + + // Inputs + auto input0 = op->inputs().front()->cast_(); + auto input1 = (*std::next(op->inputs().begin()))->cast_(); + + // Output + auto output = op->outputs().front()->cast_(); + + // Create QNN ElementWiseEqual Op + auto qnn_op_node = QnnAOTNodeOperation::create("ElementWiseEqual"); + qnn_op_node->setPackageName("qti.aisw"); + + qnn_op_node->emplaceInput(env->captureQnnAOTNodeTensor(qnn_context_name, qnn_graph_name, input0)) + ->emplaceInput(env->captureQnnAOTNodeTensor(qnn_context_name, qnn_graph_name, input1)) + ->emplaceOutput(env->captureQnnAOTNodeTensor(qnn_context_name, qnn_graph_name, output)) + ->setName(equal_op->getAOp()->getName()); + + // Register this op node into one graph. + env->captureAOTNodeOp(qnn_context_name, qnn_graph_name, qnn_op_node); + + return true; +} + +} // namespace mllm::qnn::aot diff --git a/mllm/backends/qnn/aot/visitor/Equal.hpp b/mllm/backends/qnn/aot/visitor/Equal.hpp index e69de29bb..7a3d91c32 100644 --- a/mllm/backends/qnn/aot/visitor/Equal.hpp +++ b/mllm/backends/qnn/aot/visitor/Equal.hpp @@ -0,0 +1,23 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/OpTypes.hpp" +#include "mllm/compile/ir/Node.hpp" +#include "mllm/backends/qnn/aot/visitor/Base.hpp" + +namespace mllm::qnn::aot { + +class QnnAOTEqualPattern : public QnnAOTBasePattern { + public: + bool isMatch(const mllm::ir::op_ptr_t& op) override; + + bool rewrite(ir::IRWriter& writer, const ir::op_ptr_t& op) override; + + static inline std::pair> create() { + return {OpTypes::kEqual, std::make_shared()}; + } +}; + +} // namespace mllm::qnn::aot diff --git a/mllm/backends/qnn/aot/visitor/Gather.cpp b/mllm/backends/qnn/aot/visitor/Gather.cpp new file mode 100644 index 000000000..abdcfae34 --- /dev/null +++ b/mllm/backends/qnn/aot/visitor/Gather.cpp @@ -0,0 +1,58 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/core/aops/GatherOp.hpp" +#include "mllm/utils/Common.hpp" +#include "mllm/compile/ir/linalg/Op.hpp" +#include "mllm/compile/ir/builtin/Attribute.hpp" +#include "mllm/backends/qnn/aot/QnnWrappersAPI.hpp" +#include "mllm/backends/qnn/aot/visitor/Gather.hpp" +#include "mllm/backends/qnn/aot/passes/AOTCompileContext.hpp" + +namespace mllm::qnn::aot { + +bool QnnAOTGatherPattern::isMatch(const mllm::ir::op_ptr_t& op) { + return op->isa_() && (op->getAttr("using_qnn") != nullptr); +} + +bool QnnAOTGatherPattern::rewrite(ir::IRWriter& writer, const ir::op_ptr_t& op) { + auto env = AOTCompileContext::getInstance().getEnv(); + + auto gather_op = op->cast_(); + if (!gather_op) { + MLLM_ERROR("Failed to cast to linalg::GatherOp"); + return false; + } + + MLLM_RETURN_FALSE_IF_NOT(op->getAttr("qnn_graph_name")); + auto qnn_graph_name = op->getAttr("qnn_graph_name")->cast_()->data(); + MLLM_RETURN_FALSE_IF_NOT(op->getAttr("qnn_context_name")); + auto qnn_context_name = op->getAttr("qnn_context_name")->cast_()->data(); + + // Inputs + auto table = op->inputs().front()->cast_(); + auto indices = (*std::next(op->inputs().begin()))->cast_(); + + // Output + auto output = op->outputs().front()->cast_(); + + // Create QNN Gather Op + auto qnn_op_node = QnnAOTNodeOperation::create("Gather"); + qnn_op_node->setPackageName("qti.aisw"); + + qnn_op_node->emplaceInput(env->captureQnnAOTNodeTensor(qnn_context_name, qnn_graph_name, table)) + ->emplaceInput(env->captureQnnAOTNodeTensor(qnn_context_name, qnn_graph_name, indices)) + ->emplaceOutput(env->captureQnnAOTNodeTensor(qnn_context_name, qnn_graph_name, output)) + ->setName(gather_op->getAOp()->getName()); + + // Add scalar param axis + int axis = dynamic_cast(gather_op->getAOp())->options().dim; + qnn_op_node->emplaceParamScalar(QNNParamScalarWrapper::create("axis", (int32_t)axis)); + + // Register this op node into one graph. + env->captureAOTNodeOp(qnn_context_name, qnn_graph_name, qnn_op_node); + + return true; +} + +} // namespace mllm::qnn::aot diff --git a/mllm/backends/qnn/aot/visitor/Gather.hpp b/mllm/backends/qnn/aot/visitor/Gather.hpp new file mode 100644 index 000000000..00ecef0dc --- /dev/null +++ b/mllm/backends/qnn/aot/visitor/Gather.hpp @@ -0,0 +1,23 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/OpTypes.hpp" +#include "mllm/compile/ir/Node.hpp" +#include "mllm/backends/qnn/aot/visitor/Base.hpp" + +namespace mllm::qnn::aot { + +class QnnAOTGatherPattern : public QnnAOTBasePattern { + public: + bool isMatch(const mllm::ir::op_ptr_t& op) override; + + bool rewrite(ir::IRWriter& writer, const ir::op_ptr_t& op) override; + + static inline std::pair> create() { + return {OpTypes::kGather, std::make_shared()}; + } +}; + +} // namespace mllm::qnn::aot diff --git a/mllm/backends/qnn/aot/visitor/Index.cpp b/mllm/backends/qnn/aot/visitor/Index.cpp index 0d98cab8a..4d19acfaf 100644 --- a/mllm/backends/qnn/aot/visitor/Index.cpp +++ b/mllm/backends/qnn/aot/visitor/Index.cpp @@ -87,12 +87,11 @@ bool QnnAOTIndexPattern::rewrite(ir::IRWriter& writer, const ir::op_ptr_t& op) { qnn_op_node->setPackageName("qti.aisw"); qnn_op_node->emplaceInput(env->captureQnnAOTNodeTensor(qnn_context_name, qnn_graph_name, input)) - ->emplaceInput(env->captureQnnAOTNodeTensor(qnn_context_name, qnn_graph_name, indices_tv, true)) + ->emplaceInput(env->captureQnnAOTNodeTensor(qnn_context_name, qnn_graph_name, indices_tv)) ->emplaceOutput(env->captureQnnAOTNodeTensor(qnn_context_name, qnn_graph_name, output)) + ->emplaceParamScalar(QNNParamScalarWrapper::create("axis", (int32_t)axis)) ->setName(base_op->getName()); - qnn_op_node->emplaceParamScalar(QNNParamScalarWrapper::create("axis", (int32_t)axis)); - env->captureAOTNodeOp(qnn_context_name, qnn_graph_name, qnn_op_node); return true; diff --git a/mllm/backends/qnn/aot/visitor/Linear.cpp b/mllm/backends/qnn/aot/visitor/Linear.cpp index e69de29bb..5c38b2bde 100644 --- a/mllm/backends/qnn/aot/visitor/Linear.cpp +++ b/mllm/backends/qnn/aot/visitor/Linear.cpp @@ -0,0 +1,85 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/core/DataTypes.hpp" +#include "mllm/mllm.hpp" +#include "mllm/utils/Common.hpp" +#include "mllm/compile/ir/linalg/Op.hpp" +#include "mllm/compile/ir/builtin/Attribute.hpp" +#include "mllm/backends/qnn/aot/QnnWrappersAPI.hpp" +#include "mllm/backends/qnn/aot/visitor/Linear.hpp" +#include "mllm/backends/qnn/aot/passes/AOTCompileContext.hpp" +#include "mllm/core/aops/LinearOp.hpp" + +namespace mllm::qnn::aot { + +bool QnnAOTLinearPattern::isMatch(const mllm::ir::op_ptr_t& op) { + return op->isa_() && (op->getAttr("using_qnn") != nullptr); +} + +bool QnnAOTLinearPattern::rewrite(ir::IRWriter& writer, const ir::op_ptr_t& op) { + auto env = AOTCompileContext::getInstance().getEnv(); + + auto linear_op = op->cast_(); + if (!linear_op) { + MLLM_ERROR("Failed to cast to linalg::LinearOp"); + return false; + } + + MLLM_RETURN_FALSE_IF_NOT(op->getAttr("qnn_graph_name")); + auto qnn_graph_name = op->getAttr("qnn_graph_name")->cast_()->data(); + MLLM_RETURN_FALSE_IF_NOT(op->getAttr("qnn_context_name")); + auto qnn_context_name = op->getAttr("qnn_context_name")->cast_()->data(); + + auto input = op->inputs().front()->cast_(); + auto output = op->outputs().front()->cast_(); + + auto base_op = linear_op->getAOp(); + auto real_linear_op = dynamic_cast(base_op); + if (!real_linear_op) { + MLLM_ERROR("Failed to cast BaseOp to mllm::aops::LinearOp"); + return false; + } + + // Retrieve weight from symbol table + auto weight_val = writer.getContext() + ->lookupSymbolTable(base_op->getName() + ".weight") + ->outputs() + .front() + ->cast_(); + + // Create QNN FullyConnected Op + auto qnn_op_node = QnnAOTNodeOperation::create("FullyConnected"); + qnn_op_node->setPackageName("qti.aisw"); + + weight_val->tensor_ = weight_val->tensor_.to(kUInt8); + + weight_val->tensor_ = weight_val->tensor_.view( + {weight_val->tensor_.shape()[0], weight_val->tensor_.shape()[1] * weight_val->tensor_.shape()[2]}); + + qnn_op_node->emplaceInput(env->captureQnnAOTNodeTensor(qnn_context_name, qnn_graph_name, input)) + ->emplaceInput(env->captureQnnAOTNodeTensor(qnn_context_name, qnn_graph_name, weight_val, true)); + + // Handle Bias + if (real_linear_op->options().bias) { + auto bias_val = writer.getContext() + ->lookupSymbolTable(base_op->getName() + ".bias") + ->outputs() + .front() + ->cast_(); + qnn_op_node->emplaceInput(env->captureQnnAOTNodeTensor(qnn_context_name, qnn_graph_name, bias_val, true)); + } + + qnn_op_node->emplaceOutput(env->captureQnnAOTNodeTensor(qnn_context_name, qnn_graph_name, output)) + ->setName(base_op->getName()); + + // Add params: keep_dims + qnn_op_node->emplaceParamScalar(QNNParamScalarWrapper::create("keep_dims", true)); + + // Register this op node into one graph. + env->captureAOTNodeOp(qnn_context_name, qnn_graph_name, qnn_op_node); + + return true; +} + +} // namespace mllm::qnn::aot diff --git a/mllm/backends/qnn/aot/visitor/Linear.hpp b/mllm/backends/qnn/aot/visitor/Linear.hpp index e69de29bb..a9e5f08fe 100644 --- a/mllm/backends/qnn/aot/visitor/Linear.hpp +++ b/mllm/backends/qnn/aot/visitor/Linear.hpp @@ -0,0 +1,23 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/OpTypes.hpp" +#include "mllm/compile/ir/Node.hpp" +#include "mllm/backends/qnn/aot/visitor/Base.hpp" + +namespace mllm::qnn::aot { + +class QnnAOTLinearPattern : public QnnAOTBasePattern { + public: + bool isMatch(const mllm::ir::op_ptr_t& op) override; + + bool rewrite(ir::IRWriter& writer, const ir::op_ptr_t& op) override; + + static inline std::pair> create() { + return {OpTypes::kLinear, std::make_shared()}; + } +}; + +} // namespace mllm::qnn::aot diff --git a/mllm/backends/qnn/aot/visitor/Matmul.cpp b/mllm/backends/qnn/aot/visitor/Matmul.cpp index e69de29bb..b44afd780 100644 --- a/mllm/backends/qnn/aot/visitor/Matmul.cpp +++ b/mllm/backends/qnn/aot/visitor/Matmul.cpp @@ -0,0 +1,53 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/utils/Common.hpp" +#include "mllm/compile/ir/linalg/Op.hpp" +#include "mllm/compile/ir/builtin/Attribute.hpp" +#include "mllm/backends/qnn/aot/QnnWrappersAPI.hpp" +#include "mllm/backends/qnn/aot/visitor/Matmul.hpp" +#include "mllm/backends/qnn/aot/passes/AOTCompileContext.hpp" + +namespace mllm::qnn::aot { + +bool QnnAOTMatMulPattern::isMatch(const mllm::ir::op_ptr_t& op) { + return op->isa_() && (op->getAttr("using_qnn") != nullptr); +} + +bool QnnAOTMatMulPattern::rewrite(ir::IRWriter& writer, const ir::op_ptr_t& op) { + auto env = AOTCompileContext::getInstance().getEnv(); + + auto matmul_op = op->cast_(); + if (!matmul_op) { + MLLM_ERROR("Failed to cast to linalg::MatMulOp"); + return false; + } + + MLLM_RETURN_FALSE_IF_NOT(op->getAttr("qnn_graph_name")); + auto qnn_graph_name = op->getAttr("qnn_graph_name")->cast_()->data(); + MLLM_RETURN_FALSE_IF_NOT(op->getAttr("qnn_context_name")); + auto qnn_context_name = op->getAttr("qnn_context_name")->cast_()->data(); + + // Inputs + auto input0 = op->inputs().front()->cast_(); + auto input1 = (*std::next(op->inputs().begin(), 1))->cast_(); + + // Output + auto output = op->outputs().front()->cast_(); + + // Create QNN MatMul Op + auto qnn_op_node = QnnAOTNodeOperation::create("MatMul"); + qnn_op_node->setPackageName("qti.aisw"); + + qnn_op_node->emplaceInput(env->captureQnnAOTNodeTensor(qnn_context_name, qnn_graph_name, input0)) + ->emplaceInput(env->captureQnnAOTNodeTensor(qnn_context_name, qnn_graph_name, input1)) + ->emplaceOutput(env->captureQnnAOTNodeTensor(qnn_context_name, qnn_graph_name, output)) + ->setName(matmul_op->getAOp()->getName()); + + // Register this op node into one graph. + env->captureAOTNodeOp(qnn_context_name, qnn_graph_name, qnn_op_node); + + return true; +} + +} // namespace mllm::qnn::aot diff --git a/mllm/backends/qnn/aot/visitor/Matmul.hpp b/mllm/backends/qnn/aot/visitor/Matmul.hpp index e69de29bb..af905c3a2 100644 --- a/mllm/backends/qnn/aot/visitor/Matmul.hpp +++ b/mllm/backends/qnn/aot/visitor/Matmul.hpp @@ -0,0 +1,23 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/OpTypes.hpp" +#include "mllm/compile/ir/Node.hpp" +#include "mllm/backends/qnn/aot/visitor/Base.hpp" + +namespace mllm::qnn::aot { + +class QnnAOTMatMulPattern : public QnnAOTBasePattern { + public: + bool isMatch(const mllm::ir::op_ptr_t& op) override; + + bool rewrite(ir::IRWriter& writer, const ir::op_ptr_t& op) override; + + static inline std::pair> create() { + return {OpTypes::kMatMul, std::make_shared()}; + } +}; + +} // namespace mllm::qnn::aot diff --git a/mllm/backends/qnn/aot/visitor/RMSNorm.cpp b/mllm/backends/qnn/aot/visitor/RMSNorm.cpp index e69de29bb..f27ff77ba 100644 --- a/mllm/backends/qnn/aot/visitor/RMSNorm.cpp +++ b/mllm/backends/qnn/aot/visitor/RMSNorm.cpp @@ -0,0 +1,68 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/utils/Common.hpp" +#include "mllm/core/aops/RMSNormOp.hpp" +#include "mllm/compile/ir/linalg/Op.hpp" +#include "mllm/compile/ir/builtin/Attribute.hpp" +#include "mllm/backends/qnn/aot/QnnWrappersAPI.hpp" +#include "mllm/backends/qnn/aot/visitor/RMSNorm.hpp" +#include "mllm/backends/qnn/aot/passes/AOTCompileContext.hpp" + +namespace mllm::qnn::aot { + +bool QnnAOTRMSNormPattern::isMatch(const mllm::ir::op_ptr_t& op) { + return op->isa_() && (op->getAttr("using_qnn") != nullptr); +} + +bool QnnAOTRMSNormPattern::rewrite(ir::IRWriter& writer, const ir::op_ptr_t& op) { + auto env = AOTCompileContext::getInstance().getEnv(); + + MLLM_RETURN_FALSE_IF_NOT(op->getAttr("quant_recipe")); + auto rms_op = op->cast_(); + if (!rms_op) { + MLLM_ERROR("Failed to cast to linalg::RMSNormOp"); + return false; + } + + MLLM_RETURN_FALSE_IF_NOT(op->getAttr("qnn_graph_name")); + auto qnn_graph_name = op->getAttr("qnn_graph_name")->cast_()->data(); + MLLM_RETURN_FALSE_IF_NOT(op->getAttr("qnn_context_name")); + auto qnn_context_name = op->getAttr("qnn_context_name")->cast_()->data(); + + auto a = rms_op->getAOp(); + auto rms_aop = dynamic_cast(a); + if (!rms_aop) { + MLLM_ERROR("Failed to cast to aops::RMSNormOp"); + return false; + } + + auto weight = + writer.getContext()->lookupSymbolTable(a->getName() + ".weight")->outputs().front()->cast_(); + + // Start to attach + auto i_0 = op->inputs().front()->cast_(); + auto o_0 = op->outputs().front()->cast_(); + auto qnn_op_node = QnnAOTNodeOperation::create("RmsNorm"); + qnn_op_node->setPackageName("qti.aisw"); + + qnn_op_node->emplaceParamScalar(mllm::qnn::QNNParamScalarWrapper::create("epsilon", rms_aop->options().epsilon)); + + std::vector axes_dims = {1}; + auto axes_param = mllm::qnn::QNNParamTensorWrapper::create("axes", a->getName() + "_axes", QNN_DATATYPE_UINT_32, axes_dims); + uint32_t* axes_data = (uint32_t*)axes_param->alloc(); + axes_data[0] = i_0->tensor_.shape().size() - 1; + qnn_op_node->emplaceParamTensor(axes_param); + + qnn_op_node->emplaceInput(env->captureQnnAOTNodeTensor(qnn_context_name, qnn_graph_name, i_0)) + ->emplaceInput(env->captureQnnAOTNodeTensor(qnn_context_name, qnn_graph_name, weight, true)) + ->emplaceOutput(env->captureQnnAOTNodeTensor(qnn_context_name, qnn_graph_name, o_0)) + ->setName(rms_op->getAOp()->getName()); + + // Register this op node into one graph. + env->captureAOTNodeOp(qnn_context_name, qnn_graph_name, qnn_op_node); + + return true; +} + +} // namespace mllm::qnn::aot diff --git a/mllm/backends/qnn/aot/visitor/RMSNorm.hpp b/mllm/backends/qnn/aot/visitor/RMSNorm.hpp index e69de29bb..5bd27ad82 100644 --- a/mllm/backends/qnn/aot/visitor/RMSNorm.hpp +++ b/mllm/backends/qnn/aot/visitor/RMSNorm.hpp @@ -0,0 +1,23 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/OpTypes.hpp" +#include "mllm/compile/ir/Node.hpp" +#include "mllm/backends/qnn/aot/visitor/Base.hpp" + +namespace mllm::qnn::aot { + +class QnnAOTRMSNormPattern : public QnnAOTBasePattern { + public: + bool isMatch(const mllm::ir::op_ptr_t& op) override; + + bool rewrite(ir::IRWriter& writer, const ir::op_ptr_t& op) override; + + static inline std::pair> create() { + return {OpTypes::kRMSNorm, std::make_shared()}; + } +}; + +} // namespace mllm::qnn::aot diff --git a/mllm/backends/qnn/aot/visitor/Repeat.cpp b/mllm/backends/qnn/aot/visitor/Repeat.cpp index e69de29bb..e6eb0542a 100644 --- a/mllm/backends/qnn/aot/visitor/Repeat.cpp +++ b/mllm/backends/qnn/aot/visitor/Repeat.cpp @@ -0,0 +1,86 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/utils/Common.hpp" +#include "mllm/compile/ir/linalg/Op.hpp" +#include "mllm/compile/ir/builtin/Attribute.hpp" +#include "mllm/backends/qnn/aot/QnnWrappersAPI.hpp" +#include "mllm/backends/qnn/aot/visitor/Repeat.hpp" +#include "mllm/backends/qnn/aot/passes/AOTCompileContext.hpp" +#include "mllm/core/aops/RepeatOp.hpp" +#include + +namespace mllm::qnn::aot { + +bool QnnAOTRepeatPattern::isMatch(const mllm::ir::op_ptr_t& op) { + return op->isa_() && (op->getAttr("using_qnn") != nullptr); +} + +bool QnnAOTRepeatPattern::rewrite(ir::IRWriter& writer, const ir::op_ptr_t& op) { + auto env = AOTCompileContext::getInstance().getEnv(); + + auto repeat_op = op->cast_(); + if (!repeat_op) { + MLLM_ERROR("Failed to cast to linalg::RepeatOp"); + return false; + } + + MLLM_RETURN_FALSE_IF_NOT(op->getAttr("qnn_graph_name")); + auto qnn_graph_name = op->getAttr("qnn_graph_name")->cast_()->data(); + MLLM_RETURN_FALSE_IF_NOT(op->getAttr("qnn_context_name")); + auto qnn_context_name = op->getAttr("qnn_context_name")->cast_()->data(); + + auto input = op->inputs().front()->cast_(); + auto output = op->outputs().front()->cast_(); + + auto base_op = repeat_op->getAOp(); + auto real_repeat_op = dynamic_cast(base_op); + if (!real_repeat_op) { + MLLM_ERROR("Failed to cast BaseOp to mllm::aops::RepeatOp"); + return false; + } + + const auto& options = real_repeat_op->options(); + int dim = options.dim; + int repeat_times = options.repeat_times; + + auto input_shape = input->tensor_.shape(); + int rank = input_shape.size(); + + if (dim < 0) { dim += rank; } + + std::vector multiples(rank, 1); + if (dim >= 0 && dim < rank) { + multiples[dim] = (uint32_t)repeat_times; + } else { + MLLM_ERROR("Invalid dimension for RepeatOp: {}", dim); + return false; + } + + // Create QNN Op Node + // QNN uses "Tile" for repeat + auto qnn_op_node = QnnAOTNodeOperation::create("Tile"); + qnn_op_node->setPackageName("qti.aisw"); + + // Add Input + qnn_op_node->emplaceInput(env->captureQnnAOTNodeTensor(qnn_context_name, qnn_graph_name, input)); + + // Add multiples Param + auto multiplesName = base_op->getName() + ".multiples"; + auto multiplesParam = + QNNParamTensorWrapper::create("multiples", multiplesName, QNN_DATATYPE_UINT_32, std::vector{(uint32_t)rank}); + uint32_t* multiplesData = static_cast(multiplesParam->alloc()); + std::memcpy(multiplesData, multiples.data(), rank * sizeof(uint32_t)); + qnn_op_node->emplaceParamTensor(multiplesParam); + + // Add Output + qnn_op_node->emplaceOutput(env->captureQnnAOTNodeTensor(qnn_context_name, qnn_graph_name, output)) + ->setName(base_op->getName()); + + // Register + env->captureAOTNodeOp(qnn_context_name, qnn_graph_name, qnn_op_node); + + return true; +} + +} // namespace mllm::qnn::aot diff --git a/mllm/backends/qnn/aot/visitor/Repeat.hpp b/mllm/backends/qnn/aot/visitor/Repeat.hpp index e69de29bb..bec7d7e61 100644 --- a/mllm/backends/qnn/aot/visitor/Repeat.hpp +++ b/mllm/backends/qnn/aot/visitor/Repeat.hpp @@ -0,0 +1,23 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/OpTypes.hpp" +#include "mllm/compile/ir/Node.hpp" +#include "mllm/backends/qnn/aot/visitor/Base.hpp" + +namespace mllm::qnn::aot { + +class QnnAOTRepeatPattern : public QnnAOTBasePattern { + public: + bool isMatch(const mllm::ir::op_ptr_t& op) override; + + bool rewrite(ir::IRWriter& writer, const ir::op_ptr_t& op) override; + + static inline std::pair> create() { + return {OpTypes::kRepeat, std::make_shared()}; + } +}; + +} // namespace mllm::qnn::aot diff --git a/mllm/backends/qnn/aot/visitor/Softmax.cpp b/mllm/backends/qnn/aot/visitor/Softmax.cpp index e69de29bb..cb830da57 100644 --- a/mllm/backends/qnn/aot/visitor/Softmax.cpp +++ b/mllm/backends/qnn/aot/visitor/Softmax.cpp @@ -0,0 +1,71 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/utils/Common.hpp" +#include "mllm/compile/ir/linalg/Op.hpp" +#include "mllm/compile/ir/builtin/Attribute.hpp" +#include "mllm/backends/qnn/aot/QnnWrappersAPI.hpp" +#include "mllm/backends/qnn/aot/visitor/Softmax.hpp" +#include "mllm/backends/qnn/aot/passes/AOTCompileContext.hpp" +#include "mllm/core/aops/SoftmaxOp.hpp" + +namespace mllm::qnn::aot { + +bool QnnAOTSoftmaxPattern::isMatch(const mllm::ir::op_ptr_t& op) { + return op->isa_() && (op->getAttr("using_qnn") != nullptr); +} + +bool QnnAOTSoftmaxPattern::rewrite(ir::IRWriter& writer, const ir::op_ptr_t& op) { + auto env = AOTCompileContext::getInstance().getEnv(); + + auto softmax_op = op->cast_(); + if (!softmax_op) { + MLLM_ERROR("Failed to cast to linalg::SoftmaxOp"); + return false; + } + + MLLM_RETURN_FALSE_IF_NOT(op->getAttr("qnn_graph_name")); + auto qnn_graph_name = op->getAttr("qnn_graph_name")->cast_()->data(); + MLLM_RETURN_FALSE_IF_NOT(op->getAttr("qnn_context_name")); + auto qnn_context_name = op->getAttr("qnn_context_name")->cast_()->data(); + + auto input = op->inputs().front()->cast_(); + auto output = op->outputs().front()->cast_(); + + auto base_op = softmax_op->getAOp(); + auto real_softmax_op = dynamic_cast(base_op); + if (!real_softmax_op) { + MLLM_ERROR("Failed to cast BaseOp to mllm::aops::SoftmaxOp"); + return false; + } + + int axis = real_softmax_op->options().axis; + float beta = 1.0f; // Default beta + + // Handle negative axis + auto input_shape = input->tensor_.shape(); + int rank = input_shape.size(); + if (axis < 0) { axis += rank; } + + // Create QNN Op Node + auto qnn_op_node = QnnAOTNodeOperation::create("Softmax"); + qnn_op_node->setPackageName("qti.aisw"); + + // Add Input + qnn_op_node->emplaceInput(env->captureQnnAOTNodeTensor(qnn_context_name, qnn_graph_name, input)); + + // Add Params + qnn_op_node->emplaceParamScalar(QNNParamScalarWrapper::create("axis", (uint32_t)axis)); + qnn_op_node->emplaceParamScalar(QNNParamScalarWrapper::create("beta", beta)); + + // Add Output + qnn_op_node->emplaceOutput(env->captureQnnAOTNodeTensor(qnn_context_name, qnn_graph_name, output)) + ->setName(base_op->getName()); + + // Register + env->captureAOTNodeOp(qnn_context_name, qnn_graph_name, qnn_op_node); + + return true; +} + +} // namespace mllm::qnn::aot diff --git a/mllm/backends/qnn/aot/visitor/Softmax.hpp b/mllm/backends/qnn/aot/visitor/Softmax.hpp index e69de29bb..c3a2fe011 100644 --- a/mllm/backends/qnn/aot/visitor/Softmax.hpp +++ b/mllm/backends/qnn/aot/visitor/Softmax.hpp @@ -0,0 +1,23 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/OpTypes.hpp" +#include "mllm/compile/ir/Node.hpp" +#include "mllm/backends/qnn/aot/visitor/Base.hpp" + +namespace mllm::qnn::aot { + +class QnnAOTSoftmaxPattern : public QnnAOTBasePattern { + public: + bool isMatch(const mllm::ir::op_ptr_t& op) override; + + bool rewrite(ir::IRWriter& writer, const ir::op_ptr_t& op) override; + + static inline std::pair> create() { + return {OpTypes::kSoftmax, std::make_shared()}; + } +}; + +} // namespace mllm::qnn::aot diff --git a/mllm/backends/qnn/aot/visitor/Transpose.cpp b/mllm/backends/qnn/aot/visitor/Transpose.cpp index e69de29bb..149510133 100644 --- a/mllm/backends/qnn/aot/visitor/Transpose.cpp +++ b/mllm/backends/qnn/aot/visitor/Transpose.cpp @@ -0,0 +1,78 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/utils/Common.hpp" +#include "mllm/compile/ir/linalg/Op.hpp" +#include "mllm/compile/ir/builtin/Attribute.hpp" +#include "mllm/backends/qnn/aot/QnnWrappersAPI.hpp" +#include "mllm/backends/qnn/aot/visitor/Transpose.hpp" +#include "mllm/backends/qnn/aot/passes/AOTCompileContext.hpp" +#include "mllm/core/aops/TransposeOp.hpp" +#include + +namespace mllm::qnn::aot { + +bool QnnAOTTransposePattern::isMatch(const mllm::ir::op_ptr_t& op) { + return op->isa_() && (op->getAttr("using_qnn") != nullptr); +} + +bool QnnAOTTransposePattern::rewrite(ir::IRWriter& writer, const ir::op_ptr_t& op) { + auto env = AOTCompileContext::getInstance().getEnv(); + + auto transpose_op = op->cast_(); + if (!transpose_op) { + MLLM_ERROR("Failed to cast to linalg::TransposeOp"); + return false; + } + + MLLM_RETURN_FALSE_IF_NOT(op->getAttr("qnn_graph_name")); + auto qnn_graph_name = op->getAttr("qnn_graph_name")->cast_()->data(); + MLLM_RETURN_FALSE_IF_NOT(op->getAttr("qnn_context_name")); + auto qnn_context_name = op->getAttr("qnn_context_name")->cast_()->data(); + + auto input = op->inputs().front()->cast_(); + auto output = op->outputs().front()->cast_(); + + auto base_op = transpose_op->getAOp(); + auto real_transpose_op = dynamic_cast(base_op); + if (!real_transpose_op) { + MLLM_ERROR("Failed to cast BaseOp to mllm::aops::TransposeOp"); + return false; + } + + const auto& options = real_transpose_op->options(); + + // Calculate perm + auto input_shape = input->tensor_.shape(); + int rank = input_shape.size(); + + std::vector perm(rank); + for (int i = 0; i < rank; ++i) { perm[i] = i; } + + if (options.dim0 < rank && options.dim1 < rank) { std::swap(perm[options.dim0], perm[options.dim1]); } + + // Create QNN Op Node + auto qnn_op_node = QnnAOTNodeOperation::create("Transpose"); + qnn_op_node->setPackageName("qti.aisw"); + + // Add Input + qnn_op_node->emplaceInput(env->captureQnnAOTNodeTensor(qnn_context_name, qnn_graph_name, input)); + + // Add Perm Param + auto permName = base_op->getName() + ".perm"; + auto permParam = QNNParamTensorWrapper::create("perm", permName, QNN_DATATYPE_UINT_32, std::vector{(uint32_t)rank}); + uint32_t* permData = static_cast(permParam->alloc()); + std::memcpy(permData, perm.data(), rank * sizeof(uint32_t)); + qnn_op_node->emplaceParamTensor(permParam); + + // Add Output + qnn_op_node->emplaceOutput(env->captureQnnAOTNodeTensor(qnn_context_name, qnn_graph_name, output)) + ->setName(base_op->getName()); + + // Register + env->captureAOTNodeOp(qnn_context_name, qnn_graph_name, qnn_op_node); + + return true; +} + +} // namespace mllm::qnn::aot diff --git a/mllm/backends/qnn/aot/visitor/Transpose.hpp b/mllm/backends/qnn/aot/visitor/Transpose.hpp index e69de29bb..85e0593cb 100644 --- a/mllm/backends/qnn/aot/visitor/Transpose.hpp +++ b/mllm/backends/qnn/aot/visitor/Transpose.hpp @@ -0,0 +1,23 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/OpTypes.hpp" +#include "mllm/compile/ir/Node.hpp" +#include "mllm/backends/qnn/aot/visitor/Base.hpp" + +namespace mllm::qnn::aot { + +class QnnAOTTransposePattern : public QnnAOTBasePattern { + public: + bool isMatch(const mllm::ir::op_ptr_t& op) override; + + bool rewrite(ir::IRWriter& writer, const ir::op_ptr_t& op) override; + + static inline std::pair> create() { + return {OpTypes::kTranspose, std::make_shared()}; + } +}; + +} // namespace mllm::qnn::aot diff --git a/mllm/backends/qnn/aot/visitor/Where.cpp b/mllm/backends/qnn/aot/visitor/Where.cpp index e69de29bb..5490e7234 100644 --- a/mllm/backends/qnn/aot/visitor/Where.cpp +++ b/mllm/backends/qnn/aot/visitor/Where.cpp @@ -0,0 +1,55 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/utils/Common.hpp" +#include "mllm/compile/ir/linalg/Op.hpp" +#include "mllm/compile/ir/builtin/Attribute.hpp" +#include "mllm/backends/qnn/aot/QnnWrappersAPI.hpp" +#include "mllm/backends/qnn/aot/visitor/Where.hpp" +#include "mllm/backends/qnn/aot/passes/AOTCompileContext.hpp" + +namespace mllm::qnn::aot { + +bool QnnAOTWherePattern::isMatch(const mllm::ir::op_ptr_t& op) { + return op->isa_() && (op->getAttr("using_qnn") != nullptr); +} + +bool QnnAOTWherePattern::rewrite(ir::IRWriter& writer, const ir::op_ptr_t& op) { + auto env = AOTCompileContext::getInstance().getEnv(); + + auto where_op = op->cast_(); + if (!where_op) { + MLLM_ERROR("Failed to cast to linalg::WhereOp"); + return false; + } + + MLLM_RETURN_FALSE_IF_NOT(op->getAttr("qnn_graph_name")); + auto qnn_graph_name = op->getAttr("qnn_graph_name")->cast_()->data(); + MLLM_RETURN_FALSE_IF_NOT(op->getAttr("qnn_context_name")); + auto qnn_context_name = op->getAttr("qnn_context_name")->cast_()->data(); + + // Inputs: Condition, True, False + auto condition = op->inputs().front()->cast_(); + auto true_input = (*std::next(op->inputs().begin(), 1))->cast_(); + auto false_input = (*std::next(op->inputs().begin(), 2))->cast_(); + + // Output + auto output = op->outputs().front()->cast_(); + + // Create QNN ElementWiseSelect Op + auto qnn_op_node = QnnAOTNodeOperation::create("ElementWiseSelect"); + qnn_op_node->setPackageName("qti.aisw"); + + qnn_op_node->emplaceInput(env->captureQnnAOTNodeTensor(qnn_context_name, qnn_graph_name, condition)) + ->emplaceInput(env->captureQnnAOTNodeTensor(qnn_context_name, qnn_graph_name, true_input)) + ->emplaceInput(env->captureQnnAOTNodeTensor(qnn_context_name, qnn_graph_name, false_input)) + ->emplaceOutput(env->captureQnnAOTNodeTensor(qnn_context_name, qnn_graph_name, output)) + ->setName(where_op->getAOp()->getName()); + + // Register this op node into one graph. + env->captureAOTNodeOp(qnn_context_name, qnn_graph_name, qnn_op_node); + + return true; +} + +} // namespace mllm::qnn::aot diff --git a/mllm/backends/qnn/aot/visitor/Where.hpp b/mllm/backends/qnn/aot/visitor/Where.hpp index e69de29bb..7538514a4 100644 --- a/mllm/backends/qnn/aot/visitor/Where.hpp +++ b/mllm/backends/qnn/aot/visitor/Where.hpp @@ -0,0 +1,23 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/OpTypes.hpp" +#include "mllm/compile/ir/Node.hpp" +#include "mllm/backends/qnn/aot/visitor/Base.hpp" + +namespace mllm::qnn::aot { + +class QnnAOTWherePattern : public QnnAOTBasePattern { + public: + bool isMatch(const mllm::ir::op_ptr_t& op) override; + + bool rewrite(ir::IRWriter& writer, const ir::op_ptr_t& op) override; + + static inline std::pair> create() { + return {OpTypes::kWhere, std::make_shared()}; + } +}; + +} // namespace mllm::qnn::aot diff --git a/mllm/core/aops/GatherOp.cpp b/mllm/core/aops/GatherOp.cpp new file mode 100644 index 000000000..86591e151 --- /dev/null +++ b/mllm/core/aops/GatherOp.cpp @@ -0,0 +1,46 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/core/aops/GatherOp.hpp" +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/Tensor.hpp" +#include "mllm/utils/Common.hpp" +#include "mllm/compile/ir/linalg/Op.hpp" + +namespace mllm::aops { + +GatherOp::GatherOp(const GatherOpOptions& options) : BaseOp(OpTypes::kGather), options_(options) {} + +void GatherOp::trace(void* trace_context, const std::vector& inputs, std::vector& outputs) { + auto ir_ctx = (ir::IRContext*)trace_context; + auto i_irs = ir::tensor::wrapTensors2TensorIR(ir_ctx, inputs); + auto o_irs = ir::tensor::wrapTensors2TensorIR(ir_ctx, outputs); + ir_ctx->create(shared_from_this(), i_irs, o_irs); +} + +void GatherOp::forward(const std::vector& inputs, std::vector& outputs) { + NYI("GatherOp::forward not implemented in aops base."); +} + +void GatherOp::reshape(const std::vector& inputs, std::vector& outputs) { + auto& table = inputs[0]; + auto& indices = inputs[1]; + + auto shape = table.shape(); + auto indices_shape = indices.shape(); + int dim = options_.dim; + if (dim < 0) dim += shape.size(); + + MLLM_RT_ASSERT(dim >= 0 && dim < shape.size()); + + std::vector new_shape; + new_shape.reserve(dim); + for (int i = 0; i < dim; ++i) new_shape.push_back(shape[i]); + for (int s : indices_shape) new_shape.push_back(s); + for (int i = dim + 1; i < shape.size(); ++i) new_shape.push_back(shape[i]); + + auto o = Tensor::empty(new_shape, table.dtype(), table.device()); + outputs.emplace_back(o); +} + +} // namespace mllm::aops diff --git a/mllm/core/aops/GatherOp.hpp b/mllm/core/aops/GatherOp.hpp new file mode 100644 index 000000000..cab952e5e --- /dev/null +++ b/mllm/core/aops/GatherOp.hpp @@ -0,0 +1,30 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/BaseOp.hpp" + +namespace mllm::aops { + +struct GatherOpOptions : public BaseOpOptions { + int dim; +}; + +class GatherOp : public BaseOp { + public: + explicit GatherOp(const GatherOpOptions& options); + + void trace(void* trace_context, const std::vector& inputs, std::vector& outputs) override; + + void forward(const std::vector& inputs, std::vector& outputs) override; + + void reshape(const std::vector& inputs, std::vector& outputs) override; + + inline const GatherOpOptions& options() const { return options_; } + + protected: + GatherOpOptions options_; +}; + +} // namespace mllm::aops diff --git a/mllm/nn/Functional.cpp b/mllm/nn/Functional.cpp index ab0c12f80..4e70b092a 100644 --- a/mllm/nn/Functional.cpp +++ b/mllm/nn/Functional.cpp @@ -5,6 +5,7 @@ #include "mllm/core/aops/ConcatOp.hpp" #include "mllm/core/aops/ElewiseOps.hpp" #include "mllm/core/aops/FlashAttention2Op.hpp" +#include "mllm/core/aops/GatherOp.hpp" #include "mllm/core/aops/MatMulOp.hpp" #include "mllm/core/aops/ReduceOps.hpp" #include "mllm/core/aops/Scatter2ShardsOp.hpp" @@ -211,4 +212,9 @@ mllm::Tensor sigmoid(const Tensor& x) { return ctx.buildOpAndSubmitTask(OpTypes::kSigmoid, aops::SigmoidOpOptions{}, {x})[0]; } +mllm::Tensor gather(const Tensor& x, int dim, const Tensor& indices) { + auto& ctx = mllm::Context::instance(); + return ctx.buildOpAndSubmitTask(OpTypes::kGather, aops::GatherOpOptions{.dim = dim}, {x, indices})[0]; +} + } // namespace mllm::nn::functional diff --git a/mllm/nn/Functional.hpp b/mllm/nn/Functional.hpp index bd0cca9dd..31a57812c 100644 --- a/mllm/nn/Functional.hpp +++ b/mllm/nn/Functional.hpp @@ -162,4 +162,6 @@ mllm::Tensor where(const Tensor& mask, const Tensor& original, const Tensor& v); mllm::Tensor sigmoid(const Tensor& x); +mllm::Tensor gather(const Tensor& x, int dim, const Tensor& indices); + } // namespace mllm::nn::functional