Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions backends/vulkan/runtime/api/Adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,8 @@ DeviceHandle::~DeviceHandle() {
Adapter::Adapter(
VkInstance instance,
PhysicalDevice physical_device,
const uint32_t num_queues)
const uint32_t num_queues,
const std::string& cache_data_path)
: queue_usage_mutex_{},
physical_device_(std::move(physical_device)),
queues_{},
Expand All @@ -307,7 +308,7 @@ Adapter::Adapter(
shader_layout_cache_(device_.handle_),
shader_cache_(device_.handle_),
pipeline_layout_cache_(device_.handle_),
compute_pipeline_cache_(device_.handle_),
compute_pipeline_cache_(device_.handle_, cache_data_path),
sampler_cache_(device_.handle_),
vma_(instance_, physical_device_.handle, device_.handle_) {}

Expand Down
3 changes: 2 additions & 1 deletion backends/vulkan/runtime/api/Adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ class Adapter final {
explicit Adapter(
VkInstance instance,
PhysicalDevice physical_device,
const uint32_t num_queues);
const uint32_t num_queues,
const std::string& cache_data_path);

Adapter(const Adapter&) = delete;
Adapter& operator=(const Adapter&) = delete;
Expand Down
54 changes: 49 additions & 5 deletions backends/vulkan/runtime/api/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

#include <executorch/backends/vulkan/runtime/api/Pipeline.h>

#include <fstream>

namespace vkcompute {
namespace api {

Expand Down Expand Up @@ -358,17 +360,24 @@ void PipelineLayoutCache::purge() {
// ComputePipelineCache
//

ComputePipelineCache::ComputePipelineCache(VkDevice device)
ComputePipelineCache::ComputePipelineCache(
VkDevice device,
const std::string& cache_data_path)
: cache_mutex_{},
device_(device),
pipeline_cache_{VK_NULL_HANDLE},
cache_{} {
const VkPipelineCacheCreateInfo pipeline_cache_create_info{
cache_{},
cache_data_path_(cache_data_path) {
VkPipelineCacheCreateInfo pipeline_cache_create_info{};

auto buffer = load_cache();

pipeline_cache_create_info = {
VK_STRUCTURE_TYPE_PIPELINE_CACHE_CREATE_INFO, // sType
nullptr, // pNext
0u, // flags
0u, // initialDataSize
nullptr, // pInitialData
buffer.size(), // initialDataSize
buffer.data(), // pInitialData
};

VK_CHECK(vkCreatePipelineCache(
Expand All @@ -392,6 +401,9 @@ ComputePipelineCache::~ComputePipelineCache() {
if (VK_NULL_HANDLE == pipeline_cache_) {
return;
}

save_cache();

vkDestroyPipelineCache(device_, pipeline_cache_, nullptr);
pipeline_cache_ = VK_NULL_HANDLE;
}
Expand All @@ -416,5 +428,37 @@ void ComputePipelineCache::purge() {
cache_.clear();
}

std::vector<char> ComputePipelineCache::load_cache() {
// Return if path is not specified; this means the optimization is disabled
if (cache_data_path_.empty()) {
return {};
}

// Return if file doesn't exist; this is expected on the first model-load
std::ifstream file(cache_data_path_, std::ios::binary | std::ios::ate);
if (file.fail()) {
return {};
}

auto size = file.tellg();
file.seekg(0, std::ios::beg);

std::vector<char> buffer(size);
file.read(buffer.data(), size);

return buffer;
}

void ComputePipelineCache::save_cache() {
size_t size{};
vkGetPipelineCacheData(device_, pipeline_cache_, &size, nullptr);

std::vector<char> buffer(size);
vkGetPipelineCacheData(device_, pipeline_cache_, &size, buffer.data());

std::ofstream file(cache_data_path_, std::ios::binary);
file.write(buffer.data(), buffer.size());
}

} // namespace api
} // namespace vkcompute
8 changes: 7 additions & 1 deletion backends/vulkan/runtime/api/Pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,9 @@ class PipelineLayoutCache final {

class ComputePipelineCache final {
public:
explicit ComputePipelineCache(VkDevice device);
explicit ComputePipelineCache(
VkDevice device,
const std::string& cache_data_path);

ComputePipelineCache(const ComputePipelineCache&) = delete;
ComputePipelineCache& operator=(const ComputePipelineCache&) = delete;
Expand Down Expand Up @@ -266,13 +268,17 @@ class ComputePipelineCache final {
};

private:
std::vector<char> load_cache();
void save_cache();

// Multiple threads could potentially be adding entries into the cache, so use
// a mutex to manage access
std::mutex cache_mutex_;

VkDevice device_;
VkPipelineCache pipeline_cache_;
std::unordered_map<Key, Value, Hasher> cache_;
const std::string cache_data_path_;

public:
VkPipeline retrieve(const Key&);
Expand Down
7 changes: 6 additions & 1 deletion backends/vulkan/runtime/api/Runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,12 +253,14 @@ std::unique_ptr<Runtime> init_global_vulkan_runtime() {
#endif /* VULKAN_DEBUG */
const bool init_default_device = true;
const uint32_t num_requested_queues = 1; // TODO: raise this value
const std::string cache_data_path = ""; // TODO: expose to client

const RuntimeConfiguration default_config{
enable_validation_messages,
init_default_device,
AdapterSelector::First,
num_requested_queues,
cache_data_path,
};

try {
Expand Down Expand Up @@ -351,7 +353,10 @@ uint32_t Runtime::create_adapter(const Selector& selector) {
// Otherwise, create an adapter for the selected physical device
adapter_i = utils::safe_downcast<int32_t>(adapters_.size());
adapters_.emplace_back(new Adapter(
instance_, device_mapping.first, config_.num_requested_queues));
instance_,
device_mapping.first,
config_.num_requested_queues,
config_.cache_data_path));
device_mapping.second = adapter_i;

return adapter_i;
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/runtime/api/Runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ struct RuntimeConfiguration final {
bool init_default_device;
AdapterSelector default_selector;
uint32_t num_requested_queues;
std::string cache_data_path;
};

class Runtime final {
Expand Down