Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions csrc/multidevice/communicator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,4 +424,21 @@ void Communicator::barrier(std::optional<CommunicatorBackend> backend) {
getWorld(backend)->barrier(options)->wait();
}

#ifdef NVFUSER_DISTRIBUTED
c10::intrusive_ptr<c10d::Store> Communicator::getStore() const {
return c10::intrusive_ptr<c10d::Store>(store_);
}

c10::intrusive_ptr<c10d::Backend> Communicator::getWorldBackendIntrusivePtr(
std::optional<CommunicatorBackend> backend) {
std::vector<RankType> all_ranks(size_);
std::iota(all_ranks.begin(), all_ranks.end(), 0);
CommunicatorBackend b = backend.value_or(default_backend_);
std::string team_key = getTeamKey(all_ranks, b);
(void)getBackendForTeam(all_ranks, backend, "");
return backends_.at(team_key);
}
#endif

} // namespace nvfuser

13 changes: 13 additions & 0 deletions csrc/multidevice/communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#ifdef NVFUSER_DISTRIBUTED
#include <torch/csrc/distributed/c10d/Backend.hpp>
#include <torch/csrc/distributed/c10d/Store.hpp>
#include <torch/csrc/distributed/c10d/TCPStore.hpp>
#include <torch/csrc/distributed/c10d/Work.hpp>
#else
Expand Down Expand Up @@ -124,6 +125,18 @@ class NVF_API Communicator {
return store_.get();
}

#ifdef NVFUSER_DISTRIBUTED
// Returns the store as an intrusive_ptr for use with PyTorch symmetric
// memory (c10d::symmetric_memory::set_group_info).
c10::intrusive_ptr<c10d::Store> getStore() const;

// Returns the world backend as an intrusive_ptr so it can be registered with
// c10d::register_process_group (e.g. for PyTorch symmetric memory NCCL
// rendezvous, which resolves the group by name).
c10::intrusive_ptr<c10d::Backend> getWorldBackendIntrusivePtr(
std::optional<CommunicatorBackend> backend = std::nullopt);
#endif

private:
Communicator(
CommunicatorBackend backend = comm_backend_default,
Expand Down
18 changes: 18 additions & 0 deletions csrc/multidevice/ipc_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,4 +191,22 @@ MulticastProtocol getMulticastProtocol() {
return MulticastProtocol::BatchMemcpy;
}

SymmetricMemoryBackend getSymmetricMemoryBackend() {
if (isOptionEnabled(EnableOption::SymmetricMemoryBackend)) {
if (hasEnableOptionArgument(
EnableOption::SymmetricMemoryBackend, "pytorch_nccl")) {
return SymmetricMemoryBackend::PyTorchNccl;
}
if (hasEnableOptionArgument(
EnableOption::SymmetricMemoryBackend, "pytorch_nvshmem")) {
return SymmetricMemoryBackend::PyTorchNvshmem;
}
if (hasEnableOptionArgument(
EnableOption::SymmetricMemoryBackend, "pytorch_cuda")) {
return SymmetricMemoryBackend::PyTorchCuda;
}
}
return SymmetricMemoryBackend::Native;
}

} // namespace nvfuser
13 changes: 13 additions & 0 deletions csrc/multidevice/ipc_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,19 @@ enum class MulticastProtocol { Memcpy, Multimem, BatchMemcpy };

MulticastProtocol getMulticastProtocol();

// Backend for symmetric memory allocation and rendezvous.
// Native: Fuser's own CUDA VMM + IPC implementation (default, maintained).
// PyTorch*: Use PyTorch's symmetric memory (torch.distributed._symmetric_memory)
// with the given transport backend (Nccl, Nvshmem, or Cuda).
enum class SymmetricMemoryBackend {
Native,
PyTorchNccl,
PyTorchNvshmem,
PyTorchCuda,
};

SymmetricMemoryBackend getSymmetricMemoryBackend();

// Creates a listening Unix domain socket bound to path.
// If path starts with '@', it uses the abstract namespace (replaced with \0).
// Returns the socket file descriptor.
Expand Down
140 changes: 139 additions & 1 deletion csrc/multidevice/symmetric_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,51 @@
#include "multidevice/ipc_utils.h"
#include "multidevice/utils.h"

#ifdef NVFUSER_DISTRIBUTED
#include <torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp>
#include <torch/csrc/distributed/c10d/GroupRegistry.hpp>
#endif

namespace nvfuser {

namespace {

#ifdef NVFUSER_DISTRIBUTED
const char* kPyTorchSymmMemGroupName = "nvfuser_symm";

void ensurePyTorchSymmMemBackend(SymmetricMemoryBackend backend) {
static std::once_flag once;
std::call_once(once, [backend]() {
const char* name = nullptr;
switch (backend) {
case SymmetricMemoryBackend::PyTorchNccl:
name = "NCCL";
break;
case SymmetricMemoryBackend::PyTorchNvshmem:
name = "NVSHMEM";
break;
case SymmetricMemoryBackend::PyTorchCuda:
name = "CUDA";
break;
default:
NVF_ERROR(false, "Unexpected PyTorch symmetric memory backend");
}
c10d::symmetric_memory::set_backend(name);
Communicator& comm = Communicator::getInstance();
NVF_CHECK(comm.is_available(), "Communicator not available for symmetric memory");
c10d::symmetric_memory::set_group_info(
kPyTorchSymmMemGroupName,
static_cast<int>(comm.deviceId()),
static_cast<int>(comm.size()),
comm.getStore());
// c10d::register_process_group(
// kPyTorchSymmMemGroupName,
// comm.getWorldBackendIntrusivePtr(CommunicatorBackend::kNccl));
});
}
#endif


// Returns the allocation granularity for symmetric memory.
// - query_mcast_granularity: if true, considers multicast granularity
// - query_mcast_recommended_granularity: if true, uses recommended (larger)
Expand Down Expand Up @@ -88,6 +129,39 @@ at::Tensor SymmetricTensor::allocate(
at::IntArrayRef sizes,
at::ScalarType dtype,
at::Device device) {
SymmetricMemoryBackend backend = getSymmetricMemoryBackend();

#ifdef NVFUSER_DISTRIBUTED
if (backend != SymmetricMemoryBackend::Native) {
ensurePyTorchSymmMemBackend(backend);
std::vector<int64_t> strides(sizes.size());
strides.back() = 1;
for (int64_t i = (int64_t)strides.size() - 2; i >= 0; --i) {
Comment on lines +137 to +139
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Undefined behavior when sizes is empty (0-dim tensor)

std::vector<int64_t> strides(sizes.size());
strides.back() = 1;   // UB if sizes is empty

std::vector::back() on an empty vector is undefined behaviour. The same guard-free pattern also exists in the native path further down in the same function (~line 225). While allocating a 0-dimensional symmetric tensor is unusual, the PyTorch path that was just added adds a new callsite where callers may pass {} as sizes. A simple check is sufficient:

NVF_CHECK(!sizes.empty(), "Cannot allocate a 0-dim symmetric tensor");

or initialise strides defensively (matching the standard row-major convention for 0-dim tensors, which is an empty strides vector) and skip the loop entirely when sizes is empty.

strides[i] = strides[i + 1] * sizes[i + 1];
}
// NCCLSymmetricMemoryAllocator::alloc must not be called with a group_name.
c10::optional<std::string> alloc_group_name =
(backend == SymmetricMemoryBackend::PyTorchNccl)
? c10::nullopt
: c10::optional<std::string>(kPyTorchSymmMemGroupName);
return c10d::symmetric_memory::empty_strided_p2p(
sizes,
strides,
dtype,
device,
alloc_group_name,
c10::nullopt);
}
#else
if (backend != SymmetricMemoryBackend::Native) {
NVF_ERROR(
false,
"PyTorch symmetric memory backend requires a build with "
"NVFUSER_DISTRIBUTED. Use NVFUSER_ENABLE=symmetric_memory_backend(native) "
"or do not set symmetric_memory_backend.");
}
#endif

int is_vmm_supported;
NVFUSER_CUDA_SAFE_CALL(cuDeviceGetAttribute(
&is_vmm_supported,
Expand Down Expand Up @@ -212,6 +286,18 @@ SymmetricTensor::SymmetricTensor(const at::Tensor& local_tensor)
"Expected CUDA tensor, got: ",
local_tensor.device());

#ifdef NVFUSER_DISTRIBUTED
SymmetricMemoryBackend backend = getSymmetricMemoryBackend();
if (backend != SymmetricMemoryBackend::Native) {
ensurePyTorchSymmMemBackend(backend);
Communicator& comm = Communicator::getInstance();
world_size_ = comm.size();
my_device_id_ = comm.deviceId();
requested_size_ = local_tensor.numel() * local_tensor.element_size();
return; // Rendezvous runs in setupRemoteHandles()
}
#endif

std::string error = SymmetricTensor::validate(local_tensor);
NVF_CHECK(error.empty(), "Invalid symmetric allocation: ", error);

Expand Down Expand Up @@ -253,6 +339,11 @@ SymmetricTensor::SymmetricTensor(const at::Tensor& local_tensor)
}

SymmetricTensor::~SymmetricTensor() {
#ifdef NVFUSER_DISTRIBUTED
if (py_symm_handle_) {
return; // PyTorch backend: no native VMM cleanup
}
#endif
#if (CUDA_VERSION >= 13000)
if (is_multicast_setup_) {
if (mc_base_ptr_) {
Expand Down Expand Up @@ -302,6 +393,20 @@ void SymmetricTensor::setupRemoteHandles(const std::string& tag) {
if (are_remote_tensors_setup_ == true) {
return;
}
#ifdef NVFUSER_DISTRIBUTED
// PyTorch backend: perform rendezvous here (lazy, on first setupRemoteHandles).
if (getSymmetricMemoryBackend() != SymmetricMemoryBackend::Native) {
ensurePyTorchSymmMemBackend(getSymmetricMemoryBackend());
Comment on lines +398 to +399
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getSymmetricMemoryBackend() is invoked twice in back-to-back lines, which redundantly re-parses the option string on each call. A single local variable should be used:

Suggested change
if (getSymmetricMemoryBackend() != SymmetricMemoryBackend::Native) {
ensurePyTorchSymmMemBackend(getSymmetricMemoryBackend());
SymmetricMemoryBackend backend = getSymmetricMemoryBackend();
if (backend != SymmetricMemoryBackend::Native) {
ensurePyTorchSymmMemBackend(backend);

py_symm_handle_ = c10d::symmetric_memory::rendezvous(
local_tensor_, c10::optional<std::string>(kPyTorchSymmMemGroupName));
are_remote_tensors_setup_ = true;
if (py_symm_handle_->has_multicast_support()) {
is_multicast_setup_ = true;
mc_ptr_ = py_symm_handle_->get_multicast_ptr();
}
return;
}
#endif
Communicator& comm = Communicator::getInstance();
CUmemGenericAllocationHandle local_handle = alloc_handles_[my_device_id_];
CUdeviceptr local_ptr = remote_ptrs_[my_device_id_];
Expand Down Expand Up @@ -379,6 +484,13 @@ at::Tensor SymmetricTensor::remoteTensor(int64_t rank) const {
return local_tensor_;
}

#ifdef NVFUSER_DISTRIBUTED
if (py_symm_handle_) {
return py_symm_handle_->get_remote_tensor(
rank, local_tensor_.sizes(), local_tensor_.scalar_type());
}
#endif

NVF_CHECK(are_remote_tensors_setup_ == true, "Remote tensors not setup");
return at::from_blob(
reinterpret_cast<void*>(remote_ptrs_[rank]),
Expand All @@ -390,6 +502,13 @@ at::Tensor SymmetricTensor::remoteTensor(int64_t rank) const {
}

void* SymmetricTensor::multicastPtr() const {
#ifdef NVFUSER_DISTRIBUTED
if (py_symm_handle_) {
return py_symm_handle_->has_multicast_support()
? py_symm_handle_->get_multicast_ptr()
: nullptr;
}
#endif
Comment on lines 504 to +511
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

multicastPtr() silently returns nullptr for PyTorch backend when multicast is not supported, which is inconsistent with the native path (which calls NVF_CHECK(is_multicast_setup_, "Multicast not setup")).

Any caller that does not check for nullptr before using the pointer will trigger a null pointer dereference / silent GPU fault rather than a clear diagnostic error.

Consider throwing or at least asserting instead of silently returning nullptr:

Suggested change
void* SymmetricTensor::multicastPtr() const {
#ifdef NVFUSER_DISTRIBUTED
if (py_symm_handle_) {
return py_symm_handle_->has_multicast_support()
? py_symm_handle_->get_multicast_ptr()
: nullptr;
}
#endif
void* SymmetricTensor::multicastPtr() const {
#ifdef NVFUSER_DISTRIBUTED
if (py_symm_handle_) {
NVF_CHECK(
py_symm_handle_->has_multicast_support(),
"Multicast not supported by the selected PyTorch symmetric memory backend.");
return py_symm_handle_->get_multicast_ptr();
}
#endif
NVF_CHECK(is_multicast_setup_, "Multicast not setup");
return mc_ptr_;
}

This brings the error contract in line with the native path, where multicastPtr() always either returns a valid pointer or throws.

NVF_CHECK(is_multicast_setup_, "Multicast not setup");
return mc_ptr_;
}
Expand All @@ -398,7 +517,14 @@ void SymmetricTensor::setupContiguousView(const std::string& tag) {
if (is_contiguous_view_setup_) {
return;
}

#ifdef NVFUSER_DISTRIBUTED
if (py_symm_handle_) {
NVF_ERROR(
false,
"Contiguous view is not yet supported for PyTorch symmetric memory backend. "
"Use native backend for SymmetricContiguousView.");
}
#endif
NVF_CHECK(
are_remote_tensors_setup_ == true,
"Remote tensors must be setup before setupContiguousView");
Expand Down Expand Up @@ -462,13 +588,25 @@ void SymmetricTensor::setupContiguousView(const std::string& tag) {
}

at::Tensor SymmetricTensor::getContiguousView() const {
#ifdef NVFUSER_DISTRIBUTED
if (py_symm_handle_) {
NVF_ERROR(
false,
"Contiguous view is not yet supported for PyTorch symmetric memory backend.");
}
#endif
NVF_CHECK(is_contiguous_view_setup_, "Contiguous view not setup");
return contiguous_view_;
}

void SymmetricTensor::setupMulticast(
int64_t exporter_rank,
const std::string& tag) {
#ifdef NVFUSER_DISTRIBUTED
if (py_symm_handle_) {
return; // PyTorch backend: multicast handled by backend if supported
}
#endif
#if (CUDA_VERSION >= 13000)
if (is_multicast_setup_) {
return;
Expand Down
21 changes: 15 additions & 6 deletions csrc/multidevice/symmetric_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
#include <ATen/core/Tensor.h>
#include <cuda.h>

#ifdef NVFUSER_DISTRIBUTED
#include <torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp>
#endif

namespace nvfuser {

// SymmetricTensor wraps a local symmetric memory allocation and enables:
Expand All @@ -18,13 +22,14 @@ namespace nvfuser {
// - Contiguous view creation across all ranks
//
// Design: Decouples local allocation from IPC handle exchange for better
// interoperability and support for pre-allocated user buffers
// interoperability and support for pre-allocated user buffers.
//
// TODO: Long term plan is to integrate pytorch's native symmetric memory as a
// possible backend. One important reason to use pytorch's allocator is to use
// pytorch's memory pool to let the framework own the memory stack and not
// further fragment the memory. On the other hand, having our own implementation
// allows us to experiment more advanced features like contigous view creation.
// Backends (see SymmetricMemoryBackend in ipc_utils.h):
// - Native (default): Fuser's own CUDA VMM + IPC implementation; maintained.
// - PyTorch (Nccl, Nvshmem, Cuda): Use PyTorch's symmetric memory
// (torch.distributed._symmetric_memory) with the chosen transport backend.
// Select via NVFUSER_ENABLE=symmetric_memory_backend(pytorch_nccl|pytorch_nvshmem|pytorch_cuda).
// Native remains the default when the option is not set.
class SymmetricTensor {
public:
// Wrap pre-allocated symmetric tensor (must use allocate())
Expand Down Expand Up @@ -79,6 +84,10 @@ class SymmetricTensor {
int peer_fd_{-1};
bool is_contiguous_view_setup_ = false;
at::Tensor contiguous_view_;
#ifdef NVFUSER_DISTRIBUTED
// When set, remote/multicast APIs delegate to PyTorch symmetric memory.
c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> py_symm_handle_;
#endif
};

} // namespace nvfuser
1 change: 1 addition & 0 deletions csrc/options.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ const std::unordered_map<std::string, EnableOption>& getEnableOptions() {
{"fast_math", EnableOption::FastMath},
{"p2p_protocol", EnableOption::P2pProtocol},
{"multicast_protocol", EnableOption::MulticastProtocol},
{"symmetric_memory_backend", EnableOption::SymmetricMemoryBackend},
{"parallel_serde", EnableOption::ParallelSerde},
};
return available_options;
Expand Down
2 changes: 2 additions & 0 deletions csrc/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ enum class EnableOption {
P2pProtocol, //! Prescribe P2P protocol: put|get
MulticastProtocol, //! Prescribe multicast protocol:
//! memcpy|multimem|batch_memcpy
SymmetricMemoryBackend, //! Prescribe symmetric memory backend:
//! native|pytorch_nccl|pytorch_nvshmem|pytorch_cuda
ParallelSerde, //! Enable deserializing FusionExecutorCache in parallel
EndOfOption //! Placeholder for counting the number of elements
};
Expand Down
24 changes: 24 additions & 0 deletions fbuild.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/bin/bash

export CC=clang-20
export CXX=clang++-20
export LDFLAGS="-fuse-ld=mold"

export NVFUSER_BUILD_ENABLE_PCH

export UCC_HOME="/opt/hpcx/ucc"
export UCC_DIR="/opt/hpcx/ucc/lib/cmake/ucc"
export UCX_HOME="/opt/hpcx/ucx"
export UCX_DIR="/opt/hpcx/ucx/lib/cmake/ucx"

# export TORCH_CUDA_ARCH_LIST="9.0"

export NVFUSER_BUILD_WITH_UCC=1
export NVFUSER_BUILD_INSTALL_DIR=$BUILD_DIRECTORY/nvfuser
export NVFUSER_BUILD_DIR=$BUILD_DIRECTORY

# Enable debug mode, leave empty for non-debug compilation
export NVFUSER_BUILD_BUILD_TYPE=Debug
export RUN_CMAKE=""

pip install -v -e ./python --no-build-isolation
Comment on lines +1 to +24
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personal developer build script committed to repository

This script contains machine-specific, hardcoded toolchain paths that are unlikely to work anywhere except the author's development machine:

  • clang-20 and clang++-20 — not a standard compiler version available broadly
  • -fuse-ld=mold — requires the mold linker to be installed
  • /opt/hpcx/ucc and /opt/hpcx/ucx — HPC-X installation path specific to the author's environment
  • $BUILD_DIRECTORY is used but never validated; if it is unset, NVFUSER_BUILD_INSTALL_DIR and NVFUSER_BUILD_DIR will silently be empty strings, likely breaking the build

This kind of personal convenience script should live outside version control (e.g., in a .gitignore-d directory or in the author's home directory). Committing it to the main repo risks confusing other contributors and cluttering the root directory.

Loading
Loading