From 8f6cf94cc7c8369bc6da7e326862c50d06a730d8 Mon Sep 17 00:00:00 2001 From: noli Date: Thu, 27 Jun 2024 13:00:34 +0000 Subject: [PATCH 1/6] torch to ark --- ark/api/executor.cpp | 70 ++++++++++++++++++++--- ark/api/tensor.cpp | 23 +++++++- ark/codegen.cpp | 39 +++++++++---- ark/codegen.hpp | 4 +- ark/include/ark/tensor.hpp | 2 + ark/model/model_buffer.cpp | 55 ++++++++++++++++-- ark/model/model_buffer.hpp | 15 +++++ ark/model/model_buffer_manager.hpp | 90 ++++++++++++++++++++++++++++++ python/ark/data_type.py | 22 ++++++++ python/ark/tensor.py | 16 ++++++ python/tensor_py.cpp | 47 +++++++++++++++- python/unittest/test_conversion.py | 83 ++++++++++++++++++++++++++- 12 files changed, 438 insertions(+), 28 deletions(-) create mode 100644 ark/model/model_buffer_manager.hpp diff --git a/ark/api/executor.cpp b/ark/api/executor.cpp index ebfa7016d..243937ce1 100644 --- a/ark/api/executor.cpp +++ b/ark/api/executor.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include "ark/data_type.hpp" #include "ark/model.hpp" @@ -24,6 +25,7 @@ #include "gpu/gpu_manager.h" #include "logging.h" #include "model/model_buffer.hpp" +#include "model/model_buffer_manager.hpp" #include "model/model_data_type.hpp" #include "model/model_tensor.hpp" #include "utils/utils_net.hpp" @@ -229,8 +231,16 @@ Executor::Impl::Impl(int rank, int world_size, int gpu_id, std::to_string(kv.first) + ": " + std::to_string(kv.second) + ", "; } - codegen_ = - std::make_shared(plan_json, buffer_id_to_offset_, name); + ModelBufferManager &buffer_manager = ModelBufferManager::getInstance(); + std::shared_ptr codegen_; + + if (!buffer_manager.isEmpty()) { + codegen_ = std::make_shared( + plan_json, buffer_id_to_offset_, name, &buffer_manager); + } else { + codegen_ = std::make_shared(plan_json, + buffer_id_to_offset_, name); + } auto gpu_manager = GpuManager::get_instance(gpu_id_); timer_begin_ = gpu_manager->create_event(); @@ -361,7 +371,16 @@ std::map Executor::Impl::init_buffers(const Json &plan_json) { } continue; } - buffer_id_to_offset[buf_info->buffer->id()] = offset; + if (buf_info->buffer->is_external()) { + if (buf_info->buffer->device_id() != gpu_id_) { + ERR(InvalidUsageError, + "PyTorch tensor and model execution are on different GPUs"); + } + continue; + } else { + buffer_id_to_offset[buf_info->buffer->id()] = offset; + offset += buf_info->bytes; + } for (const auto &tag_info : buf_info->buffer->send_tags()) { remote_rank_to_send_tags_and_offsets[tag_info.first] .first.push_back(tag_info.second); @@ -374,7 +393,6 @@ std::map Executor::Impl::init_buffers(const Json &plan_json) { remote_rank_to_recv_tags_and_offsets[tag_info.first] .second.push_back(offset); } - offset += buf_info->bytes; } total_bytes_ = offset; @@ -450,7 +468,11 @@ std::map Executor::Impl::init_buffers(const Json &plan_json) { bootstrap->recv(tags.data(), len * sizeof(int), remote_rank, 1); bootstrap->recv(offsets.data(), len * sizeof(size_t), remote_rank, 2); for (int i = 0; i < len; ++i) { - buffer_id_to_offset[send_tag_to_buffer_id[tags[i]]] = offsets[i]; + if (!buffer_id_to_info[send_tag_to_buffer_id[tags[i]]] + ->buffer->is_external()) { + buffer_id_to_offset[send_tag_to_buffer_id[tags[i]]] = + offsets[i]; + } } } for (auto &kv : remote_rank_to_recv_tag_to_buffer_id) { @@ -466,10 +488,13 @@ std::map Executor::Impl::init_buffers(const Json &plan_json) { bootstrap->recv(tags.data(), len * sizeof(int), remote_rank, 4); bootstrap->recv(offsets.data(), len * sizeof(size_t), remote_rank, 5); for (int i = 0; i < len; ++i) { - buffer_id_to_offset[recv_tag_to_buffer_id[tags[i]]] = offsets[i]; + if (!buffer_id_to_info[recv_tag_to_buffer_id[tags[i]]] + ->buffer->is_external()) { + buffer_id_to_offset[recv_tag_to_buffer_id[tags[i]]] = + offsets[i]; + } } } - return buffer_id_to_offset; } @@ -617,6 +642,22 @@ void Executor::Impl::launch(int64_t max_spin_count) { gpuMemcpyHostToDevice, copy_stream_->get())); GLOG(gpuMemcpyAsync(buf_ptr_addr, &buf_ptr_val, sizeof(gpuDeviceptr), gpuMemcpyHostToDevice, copy_stream_->get())); + + // Handle external buffers + ModelBufferManager &buffer_manager = ModelBufferManager::getInstance(); + if (!buffer_manager.isEmpty()) { + void *ext_buf_addr = get_global_rt("ARK_EXTERNAL_BUFFERS"); + std::vector ext_buffers(buffer_manager.getCompactIdSize()); + for (const auto &[id, buffer_info] : buffer_manager.getBuffers()) { + size_t compactId = buffer_manager.getCompactId(id); + void *buffer_address = std::get<0>(buffer_info); + ext_buffers[compactId] = buffer_address; + } + GLOG(gpuMemcpyAsync(ext_buf_addr, ext_buffers.data(), + ext_buffers.size() * sizeof(void *), + gpuMemcpyHostToDevice, copy_stream_->get())); + } + if (world_size_ > 1) { void *proxy_chan_addr = get_global_rt("ARK_PROXY_CHANS"); void *proxy_secondary_chan_addr = @@ -745,6 +786,11 @@ uintptr_t Executor::Impl::tensor_address(const Tensor tensor) const { void Executor::Impl::tensor_read(const Tensor tensor, void *data, size_t bytes, bool is_d2d) const { GLOG(gpuSetDevice(gpu_id_)); + if (tensor.ref()->buffer()->is_external()) { + ERR(InvalidUsageError, + "Reading data from a tensor preallocated by PyTorch is not " + "supported. Use PyTorch's native methods."); + } size_t tensor_data_bytes = tensor.shape().nelems() * tensor.data_type().bytes(); if (bytes != tensor_data_bytes) { @@ -782,6 +828,11 @@ void Executor::Impl::tensor_read(const Tensor tensor, void *data, size_t bytes, void Executor::Impl::tensor_write(const Tensor tensor, const void *data, size_t bytes, bool is_d2d) const { GLOG(gpuSetDevice(gpu_id_)); + if (tensor.ref()->buffer()->is_external()) { + ERR(InvalidUsageError, + "Writing data to a tensor preallocated by PyTorch is not " + "supported. Use PyTorch's native methods."); + } size_t tensor_data_bytes = tensor.shape().nelems() * tensor.data_type().bytes(); if (bytes != tensor_data_bytes) { @@ -841,7 +892,10 @@ float Executor::stop(int64_t max_spin_count) { void Executor::barrier() { impl_->barrier(); } -void Executor::destroy() { impl_.reset(nullptr); } +void Executor::destroy() { + ModelBufferManager::getInstance().clearBuffers(); + impl_.reset(nullptr); +} bool Executor::destroyed() const { return impl_.get() == nullptr; } diff --git a/ark/api/tensor.cpp b/ark/api/tensor.cpp index 4b03c3ac8..65ba09e36 100644 --- a/ark/api/tensor.cpp +++ b/ark/api/tensor.cpp @@ -3,11 +3,30 @@ #include "ark/tensor.hpp" +#include + +#include "ark/dims.hpp" +#include "logging.h" +#include "model/model_buffer.hpp" #include "model/model_data_type.hpp" #include "model/model_tensor.hpp" namespace ark { +Tensor::Tensor(void* data_ptr, int32_t device_id, int8_t dtype_bytes, + const std::vector& shape, + const std::string& ark_type_str) { + size_t external_data_size = std::accumulate(shape.begin(), shape.end(), 1, + std::multiplies()) * + dtype_bytes; + auto buffer = + std::make_shared(data_ptr, external_data_size, device_id); + ark::ModelDataType dtype = DataType::from_name(ark_type_str).ref(); + auto tensor = std::make_shared(dtype, buffer, Dims(shape), + Dims(shape), Dims(), Dims()); + ref_ = tensor; +} + size_t Tensor::id() const { if (ref_) { return ref_->id(); @@ -43,14 +62,14 @@ Dims Tensor::padded_shape() const { return Dims(); } -const DataType &Tensor::data_type() const { +const DataType& Tensor::data_type() const { if (ref_) { return DataType::from_name(ref_->data_type()->type_name()); } return NONE; } -std::ostream &operator<<(std::ostream &os, const Tensor &tensor) { +std::ostream& operator<<(std::ostream& os, const Tensor& tensor) { if (tensor.is_null()) { os << "null"; } else { diff --git a/ark/codegen.cpp b/ark/codegen.cpp index cd6206284..d86395d0c 100644 --- a/ark/codegen.cpp +++ b/ark/codegen.cpp @@ -10,6 +10,7 @@ #include "file_io.h" #include "logging.h" #include "model/model_buffer.hpp" +#include "model/model_buffer_manager.hpp" #include "model/model_data_type.hpp" #include "model/model_op.hpp" #include "model/model_tensor.hpp" @@ -43,7 +44,7 @@ class CodeGenerator::Impl { public: Impl(const PlanJson &plan, const std::map &buffer_id_to_offset, - const std::string &name); + const std::string &name, ModelBufferManager *buffer_manager); ~Impl() = default; private: @@ -64,6 +65,8 @@ class CodeGenerator::Impl { std::string sync_process_range(const Range &ranges, int state_id); + ModelBufferManager *buffer_manager_; + protected: friend class CodeGenerator; @@ -78,14 +81,22 @@ class CodeGenerator::Impl { CodeGenerator::Impl::Impl(const PlanJson &plan, const std::map &buffer_id_to_offset, - const std::string &name) - : buffer_id_to_offset_(buffer_id_to_offset), name_(name) { + const std::string &name, + ModelBufferManager *buffer_manager) + : buffer_id_to_offset_(buffer_id_to_offset), + name_(name), + buffer_manager_(buffer_manager) { rank_ = plan.at("Rank"); world_size_ = plan.at("WorldSize"); num_procs_ = plan.at("NumProcessors"); num_warps_per_proc_ = plan.at("NumWarpsPerProcessor"); std::stringstream definitions_ss; + if (buffer_manager_) { + definitions_ss << "__device__ void* ARK_EXTERNAL_BUFFERS[" + << buffer_manager_->getCompactIdSize() << "];\n"; + } + for (auto &task_json : plan.at("TaskInfos")) { definitions_ss << this->def_task(task_json); } @@ -224,11 +235,18 @@ std::string CodeGenerator::Impl::def_task(const Json &task_json) { auto &arg = impl_args[i]; if (arg.type_name() == "TENSOR") { auto tns = arg.value(); - size_t buffer_offset = - buffer_id_to_offset_.at(tns->buffer()->id()); - size_t offset = buffer_offset + ModelOffset(tns).value(); - ss << "(" << tns->data_type()->type_str() << "*)&_buf[" - << offset << "]"; + if (tns->buffer()->is_external()) { + size_t compactId = + buffer_manager_->getCompactId(tns->buffer()->id()); + ss << "(" << tns->data_type()->type_str() + << "*)ARK_EXTERNAL_BUFFERS[" << compactId << "]"; + } else { + size_t buffer_offset = + buffer_id_to_offset_.at(tns->buffer()->id()); + size_t offset = buffer_offset + ModelOffset(tns).value(); + ss << "(" << tns->data_type()->type_str() << "*)&_buf[" + << offset << "]"; + } } else if (arg.type_name() == "OFFSET") { auto moff = arg.value(); size_t buffer_offset = @@ -430,8 +448,9 @@ std::string CodeGenerator::Impl::sync_process_range(const Range &range, CodeGenerator::CodeGenerator( const PlanJson &plan, const std::map &buffer_id_to_offset, - const std::string &name) - : impl_(std::make_shared(plan, buffer_id_to_offset, name)) {} + const std::string &name, ModelBufferManager *buffer_manager) + : impl_(std::make_shared(plan, buffer_id_to_offset, name, + buffer_manager)) {} std::string CodeGenerator::code() const { return impl_->code_; } diff --git a/ark/codegen.hpp b/ark/codegen.hpp index 4f8307e7e..295aff919 100644 --- a/ark/codegen.hpp +++ b/ark/codegen.hpp @@ -8,6 +8,7 @@ #include #include +#include "model/model_buffer_manager.hpp" #include "model/model_json.hpp" namespace ark { @@ -16,7 +17,8 @@ class CodeGenerator { public: CodeGenerator(const PlanJson &plan, const std::map &buffer_id_to_offset, - const std::string &name = "ark_kernel"); + const std::string &name = "ark_kernel", + ModelBufferManager *buffer_manager = nullptr); ~CodeGenerator() = default; diff --git a/ark/include/ark/tensor.hpp b/ark/include/ark/tensor.hpp index 747ce5fea..7f65dff25 100644 --- a/ark/include/ark/tensor.hpp +++ b/ark/include/ark/tensor.hpp @@ -31,6 +31,8 @@ class Tensor { Tensor(ModelTensorRef ref) : ref_(ref) {} Tensor(const Tensor &other) = default; Tensor &operator=(const Tensor &other) = default; + Tensor(void *data_ptr, int32_t device_id, int8_t dtype_bytes, + const std::vector &shape, const std::string &ark_type_str); bool operator==(const Tensor &other) const { return ref_ == other.ref_; } bool operator!=(const Tensor &other) const { return ref_ != other.ref_; } diff --git a/ark/model/model_buffer.cpp b/ark/model/model_buffer.cpp index 4ce91b5e4..a2ca423a8 100644 --- a/ark/model/model_buffer.cpp +++ b/ark/model/model_buffer.cpp @@ -4,13 +4,13 @@ #include "model_buffer.hpp" #include "logging.h" +#include "model_buffer_manager.hpp" namespace ark { -ModelBuffer::ModelBuffer(int rank) : rank_(rank) { - static size_t id = 0; - id_ = id++; -} +size_t ModelBuffer::curr_id = 0; + +ModelBuffer::ModelBuffer(int rank) : rank_(rank) { id_ = curr_id++; } ModelBuffer::ModelBuffer(size_t id, int rank, const std::vector &send_tags, @@ -24,6 +24,23 @@ ModelBuffer::ModelBuffer(size_t id, int rank, } } +ModelBuffer::ModelBuffer(void *data, size_t size, int32_t device_id) + : rank_(-1), + external_data_(data), + external_data_size_(size), + device_id_(device_id), + is_external_(true) { + id_ = curr_id++; +} + +ModelBuffer::ModelBuffer(size_t id, void *data, size_t size, int32_t device_id) + : id_(id), + rank_(-1), + external_data_(data), + external_data_size_(size), + device_id_(device_id), + is_external_(true) {} + void ModelBuffer::tag_send(int remote_rank, int tag) { send_tags_.insert(TagInfo{remote_rank, tag}); } @@ -46,6 +63,14 @@ Json ModelBuffer::serialize() const { } j["SendTags"] = send_tags; j["RecvTags"] = recv_tags; + j["IsExternal"] = is_external_; + if (is_external_) { + ModelBufferManager::getInstance().registerBuffer(id_, external_data_, + external_data_size_); + j["ExternalDataSize"] = external_data_size_; + j["DeviceId"] = device_id_; + } + // external_data_ptr_ is not included in JSON return j; } @@ -62,6 +87,28 @@ std::shared_ptr ModelBuffer::deserialize(const Json &serialized) { } else if (!serialized.contains("RecvTags")) { ERR(InvalidUsageError, "ModelBuffer deserialization failed: missing RecvTags"); + } else if (!serialized.contains("IsExternal")) { + ERR(InvalidUsageError, + "ModelBuffer deserialization failed: missing IsExternal"); + } + if (serialized["IsExternal"]) { + if (!serialized.contains("ExternalDataSize")) { + ERR(InvalidUsageError, + "ModelBuffer deserialization failed: missing ExternalDataSize"); + } else if (!serialized.contains("DeviceId")) { + ERR(InvalidUsageError, + "ModelBuffer deserialization failed: missing DeviceId"); + } + void *data_ptr = + ModelBufferManager::getInstance().getBuffer(serialized["Id"]); + if (!data_ptr) { + ERR(InvalidUsageError, + "ModelBuffer deserialization failed: external buffer not found " + "in BufferManager"); + } + return std::make_shared(serialized["Id"], data_ptr, + serialized["ExternalDataSize"], + serialized["DeviceId"]); } return std::make_shared(serialized["Id"], serialized["Rank"], serialized["SendTags"], diff --git a/ark/model/model_buffer.hpp b/ark/model/model_buffer.hpp index 7ad3db206..e7f1045b2 100644 --- a/ark/model/model_buffer.hpp +++ b/ark/model/model_buffer.hpp @@ -22,6 +22,10 @@ class ModelBuffer { ModelBuffer(size_t id, int rank, const std::vector &send_tags, const std::vector &recv_tags); + // externally managed buffer + ModelBuffer(void *data, size_t size, int32_t device_id); + ModelBuffer(size_t id, void *data, size_t size, int32_t device_id); + size_t id() const { return id_; } int rank() const { return rank_; } @@ -44,11 +48,22 @@ class ModelBuffer { static std::shared_ptr deserialize(const Json &serialized); + // external buffer management + size_t external_data_size() const { return external_data_size_; } + void *external_data() const { return external_data_; } + int32_t device_id() const { return device_id_; } + bool is_external() const { return is_external_; } + private: + static size_t curr_id; size_t id_; int rank_; std::set send_tags_; std::set recv_tags_; + void *external_data_ = nullptr; + size_t external_data_size_ = 0; + int32_t device_id_; + bool is_external_ = false; }; } // namespace ark diff --git a/ark/model/model_buffer_manager.hpp b/ark/model/model_buffer_manager.hpp new file mode 100644 index 000000000..c73ba51dd --- /dev/null +++ b/ark/model/model_buffer_manager.hpp @@ -0,0 +1,90 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef MODEL_BUFFER_MANAGER_HPP +#define MODEL_BUFFER_MANAGER_HPP + +#include +#include + +namespace ark { +/** + * @brief Manages externally allocated buffers not in the ARK memory space. + * + * Details: + * - `buffers_`: Maps external buffer IDs to their pointers and sizes. + * - `externalIdMap_`: Maps external buffer IDs to their corresponding compact + * IDs. During the code generation phase, an array of buffer addresses + * (ARK_EXTERNAL_BUFFERS) is preallocated to hold addresses of external + * buffers. Accessing an external buffer utilizes its compact ID to index into + * this array, ensuring that each index *i* in ARK_EXTERNAL_BUFFERS corresponds + * to the buffer with compact ID *i*, facilitating mixed allocation patterns + * (e.g., internal, internal, external, internal). + */ +class ModelBufferManager { + public: + static ModelBufferManager& getInstance() { + static ModelBufferManager instance; + return instance; + } + + void registerBuffer(size_t id, void* data, size_t size) { + buffers_[id] = std::make_tuple(data, size); + if (externalIdMap_.find(id) == externalIdMap_.end()) { + externalIdMap_[id] = nextCompactId_++; + } + } + + void* getBuffer(size_t id) { + auto it = buffers_.find(id); + if (it != buffers_.end()) { + return std::get<0>(it->second); + } + return nullptr; + } + + size_t getBufferSize(size_t id) { + auto it = buffers_.find(id); + if (it != buffers_.end()) { + return std::get<1>(it->second); + } + return 0; + } + + size_t getCompactIdSize() { return nextCompactId_; } + + size_t getCompactId(size_t id) { + auto it = externalIdMap_.find(id); + if (it != externalIdMap_.end()) { + return it->second; + } + return 0; + } + + const std::unordered_map>& getBuffers() + const { + return buffers_; + } + + void clearBuffers() { + buffers_.clear(); + externalIdMap_.clear(); + nextCompactId_ = 0; + } + + bool isEmpty() const { return buffers_.empty(); } + + private: + std::unordered_map> + buffers_; // Maps buffer IDs to pointers and sizes. + std::unordered_map + externalIdMap_; // Maps original buffer IDs to compact IDs for external + // buffers. + size_t nextCompactId_ = 0; + ModelBufferManager() {} + ModelBufferManager(const ModelBufferManager&) = delete; + ModelBufferManager& operator=(const ModelBufferManager&) = delete; +}; +} // namespace ark + +#endif // MODEL_BUFFER_MANAGER_HPP \ No newline at end of file diff --git a/python/ark/data_type.py b/python/ark/data_type.py index 8ab982106..1d76200e8 100644 --- a/python/ark/data_type.py +++ b/python/ark/data_type.py @@ -86,6 +86,28 @@ def from_torch(torch_type: torch.dtype) -> "DataType": f" to ark data type." ) + @staticmethod + def torch_to_ark_dtype_name(torch_type: torch.dtype) -> str: + """ + Converts a PyTorch data type to its corresponding ARK data type string representation. + + Parameters: + torch_type (torch.dtype): The PyTorch data type to convert. + + Returns: + str: The corresponding ARK data type as a string in uppercase. + + Raises: + ValueError: If there is no defined conversion from the given PyTorch data type to an ARK data type. + """ + for type_name, reg in _REGISTRY_DATA_TYPE.items(): + if reg["torch"] == torch_type: + return type_name.upper() + raise ValueError( + f"Undefined conversion from torch data type {torch_type}" + f" to ark data type." + ) + @staticmethod def from_name(type_name: str) -> "DataType": """ diff --git a/python/ark/tensor.py b/python/ark/tensor.py index eff1bf20e..03d186e2d 100644 --- a/python/ark/tensor.py +++ b/python/ark/tensor.py @@ -180,6 +180,22 @@ def from_torch(tensor: torch.Tensor): lambda: tensor, ) + @staticmethod + def get_ark_view(tensor: torch.Tensor, runtime_id: int = -1) -> "Tensor": + """ + Returns an ARK tensor that shares the same memory with the torch tensor. + """ + if _no_torch: + raise ImportError("torch is not available") + elif not tensor.is_contiguous(): + raise ValueError("Torch tensor must be contiguous.") + elif tensor.device.type == "cpu": + raise ValueError("Torch tensor must be on a device.") + ark_dtype = DataType.torch_to_ark_dtype_name(tensor.dtype) + dl_capsule = torch.utils.dlpack.to_dlpack(tensor) + ark_tensor = _Tensor(dl_capsule, ark_dtype) + return Tensor(ark_tensor, runtime_id=runtime_id) + def copy(self, data: Union[np.ndarray, torch.Tensor]) -> "Tensor": """ Copies data into this tensor. The data type may differ, diff --git a/python/tensor_py.cpp b/python/tensor_py.cpp index fbd909d3d..ceaf9269b 100644 --- a/python/tensor_py.cpp +++ b/python/tensor_py.cpp @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. +#include #include #include #include @@ -9,8 +10,52 @@ namespace py = pybind11; -void register_tensor(py::module &m) { +struct DLTensorMetadata { + void* data_ptr; + int32_t device_id; + DLDeviceType device_type; + int32_t ndim; + DLDataType dtype; + std::vector shape; + std::vector strides; + uint64_t byte_offset; +}; + +DLTensorMetadata extractDLTensorMetadata(DLManagedTensor* dl_tensor) { + DLTensorMetadata metadata; + metadata.data_ptr = dl_tensor->dl_tensor.data; + metadata.device_id = dl_tensor->dl_tensor.device.device_id; + metadata.device_type = dl_tensor->dl_tensor.device.device_type; + metadata.ndim = dl_tensor->dl_tensor.ndim; + metadata.dtype = dl_tensor->dl_tensor.dtype; + metadata.shape.assign( + dl_tensor->dl_tensor.shape, + dl_tensor->dl_tensor.shape + dl_tensor->dl_tensor.ndim); + if (dl_tensor->dl_tensor.strides != nullptr) { + metadata.strides.assign( + dl_tensor->dl_tensor.strides, + dl_tensor->dl_tensor.strides + dl_tensor->dl_tensor.ndim); + } + metadata.byte_offset = dl_tensor->dl_tensor.byte_offset; + return metadata; +} + +void register_tensor(py::module& m) { py::class_(m, "_Tensor") + .def(py::init([](py::capsule capsule, const std::string& ark_type_str) { + DLManagedTensor* dl_tensor = (DLManagedTensor*)capsule; + if (!dl_tensor) { + throw std::runtime_error( + "Capsule does not contain a DLManagedTensor"); + } + DLTensorMetadata metadata = extractDLTensorMetadata(dl_tensor); + int32_t device_id = metadata.device_id; + void* data_ptr = metadata.data_ptr; + int8_t dtype_bytes = metadata.dtype.bits / 8; + auto shape = metadata.shape; + + return new ark::Tensor(data_ptr, device_id, dtype_bytes, shape, ark_type_str); + })) .def("id", &ark::Tensor::id) .def("shape", &ark::Tensor::shape, py::return_value_policy::reference) .def("strides", &ark::Tensor::strides, diff --git a/python/unittest/test_conversion.py b/python/unittest/test_conversion.py index 5befa1c34..07c51b194 100644 --- a/python/unittest/test_conversion.py +++ b/python/unittest/test_conversion.py @@ -1,6 +1,7 @@ import pytest import numpy as np import ark +from typing import Callable try: import torch @@ -9,6 +10,8 @@ except ImportError: _no_torch = True +# ARK to Torch tests + def initialize_tensor(dimensions, dtype): tensor = ark.tensor(dimensions, dtype) @@ -69,7 +72,7 @@ def check_diff(input_tensor_host, input_view_numpy, value, index): # Test function to check if changes to the torch views are reflected in the original tensors @pytest.mark.parametrize("dtype", [ark.fp16, ark.fp32]) -def test_aliasing(dtype: ark.DataType): +def test_ark_to_torch_aliasing(dtype: ark.DataType): ark.init() dimensions = [4, 4] input_tensor, input_tensor_host = initialize_tensor(dimensions, dtype) @@ -102,7 +105,7 @@ def test_aliasing(dtype: ark.DataType): runtime.stop() runtime.reset() - +@pytest.mark.skip(reason="Not implemented") def test_conversion_torch(): if _no_torch: pytest.skip("PyTorch not available") @@ -126,3 +129,79 @@ def test_conversion_torch(): torch_tensor = t.to_torch() assert torch.all(torch_tensor == 7) + + +# Torch to ARK tests + +ArkBinOp = Callable[[ark.Tensor, ark.Tensor], ark.Tensor] +TorchBinOp = Callable[[torch.Tensor, torch.Tensor], torch.Tensor] +ArkUnOp = Callable[[ark.Tensor], ark.Tensor] +TorchUnOp = Callable[[torch.Tensor], torch.Tensor] + + +# Verify the accuracy of binary operations involving ARK view tensors +@pytest.mark.parametrize( + "dtype, ark_op, torch_op, tensor_dims", + [(torch.float16, ark.add, torch.add, (2, 3))], +) +def test_bin_op(dtype, ark_op: ArkBinOp, torch_op: TorchBinOp, tensor_dims): + ark.init() + input_tensor = torch.randn(tensor_dims, dtype=dtype, device="cuda:0") + other_tensor = torch.randn(tensor_dims, dtype=dtype, device="cuda:0") + expected_output = torch_op(input_tensor, other_tensor).cpu().numpy() + input_ark_view = ark.Tensor.get_ark_view(input_tensor) + other_ark_view = ark.Tensor.get_ark_view(other_tensor) + output = ark_op(input_ark_view, other_ark_view) + runtime = ark.Runtime() + runtime.launch() + runtime.run() + output_host = output.to_numpy() + runtime.stop() + runtime.reset() + assert np.allclose(output_host, expected_output) + + +# Verify the accuracy of unary operations involving ARK view tensors +@pytest.mark.parametrize( + "dtype, ark_op, torch_op, tensor_dims", + [(torch.float16, ark.exp, torch.exp, (3, 3))], +) +def test_unary_op(dtype, ark_op: ArkUnOp, torch_op: TorchUnOp, tensor_dims): + ark.init() + input_tensor = torch.randn(tensor_dims, dtype=dtype, device="cuda:0") + expected_output = torch_op(input_tensor).cpu().numpy() + input_ark_view = ark.Tensor.get_ark_view(input_tensor) + output = ark_op(input_ark_view) + runtime = ark.Runtime() + runtime.launch() + runtime.run() + output_host = output.to_numpy() + runtime.stop() + runtime.reset() + assert np.allclose(output_host, expected_output) + + +# Test function to check if changes in torch tensors are reflected in ARK views +@pytest.mark.parametrize("dtype, tensor_dims", [(torch.float16, (64, 64))]) +def test_torch_to_ark_aliasing(dtype, tensor_dims): + ark.init() + # Initialize a PyTorch tensor + input_tensor = torch.randn(tensor_dims, dtype=dtype, device="cuda:0") + other_tensor = torch.randn(tensor_dims, dtype=dtype, device="cuda:0") + + input_ark_view = ark.Tensor.get_ark_view(input_tensor) + other_ark_view = ark.Tensor.get_ark_view(other_tensor) + + output = ark.add(input_ark_view, other_ark_view) + # Perform in place operations + input_tensor += other_tensor + other_tensor += input_tensor + expected_output = (input_tensor + other_tensor).cpu().numpy() + + runtime = ark.Runtime() + runtime.launch() + runtime.run() + output_host = output.to_numpy() + runtime.stop() + runtime.reset() + assert np.allclose(output_host, expected_output) From 60a5836cb1e6877b77863336d826a656cc1c9e38 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Thu, 27 Jun 2024 20:41:35 +0000 Subject: [PATCH 2/6] minor updates --- ark/api/executor.cpp | 1 - ark/api/tensor.cpp | 13 ++++--------- ark/include/ark/tensor.hpp | 4 ++-- ark/model/model_buffer_manager.hpp | 6 +++--- python/ark/data_type.py | 22 ---------------------- python/ark/tensor.py | 4 ++-- python/tensor_py.cpp | 7 +++---- python/unittest/test_conversion.py | 2 +- 8 files changed, 15 insertions(+), 44 deletions(-) diff --git a/ark/api/executor.cpp b/ark/api/executor.cpp index 243937ce1..c145241ba 100644 --- a/ark/api/executor.cpp +++ b/ark/api/executor.cpp @@ -232,7 +232,6 @@ Executor::Impl::Impl(int rank, int world_size, int gpu_id, } ModelBufferManager &buffer_manager = ModelBufferManager::getInstance(); - std::shared_ptr codegen_; if (!buffer_manager.isEmpty()) { codegen_ = std::make_shared( diff --git a/ark/api/tensor.cpp b/ark/api/tensor.cpp index 65ba09e36..4d33bd9f1 100644 --- a/ark/api/tensor.cpp +++ b/ark/api/tensor.cpp @@ -3,26 +3,21 @@ #include "ark/tensor.hpp" -#include - -#include "ark/dims.hpp" -#include "logging.h" #include "model/model_buffer.hpp" #include "model/model_data_type.hpp" #include "model/model_tensor.hpp" namespace ark { -Tensor::Tensor(void* data_ptr, int32_t device_id, int8_t dtype_bytes, +Tensor::Tensor(void* data_ptr, int32_t device_id, const std::vector& shape, - const std::string& ark_type_str) { + const DataType& dtype) { size_t external_data_size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()) * - dtype_bytes; + dtype.bytes(); auto buffer = std::make_shared(data_ptr, external_data_size, device_id); - ark::ModelDataType dtype = DataType::from_name(ark_type_str).ref(); - auto tensor = std::make_shared(dtype, buffer, Dims(shape), + auto tensor = std::make_shared(dtype.ref(), buffer, Dims(shape), Dims(shape), Dims(), Dims()); ref_ = tensor; } diff --git a/ark/include/ark/tensor.hpp b/ark/include/ark/tensor.hpp index 7f65dff25..d13748175 100644 --- a/ark/include/ark/tensor.hpp +++ b/ark/include/ark/tensor.hpp @@ -31,8 +31,8 @@ class Tensor { Tensor(ModelTensorRef ref) : ref_(ref) {} Tensor(const Tensor &other) = default; Tensor &operator=(const Tensor &other) = default; - Tensor(void *data_ptr, int32_t device_id, int8_t dtype_bytes, - const std::vector &shape, const std::string &ark_type_str); + Tensor(void *data_ptr, int32_t device_id, const std::vector &shape, + const DataType &dtype); bool operator==(const Tensor &other) const { return ref_ == other.ref_; } bool operator!=(const Tensor &other) const { return ref_ != other.ref_; } diff --git a/ark/model/model_buffer_manager.hpp b/ark/model/model_buffer_manager.hpp index c73ba51dd..3768f2340 100644 --- a/ark/model/model_buffer_manager.hpp +++ b/ark/model/model_buffer_manager.hpp @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -#ifndef MODEL_BUFFER_MANAGER_HPP -#define MODEL_BUFFER_MANAGER_HPP +#ifndef ARK_MODEL_BUFFER_MANAGER_HPP_ +#define ARK_MODEL_BUFFER_MANAGER_HPP_ #include #include @@ -87,4 +87,4 @@ class ModelBufferManager { }; } // namespace ark -#endif // MODEL_BUFFER_MANAGER_HPP \ No newline at end of file +#endif // ARK_MODEL_BUFFER_MANAGER_HPP_ diff --git a/python/ark/data_type.py b/python/ark/data_type.py index 1d76200e8..8ab982106 100644 --- a/python/ark/data_type.py +++ b/python/ark/data_type.py @@ -86,28 +86,6 @@ def from_torch(torch_type: torch.dtype) -> "DataType": f" to ark data type." ) - @staticmethod - def torch_to_ark_dtype_name(torch_type: torch.dtype) -> str: - """ - Converts a PyTorch data type to its corresponding ARK data type string representation. - - Parameters: - torch_type (torch.dtype): The PyTorch data type to convert. - - Returns: - str: The corresponding ARK data type as a string in uppercase. - - Raises: - ValueError: If there is no defined conversion from the given PyTorch data type to an ARK data type. - """ - for type_name, reg in _REGISTRY_DATA_TYPE.items(): - if reg["torch"] == torch_type: - return type_name.upper() - raise ValueError( - f"Undefined conversion from torch data type {torch_type}" - f" to ark data type." - ) - @staticmethod def from_name(type_name: str) -> "DataType": """ diff --git a/python/ark/tensor.py b/python/ark/tensor.py index 03d186e2d..b4a31ab90 100644 --- a/python/ark/tensor.py +++ b/python/ark/tensor.py @@ -191,9 +191,9 @@ def get_ark_view(tensor: torch.Tensor, runtime_id: int = -1) -> "Tensor": raise ValueError("Torch tensor must be contiguous.") elif tensor.device.type == "cpu": raise ValueError("Torch tensor must be on a device.") - ark_dtype = DataType.torch_to_ark_dtype_name(tensor.dtype) + ark_dtype = DataType.from_torch(tensor.dtype) dl_capsule = torch.utils.dlpack.to_dlpack(tensor) - ark_tensor = _Tensor(dl_capsule, ark_dtype) + ark_tensor = _Tensor(dl_capsule, ark_dtype.ctype()) return Tensor(ark_tensor, runtime_id=runtime_id) def copy(self, data: Union[np.ndarray, torch.Tensor]) -> "Tensor": diff --git a/python/tensor_py.cpp b/python/tensor_py.cpp index ceaf9269b..16eb03421 100644 --- a/python/tensor_py.cpp +++ b/python/tensor_py.cpp @@ -21,7 +21,7 @@ struct DLTensorMetadata { uint64_t byte_offset; }; -DLTensorMetadata extractDLTensorMetadata(DLManagedTensor* dl_tensor) { +static DLTensorMetadata extractDLTensorMetadata(DLManagedTensor* dl_tensor) { DLTensorMetadata metadata; metadata.data_ptr = dl_tensor->dl_tensor.data; metadata.device_id = dl_tensor->dl_tensor.device.device_id; @@ -42,7 +42,7 @@ DLTensorMetadata extractDLTensorMetadata(DLManagedTensor* dl_tensor) { void register_tensor(py::module& m) { py::class_(m, "_Tensor") - .def(py::init([](py::capsule capsule, const std::string& ark_type_str) { + .def(py::init([](py::capsule capsule, const ark::DataType& dtype) { DLManagedTensor* dl_tensor = (DLManagedTensor*)capsule; if (!dl_tensor) { throw std::runtime_error( @@ -51,10 +51,9 @@ void register_tensor(py::module& m) { DLTensorMetadata metadata = extractDLTensorMetadata(dl_tensor); int32_t device_id = metadata.device_id; void* data_ptr = metadata.data_ptr; - int8_t dtype_bytes = metadata.dtype.bits / 8; auto shape = metadata.shape; - return new ark::Tensor(data_ptr, device_id, dtype_bytes, shape, ark_type_str); + return new ark::Tensor(data_ptr, device_id, shape, dtype); })) .def("id", &ark::Tensor::id) .def("shape", &ark::Tensor::shape, py::return_value_policy::reference) diff --git a/python/unittest/test_conversion.py b/python/unittest/test_conversion.py index 07c51b194..1bbc6171d 100644 --- a/python/unittest/test_conversion.py +++ b/python/unittest/test_conversion.py @@ -105,7 +105,7 @@ def test_ark_to_torch_aliasing(dtype: ark.DataType): runtime.stop() runtime.reset() -@pytest.mark.skip(reason="Not implemented") + def test_conversion_torch(): if _no_torch: pytest.skip("PyTorch not available") From 9a2f92b3817f02f8bb0a4d91b1c38f3479308ebe Mon Sep 17 00:00:00 2001 From: noli Date: Sun, 30 Jun 2024 06:59:12 +0000 Subject: [PATCH 3/6] minor fixes --- ark/api/executor.cpp | 16 +++++----- ark/codegen.cpp | 4 +-- ark/model/model_buffer.cpp | 6 ++-- ark/model/model_buffer_manager.hpp | 50 ++++++++++++------------------ python/ark/tensor.py | 16 +--------- python/unittest/test_conversion.py | 10 +++--- 6 files changed, 38 insertions(+), 64 deletions(-) diff --git a/ark/api/executor.cpp b/ark/api/executor.cpp index c145241ba..6c5d42929 100644 --- a/ark/api/executor.cpp +++ b/ark/api/executor.cpp @@ -231,9 +231,9 @@ Executor::Impl::Impl(int rank, int world_size, int gpu_id, std::to_string(kv.first) + ": " + std::to_string(kv.second) + ", "; } - ModelBufferManager &buffer_manager = ModelBufferManager::getInstance(); + ModelBufferManager &buffer_manager = ModelBufferManager::get_instance(); - if (!buffer_manager.isEmpty()) { + if (!buffer_manager.is_empty()) { codegen_ = std::make_shared( plan_json, buffer_id_to_offset_, name, &buffer_manager); } else { @@ -643,12 +643,12 @@ void Executor::Impl::launch(int64_t max_spin_count) { gpuMemcpyHostToDevice, copy_stream_->get())); // Handle external buffers - ModelBufferManager &buffer_manager = ModelBufferManager::getInstance(); - if (!buffer_manager.isEmpty()) { + ModelBufferManager &buffer_manager = ModelBufferManager::get_instance(); + if (!buffer_manager.is_empty()) { void *ext_buf_addr = get_global_rt("ARK_EXTERNAL_BUFFERS"); - std::vector ext_buffers(buffer_manager.getCompactIdSize()); - for (const auto &[id, buffer_info] : buffer_manager.getBuffers()) { - size_t compactId = buffer_manager.getCompactId(id); + std::vector ext_buffers(buffer_manager.get_compact_id_size()); + for (const auto &[id, buffer_info] : buffer_manager.get_buffers()) { + size_t compactId = buffer_manager.get_compact_id(id); void *buffer_address = std::get<0>(buffer_info); ext_buffers[compactId] = buffer_address; } @@ -892,7 +892,7 @@ float Executor::stop(int64_t max_spin_count) { void Executor::barrier() { impl_->barrier(); } void Executor::destroy() { - ModelBufferManager::getInstance().clearBuffers(); + ModelBufferManager::get_instance().clear_buffers(); impl_.reset(nullptr); } diff --git a/ark/codegen.cpp b/ark/codegen.cpp index d86395d0c..7484997d2 100644 --- a/ark/codegen.cpp +++ b/ark/codegen.cpp @@ -94,7 +94,7 @@ CodeGenerator::Impl::Impl(const PlanJson &plan, std::stringstream definitions_ss; if (buffer_manager_) { definitions_ss << "__device__ void* ARK_EXTERNAL_BUFFERS[" - << buffer_manager_->getCompactIdSize() << "];\n"; + << buffer_manager_->get_compact_id_size() << "];\n"; } for (auto &task_json : plan.at("TaskInfos")) { @@ -237,7 +237,7 @@ std::string CodeGenerator::Impl::def_task(const Json &task_json) { auto tns = arg.value(); if (tns->buffer()->is_external()) { size_t compactId = - buffer_manager_->getCompactId(tns->buffer()->id()); + buffer_manager_->get_compact_id(tns->buffer()->id()); ss << "(" << tns->data_type()->type_str() << "*)ARK_EXTERNAL_BUFFERS[" << compactId << "]"; } else { diff --git a/ark/model/model_buffer.cpp b/ark/model/model_buffer.cpp index a2ca423a8..ce8f37727 100644 --- a/ark/model/model_buffer.cpp +++ b/ark/model/model_buffer.cpp @@ -65,8 +65,8 @@ Json ModelBuffer::serialize() const { j["RecvTags"] = recv_tags; j["IsExternal"] = is_external_; if (is_external_) { - ModelBufferManager::getInstance().registerBuffer(id_, external_data_, - external_data_size_); + ModelBufferManager::get_instance().register_buffer(id_, external_data_, + external_data_size_); j["ExternalDataSize"] = external_data_size_; j["DeviceId"] = device_id_; } @@ -100,7 +100,7 @@ std::shared_ptr ModelBuffer::deserialize(const Json &serialized) { "ModelBuffer deserialization failed: missing DeviceId"); } void *data_ptr = - ModelBufferManager::getInstance().getBuffer(serialized["Id"]); + ModelBufferManager::get_instance().get_buffer(serialized["Id"]); if (!data_ptr) { ERR(InvalidUsageError, "ModelBuffer deserialization failed: external buffer not found " diff --git a/ark/model/model_buffer_manager.hpp b/ark/model/model_buffer_manager.hpp index 3768f2340..f359e57b1 100644 --- a/ark/model/model_buffer_manager.hpp +++ b/ark/model/model_buffer_manager.hpp @@ -8,34 +8,22 @@ #include namespace ark { -/** - * @brief Manages externally allocated buffers not in the ARK memory space. - * - * Details: - * - `buffers_`: Maps external buffer IDs to their pointers and sizes. - * - `externalIdMap_`: Maps external buffer IDs to their corresponding compact - * IDs. During the code generation phase, an array of buffer addresses - * (ARK_EXTERNAL_BUFFERS) is preallocated to hold addresses of external - * buffers. Accessing an external buffer utilizes its compact ID to index into - * this array, ensuring that each index *i* in ARK_EXTERNAL_BUFFERS corresponds - * to the buffer with compact ID *i*, facilitating mixed allocation patterns - * (e.g., internal, internal, external, internal). - */ +// Manages externally allocated buffers not in the ARK memory space. class ModelBufferManager { public: - static ModelBufferManager& getInstance() { + static ModelBufferManager& get_instance() { static ModelBufferManager instance; return instance; } - void registerBuffer(size_t id, void* data, size_t size) { + void register_buffer(size_t id, void* data, size_t size) { buffers_[id] = std::make_tuple(data, size); - if (externalIdMap_.find(id) == externalIdMap_.end()) { - externalIdMap_[id] = nextCompactId_++; + if (external_id_map_.find(id) == external_id_map_.end()) { + external_id_map_[id] = next_compact_id_++; } } - void* getBuffer(size_t id) { + void* get_buffer(size_t id) { auto it = buffers_.find(id); if (it != buffers_.end()) { return std::get<0>(it->second); @@ -43,7 +31,7 @@ class ModelBufferManager { return nullptr; } - size_t getBufferSize(size_t id) { + size_t get_buffer_size(size_t id) { auto it = buffers_.find(id); if (it != buffers_.end()) { return std::get<1>(it->second); @@ -51,36 +39,36 @@ class ModelBufferManager { return 0; } - size_t getCompactIdSize() { return nextCompactId_; } + size_t get_compact_id_size() { return next_compact_id_; } - size_t getCompactId(size_t id) { - auto it = externalIdMap_.find(id); - if (it != externalIdMap_.end()) { + size_t get_compact_id(size_t id) { + auto it = external_id_map_.find(id); + if (it != external_id_map_.end()) { return it->second; } return 0; } - const std::unordered_map>& getBuffers() + const std::unordered_map>& get_buffers() const { return buffers_; } - void clearBuffers() { + void clear_buffers() { buffers_.clear(); - externalIdMap_.clear(); - nextCompactId_ = 0; + external_id_map_.clear(); + next_compact_id_ = 0; } - bool isEmpty() const { return buffers_.empty(); } + bool is_empty() const { return buffers_.empty(); } private: std::unordered_map> buffers_; // Maps buffer IDs to pointers and sizes. std::unordered_map - externalIdMap_; // Maps original buffer IDs to compact IDs for external - // buffers. - size_t nextCompactId_ = 0; + external_id_map_; // Maps original buffer IDs to compact IDs for + // external buffers. + size_t next_compact_id_ = 0; ModelBufferManager() {} ModelBufferManager(const ModelBufferManager&) = delete; ModelBufferManager& operator=(const ModelBufferManager&) = delete; diff --git a/python/ark/tensor.py b/python/ark/tensor.py index b4a31ab90..42949cfb6 100644 --- a/python/ark/tensor.py +++ b/python/ark/tensor.py @@ -167,21 +167,7 @@ def from_numpy(self, ndarray: np.ndarray) -> "Tensor": return self @staticmethod - def from_torch(tensor: torch.Tensor): - return Tensor( - Model.get_model().tensor( - Dims(list(tensor.shape)), - DataType.from_torch(tensor.dtype).ctype(), - Dims(), - Dims(), - Dims(), - "", - ), - lambda: tensor, - ) - - @staticmethod - def get_ark_view(tensor: torch.Tensor, runtime_id: int = -1) -> "Tensor": + def from_torch(tensor: torch.Tensor, runtime_id: int = -1) -> "Tensor": """ Returns an ARK tensor that shares the same memory with the torch tensor. """ diff --git a/python/unittest/test_conversion.py b/python/unittest/test_conversion.py index 1bbc6171d..833b88662 100644 --- a/python/unittest/test_conversion.py +++ b/python/unittest/test_conversion.py @@ -149,8 +149,8 @@ def test_bin_op(dtype, ark_op: ArkBinOp, torch_op: TorchBinOp, tensor_dims): input_tensor = torch.randn(tensor_dims, dtype=dtype, device="cuda:0") other_tensor = torch.randn(tensor_dims, dtype=dtype, device="cuda:0") expected_output = torch_op(input_tensor, other_tensor).cpu().numpy() - input_ark_view = ark.Tensor.get_ark_view(input_tensor) - other_ark_view = ark.Tensor.get_ark_view(other_tensor) + input_ark_view = ark.Tensor.from_torch(input_tensor) + other_ark_view = ark.Tensor.from_torch(other_tensor) output = ark_op(input_ark_view, other_ark_view) runtime = ark.Runtime() runtime.launch() @@ -170,7 +170,7 @@ def test_unary_op(dtype, ark_op: ArkUnOp, torch_op: TorchUnOp, tensor_dims): ark.init() input_tensor = torch.randn(tensor_dims, dtype=dtype, device="cuda:0") expected_output = torch_op(input_tensor).cpu().numpy() - input_ark_view = ark.Tensor.get_ark_view(input_tensor) + input_ark_view = ark.Tensor.from_torch(input_tensor) output = ark_op(input_ark_view) runtime = ark.Runtime() runtime.launch() @@ -189,8 +189,8 @@ def test_torch_to_ark_aliasing(dtype, tensor_dims): input_tensor = torch.randn(tensor_dims, dtype=dtype, device="cuda:0") other_tensor = torch.randn(tensor_dims, dtype=dtype, device="cuda:0") - input_ark_view = ark.Tensor.get_ark_view(input_tensor) - other_ark_view = ark.Tensor.get_ark_view(other_tensor) + input_ark_view = ark.Tensor.from_torch(input_tensor) + other_ark_view = ark.Tensor.from_torch(other_tensor) output = ark.add(input_ark_view, other_ark_view) # Perform in place operations From 4d654ac90eca4334769212188a4316964e52ba0b Mon Sep 17 00:00:00 2001 From: noli Date: Sun, 30 Jun 2024 07:15:27 +0000 Subject: [PATCH 4/6] remove copying of external buffer array to device --- ark/api/executor.cpp | 15 --------------- ark/codegen.cpp | 13 +++++-------- ark/model/model_buffer_manager.hpp | 22 +--------------------- 3 files changed, 6 insertions(+), 44 deletions(-) diff --git a/ark/api/executor.cpp b/ark/api/executor.cpp index 6c5d42929..0b3a6e0ed 100644 --- a/ark/api/executor.cpp +++ b/ark/api/executor.cpp @@ -642,21 +642,6 @@ void Executor::Impl::launch(int64_t max_spin_count) { GLOG(gpuMemcpyAsync(buf_ptr_addr, &buf_ptr_val, sizeof(gpuDeviceptr), gpuMemcpyHostToDevice, copy_stream_->get())); - // Handle external buffers - ModelBufferManager &buffer_manager = ModelBufferManager::get_instance(); - if (!buffer_manager.is_empty()) { - void *ext_buf_addr = get_global_rt("ARK_EXTERNAL_BUFFERS"); - std::vector ext_buffers(buffer_manager.get_compact_id_size()); - for (const auto &[id, buffer_info] : buffer_manager.get_buffers()) { - size_t compactId = buffer_manager.get_compact_id(id); - void *buffer_address = std::get<0>(buffer_info); - ext_buffers[compactId] = buffer_address; - } - GLOG(gpuMemcpyAsync(ext_buf_addr, ext_buffers.data(), - ext_buffers.size() * sizeof(void *), - gpuMemcpyHostToDevice, copy_stream_->get())); - } - if (world_size_ > 1) { void *proxy_chan_addr = get_global_rt("ARK_PROXY_CHANS"); void *proxy_secondary_chan_addr = diff --git a/ark/codegen.cpp b/ark/codegen.cpp index 7484997d2..b1764640c 100644 --- a/ark/codegen.cpp +++ b/ark/codegen.cpp @@ -92,10 +92,6 @@ CodeGenerator::Impl::Impl(const PlanJson &plan, num_warps_per_proc_ = plan.at("NumWarpsPerProcessor"); std::stringstream definitions_ss; - if (buffer_manager_) { - definitions_ss << "__device__ void* ARK_EXTERNAL_BUFFERS[" - << buffer_manager_->get_compact_id_size() << "];\n"; - } for (auto &task_json : plan.at("TaskInfos")) { definitions_ss << this->def_task(task_json); @@ -236,10 +232,11 @@ std::string CodeGenerator::Impl::def_task(const Json &task_json) { if (arg.type_name() == "TENSOR") { auto tns = arg.value(); if (tns->buffer()->is_external()) { - size_t compactId = - buffer_manager_->get_compact_id(tns->buffer()->id()); - ss << "(" << tns->data_type()->type_str() - << "*)ARK_EXTERNAL_BUFFERS[" << compactId << "]"; + void *buf_addr = + ModelBufferManager::get_instance().get_buffer( + tns->buffer()->id()); + ss << "(" << tns->data_type()->type_str() << "*)" + << buf_addr; } else { size_t buffer_offset = buffer_id_to_offset_.at(tns->buffer()->id()); diff --git a/ark/model/model_buffer_manager.hpp b/ark/model/model_buffer_manager.hpp index f359e57b1..7b705f4c8 100644 --- a/ark/model/model_buffer_manager.hpp +++ b/ark/model/model_buffer_manager.hpp @@ -18,9 +18,6 @@ class ModelBufferManager { void register_buffer(size_t id, void* data, size_t size) { buffers_[id] = std::make_tuple(data, size); - if (external_id_map_.find(id) == external_id_map_.end()) { - external_id_map_[id] = next_compact_id_++; - } } void* get_buffer(size_t id) { @@ -39,35 +36,18 @@ class ModelBufferManager { return 0; } - size_t get_compact_id_size() { return next_compact_id_; } - - size_t get_compact_id(size_t id) { - auto it = external_id_map_.find(id); - if (it != external_id_map_.end()) { - return it->second; - } - return 0; - } - const std::unordered_map>& get_buffers() const { return buffers_; } - void clear_buffers() { - buffers_.clear(); - external_id_map_.clear(); - next_compact_id_ = 0; - } + void clear_buffers() { buffers_.clear(); } bool is_empty() const { return buffers_.empty(); } private: std::unordered_map> buffers_; // Maps buffer IDs to pointers and sizes. - std::unordered_map - external_id_map_; // Maps original buffer IDs to compact IDs for - // external buffers. size_t next_compact_id_ = 0; ModelBufferManager() {} ModelBufferManager(const ModelBufferManager&) = delete; From dd8ac3d1272acbd3e4540da4c3318d39421eb88c Mon Sep 17 00:00:00 2001 From: noli Date: Wed, 3 Jul 2024 00:33:51 +0000 Subject: [PATCH 5/6] move buffer manager logic --- ark/api/executor.cpp | 2 +- ark/codegen.cpp | 2 +- ark/codegen.hpp | 2 +- ark/model_buffer_manager.hpp | 58 ++++++++++++++++++++++++++++++++++++ 4 files changed, 61 insertions(+), 3 deletions(-) create mode 100644 ark/model_buffer_manager.hpp diff --git a/ark/api/executor.cpp b/ark/api/executor.cpp index 266893b38..deadd4582 100644 --- a/ark/api/executor.cpp +++ b/ark/api/executor.cpp @@ -25,7 +25,7 @@ #include "gpu/gpu_manager.h" #include "logging.h" #include "model/model_buffer.hpp" -#include "model/model_buffer_manager.hpp" +#include "model_buffer_manager.hpp" #include "model/model_data_type.hpp" #include "model/model_tensor.hpp" #include "utils/utils_net.hpp" diff --git a/ark/codegen.cpp b/ark/codegen.cpp index b1764640c..d4fe25cec 100644 --- a/ark/codegen.cpp +++ b/ark/codegen.cpp @@ -10,7 +10,7 @@ #include "file_io.h" #include "logging.h" #include "model/model_buffer.hpp" -#include "model/model_buffer_manager.hpp" +#include "model_buffer_manager.hpp" #include "model/model_data_type.hpp" #include "model/model_op.hpp" #include "model/model_tensor.hpp" diff --git a/ark/codegen.hpp b/ark/codegen.hpp index 295aff919..a2976e644 100644 --- a/ark/codegen.hpp +++ b/ark/codegen.hpp @@ -8,7 +8,7 @@ #include #include -#include "model/model_buffer_manager.hpp" +#include "model_buffer_manager.hpp" #include "model/model_json.hpp" namespace ark { diff --git a/ark/model_buffer_manager.hpp b/ark/model_buffer_manager.hpp new file mode 100644 index 000000000..7b705f4c8 --- /dev/null +++ b/ark/model_buffer_manager.hpp @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef ARK_MODEL_BUFFER_MANAGER_HPP_ +#define ARK_MODEL_BUFFER_MANAGER_HPP_ + +#include +#include + +namespace ark { +// Manages externally allocated buffers not in the ARK memory space. +class ModelBufferManager { + public: + static ModelBufferManager& get_instance() { + static ModelBufferManager instance; + return instance; + } + + void register_buffer(size_t id, void* data, size_t size) { + buffers_[id] = std::make_tuple(data, size); + } + + void* get_buffer(size_t id) { + auto it = buffers_.find(id); + if (it != buffers_.end()) { + return std::get<0>(it->second); + } + return nullptr; + } + + size_t get_buffer_size(size_t id) { + auto it = buffers_.find(id); + if (it != buffers_.end()) { + return std::get<1>(it->second); + } + return 0; + } + + const std::unordered_map>& get_buffers() + const { + return buffers_; + } + + void clear_buffers() { buffers_.clear(); } + + bool is_empty() const { return buffers_.empty(); } + + private: + std::unordered_map> + buffers_; // Maps buffer IDs to pointers and sizes. + size_t next_compact_id_ = 0; + ModelBufferManager() {} + ModelBufferManager(const ModelBufferManager&) = delete; + ModelBufferManager& operator=(const ModelBufferManager&) = delete; +}; +} // namespace ark + +#endif // ARK_MODEL_BUFFER_MANAGER_HPP_ From afef027d7e5606a6d7f1012957152f89557fd382 Mon Sep 17 00:00:00 2001 From: noli Date: Wed, 3 Jul 2024 00:36:20 +0000 Subject: [PATCH 6/6] remove old ebuffer file --- ark/model/model_buffer_manager.hpp | 58 ------------------------------ 1 file changed, 58 deletions(-) delete mode 100644 ark/model/model_buffer_manager.hpp diff --git a/ark/model/model_buffer_manager.hpp b/ark/model/model_buffer_manager.hpp deleted file mode 100644 index 7b705f4c8..000000000 --- a/ark/model/model_buffer_manager.hpp +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -#ifndef ARK_MODEL_BUFFER_MANAGER_HPP_ -#define ARK_MODEL_BUFFER_MANAGER_HPP_ - -#include -#include - -namespace ark { -// Manages externally allocated buffers not in the ARK memory space. -class ModelBufferManager { - public: - static ModelBufferManager& get_instance() { - static ModelBufferManager instance; - return instance; - } - - void register_buffer(size_t id, void* data, size_t size) { - buffers_[id] = std::make_tuple(data, size); - } - - void* get_buffer(size_t id) { - auto it = buffers_.find(id); - if (it != buffers_.end()) { - return std::get<0>(it->second); - } - return nullptr; - } - - size_t get_buffer_size(size_t id) { - auto it = buffers_.find(id); - if (it != buffers_.end()) { - return std::get<1>(it->second); - } - return 0; - } - - const std::unordered_map>& get_buffers() - const { - return buffers_; - } - - void clear_buffers() { buffers_.clear(); } - - bool is_empty() const { return buffers_.empty(); } - - private: - std::unordered_map> - buffers_; // Maps buffer IDs to pointers and sizes. - size_t next_compact_id_ = 0; - ModelBufferManager() {} - ModelBufferManager(const ModelBufferManager&) = delete; - ModelBufferManager& operator=(const ModelBufferManager&) = delete; -}; -} // namespace ark - -#endif // ARK_MODEL_BUFFER_MANAGER_HPP_