From 9db3c804ab9e727d158ea5aa09ee7e73d429d3e5 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Sun, 29 Sep 2024 23:08:33 +0000 Subject: [PATCH] [NVSHMEM] Enable nvshmem memory allocation This PR add the support of nvshmem memory allocation, and integrates it into disco. --- .../contrib/nvshmem/{nvshmem.cc => init.cc} | 2 + .../contrib/nvshmem/memory_allocator.cc | 104 ++++++++++++++++++ tests/python/disco/test_nvshmem.py | 45 +++++++- 3 files changed, 145 insertions(+), 6 deletions(-) rename src/runtime/contrib/nvshmem/{nvshmem.cc => init.cc} (96%) create mode 100644 src/runtime/contrib/nvshmem/memory_allocator.cc diff --git a/src/runtime/contrib/nvshmem/nvshmem.cc b/src/runtime/contrib/nvshmem/init.cc similarity index 96% rename from src/runtime/contrib/nvshmem/nvshmem.cc rename to src/runtime/contrib/nvshmem/init.cc index 985ba5510762..50fdde4c49d8 100644 --- a/src/runtime/contrib/nvshmem/nvshmem.cc +++ b/src/runtime/contrib/nvshmem/init.cc @@ -54,6 +54,8 @@ void InitNVSHMEM(ShapeTuple uid_64, int num_workers) { } nvshmemx_set_attr_uniqueid_args(worker->worker_id, num_workers, &uid, &attr); nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr); + int mype_node = nvshmem_team_my_pe(NVSHMEMX_TEAM_NODE); + CUDA_CALL(cudaSetDevice(mype_node)); LOG_INFO << "NVSHMEM init finished: mype=" << nvshmem_my_pe() << " " << ", npes=" << nvshmem_n_pes(); } diff --git a/src/runtime/contrib/nvshmem/memory_allocator.cc b/src/runtime/contrib/nvshmem/memory_allocator.cc new file mode 100644 index 000000000000..89d56ed3dc81 --- /dev/null +++ b/src/runtime/contrib/nvshmem/memory_allocator.cc @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include + +#include + +#include "../../cuda/cuda_common.h" +#include "../../memory/pooled_allocator.h" + +namespace tvm { +namespace runtime { + +using tvm::runtime::memory::Buffer; +using tvm::runtime::memory::PooledAllocator; + +/*! + * \brief The memory allocator of NVSHMEM. + * Overriding PooledAllocator for efficient memory management. + */ +class NVSHMEMAllocator final : public PooledAllocator { + public: + explicit NVSHMEMAllocator() : PooledAllocator() {} + + ~NVSHMEMAllocator() { PooledAllocator::ReleaseAll(); } + + void Clear() final { PooledAllocator::ReleaseAll(); } + + bool AllowMemoryScope(const std::string& mem_scope) const final { + // The allowed memory scope of NVSHMEM is "nvshmem"; + return mem_scope == "nvshmem"; + } + + /*! \brief Return the global NVSHMEM singleton allocator. */ + static NVSHMEMAllocator* Global() { + static NVSHMEMAllocator* allocator = new NVSHMEMAllocator(); + return allocator; + } + + NDArray Empty(ShapeTuple shape, DataType dtype, Device device) { + NDArray::Container* container = new NDArray::Container(nullptr, shape, dtype, device); + container->SetDeleter([](Object* obj) { + auto* ptr = static_cast(obj); + ICHECK(ptr->manager_ctx != nullptr); + Buffer* buffer = reinterpret_cast(ptr->manager_ctx); + NVSHMEMAllocator::Global()->Free(*(buffer)); + delete buffer; + delete ptr; + }); + Buffer* buffer = new Buffer; + *buffer = PooledAllocator::Alloc(device, shape, dtype, String("nvshmem")); + container->manager_ctx = reinterpret_cast(buffer); + container->dl_tensor.data = buffer->data; + return NDArray(GetObjectPtr(container)); + } + + private: + void* DeviceAllocDataSpace(Device dev, size_t size, size_t alignment, + DLDataType type_hint) final { + ICHECK_EQ(dev.device_type, DLDeviceType::kDLCUDA) + << "nvshmem can only allocate cuda device memory space."; + ICHECK(type_hint.code == DLDataTypeCode::kDLInt || type_hint.code == DLDataTypeCode::kDLUInt || + type_hint.code == DLDataTypeCode::kDLFloat) + << "nvshmem can only allocate tensor with int, usingned int or float data types."; + return nvshmem_align(alignment, size); + } + + void DeviceFreeDataSpace(Device dev, void* ptr) final { nvshmem_free(ptr); } +}; + +NDArray NVSHMEMEmpty(ShapeTuple shape, DataType dtype, Device device) { + return NVSHMEMAllocator::Global()->Empty(shape, dtype, device); +} + +TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.empty").set_body_typed(NVSHMEMEmpty); + +void NVSHMEMFinalize() { + NVSHMEMAllocator::Global()->Clear(); + nvshmem_finalize(); +} + +TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.finalize_nvshmem").set_body_typed(NVSHMEMFinalize); + +} // namespace runtime +} // namespace tvm diff --git a/tests/python/disco/test_nvshmem.py b/tests/python/disco/test_nvshmem.py index 0b16fe93612f..b304d145aa38 100644 --- a/tests/python/disco/test_nvshmem.py +++ b/tests/python/disco/test_nvshmem.py @@ -23,6 +23,9 @@ import subprocess import threading import sys +from multiprocessing import Process +from typing import Any, Callable, List + import tvm import tvm.testing @@ -82,8 +85,6 @@ def start_server(): thread.join() def __del__(self): - for node in self.remote_nodes: - node.kill() if self.sess is not None: self.sess.shutdown() del self.sess @@ -98,17 +99,49 @@ def create_socket_session(num_workers): return _SOCKET_SESSION_TESTER.sess -@pytest.mark.parametrize("num_workers", [2, 4]) -def test_nvshmem_init(num_workers): +def test_nvshmem_init_finalize(session_kind: di.Session, num_workers: int): if tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid", True) is None: return - sess = create_socket_session(num_workers=num_workers) + + sess = session_kind(num_workers=num_workers) f_init_nvshmem_uid = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid") uid = f_init_nvshmem_uid() init_dfunc = sess.get_global_func("runtime.disco.nvshmem.init_nvshmem") init_dfunc(uid, num_workers) sess.sync_worker_0() + finalize_dfunc = sess.get_global_func("runtime.disco.nvshmem.finalize_nvshmem") + finalize_dfunc() + sess.sync_worker_0() + + +def test_nvshmem_empty(session_kind: di.Session, num_workers: int): + if tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid", True) is None: + return + + device = tvm.cuda() + sess = session_kind(num_workers=num_workers) + f_init_nvshmem_uid = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid") + uid = f_init_nvshmem_uid() + init_dfunc = sess.get_global_func("runtime.disco.nvshmem.init_nvshmem") + init_dfunc(uid, num_workers) + sess.sync_worker_0() + empty_dfunc = sess.get_global_func("runtime.disco.nvshmem.empty") + a = empty_dfunc(ShapeTuple((32, 64)), "float32", device) + b = empty_dfunc(ShapeTuple((64, 32)), "float32", device) + sess.sync_worker_0() + finalize_dfunc = sess.get_global_func("runtime.disco.nvshmem.finalize_nvshmem") + finalize_dfunc() + sess.sync_worker_0() if __name__ == "__main__": - tvm.testing.main() + # After the first call to `nvshmem_init`, a subsequent call to `nvshmem_init` + # or `nvshmem_init_thread` in the same program results in undefined behavior. + # So we always create a new process to run the test. Then no repeated nvshmem + # init happens in the same process, since the worker0 may share the same process. + for session_kind in [create_socket_session, di.ProcessSession]: + for num_workers in [2, 4]: + for test_func in [test_nvshmem_init_finalize, test_nvshmem_empty]: + p = Process(target=test_func, args=[session_kind, num_workers]) + p.start() + p.join()