diff --git a/examples/models/llama2/TARGETS b/examples/models/llama2/TARGETS index 4d8f0cce628..8d3c8c24d36 100644 --- a/examples/models/llama2/TARGETS +++ b/examples/models/llama2/TARGETS @@ -82,6 +82,7 @@ runtime.python_library( "//executorch/backends/transforms:duplicate_dynamic_quant_chain", "//executorch/backends/xnnpack:xnnpack_backend", "//executorch/backends/xnnpack/partition:xnnpack_partitioner", + "//executorch/backends/vulkan/partitioner:vulkan_partitioner", "//executorch/examples/models:model_base", "//executorch/examples/models:models", "//executorch/examples/portable:utils", diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index ede434927c5..0e036f66857 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -18,6 +18,7 @@ import pkg_resources import torch +from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( XnnpackDynamicallyQuantizedPartitioner, ) @@ -359,6 +360,7 @@ def build_args_parser() -> argparse.ArgumentParser: parser.add_argument("-2", "--fairseq2", action="store_true") parser.add_argument("-v", "--verbose", action="store_true") parser.add_argument("-X", "--xnnpack", action="store_true") + parser.add_argument("-V", "--vulkan", action="store_true") parser.add_argument( "--generate_etrecord", @@ -463,6 +465,17 @@ def _export_llama(modelname, args) -> str: # noqa: C901 # partitioners[XnnpackPartitioner.__name__] = XnnpackPartitioner() modelname = f"xnnpack_{modelname}" + if args.vulkan: + assert ( + args.dtype_override is None + ), "Vulkan backend does not support non fp32 dtypes at the moment" + assert ( + args.quantization_mode is None + ), "Vulkan backend does not support quantization at the moment" + + partitioners[VulkanPartitioner.__name__] = VulkanPartitioner() + modelname = f"vulkan_{modelname}" + builder_exported_to_edge = ( load_llama_model( checkpoint=checkpoint_path, diff --git a/examples/models/llama2/runner/targets.bzl b/examples/models/llama2/runner/targets.bzl index 95883037e70..72e8443b8bd 100644 --- a/examples/models/llama2/runner/targets.bzl +++ b/examples/models/llama2/runner/targets.bzl @@ -36,7 +36,11 @@ def define_common_targets(): "//executorch/extension/module:module" + aten_suffix, "//executorch/kernels/quantized:generated_lib" + aten_suffix, "//executorch/runtime/core/exec_aten:lib" + aten_suffix, - ] + (_get_operator_lib(aten)), + ] + (_get_operator_lib(aten)) + ([ + # Vulkan API currently cannot build on some platforms (e.g. Apple, FBCODE) + # Therefore enable it explicitly for now to avoid failing tests + "//executorch/backends/vulkan:vulkan_backend_lib", + ] if native.read_config("llama", "use_vulkan", "0") == "1" else []), external_deps = [ "libtorch", ] if aten else [],