diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 203cd365275c7..7ca96978f9960 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -55,6 +55,11 @@ KernelDefBuilder& BuildFusedKernelDef(KernelDefBuilder& builder, const onnxrunti } Status GraphPartitioner::Partition(onnxruntime::Graph& graph) const { + // It is a greedy partitioning algorithm per provider preferences user provided when calling ONNX RUNTIME right now. + // 1. Execution providers' capabilities are checked one by one. + // 2. All sub-graphs that an execution provider returns will be assigned to it if it's not assigned yet. + // 3. CPU execution provider is expected to be able to run any node and is the last one in execution provider preference. + if (providers_.Empty()) { return Status(ONNXRUNTIME, INVALID_ARGUMENT, "No provider specified."); } @@ -63,10 +68,17 @@ Status GraphPartitioner::Partition(onnxruntime::Graph& graph) const { std::shared_ptr fused_kernel_registry = std::make_shared(); // Partitioning based on provider preference and their capabilities. auto kernel_registries = kernel_registry_mgr_.GetAllKernelRegistries(); + + std::vector>> capabilities_of_all_providers; + GraphViewer graph_viewer(graph); + for (auto& provider : providers_) { + capabilities_of_all_providers.push_back(provider->GetCapability(graph_viewer, kernel_registries)); + } + + int i = 0; for (auto& provider : providers_) { - auto capability_results = provider->GetCapability(GraphViewer(graph), kernel_registries); int count = 0; - for (auto& capability : capability_results) { + for (auto& capability : capabilities_of_all_providers[i++]) { if (nullptr == capability || nullptr == capability->sub_graph) { continue; } @@ -78,30 +90,44 @@ Status GraphPartitioner::Partition(onnxruntime::Graph& graph) const { auto node = graph.GetNode(capability->sub_graph->nodes[0]); if (nullptr != node && node->GetExecutionProviderType().empty()) { + // The node was not fused or assigned. Assign it to this . node->SetExecutionProviderType(provider->Type()); } } else { // The can run a fused in the . - // - // Add fused node into ORT_ENFORCE(nullptr != capability->sub_graph->GetMetaDef()); - std::string node_name = provider->Type() + "_" + capability->sub_graph->GetMetaDef()->name + "_" + std::to_string(count++); - auto& fused_node = graph.FuseSubGraph(std::move(capability->sub_graph), node_name); - fused_node.SetExecutionProviderType(provider->Type()); - auto fused_kernel_func = capability->fuse_kernel_function; - if (fused_kernel_func != nullptr) { - // build the kernel definition on the fly, and register it to the fused_kernel_regisitry. - KernelDefBuilder builder; - BuildFusedKernelDef(builder, fused_node); - fused_kernel_registry->Register(builder, fused_kernel_func); + + // Check whether any node in the was already assigned. + bool sub_graph_available_for_assignment = true; + for (auto node_index : capability->sub_graph->nodes) { + auto node = graph.GetNode(node_index); + if (nullptr == node || !node->GetExecutionProviderType().empty()) { + // The node was fused or assigned, so that the whole sub-graph will not be assigned to this + // The assumption is that this can only run the sub-graph as a whole unit. + sub_graph_available_for_assignment = false; + break; + } + } + + if (sub_graph_available_for_assignment) { + // Add fused node into + std::string node_name = provider->Type() + "_" + capability->sub_graph->GetMetaDef()->name + "_" + std::to_string(count++); + auto& fused_node = graph.FuseSubGraph(std::move(capability->sub_graph), node_name); + fused_node.SetExecutionProviderType(provider->Type()); + auto fused_kernel_func = capability->fuse_kernel_function; + if (fused_kernel_func != nullptr) { + // build the kernel definition on the fly, and register it to the fused_kernel_regisitry. + KernelDefBuilder builder; + BuildFusedKernelDef(builder, fused_node); + fused_kernel_registry->Register(builder, fused_kernel_func); + } } } } - // all done with this provider, resolve the graph before we move on to the next provider. - // This is needed since we create a new GraphViewer() that we pass into the next provider's GetCapability(). - ORT_ENFORCE(graph.Resolve().IsOK()); } + ORT_ENFORCE(graph.Resolve().IsOK()); + // To see if the node with no provider can be inlined. If one such nodes can be // successfully inlined, we re-run the partitioner on the modified graph. bool inline_flag = false; @@ -126,10 +152,10 @@ Status GraphPartitioner::Partition(onnxruntime::Graph& graph) const { this->Partition(graph); } - //For some cases, like fp16 on cpu, right now we don't have any kernel support that. - //But we will insert cast op to run the model, so skip the error checking here. - //If after graph transform phase, the node still not assigned, we will report error - //during kernel creation phase. + //For some cases, like fp16 on cpu, right now we don't have any kernel support that. + //But we will insert cast op to run the model, so skip the error checking here. + //If after graph transform phase, the node still not assigned, we will report error + //during kernel creation phase. #ifdef COUNT_NON_CUDA_OPS for (auto& node : graph.Nodes()) { if (node.GetExecutionProviderType() != kCudaExecutionProvider &&