From d240175d2d1def83fd3c7ecec159c9c348d215f2 Mon Sep 17 00:00:00 2001 From: Siva Date: Thu, 25 Apr 2024 12:18:34 +0530 Subject: [PATCH] Take advantage of OpenCL host ptr for improved copy --- src/runtime/relax_vm/paged_kv_cache.cc | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 64759d465b72..efedac235bfc 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -31,6 +31,9 @@ #include #include "kv_state.h" +#if defined(OPENCL_ENABLE_HOST_PTR) +#include "../opencl/opencl_common.h" +#endif namespace tvm { namespace runtime { @@ -384,6 +387,22 @@ class PlainPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { return; } DLTensor copy_dst = *array.operator->(); +#if defined(OPENCL_ENABLE_HOST_PTR) + tvm::runtime::cl::OpenCLWorkspace* workspace = tvm::runtime::cl::OpenCLWorkspace::Global(); + if (workspace->IsOpenCLDevice(copy_dst.device)) { + void* nptr = workspace->GetNativePtr(array); + uint64_t copy_size; + if (shape.defined()) { + ICHECK_EQ(shape.value().size(), 1); + copy_size = shape.value()->data[0] * sizeof(int32_t); + } else { + copy_size = DeviceAPI::Get(array->device)->GetDataSize(*array.operator->()); + } + memcpy(static_cast(nptr) + dst_elem_offset * sizeof(int32_t), vec_data, copy_size); + return; + } +#endif + if (shape.defined()) { ICHECK_EQ(shape.value().size(), 1); copy_dst.ndim = 1;