From a27910c9f5cd0283a57b9639ac51c36b300da615 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Thu, 11 Sep 2025 18:46:39 -0700 Subject: [PATCH 1/5] Address edge cases when - input propagates to output - output is produced by an initialier. --- .../core/session/onnxruntime_cxx_inline.h | 32 ++++++----- onnxruntime/core/session/inference_session.cc | 51 +++++++++++++++--- onnxruntime/test/shared_lib/test_inference.cc | 19 +++++++ .../testdata/input_propagated_to_output.onnx | Bin 0 -> 854 bytes 4 files changed, 81 insertions(+), 21 deletions(-) create mode 100644 onnxruntime/test/testdata/input_propagated_to_output.onnx diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 59979189eed0f..9c42bf34b5b0f 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -1582,11 +1582,13 @@ inline std::vector ConstSessionImpl::GetMemoryInfoForInputs( auto num_inputs = GetInputCount(); std::vector mem_infos; - mem_infos.resize(num_inputs); + if (num_inputs > 0) { + mem_infos.resize(num_inputs); - ThrowOnError(GetApi().SessionGetMemoryInfoForInputs(this->p_, - reinterpret_cast(mem_infos.data()), - num_inputs)); + ThrowOnError(GetApi().SessionGetMemoryInfoForInputs(this->p_, + reinterpret_cast(mem_infos.data()), + num_inputs)); + } return mem_infos; } @@ -1598,11 +1600,13 @@ inline std::vector ConstSessionImpl::GetMemoryInfoForOutputs auto num_outputs = GetOutputCount(); std::vector mem_infos; - mem_infos.resize(num_outputs); + if (num_outputs > 0) { + mem_infos.resize(num_outputs); - ThrowOnError(GetApi().SessionGetMemoryInfoForOutputs(this->p_, - reinterpret_cast(mem_infos.data()), - num_outputs)); + ThrowOnError(GetApi().SessionGetMemoryInfoForOutputs(this->p_, + reinterpret_cast(mem_infos.data()), + num_outputs)); + } return mem_infos; } @@ -1631,12 +1635,12 @@ template inline std::vector ConstSessionImpl::GetEpDeviceForInputs() const { auto num_inputs = GetInputCount(); std::vector input_devices; - input_devices.resize(num_inputs); - - ThrowOnError(GetApi().SessionGetEpDeviceForInputs(this->p_, - reinterpret_cast(input_devices.data()), - num_inputs)); - + if (num_inputs > 0) { + input_devices.resize(num_inputs); + ThrowOnError(GetApi().SessionGetEpDeviceForInputs(this->p_, + reinterpret_cast(input_devices.data()), + num_inputs)); + } return input_devices; } diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 17b7f9af372bc..cc9a2413256b6 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -3383,17 +3383,54 @@ common::Status InferenceSession::GetInputOutputMemoryInfo(SessionInputOutputType for (const auto* def : def_list) { InlinedVector node_info_vec; + Status status; if (type == SessionInputOutputType::kOutput) { - ORT_RETURN_IF_ERROR(session_state_->GetOutputNodeInfo(def->Name(), node_info_vec)); + status = session_state_->GetOutputNodeInfo(def->Name(), node_info_vec); } else { - ORT_RETURN_IF_ERROR(session_state_->GetInputNodeInfo(def->Name(), node_info_vec)); + status = session_state_->GetInputNodeInfo(def->Name(), node_info_vec); } - // all entries are for the same OrtDevice so use the first one. - // we need to get an OrtMemoryInfo* that will remain valid, so we get the allocator for the OrtDevice - // from the session state and use its OrtMemoryInfo. - auto allocator = session_state_->GetAllocator(*node_info_vec.front().device); - memory_info.push_back(&allocator->Info()); + if (!status.IsOK()) { + if (type == SessionInputOutputType::kInput) { + return status; + } + + // Check first if this output is produced by an input that directly + // propagates to output with the same name. + status = session_state_->GetInputNodeInfo(def->Name(), node_info_vec); + if (status.IsOK()) { + auto allocator = session_state_->GetAllocator(*node_info_vec.front().device); + memory_info.push_back(&allocator->Info()); + } else { + // Check if this output is produced by a constant initializer + // Pick the MemoryInfo from the initializer's OrtValue + const auto& ort_value_map = session_state_->GetOrtValueNameIdxMap(); + + OrtValueIndex ort_value_index; + status = ort_value_map.GetIdx(def->Name(), ort_value_index); + if (!status.IsOK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Failed to find node output or a constant initializer producing output: ", + def->Name(), "."); + } + + const auto& idx_to_ort_value = session_state_->GetInitializedTensors(); + auto it = idx_to_ort_value.find(ort_value_index); + if (it == idx_to_ort_value.end()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Failed to find node output or a constant initializer producing output: ", + def->Name(), "."); + } + const auto& tensor = it->second.Get(); + memory_info.push_back(&tensor.Location()); + } + } else { + // all entries are for the same OrtDevice so use the first one. + // we need to get an OrtMemoryInfo* that will remain valid, so we get the allocator for the OrtDevice + // from the session state and use its OrtMemoryInfo. + auto allocator = session_state_->GetAllocator(*node_info_vec.front().device); + memory_info.push_back(&allocator->Info()); + } } return Status::OK(); diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index b7a9da8e1b658..105758bf095eb 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -494,6 +494,17 @@ INSTANTIATE_TEST_SUITE_P(CApiTestWithProviders, CApiTestWithProvider, ::testing::Values(0, 1, 2, 3, 4)); +TEST(CApiTest, TestInputPassThroughToOutput) { + const ORTCHAR_T* model_uri = TSTR("testdata/input_propagated_to_output.onnx"); + Ort::Session session(*ort_env, model_uri, Ort::SessionOptions{}); + auto inputs_meminfos = session.GetMemoryInfoForInputs(); + ASSERT_EQ(1U, inputs_meminfos.size()); + auto inputs_epdevices = session.GetEpDeviceForInputs(); + ASSERT_EQ(1U, inputs_epdevices.size()); + auto outputs_meminfos = session.GetMemoryInfoForOutputs(); + ASSERT_EQ(7U, outputs_meminfos.size()); +} + #if !defined(DISABLE_SPARSE_TENSORS) TEST(CApiTest, SparseOutputModel) { std::vector dense_shape{3, 3}; @@ -505,7 +516,15 @@ TEST(CApiTest, SparseOutputModel) { std::vector ort_inputs; std::vector input_names; const char* const output_names[] = {"values"}; + // This model produces a sparse output from a constant sparse initializer Ort::Session session(*ort_env, SPARSE_OUTPUT_MODEL_URI, Ort::SessionOptions{}); + auto inputs_meminfos = session.GetMemoryInfoForInputs(); + ASSERT_TRUE(inputs_meminfos.empty()); + auto outputs_meminfos = session.GetMemoryInfoForOutputs(); + ASSERT_EQ(1U, outputs_meminfos.size()); + auto inputs_epdevices = session.GetEpDeviceForInputs(); + ASSERT_TRUE(inputs_epdevices.empty()); + auto ort_outputs = session.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), output_names, 1); ASSERT_EQ(ort_outputs.size(), 1U); diff --git a/onnxruntime/test/testdata/input_propagated_to_output.onnx b/onnxruntime/test/testdata/input_propagated_to_output.onnx new file mode 100644 index 0000000000000000000000000000000000000000..feeab10556cb06cfb9fc59c9e03d84a6b41f55f7 GIT binary patch literal 854 zcmd|_X(Hj-j>&d)0Y zN`ojaQJ}ODP}mrv%LqxAF;Eu}I~fC|O(42}(qLUMLrsCwra);^h_o?OS`28Y2@r#% z%%qqdQ&N;bgb+_jYH>+?dQoCQM!boZ2?sL}GIB5qFggL<=cTgl@-jP(d7^exgl_J8 zYb$OSK0U~;(|~RNn#qf8?Y@?T$@AERBwRirZPqsU_-1h}3i`aKLPTIG9S-L&kG-EHGhb|oZe#G|Dkl~$1_9n8c`mN}yu1p%%)GSJ WA|XK#yQnm;Br`Wvudp;RuLJ-Jnx9qx literal 0 HcmV?d00001 From c30f06beb416247d8157d05e633c0ad260dbc6be Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Mon, 15 Sep 2025 10:45:06 -0700 Subject: [PATCH 2/5] Account for unlikely dnagling inputs (unconsumed) --- onnxruntime/core/session/inference_session.cc | 22 +++++++++++-------- onnxruntime/test/shared_lib/test_inference.cc | 18 +++++++++++++++ 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index cc9a2413256b6..1fe1f6a57e13f 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -3459,15 +3459,19 @@ common::Status InferenceSession::GetEpDeviceForInputs(InlinedVector node_info_vec; ORT_RETURN_IF_ERROR(session_state_->GetInputNodeInfo(def->Name(), node_info_vec)); - - // if we have a lot of inputs or there are a lot of execution providers it may be worth creating a map - // instead of doing a linear search each time. - const auto& ep_name = node_info_vec.front().p_node->GetExecutionProviderType(); - auto it = std::find_if(available_eps.begin(), available_eps.end(), [&ep_name](const OrtEpDevice* entry) { - return entry->ep_name == ep_name; - }); - - ep_devices.push_back(it != available_eps.end() ? *it : nullptr); + assert(!node_info_vec.empty()); + // If we have an input that is not consumed by any node, + // including nodes in subgraphs, then we return nullptr. + const auto* p_node = node_info_vec.front().p_node; + if (p_node != nullptr) { + const auto ep_name = p_node->GetExecutionProviderType(); + auto it = std::find_if(available_eps.begin(), available_eps.end(), [&ep_name](const OrtEpDevice* entry) { + return entry->ep_name == ep_name; + }); + ep_devices.push_back(it != available_eps.end() ? *it : nullptr); + } else { + ep_devices.push_back(nullptr); + } } return Status::OK(); diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 105758bf095eb..4defcfcd56b23 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -505,6 +505,24 @@ TEST(CApiTest, TestInputPassThroughToOutput) { ASSERT_EQ(7U, outputs_meminfos.size()); } +TEST(CApiTest, TestDanglingInput) { + // Here we test an issue with segments_ids that is an input not consumed by anything + // This kind of model is unlikely to be used in practice but we want to make sure it works + const ORTCHAR_T* model_uri = TSTR("test_embed_layer_norm_unit_test_batch1_empty_segment.onnx"); + Ort::Session session(*ort_env, model_uri, Ort::SessionOptions{}); + auto inputs_meminfos = session.GetMemoryInfoForInputs(); + ASSERT_EQ(2U, inputs_meminfos.size()); + auto outputs_meminfos = session.GetMemoryInfoForOutputs(); + ASSERT_EQ(2U, outputs_meminfos.size()); + auto inputs_epdevices = session.GetEpDeviceForInputs(); + ASSERT_EQ(2U, inputs_epdevices.size()); + // One of the devices returning is null since the input is not consumed + // there is not a device for it. + const bool null_present = std::any_of(inputs_epdevices.begin(), inputs_epdevices.end(), + [](const auto& device) { return device == nullptr; }); + ASSERT_TRUE(null_present); +} + #if !defined(DISABLE_SPARSE_TENSORS) TEST(CApiTest, SparseOutputModel) { std::vector dense_shape{3, 3}; From 9f4381ec945d259bcbeaf16b5aa06bc4b6677796 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Mon, 15 Sep 2025 12:09:41 -0700 Subject: [PATCH 3/5] Fix test model path --- onnxruntime/test/shared_lib/test_inference.cc | 2 +- .../test_dangling_input_segment_ids.onnx | Bin 0 -> 1180 bytes 2 files changed, 1 insertion(+), 1 deletion(-) create mode 100644 onnxruntime/test/testdata/test_dangling_input_segment_ids.onnx diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 4defcfcd56b23..8c2928670934a 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -508,7 +508,7 @@ TEST(CApiTest, TestInputPassThroughToOutput) { TEST(CApiTest, TestDanglingInput) { // Here we test an issue with segments_ids that is an input not consumed by anything // This kind of model is unlikely to be used in practice but we want to make sure it works - const ORTCHAR_T* model_uri = TSTR("test_embed_layer_norm_unit_test_batch1_empty_segment.onnx"); + const ORTCHAR_T* model_uri = TSTR("testdata/test_dangling_input_segment_ids.onnx"); Ort::Session session(*ort_env, model_uri, Ort::SessionOptions{}); auto inputs_meminfos = session.GetMemoryInfoForInputs(); ASSERT_EQ(2U, inputs_meminfos.size()); diff --git a/onnxruntime/test/testdata/test_dangling_input_segment_ids.onnx b/onnxruntime/test/testdata/test_dangling_input_segment_ids.onnx new file mode 100644 index 0000000000000000000000000000000000000000..a83c21030ad674d208a08da3bc5f7ce967204890 GIT binary patch literal 1180 zcmZ8heNdAH7-wtBc=>R=%#Umef~^B%p_dqJ@ADIK5e$Swur<-Su{o!{+t}a~bm9j` zS}KPr3T5$eVpk!kmnmdcXs1wWq+(N0z(X{pJF~*dOje!7AJ5(M-1B=r?z!jp6Z^*h zC1#2kKda4AkgvCzT#OIHth75#dNXe{n;1Wb-Sw~Q!;G>Rc;3K>jOKg;8@R$yXm;A{ zPF`;>$Y+CigKN3oYBQNv4aec*)x-6v!`U_5 zvEDiQ$IUEGgJWsbPLViFERqNkpI9bko#=$6h8FBF!8I6y=Xn*ie3Fci{W1)l&Ecve zI;o&D8fzk#qxiffvk1LGuCx zrzN20&3|TI6VY`bB*Hk_;7xQn9rvkXxis zqdZfInBPjM;=PSj(>(>}+nPvqM}s8(I?XY@in1~{R?I!*%H*%F{s#7hz&#G z7@M~isqOxV+94E6ynpfdF;mtFJ*}g$r-7jdlMhqosf(&wno81M!b$Oc)#i^{=;&8D z$eK4D6+JN!L?uvCLN6uuO0j6M8|~j@(7^^bZB>?1n*VHkJam&TDXd(mLWpqzuTxfD zDms+ANg(f4#r!xPcSdQkMcPWA{t8U<2=U!&2JSrtFrG`pbhC)FRLw+#bdbuoJMr>y zBYk*m6Lnm!=kB;JlTv+^`vKwbDN93%VFqmSIe2z{BOTqW#m+Y_Q^ws)1RBGzwu_~V z&DGp%iVD@_eVe%A_FHr|bbuSE>{C@y0xm_iQO+#jothhTGOwSaCV6l24?j~R1M?@N z>rAAosLu~KPpOqrGH=?;+Z2goy$k#)3*09(Sf0KNx@VrCKO4jZh)1(xMkp3MGLdc) z{ F;BR?lv5o)$ literal 0 HcmV?d00001 From 455b78621574fcc3b589711949ae86c1b055f372 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Mon, 15 Sep 2025 16:30:39 -0700 Subject: [PATCH 4/5] Address review comments --- onnxruntime/core/session/inference_session.cc | 6 +- .../test_dangling_input_segment_ids.py | 56 +++++++++++++++++++ 2 files changed, 61 insertions(+), 1 deletion(-) create mode 100644 onnxruntime/test/testdata/test_dangling_input_segment_ids.py diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 1fe1f6a57e13f..c424bc4264b0d 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -3399,6 +3399,9 @@ common::Status InferenceSession::GetInputOutputMemoryInfo(SessionInputOutputType // propagates to output with the same name. status = session_state_->GetInputNodeInfo(def->Name(), node_info_vec); if (status.IsOK()) { + // all entries are for the same OrtDevice so use the first one. + // we need to get an OrtMemoryInfo* that will remain valid, so we get the allocator for the OrtDevice + // from the session state and use its OrtMemoryInfo. auto allocator = session_state_->GetAllocator(*node_info_vec.front().device); memory_info.push_back(&allocator->Info()); } else { @@ -3422,7 +3425,8 @@ common::Status InferenceSession::GetInputOutputMemoryInfo(SessionInputOutputType def->Name(), "."); } const auto& tensor = it->second.Get(); - memory_info.push_back(&tensor.Location()); + auto allocator = session_state_->GetAllocator(tensor.Location()); + memory_info.push_back(&allocator->Info()); } } else { // all entries are for the same OrtDevice so use the first one. diff --git a/onnxruntime/test/testdata/test_dangling_input_segment_ids.py b/onnxruntime/test/testdata/test_dangling_input_segment_ids.py new file mode 100644 index 0000000000000..fbc9e1cf87fd0 --- /dev/null +++ b/onnxruntime/test/testdata/test_dangling_input_segment_ids.py @@ -0,0 +1,56 @@ +""" +Run this script to recreate the original onnx model. +Example usage: +python test_dangling_input_segment_ids.py out_model_path.onnx +""" + +from onnx import helper, numpy_helper, TensorProto + +import onnx +import numpy as np +import sys +import os + +DATA_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_dangling_input_segment_ids') + +def order_repeated_field(repeated_proto, key_name, order): + order = list(order) + repeated_proto.sort(key=lambda x: order.index(getattr(x, key_name))) + +def make_node( + op_type, inputs, outputs, name=None, doc_string=None, domain=None, **kwargs +): + node = helper.make_node( + op_type, inputs, outputs, name, doc_string, domain, **kwargs + ) + if doc_string == "": + node.doc_string = "" + order_repeated_field(node.attribute, "name", kwargs.keys()) + return node + +def make_graph(*args, doc_string=None, **kwargs): + graph = helper.make_graph(*args, doc_string=doc_string, **kwargs) + if doc_string == "": + graph.doc_string = "" + return graph + +model = helper.make_model( + opset_imports=[helper.make_operatorsetid('', 14), helper.make_operatorsetid('com.microsoft', 1)], + ir_version=7, + graph=make_graph( + name='embed_layernorm_graph', + inputs=[helper.make_tensor_value_info('input_ids', TensorProto.INT32, shape=[1, 4]), helper.make_tensor_value_info('segment_ids', TensorProto.INT32, shape=[1, 4])], + outputs=[helper.make_tensor_value_info('layernorm_out', TensorProto.FLOAT, shape=[1, 4, 4]), helper.make_tensor_value_info('mask_index_out', TensorProto.INT32, shape=[1])], + initializer=[ + numpy_helper.from_array(np.load(os.path.join(DATA_DIR, 'const0_word_embed.npy')).astype('float32').reshape([32, 4]), name='word_embed'), + numpy_helper.from_array(np.load(os.path.join(DATA_DIR, 'const1_pos_embed.npy')).astype('float32').reshape([16, 4]), name='pos_embed'), + numpy_helper.from_array(np.array([0.6185135841369629, 0.010364261455833912, 0.5386272668838501, 0.0030179566238075495], dtype='float32'), name='gamma'), + numpy_helper.from_array(np.array([0.9511938095092773, 0.9054020047187805, 0.7959669232368469, 0.9152743220329285], dtype='float32'), name='beta'), + ], + nodes=[make_node('EmbedLayerNormalization', inputs=['input_ids', '', 'word_embed', 'pos_embed', '', 'gamma', 'beta'], outputs=['layernorm_out', 'mask_index_out'], domain='com.microsoft')], + ), +) + +if __name__ == '__main__' and len(sys.argv) == 2: + _, out_path = sys.argv + onnx.save(model, out_path) From 2891999a1c8104fcf4103d9b3bd5955ffacdf49d Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Mon, 15 Sep 2025 18:19:15 -0700 Subject: [PATCH 5/5] Lint --- .../test_dangling_input_segment_ids.py | 72 +++++++++++++------ 1 file changed, 51 insertions(+), 21 deletions(-) diff --git a/onnxruntime/test/testdata/test_dangling_input_segment_ids.py b/onnxruntime/test/testdata/test_dangling_input_segment_ids.py index fbc9e1cf87fd0..c5eb8a600d6b5 100644 --- a/onnxruntime/test/testdata/test_dangling_input_segment_ids.py +++ b/onnxruntime/test/testdata/test_dangling_input_segment_ids.py @@ -4,53 +4,83 @@ python test_dangling_input_segment_ids.py out_model_path.onnx """ -from onnx import helper, numpy_helper, TensorProto +import os +import sys -import onnx import numpy as np -import sys -import os +import onnx +from onnx import TensorProto, helper, numpy_helper + +DATA_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "test_dangling_input_segment_ids") -DATA_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_dangling_input_segment_ids') def order_repeated_field(repeated_proto, key_name, order): order = list(order) repeated_proto.sort(key=lambda x: order.index(getattr(x, key_name))) -def make_node( - op_type, inputs, outputs, name=None, doc_string=None, domain=None, **kwargs -): - node = helper.make_node( - op_type, inputs, outputs, name, doc_string, domain, **kwargs - ) + +def make_node(op_type, inputs, outputs, name=None, doc_string=None, domain=None, **kwargs): + node = helper.make_node(op_type, inputs, outputs, name, doc_string, domain, **kwargs) if doc_string == "": node.doc_string = "" order_repeated_field(node.attribute, "name", kwargs.keys()) return node + def make_graph(*args, doc_string=None, **kwargs): graph = helper.make_graph(*args, doc_string=doc_string, **kwargs) if doc_string == "": graph.doc_string = "" return graph + model = helper.make_model( - opset_imports=[helper.make_operatorsetid('', 14), helper.make_operatorsetid('com.microsoft', 1)], + opset_imports=[helper.make_operatorsetid("", 14), helper.make_operatorsetid("com.microsoft", 1)], ir_version=7, graph=make_graph( - name='embed_layernorm_graph', - inputs=[helper.make_tensor_value_info('input_ids', TensorProto.INT32, shape=[1, 4]), helper.make_tensor_value_info('segment_ids', TensorProto.INT32, shape=[1, 4])], - outputs=[helper.make_tensor_value_info('layernorm_out', TensorProto.FLOAT, shape=[1, 4, 4]), helper.make_tensor_value_info('mask_index_out', TensorProto.INT32, shape=[1])], + name="embed_layernorm_graph", + inputs=[ + helper.make_tensor_value_info("input_ids", TensorProto.INT32, shape=[1, 4]), + helper.make_tensor_value_info("segment_ids", TensorProto.INT32, shape=[1, 4]), + ], + outputs=[ + helper.make_tensor_value_info("layernorm_out", TensorProto.FLOAT, shape=[1, 4, 4]), + helper.make_tensor_value_info("mask_index_out", TensorProto.INT32, shape=[1]), + ], initializer=[ - numpy_helper.from_array(np.load(os.path.join(DATA_DIR, 'const0_word_embed.npy')).astype('float32').reshape([32, 4]), name='word_embed'), - numpy_helper.from_array(np.load(os.path.join(DATA_DIR, 'const1_pos_embed.npy')).astype('float32').reshape([16, 4]), name='pos_embed'), - numpy_helper.from_array(np.array([0.6185135841369629, 0.010364261455833912, 0.5386272668838501, 0.0030179566238075495], dtype='float32'), name='gamma'), - numpy_helper.from_array(np.array([0.9511938095092773, 0.9054020047187805, 0.7959669232368469, 0.9152743220329285], dtype='float32'), name='beta'), + numpy_helper.from_array( + np.load(os.path.join(DATA_DIR, "const0_word_embed.npy")).astype("float32").reshape([32, 4]), + name="word_embed", + ), + numpy_helper.from_array( + np.load(os.path.join(DATA_DIR, "const1_pos_embed.npy")).astype("float32").reshape([16, 4]), + name="pos_embed", + ), + numpy_helper.from_array( + np.array( + [0.6185135841369629, 0.010364261455833912, 0.5386272668838501, 0.0030179566238075495], + dtype="float32", + ), + name="gamma", + ), + numpy_helper.from_array( + np.array( + [0.9511938095092773, 0.9054020047187805, 0.7959669232368469, 0.9152743220329285], dtype="float32" + ), + name="beta", + ), + ], + nodes=[ + make_node( + "EmbedLayerNormalization", + inputs=["input_ids", "", "word_embed", "pos_embed", "", "gamma", "beta"], + outputs=["layernorm_out", "mask_index_out"], + domain="com.microsoft", + ) ], - nodes=[make_node('EmbedLayerNormalization', inputs=['input_ids', '', 'word_embed', 'pos_embed', '', 'gamma', 'beta'], outputs=['layernorm_out', 'mask_index_out'], domain='com.microsoft')], ), ) -if __name__ == '__main__' and len(sys.argv) == 2: +if __name__ == "__main__" and len(sys.argv) == 2: _, out_path = sys.argv onnx.save(model, out_path)