Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/coreml/builders/helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ bool IsInputSupported(const NodeArg& input, const std::string& parent_name, cons
return false;
}

// For some undocuemented reason, apple CoreML lib will fail loading the model if the model has
// dimension > 16384
// For some undocumented reason, Apple CoreML framework will fail loading the model if the model
// input has dimension > 16384
// See this issue, https://github.com/apple/coremltools/issues/1003
if (dim.dim_value() > 16384) {
LOGS(logger, WARNING) << "CoreML does not support input dim > 16384, input:" << input_name
Expand Down
18 changes: 10 additions & 8 deletions onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <core/common/safeint.h>
#include <core/providers/common.h>
#include "core/providers/shared/utils/utils.h"

#include "builder_utils.h"
#include "coreml/NeuralNetwork.pb.h"
Expand Down Expand Up @@ -80,18 +82,18 @@ common::Status HandleAutoPad(const std::vector<int64_t> input_shape,
return Status::OK();
}

void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight,
const float* data, size_t num_elements) {
*weight.mutable_floatvalue() = {data, data + num_elements};
}
Comment thread
guoyu-wang marked this conversation as resolved.

common::Status CreateCoreMLWeight(CoreML::Specification::WeightParams& weight,
const ONNX_NAMESPACE::TensorProto& tensor) {
auto data_type = tensor.data_type();
if (data_type = ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
const float* data =
tensor.float_data().empty() ? reinterpret_cast<const float*>(tensor.raw_data().data())
: tensor.float_data().data();

weight.mutable_floatvalue()->Clear();
auto num_elements = Product(tensor.dims());
std::copy(data, data + num_elements,
google::protobuf::RepeatedFieldBackInserter(weight.mutable_floatvalue()));
const float* data = GetTensorFloatData(tensor);
auto num_elements = SafeInt<size_t>(Product(tensor.dims()));
CreateCoreMLWeight(weight, data, num_elements);
} else {
// TODO: support other type
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,9 @@ common::Status HandleAutoPad(const std::vector<int64_t> input_shape,
common::Status CreateCoreMLWeight(CoreML::Specification::WeightParams& weight,
const ONNX_NAMESPACE::TensorProto& tensor);

// Copy the float array to a coreml weight
void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight,
const float* data, size_t num_elements);

} // namespace coreml
} // namespace onnxruntime
233 changes: 233 additions & 0 deletions onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <core/common/safeint.h>
#include "core/providers/common.h"
#include "core/providers/shared/utils/utils.h"
#include "core/providers/coreml/builders/helper.h"
#include "core/providers/coreml/builders/model_builder.h"
#include "core/providers/coreml/builders/op_builder_factory.h"

#include "base_op_builder.h"
#include "builder_utils.h"

namespace onnxruntime {
namespace coreml {

class GemmOpBuilder : public BaseOpBuilder {
// Add operator related
public:
void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override;

private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;

// Operator support related
private:
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& /* node */,
const logging::Logger& /* logger */) const override;
};

// Add operator related

void GemmOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
const auto& op = node.OpType();
const auto& input_defs(node.InputDefs());
// We have already embedded the weights (matrix B and C(if any)) into the coreml layer
// No need to copy them later to reduce memory consumption
model_builder.AddInitializerToSkip(input_defs[1]->Name());
if (op == "Gemm" && input_defs.size() > 2) {
model_builder.AddInitializerToSkip(input_defs[2]->Name());
}
}
Comment thread
guoyu-wang marked this conversation as resolved.

// This is an internal function, requires input tensor to be 2d float tensor
// TODO, add support of other data types
static std::vector<float> GetTensorFloatDataTransposed(const ONNX_NAMESPACE::TensorProto& tensor) {
const float* src_data = GetTensorFloatData(tensor);
const auto& tensor_shape = tensor.dims();
auto x_t = SafeInt<size_t>(tensor_shape[0]);
auto y_t = SafeInt<size_t>(tensor_shape[1]);
std::vector<float> transposed_data(x_t * y_t);
for (size_t x = 0; x < x_t; x++) {
for (size_t y = 0; y < y_t; y++) {
transposed_data[y * x_t + x] = src_data[x * y_t + y];
}
}

return transposed_data;
}

Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& /* logger */) const {
std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = CreateNNLayer(node);

const auto& op_type = node.OpType();
const auto& input_defs = node.InputDefs();
const auto& b_tensor = *model_builder.GetInitializerTensors().at(input_defs[1]->Name());
const auto& b_shape = b_tensor.dims();

auto* coreml_inner_product = layer->mutable_innerproduct();

// The coreml innerproduct weight (matrix B) is stored transposed
// - for MatMul and Gemm (transB = 0), the coreml weight is B'
// - for Gemm (transB = 1), the coreml weight is B
if (op_type == "MatMul") {
coreml_inner_product->set_inputchannels(b_shape[0]);
coreml_inner_product->set_outputchannels(b_shape[1]);
// Add weight (b of MatMul)
const auto b_transposed = GetTensorFloatDataTransposed(b_tensor);
CreateCoreMLWeight(*coreml_inner_product->mutable_weights(), b_transposed.data(), b_transposed.size());
} else { // Gemm
NodeAttrHelper helper(node);
const auto transB = helper.Get("transB", 0);
if (transB == 0) {
coreml_inner_product->set_inputchannels(b_shape[0]);
coreml_inner_product->set_outputchannels(b_shape[1]);
const auto b_transposed = GetTensorFloatDataTransposed(b_tensor);
CreateCoreMLWeight(*coreml_inner_product->mutable_weights(), b_transposed.data(), b_transposed.size());
} else {
coreml_inner_product->set_inputchannels(b_shape[1]);
coreml_inner_product->set_outputchannels(b_shape[0]);
// Add weight (b of MatMul)
CreateCoreMLWeight(*coreml_inner_product->mutable_weights(), b_tensor);
}

// Add bias if present
if (input_defs.size() > 2) {
coreml_inner_product->set_hasbias(true);
const auto& bias_tensor = *model_builder.GetInitializerTensors().at(input_defs[2]->Name());
CreateCoreMLWeight(*coreml_inner_product->mutable_bias(), bias_tensor);
}
}

*layer->mutable_input()->Add() = input_defs[0]->Name();
*layer->mutable_output()->Add() = node.OutputDefs()[0]->Name();

model_builder.AddLayer(std::move(layer));
return Status::OK();
}

// Operator support related

bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const logging::Logger& logger) const {
const auto& op_type = node.OpType();
const auto& input_defs(node.InputDefs());
size_t a_idx = 0, b_idx = 1, c_idx = 2; // A*B+C

if (!Contains(initializers, input_defs[b_idx]->Name())) {
LOGS(logger, VERBOSE) << "B of Gemm/Matmul must be an initializer tensor";
return false;
}

std::vector<int64_t> a_shape;
{
if (!GetShape(*input_defs[a_idx], a_shape, logger))
return false;

if (a_shape.size() != 2) {
LOGS(logger, VERBOSE) << "A must be 2D";
return false;
}

if (Product(a_shape) == 0) {
LOGS(logger, VERBOSE) << "A must be non-empty";
return false;
}
}

std::vector<int64_t> b_shape;
{
if (!GetShape(*input_defs[b_idx], b_shape, logger))
return false;

if (b_shape.size() != 2) {
LOGS(logger, VERBOSE) << "B must be 2D";
return false;
}

if (Product(b_shape) == 0) {
LOGS(logger, VERBOSE) << "B must be non-empty";
return false;
}
}

if (op_type == "Gemm") {
NodeAttrHelper helper(node);
const auto transA = helper.Get("transA", 0);
const auto transB = helper.Get("transB", 0);
const auto alpha = helper.Get("alpha", 1.0f);
const auto beta = helper.Get("beta", 1.0f);
if (!(transA == 0 && alpha == 1.f && beta == 1.f)) {
LOGS(logger, VERBOSE) << "Only transA == 0, alpha == 1.0 "
<< "and beta == 1.0 is supported."
<< " transA " << transA
<< " alpha " << alpha
<< " beta " << beta;
return false;
}

// C of Gemm
// For now we only support {n} or {1,n} tensor
if (input_defs.size() == 3) {
if (!Contains(initializers, input_defs[c_idx]->Name())) {
LOGS(logger, VERBOSE) << "C of Gemm must be an initializer tensor";
return false;
}

std::vector<int64_t> c_shape;
if (!GetShape(*input_defs[c_idx], c_shape, logger))
return false;

size_t c_dim = c_shape.size();

if (c_dim == 0) {
LOGS(logger, VERBOSE) << "C of Gemm cannot be a scalar";
return false;
}

if (c_dim != 1) {
// If C is a (2+)d tensor, it must have the format {1, 1, ..., 1, n}
// where every except the last dimension should be 1
for (size_t i = 0; i < c_dim - 1; ++i) {
if (c_shape[i] != 1) {
LOGS(logger, VERBOSE) << "C of Gemm must be a vector or a tensor with only last dimension != 1";
return false;
}
}
}

auto c_size = c_shape[c_dim - 1];
if (c_size != (transB == 0 ? b_shape[1] : b_shape[0])) {
LOGS(logger, VERBOSE) << "C of Gemm must be a vector of b_shape["
<< (transB == 0 ? "1" : "0") << "]"
<< " b_shape: [" << b_shape[0] << ", " << b_shape[1] << "]"
<< " c_size: " << c_size;

return false;
}
}
}

return true;
}

void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
if (op_registrations.op_builder_map.find(op_type) != op_registrations.op_builder_map.cend())
return;

static std::vector<std::string> op_types =
{
"Gemm",
"MatMul",
};

op_registrations.builders.push_back(onnxruntime::make_unique<GemmOpBuilder>());
for (const auto& op_type : op_types) {
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
}
}
} // namespace coreml
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,11 @@ Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
: target_shape_tensor.int64_data().data();

const auto size = target_shape_tensor.dims()[0];
std::vector<int64_t> target_shape;
std::copy(raw_target_shape, raw_target_shape + size, std::back_inserter(target_shape));

std::vector<int64_t> target_shape{raw_target_shape, raw_target_shape + size};
std::vector<int64_t> input_shape;
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape");
ReshapeHelper helper(TensorShape(input_shape), target_shape);
std::copy(target_shape.cbegin(), target_shape.cend(),
google::protobuf::RepeatedFieldBackInserter(
layer->mutable_reshapestatic()->mutable_targetshape()));

*layer->mutable_reshapestatic()->mutable_targetshape() = {target_shape.cbegin(), target_shape.cend()};
*layer->mutable_input()->Add() = input_defs[0]->Name();
*layer->mutable_output()->Add() = node.OutputDefs()[0]->Name();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
CreateResizeOpBuilder("Resize", op_registrations);
}

{ // Gemm/MatMul
CreateGemmOpBuilder("Gemm", op_registrations);
CreateGemmOpBuilder("MatMul", op_registrations);
}

return op_registrations;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,6 @@ void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& o

void CreateActivationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreatePoolOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);

void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
} // namespace coreml
} // namespace onnxruntime
Loading