diff --git a/backends/vulkan/runtime/api/Adapter.cpp b/backends/vulkan/runtime/api/Adapter.cpp index 5db2642e3ec..932678f18fc 100644 --- a/backends/vulkan/runtime/api/Adapter.cpp +++ b/backends/vulkan/runtime/api/Adapter.cpp @@ -292,7 +292,8 @@ DeviceHandle::~DeviceHandle() { Adapter::Adapter( VkInstance instance, PhysicalDevice physical_device, - const uint32_t num_queues) + const uint32_t num_queues, + const std::string& cache_data_path) : queue_usage_mutex_{}, physical_device_(std::move(physical_device)), queues_{}, @@ -307,7 +308,7 @@ Adapter::Adapter( shader_layout_cache_(device_.handle_), shader_cache_(device_.handle_), pipeline_layout_cache_(device_.handle_), - compute_pipeline_cache_(device_.handle_), + compute_pipeline_cache_(device_.handle_, cache_data_path), sampler_cache_(device_.handle_), vma_(instance_, physical_device_.handle, device_.handle_) {} diff --git a/backends/vulkan/runtime/api/Adapter.h b/backends/vulkan/runtime/api/Adapter.h index ef246260021..fcbba281642 100644 --- a/backends/vulkan/runtime/api/Adapter.h +++ b/backends/vulkan/runtime/api/Adapter.h @@ -101,7 +101,8 @@ class Adapter final { explicit Adapter( VkInstance instance, PhysicalDevice physical_device, - const uint32_t num_queues); + const uint32_t num_queues, + const std::string& cache_data_path); Adapter(const Adapter&) = delete; Adapter& operator=(const Adapter&) = delete; diff --git a/backends/vulkan/runtime/api/Pipeline.cpp b/backends/vulkan/runtime/api/Pipeline.cpp index bc5d46af21c..a6bff47cac1 100644 --- a/backends/vulkan/runtime/api/Pipeline.cpp +++ b/backends/vulkan/runtime/api/Pipeline.cpp @@ -8,6 +8,8 @@ #include +#include + namespace vkcompute { namespace api { @@ -358,17 +360,24 @@ void PipelineLayoutCache::purge() { // ComputePipelineCache // -ComputePipelineCache::ComputePipelineCache(VkDevice device) +ComputePipelineCache::ComputePipelineCache( + VkDevice device, + const std::string& cache_data_path) : cache_mutex_{}, device_(device), pipeline_cache_{VK_NULL_HANDLE}, - cache_{} { - const VkPipelineCacheCreateInfo pipeline_cache_create_info{ + cache_{}, + cache_data_path_(cache_data_path) { + VkPipelineCacheCreateInfo pipeline_cache_create_info{}; + + auto buffer = load_cache(); + + pipeline_cache_create_info = { VK_STRUCTURE_TYPE_PIPELINE_CACHE_CREATE_INFO, // sType nullptr, // pNext 0u, // flags - 0u, // initialDataSize - nullptr, // pInitialData + buffer.size(), // initialDataSize + buffer.data(), // pInitialData }; VK_CHECK(vkCreatePipelineCache( @@ -392,6 +401,9 @@ ComputePipelineCache::~ComputePipelineCache() { if (VK_NULL_HANDLE == pipeline_cache_) { return; } + + save_cache(); + vkDestroyPipelineCache(device_, pipeline_cache_, nullptr); pipeline_cache_ = VK_NULL_HANDLE; } @@ -416,5 +428,37 @@ void ComputePipelineCache::purge() { cache_.clear(); } +std::vector ComputePipelineCache::load_cache() { + // Return if path is not specified; this means the optimization is disabled + if (cache_data_path_.empty()) { + return {}; + } + + // Return if file doesn't exist; this is expected on the first model-load + std::ifstream file(cache_data_path_, std::ios::binary | std::ios::ate); + if (file.fail()) { + return {}; + } + + auto size = file.tellg(); + file.seekg(0, std::ios::beg); + + std::vector buffer(size); + file.read(buffer.data(), size); + + return buffer; +} + +void ComputePipelineCache::save_cache() { + size_t size{}; + vkGetPipelineCacheData(device_, pipeline_cache_, &size, nullptr); + + std::vector buffer(size); + vkGetPipelineCacheData(device_, pipeline_cache_, &size, buffer.data()); + + std::ofstream file(cache_data_path_, std::ios::binary); + file.write(buffer.data(), buffer.size()); +} + } // namespace api } // namespace vkcompute diff --git a/backends/vulkan/runtime/api/Pipeline.h b/backends/vulkan/runtime/api/Pipeline.h index 118a67e37d5..35b3b6275b4 100644 --- a/backends/vulkan/runtime/api/Pipeline.h +++ b/backends/vulkan/runtime/api/Pipeline.h @@ -216,7 +216,9 @@ class PipelineLayoutCache final { class ComputePipelineCache final { public: - explicit ComputePipelineCache(VkDevice device); + explicit ComputePipelineCache( + VkDevice device, + const std::string& cache_data_path); ComputePipelineCache(const ComputePipelineCache&) = delete; ComputePipelineCache& operator=(const ComputePipelineCache&) = delete; @@ -266,6 +268,9 @@ class ComputePipelineCache final { }; private: + std::vector load_cache(); + void save_cache(); + // Multiple threads could potentially be adding entries into the cache, so use // a mutex to manage access std::mutex cache_mutex_; @@ -273,6 +278,7 @@ class ComputePipelineCache final { VkDevice device_; VkPipelineCache pipeline_cache_; std::unordered_map cache_; + const std::string cache_data_path_; public: VkPipeline retrieve(const Key&); diff --git a/backends/vulkan/runtime/api/Runtime.cpp b/backends/vulkan/runtime/api/Runtime.cpp index ebed34162f3..432af326a53 100644 --- a/backends/vulkan/runtime/api/Runtime.cpp +++ b/backends/vulkan/runtime/api/Runtime.cpp @@ -253,12 +253,14 @@ std::unique_ptr init_global_vulkan_runtime() { #endif /* VULKAN_DEBUG */ const bool init_default_device = true; const uint32_t num_requested_queues = 1; // TODO: raise this value + const std::string cache_data_path = ""; // TODO: expose to client const RuntimeConfiguration default_config{ enable_validation_messages, init_default_device, AdapterSelector::First, num_requested_queues, + cache_data_path, }; try { @@ -351,7 +353,10 @@ uint32_t Runtime::create_adapter(const Selector& selector) { // Otherwise, create an adapter for the selected physical device adapter_i = utils::safe_downcast(adapters_.size()); adapters_.emplace_back(new Adapter( - instance_, device_mapping.first, config_.num_requested_queues)); + instance_, + device_mapping.first, + config_.num_requested_queues, + config_.cache_data_path)); device_mapping.second = adapter_i; return adapter_i; diff --git a/backends/vulkan/runtime/api/Runtime.h b/backends/vulkan/runtime/api/Runtime.h index 6cfcc0ca03a..e4cb6922ad8 100644 --- a/backends/vulkan/runtime/api/Runtime.h +++ b/backends/vulkan/runtime/api/Runtime.h @@ -39,6 +39,7 @@ struct RuntimeConfiguration final { bool init_default_device; AdapterSelector default_selector; uint32_t num_requested_queues; + std::string cache_data_path; }; class Runtime final {