-
Notifications
You must be signed in to change notification settings - Fork 79
Symmetric memory pytorch backends #6023
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
14fd212
5646c03
14816aa
6996d05
49d669c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -15,10 +15,51 @@ | |||||||||||||||||||||||||||||||||||||||||
| #include "multidevice/ipc_utils.h" | ||||||||||||||||||||||||||||||||||||||||||
| #include "multidevice/utils.h" | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| #ifdef NVFUSER_DISTRIBUTED | ||||||||||||||||||||||||||||||||||||||||||
| #include <torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp> | ||||||||||||||||||||||||||||||||||||||||||
| #include <torch/csrc/distributed/c10d/GroupRegistry.hpp> | ||||||||||||||||||||||||||||||||||||||||||
| #endif | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| namespace nvfuser { | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| namespace { | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| #ifdef NVFUSER_DISTRIBUTED | ||||||||||||||||||||||||||||||||||||||||||
| const char* kPyTorchSymmMemGroupName = "nvfuser_symm"; | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| void ensurePyTorchSymmMemBackend(SymmetricMemoryBackend backend) { | ||||||||||||||||||||||||||||||||||||||||||
| static std::once_flag once; | ||||||||||||||||||||||||||||||||||||||||||
| std::call_once(once, [backend]() { | ||||||||||||||||||||||||||||||||||||||||||
| const char* name = nullptr; | ||||||||||||||||||||||||||||||||||||||||||
| switch (backend) { | ||||||||||||||||||||||||||||||||||||||||||
| case SymmetricMemoryBackend::PyTorchNccl: | ||||||||||||||||||||||||||||||||||||||||||
| name = "NCCL"; | ||||||||||||||||||||||||||||||||||||||||||
| break; | ||||||||||||||||||||||||||||||||||||||||||
| case SymmetricMemoryBackend::PyTorchNvshmem: | ||||||||||||||||||||||||||||||||||||||||||
| name = "NVSHMEM"; | ||||||||||||||||||||||||||||||||||||||||||
| break; | ||||||||||||||||||||||||||||||||||||||||||
| case SymmetricMemoryBackend::PyTorchCuda: | ||||||||||||||||||||||||||||||||||||||||||
| name = "CUDA"; | ||||||||||||||||||||||||||||||||||||||||||
| break; | ||||||||||||||||||||||||||||||||||||||||||
| default: | ||||||||||||||||||||||||||||||||||||||||||
| NVF_ERROR(false, "Unexpected PyTorch symmetric memory backend"); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| c10d::symmetric_memory::set_backend(name); | ||||||||||||||||||||||||||||||||||||||||||
| Communicator& comm = Communicator::getInstance(); | ||||||||||||||||||||||||||||||||||||||||||
| NVF_CHECK(comm.is_available(), "Communicator not available for symmetric memory"); | ||||||||||||||||||||||||||||||||||||||||||
| c10d::symmetric_memory::set_group_info( | ||||||||||||||||||||||||||||||||||||||||||
| kPyTorchSymmMemGroupName, | ||||||||||||||||||||||||||||||||||||||||||
| static_cast<int>(comm.deviceId()), | ||||||||||||||||||||||||||||||||||||||||||
| static_cast<int>(comm.size()), | ||||||||||||||||||||||||||||||||||||||||||
| comm.getStore()); | ||||||||||||||||||||||||||||||||||||||||||
| // c10d::register_process_group( | ||||||||||||||||||||||||||||||||||||||||||
| // kPyTorchSymmMemGroupName, | ||||||||||||||||||||||||||||||||||||||||||
| // comm.getWorldBackendIntrusivePtr(CommunicatorBackend::kNccl)); | ||||||||||||||||||||||||||||||||||||||||||
| }); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| #endif | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| // Returns the allocation granularity for symmetric memory. | ||||||||||||||||||||||||||||||||||||||||||
| // - query_mcast_granularity: if true, considers multicast granularity | ||||||||||||||||||||||||||||||||||||||||||
| // - query_mcast_recommended_granularity: if true, uses recommended (larger) | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -88,6 +129,39 @@ at::Tensor SymmetricTensor::allocate( | |||||||||||||||||||||||||||||||||||||||||
| at::IntArrayRef sizes, | ||||||||||||||||||||||||||||||||||||||||||
| at::ScalarType dtype, | ||||||||||||||||||||||||||||||||||||||||||
| at::Device device) { | ||||||||||||||||||||||||||||||||||||||||||
| SymmetricMemoryBackend backend = getSymmetricMemoryBackend(); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| #ifdef NVFUSER_DISTRIBUTED | ||||||||||||||||||||||||||||||||||||||||||
| if (backend != SymmetricMemoryBackend::Native) { | ||||||||||||||||||||||||||||||||||||||||||
| ensurePyTorchSymmMemBackend(backend); | ||||||||||||||||||||||||||||||||||||||||||
| std::vector<int64_t> strides(sizes.size()); | ||||||||||||||||||||||||||||||||||||||||||
| strides.back() = 1; | ||||||||||||||||||||||||||||||||||||||||||
| for (int64_t i = (int64_t)strides.size() - 2; i >= 0; --i) { | ||||||||||||||||||||||||||||||||||||||||||
| strides[i] = strides[i + 1] * sizes[i + 1]; | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| // NCCLSymmetricMemoryAllocator::alloc must not be called with a group_name. | ||||||||||||||||||||||||||||||||||||||||||
| c10::optional<std::string> alloc_group_name = | ||||||||||||||||||||||||||||||||||||||||||
| (backend == SymmetricMemoryBackend::PyTorchNccl) | ||||||||||||||||||||||||||||||||||||||||||
| ? c10::nullopt | ||||||||||||||||||||||||||||||||||||||||||
| : c10::optional<std::string>(kPyTorchSymmMemGroupName); | ||||||||||||||||||||||||||||||||||||||||||
| return c10d::symmetric_memory::empty_strided_p2p( | ||||||||||||||||||||||||||||||||||||||||||
| sizes, | ||||||||||||||||||||||||||||||||||||||||||
| strides, | ||||||||||||||||||||||||||||||||||||||||||
| dtype, | ||||||||||||||||||||||||||||||||||||||||||
| device, | ||||||||||||||||||||||||||||||||||||||||||
| alloc_group_name, | ||||||||||||||||||||||||||||||||||||||||||
| c10::nullopt); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| #else | ||||||||||||||||||||||||||||||||||||||||||
| if (backend != SymmetricMemoryBackend::Native) { | ||||||||||||||||||||||||||||||||||||||||||
| NVF_ERROR( | ||||||||||||||||||||||||||||||||||||||||||
| false, | ||||||||||||||||||||||||||||||||||||||||||
| "PyTorch symmetric memory backend requires a build with " | ||||||||||||||||||||||||||||||||||||||||||
| "NVFUSER_DISTRIBUTED. Use NVFUSER_ENABLE=symmetric_memory_backend(native) " | ||||||||||||||||||||||||||||||||||||||||||
| "or do not set symmetric_memory_backend."); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| #endif | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| int is_vmm_supported; | ||||||||||||||||||||||||||||||||||||||||||
| NVFUSER_CUDA_SAFE_CALL(cuDeviceGetAttribute( | ||||||||||||||||||||||||||||||||||||||||||
| &is_vmm_supported, | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -212,6 +286,18 @@ SymmetricTensor::SymmetricTensor(const at::Tensor& local_tensor) | |||||||||||||||||||||||||||||||||||||||||
| "Expected CUDA tensor, got: ", | ||||||||||||||||||||||||||||||||||||||||||
| local_tensor.device()); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| #ifdef NVFUSER_DISTRIBUTED | ||||||||||||||||||||||||||||||||||||||||||
| SymmetricMemoryBackend backend = getSymmetricMemoryBackend(); | ||||||||||||||||||||||||||||||||||||||||||
| if (backend != SymmetricMemoryBackend::Native) { | ||||||||||||||||||||||||||||||||||||||||||
| ensurePyTorchSymmMemBackend(backend); | ||||||||||||||||||||||||||||||||||||||||||
| Communicator& comm = Communicator::getInstance(); | ||||||||||||||||||||||||||||||||||||||||||
| world_size_ = comm.size(); | ||||||||||||||||||||||||||||||||||||||||||
| my_device_id_ = comm.deviceId(); | ||||||||||||||||||||||||||||||||||||||||||
| requested_size_ = local_tensor.numel() * local_tensor.element_size(); | ||||||||||||||||||||||||||||||||||||||||||
| return; // Rendezvous runs in setupRemoteHandles() | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| #endif | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| std::string error = SymmetricTensor::validate(local_tensor); | ||||||||||||||||||||||||||||||||||||||||||
| NVF_CHECK(error.empty(), "Invalid symmetric allocation: ", error); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -253,6 +339,11 @@ SymmetricTensor::SymmetricTensor(const at::Tensor& local_tensor) | |||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| SymmetricTensor::~SymmetricTensor() { | ||||||||||||||||||||||||||||||||||||||||||
| #ifdef NVFUSER_DISTRIBUTED | ||||||||||||||||||||||||||||||||||||||||||
| if (py_symm_handle_) { | ||||||||||||||||||||||||||||||||||||||||||
| return; // PyTorch backend: no native VMM cleanup | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| #endif | ||||||||||||||||||||||||||||||||||||||||||
| #if (CUDA_VERSION >= 13000) | ||||||||||||||||||||||||||||||||||||||||||
| if (is_multicast_setup_) { | ||||||||||||||||||||||||||||||||||||||||||
| if (mc_base_ptr_) { | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -302,6 +393,20 @@ void SymmetricTensor::setupRemoteHandles(const std::string& tag) { | |||||||||||||||||||||||||||||||||||||||||
| if (are_remote_tensors_setup_ == true) { | ||||||||||||||||||||||||||||||||||||||||||
| return; | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| #ifdef NVFUSER_DISTRIBUTED | ||||||||||||||||||||||||||||||||||||||||||
| // PyTorch backend: perform rendezvous here (lazy, on first setupRemoteHandles). | ||||||||||||||||||||||||||||||||||||||||||
| if (getSymmetricMemoryBackend() != SymmetricMemoryBackend::Native) { | ||||||||||||||||||||||||||||||||||||||||||
| ensurePyTorchSymmMemBackend(getSymmetricMemoryBackend()); | ||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+398
to
+399
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||
| py_symm_handle_ = c10d::symmetric_memory::rendezvous( | ||||||||||||||||||||||||||||||||||||||||||
| local_tensor_, c10::optional<std::string>(kPyTorchSymmMemGroupName)); | ||||||||||||||||||||||||||||||||||||||||||
| are_remote_tensors_setup_ = true; | ||||||||||||||||||||||||||||||||||||||||||
| if (py_symm_handle_->has_multicast_support()) { | ||||||||||||||||||||||||||||||||||||||||||
| is_multicast_setup_ = true; | ||||||||||||||||||||||||||||||||||||||||||
| mc_ptr_ = py_symm_handle_->get_multicast_ptr(); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| return; | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| #endif | ||||||||||||||||||||||||||||||||||||||||||
| Communicator& comm = Communicator::getInstance(); | ||||||||||||||||||||||||||||||||||||||||||
| CUmemGenericAllocationHandle local_handle = alloc_handles_[my_device_id_]; | ||||||||||||||||||||||||||||||||||||||||||
| CUdeviceptr local_ptr = remote_ptrs_[my_device_id_]; | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -379,6 +484,13 @@ at::Tensor SymmetricTensor::remoteTensor(int64_t rank) const { | |||||||||||||||||||||||||||||||||||||||||
| return local_tensor_; | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| #ifdef NVFUSER_DISTRIBUTED | ||||||||||||||||||||||||||||||||||||||||||
| if (py_symm_handle_) { | ||||||||||||||||||||||||||||||||||||||||||
| return py_symm_handle_->get_remote_tensor( | ||||||||||||||||||||||||||||||||||||||||||
| rank, local_tensor_.sizes(), local_tensor_.scalar_type()); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| #endif | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| NVF_CHECK(are_remote_tensors_setup_ == true, "Remote tensors not setup"); | ||||||||||||||||||||||||||||||||||||||||||
| return at::from_blob( | ||||||||||||||||||||||||||||||||||||||||||
| reinterpret_cast<void*>(remote_ptrs_[rank]), | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -390,6 +502,13 @@ at::Tensor SymmetricTensor::remoteTensor(int64_t rank) const { | |||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| void* SymmetricTensor::multicastPtr() const { | ||||||||||||||||||||||||||||||||||||||||||
| #ifdef NVFUSER_DISTRIBUTED | ||||||||||||||||||||||||||||||||||||||||||
| if (py_symm_handle_) { | ||||||||||||||||||||||||||||||||||||||||||
| return py_symm_handle_->has_multicast_support() | ||||||||||||||||||||||||||||||||||||||||||
| ? py_symm_handle_->get_multicast_ptr() | ||||||||||||||||||||||||||||||||||||||||||
| : nullptr; | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| #endif | ||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
504
to
+511
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Any caller that does not check for Consider throwing or at least asserting instead of silently returning
Suggested change
This brings the error contract in line with the native path, where |
||||||||||||||||||||||||||||||||||||||||||
| NVF_CHECK(is_multicast_setup_, "Multicast not setup"); | ||||||||||||||||||||||||||||||||||||||||||
| return mc_ptr_; | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -398,7 +517,14 @@ void SymmetricTensor::setupContiguousView(const std::string& tag) { | |||||||||||||||||||||||||||||||||||||||||
| if (is_contiguous_view_setup_) { | ||||||||||||||||||||||||||||||||||||||||||
| return; | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| #ifdef NVFUSER_DISTRIBUTED | ||||||||||||||||||||||||||||||||||||||||||
| if (py_symm_handle_) { | ||||||||||||||||||||||||||||||||||||||||||
| NVF_ERROR( | ||||||||||||||||||||||||||||||||||||||||||
| false, | ||||||||||||||||||||||||||||||||||||||||||
| "Contiguous view is not yet supported for PyTorch symmetric memory backend. " | ||||||||||||||||||||||||||||||||||||||||||
| "Use native backend for SymmetricContiguousView."); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| #endif | ||||||||||||||||||||||||||||||||||||||||||
| NVF_CHECK( | ||||||||||||||||||||||||||||||||||||||||||
| are_remote_tensors_setup_ == true, | ||||||||||||||||||||||||||||||||||||||||||
| "Remote tensors must be setup before setupContiguousView"); | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -462,13 +588,25 @@ void SymmetricTensor::setupContiguousView(const std::string& tag) { | |||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| at::Tensor SymmetricTensor::getContiguousView() const { | ||||||||||||||||||||||||||||||||||||||||||
| #ifdef NVFUSER_DISTRIBUTED | ||||||||||||||||||||||||||||||||||||||||||
| if (py_symm_handle_) { | ||||||||||||||||||||||||||||||||||||||||||
| NVF_ERROR( | ||||||||||||||||||||||||||||||||||||||||||
| false, | ||||||||||||||||||||||||||||||||||||||||||
| "Contiguous view is not yet supported for PyTorch symmetric memory backend."); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| #endif | ||||||||||||||||||||||||||||||||||||||||||
| NVF_CHECK(is_contiguous_view_setup_, "Contiguous view not setup"); | ||||||||||||||||||||||||||||||||||||||||||
| return contiguous_view_; | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| void SymmetricTensor::setupMulticast( | ||||||||||||||||||||||||||||||||||||||||||
| int64_t exporter_rank, | ||||||||||||||||||||||||||||||||||||||||||
| const std::string& tag) { | ||||||||||||||||||||||||||||||||||||||||||
| #ifdef NVFUSER_DISTRIBUTED | ||||||||||||||||||||||||||||||||||||||||||
| if (py_symm_handle_) { | ||||||||||||||||||||||||||||||||||||||||||
| return; // PyTorch backend: multicast handled by backend if supported | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| #endif | ||||||||||||||||||||||||||||||||||||||||||
| #if (CUDA_VERSION >= 13000) | ||||||||||||||||||||||||||||||||||||||||||
| if (is_multicast_setup_) { | ||||||||||||||||||||||||||||||||||||||||||
| return; | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| #!/bin/bash | ||
|
|
||
| export CC=clang-20 | ||
| export CXX=clang++-20 | ||
| export LDFLAGS="-fuse-ld=mold" | ||
|
|
||
| export NVFUSER_BUILD_ENABLE_PCH | ||
|
|
||
| export UCC_HOME="/opt/hpcx/ucc" | ||
| export UCC_DIR="/opt/hpcx/ucc/lib/cmake/ucc" | ||
| export UCX_HOME="/opt/hpcx/ucx" | ||
| export UCX_DIR="/opt/hpcx/ucx/lib/cmake/ucx" | ||
|
|
||
| # export TORCH_CUDA_ARCH_LIST="9.0" | ||
|
|
||
| export NVFUSER_BUILD_WITH_UCC=1 | ||
| export NVFUSER_BUILD_INSTALL_DIR=$BUILD_DIRECTORY/nvfuser | ||
| export NVFUSER_BUILD_DIR=$BUILD_DIRECTORY | ||
|
|
||
| # Enable debug mode, leave empty for non-debug compilation | ||
| export NVFUSER_BUILD_BUILD_TYPE=Debug | ||
| export RUN_CMAKE="" | ||
|
|
||
| pip install -v -e ./python --no-build-isolation | ||
|
Comment on lines
+1
to
+24
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Personal developer build script committed to repository This script contains machine-specific, hardcoded toolchain paths that are unlikely to work anywhere except the author's development machine:
This kind of personal convenience script should live outside version control (e.g., in a |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Undefined behavior when
sizesis empty (0-dim tensor)std::vector::back()on an empty vector is undefined behaviour. The same guard-free pattern also exists in the native path further down in the same function (~line 225). While allocating a 0-dimensional symmetric tensor is unusual, the PyTorch path that was just added adds a new callsite where callers may pass{}as sizes. A simple check is sufficient:NVF_CHECK(!sizes.empty(), "Cannot allocate a 0-dim symmetric tensor");or initialise strides defensively (matching the standard row-major convention for 0-dim tensors, which is an empty strides vector) and skip the loop entirely when
sizesis empty.