From 358bf8b7bbc68054e5cba62f9ac2d3b87196dee2 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 28 May 2024 16:54:48 -0400 Subject: [PATCH] [Runtime][ROCm] Enable ROCm host memory support This PR enables the ROCMHost memory support in ROCm device API. --- src/runtime/ndarray.cc | 3 ++- src/runtime/rocm/rocm_device_api.cc | 40 +++++++++++++++++++++++++---- 2 files changed, 37 insertions(+), 6 deletions(-) diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index c2efa79c0c83..c2cf5f388a21 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -316,7 +316,8 @@ void NDArray::CopyFromTo(const DLTensor* from, DLTensor* to, TVMStreamHandle str ICHECK(from->device.device_type == to->device.device_type || from->device.device_type == kDLCPU || to->device.device_type == kDLCPU || from->device.device_type == kDLCUDAHost || - to->device.device_type == kDLCUDAHost) + to->device.device_type == kDLCUDAHost || from->device.device_type == kDLROCMHost || + to->device.device_type == kDLROCMHost) << "Can not copy across different device types directly. From device type: " << from->device.device_type << " to device type: " << to->device.device_type; diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc index f3cc46f92723..e2a5048ca030 100644 --- a/src/runtime/rocm/rocm_device_api.cc +++ b/src/runtime/rocm/rocm_device_api.cc @@ -144,16 +144,26 @@ class ROCMDeviceAPI final : public DeviceAPI { *rv = value; } void* AllocDataSpace(Device dev, size_t nbytes, size_t alignment, DLDataType type_hint) final { - ROCM_CALL(hipSetDevice(dev.device_id)); ICHECK_EQ(256 % alignment, 0U) << "ROCM space is aligned at 256 bytes"; void* ret; - ROCM_CALL(hipMalloc(&ret, nbytes)); + if (dev.device_type == kDLROCMHost) { + VLOG(1) << "allocating " << nbytes << "bytes on host"; + ROCM_CALL(hipHostMalloc(&ret, nbytes)); + } else { + ROCM_CALL(hipSetDevice(dev.device_id)); + VLOG(1) << "allocating " << nbytes << " bytes on device"; + ROCM_CALL(hipMalloc(&ret, nbytes)); + } return ret; } void FreeDataSpace(Device dev, void* ptr) final { - ROCM_CALL(hipSetDevice(dev.device_id)); - ROCM_CALL(hipFree(ptr)); + if (dev.device_type == kDLROCMHost) { + ROCM_CALL(hipHostFree(ptr)); + } else { + ROCM_CALL(hipSetDevice(dev.device_id)); + ROCM_CALL(hipFree(ptr)); + } } void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, @@ -162,6 +172,21 @@ class ROCMDeviceAPI final : public DeviceAPI { hipStream_t hip_stream = static_cast(stream); from = static_cast(from) + from_offset; to = static_cast(to) + to_offset; + + if (dev_from.device_type == kDLROCMHost) { + dev_from.device_type = kDLCPU; + } + + if (dev_to.device_type == kDLROCMHost) { + dev_to.device_type = kDLCPU; + } + + // In case there is a copy from host mem to host mem */ + if (dev_to.device_type == kDLCPU && dev_from.device_type == kDLCPU) { + memcpy(to, from, size); + return; + } + if (dev_from.device_type == kDLROCM && dev_to.device_type == kDLROCM) { ROCM_CALL(hipSetDevice(dev_from.device_id)); if (dev_from.device_id == dev_to.device_id) { @@ -210,7 +235,7 @@ class ROCMDeviceAPI final : public DeviceAPI { private: static void GPUCopy(const void* from, void* to, size_t size, hipMemcpyKind kind, hipStream_t stream) { - if (stream != 0) { + if (stream != nullptr) { ROCM_CALL(hipMemcpyAsync(to, from, size, kind, stream)); } else { ROCM_CALL(hipMemcpy(to, from, size, kind)); @@ -229,6 +254,11 @@ TVM_REGISTER_GLOBAL("device_api.rocm").set_body([](TVMArgs args, TVMRetValue* rv *rv = static_cast(ptr); }); +TVM_REGISTER_GLOBAL("device_api.rocm_host").set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = ROCMDeviceAPI::Global(); + *rv = static_cast(ptr); +}); + class ROCMTimerNode : public TimerNode { public: virtual void Start() {