diff --git a/ark/api/executor.cpp b/ark/api/executor.cpp index 4af9df7c0..0a780bcc0 100644 --- a/ark/api/executor.cpp +++ b/ark/api/executor.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include "ark/data_type.hpp" #include "ark/model.hpp" @@ -24,6 +25,7 @@ #include "gpu/gpu_manager.h" #include "logging.h" #include "model/model_buffer.hpp" +#include "model_buffer_manager.hpp" #include "model/model_data_type.hpp" #include "model/model_tensor.hpp" #include "utils/utils_net.hpp" @@ -234,8 +236,15 @@ void Executor::Impl::init(const std::string &plan) { std::to_string(kv.first) + ": " + std::to_string(kv.second) + ", "; } - codegen_ = - std::make_shared(plan_json_, buffer_id_to_offset_, name_); + ModelBufferManager &buffer_manager = ModelBufferManager::get_instance(); + + if (!buffer_manager.is_empty()) { + codegen_ = std::make_shared( + plan_json_, buffer_id_to_offset_, name, &buffer_manager); + } else { + codegen_ = std::make_shared(plan_json_, + buffer_id_to_offset_, name); + } auto gpu_manager = GpuManager::get_instance(gpu_id_); timer_begin_ = gpu_manager->create_event(); @@ -367,7 +376,16 @@ std::map Executor::Impl::init_buffers(const Json &plan_json) { } continue; } - buffer_id_to_offset[buf_info->buffer->id()] = offset; + if (buf_info->buffer->is_external()) { + if (buf_info->buffer->device_id() != gpu_id_) { + ERR(InvalidUsageError, + "PyTorch tensor and model execution are on different GPUs"); + } + continue; + } else { + buffer_id_to_offset[buf_info->buffer->id()] = offset; + offset += buf_info->bytes; + } for (const auto &tag_info : buf_info->buffer->send_tags()) { remote_rank_to_send_tags_and_offsets[tag_info.first] .first.push_back(tag_info.second); @@ -380,7 +398,6 @@ std::map Executor::Impl::init_buffers(const Json &plan_json) { remote_rank_to_recv_tags_and_offsets[tag_info.first] .second.push_back(offset); } - offset += buf_info->bytes; } total_bytes_ = offset; @@ -456,7 +473,11 @@ std::map Executor::Impl::init_buffers(const Json &plan_json) { bootstrap->recv(tags.data(), len * sizeof(int), remote_rank, 1); bootstrap->recv(offsets.data(), len * sizeof(size_t), remote_rank, 2); for (int i = 0; i < len; ++i) { - buffer_id_to_offset[send_tag_to_buffer_id[tags[i]]] = offsets[i]; + if (!buffer_id_to_info[send_tag_to_buffer_id[tags[i]]] + ->buffer->is_external()) { + buffer_id_to_offset[send_tag_to_buffer_id[tags[i]]] = + offsets[i]; + } } } for (auto &kv : remote_rank_to_recv_tag_to_buffer_id) { @@ -472,10 +493,13 @@ std::map Executor::Impl::init_buffers(const Json &plan_json) { bootstrap->recv(tags.data(), len * sizeof(int), remote_rank, 4); bootstrap->recv(offsets.data(), len * sizeof(size_t), remote_rank, 5); for (int i = 0; i < len; ++i) { - buffer_id_to_offset[recv_tag_to_buffer_id[tags[i]]] = offsets[i]; + if (!buffer_id_to_info[recv_tag_to_buffer_id[tags[i]]] + ->buffer->is_external()) { + buffer_id_to_offset[recv_tag_to_buffer_id[tags[i]]] = + offsets[i]; + } } } - return buffer_id_to_offset; } @@ -742,6 +766,11 @@ uintptr_t Executor::Impl::tensor_address(const Tensor tensor) const { void Executor::Impl::tensor_read(const Tensor tensor, void *data, size_t bytes, bool is_d2d) const { GLOG(gpuSetDevice(gpu_id_)); + if (tensor.ref()->buffer()->is_external()) { + ERR(InvalidUsageError, + "Reading data from a tensor preallocated by PyTorch is not " + "supported. Use PyTorch's native methods."); + } size_t tensor_data_bytes = tensor.shape().nelems() * tensor.data_type().bytes(); if (bytes != tensor_data_bytes) { @@ -779,6 +808,11 @@ void Executor::Impl::tensor_read(const Tensor tensor, void *data, size_t bytes, void Executor::Impl::tensor_write(const Tensor tensor, const void *data, size_t bytes, bool is_d2d) const { GLOG(gpuSetDevice(gpu_id_)); + if (tensor.ref()->buffer()->is_external()) { + ERR(InvalidUsageError, + "Writing data to a tensor preallocated by PyTorch is not " + "supported. Use PyTorch's native methods."); + } size_t tensor_data_bytes = tensor.shape().nelems() * tensor.data_type().bytes(); if (bytes != tensor_data_bytes) { @@ -843,7 +877,10 @@ float Executor::stop(int64_t max_spin_count) { void Executor::barrier() { impl_->barrier(); } -void Executor::destroy() { impl_.reset(nullptr); } +void Executor::destroy() { + ModelBufferManager::get_instance().clear_buffers(); + impl_.reset(nullptr); +} bool Executor::destroyed() const { return impl_.get() == nullptr; } diff --git a/ark/api/tensor.cpp b/ark/api/tensor.cpp index 4b03c3ac8..4d33bd9f1 100644 --- a/ark/api/tensor.cpp +++ b/ark/api/tensor.cpp @@ -3,11 +3,25 @@ #include "ark/tensor.hpp" +#include "model/model_buffer.hpp" #include "model/model_data_type.hpp" #include "model/model_tensor.hpp" namespace ark { +Tensor::Tensor(void* data_ptr, int32_t device_id, + const std::vector& shape, + const DataType& dtype) { + size_t external_data_size = std::accumulate(shape.begin(), shape.end(), 1, + std::multiplies()) * + dtype.bytes(); + auto buffer = + std::make_shared(data_ptr, external_data_size, device_id); + auto tensor = std::make_shared(dtype.ref(), buffer, Dims(shape), + Dims(shape), Dims(), Dims()); + ref_ = tensor; +} + size_t Tensor::id() const { if (ref_) { return ref_->id(); @@ -43,14 +57,14 @@ Dims Tensor::padded_shape() const { return Dims(); } -const DataType &Tensor::data_type() const { +const DataType& Tensor::data_type() const { if (ref_) { return DataType::from_name(ref_->data_type()->type_name()); } return NONE; } -std::ostream &operator<<(std::ostream &os, const Tensor &tensor) { +std::ostream& operator<<(std::ostream& os, const Tensor& tensor) { if (tensor.is_null()) { os << "null"; } else { diff --git a/ark/codegen.cpp b/ark/codegen.cpp index 09ff28dd3..a97e5e45b 100644 --- a/ark/codegen.cpp +++ b/ark/codegen.cpp @@ -10,6 +10,7 @@ #include "file_io.h" #include "logging.h" #include "model/model_buffer.hpp" +#include "model_buffer_manager.hpp" #include "model/model_data_type.hpp" #include "model/model_op.hpp" #include "model/model_tensor.hpp" @@ -43,7 +44,7 @@ class CodeGenerator::Impl { public: Impl(const PlanJson &plan, const std::map &buffer_id_to_offset, - const std::string &name); + const std::string &name, ModelBufferManager *buffer_manager); ~Impl() = default; private: @@ -64,6 +65,8 @@ class CodeGenerator::Impl { std::string sync_process_range(const Range &ranges, int state_id); + ModelBufferManager *buffer_manager_; + protected: friend class CodeGenerator; @@ -78,14 +81,18 @@ class CodeGenerator::Impl { CodeGenerator::Impl::Impl(const PlanJson &plan, const std::map &buffer_id_to_offset, - const std::string &name) - : buffer_id_to_offset_(buffer_id_to_offset), name_(name) { + const std::string &name, + ModelBufferManager *buffer_manager) + : buffer_id_to_offset_(buffer_id_to_offset), + name_(name), + buffer_manager_(buffer_manager) { rank_ = plan.at("Rank"); world_size_ = plan.at("WorldSize"); num_procs_ = plan.at("NumProcessors"); num_warps_per_proc_ = plan.at("NumWarpsPerProcessor"); std::stringstream definitions_ss; + for (auto &task_json : plan.at("TaskInfos")) { definitions_ss << this->def_task(task_json); } @@ -224,11 +231,19 @@ std::string CodeGenerator::Impl::def_task(const Json &task_json) { auto &arg = impl_args[i]; if (arg.type_name() == "TENSOR") { auto tns = arg.value(); - size_t buffer_offset = - buffer_id_to_offset_.at(tns->buffer()->id()); - size_t offset = buffer_offset + ModelOffset(tns).value(); - ss << "(" << tns->data_type()->type_str() << "*)&_buf[" - << offset << "]"; + if (tns->buffer()->is_external()) { + void *buf_addr = + ModelBufferManager::get_instance().get_buffer( + tns->buffer()->id()); + ss << "(" << tns->data_type()->type_str() << "*)" + << buf_addr; + } else { + size_t buffer_offset = + buffer_id_to_offset_.at(tns->buffer()->id()); + size_t offset = buffer_offset + ModelOffset(tns).value(); + ss << "(" << tns->data_type()->type_str() << "*)&_buf[" + << offset << "]"; + } } else if (arg.type_name() == "OFFSET") { auto moff = arg.value(); size_t buffer_offset = @@ -431,8 +446,9 @@ std::string CodeGenerator::Impl::sync_process_range(const Range &range, CodeGenerator::CodeGenerator( const PlanJson &plan, const std::map &buffer_id_to_offset, - const std::string &name) - : impl_(std::make_shared(plan, buffer_id_to_offset, name)) {} + const std::string &name, ModelBufferManager *buffer_manager) + : impl_(std::make_shared(plan, buffer_id_to_offset, name, + buffer_manager)) {} std::string CodeGenerator::code() const { return impl_->code_; } diff --git a/ark/codegen.hpp b/ark/codegen.hpp index 4f8307e7e..a2976e644 100644 --- a/ark/codegen.hpp +++ b/ark/codegen.hpp @@ -8,6 +8,7 @@ #include #include +#include "model_buffer_manager.hpp" #include "model/model_json.hpp" namespace ark { @@ -16,7 +17,8 @@ class CodeGenerator { public: CodeGenerator(const PlanJson &plan, const std::map &buffer_id_to_offset, - const std::string &name = "ark_kernel"); + const std::string &name = "ark_kernel", + ModelBufferManager *buffer_manager = nullptr); ~CodeGenerator() = default; diff --git a/ark/include/ark/tensor.hpp b/ark/include/ark/tensor.hpp index 747ce5fea..d13748175 100644 --- a/ark/include/ark/tensor.hpp +++ b/ark/include/ark/tensor.hpp @@ -31,6 +31,8 @@ class Tensor { Tensor(ModelTensorRef ref) : ref_(ref) {} Tensor(const Tensor &other) = default; Tensor &operator=(const Tensor &other) = default; + Tensor(void *data_ptr, int32_t device_id, const std::vector &shape, + const DataType &dtype); bool operator==(const Tensor &other) const { return ref_ == other.ref_; } bool operator!=(const Tensor &other) const { return ref_ != other.ref_; } diff --git a/ark/model/model_buffer.cpp b/ark/model/model_buffer.cpp index 4ce91b5e4..ce8f37727 100644 --- a/ark/model/model_buffer.cpp +++ b/ark/model/model_buffer.cpp @@ -4,13 +4,13 @@ #include "model_buffer.hpp" #include "logging.h" +#include "model_buffer_manager.hpp" namespace ark { -ModelBuffer::ModelBuffer(int rank) : rank_(rank) { - static size_t id = 0; - id_ = id++; -} +size_t ModelBuffer::curr_id = 0; + +ModelBuffer::ModelBuffer(int rank) : rank_(rank) { id_ = curr_id++; } ModelBuffer::ModelBuffer(size_t id, int rank, const std::vector &send_tags, @@ -24,6 +24,23 @@ ModelBuffer::ModelBuffer(size_t id, int rank, } } +ModelBuffer::ModelBuffer(void *data, size_t size, int32_t device_id) + : rank_(-1), + external_data_(data), + external_data_size_(size), + device_id_(device_id), + is_external_(true) { + id_ = curr_id++; +} + +ModelBuffer::ModelBuffer(size_t id, void *data, size_t size, int32_t device_id) + : id_(id), + rank_(-1), + external_data_(data), + external_data_size_(size), + device_id_(device_id), + is_external_(true) {} + void ModelBuffer::tag_send(int remote_rank, int tag) { send_tags_.insert(TagInfo{remote_rank, tag}); } @@ -46,6 +63,14 @@ Json ModelBuffer::serialize() const { } j["SendTags"] = send_tags; j["RecvTags"] = recv_tags; + j["IsExternal"] = is_external_; + if (is_external_) { + ModelBufferManager::get_instance().register_buffer(id_, external_data_, + external_data_size_); + j["ExternalDataSize"] = external_data_size_; + j["DeviceId"] = device_id_; + } + // external_data_ptr_ is not included in JSON return j; } @@ -62,6 +87,28 @@ std::shared_ptr ModelBuffer::deserialize(const Json &serialized) { } else if (!serialized.contains("RecvTags")) { ERR(InvalidUsageError, "ModelBuffer deserialization failed: missing RecvTags"); + } else if (!serialized.contains("IsExternal")) { + ERR(InvalidUsageError, + "ModelBuffer deserialization failed: missing IsExternal"); + } + if (serialized["IsExternal"]) { + if (!serialized.contains("ExternalDataSize")) { + ERR(InvalidUsageError, + "ModelBuffer deserialization failed: missing ExternalDataSize"); + } else if (!serialized.contains("DeviceId")) { + ERR(InvalidUsageError, + "ModelBuffer deserialization failed: missing DeviceId"); + } + void *data_ptr = + ModelBufferManager::get_instance().get_buffer(serialized["Id"]); + if (!data_ptr) { + ERR(InvalidUsageError, + "ModelBuffer deserialization failed: external buffer not found " + "in BufferManager"); + } + return std::make_shared(serialized["Id"], data_ptr, + serialized["ExternalDataSize"], + serialized["DeviceId"]); } return std::make_shared(serialized["Id"], serialized["Rank"], serialized["SendTags"], diff --git a/ark/model/model_buffer.hpp b/ark/model/model_buffer.hpp index 7ad3db206..e7f1045b2 100644 --- a/ark/model/model_buffer.hpp +++ b/ark/model/model_buffer.hpp @@ -22,6 +22,10 @@ class ModelBuffer { ModelBuffer(size_t id, int rank, const std::vector &send_tags, const std::vector &recv_tags); + // externally managed buffer + ModelBuffer(void *data, size_t size, int32_t device_id); + ModelBuffer(size_t id, void *data, size_t size, int32_t device_id); + size_t id() const { return id_; } int rank() const { return rank_; } @@ -44,11 +48,22 @@ class ModelBuffer { static std::shared_ptr deserialize(const Json &serialized); + // external buffer management + size_t external_data_size() const { return external_data_size_; } + void *external_data() const { return external_data_; } + int32_t device_id() const { return device_id_; } + bool is_external() const { return is_external_; } + private: + static size_t curr_id; size_t id_; int rank_; std::set send_tags_; std::set recv_tags_; + void *external_data_ = nullptr; + size_t external_data_size_ = 0; + int32_t device_id_; + bool is_external_ = false; }; } // namespace ark diff --git a/ark/model_buffer_manager.hpp b/ark/model_buffer_manager.hpp new file mode 100644 index 000000000..7b705f4c8 --- /dev/null +++ b/ark/model_buffer_manager.hpp @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef ARK_MODEL_BUFFER_MANAGER_HPP_ +#define ARK_MODEL_BUFFER_MANAGER_HPP_ + +#include +#include + +namespace ark { +// Manages externally allocated buffers not in the ARK memory space. +class ModelBufferManager { + public: + static ModelBufferManager& get_instance() { + static ModelBufferManager instance; + return instance; + } + + void register_buffer(size_t id, void* data, size_t size) { + buffers_[id] = std::make_tuple(data, size); + } + + void* get_buffer(size_t id) { + auto it = buffers_.find(id); + if (it != buffers_.end()) { + return std::get<0>(it->second); + } + return nullptr; + } + + size_t get_buffer_size(size_t id) { + auto it = buffers_.find(id); + if (it != buffers_.end()) { + return std::get<1>(it->second); + } + return 0; + } + + const std::unordered_map>& get_buffers() + const { + return buffers_; + } + + void clear_buffers() { buffers_.clear(); } + + bool is_empty() const { return buffers_.empty(); } + + private: + std::unordered_map> + buffers_; // Maps buffer IDs to pointers and sizes. + size_t next_compact_id_ = 0; + ModelBufferManager() {} + ModelBufferManager(const ModelBufferManager&) = delete; + ModelBufferManager& operator=(const ModelBufferManager&) = delete; +}; +} // namespace ark + +#endif // ARK_MODEL_BUFFER_MANAGER_HPP_ diff --git a/python/ark/tensor.py b/python/ark/tensor.py index ac2886960..8f26dc96e 100644 --- a/python/ark/tensor.py +++ b/python/ark/tensor.py @@ -167,18 +167,20 @@ def from_numpy(self, ndarray: np.ndarray) -> "Tensor": return self @staticmethod - def from_torch(tensor: torch.Tensor): - return Tensor( - Model.get_model().tensor( - Dims(list(tensor.shape)), - DataType.from_torch(tensor.dtype).ctype(), - Dims(), - Dims(), - Dims(), - "", - ), - lambda: tensor, - ) + def from_torch(tensor: torch.Tensor, runtime_id: int = -1) -> "Tensor": + """ + Returns an ARK tensor that shares the same memory with the torch tensor. + """ + if _no_torch: + raise ImportError("torch is not available") + elif not tensor.is_contiguous(): + raise ValueError("Torch tensor must be contiguous.") + elif tensor.device.type == "cpu": + raise ValueError("Torch tensor must be on a device.") + ark_dtype = DataType.from_torch(tensor.dtype) + dl_capsule = torch.utils.dlpack.to_dlpack(tensor) + ark_tensor = _Tensor(dl_capsule, ark_dtype.ctype()) + return Tensor(ark_tensor, runtime_id=runtime_id) def copy(self, data: Union[np.ndarray, torch.Tensor]) -> "Tensor": """ diff --git a/python/tensor_py.cpp b/python/tensor_py.cpp index fbd909d3d..16eb03421 100644 --- a/python/tensor_py.cpp +++ b/python/tensor_py.cpp @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. +#include #include #include #include @@ -9,8 +10,51 @@ namespace py = pybind11; -void register_tensor(py::module &m) { +struct DLTensorMetadata { + void* data_ptr; + int32_t device_id; + DLDeviceType device_type; + int32_t ndim; + DLDataType dtype; + std::vector shape; + std::vector strides; + uint64_t byte_offset; +}; + +static DLTensorMetadata extractDLTensorMetadata(DLManagedTensor* dl_tensor) { + DLTensorMetadata metadata; + metadata.data_ptr = dl_tensor->dl_tensor.data; + metadata.device_id = dl_tensor->dl_tensor.device.device_id; + metadata.device_type = dl_tensor->dl_tensor.device.device_type; + metadata.ndim = dl_tensor->dl_tensor.ndim; + metadata.dtype = dl_tensor->dl_tensor.dtype; + metadata.shape.assign( + dl_tensor->dl_tensor.shape, + dl_tensor->dl_tensor.shape + dl_tensor->dl_tensor.ndim); + if (dl_tensor->dl_tensor.strides != nullptr) { + metadata.strides.assign( + dl_tensor->dl_tensor.strides, + dl_tensor->dl_tensor.strides + dl_tensor->dl_tensor.ndim); + } + metadata.byte_offset = dl_tensor->dl_tensor.byte_offset; + return metadata; +} + +void register_tensor(py::module& m) { py::class_(m, "_Tensor") + .def(py::init([](py::capsule capsule, const ark::DataType& dtype) { + DLManagedTensor* dl_tensor = (DLManagedTensor*)capsule; + if (!dl_tensor) { + throw std::runtime_error( + "Capsule does not contain a DLManagedTensor"); + } + DLTensorMetadata metadata = extractDLTensorMetadata(dl_tensor); + int32_t device_id = metadata.device_id; + void* data_ptr = metadata.data_ptr; + auto shape = metadata.shape; + + return new ark::Tensor(data_ptr, device_id, shape, dtype); + })) .def("id", &ark::Tensor::id) .def("shape", &ark::Tensor::shape, py::return_value_policy::reference) .def("strides", &ark::Tensor::strides, diff --git a/python/unittest/test_conversion.py b/python/unittest/test_conversion.py index 5befa1c34..833b88662 100644 --- a/python/unittest/test_conversion.py +++ b/python/unittest/test_conversion.py @@ -1,6 +1,7 @@ import pytest import numpy as np import ark +from typing import Callable try: import torch @@ -9,6 +10,8 @@ except ImportError: _no_torch = True +# ARK to Torch tests + def initialize_tensor(dimensions, dtype): tensor = ark.tensor(dimensions, dtype) @@ -69,7 +72,7 @@ def check_diff(input_tensor_host, input_view_numpy, value, index): # Test function to check if changes to the torch views are reflected in the original tensors @pytest.mark.parametrize("dtype", [ark.fp16, ark.fp32]) -def test_aliasing(dtype: ark.DataType): +def test_ark_to_torch_aliasing(dtype: ark.DataType): ark.init() dimensions = [4, 4] input_tensor, input_tensor_host = initialize_tensor(dimensions, dtype) @@ -126,3 +129,79 @@ def test_conversion_torch(): torch_tensor = t.to_torch() assert torch.all(torch_tensor == 7) + + +# Torch to ARK tests + +ArkBinOp = Callable[[ark.Tensor, ark.Tensor], ark.Tensor] +TorchBinOp = Callable[[torch.Tensor, torch.Tensor], torch.Tensor] +ArkUnOp = Callable[[ark.Tensor], ark.Tensor] +TorchUnOp = Callable[[torch.Tensor], torch.Tensor] + + +# Verify the accuracy of binary operations involving ARK view tensors +@pytest.mark.parametrize( + "dtype, ark_op, torch_op, tensor_dims", + [(torch.float16, ark.add, torch.add, (2, 3))], +) +def test_bin_op(dtype, ark_op: ArkBinOp, torch_op: TorchBinOp, tensor_dims): + ark.init() + input_tensor = torch.randn(tensor_dims, dtype=dtype, device="cuda:0") + other_tensor = torch.randn(tensor_dims, dtype=dtype, device="cuda:0") + expected_output = torch_op(input_tensor, other_tensor).cpu().numpy() + input_ark_view = ark.Tensor.from_torch(input_tensor) + other_ark_view = ark.Tensor.from_torch(other_tensor) + output = ark_op(input_ark_view, other_ark_view) + runtime = ark.Runtime() + runtime.launch() + runtime.run() + output_host = output.to_numpy() + runtime.stop() + runtime.reset() + assert np.allclose(output_host, expected_output) + + +# Verify the accuracy of unary operations involving ARK view tensors +@pytest.mark.parametrize( + "dtype, ark_op, torch_op, tensor_dims", + [(torch.float16, ark.exp, torch.exp, (3, 3))], +) +def test_unary_op(dtype, ark_op: ArkUnOp, torch_op: TorchUnOp, tensor_dims): + ark.init() + input_tensor = torch.randn(tensor_dims, dtype=dtype, device="cuda:0") + expected_output = torch_op(input_tensor).cpu().numpy() + input_ark_view = ark.Tensor.from_torch(input_tensor) + output = ark_op(input_ark_view) + runtime = ark.Runtime() + runtime.launch() + runtime.run() + output_host = output.to_numpy() + runtime.stop() + runtime.reset() + assert np.allclose(output_host, expected_output) + + +# Test function to check if changes in torch tensors are reflected in ARK views +@pytest.mark.parametrize("dtype, tensor_dims", [(torch.float16, (64, 64))]) +def test_torch_to_ark_aliasing(dtype, tensor_dims): + ark.init() + # Initialize a PyTorch tensor + input_tensor = torch.randn(tensor_dims, dtype=dtype, device="cuda:0") + other_tensor = torch.randn(tensor_dims, dtype=dtype, device="cuda:0") + + input_ark_view = ark.Tensor.from_torch(input_tensor) + other_ark_view = ark.Tensor.from_torch(other_tensor) + + output = ark.add(input_ark_view, other_ark_view) + # Perform in place operations + input_tensor += other_tensor + other_tensor += input_tensor + expected_output = (input_tensor + other_tensor).cpu().numpy() + + runtime = ark.Runtime() + runtime.launch() + runtime.run() + output_host = output.to_numpy() + runtime.stop() + runtime.reset() + assert np.allclose(output_host, expected_output)