From e4bec078188fe9ab20c8e69b209ad7931845ca26 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Tue, 12 Mar 2024 08:59:22 -0700 Subject: [PATCH] [ET-VK] Enable Partial GPU lowering via Vulkan in stories model export ## Context Simple change to add Vulkan Partitioner as a dependency for the llama exporter and runner, and provide a command line flag to invoke the vulkan partitioner during export. Included a small change to the Vulkan serializer which was needed for everything to work (i.e. enable serializing multiple graph outputs). Differential Revision: [D54805831](https://our.internmc.facebook.com/intern/diff/D54805831/) [ghstack-poisoned] --- .../vulkan/serialization/vulkan_graph_builder.py | 15 ++++++++------- examples/models/llama2/TARGETS | 1 + examples/models/llama2/export_llama_lib.py | 5 +++++ examples/models/llama2/runner/targets.bzl | 1 + 4 files changed, 15 insertions(+), 7 deletions(-) diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py index 572ef018bc2..4bd0c527605 100644 --- a/backends/vulkan/serialization/vulkan_graph_builder.py +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -218,13 +218,14 @@ def process_getattr_node(self, node: Node) -> None: self.create_tensor_values(node) def process_output_node(self, node: Node) -> None: - if node.all_input_nodes[0] not in self.node_to_value_ids: - raise AssertionError( - "Cannot find input to output node in node_to_value_ids. This means the " - "output node is being serialized before its corresponding internal node " - "which is not allowed." - ) - self.output_ids.append(self.node_to_value_ids[node.all_input_nodes[0]]) + for out_node in node.all_input_nodes: + if out_node not in self.node_to_value_ids: + raise AssertionError( + "Cannot find input to output node in node_to_value_ids. This means " + "the output node is being serialized before its corresponding " + "internal node which is not allowed." + ) + self.output_ids.append(self.node_to_value_ids[out_node]) def process_node(self, node: Node) -> None: if node.op == "placeholder": diff --git a/examples/models/llama2/TARGETS b/examples/models/llama2/TARGETS index ddd3feab54b..ec6f97a2813 100644 --- a/examples/models/llama2/TARGETS +++ b/examples/models/llama2/TARGETS @@ -82,6 +82,7 @@ runtime.python_library( "//executorch/backends/transforms:duplicate_dynamic_quant_chain", "//executorch/backends/xnnpack:xnnpack_backend", "//executorch/backends/xnnpack/partition:xnnpack_partitioner", + "//executorch/backends/vulkan/partitioner:vulkan_partitioner", "//executorch/examples/models:model_base", "//executorch/examples/models:models", "//executorch/examples/portable:utils", diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 398c1ac1c7c..5c56b3ea247 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -17,6 +17,7 @@ import pkg_resources import torch +from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( XnnpackDynamicallyQuantizedPartitioner, ) @@ -356,6 +357,7 @@ def build_args_parser() -> argparse.ArgumentParser: parser.add_argument("-2", "--fairseq2", action="store_true") parser.add_argument("-v", "--verbose", action="store_true") parser.add_argument("-X", "--xnnpack", action="store_true") + parser.add_argument("-V", "--vulkan", action="store_true") return parser @@ -451,6 +453,9 @@ def _export_llama(modelname, args) -> str: # noqa: C901 ) # partitioners[XnnpackPartitioner.__name__] = XnnpackPartitioner() modelname = f"xnnpack_{modelname}" + if args.vulkan: + partitioners[VulkanPartitioner.__name__] = VulkanPartitioner() + modelname = f"vulkan_{modelname}" builder = ( load_llama_model( diff --git a/examples/models/llama2/runner/targets.bzl b/examples/models/llama2/runner/targets.bzl index 2f943d73ec4..ab87393d109 100644 --- a/examples/models/llama2/runner/targets.bzl +++ b/examples/models/llama2/runner/targets.bzl @@ -29,6 +29,7 @@ def define_common_targets(): ], exported_deps = [ "//executorch/backends/xnnpack:xnnpack_backend", + "//executorch/backends/vulkan:vulkan_backend_lib", "//executorch/examples/models/llama2/sampler:sampler" + aten_suffix, "//executorch/examples/models/llama2/tokenizer:tokenizer", "//executorch/extension/evalue_util:print_evalue" + aten_suffix,