Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 0 additions & 59 deletions backends/vulkan/runtime/api/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,6 @@

#include <executorch/backends/vulkan/runtime/api/Context.h>

#include <cstdint>
#include <cstring>
#include <memory>
#include <sstream>

#ifndef VULKAN_DESCRIPTOR_POOL_SIZE
#define VULKAN_DESCRIPTOR_POOL_SIZE 1024u
#endif
Expand Down Expand Up @@ -220,59 +215,5 @@ Context* context() {
return context.get();
}

//
// UniformParamsBuffer
//

namespace {

void memcpy_to_buffer(const VulkanBuffer& src, VulkanBuffer& dst) {
MemoryMap dst_mapping(dst, MemoryAccessType::WRITE);

MemoryMap src_mapping(src, MemoryAccessType::READ);
src_mapping.invalidate();

void* dst_ptr = dst_mapping.template data<void>();
void* src_ptr = src_mapping.template data<void>();

// @lint-ignore CLANGTIDY facebook-security-vulnerable-memcpy
memcpy(dst_ptr, src_ptr, src.mem_size());
}

} // namespace

UniformParamsBuffer::UniformParamsBuffer(const UniformParamsBuffer& other)
: context_p_(other.context_p_), vulkan_buffer_{} {
if (other.vulkan_buffer_) {
vulkan_buffer_ = context_p_->adapter_ptr()->vma().create_uniform_buffer(
other.vulkan_buffer_.mem_size());

memcpy_to_buffer(other.vulkan_buffer_, vulkan_buffer_);
}
}

UniformParamsBuffer& UniformParamsBuffer::operator=(
const UniformParamsBuffer& other) {
if (&other != this) {
context_p_ = other.context_p_;

// Move vulkan_buffer_ to another VulkanBuffer for cleanup
if (vulkan_buffer_) {
VulkanBuffer temp_buffer(std::move(vulkan_buffer_));
context_p_->register_buffer_cleanup(temp_buffer);
}
// vulkan_buffer_ should now be empty

if (other.vulkan_buffer_) {
vulkan_buffer_ = context_p_->adapter_ptr()->vma().create_uniform_buffer(
other.vulkan_buffer_.mem_size());

memcpy_to_buffer(other.vulkan_buffer_, vulkan_buffer_);
}
}

return *this;
}

} // namespace api
} // namespace vkcompute
104 changes: 0 additions & 104 deletions backends/vulkan/runtime/api/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,12 @@

// @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName

#include <executorch/backends/vulkan/runtime/api/vk_api.h>

#include <executorch/backends/vulkan/runtime/api/Adapter.h>
#include <executorch/backends/vulkan/runtime/api/Command.h>
#include <executorch/backends/vulkan/runtime/api/Descriptor.h>
#include <executorch/backends/vulkan/runtime/api/Fence.h>
#include <executorch/backends/vulkan/runtime/api/Pipeline.h>
#include <executorch/backends/vulkan/runtime/api/QueryPool.h>
#include <executorch/backends/vulkan/runtime/api/Runtime.h>
#include <executorch/backends/vulkan/runtime/api/Shader.h>
#include <executorch/backends/vulkan/runtime/api/Utils.h>

#include <executorch/backends/vulkan/runtime/api/memory/Buffer.h>

namespace vkcompute {
namespace api {
Expand Down Expand Up @@ -218,103 +211,6 @@ class Context final {
void flush();
};

class UniformParamsBuffer final {
private:
Context* context_p_;
size_t nbytes_;
VulkanBuffer vulkan_buffer_;

public:
UniformParamsBuffer() : context_p_{nullptr}, vulkan_buffer_{} {}

template <typename Block>
UniformParamsBuffer(Context* context_p, const Block& block)
: context_p_(context_p),
nbytes_(sizeof(block)),
vulkan_buffer_(
context_p_->adapter_ptr()->vma().create_params_buffer(block)) {}

UniformParamsBuffer(const UniformParamsBuffer&);
UniformParamsBuffer& operator=(const UniformParamsBuffer&);

UniformParamsBuffer(UniformParamsBuffer&&) = default;
UniformParamsBuffer& operator=(UniformParamsBuffer&&) = default;

~UniformParamsBuffer() {
if (vulkan_buffer_) {
context_p_->register_buffer_cleanup(vulkan_buffer_);
}
}

const VulkanBuffer& buffer() const {
return vulkan_buffer_;
}

template <typename Block>
void update(const Block& block) {
if (sizeof(block) != nbytes_) {
VK_THROW(
"Attempted to update UniformParamsBuffer with data of different size");
}
// Fill the uniform buffer with data in block
{
MemoryMap mapping(vulkan_buffer_, MemoryAccessType::WRITE);
Block* data_ptr = mapping.template data<Block>();

*data_ptr = block;
}
}
};

class StorageBuffer final {
private:
Context* context_p_;
ScalarType dtype_;
size_t numel_;
size_t nbytes_;
VulkanBuffer vulkan_buffer_;

public:
StorageBuffer(
Context* context_p,
const ScalarType dtype,
const size_t numel,
const bool gpuonly = false)
: context_p_(context_p),
dtype_(dtype),
numel_(numel),
nbytes_(element_size(dtype_) * numel_),
vulkan_buffer_(context_p_->adapter_ptr()->vma().create_storage_buffer(
nbytes_,
gpuonly)) {}

StorageBuffer(const StorageBuffer&) = delete;
StorageBuffer& operator=(const StorageBuffer&) = delete;

StorageBuffer(StorageBuffer&&) = default;
StorageBuffer& operator=(StorageBuffer&&) = default;

~StorageBuffer() {
context_p_->register_buffer_cleanup(vulkan_buffer_);
}

inline ScalarType dtype() {
return dtype_;
}

inline VulkanBuffer& buffer() {
return vulkan_buffer_;
}

inline size_t numel() {
return numel_;
}

inline size_t nbytes() {
return nbytes_;
}
};

bool available();

// The global runtime is retrieved using this function, where it is declared as
Expand Down
66 changes: 66 additions & 0 deletions backends/vulkan/runtime/api/ParamsBuffer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/api/ParamsBuffer.h>

#include <cstring>

namespace vkcompute {
namespace api {

namespace {

void memcpy_to_buffer(const VulkanBuffer& src, VulkanBuffer& dst) {
MemoryMap dst_mapping(dst, MemoryAccessType::WRITE);

MemoryMap src_mapping(src, MemoryAccessType::READ);
src_mapping.invalidate();

void* dst_ptr = dst_mapping.template data<void>();
void* src_ptr = src_mapping.template data<void>();

// @lint-ignore CLANGTIDY facebook-security-vulnerable-memcpy
memcpy(dst_ptr, src_ptr, src.mem_size());
}

} // namespace

ParamsBuffer::ParamsBuffer(const ParamsBuffer& other)
: context_p_(other.context_p_), vulkan_buffer_{} {
if (other.vulkan_buffer_) {
vulkan_buffer_ = context_p_->adapter_ptr()->vma().create_uniform_buffer(
other.vulkan_buffer_.mem_size());

memcpy_to_buffer(other.vulkan_buffer_, vulkan_buffer_);
}
}

ParamsBuffer& ParamsBuffer::operator=(const ParamsBuffer& other) {
if (&other != this) {
context_p_ = other.context_p_;

// Move vulkan_buffer_ to another VulkanBuffer for cleanup
if (vulkan_buffer_) {
VulkanBuffer temp_buffer(std::move(vulkan_buffer_));
context_p_->register_buffer_cleanup(temp_buffer);
}
// vulkan_buffer_ should now be empty

if (other.vulkan_buffer_) {
vulkan_buffer_ = context_p_->adapter_ptr()->vma().create_uniform_buffer(
other.vulkan_buffer_.mem_size());

memcpy_to_buffer(other.vulkan_buffer_, vulkan_buffer_);
}
}

return *this;
}

} // namespace api
} // namespace vkcompute
68 changes: 68 additions & 0 deletions backends/vulkan/runtime/api/ParamsBuffer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

// @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName

#include <executorch/backends/vulkan/runtime/api/Context.h>

#include <executorch/backends/vulkan/runtime/api/memory/Buffer.h>

namespace vkcompute {
namespace api {

class ParamsBuffer final {
private:
Context* context_p_;
size_t nbytes_;
VulkanBuffer vulkan_buffer_;

public:
ParamsBuffer() : context_p_{nullptr}, vulkan_buffer_{} {}

template <typename Block>
ParamsBuffer(Context* context_p, const Block& block)
: context_p_(context_p),
nbytes_(sizeof(block)),
vulkan_buffer_(
context_p_->adapter_ptr()->vma().create_params_buffer(block)) {}

ParamsBuffer(const ParamsBuffer&);
ParamsBuffer& operator=(const ParamsBuffer&);

ParamsBuffer(ParamsBuffer&&) = default;
ParamsBuffer& operator=(ParamsBuffer&&) = default;

~ParamsBuffer() {
if (vulkan_buffer_) {
context_p_->register_buffer_cleanup(vulkan_buffer_);
}
}

const VulkanBuffer& buffer() const {
return vulkan_buffer_;
}

template <typename Block>
void update(const Block& block) {
if (sizeof(block) != nbytes_) {
VK_THROW("Attempted to update ParamsBuffer with data of different size");
}
// Fill the uniform buffer with data in block
{
MemoryMap mapping(vulkan_buffer_, MemoryAccessType::WRITE);
Block* data_ptr = mapping.template data<Block>();

*data_ptr = block;
}
}
};

} // namespace api
} // namespace vkcompute
Loading