From f07fdf96b4ccc0a9c6e87fe2562e9359be85d3af Mon Sep 17 00:00:00 2001 From: Paul McDaniel Date: Fri, 15 Nov 2019 10:54:44 -0800 Subject: [PATCH 1/6] model moved over. everything builds clean. step ! --- cmake/winml.cmake | 26 ++-- winml/lib/Api.Core/ModelInfo.cpp | 116 --------------- winml/lib/Api.Core/WinMLAdapter.cpp | 195 +++++++++++++++++++++++++- winml/lib/Api.Core/inc/ModelInfo.h | 20 --- winml/lib/Api.Core/inc/WinMLAdapter.h | 23 ++- winml/lib/Api/LearningModel.cpp | 74 ++++------ winml/lib/Api/LearningModel.h | 12 +- 7 files changed, 251 insertions(+), 215 deletions(-) diff --git a/cmake/winml.cmake b/cmake/winml.cmake index 94e7db025bbd7..3fa0f4e72afcd 100644 --- a/cmake/winml.cmake +++ b/cmake/winml.cmake @@ -471,24 +471,24 @@ if("${CMAKE_BUILD_TYPE}" STREQUAL "Debug") endif("${CMAKE_BUILD_TYPE}" STREQUAL "Debug") # Link libraries -target_link_libraries(winml_dll PRIVATE libprotobuf) -target_link_libraries(winml_dll PRIVATE onnx) -target_link_libraries(winml_dll PRIVATE onnxruntime_common) -target_link_libraries(winml_dll PRIVATE onnxruntime_graph) -target_link_libraries(winml_dll PRIVATE onnxruntime_framework) -target_link_libraries(winml_dll PRIVATE onnxruntime_mlas) -target_link_libraries(winml_dll PRIVATE onnxruntime_optimizer) -target_link_libraries(winml_dll PRIVATE onnxruntime_providers) -target_link_libraries(winml_dll PRIVATE onnxruntime_providers_dml) -target_link_libraries(winml_dll PRIVATE onnxruntime_session) -target_link_libraries(winml_dll PRIVATE onnxruntime_util) -target_link_libraries(winml_dll PRIVATE onnx_proto) +#target_link_libraries(winml_dll PRIVATE libprotobuf) +#target_link_libraries(winml_dll PRIVATE onnx) +#target_link_libraries(winml_dll PRIVATE onnxruntime_common) +#target_link_libraries(winml_dll PRIVATE onnxruntime_graph) +#target_link_libraries(winml_dll PRIVATE onnxruntime_framework) +#target_link_libraries(winml_dll PRIVATE onnxruntime_mlas) +#target_link_libraries(winml_dll PRIVATE onnxruntime_optimizer) +#target_link_libraries(winml_dll PRIVATE onnxruntime_providers) +#target_link_libraries(winml_dll PRIVATE onnxruntime_providers_dml) +#target_link_libraries(winml_dll PRIVATE onnxruntime_session) +#target_link_libraries(winml_dll PRIVATE onnxruntime_util) +#target_link_libraries(winml_dll PRIVATE onnx_proto) target_link_libraries(winml_dll PRIVATE onnxruntime) target_link_libraries(winml_dll PRIVATE re2) target_link_libraries(winml_dll PRIVATE wil) target_link_libraries(winml_dll PRIVATE windowsapp.lib) target_link_libraries(winml_dll PRIVATE winml_lib_api) -target_link_libraries(winml_dll PRIVATE winml_lib_core) +#target_link_libraries(winml_dll PRIVATE winml_lib_core) target_link_libraries(winml_dll PRIVATE winml_lib_image) target_link_libraries(winml_dll PRIVATE winml_lib_telemetry) target_link_libraries(winml_dll PRIVATE ${DBGHELP}) diff --git a/winml/lib/Api.Core/ModelInfo.cpp b/winml/lib/Api.Core/ModelInfo.cpp index 5d34cb50623b3..46e9bc29bfa3d 100644 --- a/winml/lib/Api.Core/ModelInfo.cpp +++ b/winml/lib/Api.Core/ModelInfo.cpp @@ -10,121 +10,5 @@ using namespace Windows::AI::MachineLearning; -static std::vector -GetAllNodeOutputs(const onnx::ModelProto& model_proto) { - std::vector nodes_outputs; - auto& graph = model_proto.graph(); - auto& nodes = graph.node(); - for (auto& node : nodes) { - for (auto& node_output : node.output()) { - nodes_outputs.push_back(node_output.c_str()); - } - } - return nodes_outputs; -} - -static std::vector -GetInitializers(const onnx::ModelProto& model_proto) { - std::vector initializers; - auto& graph = model_proto.graph(); - auto& graph_initializers = graph.initializer(); - for (auto& initializer : graph_initializers) { - initializers.push_back(initializer.name().c_str()); - } - return initializers; -} - -static std::vector -GetInputsWithoutInitializers(const onnx::ModelProto& model_proto) { - auto initializers = GetInitializers(model_proto); - - std::vector inputs_without_initializers; - auto& graph = model_proto.graph(); - auto& inputs = graph.input(); - for (auto& input : inputs) { - if (input.has_name() && input.has_type()) { - auto found_it = std::find_if( - std::begin(initializers), - std::end(initializers), - [&](auto& initializer) { - return std::strcmp(initializer, input.name().c_str()) == 0; - }); - - auto is_initializer = found_it != std::end(initializers); - if (!is_initializer) { - inputs_without_initializers.push_back(&input); - } - } - } - return inputs_without_initializers; -} - -static std::vector -GetOutputs(const onnx::ModelProto& model_proto) { - std::vector outputs_with_name; - auto& graph = model_proto.graph(); - auto& outputs = graph.output(); - for (auto& output : outputs) { - if (output.has_name() && output.has_type()) { - outputs_with_name.push_back(&output); - } - } - return outputs_with_name; -} - -ModelInfo::ModelInfo( - const onnx::ModelProto* model_proto) { - Initialize(model_proto); -} - -void ModelInfo::Initialize( - const onnx::ModelProto* model_proto) { - // metadata - for (auto& prop : model_proto->metadata_props()) { - model_metadata_[prop.key()] = prop.value(); - } - - WinML::FeatureDescriptorFactory builder(model_metadata_); - - // Create inputs - auto inputs = GetInputsWithoutInitializers(*model_proto); - input_features_ = builder.CreateDescriptorsFromValueInfoProtos(inputs); - - // Create outputs - auto outputs = ::GetOutputs(*model_proto); - output_features_ = builder.CreateDescriptorsFromValueInfoProtos(outputs); - - // author - auto has_producer_name = model_proto->has_producer_name(); - author_ = has_producer_name - ? model_proto->producer_name() - : ""; - - // domain - auto has_domain = model_proto->has_domain(); - domain_ = has_domain - ? model_proto->domain() - : ""; - - // name - auto has_graph = model_proto->has_graph(); - auto graph_has_name = model_proto->graph().has_name(); - auto is_name_available = has_graph && graph_has_name; - name_ = is_name_available - ? model_proto->graph().name() - : ""; - - // description - auto has_description = model_proto->has_doc_string(); - description_ = has_description - ? model_proto->doc_string() - : ""; - - // version - auto has_version = model_proto->has_model_version(); - version_ = has_version - ? model_proto->model_version() - : 0; -} diff --git a/winml/lib/Api.Core/WinMLAdapter.cpp b/winml/lib/Api.Core/WinMLAdapter.cpp index 105705b7a917a..b40318f381537 100644 --- a/winml/lib/Api.Core/WinMLAdapter.cpp +++ b/winml/lib/Api.Core/WinMLAdapter.cpp @@ -4,10 +4,12 @@ #include "pch.h" #include "inc/WinMLAdapter.h" #include "inc/CustomRegistryHelper.h" +#include "PheonixSingleton.h" #include "inc/LotusEnvironment.h" #include "inc/AbiCustomRegistryImpl.h" #include "core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h" #include "core/providers/dml/GraphTransformers/GraphTransformerHelpers.h" +#include "core/providers/dml/OperatorAuthorHelper/SchemaInferenceOverrider.h" #include "LearningModelDevice.h" #include "TensorFeatureDescriptor.h" @@ -25,6 +27,8 @@ #include "ZeroCopyInputStreamWrapper.h" #include "google/protobuf/io/zero_copy_stream_impl.h" +#include "FeatureDescriptorFactory.h" + using namespace winrt::Windows::AI::MachineLearning; @@ -111,7 +115,7 @@ class AbiSafeOrtValue : public Microsoft::WRL::RuntimeClass < *tensor = tensor_outer.Detach(); return S_OK; } -}; +}; // class AbiSafeOrtValue class ModelProto : public Microsoft::WRL::RuntimeClass< Microsoft::WRL::RuntimeClassFlags, @@ -128,12 +132,178 @@ class ModelProto : public Microsoft::WRL::RuntimeClass< private: std::shared_ptr model_proto_; -}; +}; // class ModelProto + + +class ModelInfo : public Microsoft::WRL::RuntimeClass< + Microsoft::WRL::RuntimeClassFlags, + IModelInfo> { + +private: + std::string author_; + std::string name_; + std::string domain_; + std::string description_; + int64_t version_; + std::unordered_map model_metadata_; + wfc::IVector input_features_; + wfc::IVector output_features_; + +public: + + ModelInfo(const onnx::ModelProto* model_proto) { + Initialize(model_proto); + } + + std::string STDMETHODCALLTYPE author() override { + return author_; + } + std::string STDMETHODCALLTYPE name() override { + return name_; + } + std::string STDMETHODCALLTYPE domain() override { + return domain_; + } + std::string STDMETHODCALLTYPE description() override { + return description_; + } + int64_t STDMETHODCALLTYPE version() override { + return version_; + } + std::unordered_map STDMETHODCALLTYPE model_metadata() override { + return model_metadata_; + } + wfc::IVector STDMETHODCALLTYPE input_features() override { + return input_features_; + } + wfc::IVector STDMETHODCALLTYPE output_features() override { + return output_features_; + } + + static std::vector + GetAllNodeOutputs(const onnx::ModelProto& model_proto) { + std::vector nodes_outputs; + auto& graph = model_proto.graph(); + auto& nodes = graph.node(); + for (auto& node : nodes) { + for (auto& node_output : node.output()) { + nodes_outputs.push_back(node_output.c_str()); + } + } + return nodes_outputs; + } + + static std::vector + GetInitializers(const onnx::ModelProto& model_proto) { + std::vector initializers; + auto& graph = model_proto.graph(); + auto& graph_initializers = graph.initializer(); + for (auto& initializer : graph_initializers) { + initializers.push_back(initializer.name().c_str()); + } + return initializers; + } + + static std::vector + GetInputsWithoutInitializers(const onnx::ModelProto& model_proto) { + auto initializers = GetInitializers(model_proto); + + std::vector inputs_without_initializers; + auto& graph = model_proto.graph(); + auto& inputs = graph.input(); + for (auto& input : inputs) { + if (input.has_name() && input.has_type()) { + auto found_it = std::find_if( + std::begin(initializers), + std::end(initializers), + [&](auto& initializer) { + return std::strcmp(initializer, input.name().c_str()) == 0; + }); + + auto is_initializer = found_it != std::end(initializers); + if (!is_initializer) { + inputs_without_initializers.push_back(&input); + } + } + } + return inputs_without_initializers; + } + + static + std::vector GetOutputs(const onnx::ModelProto& model_proto) { + std::vector outputs_with_name; + auto& graph = model_proto.graph(); + auto& outputs = graph.output(); + for (auto& output : outputs) { + if (output.has_name() && output.has_type()) { + outputs_with_name.push_back(&output); + } + } + return outputs_with_name; + } + +private: + void Initialize(const onnx::ModelProto* model_proto) { + // metadata + for (auto& prop : model_proto->metadata_props()) { + model_metadata_[prop.key()] = prop.value(); + } + + WinML::FeatureDescriptorFactory builder(model_metadata_); + + // Create inputs + auto inputs = GetInputsWithoutInitializers(*model_proto); + input_features_ = builder.CreateDescriptorsFromValueInfoProtos(inputs); + + // Create outputs + auto outputs = GetOutputs(*model_proto); + output_features_ = builder.CreateDescriptorsFromValueInfoProtos(outputs); + + // author + auto has_producer_name = model_proto->has_producer_name(); + author_ = has_producer_name + ? model_proto->producer_name() + : ""; + + // domain + auto has_domain = model_proto->has_domain(); + domain_ = has_domain + ? model_proto->domain() + : ""; + + // name + auto has_graph = model_proto->has_graph(); + auto graph_has_name = model_proto->graph().has_name(); + auto is_name_available = has_graph && graph_has_name; + name_ = is_name_available + ? model_proto->graph().name() + : ""; + + // description + auto has_description = model_proto->has_doc_string(); + description_ = has_description + ? model_proto->doc_string() + : ""; + + // version + auto has_version = model_proto->has_model_version(); + version_ = has_version + ? model_proto->model_version() + : 0; + } +}; // class ModelInfo class WinMLAdapter : public Microsoft::WRL::RuntimeClass< - Microsoft::WRL::RuntimeClassFlags, + Microsoft::WRL::RuntimeClassFlags, IWinMLAdapter> { +private: + std::shared_ptr lotus_environment_; + public: + WinMLAdapter() : lotus_environment_(PheonixSingleton()) { + + } + // factory methods for creating an ort model from a path HRESULT STDMETHODCALLTYPE CreateModelProto( const char* path, @@ -188,6 +358,12 @@ class WinMLAdapter : public Microsoft::WRL::RuntimeClass< return model_proto_outer.CopyTo(__uuidof(IModelProto), (void**)model_proto); } + HRESULT STDMETHODCALLTYPE CreateModelInfo(IModelProto * model_proto, IModelInfo ** model_info) override { + auto model_info_outer = wil::MakeOrThrow(model_proto->get()); + return model_info_outer.CopyTo(__uuidof(IModelInfo), (void**)model_info); + } + + void STDMETHODCALLTYPE EnableDebugOutput() override { WinML::CWinMLLogSink::EnableDebugOutput(); } @@ -516,6 +692,19 @@ class WinMLAdapter : public Microsoft::WRL::RuntimeClass< return S_OK; } + // Override select shape inference functions which are incomplete in ONNX with versions that are complete, + // and are also used in DML kernel registrations. Doing this avoids kernel and shader creation being + // deferred until first evaluation. It also prevents a situation where inference functions in externally + // registered schema are reachable only after upstream schema have been revised in a later OS release, + // which would be a compatibility risk. + HRESULT STDMETHODCALLTYPE OverrideSchemaInferenceFunctions() override { + static std::once_flag schema_override_once_flag; + std::call_once(schema_override_once_flag, []() { + SchemaInferenceOverrider::OverrideSchemaInferenceFunctions(); + }); + return S_OK; + } + }; diff --git a/winml/lib/Api.Core/inc/ModelInfo.h b/winml/lib/Api.Core/inc/ModelInfo.h index 95dd8a2988b6b..3546aee5416e5 100644 --- a/winml/lib/Api.Core/inc/ModelInfo.h +++ b/winml/lib/Api.Core/inc/ModelInfo.h @@ -3,27 +3,7 @@ #pragma once -#include "WinMLAdapter.h" - namespace Windows::AI::MachineLearning { -class ModelInfo { - public: - ModelInfo(const onnx::ModelProto* model_proto); - - public: - // model metadata - std::string author_; - std::string name_; - std::string domain_; - std::string description_; - int64_t version_; - std::unordered_map model_metadata_; - wfc::IVector input_features_; - wfc::IVector output_features_; - - private: - void Initialize(const onnx::ModelProto* model_proto); -}; } // namespace Windows::AI::MachineLearning \ No newline at end of file diff --git a/winml/lib/Api.Core/inc/WinMLAdapter.h b/winml/lib/Api.Core/inc/WinMLAdapter.h index 2ef2a3b9beac8..c31ffbd473dd9 100644 --- a/winml/lib/Api.Core/inc/WinMLAdapter.h +++ b/winml/lib/Api.Core/inc/WinMLAdapter.h @@ -4,9 +4,22 @@ #pragma once #include "IOrtSessionBuilder.h" +#include "ModelInfo.h" namespace Windows::AI::MachineLearning::Adapter { +MIDL_INTERFACE("eaae30b5-7381-432d-9730-322136b02371") IModelInfo : IUnknown{ + // model metadata + virtual std::string STDMETHODCALLTYPE author() = 0; + virtual std::string STDMETHODCALLTYPE name() = 0; + virtual std::string STDMETHODCALLTYPE domain() = 0; + virtual std::string STDMETHODCALLTYPE description() = 0; + virtual int64_t STDMETHODCALLTYPE version() = 0; + virtual std::unordered_map STDMETHODCALLTYPE model_metadata() = 0; + virtual wfc::IVector STDMETHODCALLTYPE input_features() = 0; + virtual wfc::IVector STDMETHODCALLTYPE output_features() = 0; +}; + MIDL_INTERFACE("eaae30b5-7381-432d-9730-322136b02371") ITensor : IUnknown{ // these all return weak pointers virtual const onnxruntime::Tensor& STDMETHODCALLTYPE get() = 0; @@ -92,14 +105,11 @@ MIDL_INTERFACE("b19385e7-d9af-441a-ba7f-3993c7b1c9db") IWinMLAdapter : IUnknown ID3D12CommandQueue* queue, IOrtSessionBuilder** session_builder) = 0; - // factory methods for creating an ort model from a path + // factory methods for creating model protos virtual HRESULT STDMETHODCALLTYPE CreateModelProto(const char* path, IModelProto** model_proto) = 0; - - // factory methods for creating an ort model from a stream virtual HRESULT STDMETHODCALLTYPE CreateModelProto(ABI::Windows::Storage::Streams::IRandomAccessStreamReference* stream_reference, IModelProto** model_proto) = 0; - - // factory methods for creating an ort model from a model_proto virtual HRESULT STDMETHODCALLTYPE CreateModelProto(IModelProto * model_proto_in, IModelProto** model_proto) = 0; + virtual HRESULT STDMETHODCALLTYPE CreateModelInfo(IModelProto * model_proto, IModelInfo ** model_info) = 0; // Data types virtual onnxruntime::MLDataType STDMETHODCALLTYPE GetTensorType() = 0; @@ -142,7 +152,8 @@ MIDL_INTERFACE("b19385e7-d9af-441a-ba7f-3993c7b1c9db") IWinMLAdapter : IUnknown onnxruntime::MLDataType data_type, IOrtValue ** ort_value) = 0; - + // schema overrides (dml does this for us) + virtual HRESULT STDMETHODCALLTYPE OverrideSchemaInferenceFunctions() = 0; }; extern "C" diff --git a/winml/lib/Api/LearningModel.cpp b/winml/lib/Api/LearningModel.cpp index 046bad17c029f..78f821e26e8be 100644 --- a/winml/lib/Api/LearningModel.cpp +++ b/winml/lib/Api/LearningModel.cpp @@ -5,10 +5,8 @@ #include "LearningModel.h" -#include "core/providers/dml/OperatorAuthorHelper/SchemaInferenceOverrider.h" #include "core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h" #include "ModelInfo.h" -#include "PheonixSingleton.h" #include "TelemetryEvent.h" #include "LotusEnvironment.h" @@ -21,21 +19,19 @@ namespace winrt::Windows::AI::MachineLearning::implementation { LearningModel::LearningModel( const hstring& path, const winml::ILearningModelOperatorProvider op_provider) try : LearningModel(WinML::Strings::UTF8FromHString(path), - op_provider) {} + op_provider) { +} WINML_CATCH_ALL LearningModel::LearningModel( const std::string& path, - const winml::ILearningModelOperatorProvider operator_provider) try : lotus_environment_(PheonixSingleton()), - operator_provider_(operator_provider) { + const winml::ILearningModelOperatorProvider operator_provider) try : operator_provider_(operator_provider) { _winmlt::PerformanceTelemetryEvent kLoadModel_event( WinMLRuntimePerf::kLoadModel); - OverrideShapeInferenceMethods(); - - com_ptr<_winmla::IWinMLAdapter> adapter; - WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter.put())); - WINML_THROW_IF_FAILED(adapter->CreateModelProto(path.c_str(), model_proto_.put())); + WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter_.put())); + WINML_THROW_IF_FAILED(adapter_->OverrideSchemaInferenceFunctions()); + WINML_THROW_IF_FAILED(adapter_->CreateModelProto(path.c_str(), model_proto_.put())); Initialize(); @@ -45,19 +41,16 @@ WINML_CATCH_ALL LearningModel::LearningModel( const wss::IRandomAccessStreamReference stream, - const winml::ILearningModelOperatorProvider operator_provider) try : lotus_environment_(PheonixSingleton()), - operator_provider_(operator_provider) { + const winml::ILearningModelOperatorProvider operator_provider) try : operator_provider_(operator_provider) { _winmlt::PerformanceTelemetryEvent kLoadModel_event( WinMLRuntimePerf::kLoadModel); - OverrideShapeInferenceMethods(); - - com_ptr<_winmla::IWinMLAdapter> adapter; - WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter.put())); - WINML_THROW_IF_FAILED(adapter->CreateModelProto( - static_cast(winrt::get_abi(stream)), + WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter_.put())); + WINML_THROW_IF_FAILED(adapter_->OverrideSchemaInferenceFunctions()); + WINML_THROW_IF_FAILED(adapter_->CreateModelProto( + static_cast(winrt::get_abi(stream)), model_proto_.put())); - + Initialize(); LogCreationEvent(true); @@ -65,8 +58,7 @@ LearningModel::LearningModel( WINML_CATCH_ALL void LearningModel::Initialize() { - model_info_ = std::make_unique( - model_proto_.get()->get()); + WINML_THROW_IF_FAILED(adapter_->CreateModelInfo(model_proto_.get(), model_info_.put())); } void LearningModel::LogCreationEvent(bool fromStream) { @@ -80,13 +72,13 @@ void LearningModel::LogCreationEvent(bool fromStream) { } telemetry_helper.LogModelCreation( fromStream, - model_info_->author_, - model_info_->name_, - model_info_->domain_, - model_info_->description_, - model_info_->version_, + model_info_->author(), + model_info_->name(), + model_info_->domain(), + model_info_->description(), + model_info_->version(), use_fp16, - model_info_->model_metadata_); + model_info_->model_metadata()); } void LearningModel::ModelUseFP16( @@ -119,41 +111,41 @@ void LearningModel::ModelUseFP16( hstring LearningModel::Author() try { - return WinML::Strings::HStringFromUTF8(model_info_->author_); + return WinML::Strings::HStringFromUTF8(model_info_->author()); } WINML_CATCH_ALL hstring LearningModel::Name() try { return WinML::Strings::HStringFromUTF8( - model_info_->name_); + model_info_->name()); } WINML_CATCH_ALL hstring LearningModel::Domain() try { return WinML::Strings::HStringFromUTF8( - model_info_->domain_); + model_info_->domain()); } WINML_CATCH_ALL hstring LearningModel::Description() try { return WinML::Strings::HStringFromUTF8( - model_info_->description_); + model_info_->description()); } WINML_CATCH_ALL int64_t LearningModel::Version() try { - return model_info_->version_; + return model_info_->version(); } WINML_CATCH_ALL wfc::IMapView LearningModel::Metadata() try { std::unordered_map map_copy; - for (auto& pair : model_info_->model_metadata_) { + for (auto& pair : model_info_->model_metadata()) { auto key = WinML::Strings::HStringFromUTF8(pair.first); auto value = WinML::Strings::HStringFromUTF8(pair.second); map_copy.emplace(std::move(key), std::move(value)); @@ -183,13 +175,13 @@ LearningModel::GetOperatorRegistry() { wfc::IVectorView LearningModel::InputFeatures() try { - return model_info_->input_features_.GetView(); + return model_info_->input_features().GetView(); } WINML_CATCH_ALL wfc::IVectorView LearningModel::OutputFeatures() try { - return model_info_->output_features_.GetView(); + return model_info_->output_features().GetView(); } WINML_CATCH_ALL @@ -287,18 +279,6 @@ LearningModel::CopyModelProto() { return model_proto.detach(); } -static std::once_flag g_schema_override_once_flag; - -// Override select shape inference functions which are incomplete in ONNX with versions that are complete, -// and are also used in DML kernel registrations. Doing this avoids kernel and shader creation being -// deferred until first evaluation. It also prevents a situation where inference functions in externally -// registered schema are reachable only after upstream schema have been revised in a later OS release, -// which would be a compatibility risk. -void LearningModel::OverrideShapeInferenceMethods() { - std::call_once(g_schema_override_once_flag, []() { - SchemaInferenceOverrider::OverrideSchemaInferenceFunctions(); - }); -} } // namespace winrt::Windows::AI::MachineLearning::implementation namespace winrt::Windows::AI::MachineLearning::factory_implementation { diff --git a/winml/lib/Api/LearningModel.h b/winml/lib/Api/LearningModel.h index ef08359b6c730..c30057ea5c043 100644 --- a/winml/lib/Api/LearningModel.h +++ b/winml/lib/Api/LearningModel.h @@ -6,11 +6,6 @@ #include "LearningModel.g.h" #include "WinMLAdapter.h" - namespace Windows::AI::MachineLearning { - class LotusEnvironment; - class ModelInfo; -} // namespace Windows::AI::MachineLearning - namespace winrt::Windows::AI::MachineLearning::implementation { struct LearningModel : LearningModelT { @@ -121,13 +116,10 @@ struct LearningModel : LearningModelT { winml::ILearningModelFeatureDescriptor descriptor, bool& use_fp16); - void - OverrideShapeInferenceMethods(); - private: - std::shared_ptr lotus_environment_; + com_ptr<_winmla::IWinMLAdapter> adapter_; com_ptr<_winmla::IModelProto> model_proto_; - std::unique_ptr model_info_; + com_ptr<_winmla::IModelInfo> model_info_; ILearningModelOperatorProvider operator_provider_; }; From f32bbd5cb79b0d11d73ba640baff56c8ea76603d Mon Sep 17 00:00:00 2001 From: Paul McDaniel Date: Fri, 15 Nov 2019 13:15:57 -0800 Subject: [PATCH 2/6] weak ref comment --- winml/lib/Api.Core/inc/WinMLAdapter.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/winml/lib/Api.Core/inc/WinMLAdapter.h b/winml/lib/Api.Core/inc/WinMLAdapter.h index c31ffbd473dd9..3c74a3726ba0c 100644 --- a/winml/lib/Api.Core/inc/WinMLAdapter.h +++ b/winml/lib/Api.Core/inc/WinMLAdapter.h @@ -38,11 +38,12 @@ MIDL_INTERFACE("72aa5eee-100c-4146-9008-4643d3b8af23") IOrtValue : IUnknown{ virtual OrtValue& STDMETHODCALLTYPE get() = 0; virtual onnxruntime::MLDataType STDMETHODCALLTYPE Type() = 0; virtual bool STDMETHODCALLTYPE IsTensor() = 0; -// end + // end virtual HRESULT STDMETHODCALLTYPE GetTensor(ITensor ** tensor) = 0; }; MIDL_INTERFACE("438e7719-554a-4058-84d9-eb6226c34887") IIOBinding : IUnknown{ + // this returns a weak ref virtual onnxruntime::IOBinding* STDMETHODCALLTYPE get() = 0; virtual HRESULT STDMETHODCALLTYPE BindInput(const std::string& name, IOrtValue * ml_value) = 0; virtual HRESULT STDMETHODCALLTYPE BindOutput(const std::string& name, IOrtValue * ml_value) = 0; From 7f9a7f5abef28851a5f1dfb8eab920e6f4eb8a91 Mon Sep 17 00:00:00 2001 From: Paul McDaniel Date: Fri, 15 Nov 2019 16:47:33 -0800 Subject: [PATCH 3/6] added a wrapper for RoGetActivationFactory to hook back into winml for creating winml objects. fixes model load. --- .../lib/Api.Core/FeatureDescriptorFactory.cpp | 92 +++++++++++++++++++ winml/lib/Api.Core/WinMLAdapter.cpp | 2 +- 2 files changed, 93 insertions(+), 1 deletion(-) diff --git a/winml/lib/Api.Core/FeatureDescriptorFactory.cpp b/winml/lib/Api.Core/FeatureDescriptorFactory.cpp index 62ea8aba96a9e..cdc5a1b5bcf5a 100644 --- a/winml/lib/Api.Core/FeatureDescriptorFactory.cpp +++ b/winml/lib/Api.Core/FeatureDescriptorFactory.cpp @@ -41,6 +41,98 @@ static const char* c_supported_nominal_ranges[] = "NominalRange_0_255"}; namespace Windows::AI::MachineLearning { + + +// since this code is now running inside ONNXRUNTIME we need to shortcut +// this a bit when creating winrt objects. This will help. + +/* extern "C" +HRESULT __stdcall OS_RoGetActivationFactory(HSTRING classId, GUID const& iid, void** factory) noexcept; + +#ifdef _M_IX86 +#pragma comment(linker, "/alternatename:_OS_RoGetActivationFactory@12=_RoGetActivationFactory@12") +#else +#pragma comment(linker, "/alternatename:OS_RoGetActivationFactory=RoGetActivationFactory") +#endif +*/ + +bool starts_with(std::wstring_view value, std::wstring_view match) noexcept +{ + return 0 == value.compare(0, match.size(), match); +} + +EXTERN_C IMAGE_DOS_HEADER __ImageBase; + +std::wstring GetModulePath() +{ + std::wstring val; + wchar_t modulePath[MAX_PATH] = { 0 }; + GetModuleFileNameW((HINSTANCE)&__ImageBase, modulePath, _countof(modulePath)); + wchar_t drive[_MAX_DRIVE]; + wchar_t dir[_MAX_DIR]; + wchar_t filename[_MAX_FNAME]; + wchar_t ext[_MAX_EXT]; + _wsplitpath_s(modulePath, drive, _MAX_DRIVE, dir, _MAX_DIR, filename, _MAX_FNAME, ext, _MAX_EXT); + + val = drive; + val += dir; + + return val; +} + +extern "C" +int32_t WINRT_CALL WINRT_RoGetActivationFactory(void* classId, winrt::guid const& iid, void** factory) noexcept +{ + *factory = nullptr; + HSTRING classId_hstring = (HSTRING)classId; + std::wstring_view name{ WindowsGetStringRawBuffer(classId_hstring, nullptr), WindowsGetStringLen(classId_hstring) }; + HMODULE library{ nullptr }; + + std::wstring winmlDllPath = GetModulePath() + L"Windows.AI.MachineLearning.dll"; + + if (starts_with(name, L"Windows.AI.MachineLearning.")) + { + const wchar_t* libPath = winmlDllPath.c_str(); + library = LoadLibraryW(libPath); + } + else + { + return RoGetActivationFactory(classId_hstring, iid, factory); + } + + if (!library) + { + return HRESULT_FROM_WIN32(GetLastError()); + } + + using DllGetActivationFactory = HRESULT __stdcall(HSTRING classId, void** factory); + auto call = reinterpret_cast(GetProcAddress(library, "DllGetActivationFactory")); + + if (!call) + { + HRESULT const hr = HRESULT_FROM_WIN32(GetLastError()); + WINRT_VERIFY(FreeLibrary(library)); + return hr; + } + + winrt::com_ptr activation_factory; + HRESULT const hr = call(classId_hstring, activation_factory.put_void()); + + if (FAILED(hr)) + { + WINRT_VERIFY(FreeLibrary(library)); + return hr; + } + + if (winrt::guid(iid) != winrt::guid_of()) + { + return activation_factory->QueryInterface(iid, factory); + } + + *factory = activation_factory.detach(); + return S_OK; +} + // Forward declare CreateFeatureDescriptor static winml::ILearningModelFeatureDescriptor CreateFeatureDescriptor( diff --git a/winml/lib/Api.Core/WinMLAdapter.cpp b/winml/lib/Api.Core/WinMLAdapter.cpp index b40318f381537..a60118d5e466f 100644 --- a/winml/lib/Api.Core/WinMLAdapter.cpp +++ b/winml/lib/Api.Core/WinMLAdapter.cpp @@ -327,7 +327,7 @@ class WinMLAdapter : public Microsoft::WRL::RuntimeClass< auto model_proto_inner = new onnx::ModelProto(); THROW_HR_IF_MSG( E_INVALIDARG, - !model_proto_inner->ParseFromZeroCopyStream(&stream) == false, + model_proto_inner->ParseFromZeroCopyStream(&stream) == false, "The stream failed to parse."); auto model_proto_outer = wil::MakeOrThrow(model_proto_inner); From b4047a0aad7dcdeffa5731b917ae58f0c4903fd1 Mon Sep 17 00:00:00 2001 From: Paul McDaniel Date: Mon, 18 Nov 2019 09:50:25 -0800 Subject: [PATCH 4/6] fixed some lifetime management. fixed the debug build. squeezenet passes using winmlrunner for CPU and GPU --- cmake/winml.cmake | 4 -- .../src/ErrorHandling.cpp | 10 ---- winml/lib/Api.Core/CpuOrtSessionBuilder.cpp | 14 +++--- winml/lib/Api.Core/CpuOrtSessionBuilder.h | 4 +- winml/lib/Api.Core/DmlOrtSessionBuilder.cpp | 15 +++--- winml/lib/Api.Core/DmlOrtSessionBuilder.h | 4 +- winml/lib/Api.Core/ModelInfo.cpp | 14 ------ winml/lib/Api.Core/OrtSessionBuilder.cpp | 5 -- winml/lib/Api.Core/WinMLAdapter.cpp | 28 ++++++----- winml/lib/Api.Core/inc/IOrtSessionBuilder.h | 5 -- winml/lib/Api.Core/inc/ModelInfo.h | 9 ---- winml/lib/Api.Core/inc/WinMLAdapter.h | 47 ++++++++++++++----- winml/lib/Api/LearningModel.cpp | 1 - winml/lib/Api/LearningModelSession.cpp | 14 ++---- 14 files changed, 75 insertions(+), 99 deletions(-) delete mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/ErrorHandling.cpp delete mode 100644 winml/lib/Api.Core/ModelInfo.cpp delete mode 100644 winml/lib/Api.Core/OrtSessionBuilder.cpp delete mode 100644 winml/lib/Api.Core/inc/IOrtSessionBuilder.h delete mode 100644 winml/lib/Api.Core/inc/ModelInfo.h diff --git a/cmake/winml.cmake b/cmake/winml.cmake index 3fa0f4e72afcd..c0a370013e8b0 100644 --- a/cmake/winml.cmake +++ b/cmake/winml.cmake @@ -122,10 +122,8 @@ target_link_libraries(winml_lib_telemetry PRIVATE wil) add_library(winml_lib_core STATIC ${winml_lib_api_core_dir}/inc/AbiCustomRegistryImpl.h ${winml_lib_api_core_dir}/inc/CustomRegistryHelper.h - ${winml_lib_api_core_dir}/inc/IOrtSessionBuilder.h ${winml_lib_api_core_dir}/inc/LotusEnvironment.h ${winml_lib_api_core_dir}/inc/MLValueHelpers.h - ${winml_lib_api_core_dir}/inc/ModelInfo.h ${winml_lib_api_core_dir}/inc/TensorBaseHelpers.h ${winml_lib_api_core_dir}/inc/WinMLAdapter.h ${winml_lib_api_core_dir}/CpuOrtSessionBuilder.h @@ -138,8 +136,6 @@ add_library(winml_lib_core STATIC ${winml_lib_api_core_dir}/CpuOrtSessionBuilder.cpp ${winml_lib_api_core_dir}/DmlOrtSessionBuilder.cpp ${winml_lib_api_core_dir}/LotusEnvironment.cpp - ${winml_lib_api_core_dir}/ModelInfo.cpp - ${winml_lib_api_core_dir}/OrtSessionBuilder.cpp ${winml_lib_api_core_dir}/WinMLAdapter.cpp ${winml_lib_api_core_dir}/ZeroCopyInputStreamWrapper.cpp ) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ErrorHandling.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ErrorHandling.cpp deleted file mode 100644 index 88943399be823..0000000000000 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ErrorHandling.cpp +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "precomp.h" - -namespace Dml -{ - - -} \ No newline at end of file diff --git a/winml/lib/Api.Core/CpuOrtSessionBuilder.cpp b/winml/lib/Api.Core/CpuOrtSessionBuilder.cpp index f57bed7d9670a..7b7002a2a39ce 100644 --- a/winml/lib/Api.Core/CpuOrtSessionBuilder.cpp +++ b/winml/lib/Api.Core/CpuOrtSessionBuilder.cpp @@ -32,24 +32,26 @@ CpuOrtSessionBuilder::CpuOrtSessionBuilder() { HRESULT CpuOrtSessionBuilder::CreateSessionOptions( - onnxruntime::SessionOptions* p_options) { + ISessionOptions** p_options) { RETURN_HR_IF_NULL(E_POINTER, p_options); - *p_options = onnxruntime::SessionOptions(); - p_options->graph_optimization_level = onnxruntime::TransformerLevel::Level3; + auto options = wil::MakeOrThrow(); + options.CopyTo(__uuidof(ISessionOptions), (void**)p_options); + + (*p_options)->get().graph_optimization_level = onnxruntime::TransformerLevel::Level3; // Onnxruntime will use half the number of concurrent threads supported on the system // by default. This causes MLAS to not exercise every logical core. // We force the thread pool size to be maxxed out to ensure that WinML always // runs the fastest. - p_options->intra_op_num_threads = std::thread::hardware_concurrency(); + (*p_options)->get().intra_op_num_threads = std::thread::hardware_concurrency(); return S_OK; } HRESULT CpuOrtSessionBuilder::CreateSession( - const onnxruntime::SessionOptions& options, + ISessionOptions* options, _winmla::IInferenceSession** p_session, onnxruntime::IExecutionProvider** pp_provider) { RETURN_HR_IF_NULL(E_POINTER, p_session); @@ -57,7 +59,7 @@ CpuOrtSessionBuilder::CreateSession( RETURN_HR_IF(E_POINTER, *pp_provider != nullptr); // Create the inference session - auto session = std::make_unique(options); + auto session = std::make_unique(options->get()); // Create the cpu execution provider onnxruntime::CPUExecutionProviderInfo xpInfo; diff --git a/winml/lib/Api.Core/CpuOrtSessionBuilder.h b/winml/lib/Api.Core/CpuOrtSessionBuilder.h index 222541fb044ca..5612656970189 100644 --- a/winml/lib/Api.Core/CpuOrtSessionBuilder.h +++ b/winml/lib/Api.Core/CpuOrtSessionBuilder.h @@ -15,10 +15,10 @@ class CpuOrtSessionBuilder : public Microsoft::WRL::RuntimeClass < CpuOrtSessionBuilder(); HRESULT STDMETHODCALLTYPE CreateSessionOptions( - onnxruntime::SessionOptions* p_options) override; + ISessionOptions** p_options) override; HRESULT STDMETHODCALLTYPE CreateSession( - const onnxruntime::SessionOptions& options, + ISessionOptions* options, _winmla::IInferenceSession** p_session, onnxruntime::IExecutionProvider** pp_provider) override; diff --git a/winml/lib/Api.Core/DmlOrtSessionBuilder.cpp b/winml/lib/Api.Core/DmlOrtSessionBuilder.cpp index 85111aeacafc0..b9143fb2a8a97 100644 --- a/winml/lib/Api.Core/DmlOrtSessionBuilder.cpp +++ b/winml/lib/Api.Core/DmlOrtSessionBuilder.cpp @@ -39,15 +39,16 @@ DmlOrtSessionBuilder::DmlOrtSessionBuilder( HRESULT DmlOrtSessionBuilder::CreateSessionOptions( - onnxruntime::SessionOptions* p_options) { + ISessionOptions** p_options) { RETURN_HR_IF_NULL(E_POINTER, p_options); - *p_options = onnxruntime::SessionOptions(); - - p_options->graph_optimization_level = onnxruntime::TransformerLevel::Level3; + auto options = wil::MakeOrThrow(); + options.CopyTo(__uuidof(ISessionOptions), (void**)p_options); + + (*p_options)->get().graph_optimization_level = onnxruntime::TransformerLevel::Level3; // Disable the mem pattern session option for DML. It will cause problems with how memory is allocated. - p_options->enable_mem_pattern = false; + (*p_options)->get().enable_mem_pattern = false; return S_OK; } @@ -100,7 +101,7 @@ Microsoft::WRL::ComPtr CreateDmlDevice(ID3D12Device* d3d12Device) { } HRESULT DmlOrtSessionBuilder::CreateSession( - const onnxruntime::SessionOptions& options, + ISessionOptions* options, _winmla::IInferenceSession** p_session, onnxruntime::IExecutionProvider** pp_provider) { RETURN_HR_IF_NULL(E_POINTER, p_session); @@ -113,7 +114,7 @@ HRESULT DmlOrtSessionBuilder::CreateSession( Microsoft::WRL::ComPtr dmlDevice = CreateDmlDevice(p_d3d_device); std::unique_ptr gpu_provider = Dml::CreateExecutionProvider(dmlDevice.Get(), p_queue); - auto session = std::make_unique(options); + auto session = std::make_unique(options->get()); // Cache the provider's raw pointer *pp_provider = gpu_provider.get(); diff --git a/winml/lib/Api.Core/DmlOrtSessionBuilder.h b/winml/lib/Api.Core/DmlOrtSessionBuilder.h index 6e0cffadb7c8d..2788a41f7b2e7 100644 --- a/winml/lib/Api.Core/DmlOrtSessionBuilder.h +++ b/winml/lib/Api.Core/DmlOrtSessionBuilder.h @@ -15,10 +15,10 @@ class DmlOrtSessionBuilder : public Microsoft::WRL::RuntimeClass < DmlOrtSessionBuilder(ID3D12Device* device, ID3D12CommandQueue* queue); HRESULT STDMETHODCALLTYPE CreateSessionOptions( - onnxruntime::SessionOptions* p_options) override; + ISessionOptions** p_options) override; HRESULT STDMETHODCALLTYPE CreateSession( - const onnxruntime::SessionOptions& options, + ISessionOptions* options, _winmla::IInferenceSession** p_session, onnxruntime::IExecutionProvider** pp_provider) override; diff --git a/winml/lib/Api.Core/ModelInfo.cpp b/winml/lib/Api.Core/ModelInfo.cpp deleted file mode 100644 index 46e9bc29bfa3d..0000000000000 --- a/winml/lib/Api.Core/ModelInfo.cpp +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -#include "pch.h" - -#include "inc/ModelInfo.h" - - -#include "FeatureDescriptorFactory.h" - -using namespace Windows::AI::MachineLearning; - - - diff --git a/winml/lib/Api.Core/OrtSessionBuilder.cpp b/winml/lib/Api.Core/OrtSessionBuilder.cpp deleted file mode 100644 index 770eac443d7c6..0000000000000 --- a/winml/lib/Api.Core/OrtSessionBuilder.cpp +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -#include "pch.h" - diff --git a/winml/lib/Api.Core/WinMLAdapter.cpp b/winml/lib/Api.Core/WinMLAdapter.cpp index a60118d5e466f..6bff0e6c97011 100644 --- a/winml/lib/Api.Core/WinMLAdapter.cpp +++ b/winml/lib/Api.Core/WinMLAdapter.cpp @@ -44,8 +44,6 @@ class InferenceSessionProtectedLoadAccessor : public onnxruntime::InferenceSessi } }; - - // class AbiSafeTensor // class AbiSafeTensor : public Microsoft::WRL::RuntimeClass < @@ -112,8 +110,7 @@ class AbiSafeOrtValue : public Microsoft::WRL::RuntimeClass < HRESULT STDMETHODCALLTYPE GetTensor(ITensor ** tensor) override { auto tensor_inner = ort_value_.GetMutable(); auto tensor_outer = wil::MakeOrThrow(tensor_inner, this); - *tensor = tensor_outer.Detach(); - return S_OK; + return tensor_outer.CopyTo(__uuidof(ITensor), (void**)tensor); } }; // class AbiSafeOrtValue @@ -130,8 +127,12 @@ class ModelProto : public Microsoft::WRL::RuntimeClass< return model_proto_.get(); } + onnx::ModelProto* STDMETHODCALLTYPE detach() override { + return model_proto_.release(); + } + private: - std::shared_ptr model_proto_; + std::unique_ptr model_proto_; }; // class ModelProto @@ -155,28 +156,28 @@ class ModelInfo : public Microsoft::WRL::RuntimeClass< Initialize(model_proto); } - std::string STDMETHODCALLTYPE author() override { + std::string& STDMETHODCALLTYPE author() override { return author_; } - std::string STDMETHODCALLTYPE name() override { + std::string& STDMETHODCALLTYPE name() override { return name_; } - std::string STDMETHODCALLTYPE domain() override { + std::string& STDMETHODCALLTYPE domain() override { return domain_; } - std::string STDMETHODCALLTYPE description() override { + std::string& STDMETHODCALLTYPE description() override { return description_; } int64_t STDMETHODCALLTYPE version() override { return version_; } - std::unordered_map STDMETHODCALLTYPE model_metadata() override { + std::unordered_map& STDMETHODCALLTYPE model_metadata() override { return model_metadata_; } - wfc::IVector STDMETHODCALLTYPE input_features() override { + wfc::IVector& STDMETHODCALLTYPE input_features() override { return input_features_; } - wfc::IVector STDMETHODCALLTYPE output_features() override { + wfc::IVector& STDMETHODCALLTYPE output_features() override { return output_features_; } @@ -804,7 +805,8 @@ InferenceSession::LoadModel( IModelProto* model_proto) { auto session_protected_load_accessor = static_cast(session_.get()); - std::unique_ptr model_proto_ptr(model_proto->get()); + // session's like to have their very own copy of the model_proto, use detach() + std::unique_ptr model_proto_ptr(model_proto->detach()); ORT_THROW_IF_ERROR(session_protected_load_accessor->Load(std::move(model_proto_ptr))); return S_OK; } diff --git a/winml/lib/Api.Core/inc/IOrtSessionBuilder.h b/winml/lib/Api.Core/inc/IOrtSessionBuilder.h deleted file mode 100644 index ccc824cef09af..0000000000000 --- a/winml/lib/Api.Core/inc/IOrtSessionBuilder.h +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -#pragma once - diff --git a/winml/lib/Api.Core/inc/ModelInfo.h b/winml/lib/Api.Core/inc/ModelInfo.h deleted file mode 100644 index 3546aee5416e5..0000000000000 --- a/winml/lib/Api.Core/inc/ModelInfo.h +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -#pragma once - -namespace Windows::AI::MachineLearning { - - -} // namespace Windows::AI::MachineLearning \ No newline at end of file diff --git a/winml/lib/Api.Core/inc/WinMLAdapter.h b/winml/lib/Api.Core/inc/WinMLAdapter.h index 3c74a3726ba0c..65cf1a51ce2b1 100644 --- a/winml/lib/Api.Core/inc/WinMLAdapter.h +++ b/winml/lib/Api.Core/inc/WinMLAdapter.h @@ -3,21 +3,18 @@ #pragma once -#include "IOrtSessionBuilder.h" -#include "ModelInfo.h" - namespace Windows::AI::MachineLearning::Adapter { MIDL_INTERFACE("eaae30b5-7381-432d-9730-322136b02371") IModelInfo : IUnknown{ // model metadata - virtual std::string STDMETHODCALLTYPE author() = 0; - virtual std::string STDMETHODCALLTYPE name() = 0; - virtual std::string STDMETHODCALLTYPE domain() = 0; - virtual std::string STDMETHODCALLTYPE description() = 0; + virtual std::string& STDMETHODCALLTYPE author() = 0; + virtual std::string& STDMETHODCALLTYPE name() = 0; + virtual std::string& STDMETHODCALLTYPE domain() = 0; + virtual std::string& STDMETHODCALLTYPE description() = 0; virtual int64_t STDMETHODCALLTYPE version() = 0; - virtual std::unordered_map STDMETHODCALLTYPE model_metadata() = 0; - virtual wfc::IVector STDMETHODCALLTYPE input_features() = 0; - virtual wfc::IVector STDMETHODCALLTYPE output_features() = 0; + virtual std::unordered_map& STDMETHODCALLTYPE model_metadata() = 0; + virtual wfc::IVector& STDMETHODCALLTYPE input_features() = 0; + virtual wfc::IVector& STDMETHODCALLTYPE output_features() = 0; }; MIDL_INTERFACE("eaae30b5-7381-432d-9730-322136b02371") ITensor : IUnknown{ @@ -53,7 +50,10 @@ MIDL_INTERFACE("438e7719-554a-4058-84d9-eb6226c34887") IIOBinding : IUnknown{ }; MIDL_INTERFACE("a848faf6-5a2e-4a7f-b622-cc036f71e28a") IModelProto : IUnknown{ + // this returns a weak ref virtual onnx::ModelProto* STDMETHODCALLTYPE get() = 0; + // this returns the ownership without touching the reference and forgets about the object + virtual onnx::ModelProto* STDMETHODCALLTYPE detach() = 0; }; MIDL_INTERFACE("6ec766ef-6365-42bf-b64f-ae85c015adb8") IInferenceSession : IUnknown { @@ -70,15 +70,22 @@ MIDL_INTERFACE("6ec766ef-6365-42bf-b64f-ae85c015adb8") IInferenceSession : IUnkn virtual void STDMETHODCALLTYPE ReleaseCompletedReferences(onnxruntime::IExecutionProvider* dml_provider) = 0; }; +MIDL_INTERFACE("55a956a7-c20e-440d-b2d2-a77acf35de10") ISessionOptions : IUnknown{ + // this returns a weak ref + virtual onnxruntime::SessionOptions& STDMETHODCALLTYPE get() = 0; + // end + virtual void STDMETHODCALLTYPE SetBatchOverride(uint32_t batch_size) = 0; +}; + // The IOrtSessionBuilder offers an abstraction over the creation of // InferenceSession, that enables the creation of the session based on a device (CPU/DML). MIDL_INTERFACE("2746f03a-7e08-4564-b5d0-c670fef116ee") IOrtSessionBuilder : IUnknown { virtual HRESULT STDMETHODCALLTYPE CreateSessionOptions( - onnxruntime::SessionOptions* options) = 0; + ISessionOptions** options) = 0; virtual HRESULT STDMETHODCALLTYPE CreateSession( - const onnxruntime::SessionOptions& options, + ISessionOptions* options, IInferenceSession** session, onnxruntime::IExecutionProvider** provider) = 0; @@ -184,6 +191,22 @@ class InferenceSession : public Microsoft::WRL::RuntimeClass < std::shared_ptr session_; }; +class AbiSafeSessionOptions : public Microsoft::WRL::RuntimeClass < + Microsoft::WRL::RuntimeClassFlags, + ISessionOptions> { +private: + onnxruntime::SessionOptions options_; +public: + virtual onnxruntime::SessionOptions& STDMETHODCALLTYPE get() override { + return options_; + } + virtual void STDMETHODCALLTYPE SetBatchOverride(uint32_t batch_size) override { + onnxruntime::FreeDimensionOverride overrideOption = {}; + overrideOption.dimension_denotation = onnx::DATA_BATCH; + overrideOption.dimension_override = batch_size; + options_.free_dimension_overrides.emplace_back(overrideOption); + } +}; // header only code to enable smart pointers on abstract ort objects template diff --git a/winml/lib/Api/LearningModel.cpp b/winml/lib/Api/LearningModel.cpp index 78f821e26e8be..7636d33fd667b 100644 --- a/winml/lib/Api/LearningModel.cpp +++ b/winml/lib/Api/LearningModel.cpp @@ -6,7 +6,6 @@ #include "LearningModel.h" #include "core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h" -#include "ModelInfo.h" #include "TelemetryEvent.h" #include "LotusEnvironment.h" diff --git a/winml/lib/Api/LearningModelSession.cpp b/winml/lib/Api/LearningModelSession.cpp index 81e9e6abba69e..21f04c2bc68e2 100644 --- a/winml/lib/Api/LearningModelSession.cpp +++ b/winml/lib/Api/LearningModelSession.cpp @@ -6,7 +6,6 @@ #include "LearningModelSession.h" #include "ImageFeatureDescriptor.h" -#include "IOrtSessionBuilder.h" #include "WinMLAdapter.h" #include "LearningModel.h" #include "LearningModelBinding.h" @@ -110,21 +109,18 @@ void LearningModelSession::Initialize() { device_impl->GetDeviceQueue(), session_builder.put())); - onnxruntime::SessionOptions options = {}; - WINML_THROW_IF_FAILED(session_builder->CreateSessionOptions(&options)); + com_ptr<_winmla::ISessionOptions> options; + WINML_THROW_IF_FAILED(session_builder->CreateSessionOptions(options.put())); // Make onnxruntime apply the batch size override, if any if (session_options_ && session_options_.BatchSizeOverride() != 0) { - onnxruntime::FreeDimensionOverride overrideOption = {}; - overrideOption.dimension_denotation = onnx::DATA_BATCH; - overrideOption.dimension_override = session_options_.BatchSizeOverride(); - options.free_dimension_overrides.emplace_back(overrideOption); + options->SetBatchOverride(session_options_.BatchSizeOverride()); } com_ptr<_winmla::IInferenceSession> session; WINML_THROW_IF_FAILED(session_builder->CreateSession( - options, session.put(), &cached_execution_provider_)); + options.get(), session.put(), &cached_execution_provider_)); // Register the custom operator registry auto model = model_.as(); @@ -136,7 +132,7 @@ void LearningModelSession::Initialize() { // Load the model into the session WINML_THROW_IF_FAILED(session->LoadModel(model_proto.get())); - // the session owns the model_proto now + // the session owns the model_proto now, it used detach() model_proto = nullptr; // Initialize the session From a3542e112871f4bb0262ef0d77bcb174c3d44cb9 Mon Sep 17 00:00:00 2001 From: Paul McDaniel Date: Mon, 18 Nov 2019 11:06:08 -0800 Subject: [PATCH 5/6] PR feedback. --- cmake/onnxruntime.cmake | 2 +- cmake/winml.cmake | 16 +--- winml/lib/Api.Core/WinMLAdapter.cpp | 27 ++++-- winml/lib/Api/LearningModelSession.cpp | 114 ++++++++++++------------- 4 files changed, 78 insertions(+), 81 deletions(-) diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index 17abdf86377f8..ca91602aa8cbf 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -66,7 +66,7 @@ target_link_libraries(onnxruntime PRIVATE ${PROVIDERS_NUPHAR} ${PROVIDERS_DML} ${PROVIDERS_ACL} - ${onnxruntime_winml} + ${onnxruntime_winml} onnxruntime_optimizer onnxruntime_providers onnxruntime_util diff --git a/cmake/winml.cmake b/cmake/winml.cmake index c0a370013e8b0..f4748c8b8b473 100644 --- a/cmake/winml.cmake +++ b/cmake/winml.cmake @@ -467,26 +467,14 @@ if("${CMAKE_BUILD_TYPE}" STREQUAL "Debug") endif("${CMAKE_BUILD_TYPE}" STREQUAL "Debug") # Link libraries -#target_link_libraries(winml_dll PRIVATE libprotobuf) -#target_link_libraries(winml_dll PRIVATE onnx) -#target_link_libraries(winml_dll PRIVATE onnxruntime_common) -#target_link_libraries(winml_dll PRIVATE onnxruntime_graph) -#target_link_libraries(winml_dll PRIVATE onnxruntime_framework) -#target_link_libraries(winml_dll PRIVATE onnxruntime_mlas) -#target_link_libraries(winml_dll PRIVATE onnxruntime_optimizer) -#target_link_libraries(winml_dll PRIVATE onnxruntime_providers) -#target_link_libraries(winml_dll PRIVATE onnxruntime_providers_dml) -#target_link_libraries(winml_dll PRIVATE onnxruntime_session) -#target_link_libraries(winml_dll PRIVATE onnxruntime_util) -#target_link_libraries(winml_dll PRIVATE onnx_proto) target_link_libraries(winml_dll PRIVATE onnxruntime) target_link_libraries(winml_dll PRIVATE re2) target_link_libraries(winml_dll PRIVATE wil) -target_link_libraries(winml_dll PRIVATE windowsapp.lib) +#target_link_libraries(winml_dll PRIVATE windowsapp.lib) target_link_libraries(winml_dll PRIVATE winml_lib_api) -#target_link_libraries(winml_dll PRIVATE winml_lib_core) target_link_libraries(winml_dll PRIVATE winml_lib_image) target_link_libraries(winml_dll PRIVATE winml_lib_telemetry) +target_link_libraries(winml_dll PRIVATE onecoreuap_apiset.lib) target_link_libraries(winml_dll PRIVATE ${DBGHELP}) # 1 of 3 projects that fail in link with 'failed to do memory mapped file I/O' (Only release) diff --git a/winml/lib/Api.Core/WinMLAdapter.cpp b/winml/lib/Api.Core/WinMLAdapter.cpp index 6bff0e6c97011..8646ca314d9af 100644 --- a/winml/lib/Api.Core/WinMLAdapter.cpp +++ b/winml/lib/Api.Core/WinMLAdapter.cpp @@ -110,7 +110,7 @@ class AbiSafeOrtValue : public Microsoft::WRL::RuntimeClass < HRESULT STDMETHODCALLTYPE GetTensor(ITensor ** tensor) override { auto tensor_inner = ort_value_.GetMutable(); auto tensor_outer = wil::MakeOrThrow(tensor_inner, this); - return tensor_outer.CopyTo(__uuidof(ITensor), (void**)tensor); + return tensor_outer.CopyTo(__uuidof(ITensor), reinterpret_cast(tensor)); } }; // class AbiSafeOrtValue @@ -310,6 +310,7 @@ class WinMLAdapter : public Microsoft::WRL::RuntimeClass< const char* path, IModelProto** model_proto) override { int file_descriptor; + _set_errno(0); // clear errno _sopen_s( &file_descriptor, path, @@ -317,6 +318,14 @@ class WinMLAdapter : public Microsoft::WRL::RuntimeClass< _SH_DENYWR, _S_IREAD | _S_IWRITE); + errno_t err = 0; + _get_errno(&err); + THROW_HR_IF_MSG( + __HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND), + err == ENOENT, + "File not found: %s", + path); + THROW_HR_IF_MSG( E_FAIL, 0 > file_descriptor, @@ -332,7 +341,7 @@ class WinMLAdapter : public Microsoft::WRL::RuntimeClass< "The stream failed to parse."); auto model_proto_outer = wil::MakeOrThrow(model_proto_inner); - return model_proto_outer.CopyTo(__uuidof(IModelProto), (void**)model_proto); + return model_proto_outer.CopyTo(__uuidof(IModelProto), reinterpret_cast(model_proto)); } // factory methods for creating an ort model from a stream @@ -349,19 +358,19 @@ class WinMLAdapter : public Microsoft::WRL::RuntimeClass< "The stream failed to parse."); auto model_proto_outer = wil::MakeOrThrow(model_proto_inner); - return model_proto_outer.CopyTo(__uuidof(IModelProto), (void**)model_proto); + return model_proto_outer.CopyTo(__uuidof(IModelProto), reinterpret_cast(model_proto)); } // factory methods for creating an ort model from a model_proto HRESULT STDMETHODCALLTYPE CreateModelProto(IModelProto * model_proto_in, IModelProto** model_proto) override { auto model_proto_inner = new onnx::ModelProto(*model_proto_in->get()); auto model_proto_outer = wil::MakeOrThrow(model_proto_inner); - return model_proto_outer.CopyTo(__uuidof(IModelProto), (void**)model_proto); + return model_proto_outer.CopyTo(__uuidof(IModelProto), reinterpret_cast(model_proto)); } HRESULT STDMETHODCALLTYPE CreateModelInfo(IModelProto * model_proto, IModelInfo ** model_info) override { auto model_info_outer = wil::MakeOrThrow(model_proto->get()); - return model_info_outer.CopyTo(__uuidof(IModelInfo), (void**)model_info); + return model_info_outer.CopyTo(__uuidof(IModelInfo), reinterpret_cast(model_info)); } @@ -471,10 +480,10 @@ class WinMLAdapter : public Microsoft::WRL::RuntimeClass< if (device == nullptr) { auto builder = wil::MakeOrThrow(); - return builder.CopyTo(__uuidof(IOrtSessionBuilder), (void**)session_builder); + return builder.CopyTo(__uuidof(IOrtSessionBuilder), reinterpret_cast(session_builder)); } else { auto builder = wil::MakeOrThrow(device, queue); - return builder.CopyTo(__uuidof(IOrtSessionBuilder), (void**)session_builder); + return builder.CopyTo(__uuidof(IOrtSessionBuilder), reinterpret_cast(session_builder)); } } @@ -713,7 +722,7 @@ extern "C" HRESULT STDMETHODCALLTYPE OrtGetWinMLAdapter(IWinMLAdapter** adapter) { // make an adapter instance Microsoft::WRL::ComPtr adapterptr = wil::MakeOrThrow(); - return adapterptr.CopyTo(__uuidof(IWinMLAdapter), (void **)adapter); + return adapterptr.CopyTo(__uuidof(IWinMLAdapter), reinterpret_cast(adapter)); } @@ -782,7 +791,7 @@ HRESULT STDMETHODCALLTYPE InferenceSession::NewIOBinding(IIOBinding** io_binding std::unique_ptr binding; ORT_THROW_IF_ERROR(this->session_->NewIOBinding(&binding)); auto io_binding_outer = wil::MakeOrThrow(binding.release()); - return io_binding_outer.CopyTo(__uuidof(IIOBinding), (void**)io_binding); + return io_binding_outer.CopyTo(__uuidof(IIOBinding), reinterpret_cast(io_binding)); } HRESULT STDMETHODCALLTYPE InferenceSession::Run(const onnxruntime::RunOptions* run_options, IIOBinding* io_binding) { diff --git a/winml/lib/Api/LearningModelSession.cpp b/winml/lib/Api/LearningModelSession.cpp index 21f04c2bc68e2..fbedef52106c9 100644 --- a/winml/lib/Api/LearningModelSession.cpp +++ b/winml/lib/Api/LearningModelSession.cpp @@ -29,23 +29,23 @@ static const GUID WINML_PIX_EVAL_CAPTURABLE_WORK_GUID = __uuidof(guid_details::W namespace winrt::Windows::AI::MachineLearning::implementation { LearningModelSession::LearningModelSession( - winml::LearningModel const& model) try : LearningModelSession(model, - make(LearningModelDeviceKind::Default)) {} + winml::LearningModel const& model) try : LearningModelSession(model, + make(LearningModelDeviceKind::Default)) {} WINML_CATCH_ALL LearningModelSession::LearningModelSession( - winml::LearningModel const& model, - winml::LearningModelDevice const& deviceToRunOn) try : LearningModelSession(model, - deviceToRunOn, - nullptr) {} + winml::LearningModel const& model, + winml::LearningModelDevice const& deviceToRunOn) try : LearningModelSession(model, + deviceToRunOn, + nullptr) {} WINML_CATCH_ALL LearningModelSession::LearningModelSession( - winml::LearningModel const& model, - winml::LearningModelDevice const& deviceToRunOn, - winml::LearningModelSessionOptions const& learningModelSessionOptions) try : model_(model), - device_(deviceToRunOn), - session_options_(learningModelSessionOptions) { + winml::LearningModel const& model, + winml::LearningModelDevice const& deviceToRunOn, + winml::LearningModelSessionOptions const& learningModelSessionOptions) try : model_(model), + device_(deviceToRunOn), + session_options_(learningModelSessionOptions) { Initialize(); } WINML_CATCH_ALL @@ -55,8 +55,8 @@ LearningModelSession::GetOptimizedModel() { // Get the model proto auto should_close_model = - session_options_ != nullptr && - session_options_.CloseModelOnSessionCreation(); + session_options_ != nullptr && + session_options_.CloseModelOnSessionCreation(); return GetOptimizedModel(should_close_model); } @@ -91,7 +91,7 @@ LearningModelSession::GetOptimizedModel(bool should_close_model) { void LearningModelSession::Initialize() { // Begin recording session creation telemetry _winmlt::TelemetryEvent session_creation_event( - _winmlt::EventCategory::kSessionCreation); + _winmlt::EventCategory::kSessionCreation); // Get the optimized model proto from the learning model com_ptr<_winmla::IModelProto> model_proto; @@ -105,9 +105,9 @@ void LearningModelSession::Initialize() { com_ptr<_winmla::IOrtSessionBuilder> session_builder; WINML_THROW_IF_FAILED(adapter->CreateOrtSessionBuilder( - device_impl->GetD3DDevice(), - device_impl->GetDeviceQueue(), - session_builder.put())); + device_impl->GetD3DDevice(), + device_impl->GetDeviceQueue(), + session_builder.put())); com_ptr<_winmla::ISessionOptions> options; WINML_THROW_IF_FAILED(session_builder->CreateSessionOptions(options.put())); @@ -115,7 +115,7 @@ void LearningModelSession::Initialize() { // Make onnxruntime apply the batch size override, if any if (session_options_ && session_options_.BatchSizeOverride() != 0) { - options->SetBatchOverride(session_options_.BatchSizeOverride()); + options->SetBatchOverride(session_options_.BatchSizeOverride()); } com_ptr<_winmla::IInferenceSession> session; @@ -142,9 +142,9 @@ void LearningModelSession::Initialize() { inference_session_ = session; telemetry_helper.LogSessionCreation( - WinML::Strings::UTF8FromHString(model_.Name()), - device_impl->IsCpuDevice(), - device_impl->GetDeviceLuid()); + WinML::Strings::UTF8FromHString(model_.Name()), + device_impl->IsCpuDevice(), + device_impl->GetDeviceLuid()); } wfc::IPropertySet @@ -169,8 +169,8 @@ LearningModelSession::Device() try { WINML_CATCH_ALL auto CreateBinding( - LearningModelSession& session, - wfc::IMap const features) { + LearningModelSession& session, + wfc::IMap const features) { auto binding = winrt::make(session); for (auto feature : features.GetView()) { @@ -181,8 +181,8 @@ auto CreateBinding( winml::LearningModelEvaluationResult LearningModelSession::EvaluateFeatures( - wfc::IMap const features, - hstring const correlation_id) try { + wfc::IMap const features, + hstring const correlation_id) try { auto binding = CreateBinding(*this, features); return Evaluate(binding, correlation_id); } @@ -190,53 +190,53 @@ WINML_CATCH_ALL wf::IAsyncOperation LearningModelSession::EvaluateFeaturesAsync( - wfc::IMap const features, - hstring const correlation_id) { + wfc::IMap const features, + hstring const correlation_id) { auto binding = CreateBinding(*this, features); return EvaluateAsync(binding, correlation_id); } static _winmla::IIOBinding* GetIOBinding( - winrt::com_ptr binding_impl, - winml::LearningModel& model) { + winrt::com_ptr binding_impl, + winml::LearningModel& model) { // Get the IOBinding Collection, and bound outputs - com_ptr<_winmla::IIOBinding> io_binding; - io_binding.attach(binding_impl->BindingCollection()); + com_ptr<_winmla::IIOBinding> io_binding; + io_binding.attach(binding_impl->BindingCollection()); auto& bound_output_names = io_binding->GetOutputNames(); std::unordered_set bound_output_names_set( - bound_output_names.begin(), - bound_output_names.end()); + bound_output_names.begin(), + bound_output_names.end()); // Get model output feature names auto model_impl = model.as(); auto output_features = model_impl->OutputFeatures(); std::vector output_descriptors( - begin(output_features), - end(output_features)); + begin(output_features), + end(output_features)); // Convert all output features to their feature names std::vector output_feature_names; std::transform( - std::begin(output_descriptors), - std::end(output_descriptors), - std::back_inserter(output_feature_names), - [&](auto& descriptor) { - auto descriptor_native = descriptor.as(); - const wchar_t* p_name; - uint32_t size; - WINML_THROW_IF_FAILED(descriptor_native->GetName(&p_name, &size)); - return WinML::Strings::UTF8FromUnicode(p_name, size); - }); + std::begin(output_descriptors), + std::end(output_descriptors), + std::back_inserter(output_feature_names), + [&](auto& descriptor) { + auto descriptor_native = descriptor.as(); + const wchar_t* p_name; + uint32_t size; + WINML_THROW_IF_FAILED(descriptor_native->GetName(&p_name, &size)); + return WinML::Strings::UTF8FromUnicode(p_name, size); + }); // Find the set difference to determine if there are any unbound output features std::vector unbound_output_names; std::copy_if( - std::begin(output_feature_names), std::end(output_feature_names), - std::inserter(unbound_output_names, std::begin(unbound_output_names)), - [&](const auto& outputFeatureName) { - return bound_output_names_set.find(outputFeatureName) == bound_output_names_set.end(); - }); + std::begin(output_feature_names), std::end(output_feature_names), + std::inserter(unbound_output_names, std::begin(unbound_output_names)), + [&](const auto& outputFeatureName) { + return bound_output_names_set.find(outputFeatureName) == bound_output_names_set.end(); + }); // Add all unbound outputs to the iobinding collection for (const auto& unbound_output : unbound_output_names) { @@ -274,9 +274,9 @@ LearningModelSession::Run( winml::LearningModelEvaluationResult LearningModelSession::GetResults( - winrt::com_ptr binding_impl, - hstring const& correlation_id, - uint64_t evaluation_complete_fence) { + winrt::com_ptr binding_impl, + hstring const& correlation_id, + uint64_t evaluation_complete_fence) { // First wait on the fence value for the expected frame. This is passed in so that // the fence value is added to the queue in a thread safe manor. auto device = device_.as(); @@ -418,10 +418,10 @@ LearningModelSession::CreateSessionBinding() { void LearningModelSession::ApplyEvaluationProperties() try { if (evaluation_properties_) { auto is_debug_output_enabled = evaluation_properties_.HasKey(c_enable_debug_output); - if (is_debug_output_enabled) { - com_ptr<_winmla::IWinMLAdapter> adapter; - WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter.put())); - adapter->EnableDebugOutput(); + if (is_debug_output_enabled) { + com_ptr<_winmla::IWinMLAdapter> adapter; + WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter.put())); + adapter->EnableDebugOutput(); } } } From acc6ea525b4b6f903ef392d2a7cbcbcab829f01f Mon Sep 17 00:00:00 2001 From: Paul McDaniel Date: Mon, 18 Nov 2019 18:31:04 -0800 Subject: [PATCH 6/6] couple of fixes and coded getmutabledata() --- winml/lib/Api.Core/WinMLAdapter.cpp | 1384 +++++++++++++------------ winml/lib/Api.Core/inc/WinMLAdapter.h | 8 +- winml/lib/Api/impl/SequenceBase.h | 20 +- 3 files changed, 717 insertions(+), 695 deletions(-) diff --git a/winml/lib/Api.Core/WinMLAdapter.cpp b/winml/lib/Api.Core/WinMLAdapter.cpp index 8646ca314d9af..ea8c1086c5873 100644 --- a/winml/lib/Api.Core/WinMLAdapter.cpp +++ b/winml/lib/Api.Core/WinMLAdapter.cpp @@ -29,7 +29,6 @@ #include "FeatureDescriptorFactory.h" - using namespace winrt::Windows::AI::MachineLearning; namespace Windows::AI::MachineLearning::Adapter { @@ -46,806 +45,821 @@ class InferenceSessionProtectedLoadAccessor : public onnxruntime::InferenceSessi // class AbiSafeTensor // -class AbiSafeTensor : public Microsoft::WRL::RuntimeClass < - Microsoft::WRL::RuntimeClassFlags, - ITensor> { -private: - onnxruntime::Tensor& tensor_; // weak ref - ComPtr value_; // strong ref - -public: +class AbiSafeTensor : public Microsoft::WRL::RuntimeClass< + Microsoft::WRL::RuntimeClassFlags, + ITensor> { + private: + onnxruntime::Tensor& tensor_; // weak ref + ComPtr value_; // strong ref - AbiSafeTensor(onnxruntime::Tensor* tensor, - IOrtValue * value_in) : tensor_(*tensor), value_(value_in) { - } - const onnxruntime::Tensor& STDMETHODCALLTYPE get() override { - return tensor_; - } - onnxruntime::Tensor* STDMETHODCALLTYPE getMutable() override { - return &tensor_; - } - onnxruntime::MLDataType STDMETHODCALLTYPE DataType() override { - return tensor_.DataType(); - } - const void* STDMETHODCALLTYPE DataRaw() override { - return tensor_.DataRaw(); - } - const std::vector& STDMETHODCALLTYPE ShapeGetDims() override { - return tensor_.Shape().GetDims(); - } - int64_t STDMETHODCALLTYPE ShapeSize() override { - return tensor_.Shape().Size(); - } - const char * STDMETHODCALLTYPE LocationName() override { - return tensor_.Location().name; - } - OrtMemType STDMETHODCALLTYPE LocationMemType() override { - return tensor_.Location().mem_type; - } + public: + AbiSafeTensor(onnxruntime::Tensor* tensor, + IOrtValue* value_in) : tensor_(*tensor), value_(value_in) { + } + const onnxruntime::Tensor& STDMETHODCALLTYPE get() override { + return tensor_; + } + onnxruntime::Tensor* STDMETHODCALLTYPE getMutable() override { + return &tensor_; + } + onnxruntime::MLDataType STDMETHODCALLTYPE DataType() override { + return tensor_.DataType(); + } + const void* STDMETHODCALLTYPE DataRaw() override { + return tensor_.DataRaw(); + } + const std::vector& STDMETHODCALLTYPE ShapeGetDims() override { + return tensor_.Shape().GetDims(); + } + int64_t STDMETHODCALLTYPE ShapeSize() override { + return tensor_.Shape().Size(); + } + const char* STDMETHODCALLTYPE LocationName() override { + return tensor_.Location().name; + } + OrtMemType STDMETHODCALLTYPE LocationMemType() override { + return tensor_.Location().mem_type; + } }; // class OrtValue // -class AbiSafeOrtValue : public Microsoft::WRL::RuntimeClass < - Microsoft::WRL::RuntimeClassFlags, - IOrtValue> { -private: - OrtValue ort_value_; - -public: - AbiSafeOrtValue() {} - AbiSafeOrtValue(OrtValue value_in) : ort_value_(value_in) {} - - OrtValue& STDMETHODCALLTYPE get() override { - return ort_value_; - } +class AbiSafeOrtValue : public Microsoft::WRL::RuntimeClass< + Microsoft::WRL::RuntimeClassFlags, + IOrtValue> { + private: + OrtValue ort_value_; + OrtValue* ort_value_weak_; - onnxruntime::MLDataType STDMETHODCALLTYPE Type() override { - return ort_value_.Type(); - } - bool STDMETHODCALLTYPE IsTensor() override { - return ort_value_.IsTensor(); - } - // end - HRESULT STDMETHODCALLTYPE GetTensor(ITensor ** tensor) override { - auto tensor_inner = ort_value_.GetMutable(); - auto tensor_outer = wil::MakeOrThrow(tensor_inner, this); - return tensor_outer.CopyTo(__uuidof(ITensor), reinterpret_cast(tensor)); - } -}; // class AbiSafeOrtValue + public: + AbiSafeOrtValue() : ort_value_weak_(nullptr) {} + AbiSafeOrtValue(OrtValue* weak_value_in) : ort_value_weak_(weak_value_in) { } + ~AbiSafeOrtValue() { + int foo = 3; + foo += 1; + } -class ModelProto : public Microsoft::WRL::RuntimeClass< - Microsoft::WRL::RuntimeClassFlags, - IModelProto> { -public: - ModelProto::ModelProto(onnx::ModelProto* model_proto) : model_proto_(model_proto) { + OrtValue* STDMETHODCALLTYPE get() override { + if (ort_value_weak_ != nullptr) + return ort_value_weak_; + return &ort_value_; + } - } + onnxruntime::MLDataType STDMETHODCALLTYPE Type() override { + return get()->Type(); + } + bool STDMETHODCALLTYPE IsTensor() override { + return get()->IsTensor(); + } + // end + HRESULT STDMETHODCALLTYPE GetTensor(ITensor** tensor) override { + auto tensor_inner = get()->GetMutable(); + auto tensor_outer = wil::MakeOrThrow(tensor_inner, this); + return tensor_outer.CopyTo(__uuidof(ITensor), reinterpret_cast(tensor)); + } +}; // class AbiSafeOrtValue - onnx::ModelProto* STDMETHODCALLTYPE get() override { - return model_proto_.get(); - } +class ModelProto : public Microsoft::WRL::RuntimeClass< + Microsoft::WRL::RuntimeClassFlags, + IModelProto> { + public: + ModelProto::ModelProto(onnx::ModelProto* model_proto) : model_proto_(model_proto) { + } - onnx::ModelProto* STDMETHODCALLTYPE detach() override { - return model_proto_.release(); - } + onnx::ModelProto* STDMETHODCALLTYPE get() override { + return model_proto_.get(); + } -private: - std::unique_ptr model_proto_; -}; // class ModelProto + onnx::ModelProto* STDMETHODCALLTYPE detach() override { + return model_proto_.release(); + } + private: + std::unique_ptr model_proto_; +}; // class ModelProto class ModelInfo : public Microsoft::WRL::RuntimeClass< - Microsoft::WRL::RuntimeClassFlags, - IModelInfo> { - -private: - std::string author_; - std::string name_; - std::string domain_; - std::string description_; - int64_t version_; - std::unordered_map model_metadata_; - wfc::IVector input_features_; - wfc::IVector output_features_; - -public: - - ModelInfo(const onnx::ModelProto* model_proto) { - Initialize(model_proto); - } + Microsoft::WRL::RuntimeClassFlags, + IModelInfo> { + private: + std::string author_; + std::string name_; + std::string domain_; + std::string description_; + int64_t version_; + std::unordered_map model_metadata_; + wfc::IVector input_features_; + wfc::IVector output_features_; - std::string& STDMETHODCALLTYPE author() override { - return author_; - } - std::string& STDMETHODCALLTYPE name() override { - return name_; - } - std::string& STDMETHODCALLTYPE domain() override { - return domain_; - } - std::string& STDMETHODCALLTYPE description() override { - return description_; - } - int64_t STDMETHODCALLTYPE version() override { - return version_; - } - std::unordered_map& STDMETHODCALLTYPE model_metadata() override { - return model_metadata_; - } - wfc::IVector& STDMETHODCALLTYPE input_features() override { - return input_features_; - } - wfc::IVector& STDMETHODCALLTYPE output_features() override { - return output_features_; - } - - static std::vector - GetAllNodeOutputs(const onnx::ModelProto& model_proto) { - std::vector nodes_outputs; - auto& graph = model_proto.graph(); - auto& nodes = graph.node(); - for (auto& node : nodes) { - for (auto& node_output : node.output()) { - nodes_outputs.push_back(node_output.c_str()); - } - } - return nodes_outputs; - } + public: + ModelInfo(const onnx::ModelProto* model_proto) { + Initialize(model_proto); + } - static std::vector - GetInitializers(const onnx::ModelProto& model_proto) { - std::vector initializers; - auto& graph = model_proto.graph(); - auto& graph_initializers = graph.initializer(); - for (auto& initializer : graph_initializers) { - initializers.push_back(initializer.name().c_str()); - } - return initializers; - } + std::string& STDMETHODCALLTYPE author() override { + return author_; + } + std::string& STDMETHODCALLTYPE name() override { + return name_; + } + std::string& STDMETHODCALLTYPE domain() override { + return domain_; + } + std::string& STDMETHODCALLTYPE description() override { + return description_; + } + int64_t STDMETHODCALLTYPE version() override { + return version_; + } + std::unordered_map& STDMETHODCALLTYPE model_metadata() override { + return model_metadata_; + } + wfc::IVector& STDMETHODCALLTYPE input_features() override { + return input_features_; + } + wfc::IVector& STDMETHODCALLTYPE output_features() override { + return output_features_; + } - static std::vector - GetInputsWithoutInitializers(const onnx::ModelProto& model_proto) { - auto initializers = GetInitializers(model_proto); - - std::vector inputs_without_initializers; - auto& graph = model_proto.graph(); - auto& inputs = graph.input(); - for (auto& input : inputs) { - if (input.has_name() && input.has_type()) { - auto found_it = std::find_if( - std::begin(initializers), - std::end(initializers), - [&](auto& initializer) { - return std::strcmp(initializer, input.name().c_str()) == 0; - }); - - auto is_initializer = found_it != std::end(initializers); - if (!is_initializer) { - inputs_without_initializers.push_back(&input); - } - } - } - return inputs_without_initializers; - } + static std::vector + GetAllNodeOutputs(const onnx::ModelProto& model_proto) { + std::vector nodes_outputs; + auto& graph = model_proto.graph(); + auto& nodes = graph.node(); + for (auto& node : nodes) { + for (auto& node_output : node.output()) { + nodes_outputs.push_back(node_output.c_str()); + } + } + return nodes_outputs; + } - static - std::vector GetOutputs(const onnx::ModelProto& model_proto) { - std::vector outputs_with_name; - auto& graph = model_proto.graph(); - auto& outputs = graph.output(); - for (auto& output : outputs) { - if (output.has_name() && output.has_type()) { - outputs_with_name.push_back(&output); - } - } - return outputs_with_name; + static std::vector + GetInitializers(const onnx::ModelProto& model_proto) { + std::vector initializers; + auto& graph = model_proto.graph(); + auto& graph_initializers = graph.initializer(); + for (auto& initializer : graph_initializers) { + initializers.push_back(initializer.name().c_str()); } + return initializers; + } -private: - void Initialize(const onnx::ModelProto* model_proto) { - // metadata - for (auto& prop : model_proto->metadata_props()) { - model_metadata_[prop.key()] = prop.value(); + static std::vector + GetInputsWithoutInitializers(const onnx::ModelProto& model_proto) { + auto initializers = GetInitializers(model_proto); + + std::vector inputs_without_initializers; + auto& graph = model_proto.graph(); + auto& inputs = graph.input(); + for (auto& input : inputs) { + if (input.has_name() && input.has_type()) { + auto found_it = std::find_if( + std::begin(initializers), + std::end(initializers), + [&](auto& initializer) { + return std::strcmp(initializer, input.name().c_str()) == 0; + }); + + auto is_initializer = found_it != std::end(initializers); + if (!is_initializer) { + inputs_without_initializers.push_back(&input); } - - WinML::FeatureDescriptorFactory builder(model_metadata_); - - // Create inputs - auto inputs = GetInputsWithoutInitializers(*model_proto); - input_features_ = builder.CreateDescriptorsFromValueInfoProtos(inputs); - - // Create outputs - auto outputs = GetOutputs(*model_proto); - output_features_ = builder.CreateDescriptorsFromValueInfoProtos(outputs); - - // author - auto has_producer_name = model_proto->has_producer_name(); - author_ = has_producer_name - ? model_proto->producer_name() - : ""; - - // domain - auto has_domain = model_proto->has_domain(); - domain_ = has_domain - ? model_proto->domain() - : ""; - - // name - auto has_graph = model_proto->has_graph(); - auto graph_has_name = model_proto->graph().has_name(); - auto is_name_available = has_graph && graph_has_name; - name_ = is_name_available - ? model_proto->graph().name() - : ""; - - // description - auto has_description = model_proto->has_doc_string(); - description_ = has_description - ? model_proto->doc_string() - : ""; - - // version - auto has_version = model_proto->has_model_version(); - version_ = has_version - ? model_proto->model_version() - : 0; + } } -}; // class ModelInfo - -class WinMLAdapter : public Microsoft::WRL::RuntimeClass< - Microsoft::WRL::RuntimeClassFlags, - IWinMLAdapter> { -private: - std::shared_ptr lotus_environment_; - -public: - WinMLAdapter() : lotus_environment_(PheonixSingleton()) { - - } - - // factory methods for creating an ort model from a path - HRESULT STDMETHODCALLTYPE CreateModelProto( - const char* path, - IModelProto** model_proto) override { - int file_descriptor; - _set_errno(0); // clear errno - _sopen_s( - &file_descriptor, - path, - O_RDONLY | _O_SEQUENTIAL | _O_BINARY, - _SH_DENYWR, - _S_IREAD | _S_IWRITE); - - errno_t err = 0; - _get_errno(&err); - THROW_HR_IF_MSG( - __HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND), - err == ENOENT, - "File not found: %s", - path); - - THROW_HR_IF_MSG( - E_FAIL, - 0 > file_descriptor, - "Failed"); //errno - - auto stream = google::protobuf::io::FileInputStream(file_descriptor); - stream.SetCloseOnDelete(true); - - auto model_proto_inner = new onnx::ModelProto(); - THROW_HR_IF_MSG( - E_INVALIDARG, - model_proto_inner->ParseFromZeroCopyStream(&stream) == false, - "The stream failed to parse."); - - auto model_proto_outer = wil::MakeOrThrow(model_proto_inner); - return model_proto_outer.CopyTo(__uuidof(IModelProto), reinterpret_cast(model_proto)); - } - - // factory methods for creating an ort model from a stream - HRESULT STDMETHODCALLTYPE CreateModelProto( - ABI::Windows::Storage::Streams::IRandomAccessStreamReference* stream_reference, - IModelProto** model_proto) override { - - ZeroCopyInputStreamWrapper wrapper(stream_reference); - - auto model_proto_inner = new onnx::ModelProto(); - THROW_HR_IF_MSG( - E_INVALIDARG, - model_proto_inner->ParseFromZeroCopyStream(&wrapper) == false, - "The stream failed to parse."); - - auto model_proto_outer = wil::MakeOrThrow(model_proto_inner); - return model_proto_outer.CopyTo(__uuidof(IModelProto), reinterpret_cast(model_proto)); - } - - // factory methods for creating an ort model from a model_proto - HRESULT STDMETHODCALLTYPE CreateModelProto(IModelProto * model_proto_in, IModelProto** model_proto) override { - auto model_proto_inner = new onnx::ModelProto(*model_proto_in->get()); - auto model_proto_outer = wil::MakeOrThrow(model_proto_inner); - return model_proto_outer.CopyTo(__uuidof(IModelProto), reinterpret_cast(model_proto)); - } - - HRESULT STDMETHODCALLTYPE CreateModelInfo(IModelProto * model_proto, IModelInfo ** model_info) override { - auto model_info_outer = wil::MakeOrThrow(model_proto->get()); - return model_info_outer.CopyTo(__uuidof(IModelInfo), reinterpret_cast(model_info)); - } - + return inputs_without_initializers; + } - void STDMETHODCALLTYPE EnableDebugOutput() override { - WinML::CWinMLLogSink::EnableDebugOutput(); + static std::vector GetOutputs(const onnx::ModelProto& model_proto) { + std::vector outputs_with_name; + auto& graph = model_proto.graph(); + auto& outputs = graph.output(); + for (auto& output : outputs) { + if (output.has_name() && output.has_type()) { + outputs_with_name.push_back(&output); + } } + return outputs_with_name; + } - static bool IsFeatureDescriptorFp16( - winml::ILearningModelFeatureDescriptor descriptor) { - if (auto imageFeatureDescriptor = descriptor.try_as()) { - return TensorKind::Float16 == imageFeatureDescriptor.TensorKind(); - } + private: + void Initialize(const onnx::ModelProto* model_proto) { + // metadata + for (auto& prop : model_proto->metadata_props()) { + model_metadata_[prop.key()] = prop.value(); + } + + WinML::FeatureDescriptorFactory builder(model_metadata_); + + // Create inputs + auto inputs = GetInputsWithoutInitializers(*model_proto); + input_features_ = builder.CreateDescriptorsFromValueInfoProtos(inputs); + + // Create outputs + auto outputs = GetOutputs(*model_proto); + output_features_ = builder.CreateDescriptorsFromValueInfoProtos(outputs); + + // author + auto has_producer_name = model_proto->has_producer_name(); + author_ = has_producer_name + ? model_proto->producer_name() + : ""; + + // domain + auto has_domain = model_proto->has_domain(); + domain_ = has_domain + ? model_proto->domain() + : ""; + + // name + auto has_graph = model_proto->has_graph(); + auto graph_has_name = model_proto->graph().has_name(); + auto is_name_available = has_graph && graph_has_name; + name_ = is_name_available + ? model_proto->graph().name() + : ""; + + // description + auto has_description = model_proto->has_doc_string(); + description_ = has_description + ? model_proto->doc_string() + : ""; + + // version + auto has_version = model_proto->has_model_version(); + version_ = has_version + ? model_proto->model_version() + : 0; + } +}; // class ModelInfo - if (auto tensorFeatureDescriptor = descriptor.try_as()) { - return TensorKind::Float16 == tensorFeatureDescriptor.TensorKind(); - } +class WinMLAdapter : public Microsoft::WRL::RuntimeClass< + Microsoft::WRL::RuntimeClassFlags, + IWinMLAdapter> { + private: + std::shared_ptr lotus_environment_; - return false; - } + public: + WinMLAdapter() : lotus_environment_(PheonixSingleton()) { + } - HRESULT STDMETHODCALLTYPE EnsureModelDeviceCompatibility( - winml::LearningModel const& model, - IModelProto* p_model_proto, - bool is_float16_supported) override { - if (!is_float16_supported) { - auto& graph = p_model_proto->get()->graph(); - - // The model will not contain fp16 operations if: - // 1. The model has no fp16 inputs - // 2. The model has no fp16 initializers - // 3. The model does not create any fp16 intermediary tensors via the Cast (to float16) operator - // 4. The model does not have any fp16 outputs - - // 1. Ensure that The model has no fp16 inputs - for (auto descriptor : model.InputFeatures()) { - THROW_HR_IF_MSG( - DXGI_ERROR_UNSUPPORTED, - IsFeatureDescriptorFp16(descriptor), - "The model contains a 16-bit input (%ls), but the current device does not support 16-bit float.", - descriptor.Name().c_str()); - } + // factory methods for creating an ort model from a path + HRESULT STDMETHODCALLTYPE CreateModelProto( + const char* path, + IModelProto** model_proto) override { + int file_descriptor; + _set_errno(0); // clear errno + _sopen_s( + &file_descriptor, + path, + O_RDONLY | _O_SEQUENTIAL | _O_BINARY, + _SH_DENYWR, + _S_IREAD | _S_IWRITE); + + errno_t err = 0; + _get_errno(&err); + THROW_HR_IF_MSG( + __HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND), + err == ENOENT, + "File not found: %s", + path); + + THROW_HR_IF_MSG( + E_FAIL, + 0 > file_descriptor, + "Failed"); //errno + + auto stream = google::protobuf::io::FileInputStream(file_descriptor); + stream.SetCloseOnDelete(true); + + auto model_proto_inner = new onnx::ModelProto(); + THROW_HR_IF_MSG( + E_INVALIDARG, + model_proto_inner->ParseFromZeroCopyStream(&stream) == false, + "The stream failed to parse."); + + auto model_proto_outer = wil::MakeOrThrow(model_proto_inner); + return model_proto_outer.CopyTo(__uuidof(IModelProto), reinterpret_cast(model_proto)); + } - // 2. Ensure that the model has no fp16 initializers - for (int i = 0; i < graph.node_size(); i++) { - auto node = graph.node(i); - if (node.op_type() == "Cast" && node.domain().empty()) { - for (int attribIndex = 0; attribIndex < node.attribute_size(); attribIndex++) { - auto attribute = node.attribute(attribIndex); - if (attribute.name() == "to") { - THROW_HR_IF_MSG( - DXGI_ERROR_UNSUPPORTED, - attribute.i() == onnx::TensorProto::DataType::TensorProto_DataType_FLOAT16, - "The model contains a 16-bit float Cast Op (%s), but the current device does not support 16-bit float.", - node.name().c_str()); - } - } - } - } + // factory methods for creating an ort model from a stream + HRESULT STDMETHODCALLTYPE CreateModelProto( + ABI::Windows::Storage::Streams::IRandomAccessStreamReference* stream_reference, + IModelProto** model_proto) override { + ZeroCopyInputStreamWrapper wrapper(stream_reference); - // 3. Ensure that the model does not create any fp16 intermediary - // tensors via the Cast (to float16) operator - for (int i = 0; i < graph.initializer_size(); i++) { - auto initializer = graph.initializer(i); + auto model_proto_inner = new onnx::ModelProto(); + THROW_HR_IF_MSG( + E_INVALIDARG, + model_proto_inner->ParseFromZeroCopyStream(&wrapper) == false, + "The stream failed to parse."); - THROW_HR_IF_MSG( - DXGI_ERROR_UNSUPPORTED, - initializer.data_type() == onnx::TensorProto::DataType::TensorProto_DataType_FLOAT16, - "The model contains a 16-bit float initializer (%s), but the current device does not support 16-bit float.", - initializer.name().c_str()); - } + auto model_proto_outer = wil::MakeOrThrow(model_proto_inner); + return model_proto_outer.CopyTo(__uuidof(IModelProto), reinterpret_cast(model_proto)); + } - // 4. Ensure that the model does not have any fp16 outputs - for (auto descriptor : model.OutputFeatures()) { - THROW_HR_IF_MSG( - DXGI_ERROR_UNSUPPORTED, - IsFeatureDescriptorFp16(descriptor), - "The model contains a 16-bit output (%ls), but the current device does not support 16-bit float.", - descriptor.Name().c_str()); - } - } - return S_OK; - } + // factory methods for creating an ort model from a model_proto + HRESULT STDMETHODCALLTYPE CreateModelProto(IModelProto* model_proto_in, IModelProto** model_proto) override { + auto model_proto_inner = new onnx::ModelProto(*model_proto_in->get()); + auto model_proto_outer = wil::MakeOrThrow(model_proto_inner); + return model_proto_outer.CopyTo(__uuidof(IModelProto), reinterpret_cast(model_proto)); + } - ID3D12Resource* STDMETHODCALLTYPE GetD3D12ResourceFromAllocation(onnxruntime::IExecutionProvider* provider, void* allocation) override { - auto d3dResource = - Dml::GetD3D12ResourceFromAllocation( - provider->GetAllocator(0, ::OrtMemType::OrtMemTypeDefault).get(), - allocation); - return d3dResource; - } + HRESULT STDMETHODCALLTYPE CreateModelInfo(IModelProto* model_proto, IModelInfo** model_info) override { + auto model_info_outer = wil::MakeOrThrow(model_proto->get()); + return model_info_outer.CopyTo(__uuidof(IModelInfo), reinterpret_cast(model_info)); + } - static onnxruntime::MLDataType GetType(winml::TensorKind kind) { - switch (kind) { - case winml::TensorKind::Float: - return onnxruntime::DataTypeImpl::GetType(); - case winml::TensorKind::Float16: - return onnxruntime::DataTypeImpl::GetType(); - }; - return nullptr; - } + void STDMETHODCALLTYPE EnableDebugOutput() override { + WinML::CWinMLLogSink::EnableDebugOutput(); + } - // factory method for creating an ortsessionbuilder from a device - HRESULT STDMETHODCALLTYPE CreateOrtSessionBuilder( - ID3D12Device* device, - ID3D12CommandQueue* queue, - IOrtSessionBuilder** session_builder) override { - - if (device == nullptr) { - auto builder = wil::MakeOrThrow(); - return builder.CopyTo(__uuidof(IOrtSessionBuilder), reinterpret_cast(session_builder)); - } else { - auto builder = wil::MakeOrThrow(device, queue); - return builder.CopyTo(__uuidof(IOrtSessionBuilder), reinterpret_cast(session_builder)); - } + static bool IsFeatureDescriptorFp16( + winml::ILearningModelFeatureDescriptor descriptor) { + if (auto imageFeatureDescriptor = descriptor.try_as()) { + return TensorKind::Float16 == imageFeatureDescriptor.TensorKind(); } - onnxruntime::MLDataType STDMETHODCALLTYPE GetTensorType() override { - return onnxruntime::DataTypeImpl::GetType(); + if (auto tensorFeatureDescriptor = descriptor.try_as()) { + return TensorKind::Float16 == tensorFeatureDescriptor.TensorKind(); } - onnxruntime::MLDataType STDMETHODCALLTYPE GetTensorType(winml::TensorKind kind) override { - if (kind == TensorKind::Float) { - return onnxruntime::DataTypeImpl::GetType(); - } - else if (kind == TensorKind::Double) { - return onnxruntime::DataTypeImpl::GetType(); - } - else if (kind == TensorKind::String) { - return onnxruntime::DataTypeImpl::GetType(); - } - else if (kind == TensorKind::UInt8) { - return onnxruntime::DataTypeImpl::GetType(); - } - else if (kind == TensorKind::Int8) { - return onnxruntime::DataTypeImpl::GetType(); - } - else if (kind == TensorKind::UInt16) { - return onnxruntime::DataTypeImpl::GetType(); - } - else if (kind == TensorKind::Int16) { - return onnxruntime::DataTypeImpl::GetType(); - } - else if (kind == TensorKind::UInt32) { - return onnxruntime::DataTypeImpl::GetType(); - } - else if (kind == TensorKind::Int32) { - return onnxruntime::DataTypeImpl::GetType(); - } - else if (kind == TensorKind::UInt64) { - return onnxruntime::DataTypeImpl::GetType(); - } - else if (kind == TensorKind::Int64) { - return onnxruntime::DataTypeImpl::GetType(); - } - else if (kind == TensorKind::Boolean) { - return onnxruntime::DataTypeImpl::GetType(); - } - else if (kind == TensorKind::Float16) { - return onnxruntime::DataTypeImpl::GetType(); - } - return nullptr; - } + return false; + } - onnxruntime::MLDataType STDMETHODCALLTYPE GetMapType(winml::TensorKind key_kind, winml::TensorKind value_kind) override { - if (key_kind == TensorKind::String) { - if (value_kind == TensorKind::String) { - return onnxruntime::DataTypeImpl::GetType(); - } - else if (value_kind == TensorKind::Int64) { - return onnxruntime::DataTypeImpl::GetType(); - } - else if (value_kind == TensorKind::Float) { - return onnxruntime::DataTypeImpl::GetType(); - } - else if (value_kind == TensorKind::Double) { - return onnxruntime::DataTypeImpl::GetType(); - } - } - else if (key_kind == TensorKind::Int64) { - if (value_kind == TensorKind::String) { - return onnxruntime::DataTypeImpl::GetType(); - } - else if (value_kind == TensorKind::Int64) { - return onnxruntime::DataTypeImpl::GetType(); - } - else if (value_kind == TensorKind::Float) { - return onnxruntime::DataTypeImpl::GetType(); - } - else if (value_kind == TensorKind::Double) { - return onnxruntime::DataTypeImpl::GetType(); + HRESULT STDMETHODCALLTYPE EnsureModelDeviceCompatibility( + winml::LearningModel const& model, + IModelProto* p_model_proto, + bool is_float16_supported) override { + if (!is_float16_supported) { + auto& graph = p_model_proto->get()->graph(); + + // The model will not contain fp16 operations if: + // 1. The model has no fp16 inputs + // 2. The model has no fp16 initializers + // 3. The model does not create any fp16 intermediary tensors via the Cast (to float16) operator + // 4. The model does not have any fp16 outputs + + // 1. Ensure that The model has no fp16 inputs + for (auto descriptor : model.InputFeatures()) { + THROW_HR_IF_MSG( + DXGI_ERROR_UNSUPPORTED, + IsFeatureDescriptorFp16(descriptor), + "The model contains a 16-bit input (%ls), but the current device does not support 16-bit float.", + descriptor.Name().c_str()); + } + + // 2. Ensure that the model has no fp16 initializers + for (int i = 0; i < graph.node_size(); i++) { + auto node = graph.node(i); + if (node.op_type() == "Cast" && node.domain().empty()) { + for (int attribIndex = 0; attribIndex < node.attribute_size(); attribIndex++) { + auto attribute = node.attribute(attribIndex); + if (attribute.name() == "to") { + THROW_HR_IF_MSG( + DXGI_ERROR_UNSUPPORTED, + attribute.i() == onnx::TensorProto::DataType::TensorProto_DataType_FLOAT16, + "The model contains a 16-bit float Cast Op (%s), but the current device does not support 16-bit float.", + node.name().c_str()); } + } } - return nullptr; - } + } - onnxruntime::MLDataType STDMETHODCALLTYPE GetVectorMapType(winml::TensorKind key_kind, winml::TensorKind value_kind) override { - if (key_kind == TensorKind::String) { - if (value_kind == TensorKind::Float) { - return onnxruntime::DataTypeImpl::GetType(); - } - } - else if (key_kind == TensorKind::Int64) { - if (value_kind == TensorKind::Float) { - return onnxruntime::DataTypeImpl::GetType(); - } - } - return nullptr; - } + // 3. Ensure that the model does not create any fp16 intermediary + // tensors via the Cast (to float16) operator + for (int i = 0; i < graph.initializer_size(); i++) { + auto initializer = graph.initializer(i); - // returns the raw mutable data. - void * STDMETHODCALLTYPE GetTensorData(IOrtValue * ort_Value) override { - return nullptr; - } - void * STDMETHODCALLTYPE GetMapData(IOrtValue * ort_Value, winml::TensorKind key_kind, winml::TensorKind value_kind) override { - return nullptr; - } - void * STDMETHODCALLTYPE GetVectorData(IOrtValue * ort_Value, winml::TensorKind key_kind, winml::TensorKind value_kind) override { - return nullptr; + THROW_HR_IF_MSG( + DXGI_ERROR_UNSUPPORTED, + initializer.data_type() == onnx::TensorProto::DataType::TensorProto_DataType_FLOAT16, + "The model contains a 16-bit float initializer (%s), but the current device does not support 16-bit float.", + initializer.name().c_str()); + } + + // 4. Ensure that the model does not have any fp16 outputs + for (auto descriptor : model.OutputFeatures()) { + THROW_HR_IF_MSG( + DXGI_ERROR_UNSUPPORTED, + IsFeatureDescriptorFp16(descriptor), + "The model contains a 16-bit output (%ls), but the current device does not support 16-bit float.", + descriptor.Name().c_str()); + } } + return S_OK; + } - HRESULT STDMETHODCALLTYPE GetCustomRegistry(IMLOperatorRegistry** registry) override { - auto impl = wil::MakeOrThrow(); - *registry = impl.Detach(); - return S_OK; - } + ID3D12Resource* STDMETHODCALLTYPE GetD3D12ResourceFromAllocation(onnxruntime::IExecutionProvider* provider, void* allocation) override { + auto d3dResource = + Dml::GetD3D12ResourceFromAllocation( + provider->GetAllocator(0, ::OrtMemType::OrtMemTypeDefault).get(), + allocation); + return d3dResource; + } - void* STDMETHODCALLTYPE CreateGPUAllocationFromD3DResource(ID3D12Resource* pResource) override { - return Dml::CreateGPUAllocationFromD3DResource(pResource); - } + static onnxruntime::MLDataType GetType(winml::TensorKind kind) { + switch (kind) { + case winml::TensorKind::Float: + return onnxruntime::DataTypeImpl::GetType(); + case winml::TensorKind::Float16: + return onnxruntime::DataTypeImpl::GetType(); + }; + return nullptr; + } - void STDMETHODCALLTYPE FreeGPUAllocation(void* ptr) override { - Dml::FreeGPUAllocation(ptr); - } - HRESULT STDMETHODCALLTYPE CopyTensor( - onnxruntime::IExecutionProvider* provider, - ITensor* src, - ITensor* dst) override { - ORT_THROW_IF_ERROR(Dml::CopyTensor(provider, src->get(), *(dst->getMutable()))); - return S_OK; + // factory method for creating an ortsessionbuilder from a device + HRESULT STDMETHODCALLTYPE CreateOrtSessionBuilder( + ID3D12Device* device, + ID3D12CommandQueue* queue, + IOrtSessionBuilder** session_builder) override { + if (device == nullptr) { + auto builder = wil::MakeOrThrow(); + return builder.CopyTo(__uuidof(IOrtSessionBuilder), reinterpret_cast(session_builder)); + } else { + auto builder = wil::MakeOrThrow(device, queue); + return builder.CopyTo(__uuidof(IOrtSessionBuilder), reinterpret_cast(session_builder)); } + } - HRESULT STDMETHODCALLTYPE CreateGPUMLValue( - void * execution_provider_allocated_resource, - onnxruntime::IExecutionProvider* provider, - std::vector* shape, - onnxruntime::MLDataType data_type, - IOrtValue ** gpu_value) override { + onnxruntime::MLDataType STDMETHODCALLTYPE GetTensorType() override { + return onnxruntime::DataTypeImpl::GetType(); + } - THROW_HR_IF_MSG(WINML_ERR_INVALID_BINDING, - "DmlExecutionProvider" != provider->Type(), - "Cannot creat GPU tensor on CPU device"); + onnxruntime::MLDataType STDMETHODCALLTYPE GetTensorType(winml::TensorKind kind) override { + if (kind == TensorKind::Float) { + return onnxruntime::DataTypeImpl::GetType(); + } else if (kind == TensorKind::Double) { + return onnxruntime::DataTypeImpl::GetType(); + } else if (kind == TensorKind::String) { + return onnxruntime::DataTypeImpl::GetType(); + } else if (kind == TensorKind::UInt8) { + return onnxruntime::DataTypeImpl::GetType(); + } else if (kind == TensorKind::Int8) { + return onnxruntime::DataTypeImpl::GetType(); + } else if (kind == TensorKind::UInt16) { + return onnxruntime::DataTypeImpl::GetType(); + } else if (kind == TensorKind::Int16) { + return onnxruntime::DataTypeImpl::GetType(); + } else if (kind == TensorKind::UInt32) { + return onnxruntime::DataTypeImpl::GetType(); + } else if (kind == TensorKind::Int32) { + return onnxruntime::DataTypeImpl::GetType(); + } else if (kind == TensorKind::UInt64) { + return onnxruntime::DataTypeImpl::GetType(); + } else if (kind == TensorKind::Int64) { + return onnxruntime::DataTypeImpl::GetType(); + } else if (kind == TensorKind::Boolean) { + return onnxruntime::DataTypeImpl::GetType(); + } else if (kind == TensorKind::Float16) { + return onnxruntime::DataTypeImpl::GetType(); + } + return nullptr; + } - onnxruntime::TensorShape tensor_shape(*shape); + onnxruntime::MLDataType STDMETHODCALLTYPE GetMapType(winml::TensorKind key_kind, winml::TensorKind value_kind) override { + if (key_kind == TensorKind::String) { + if (value_kind == TensorKind::String) { + return onnxruntime::DataTypeImpl::GetType(); + } else if (value_kind == TensorKind::Int64) { + return onnxruntime::DataTypeImpl::GetType(); + } else if (value_kind == TensorKind::Float) { + return onnxruntime::DataTypeImpl::GetType(); + } else if (value_kind == TensorKind::Double) { + return onnxruntime::DataTypeImpl::GetType(); + } + } else if (key_kind == TensorKind::Int64) { + if (value_kind == TensorKind::String) { + return onnxruntime::DataTypeImpl::GetType(); + } else if (value_kind == TensorKind::Int64) { + return onnxruntime::DataTypeImpl::GetType(); + } else if (value_kind == TensorKind::Float) { + return onnxruntime::DataTypeImpl::GetType(); + } else if (value_kind == TensorKind::Double) { + return onnxruntime::DataTypeImpl::GetType(); + } + } + return nullptr; + } - auto tensor = new onnxruntime::Tensor( - data_type, - tensor_shape, - execution_provider_allocated_resource, - provider->GetAllocator(0, ::OrtMemType::OrtMemTypeDefault)->Info()); + onnxruntime::MLDataType STDMETHODCALLTYPE GetVectorMapType(winml::TensorKind key_kind, winml::TensorKind value_kind) override { + if (key_kind == TensorKind::String) { + if (value_kind == TensorKind::Float) { + return onnxruntime::DataTypeImpl::GetType(); + } + } else if (key_kind == TensorKind::Int64) { + if (value_kind == TensorKind::Float) { + return onnxruntime::DataTypeImpl::GetType(); + } + } + return nullptr; + } - auto ort_value = wil::MakeOrThrow(); - ort_value->get().Init(tensor, - onnxruntime::DataTypeImpl::GetType(), - onnxruntime::DataTypeImpl::GetType()->GetDeleteFunc()); + // returns the raw mutable data. + void* STDMETHODCALLTYPE GetTensorData(IOrtValue* ort_value) override { + auto ml_value = ort_value->get(); + auto tensor = ml_value->GetMutable(); + return static_cast(tensor->MutableDataRaw()); + } - *gpu_value = ort_value.Detach(); - return S_OK; + void* STDMETHODCALLTYPE GetMapData(IOrtValue* ort_value, winml::TensorKind key_kind, winml::TensorKind value_kind) override { + auto ml_value = ort_value->get(); + if (key_kind == TensorKind::Int64) { + if (value_kind == TensorKind::Int64) { + return static_cast(ml_value->GetMutable>()); + } else if (value_kind == TensorKind::Float) { + return static_cast(ml_value->GetMutable>()); + } else if (value_kind == TensorKind::Double) { + return static_cast(ml_value->GetMutable>()); + } else if (value_kind == TensorKind::String) { + return static_cast(ml_value->GetMutable>()); + } else { + THROW_HR(E_FAIL); + } + } + else if (key_kind == TensorKind::String) { + if (value_kind == TensorKind::Int64) { + return static_cast(ml_value->GetMutable>()); + } else if (value_kind == TensorKind::Float) { + return static_cast(ml_value->GetMutable>()); + } else if (value_kind == TensorKind::Double) { + return static_cast(ml_value->GetMutable>()); + } else if (value_kind == TensorKind::String) { + return static_cast(ml_value->GetMutable>()); + } else { + THROW_HR(E_FAIL); + } + } else { + THROW_HR(E_FAIL); } + } - HRESULT STDMETHODCALLTYPE CreateCPUMLValue( - std::vector* shape, - onnxruntime::MLDataType data_type, - onnxruntime::BufferNakedPtr buffer, - IOrtValue ** cpu_value) override { - auto registrations = onnxruntime::DeviceAllocatorRegistry::Instance().AllRegistrations(); - auto alloc = registrations[onnxruntime::CPU].factory(0); - - onnxruntime::TensorShape tensor_shape(*shape); - - // Unowned raw tensor pointer passed to engine - auto tensor = new onnxruntime::Tensor( - data_type, - tensor_shape, - buffer, - alloc->Info()); - - auto ort_value = wil::MakeOrThrow(); - ort_value->get().Init(tensor, - onnxruntime::DataTypeImpl::GetType(), - onnxruntime::DataTypeImpl::GetType()->GetDeleteFunc()); - - *cpu_value = ort_value.Detach(); - return S_OK; + void* STDMETHODCALLTYPE GetVectorData(IOrtValue* ort_value, winml::TensorKind key_kind, winml::TensorKind value_kind) override { + auto ml_value = ort_value->get(); + if (key_kind == TensorKind::String) { + if (value_kind == TensorKind::Float) { + return static_cast(ml_value->GetMutable>>()); + } else { + THROW_HR(E_FAIL); + } + } else if (key_kind == TensorKind::Int64) { + if (value_kind == TensorKind::Float) { + return static_cast(ml_value->GetMutable>>()); + } else { + THROW_HR(E_FAIL); + } + } else { + THROW_HR(E_FAIL); } + } - HRESULT STDMETHODCALLTYPE CreateMLValue( - winml::TensorKind kind, - onnxruntime::MLDataType data_type, - const int64_t * shape, - uint32_t shape_count, - onnxruntime::IExecutionProvider* provider, - IOrtValue ** ort_value) override { - onnxruntime::TensorShape tensor_shape(shape, shape_count); - auto tensor = new onnxruntime::Tensor( - GetType(kind), - tensor_shape, - provider->GetAllocator(0, ::OrtMemType::OrtMemTypeDefault)); - auto ort_value_out = wil::MakeOrThrow(); - ort_value_out->get().Init(tensor, - data_type, - data_type->GetDeleteFunc()); - - *ort_value = ort_value_out.Detach();; - return S_OK; - } +HRESULT STDMETHODCALLTYPE GetCustomRegistry(IMLOperatorRegistry** registry) override { + auto impl = wil::MakeOrThrow(); + *registry = impl.Detach(); + return S_OK; +} - HRESULT STDMETHODCALLTYPE CreateOrtValue( - void * data, - onnxruntime::MLDataType data_type, - IOrtValue ** ort_value) override { +void* STDMETHODCALLTYPE CreateGPUAllocationFromD3DResource(ID3D12Resource* pResource) override { + return Dml::CreateGPUAllocationFromD3DResource(pResource); +} - auto ort_value_out = wil::MakeOrThrow(); - ort_value_out->get().Init( - data, - data_type, - data_type->GetDeleteFunc()); +void STDMETHODCALLTYPE FreeGPUAllocation(void* ptr) override { + Dml::FreeGPUAllocation(ptr); +} +HRESULT STDMETHODCALLTYPE CopyTensor( + onnxruntime::IExecutionProvider* provider, + ITensor* src, + ITensor* dst) override { + ORT_THROW_IF_ERROR(Dml::CopyTensor(provider, src->get(), *(dst->getMutable()))); + return S_OK; +} - *ort_value = ort_value_out.Detach(); - return S_OK; - } +HRESULT STDMETHODCALLTYPE CreateGPUMLValue( + void* execution_provider_allocated_resource, + onnxruntime::IExecutionProvider* provider, + std::vector* shape, + onnxruntime::MLDataType data_type, + IOrtValue** gpu_value) override { + THROW_HR_IF_MSG(WINML_ERR_INVALID_BINDING, + "DmlExecutionProvider" != provider->Type(), + "Cannot creat GPU tensor on CPU device"); + + onnxruntime::TensorShape tensor_shape(*shape); + + auto tensor = new onnxruntime::Tensor( + data_type, + tensor_shape, + execution_provider_allocated_resource, + provider->GetAllocator(0, ::OrtMemType::OrtMemTypeDefault)->Info()); + + auto ort_value = wil::MakeOrThrow(); + ort_value->get()->Init(tensor, + onnxruntime::DataTypeImpl::GetType(), + onnxruntime::DataTypeImpl::GetType()->GetDeleteFunc()); + + *gpu_value = ort_value.Detach(); + return S_OK; +} - // Override select shape inference functions which are incomplete in ONNX with versions that are complete, - // and are also used in DML kernel registrations. Doing this avoids kernel and shader creation being - // deferred until first evaluation. It also prevents a situation where inference functions in externally - // registered schema are reachable only after upstream schema have been revised in a later OS release, - // which would be a compatibility risk. - HRESULT STDMETHODCALLTYPE OverrideSchemaInferenceFunctions() override { - static std::once_flag schema_override_once_flag; - std::call_once(schema_override_once_flag, []() { - SchemaInferenceOverrider::OverrideSchemaInferenceFunctions(); - }); - return S_OK; - } +HRESULT STDMETHODCALLTYPE CreateCPUMLValue( + std::vector* shape, + onnxruntime::MLDataType data_type, + onnxruntime::BufferNakedPtr buffer, + IOrtValue** cpu_value) override { + auto registrations = onnxruntime::DeviceAllocatorRegistry::Instance().AllRegistrations(); + auto alloc = registrations[onnxruntime::CPU].factory(0); + + onnxruntime::TensorShape tensor_shape(*shape); + + // Unowned raw tensor pointer passed to engine + auto tensor = new onnxruntime::Tensor( + data_type, + tensor_shape, + buffer, + alloc->Info()); + + auto ort_value = wil::MakeOrThrow(); + ort_value->get()->Init(tensor, + onnxruntime::DataTypeImpl::GetType(), + onnxruntime::DataTypeImpl::GetType()->GetDeleteFunc()); + + *cpu_value = ort_value.Detach(); + return S_OK; +} +HRESULT STDMETHODCALLTYPE CreateMLValue( + winml::TensorKind kind, + onnxruntime::MLDataType data_type, + const int64_t* shape, + uint32_t shape_count, + onnxruntime::IExecutionProvider* provider, + IOrtValue** ort_value) override { + onnxruntime::TensorShape tensor_shape(shape, shape_count); + auto tensor = new onnxruntime::Tensor( + GetType(kind), + tensor_shape, + provider->GetAllocator(0, ::OrtMemType::OrtMemTypeDefault)); + auto ort_value_out = wil::MakeOrThrow(); + ort_value_out->get()->Init(tensor, + data_type, + data_type->GetDeleteFunc()); + + *ort_value = ort_value_out.Detach(); + ; + return S_OK; +} -}; +HRESULT STDMETHODCALLTYPE CreateOrtValue( + void* data, + onnxruntime::MLDataType data_type, + IOrtValue** ort_value) override { + auto ort_value_out = wil::MakeOrThrow(); + ort_value_out->get()->Init( + data, + data_type, + data_type->GetDeleteFunc()); + + *ort_value = ort_value_out.Detach(); + return S_OK; +} -extern "C" -HRESULT STDMETHODCALLTYPE OrtGetWinMLAdapter(IWinMLAdapter** adapter) { - // make an adapter instance - Microsoft::WRL::ComPtr adapterptr = wil::MakeOrThrow(); - return adapterptr.CopyTo(__uuidof(IWinMLAdapter), reinterpret_cast(adapter)); +// Override select shape inference functions which are incomplete in ONNX with versions that are complete, +// and are also used in DML kernel registrations. Doing this avoids kernel and shader creation being +// deferred until first evaluation. It also prevents a situation where inference functions in externally +// registered schema are reachable only after upstream schema have been revised in a later OS release, +// which would be a compatibility risk. +HRESULT STDMETHODCALLTYPE OverrideSchemaInferenceFunctions() override { + static std::once_flag schema_override_once_flag; + std::call_once(schema_override_once_flag, []() { + SchemaInferenceOverrider::OverrideSchemaInferenceFunctions(); + }); + return S_OK; } +}; // namespace Windows::AI::MachineLearning::Adapter +extern "C" HRESULT STDMETHODCALLTYPE OrtGetWinMLAdapter(IWinMLAdapter** adapter) { + // make an adapter instance + Microsoft::WRL::ComPtr adapterptr = wil::MakeOrThrow(); + return adapterptr.CopyTo(__uuidof(IWinMLAdapter), reinterpret_cast(adapter)); +} // class IOBinding // =============== -class IOBinding : public Microsoft::WRL::RuntimeClass < - Microsoft::WRL::RuntimeClassFlags, - IIOBinding> { -private: - std::shared_ptr binding_; - std::vector outputs_weak; - std::vector> outputs_; - -public: +class IOBinding : public Microsoft::WRL::RuntimeClass< + Microsoft::WRL::RuntimeClassFlags, + IIOBinding> { + private: + std::shared_ptr binding_; + std::vector outputs_weak_; + std::vector> outputs_; - IOBinding(onnxruntime::IOBinding * binding) : binding_(binding) { - } + public: + IOBinding(onnxruntime::IOBinding* binding) : binding_(binding) { + } - onnxruntime::IOBinding* STDMETHODCALLTYPE get() override { - return binding_.get(); - } + onnxruntime::IOBinding* STDMETHODCALLTYPE get() override { + return binding_.get(); + } - HRESULT STDMETHODCALLTYPE BindInput(const std::string& name, IOrtValue* ml_value) override { - ORT_THROW_IF_ERROR(binding_->BindInput(name, ml_value->get())); - return S_OK; - } + HRESULT STDMETHODCALLTYPE BindInput(const std::string& name, IOrtValue* ml_value) override { + ORT_THROW_IF_ERROR(binding_->BindInput(name, *ml_value->get())); + return S_OK; + } - HRESULT STDMETHODCALLTYPE BindOutput(const std::string& name, IOrtValue* ml_value) override { - // this can be null for unbound outputs - if (ml_value == nullptr) { - OrtValue empty_value = {}; - ORT_THROW_IF_ERROR(binding_->BindOutput(name, empty_value)); - } else { - ORT_THROW_IF_ERROR(binding_->BindOutput(name, ml_value->get())); - } - return S_OK; + HRESULT STDMETHODCALLTYPE BindOutput(const std::string& name, IOrtValue* ml_value) override { + // this can be null for unbound outputs + if (ml_value == nullptr) { + OrtValue empty_value = {}; + ORT_THROW_IF_ERROR(binding_->BindOutput(name, empty_value)); + } else { + ORT_THROW_IF_ERROR(binding_->BindOutput(name, *ml_value->get())); } + return S_OK; + } - const std::vector& STDMETHODCALLTYPE GetOutputNames() override { - return binding_->GetOutputNames(); - } - std::vector& STDMETHODCALLTYPE GetOutputs() override { - auto output_inner = binding_->GetOutputs(); - outputs_.clear(); - for (unsigned i = 0; i < output_inner.size(); i++) { - auto ort_value = wil::MakeOrThrow(output_inner[i]); - outputs_.push_back(ort_value); - outputs_weak.push_back(ort_value.Get()); - } - return outputs_weak; - } + const std::vector& STDMETHODCALLTYPE GetOutputNames() override { + return binding_->GetOutputNames(); + } + std::vector& STDMETHODCALLTYPE GetOutputs() override { + auto output_inner = binding_->GetOutputs(); + outputs_weak_.clear(); + outputs_.clear(); + for (unsigned i = 0; i < output_inner.size(); i++) { + auto ort_value = wil::MakeOrThrow(&(output_inner[i])); + outputs_.push_back(ort_value); + outputs_weak_.push_back(ort_value.Get()); + } + return outputs_weak_; + } }; // InferenceSession // ================ -InferenceSession::InferenceSession(onnxruntime::InferenceSession * session) : session_(session) { - +InferenceSession::InferenceSession(onnxruntime::InferenceSession* session) : session_(session) { } void STDMETHODCALLTYPE InferenceSession::RegisterGraphTransformers(bool registerLotusTransforms) { - GraphTransformerHelpers::RegisterGraphTransformers(session_.get(), registerLotusTransforms); + GraphTransformerHelpers::RegisterGraphTransformers(session_.get(), registerLotusTransforms); } HRESULT STDMETHODCALLTYPE InferenceSession::NewIOBinding(IIOBinding** io_binding) { - std::unique_ptr binding; - ORT_THROW_IF_ERROR(this->session_->NewIOBinding(&binding)); - auto io_binding_outer = wil::MakeOrThrow(binding.release()); - return io_binding_outer.CopyTo(__uuidof(IIOBinding), reinterpret_cast(io_binding)); + std::unique_ptr binding; + ORT_THROW_IF_ERROR(this->session_->NewIOBinding(&binding)); + auto io_binding_outer = wil::MakeOrThrow(binding.release()); + return io_binding_outer.CopyTo(__uuidof(IIOBinding), reinterpret_cast(io_binding)); } HRESULT STDMETHODCALLTYPE InferenceSession::Run(const onnxruntime::RunOptions* run_options, IIOBinding* io_binding) { - ORT_THROW_IF_ERROR(this->session_->Run(*run_options, *(io_binding->get()))); - return S_OK; + ORT_THROW_IF_ERROR(this->session_->Run(*run_options, *(io_binding->get()))); + return S_OK; } HRESULT STDMETHODCALLTYPE InferenceSession::StartProfiling() { - this->session_->StartProfiling(PheonixSingleton()->GetDefaultLogger()); - return S_OK; - + this->session_->StartProfiling(PheonixSingleton()->GetDefaultLogger()); + return S_OK; } HRESULT STDMETHODCALLTYPE InferenceSession::EndProfiling() { - this->session_->EndProfiling(); - return S_OK; - + this->session_->EndProfiling(); + return S_OK; } HRESULT STDMETHODCALLTYPE InferenceSession::LoadModel( - IModelProto* model_proto) { - auto session_protected_load_accessor = - static_cast(session_.get()); - // session's like to have their very own copy of the model_proto, use detach() - std::unique_ptr model_proto_ptr(model_proto->detach()); - ORT_THROW_IF_ERROR(session_protected_load_accessor->Load(std::move(model_proto_ptr))); - return S_OK; + IModelProto* model_proto) { + auto session_protected_load_accessor = + static_cast(session_.get()); + // session's like to have their very own copy of the model_proto, use detach() + std::unique_ptr model_proto_ptr(model_proto->detach()); + ORT_THROW_IF_ERROR(session_protected_load_accessor->Load(std::move(model_proto_ptr))); + return S_OK; } HRESULT STDMETHODCALLTYPE InferenceSession::RegisterCustomRegistry( - IMLOperatorRegistry* registry) { - RETURN_HR_IF(S_OK, registry == nullptr); + IMLOperatorRegistry* registry) { + RETURN_HR_IF(S_OK, registry == nullptr); - auto custom_registries = GetLotusCustomRegistries(registry); + auto custom_registries = GetLotusCustomRegistries(registry); - // Register - for (auto& custom_registry : custom_registries) { - ORT_THROW_IF_ERROR(session_->RegisterCustomRegistry(custom_registry)); - } + // Register + for (auto& custom_registry : custom_registries) { + ORT_THROW_IF_ERROR(session_->RegisterCustomRegistry(custom_registry)); + } - return S_OK; + return S_OK; } void STDMETHODCALLTYPE InferenceSession::FlushContext(onnxruntime::IExecutionProvider* dml_provider) { - Dml::FlushContext(dml_provider); + Dml::FlushContext(dml_provider); } - void STDMETHODCALLTYPE InferenceSession::TrimUploadHeap(onnxruntime::IExecutionProvider* dml_provider) { - Dml::TrimUploadHeap(dml_provider); + Dml::TrimUploadHeap(dml_provider); } void STDMETHODCALLTYPE InferenceSession::ReleaseCompletedReferences(onnxruntime::IExecutionProvider* dml_provider) { - Dml::ReleaseCompletedReferences(dml_provider); + Dml::ReleaseCompletedReferences(dml_provider); } } // namespace Windows::AI::MachineLearning::Adapter \ No newline at end of file diff --git a/winml/lib/Api.Core/inc/WinMLAdapter.h b/winml/lib/Api.Core/inc/WinMLAdapter.h index 65cf1a51ce2b1..43f6b44df65c6 100644 --- a/winml/lib/Api.Core/inc/WinMLAdapter.h +++ b/winml/lib/Api.Core/inc/WinMLAdapter.h @@ -32,7 +32,7 @@ MIDL_INTERFACE("eaae30b5-7381-432d-9730-322136b02371") ITensor : IUnknown{ MIDL_INTERFACE("72aa5eee-100c-4146-9008-4643d3b8af23") IOrtValue : IUnknown{ // these all return weak pointers - virtual OrtValue& STDMETHODCALLTYPE get() = 0; + virtual OrtValue* STDMETHODCALLTYPE get() = 0; virtual onnxruntime::MLDataType STDMETHODCALLTYPE Type() = 0; virtual bool STDMETHODCALLTYPE IsTensor() = 0; // end @@ -126,9 +126,9 @@ MIDL_INTERFACE("b19385e7-d9af-441a-ba7f-3993c7b1c9db") IWinMLAdapter : IUnknown virtual onnxruntime::MLDataType STDMETHODCALLTYPE GetVectorMapType(winml::TensorKind key_kind, winml::TensorKind value_kind) = 0; // Data getter - virtual void * STDMETHODCALLTYPE GetTensorData(IOrtValue * ort_Value) = 0; - virtual void * STDMETHODCALLTYPE GetMapData(IOrtValue * ort_Value, winml::TensorKind key_kind, winml::TensorKind value_kind) = 0; - virtual void * STDMETHODCALLTYPE GetVectorData(IOrtValue * ort_Value, winml::TensorKind key_kind, winml::TensorKind value_kind) = 0; + virtual void * STDMETHODCALLTYPE GetTensorData(IOrtValue * ort_value) = 0; + virtual void * STDMETHODCALLTYPE GetMapData(IOrtValue * ort_value, winml::TensorKind key_kind, winml::TensorKind value_kind) = 0; + virtual void * STDMETHODCALLTYPE GetVectorData(IOrtValue * ort_value, winml::TensorKind key_kind, winml::TensorKind value_kind) = 0; // custom ops virtual HRESULT STDMETHODCALLTYPE GetCustomRegistry(IMLOperatorRegistry** registry) = 0; diff --git a/winml/lib/Api/impl/SequenceBase.h b/winml/lib/Api/impl/SequenceBase.h index d6b921ac8f890..9faf10b7a41ab 100644 --- a/winml/lib/Api/impl/SequenceBase.h +++ b/winml/lib/Api/impl/SequenceBase.h @@ -206,21 +206,24 @@ struct SequenceBase : public winrt::implements< template static TRawType ConvertToABIType( - typename ValidLotusType::Type lotus_value) { - return lotus_value; + const typename ValidLotusType::Type& lotus_value) { + // make a copy + TRawType copy = lotus_value; + return copy; } template <> static winrt::hstring ConvertToABIType( - typename ValidLotusType::Type lotus_value) { + const typename ValidLotusType::Type& lotus_value) { return WinML::Strings::HStringFromUTF8(lotus_value); } template <> static AbiMapStringToFloat ConvertToABIType( - typename ValidLotusType::Type lotus_value) { + const typename ValidLotusType::Type& lotus_value) { + // need to make a copy to convert std::string to hstring std::map copy; for (const auto& pair : lotus_value) { auto key = WinML::Strings::HStringFromUTF8(pair.first); @@ -233,9 +236,14 @@ struct SequenceBase : public winrt::implements< template <> static AbiMapInt64BitToFloat ConvertToABIType( - typename ValidLotusType::Type lotus_value) { + const typename ValidLotusType::Type& lotus_value) { + // need to make a copy since stl objects are not ABI safe. + std::map copy; + for (const auto& pair : lotus_value) { + copy[pair.first] = pair.second; + } return winrt::single_threaded_map( - std::move(lotus_value)); + std::move(copy)); } STDMETHOD(UpdateSourceResourceData)(