From 5dc5313ecd7769b372dbfeed2e068ca105ebcd33 Mon Sep 17 00:00:00 2001 From: linkerzhang Date: Wed, 12 Dec 2018 16:41:01 -0800 Subject: [PATCH 1/3] add check before fusing sub-graph in greedy partitioning --- .../core/framework/graph_partitioner.cc | 48 +++++++++++++------ 1 file changed, 33 insertions(+), 15 deletions(-) diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 648f36c723d03..db8991d8f8167 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -55,6 +55,11 @@ KernelDefBuilder& BuildFusedKernelDef(KernelDefBuilder& builder, const onnxrunti } Status GraphPartitioner::Partition(onnxruntime::Graph& graph) const { + // It is a greedy partitioning algorithm per provider preferences user provided when calling ONNX RUNTIME right now. + // 1. Execution providers' capabilities are checked one by one. + // 2. All sub-graphs that an execution provider returns will be assigned to it if it's not assigned yet. + // 3. CPU execution provider is expected to be able to run any node and is the last one in execution provider preference. + if (providers_.Empty()) { return Status(ONNXRUNTIME, INVALID_ARGUMENT, "No provider specified."); } @@ -82,18 +87,31 @@ Status GraphPartitioner::Partition(onnxruntime::Graph& graph) const { } } else { // The can run a fused in the . - // - // Add fused node into ONNXRUNTIME_ENFORCE(nullptr != capability->sub_graph->GetMetaDef()); - std::string node_name = provider->Type() + "_" + capability->sub_graph->GetMetaDef()->name + "_" + std::to_string(count++); - auto& fused_node = graph.FuseSubGraph(std::move(capability->sub_graph), node_name); - fused_node.SetExecutionProviderType(provider->Type()); - auto fused_kernel_func = capability->fuse_kernel_function; - if (fused_kernel_func != nullptr) { - // build the kernel definition on the fly, and register it to the fused_kernel_regisitry. - KernelDefBuilder builder; - BuildFusedKernelDef(builder, fused_node); - fused_kernel_registry->Register(builder, fused_kernel_func); + + // Check whether any node in the was already assigned. + bool sub_graph_available_for_assignment = true; + for (auto node_index : capability->sub_graph->nodes) { + auto node = graph.GetNode(node_index); + if (nullptr == node || !node->GetExecutionProviderType().empty()) { + // There's invalid node or a node was assigned. + sub_graph_available_for_assignment = false; + break; + } + } + + if (sub_graph_available_for_assignment) { + // Add fused node into + std::string node_name = provider->Type() + "_" + capability->sub_graph->GetMetaDef()->name + "_" + std::to_string(count++); + auto& fused_node = graph.FuseSubGraph(std::move(capability->sub_graph), node_name); + fused_node.SetExecutionProviderType(provider->Type()); + auto fused_kernel_func = capability->fuse_kernel_function; + if (fused_kernel_func != nullptr) { + // build the kernel definition on the fly, and register it to the fused_kernel_regisitry. + KernelDefBuilder builder; + BuildFusedKernelDef(builder, fused_node); + fused_kernel_registry->Register(builder, fused_kernel_func); + } } } } @@ -126,10 +144,10 @@ Status GraphPartitioner::Partition(onnxruntime::Graph& graph) const { this->Partition(graph); } - //For some cases, like fp16 on cpu, right now we don't have any kernel support that. - //But we will insert cast op to run the model, so skip the error checking here. - //If after graph transform phase, the node still not assigned, we will report error - //during kernel creation phase. + //For some cases, like fp16 on cpu, right now we don't have any kernel support that. + //But we will insert cast op to run the model, so skip the error checking here. + //If after graph transform phase, the node still not assigned, we will report error + //during kernel creation phase. #ifdef COUNT_NON_CUDA_OPS for (auto& node : graph.Nodes()) { if (node.GetExecutionProviderType() != kCudaExecutionProvider && From b1452b535ea66968e8683adea68efd74d8a1005a Mon Sep 17 00:00:00 2001 From: linkerzhang Date: Wed, 12 Dec 2018 17:09:25 -0800 Subject: [PATCH 2/3] update the partitioning logic to 1) not fuse sub-graph if inner nodes were assigned 2) avoid resolving graph after each provider capability checking and assignment. --- .../core/framework/graph_partitioner.cc | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index db8991d8f8167..912aa7a06c4a4 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -68,10 +68,17 @@ Status GraphPartitioner::Partition(onnxruntime::Graph& graph) const { std::shared_ptr fused_kernel_registry = std::make_shared(); // Partitioning based on provider preference and their capabilities. auto kernel_registries = kernel_registry_mgr_.GetAllKernelRegistries(); + + std::vector>> capabilities_of_all_providers; + GraphViewer graph_viewer(graph); + for (auto& provider : providers_) { + capabilities_of_all_providers.push_back(provider->GetCapability(graph_viewer, kernel_registries)); + } + + int i = 0; for (auto& provider : providers_) { - auto capability_results = provider->GetCapability(GraphViewer(graph), kernel_registries); int count = 0; - for (auto& capability : capability_results) { + for (auto& capability : capabilities_of_all_providers[i++]) { if (nullptr == capability || nullptr == capability->sub_graph) { continue; } @@ -83,6 +90,7 @@ Status GraphPartitioner::Partition(onnxruntime::Graph& graph) const { auto node = graph.GetNode(capability->sub_graph->nodes[0]); if (nullptr != node && node->GetExecutionProviderType().empty()) { + // The node was not fused or assigned. Assign it to this . node->SetExecutionProviderType(provider->Type()); } } else { @@ -94,7 +102,8 @@ Status GraphPartitioner::Partition(onnxruntime::Graph& graph) const { for (auto node_index : capability->sub_graph->nodes) { auto node = graph.GetNode(node_index); if (nullptr == node || !node->GetExecutionProviderType().empty()) { - // There's invalid node or a node was assigned. + // The node was fused or assigned, so that the whole sub-graph will not be assigned to this + // The assumption is that this can only run the sub-graph as a whole unit. sub_graph_available_for_assignment = false; break; } @@ -115,11 +124,10 @@ Status GraphPartitioner::Partition(onnxruntime::Graph& graph) const { } } } - // all done with this provider, resolve the graph before we move on to the next provider. - // This is needed since we create a new GraphViewer() that we pass into the next provider's GetCapability(). - ONNXRUNTIME_ENFORCE(graph.Resolve().IsOK()); } + ONNXRUNTIME_ENFORCE(graph.Resolve().IsOK()); + // To see if the node with no provider can be inlined. If one such nodes can be // successfully inlined, we re-run the partitioner on the modified graph. bool inline_flag = false; From 659433b64b3482012dac8c4994a04fd09665d389 Mon Sep 17 00:00:00 2001 From: linkerzhang Date: Mon, 17 Dec 2018 12:03:20 -0800 Subject: [PATCH 3/3] resolve conflicts --- .../core/framework/graph_partitioner.cc | 25 ++----------------- onnxruntime/core/graph/graph.cc | 2 +- 2 files changed, 3 insertions(+), 24 deletions(-) diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 1626f4b45db4d..7ca96978f9960 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -95,8 +95,7 @@ Status GraphPartitioner::Partition(onnxruntime::Graph& graph) const { } } else { // The can run a fused in the . -<<<<<<< HEAD - ONNXRUNTIME_ENFORCE(nullptr != capability->sub_graph->GetMetaDef()); + ORT_ENFORCE(nullptr != capability->sub_graph->GetMetaDef()); // Check whether any node in the was already assigned. bool sub_graph_available_for_assignment = true; @@ -125,29 +124,9 @@ Status GraphPartitioner::Partition(onnxruntime::Graph& graph) const { } } } -======= - // - // Add fused node into - ORT_ENFORCE(nullptr != capability->sub_graph->GetMetaDef()); - std::string node_name = provider->Type() + "_" + capability->sub_graph->GetMetaDef()->name + "_" + std::to_string(count++); - auto& fused_node = graph.FuseSubGraph(std::move(capability->sub_graph), node_name); - fused_node.SetExecutionProviderType(provider->Type()); - auto fused_kernel_func = capability->fuse_kernel_function; - if (fused_kernel_func != nullptr) { - // build the kernel definition on the fly, and register it to the fused_kernel_regisitry. - KernelDefBuilder builder; - BuildFusedKernelDef(builder, fused_node); - fused_kernel_registry->Register(builder, fused_kernel_func); - } - } - } - // all done with this provider, resolve the graph before we move on to the next provider. - // This is needed since we create a new GraphViewer() that we pass into the next provider's GetCapability(). - ORT_ENFORCE(graph.Resolve().IsOK()); ->>>>>>> b418adff42824a45ba96e9af272b986f995d3f2b } - ONNXRUNTIME_ENFORCE(graph.Resolve().IsOK()); + ORT_ENFORCE(graph.Resolve().IsOK()); // To see if the node with no provider can be inlined. If one such nodes can be // successfully inlined, we re-run the partitioner on the modified graph. diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 1c0c84a03ffb7..b593e0b5bb656 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -830,7 +830,7 @@ void Graph::AddEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int src_ // The output type of source node arg does not match the input type of destination node arg. ORT_THROW("Argument type mismatch when adding edge."); } else { - //src_arg->UpdateTypeAndShape(*dst_arg); + src_arg->UpdateTypeAndShape(*dst_arg); *dst_arg_pointer = src_arg; } }