From eb3a3810ee198aca35dd93196145804335e40d69 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Mon, 4 Mar 2024 10:25:48 -0800 Subject: [PATCH] Implement global shader registry (#2222) Summary: X-link: https://github.com/pytorch/pytorch/pull/121088 ## Context This changeset updates Vulkan SPIR-V codegen to introduce a global SPIR-V shader registry and register shaders dynamically at static initialization time. This change makes it possible to define and link custom shader libraries to the ATen-Vulkan runtime. Before: * `gen_vulkan_spv.py` generated two files, `spv.h` and `spv.cpp` which would contain the definition and initialization of Vulkan shader registry variables. After: * Introduce the `ShaderRegistry` class in `api/`, which encapsulates functionality of the `ShaderRegistry` class previously defined in the generated `spv.h` file * Introduce a global shader registry (defined as a static variable in the `api::shader_registry() function` * Define a `ShaderRegisterInit` class (taking inspiration from `TorchLibraryInit`) that allows for dynamic shader registration * `gen_vulkan_spv.py` now only generates `spv.cpp`, which defines a static `ShaderRegisterInit` instance that triggers registration of the compiled shaders to the global shader registry. Benefits: * Cleaner code base; we no longer have `ShaderRegistry` defined in a generated file, and don't need a separate implementation file (`impl/Registry.*`) to handle shader lookup. All that logic now lives under `api/ShaderRegistry.*` * Makes it possible to compile and link separate shader libraries, providing similar flexibility as defining and linking custom ATen operators Differential Revision: D54447700 --- backends/vulkan/targets.bzl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/vulkan/targets.bzl b/backends/vulkan/targets.bzl index f189e44c1a6..345f18801fe 100644 --- a/backends/vulkan/targets.bzl +++ b/backends/vulkan/targets.bzl @@ -53,8 +53,8 @@ def define_common_targets(): "@EXECUTORCH_CLIENTS", ], exported_deps = [ - "//caffe2:torch_vulkan_api", "//caffe2:torch_vulkan_ops", + "//caffe2:torch_vulkan_spv", ], define_static_target = False, )