From c2941a06eefe48873921d801c7bf05a3ca761d6e Mon Sep 17 00:00:00 2001 From: snordmann Date: Tue, 24 Feb 2026 06:18:37 -0800 Subject: [PATCH 1/5] add multidevice tma p2p standalone tests --- CMakeLists.txt | 4 +- tests/cpp/test_multidevice_ipc.cpp | 438 +++++++++++++++++++++++++++++ tests/cpp/tma_test_kernels.cu | 118 ++++++++ tests/cpp/tma_test_kernels.h | 24 ++ 4 files changed, 583 insertions(+), 1 deletion(-) create mode 100644 tests/cpp/tma_test_kernels.cu create mode 100644 tests/cpp/tma_test_kernels.h diff --git a/CMakeLists.txt b/CMakeLists.txt index f5dad5c90d5..4113fdc5a95 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -919,7 +919,7 @@ function(add_test_without_main TEST_NAME TEST_SRC ADDITIONAL_LINK) if(NOT MSVC) target_compile_options(${TEST_NAME} PRIVATE - -Wall -Wno-unused-function -Werror + $<$:-Wall -Wno-unused-function -Werror> ) endif() endfunction() @@ -1019,6 +1019,7 @@ if(BUILD_TEST) ${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir_overlap.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_ipc.cpp + ${NVFUSER_ROOT}/tests/cpp/tma_test_kernels.cu ${NVFUSER_ROOT}/tests/cpp/test_multidevice_lower_communication.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_lower_communication_cuda.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_matmul.cpp @@ -1029,6 +1030,7 @@ if(BUILD_TEST) ${NVFUSER_ROOT}/tests/cpp/test_multidevice_transformer.cpp ) add_test_without_main(test_multidevice "${MULTIDEVICE_TEST_SRCS}" "") + set_property(TARGET test_multidevice PROPERTY CUDA_STANDARD ${NVFUSER_CUDA_STANDARD}) list(APPEND TEST_BINARIES test_multidevice) set(MULTIDEVICE_TUTORIAL_SRCS) diff --git a/tests/cpp/test_multidevice_ipc.cpp b/tests/cpp/test_multidevice_ipc.cpp index 90a615602d7..5e122cb85e7 100644 --- a/tests/cpp/test_multidevice_ipc.cpp +++ b/tests/cpp/test_multidevice_ipc.cpp @@ -15,6 +15,7 @@ #include "multidevice/utils.h" #include "ops/all_ops.h" #include "tests/cpp/multidevice.h" +#include "tests/cpp/tma_test_kernels.h" namespace nvfuser { @@ -1069,4 +1070,441 @@ TEST_F(IpcTest, VmmMultiRankContiguousMappingTest) { } } +// ============================================================================= +// TMA (Tensor Memory Accelerator) tests +// +// These tests exercise the Hopper TMA 1D bulk copy (cp.async.bulk) for +// different memory sources: local device memory, VMM-mapped peer memory, +// and NVLS multicast unicast pointers. +// ============================================================================= + +TEST_F(IpcTest, TmaLocalCopy) { + const int64_t local_rank = communicator_->local_rank(); + + int major; + NVFUSER_CUDA_RT_SAFE_CALL(cudaDeviceGetAttribute( + &major, cudaDevAttrComputeCapabilityMajor, local_rank)); + if (major < 9) { + GTEST_SKIP() << "Requires Hopper (SM90+)"; + } + + NVFUSER_CUDA_RT_SAFE_CALL(cudaSetDevice(local_rank)); + + constexpr int kNumElems = 256; + constexpr int kSizeBytes = kNumElems * sizeof(uint32_t); + static_assert(kSizeBytes % 16 == 0); + + void* d_src; + void* d_dst; + NVFUSER_CUDA_RT_SAFE_CALL(cudaMalloc(&d_src, kSizeBytes)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaMalloc(&d_dst, kSizeBytes)); + + std::vector host_src(kNumElems); + for (int i = 0; i < kNumElems; i++) { + host_src[i] = + static_cast(communicator_->deviceId() * 1000 + i * 7 + 42); + } + NVFUSER_CUDA_RT_SAFE_CALL( + cudaMemcpy(d_src, host_src.data(), kSizeBytes, cudaMemcpyHostToDevice)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaMemset(d_dst, 0, kSizeBytes)); + + launchTmaCopy1D(d_dst, d_src, kSizeBytes, 0); + NVFUSER_CUDA_RT_SAFE_CALL(cudaDeviceSynchronize()); + + std::vector result(kNumElems); + NVFUSER_CUDA_RT_SAFE_CALL( + cudaMemcpy(result.data(), d_dst, kSizeBytes, cudaMemcpyDeviceToHost)); + + for (int i = 0; i < kNumElems; i++) { + EXPECT_EQ(result[i], host_src[i]) << "Mismatch at index " << i; + } + + NVFUSER_CUDA_RT_SAFE_CALL(cudaFree(d_src)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaFree(d_dst)); +} + +TEST_F(IpcTest, TmaInterDeviceCopy) { + if (communicator_->size() == 1) { + GTEST_SKIP() << "Skipping test for single device"; + } + + const int64_t num_devices = communicator_->size(); + const int64_t rank = communicator_->deviceId(); + const int64_t local_rank = communicator_->local_rank(); + + int major; + NVFUSER_CUDA_RT_SAFE_CALL(cudaDeviceGetAttribute( + &major, cudaDevAttrComputeCapabilityMajor, local_rank)); + if (major < 9) { + GTEST_SKIP() << "Requires Hopper (SM90+)"; + } + + int is_vmm_supported; + NVFUSER_CUDA_SAFE_CALL(cuDeviceGetAttribute( + &is_vmm_supported, + CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, + local_rank)); + if (is_vmm_supported == 0) { + GTEST_SKIP() + << "Device does not support Virtual Memory Management; skipping."; + } + + int is_ipc_supported; + NVFUSER_CUDA_SAFE_CALL(cuDeviceGetAttribute( + &is_ipc_supported, + CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR_SUPPORTED, + local_rank)); + if (is_ipc_supported == 0) { + GTEST_SKIP() << "Device does not support IPC handles; skipping."; + } + + NVFUSER_CUDA_RT_SAFE_CALL(cudaSetDevice(local_rank)); + + // VMM allocation for the source buffer on each rank + CUmemAllocationProp prop{}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = static_cast(local_rank); + prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; + + size_t granularity = 0; + NVFUSER_CUDA_SAFE_CALL(cuMemGetAllocationGranularity( + &granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM)); + + constexpr size_t kNumElems = 256; + constexpr size_t kSizeBytes = kNumElems * sizeof(uint32_t); + static_assert(kSizeBytes % 16 == 0); + size_t aligned_size = + ((kSizeBytes + granularity - 1) / granularity) * granularity; + + CUmemGenericAllocationHandle mem_handle = 0; + NVFUSER_CUDA_SAFE_CALL( + cuMemCreate(&mem_handle, aligned_size, &prop, /*flags=*/0)); + + CUdeviceptr d_ptr = 0; + NVFUSER_CUDA_SAFE_CALL(cuMemAddressReserve( + &d_ptr, aligned_size, /*alignment=*/granularity, /*baseVA=*/0, + /*flags=*/0)); + NVFUSER_CUDA_SAFE_CALL( + cuMemMap(d_ptr, aligned_size, /*offset=*/0, mem_handle, /*flags=*/0)); + + CUmemAccessDesc access{}; + access.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + access.location.id = static_cast(local_rank); + access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + NVFUSER_CUDA_SAFE_CALL( + cuMemSetAccess(d_ptr, aligned_size, &access, /*count=*/1)); + + // Each rank writes its own pattern + std::vector host_data(kNumElems); + for (size_t i = 0; i < kNumElems; i++) { + host_data[i] = static_cast(rank * 10000 + i); + } + NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpy( + reinterpret_cast(d_ptr), + host_data.data(), + kSizeBytes, + cudaMemcpyHostToDevice)); + + // Export handle and exchange via Unix sockets (same pattern as IpcP2pWithVmm) + int shared_fd; + NVFUSER_CUDA_SAFE_CALL(cuMemExportToShareableHandle( + &shared_fd, + mem_handle, + CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, + /*flags=*/0)); + + std::string my_socket = "@nvfuser_tma_p2p_" + std::to_string(rank); + int listener_fd = nvfuser::createIpcSocket(my_socket); + + communicator_->barrier(); + + const int64_t peer_rank = (rank + 1) % num_devices; + const int64_t receiver_rank = (rank - 1 + num_devices) % num_devices; + std::string receiver_path = + "@nvfuser_tma_p2p_" + std::to_string(receiver_rank); + + nvfuser::sendFd(receiver_path, shared_fd); + int peer_fd = nvfuser::recvFd(listener_fd); + + close(listener_fd); + close(shared_fd); + + // Import and map peer's allocation + CUmemGenericAllocationHandle peer_mem_handle = 0; + NVFUSER_CUDA_SAFE_CALL(cuMemImportFromShareableHandle( + &peer_mem_handle, + (void*)((uint64_t)peer_fd), + CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR)); + + CUdeviceptr peer_d_ptr = 0; + NVFUSER_CUDA_SAFE_CALL(cuMemAddressReserve( + &peer_d_ptr, aligned_size, /*alignment=*/granularity, /*baseVA=*/0, + /*flags=*/0)); + NVFUSER_CUDA_SAFE_CALL(cuMemMap( + peer_d_ptr, aligned_size, /*offset=*/0, peer_mem_handle, /*flags=*/0)); + + CUmemAccessDesc peer_access{}; + peer_access.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + peer_access.location.id = static_cast(local_rank); + peer_access.flags = CU_MEM_ACCESS_FLAGS_PROT_READ; + NVFUSER_CUDA_SAFE_CALL( + cuMemSetAccess(peer_d_ptr, aligned_size, &peer_access, /*count=*/1)); + + // Allocate local output buffer and TMA-copy from peer's mapped VA + void* d_output; + NVFUSER_CUDA_RT_SAFE_CALL(cudaMalloc(&d_output, kSizeBytes)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaMemset(d_output, 0, kSizeBytes)); + + communicator_->barrier(); + + launchTmaCopy1D( + d_output, reinterpret_cast(peer_d_ptr), kSizeBytes, 0); + NVFUSER_CUDA_RT_SAFE_CALL(cudaDeviceSynchronize()); + + std::vector result(kNumElems); + NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpy( + result.data(), d_output, kSizeBytes, cudaMemcpyDeviceToHost)); + + for (size_t i = 0; i < kNumElems; i++) { + uint32_t expected = static_cast(peer_rank * 10000 + i); + EXPECT_EQ(result[i], expected) + << "Rank " << rank << " mismatch at index " << i; + } + + NVFUSER_CUDA_RT_SAFE_CALL(cudaFree(d_output)); + NVFUSER_CUDA_SAFE_CALL(cuMemUnmap(peer_d_ptr, aligned_size)); + NVFUSER_CUDA_SAFE_CALL(cuMemAddressFree(peer_d_ptr, aligned_size)); + NVFUSER_CUDA_SAFE_CALL(cuMemRelease(peer_mem_handle)); + close(peer_fd); + NVFUSER_CUDA_SAFE_CALL(cuMemUnmap(d_ptr, aligned_size)); + NVFUSER_CUDA_SAFE_CALL(cuMemAddressFree(d_ptr, aligned_size)); + NVFUSER_CUDA_SAFE_CALL(cuMemRelease(mem_handle)); +} + +#if (CUDA_VERSION >= 13000) + +TEST_F(IpcTest, TmaMulticastRead) { + if (communicator_->size() == 1) { + GTEST_SKIP() << "Skipping test for single device"; + } + + const int64_t world_size = communicator_->size(); + const int64_t rank = communicator_->deviceId(); + const int64_t local_rank = communicator_->local_rank(); + + constexpr int64_t exporter_rank = 0; + constexpr int64_t root_rank = 1; + + int major; + NVFUSER_CUDA_RT_SAFE_CALL(cudaDeviceGetAttribute( + &major, cudaDevAttrComputeCapabilityMajor, local_rank)); + if (major < 9) { + GTEST_SKIP() << "Requires Hopper (SM90+)"; + } + + int is_vmm_supported; + NVFUSER_CUDA_SAFE_CALL(cuDeviceGetAttribute( + &is_vmm_supported, + CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, + local_rank)); + if (is_vmm_supported == 0) { + GTEST_SKIP() + << "Device does not support Virtual Memory Management; skipping."; + } + + int is_ipc_supported; + NVFUSER_CUDA_SAFE_CALL(cuDeviceGetAttribute( + &is_ipc_supported, + CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR_SUPPORTED, + local_rank)); + if (is_ipc_supported == 0) { + GTEST_SKIP() << "Device does not support IPC handles; skipping."; + } + + int is_multicast_supported; + NVFUSER_CUDA_SAFE_CALL(cuDeviceGetAttribute( + &is_multicast_supported, + CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, + local_rank)); + if (is_multicast_supported == 0) { + GTEST_SKIP() << "Device does not support Multicast Objects; skipping."; + } + + NVFUSER_CUDA_RT_SAFE_CALL(cudaSetDevice(local_rank)); + + // Multicast buffer: 2 MB (matches existing NVLS tests) + constexpr size_t kNumElems = 524288; + constexpr size_t kSizeBytes = kNumElems * sizeof(uint32_t); + + using handle_typename = int; + auto handle_type = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; + + CUmulticastObjectProp mcast_prop{}; + mcast_prop.flags = 0; + mcast_prop.handleTypes = handle_type; + mcast_prop.numDevices = world_size; + mcast_prop.size = kSizeBytes; + + size_t mcast_min_granularity = 0; + NVFUSER_CUDA_SAFE_CALL(cuMulticastGetGranularity( + &mcast_min_granularity, &mcast_prop, CU_MULTICAST_GRANULARITY_MINIMUM)); + if (mcast_min_granularity > kSizeBytes) { + GTEST_SKIP() << "Multicast min granularity (" << mcast_min_granularity + << ") exceeds buffer size; skipping."; + } + + size_t mcast_granularity = 0; + NVFUSER_CUDA_SAFE_CALL(cuMulticastGetGranularity( + &mcast_granularity, &mcast_prop, CU_MULTICAST_GRANULARITY_RECOMMENDED)); + if (mcast_granularity > kSizeBytes) { + GTEST_SKIP() << "Multicast recommended granularity (" << mcast_granularity + << ") exceeds buffer size; skipping."; + } + + // Create multicast object on exporter rank and share via Unix sockets + CUmemGenericAllocationHandle mcast_handle{}; + handle_typename shared_handle; + int listener_fd = -1; + + if (rank == exporter_rank) { + NVFUSER_CUDA_SAFE_CALL(cuMulticastCreate(&mcast_handle, &mcast_prop)); + NVFUSER_CUDA_SAFE_CALL(cuMemExportToShareableHandle( + &shared_handle, mcast_handle, handle_type, /*flags=*/0)); + } else { + std::string my_path = + "@nvfuser_tma_mcast_recv_" + std::to_string(rank); + listener_fd = nvfuser::createIpcSocket(my_path); + } + + communicator_->barrier(); + + if (rank != exporter_rank) { + int received_fd = nvfuser::recvFd(listener_fd); + shared_handle = received_fd; + close(listener_fd); + } else { + for (int i = 0; i < world_size; ++i) { + if (i == rank) { + continue; + } + std::string peer_path = + "@nvfuser_tma_mcast_recv_" + std::to_string(i); + nvfuser::sendFd(peer_path, shared_handle); + } + close(shared_handle); + } + + if (rank != exporter_rank) { + NVFUSER_CUDA_SAFE_CALL(cuMemImportFromShareableHandle( + &mcast_handle, (void*)((uint64_t)shared_handle), handle_type)); + close(shared_handle); + } + + CUdevice cu_dev; + NVFUSER_CUDA_SAFE_CALL(cuDeviceGet(&cu_dev, static_cast(local_rank))); + NVFUSER_CUDA_SAFE_CALL(cuMulticastAddDevice(mcast_handle, cu_dev)); + + // Local physical allocation + CUmemAllocationProp prop{}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = static_cast(local_rank); + prop.requestedHandleTypes = handle_type; + + size_t granularity = 0; + NVFUSER_CUDA_SAFE_CALL(cuMemGetAllocationGranularity( + &granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM)); + if (granularity > kSizeBytes) { + GTEST_SKIP() << "Allocation granularity (" << granularity + << ") exceeds buffer size; skipping."; + } + + CUmemGenericAllocationHandle local_buffer = 0; + NVFUSER_CUDA_SAFE_CALL( + cuMemCreate(&local_buffer, kSizeBytes, &prop, /*flags=*/0)); + + NVFUSER_CUDA_SAFE_CALL(cuMulticastBindMem( + mcast_handle, /*mcOffset=*/0, local_buffer, /*memOffset=*/0, + kSizeBytes, /*flags=*/0)); + + // MC (multicast) mapping — used for broadcast writes + CUdeviceptr mc_ptr = 0; + NVFUSER_CUDA_SAFE_CALL(cuMemAddressReserve( + &mc_ptr, kSizeBytes, /*alignment=*/mcast_granularity, + /*baseVA=*/0, /*flags=*/0)); + NVFUSER_CUDA_SAFE_CALL( + cuMemMap(mc_ptr, kSizeBytes, /*offset=*/0, mcast_handle, /*flags=*/0)); + CUmemAccessDesc mc_access{}; + mc_access.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + mc_access.location.id = static_cast(local_rank); + mc_access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + NVFUSER_CUDA_SAFE_CALL( + cuMemSetAccess(mc_ptr, kSizeBytes, &mc_access, /*count=*/1)); + + // UC (unicast) mapping — used for local reads + CUdeviceptr uc_ptr = 0; + NVFUSER_CUDA_SAFE_CALL(cuMemAddressReserve( + &uc_ptr, kSizeBytes, /*alignment=*/granularity, + /*baseVA=*/0, /*flags=*/0)); + NVFUSER_CUDA_SAFE_CALL( + cuMemMap(uc_ptr, kSizeBytes, /*offset=*/0, local_buffer, /*flags=*/0)); + CUmemAccessDesc uc_access{}; + uc_access.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + uc_access.location.id = static_cast(local_rank); + uc_access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + NVFUSER_CUDA_SAFE_CALL( + cuMemSetAccess(uc_ptr, kSizeBytes, &uc_access, /*count=*/1)); + + // Root broadcasts data via the MC pointer + std::vector host_buffer(kNumElems); + if (rank == root_rank) { + for (size_t i = 0; i < kNumElems; ++i) { + host_buffer[i] = static_cast(i * 3 + 17); + } + NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpy( + reinterpret_cast(mc_ptr), + host_buffer.data(), + kSizeBytes, + cudaMemcpyHostToDevice)); + } + + communicator_->barrier(); + + // Use TMA to copy a portion from the UC pointer to a local output buffer. + // The UC pointer maps to local physical memory that received the multicast + // data, so TMA should be able to read from it. + constexpr int kTmaBytes = 4096; + static_assert(kTmaBytes % 16 == 0); + static_assert(kTmaBytes <= kSizeBytes); + constexpr int kTmaElems = kTmaBytes / sizeof(uint32_t); + + void* d_output; + NVFUSER_CUDA_RT_SAFE_CALL(cudaMalloc(&d_output, kTmaBytes)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaMemset(d_output, 0, kTmaBytes)); + + launchTmaCopy1D(d_output, reinterpret_cast(uc_ptr), kTmaBytes, 0); + NVFUSER_CUDA_RT_SAFE_CALL(cudaDeviceSynchronize()); + + std::vector result(kTmaElems); + NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpy( + result.data(), d_output, kTmaBytes, cudaMemcpyDeviceToHost)); + + for (int i = 0; i < kTmaElems; ++i) { + uint32_t expected = static_cast(i * 3 + 17); + EXPECT_EQ(result[i], expected) + << "Rank " << rank << " mismatch at index " << i; + } + + NVFUSER_CUDA_RT_SAFE_CALL(cudaFree(d_output)); + NVFUSER_CUDA_SAFE_CALL(cuMemUnmap(mc_ptr, kSizeBytes)); + NVFUSER_CUDA_SAFE_CALL(cuMemUnmap(uc_ptr, kSizeBytes)); + NVFUSER_CUDA_SAFE_CALL(cuMemAddressFree(mc_ptr, kSizeBytes)); + NVFUSER_CUDA_SAFE_CALL(cuMemAddressFree(uc_ptr, kSizeBytes)); + NVFUSER_CUDA_SAFE_CALL(cuMemRelease(local_buffer)); + NVFUSER_CUDA_SAFE_CALL(cuMemRelease(mcast_handle)); +} + +#endif // CUDA_VERSION >= 13000 + } // namespace nvfuser diff --git a/tests/cpp/tma_test_kernels.cu b/tests/cpp/tma_test_kernels.cu new file mode 100644 index 00000000000..5ca8f6c917d --- /dev/null +++ b/tests/cpp/tma_test_kernels.cu @@ -0,0 +1,118 @@ +// clang-format off +/* +* SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. +* All rights reserved. +* SPDX-License-Identifier: BSD-3-Clause +*/ +// clang-format on + +#include "tests/cpp/tma_test_kernels.h" + +#include +#include + +namespace nvfuser { + +// TMA 1D bulk copy kernel: GMEM(src) -> SMEM -> GMEM(dst). +// Inspired by DeepEP's tma_load_1d / tma_store_1d pattern. +// A single elected thread issues all TMA operations while the rest of the warp +// idles. mbarrier synchronization ensures the async TMA load completes before +// the TMA store reads from shared memory. +// +// Dynamic shared memory layout (128-byte aligned): +// [0, num_bytes) : data staging buffer +// [num_bytes, num_bytes+8) : mbarrier (uint64_t, 16-byte aligned since +// num_bytes is a multiple of 16) +__global__ void __launch_bounds__(32, 1) + tma_copy_1d_kernel( + void* __restrict__ dst, + const void* __restrict__ src, + int num_bytes) { + extern __shared__ __align__(128) uint8_t smem[]; + + auto* mbar = reinterpret_cast(smem + num_bytes); + auto smem_addr = + static_cast(__cvta_generic_to_shared(smem)); + auto mbar_addr = + static_cast(__cvta_generic_to_shared(mbar)); + + if (threadIdx.x == 0) { + // Initialize mbarrier with arrival count = 1 + asm volatile( + "mbarrier.init.shared::cta.b64 [%0], %1;" + ::"r"(mbar_addr), "r"(1)); + // Ensure init is visible cluster-wide before any use + asm volatile( + "fence.mbarrier_init.release.cluster;" :::); + } + __syncwarp(); + + if (threadIdx.x == 0) { + // Announce expected number of transaction bytes on the mbarrier + asm volatile( + "mbarrier.arrive.expect_tx.shared::cta.b64 _, [%0], %1;" + ::"r"(mbar_addr), "r"(num_bytes)); + + // TMA Load: GMEM -> SMEM (async, completed via mbarrier) + asm volatile( + "cp.async.bulk.shared::cluster.global" + ".mbarrier::complete_tx::bytes" + " [%0], [%1], %2, [%3];\n" + ::"r"(smem_addr), + "l"(src), + "r"(num_bytes), + "r"(mbar_addr) + : "memory"); + + // Block until the mbarrier phase flips (TMA load completed). + // Phase 0 is the initial phase after mbarrier.init. + asm volatile( + "{\n" + ".reg .pred P1;\n" + "TMA_COPY_WAIT_LOAD:\n" + "mbarrier.try_wait.parity.shared::cta.b64" + " P1, [%0], %1;\n" + "@P1 bra TMA_COPY_LOAD_DONE;\n" + "bra TMA_COPY_WAIT_LOAD;\n" + "TMA_COPY_LOAD_DONE:\n" + "}" + ::"r"(mbar_addr), "r"(0)); + + // TMA Store: SMEM -> GMEM + // No fence.proxy.async needed here because both the load and store + // operate through the async proxy; the mbarrier completion already + // establishes the necessary ordering (cf. DeepEP intranode.cu). + asm volatile( + "cp.async.bulk.global.shared::cta.bulk_group" + " [%0], [%1], %2;\n" + ::"l"(dst), + "r"(smem_addr), + "r"(num_bytes) + : "memory"); + asm volatile("cp.async.bulk.commit_group;"); + asm volatile( + "cp.async.bulk.wait_group.read 0;" ::: "memory"); + + // Invalidate mbarrier before kernel exit + asm volatile( + "mbarrier.inval.shared::cta.b64 [%0];" + ::"r"(mbar_addr)); + } +} + +void launchTmaCopy1D( + void* dst, + const void* src, + int num_bytes, + cudaStream_t stream) { + assert(num_bytes > 0 && "num_bytes must be positive"); + assert( + num_bytes % 16 == 0 && + "cp.async.bulk requires size to be a multiple of 16 bytes"); + + // data buffer + mbarrier (8 bytes) + int smem_size = num_bytes + static_cast(sizeof(uint64_t)); + tma_copy_1d_kernel<<<1, 32, smem_size, stream>>>(dst, src, num_bytes); +} + +} // namespace nvfuser diff --git a/tests/cpp/tma_test_kernels.h b/tests/cpp/tma_test_kernels.h new file mode 100644 index 00000000000..8921decedae --- /dev/null +++ b/tests/cpp/tma_test_kernels.h @@ -0,0 +1,24 @@ +// clang-format off +/* +* SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. +* All rights reserved. +* SPDX-License-Identifier: BSD-3-Clause +*/ +// clang-format on +#pragma once + +#include + +namespace nvfuser { + +//! Copies num_bytes from src (GMEM) to dst (GMEM) via TMA 1D bulk copy: +//! GMEM(src) -> SMEM -> GMEM(dst) +//! Uses cp.async.bulk with mbarrier synchronization (SM90+ / Hopper). +//! num_bytes must be a multiple of 16 and > 0. +void launchTmaCopy1D( + void* dst, + const void* src, + int num_bytes, + cudaStream_t stream); + +} // namespace nvfuser From 64c6162bc2f9e9d453a2423f436b65f02de3c90a Mon Sep 17 00:00:00 2001 From: snordmann Date: Tue, 24 Feb 2026 11:26:03 -0800 Subject: [PATCH 2/5] clean and fix the tests --- CMakeLists.txt | 10 +- tests/cpp/test_multidevice_ipc.cpp | 438 ----------------------- tests/cpp/test_multidevice_tma.cpp | 288 +++++++++++++++ tests/cpp/test_multidevice_tma_kernel.cu | 92 +++++ tests/cpp/tma_test_kernels.cu | 118 ------ tests/cpp/tma_test_kernels.h | 24 -- 6 files changed, 387 insertions(+), 583 deletions(-) create mode 100644 tests/cpp/test_multidevice_tma.cpp create mode 100644 tests/cpp/test_multidevice_tma_kernel.cu delete mode 100644 tests/cpp/tma_test_kernels.cu delete mode 100644 tests/cpp/tma_test_kernels.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 4113fdc5a95..eaa5d06b7c2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -919,7 +919,7 @@ function(add_test_without_main TEST_NAME TEST_SRC ADDITIONAL_LINK) if(NOT MSVC) target_compile_options(${TEST_NAME} PRIVATE - $<$:-Wall -Wno-unused-function -Werror> + -Wall -Wno-unused-function -Werror ) endif() endfunction() @@ -1019,7 +1019,7 @@ if(BUILD_TEST) ${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir_overlap.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_ipc.cpp - ${NVFUSER_ROOT}/tests/cpp/tma_test_kernels.cu + ${NVFUSER_ROOT}/tests/cpp/test_multidevice_tma.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_lower_communication.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_lower_communication_cuda.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_matmul.cpp @@ -1030,7 +1030,10 @@ if(BUILD_TEST) ${NVFUSER_ROOT}/tests/cpp/test_multidevice_transformer.cpp ) add_test_without_main(test_multidevice "${MULTIDEVICE_TEST_SRCS}" "") - set_property(TARGET test_multidevice PROPERTY CUDA_STANDARD ${NVFUSER_CUDA_STANDARD}) + target_include_directories(test_multidevice PRIVATE + "${CMAKE_BINARY_DIR}/include") + add_dependencies(test_multidevice + nvfuser_rt_test_multidevice_tma_kernel) list(APPEND TEST_BINARIES test_multidevice) set(MULTIDEVICE_TUTORIAL_SRCS) @@ -1266,6 +1269,7 @@ list(APPEND NVFUSER_RUNTIME_FILES ${NVFUSER_ROOT}/runtime/memory.cu ${NVFUSER_ROOT}/runtime/multicast.cu ${NVFUSER_SRCS_DIR}/multidevice/alltoallv.cu + ${NVFUSER_ROOT}/tests/cpp/test_multidevice_tma_kernel.cu ${NVFUSER_ROOT}/runtime/random_numbers.cu ${NVFUSER_ROOT}/runtime/tensor_memory.cu ${NVFUSER_ROOT}/runtime/tensor.cu diff --git a/tests/cpp/test_multidevice_ipc.cpp b/tests/cpp/test_multidevice_ipc.cpp index 5e122cb85e7..90a615602d7 100644 --- a/tests/cpp/test_multidevice_ipc.cpp +++ b/tests/cpp/test_multidevice_ipc.cpp @@ -15,7 +15,6 @@ #include "multidevice/utils.h" #include "ops/all_ops.h" #include "tests/cpp/multidevice.h" -#include "tests/cpp/tma_test_kernels.h" namespace nvfuser { @@ -1070,441 +1069,4 @@ TEST_F(IpcTest, VmmMultiRankContiguousMappingTest) { } } -// ============================================================================= -// TMA (Tensor Memory Accelerator) tests -// -// These tests exercise the Hopper TMA 1D bulk copy (cp.async.bulk) for -// different memory sources: local device memory, VMM-mapped peer memory, -// and NVLS multicast unicast pointers. -// ============================================================================= - -TEST_F(IpcTest, TmaLocalCopy) { - const int64_t local_rank = communicator_->local_rank(); - - int major; - NVFUSER_CUDA_RT_SAFE_CALL(cudaDeviceGetAttribute( - &major, cudaDevAttrComputeCapabilityMajor, local_rank)); - if (major < 9) { - GTEST_SKIP() << "Requires Hopper (SM90+)"; - } - - NVFUSER_CUDA_RT_SAFE_CALL(cudaSetDevice(local_rank)); - - constexpr int kNumElems = 256; - constexpr int kSizeBytes = kNumElems * sizeof(uint32_t); - static_assert(kSizeBytes % 16 == 0); - - void* d_src; - void* d_dst; - NVFUSER_CUDA_RT_SAFE_CALL(cudaMalloc(&d_src, kSizeBytes)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaMalloc(&d_dst, kSizeBytes)); - - std::vector host_src(kNumElems); - for (int i = 0; i < kNumElems; i++) { - host_src[i] = - static_cast(communicator_->deviceId() * 1000 + i * 7 + 42); - } - NVFUSER_CUDA_RT_SAFE_CALL( - cudaMemcpy(d_src, host_src.data(), kSizeBytes, cudaMemcpyHostToDevice)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaMemset(d_dst, 0, kSizeBytes)); - - launchTmaCopy1D(d_dst, d_src, kSizeBytes, 0); - NVFUSER_CUDA_RT_SAFE_CALL(cudaDeviceSynchronize()); - - std::vector result(kNumElems); - NVFUSER_CUDA_RT_SAFE_CALL( - cudaMemcpy(result.data(), d_dst, kSizeBytes, cudaMemcpyDeviceToHost)); - - for (int i = 0; i < kNumElems; i++) { - EXPECT_EQ(result[i], host_src[i]) << "Mismatch at index " << i; - } - - NVFUSER_CUDA_RT_SAFE_CALL(cudaFree(d_src)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaFree(d_dst)); -} - -TEST_F(IpcTest, TmaInterDeviceCopy) { - if (communicator_->size() == 1) { - GTEST_SKIP() << "Skipping test for single device"; - } - - const int64_t num_devices = communicator_->size(); - const int64_t rank = communicator_->deviceId(); - const int64_t local_rank = communicator_->local_rank(); - - int major; - NVFUSER_CUDA_RT_SAFE_CALL(cudaDeviceGetAttribute( - &major, cudaDevAttrComputeCapabilityMajor, local_rank)); - if (major < 9) { - GTEST_SKIP() << "Requires Hopper (SM90+)"; - } - - int is_vmm_supported; - NVFUSER_CUDA_SAFE_CALL(cuDeviceGetAttribute( - &is_vmm_supported, - CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, - local_rank)); - if (is_vmm_supported == 0) { - GTEST_SKIP() - << "Device does not support Virtual Memory Management; skipping."; - } - - int is_ipc_supported; - NVFUSER_CUDA_SAFE_CALL(cuDeviceGetAttribute( - &is_ipc_supported, - CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR_SUPPORTED, - local_rank)); - if (is_ipc_supported == 0) { - GTEST_SKIP() << "Device does not support IPC handles; skipping."; - } - - NVFUSER_CUDA_RT_SAFE_CALL(cudaSetDevice(local_rank)); - - // VMM allocation for the source buffer on each rank - CUmemAllocationProp prop{}; - prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; - prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - prop.location.id = static_cast(local_rank); - prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; - - size_t granularity = 0; - NVFUSER_CUDA_SAFE_CALL(cuMemGetAllocationGranularity( - &granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM)); - - constexpr size_t kNumElems = 256; - constexpr size_t kSizeBytes = kNumElems * sizeof(uint32_t); - static_assert(kSizeBytes % 16 == 0); - size_t aligned_size = - ((kSizeBytes + granularity - 1) / granularity) * granularity; - - CUmemGenericAllocationHandle mem_handle = 0; - NVFUSER_CUDA_SAFE_CALL( - cuMemCreate(&mem_handle, aligned_size, &prop, /*flags=*/0)); - - CUdeviceptr d_ptr = 0; - NVFUSER_CUDA_SAFE_CALL(cuMemAddressReserve( - &d_ptr, aligned_size, /*alignment=*/granularity, /*baseVA=*/0, - /*flags=*/0)); - NVFUSER_CUDA_SAFE_CALL( - cuMemMap(d_ptr, aligned_size, /*offset=*/0, mem_handle, /*flags=*/0)); - - CUmemAccessDesc access{}; - access.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - access.location.id = static_cast(local_rank); - access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; - NVFUSER_CUDA_SAFE_CALL( - cuMemSetAccess(d_ptr, aligned_size, &access, /*count=*/1)); - - // Each rank writes its own pattern - std::vector host_data(kNumElems); - for (size_t i = 0; i < kNumElems; i++) { - host_data[i] = static_cast(rank * 10000 + i); - } - NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpy( - reinterpret_cast(d_ptr), - host_data.data(), - kSizeBytes, - cudaMemcpyHostToDevice)); - - // Export handle and exchange via Unix sockets (same pattern as IpcP2pWithVmm) - int shared_fd; - NVFUSER_CUDA_SAFE_CALL(cuMemExportToShareableHandle( - &shared_fd, - mem_handle, - CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, - /*flags=*/0)); - - std::string my_socket = "@nvfuser_tma_p2p_" + std::to_string(rank); - int listener_fd = nvfuser::createIpcSocket(my_socket); - - communicator_->barrier(); - - const int64_t peer_rank = (rank + 1) % num_devices; - const int64_t receiver_rank = (rank - 1 + num_devices) % num_devices; - std::string receiver_path = - "@nvfuser_tma_p2p_" + std::to_string(receiver_rank); - - nvfuser::sendFd(receiver_path, shared_fd); - int peer_fd = nvfuser::recvFd(listener_fd); - - close(listener_fd); - close(shared_fd); - - // Import and map peer's allocation - CUmemGenericAllocationHandle peer_mem_handle = 0; - NVFUSER_CUDA_SAFE_CALL(cuMemImportFromShareableHandle( - &peer_mem_handle, - (void*)((uint64_t)peer_fd), - CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR)); - - CUdeviceptr peer_d_ptr = 0; - NVFUSER_CUDA_SAFE_CALL(cuMemAddressReserve( - &peer_d_ptr, aligned_size, /*alignment=*/granularity, /*baseVA=*/0, - /*flags=*/0)); - NVFUSER_CUDA_SAFE_CALL(cuMemMap( - peer_d_ptr, aligned_size, /*offset=*/0, peer_mem_handle, /*flags=*/0)); - - CUmemAccessDesc peer_access{}; - peer_access.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - peer_access.location.id = static_cast(local_rank); - peer_access.flags = CU_MEM_ACCESS_FLAGS_PROT_READ; - NVFUSER_CUDA_SAFE_CALL( - cuMemSetAccess(peer_d_ptr, aligned_size, &peer_access, /*count=*/1)); - - // Allocate local output buffer and TMA-copy from peer's mapped VA - void* d_output; - NVFUSER_CUDA_RT_SAFE_CALL(cudaMalloc(&d_output, kSizeBytes)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaMemset(d_output, 0, kSizeBytes)); - - communicator_->barrier(); - - launchTmaCopy1D( - d_output, reinterpret_cast(peer_d_ptr), kSizeBytes, 0); - NVFUSER_CUDA_RT_SAFE_CALL(cudaDeviceSynchronize()); - - std::vector result(kNumElems); - NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpy( - result.data(), d_output, kSizeBytes, cudaMemcpyDeviceToHost)); - - for (size_t i = 0; i < kNumElems; i++) { - uint32_t expected = static_cast(peer_rank * 10000 + i); - EXPECT_EQ(result[i], expected) - << "Rank " << rank << " mismatch at index " << i; - } - - NVFUSER_CUDA_RT_SAFE_CALL(cudaFree(d_output)); - NVFUSER_CUDA_SAFE_CALL(cuMemUnmap(peer_d_ptr, aligned_size)); - NVFUSER_CUDA_SAFE_CALL(cuMemAddressFree(peer_d_ptr, aligned_size)); - NVFUSER_CUDA_SAFE_CALL(cuMemRelease(peer_mem_handle)); - close(peer_fd); - NVFUSER_CUDA_SAFE_CALL(cuMemUnmap(d_ptr, aligned_size)); - NVFUSER_CUDA_SAFE_CALL(cuMemAddressFree(d_ptr, aligned_size)); - NVFUSER_CUDA_SAFE_CALL(cuMemRelease(mem_handle)); -} - -#if (CUDA_VERSION >= 13000) - -TEST_F(IpcTest, TmaMulticastRead) { - if (communicator_->size() == 1) { - GTEST_SKIP() << "Skipping test for single device"; - } - - const int64_t world_size = communicator_->size(); - const int64_t rank = communicator_->deviceId(); - const int64_t local_rank = communicator_->local_rank(); - - constexpr int64_t exporter_rank = 0; - constexpr int64_t root_rank = 1; - - int major; - NVFUSER_CUDA_RT_SAFE_CALL(cudaDeviceGetAttribute( - &major, cudaDevAttrComputeCapabilityMajor, local_rank)); - if (major < 9) { - GTEST_SKIP() << "Requires Hopper (SM90+)"; - } - - int is_vmm_supported; - NVFUSER_CUDA_SAFE_CALL(cuDeviceGetAttribute( - &is_vmm_supported, - CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, - local_rank)); - if (is_vmm_supported == 0) { - GTEST_SKIP() - << "Device does not support Virtual Memory Management; skipping."; - } - - int is_ipc_supported; - NVFUSER_CUDA_SAFE_CALL(cuDeviceGetAttribute( - &is_ipc_supported, - CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR_SUPPORTED, - local_rank)); - if (is_ipc_supported == 0) { - GTEST_SKIP() << "Device does not support IPC handles; skipping."; - } - - int is_multicast_supported; - NVFUSER_CUDA_SAFE_CALL(cuDeviceGetAttribute( - &is_multicast_supported, - CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, - local_rank)); - if (is_multicast_supported == 0) { - GTEST_SKIP() << "Device does not support Multicast Objects; skipping."; - } - - NVFUSER_CUDA_RT_SAFE_CALL(cudaSetDevice(local_rank)); - - // Multicast buffer: 2 MB (matches existing NVLS tests) - constexpr size_t kNumElems = 524288; - constexpr size_t kSizeBytes = kNumElems * sizeof(uint32_t); - - using handle_typename = int; - auto handle_type = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; - - CUmulticastObjectProp mcast_prop{}; - mcast_prop.flags = 0; - mcast_prop.handleTypes = handle_type; - mcast_prop.numDevices = world_size; - mcast_prop.size = kSizeBytes; - - size_t mcast_min_granularity = 0; - NVFUSER_CUDA_SAFE_CALL(cuMulticastGetGranularity( - &mcast_min_granularity, &mcast_prop, CU_MULTICAST_GRANULARITY_MINIMUM)); - if (mcast_min_granularity > kSizeBytes) { - GTEST_SKIP() << "Multicast min granularity (" << mcast_min_granularity - << ") exceeds buffer size; skipping."; - } - - size_t mcast_granularity = 0; - NVFUSER_CUDA_SAFE_CALL(cuMulticastGetGranularity( - &mcast_granularity, &mcast_prop, CU_MULTICAST_GRANULARITY_RECOMMENDED)); - if (mcast_granularity > kSizeBytes) { - GTEST_SKIP() << "Multicast recommended granularity (" << mcast_granularity - << ") exceeds buffer size; skipping."; - } - - // Create multicast object on exporter rank and share via Unix sockets - CUmemGenericAllocationHandle mcast_handle{}; - handle_typename shared_handle; - int listener_fd = -1; - - if (rank == exporter_rank) { - NVFUSER_CUDA_SAFE_CALL(cuMulticastCreate(&mcast_handle, &mcast_prop)); - NVFUSER_CUDA_SAFE_CALL(cuMemExportToShareableHandle( - &shared_handle, mcast_handle, handle_type, /*flags=*/0)); - } else { - std::string my_path = - "@nvfuser_tma_mcast_recv_" + std::to_string(rank); - listener_fd = nvfuser::createIpcSocket(my_path); - } - - communicator_->barrier(); - - if (rank != exporter_rank) { - int received_fd = nvfuser::recvFd(listener_fd); - shared_handle = received_fd; - close(listener_fd); - } else { - for (int i = 0; i < world_size; ++i) { - if (i == rank) { - continue; - } - std::string peer_path = - "@nvfuser_tma_mcast_recv_" + std::to_string(i); - nvfuser::sendFd(peer_path, shared_handle); - } - close(shared_handle); - } - - if (rank != exporter_rank) { - NVFUSER_CUDA_SAFE_CALL(cuMemImportFromShareableHandle( - &mcast_handle, (void*)((uint64_t)shared_handle), handle_type)); - close(shared_handle); - } - - CUdevice cu_dev; - NVFUSER_CUDA_SAFE_CALL(cuDeviceGet(&cu_dev, static_cast(local_rank))); - NVFUSER_CUDA_SAFE_CALL(cuMulticastAddDevice(mcast_handle, cu_dev)); - - // Local physical allocation - CUmemAllocationProp prop{}; - prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; - prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - prop.location.id = static_cast(local_rank); - prop.requestedHandleTypes = handle_type; - - size_t granularity = 0; - NVFUSER_CUDA_SAFE_CALL(cuMemGetAllocationGranularity( - &granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM)); - if (granularity > kSizeBytes) { - GTEST_SKIP() << "Allocation granularity (" << granularity - << ") exceeds buffer size; skipping."; - } - - CUmemGenericAllocationHandle local_buffer = 0; - NVFUSER_CUDA_SAFE_CALL( - cuMemCreate(&local_buffer, kSizeBytes, &prop, /*flags=*/0)); - - NVFUSER_CUDA_SAFE_CALL(cuMulticastBindMem( - mcast_handle, /*mcOffset=*/0, local_buffer, /*memOffset=*/0, - kSizeBytes, /*flags=*/0)); - - // MC (multicast) mapping — used for broadcast writes - CUdeviceptr mc_ptr = 0; - NVFUSER_CUDA_SAFE_CALL(cuMemAddressReserve( - &mc_ptr, kSizeBytes, /*alignment=*/mcast_granularity, - /*baseVA=*/0, /*flags=*/0)); - NVFUSER_CUDA_SAFE_CALL( - cuMemMap(mc_ptr, kSizeBytes, /*offset=*/0, mcast_handle, /*flags=*/0)); - CUmemAccessDesc mc_access{}; - mc_access.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - mc_access.location.id = static_cast(local_rank); - mc_access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; - NVFUSER_CUDA_SAFE_CALL( - cuMemSetAccess(mc_ptr, kSizeBytes, &mc_access, /*count=*/1)); - - // UC (unicast) mapping — used for local reads - CUdeviceptr uc_ptr = 0; - NVFUSER_CUDA_SAFE_CALL(cuMemAddressReserve( - &uc_ptr, kSizeBytes, /*alignment=*/granularity, - /*baseVA=*/0, /*flags=*/0)); - NVFUSER_CUDA_SAFE_CALL( - cuMemMap(uc_ptr, kSizeBytes, /*offset=*/0, local_buffer, /*flags=*/0)); - CUmemAccessDesc uc_access{}; - uc_access.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - uc_access.location.id = static_cast(local_rank); - uc_access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; - NVFUSER_CUDA_SAFE_CALL( - cuMemSetAccess(uc_ptr, kSizeBytes, &uc_access, /*count=*/1)); - - // Root broadcasts data via the MC pointer - std::vector host_buffer(kNumElems); - if (rank == root_rank) { - for (size_t i = 0; i < kNumElems; ++i) { - host_buffer[i] = static_cast(i * 3 + 17); - } - NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpy( - reinterpret_cast(mc_ptr), - host_buffer.data(), - kSizeBytes, - cudaMemcpyHostToDevice)); - } - - communicator_->barrier(); - - // Use TMA to copy a portion from the UC pointer to a local output buffer. - // The UC pointer maps to local physical memory that received the multicast - // data, so TMA should be able to read from it. - constexpr int kTmaBytes = 4096; - static_assert(kTmaBytes % 16 == 0); - static_assert(kTmaBytes <= kSizeBytes); - constexpr int kTmaElems = kTmaBytes / sizeof(uint32_t); - - void* d_output; - NVFUSER_CUDA_RT_SAFE_CALL(cudaMalloc(&d_output, kTmaBytes)); - NVFUSER_CUDA_RT_SAFE_CALL(cudaMemset(d_output, 0, kTmaBytes)); - - launchTmaCopy1D(d_output, reinterpret_cast(uc_ptr), kTmaBytes, 0); - NVFUSER_CUDA_RT_SAFE_CALL(cudaDeviceSynchronize()); - - std::vector result(kTmaElems); - NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpy( - result.data(), d_output, kTmaBytes, cudaMemcpyDeviceToHost)); - - for (int i = 0; i < kTmaElems; ++i) { - uint32_t expected = static_cast(i * 3 + 17); - EXPECT_EQ(result[i], expected) - << "Rank " << rank << " mismatch at index " << i; - } - - NVFUSER_CUDA_RT_SAFE_CALL(cudaFree(d_output)); - NVFUSER_CUDA_SAFE_CALL(cuMemUnmap(mc_ptr, kSizeBytes)); - NVFUSER_CUDA_SAFE_CALL(cuMemUnmap(uc_ptr, kSizeBytes)); - NVFUSER_CUDA_SAFE_CALL(cuMemAddressFree(mc_ptr, kSizeBytes)); - NVFUSER_CUDA_SAFE_CALL(cuMemAddressFree(uc_ptr, kSizeBytes)); - NVFUSER_CUDA_SAFE_CALL(cuMemRelease(local_buffer)); - NVFUSER_CUDA_SAFE_CALL(cuMemRelease(mcast_handle)); -} - -#endif // CUDA_VERSION >= 13000 - } // namespace nvfuser diff --git a/tests/cpp/test_multidevice_tma.cpp b/tests/cpp/test_multidevice_tma.cpp new file mode 100644 index 00000000000..4a12b311b5d --- /dev/null +++ b/tests/cpp/test_multidevice_tma.cpp @@ -0,0 +1,288 @@ +// clang-format off +/* +* SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. +* All rights reserved. +* SPDX-License-Identifier: BSD-3-Clause +*/ +// clang-format on +// +// Unit tests for Hopper TMA (Tensor Memory Accelerator) 1D bulk copy +// (cp.async.bulk) across different memory sources: +// 1. Local device memory (cudaMalloc) +// 2. VMM-mapped peer device memory (inter-device P2P) +// 3. NVLS multicast unicast pointers +// +// The kernel source lives in test_multidevice_tma_kernel.cu and is +// stringified at build time. It is compiled at runtime via NVRTC, +// same pattern as csrc/multidevice/cuda_p2p.cpp. + +#include +#include + +#include +#include + +#include "cuda_utils.h" +#include "driver_api.h" +#include "exceptions.h" +#include "multidevice/symmetric_tensor.h" +#include "multidevice/utils.h" +#include "nvfuser_resources/test_multidevice_tma_kernel.h" +#include "tests/cpp/multidevice.h" + +namespace nvfuser { + +// ============================================================================ +// NVRTC helper: compile kernel source at runtime, cache the result. +// ============================================================================ + +namespace { + +CUfunction compileAndGetKernel( + CUmodule& module, + CUfunction& function, + const char* source, + const char* source_name, + const char* kernel_name) { + if (function != nullptr) { + return function; + } + + nvrtcProgram prog; + NVFUSER_NVRTC_SAFE_CALL(nvrtcCreateProgram( + &prog, source, source_name, 0, nullptr, nullptr)); + + int device = 0; + NVFUSER_CUDA_RT_SAFE_CALL(cudaGetDevice(&device)); + cudaDeviceProp prop; + NVFUSER_CUDA_RT_SAFE_CALL( + cudaGetDeviceProperties(&prop, device)); + + std::string arch_arg = "--gpu-architecture=compute_" + + std::to_string(prop.major) + std::to_string(prop.minor); + std::vector opts = { + arch_arg.c_str(), "--std=c++17"}; + + nvrtcResult res = + nvrtcCompileProgram(prog, (int)opts.size(), opts.data()); + if (res != NVRTC_SUCCESS) { + size_t logSize; + NVFUSER_NVRTC_SAFE_CALL( + nvrtcGetProgramLogSize(prog, &logSize)); + std::vector log(logSize); + NVFUSER_NVRTC_SAFE_CALL( + nvrtcGetProgramLog(prog, log.data())); + NVF_ERROR( + false, + "NVRTC compilation of '", + source_name, + "' failed:\n", + log.data()); + } + + size_t ptxSize; + NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTXSize(prog, &ptxSize)); + std::vector ptx(ptxSize); + NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTX(prog, ptx.data())); + NVFUSER_NVRTC_SAFE_CALL(nvrtcDestroyProgram(&prog)); + + NVFUSER_CUDA_SAFE_CALL( + cuModuleLoadData(&module, ptx.data())); + NVFUSER_CUDA_SAFE_CALL( + cuModuleGetFunction(&function, module, kernel_name)); + + return function; +} + +//! Return the NVRTC-compiled tma_copy_1d CUfunction (cached after +//! first call). The kernel uses cp.async.bulk to perform +//! GMEM(src) -> SMEM -> GMEM(dst) +//! and requires dynamic shared memory of num_bytes + 8 (mbarrier). +CUfunction getTmaCopy1dKernel() { + static CUmodule module = nullptr; + static CUfunction kernel = nullptr; + return compileAndGetKernel( + module, + kernel, + nvfuser_resources::test_multidevice_tma_kernel_cu, + "test_multidevice_tma_kernel.cu", + "tma_copy_1d"); +} + +//! Launch the TMA 1D bulk copy kernel: GMEM(src) -> SMEM -> GMEM(dst). +//! num_bytes must be > 0 and a multiple of 16. +void launchTmaCopy1D( + void* dst, + const void* src, + int num_bytes, + CUstream stream = nullptr) { + NVF_CHECK(num_bytes > 0 && num_bytes % 16 == 0); + CUfunction tma_kernel = getTmaCopy1dKernel(); + int smem_size = num_bytes + static_cast(sizeof(uint64_t)); + void* args[] = {&dst, &src, &num_bytes}; + NVFUSER_CUDA_SAFE_CALL(cuLaunchKernel( + tma_kernel, + 1, 1, 1, + 32, 1, 1, + smem_size, + stream, + args, + nullptr)); +} + +} // anonymous namespace + +// ============================================================================ +// Tests +// ============================================================================ + +using TmaTest = MultiDeviceTest; + +// Verify TMA 1D bulk copy on local device memory. +// The kernel uses cp.async.bulk (GMEM->SMEM) + cp.async.bulk (SMEM->GMEM) +// with mbarrier synchronization between the two phases. +TEST_F(TmaTest, TmaLocalCopy) { + const int64_t local_rank = communicator_->local_rank(); + + int major; + NVFUSER_CUDA_RT_SAFE_CALL(cudaDeviceGetAttribute( + &major, cudaDevAttrComputeCapabilityMajor, local_rank)); + if (major < 9) { + GTEST_SKIP() << "Requires Hopper (SM90+)"; + } + + NVFUSER_CUDA_RT_SAFE_CALL(cudaSetDevice(local_rank)); + + constexpr int kNumElems = 256; + constexpr int kSizeBytes = kNumElems * sizeof(uint32_t); + static_assert(kSizeBytes % 16 == 0); + + auto options = at::TensorOptions() + .dtype(at::kInt) + .device(at::kCUDA, local_rank); + at::Tensor src = at::arange(kNumElems, options); + at::Tensor dst = at::zeros({kNumElems}, options); + + launchTmaCopy1D(dst.data_ptr(), src.data_ptr(), kSizeBytes); + NVFUSER_CUDA_RT_SAFE_CALL(cudaDeviceSynchronize()); + + EXPECT_TRUE(dst.equal(src)); +} + +// Verify TMA 1D bulk copy reading from a VMM-mapped peer device +// buffer. SymmetricTensor handles the VMM allocation and IPC handle +// exchange; the test focuses on the TMA transfer itself. +TEST_F(TmaTest, TmaInterDeviceCopy) { + if (communicator_->size() == 1) { + GTEST_SKIP() << "Skipping test for single device"; + } + + const int64_t rank = communicator_->deviceId(); + const int64_t local_rank = communicator_->local_rank(); + const int64_t world_size = communicator_->size(); + + int major; + NVFUSER_CUDA_RT_SAFE_CALL(cudaDeviceGetAttribute( + &major, cudaDevAttrComputeCapabilityMajor, local_rank)); + if (major < 9) { + GTEST_SKIP() << "Requires Hopper (SM90+)"; + } + + constexpr int kNumElems = 256; + constexpr int kSizeBytes = kNumElems * sizeof(int32_t); + static_assert(kSizeBytes % 16 == 0); + + at::Tensor local = SymmetricTensor::allocate( + {kNumElems}, at::kInt, communicator_->device()); + local.fill_(static_cast(rank * 10000)); + SymmetricTensor sym(local); + sym.setupRemoteHandles("tma_p2p"); + + const int64_t peer_rank = (rank + 1) % world_size; + at::Tensor peer = sym.remoteTensor(peer_rank); + + at::Tensor output = at::zeros( + {kNumElems}, + at::TensorOptions().dtype(at::kInt).device(at::kCUDA, local_rank)); + + launchTmaCopy1D(output.data_ptr(), peer.data_ptr(), kSizeBytes); + NVFUSER_CUDA_RT_SAFE_CALL(cudaDeviceSynchronize()); + + at::Tensor expected = at::full( + {kNumElems}, + static_cast(peer_rank * 10000), + at::TensorOptions().dtype(at::kInt).device(at::kCUDA, local_rank)); + EXPECT_TRUE(output.equal(expected)) + << "Rank " << rank << " TMA read from peer " << peer_rank + << " returned wrong data"; +} + +#if (CUDA_VERSION >= 13000) + +// Verify TMA 1D bulk copy writing TO an NVLS multicast pointer. +// Root uses TMA to write data to the MC pointer, which broadcasts +// via NVLS hardware. All ranks then verify the data arrived by +// reading from their local UC view with a normal copy. +TEST_F(TmaTest, TmaMulticastWrite) { + if (communicator_->size() == 1) { + GTEST_SKIP() << "Skipping test for single device"; + } + + const int64_t rank = communicator_->deviceId(); + const int64_t local_rank = communicator_->local_rank(); + + int major; + NVFUSER_CUDA_RT_SAFE_CALL(cudaDeviceGetAttribute( + &major, cudaDevAttrComputeCapabilityMajor, local_rank)); + if (major < 9) { + GTEST_SKIP() << "Requires Hopper (SM90+)"; + } + + int is_multicast_supported; + NVFUSER_CUDA_SAFE_CALL(cuDeviceGetAttribute( + &is_multicast_supported, + CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, + local_rank)); + if (is_multicast_supported == 0) { + GTEST_SKIP() + << "Device does not support Multicast Objects; skipping."; + } + + constexpr int64_t kNumElems = 524288; // 2 MB / sizeof(int32_t) + constexpr int64_t root = 0; + + // cp.async.bulk transfer size is limited by shared memory, + // so we broadcast a 4 KB slice via TMA. + constexpr int kTmaBytes = 4096; + static_assert(kTmaBytes % 16 == 0); + constexpr int kTmaElems = kTmaBytes / sizeof(int32_t); + + at::Tensor local = SymmetricTensor::allocate( + {kNumElems}, at::kInt, communicator_->device()); + local.zero_(); + SymmetricTensor sym(local); + sym.setupMulticast(root, "tma_mcast"); + + auto opts = + at::TensorOptions().dtype(at::kInt).device(at::kCUDA, local_rank); + + // Root: TMA-write source data to MC pointer (NVLS broadcasts it) + if (rank == root) { + at::Tensor src = at::arange(kTmaElems, opts); + launchTmaCopy1D(sym.multicastPtr(), src.data_ptr(), kTmaBytes); + NVFUSER_CUDA_RT_SAFE_CALL(cudaDeviceSynchronize()); + } + + communicator_->barrier(); + + // All ranks: verify data arrived via normal read of local UC tensor + at::Tensor readback = sym.localTensor().slice(0, 0, kTmaElems).clone(); + at::Tensor expected = at::arange(kTmaElems, opts); + EXPECT_TRUE(readback.equal(expected)) + << "Rank " << rank + << " did not receive multicast data written by TMA"; +} + +#endif // CUDA_VERSION >= 13000 + +} // namespace nvfuser diff --git a/tests/cpp/test_multidevice_tma_kernel.cu b/tests/cpp/test_multidevice_tma_kernel.cu new file mode 100644 index 00000000000..d923d626fa7 --- /dev/null +++ b/tests/cpp/test_multidevice_tma_kernel.cu @@ -0,0 +1,92 @@ +// clang-format off +/* +* SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. +* All rights reserved. +* SPDX-License-Identifier: BSD-3-Clause +*/ +// clang-format on +// +// TMA 1D bulk copy kernel (SM90+ / Hopper). +// +// This file is the implements the TMA kernel. Like other kernels, the build +// system stringifies it into nvfuser_resources/test_multidevice_tma_kernel.h +// (a const char*), which test_multidevice_tma.cpp compiles at runtime +// via NVRTC. The file is never compiled statically by nvcc. +// +// A single elected thread (thread 0) performs the full round-trip: +// 1. mbarrier.init (arrival count = 1) +// 2. mbarrier.arrive.expect_tx (announce expected bytes) +// 3. cp.async.bulk GMEM -> SMEM (TMA load, completed via mbarrier) +// 4. mbarrier.try_wait.parity (block until load completes) +// 5. cp.async.bulk SMEM -> GMEM (TMA store) +// 6. cp.async.bulk.commit_group + wait_group.read 0 +// +// Dynamic shared memory layout (128-byte aligned): +// [0, num_bytes) : staging buffer +// [num_bytes, num_bytes+8) : mbarrier (uint64_t) + +extern "C" __global__ void __launch_bounds__(32, 1) + tma_copy_1d( + void* __restrict__ dst, + const void* __restrict__ src, + int num_bytes) { + extern __shared__ __align__(128) unsigned char smem[]; + + unsigned long long* mbar = + reinterpret_cast(smem + num_bytes); + unsigned int smem_addr = + static_cast(__cvta_generic_to_shared(smem)); + unsigned int mbar_addr = + static_cast(__cvta_generic_to_shared(mbar)); + + if (threadIdx.x == 0) { + asm volatile( + "mbarrier.init.shared::cta.b64 [%0], %1;" + ::"r"(mbar_addr), "r"(1)); + asm volatile( + "fence.mbarrier_init.release.cluster;" :::); + } + __syncwarp(); + + if (threadIdx.x == 0) { + // Announce expected transaction bytes on the mbarrier + asm volatile( + "mbarrier.arrive.expect_tx.shared::cta.b64 _, [%0], %1;" + ::"r"(mbar_addr), "r"(num_bytes)); + + // TMA Load: GMEM -> SMEM (async, completed via mbarrier) + asm volatile( + "cp.async.bulk.shared::cluster.global" + ".mbarrier::complete_tx::bytes" + " [%0], [%1], %2, [%3];\n" + ::"r"(smem_addr), "l"(src), "r"(num_bytes), "r"(mbar_addr) + : "memory"); + + // Block until the mbarrier phase flips (TMA load completed) + asm volatile( + "{\n" + ".reg .pred P1;\n" + "TMA_COPY_WAIT_LOAD:\n" + "mbarrier.try_wait.parity.shared::cta.b64" + " P1, [%0], %1;\n" + "@P1 bra TMA_COPY_LOAD_DONE;\n" + "bra TMA_COPY_WAIT_LOAD;\n" + "TMA_COPY_LOAD_DONE:\n" + "}" + ::"r"(mbar_addr), "r"(0)); + + // TMA Store: SMEM -> GMEM + asm volatile( + "cp.async.bulk.global.shared::cta.bulk_group" + " [%0], [%1], %2;\n" + ::"l"(dst), "r"(smem_addr), "r"(num_bytes) + : "memory"); + asm volatile("cp.async.bulk.commit_group;"); + asm volatile( + "cp.async.bulk.wait_group.read 0;" ::: "memory"); + + asm volatile( + "mbarrier.inval.shared::cta.b64 [%0];" + ::"r"(mbar_addr)); + } +} diff --git a/tests/cpp/tma_test_kernels.cu b/tests/cpp/tma_test_kernels.cu deleted file mode 100644 index 5ca8f6c917d..00000000000 --- a/tests/cpp/tma_test_kernels.cu +++ /dev/null @@ -1,118 +0,0 @@ -// clang-format off -/* -* SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. -* All rights reserved. -* SPDX-License-Identifier: BSD-3-Clause -*/ -// clang-format on - -#include "tests/cpp/tma_test_kernels.h" - -#include -#include - -namespace nvfuser { - -// TMA 1D bulk copy kernel: GMEM(src) -> SMEM -> GMEM(dst). -// Inspired by DeepEP's tma_load_1d / tma_store_1d pattern. -// A single elected thread issues all TMA operations while the rest of the warp -// idles. mbarrier synchronization ensures the async TMA load completes before -// the TMA store reads from shared memory. -// -// Dynamic shared memory layout (128-byte aligned): -// [0, num_bytes) : data staging buffer -// [num_bytes, num_bytes+8) : mbarrier (uint64_t, 16-byte aligned since -// num_bytes is a multiple of 16) -__global__ void __launch_bounds__(32, 1) - tma_copy_1d_kernel( - void* __restrict__ dst, - const void* __restrict__ src, - int num_bytes) { - extern __shared__ __align__(128) uint8_t smem[]; - - auto* mbar = reinterpret_cast(smem + num_bytes); - auto smem_addr = - static_cast(__cvta_generic_to_shared(smem)); - auto mbar_addr = - static_cast(__cvta_generic_to_shared(mbar)); - - if (threadIdx.x == 0) { - // Initialize mbarrier with arrival count = 1 - asm volatile( - "mbarrier.init.shared::cta.b64 [%0], %1;" - ::"r"(mbar_addr), "r"(1)); - // Ensure init is visible cluster-wide before any use - asm volatile( - "fence.mbarrier_init.release.cluster;" :::); - } - __syncwarp(); - - if (threadIdx.x == 0) { - // Announce expected number of transaction bytes on the mbarrier - asm volatile( - "mbarrier.arrive.expect_tx.shared::cta.b64 _, [%0], %1;" - ::"r"(mbar_addr), "r"(num_bytes)); - - // TMA Load: GMEM -> SMEM (async, completed via mbarrier) - asm volatile( - "cp.async.bulk.shared::cluster.global" - ".mbarrier::complete_tx::bytes" - " [%0], [%1], %2, [%3];\n" - ::"r"(smem_addr), - "l"(src), - "r"(num_bytes), - "r"(mbar_addr) - : "memory"); - - // Block until the mbarrier phase flips (TMA load completed). - // Phase 0 is the initial phase after mbarrier.init. - asm volatile( - "{\n" - ".reg .pred P1;\n" - "TMA_COPY_WAIT_LOAD:\n" - "mbarrier.try_wait.parity.shared::cta.b64" - " P1, [%0], %1;\n" - "@P1 bra TMA_COPY_LOAD_DONE;\n" - "bra TMA_COPY_WAIT_LOAD;\n" - "TMA_COPY_LOAD_DONE:\n" - "}" - ::"r"(mbar_addr), "r"(0)); - - // TMA Store: SMEM -> GMEM - // No fence.proxy.async needed here because both the load and store - // operate through the async proxy; the mbarrier completion already - // establishes the necessary ordering (cf. DeepEP intranode.cu). - asm volatile( - "cp.async.bulk.global.shared::cta.bulk_group" - " [%0], [%1], %2;\n" - ::"l"(dst), - "r"(smem_addr), - "r"(num_bytes) - : "memory"); - asm volatile("cp.async.bulk.commit_group;"); - asm volatile( - "cp.async.bulk.wait_group.read 0;" ::: "memory"); - - // Invalidate mbarrier before kernel exit - asm volatile( - "mbarrier.inval.shared::cta.b64 [%0];" - ::"r"(mbar_addr)); - } -} - -void launchTmaCopy1D( - void* dst, - const void* src, - int num_bytes, - cudaStream_t stream) { - assert(num_bytes > 0 && "num_bytes must be positive"); - assert( - num_bytes % 16 == 0 && - "cp.async.bulk requires size to be a multiple of 16 bytes"); - - // data buffer + mbarrier (8 bytes) - int smem_size = num_bytes + static_cast(sizeof(uint64_t)); - tma_copy_1d_kernel<<<1, 32, smem_size, stream>>>(dst, src, num_bytes); -} - -} // namespace nvfuser diff --git a/tests/cpp/tma_test_kernels.h b/tests/cpp/tma_test_kernels.h deleted file mode 100644 index 8921decedae..00000000000 --- a/tests/cpp/tma_test_kernels.h +++ /dev/null @@ -1,24 +0,0 @@ -// clang-format off -/* -* SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. -* All rights reserved. -* SPDX-License-Identifier: BSD-3-Clause -*/ -// clang-format on -#pragma once - -#include - -namespace nvfuser { - -//! Copies num_bytes from src (GMEM) to dst (GMEM) via TMA 1D bulk copy: -//! GMEM(src) -> SMEM -> GMEM(dst) -//! Uses cp.async.bulk with mbarrier synchronization (SM90+ / Hopper). -//! num_bytes must be a multiple of 16 and > 0. -void launchTmaCopy1D( - void* dst, - const void* src, - int num_bytes, - cudaStream_t stream); - -} // namespace nvfuser From a852aad80f51549413f8e1036a6b0bc93ab577b2 Mon Sep 17 00:00:00 2001 From: snordmann Date: Tue, 24 Feb 2026 11:28:49 -0800 Subject: [PATCH 3/5] add comment --- tests/cpp/test_multidevice_tma_kernel.cu | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/tests/cpp/test_multidevice_tma_kernel.cu b/tests/cpp/test_multidevice_tma_kernel.cu index d923d626fa7..125cd952171 100644 --- a/tests/cpp/test_multidevice_tma_kernel.cu +++ b/tests/cpp/test_multidevice_tma_kernel.cu @@ -8,17 +8,23 @@ // // TMA 1D bulk copy kernel (SM90+ / Hopper). // -// This file is the implements the TMA kernel. Like other kernels, the build -// system stringifies it into nvfuser_resources/test_multidevice_tma_kernel.h -// (a const char*), which test_multidevice_tma.cpp compiles at runtime -// via NVRTC. The file is never compiled statically by nvcc. +// This file implements the TMA kernel. The build system stringifies it +// into nvfuser_resources/test_multidevice_tma_kernel.h (a const char*), +// which test_multidevice_tma.cpp compiles at runtime via NVRTC. The +// file is never compiled statically by nvcc. // -// A single elected thread (thread 0) performs the full round-trip: +// TMA (cp.async.bulk) is a GMEM<->SMEM transfer engine — there is no +// GMEM-to-GMEM variant. Shared memory staging is inherent to the +// hardware, so the kernel performs a two-phase copy: +// +// GMEM(src) --[TMA load]--> SMEM --[TMA store]--> GMEM(dst) +// +// A single elected thread (thread 0) drives both phases: // 1. mbarrier.init (arrival count = 1) // 2. mbarrier.arrive.expect_tx (announce expected bytes) -// 3. cp.async.bulk GMEM -> SMEM (TMA load, completed via mbarrier) +// 3. cp.async.bulk.shared::cluster.global (TMA load, async) // 4. mbarrier.try_wait.parity (block until load completes) -// 5. cp.async.bulk SMEM -> GMEM (TMA store) +// 5. cp.async.bulk.global.shared::cta (TMA store) // 6. cp.async.bulk.commit_group + wait_group.read 0 // // Dynamic shared memory layout (128-byte aligned): From c32c13087bc4dced1ab985c2c97d8c4d9d6aea6b Mon Sep 17 00:00:00 2001 From: snordmann Date: Tue, 24 Feb 2026 11:34:54 -0800 Subject: [PATCH 4/5] rename kernel file --- CMakeLists.txt | 5 ++--- .../multidevice/tma_copy.cu | 13 +++++++++---- tests/cpp/test_multidevice_tma.cpp | 8 ++++---- 3 files changed, 15 insertions(+), 11 deletions(-) rename tests/cpp/test_multidevice_tma_kernel.cu => csrc/multidevice/tma_copy.cu (86%) diff --git a/CMakeLists.txt b/CMakeLists.txt index eaa5d06b7c2..644f8bb06ea 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1032,8 +1032,7 @@ if(BUILD_TEST) add_test_without_main(test_multidevice "${MULTIDEVICE_TEST_SRCS}" "") target_include_directories(test_multidevice PRIVATE "${CMAKE_BINARY_DIR}/include") - add_dependencies(test_multidevice - nvfuser_rt_test_multidevice_tma_kernel) + add_dependencies(test_multidevice nvfuser_rt_tma_copy) list(APPEND TEST_BINARIES test_multidevice) set(MULTIDEVICE_TUTORIAL_SRCS) @@ -1269,7 +1268,7 @@ list(APPEND NVFUSER_RUNTIME_FILES ${NVFUSER_ROOT}/runtime/memory.cu ${NVFUSER_ROOT}/runtime/multicast.cu ${NVFUSER_SRCS_DIR}/multidevice/alltoallv.cu - ${NVFUSER_ROOT}/tests/cpp/test_multidevice_tma_kernel.cu + ${NVFUSER_SRCS_DIR}/multidevice/tma_copy.cu ${NVFUSER_ROOT}/runtime/random_numbers.cu ${NVFUSER_ROOT}/runtime/tensor_memory.cu ${NVFUSER_ROOT}/runtime/tensor.cu diff --git a/tests/cpp/test_multidevice_tma_kernel.cu b/csrc/multidevice/tma_copy.cu similarity index 86% rename from tests/cpp/test_multidevice_tma_kernel.cu rename to csrc/multidevice/tma_copy.cu index 125cd952171..382fcc1f2a7 100644 --- a/tests/cpp/test_multidevice_tma_kernel.cu +++ b/csrc/multidevice/tma_copy.cu @@ -8,10 +8,15 @@ // // TMA 1D bulk copy kernel (SM90+ / Hopper). // -// This file implements the TMA kernel. The build system stringifies it -// into nvfuser_resources/test_multidevice_tma_kernel.h (a const char*), -// which test_multidevice_tma.cpp compiles at runtime via NVRTC. The -// file is never compiled statically by nvcc. +// This file implements a TMA-based data copy kernel. The build system +// stringifies it into nvfuser_resources/tma_copy.h (a const char*), +// which is compiled at runtime via NVRTC. The file is never compiled +// statically by nvcc. +// +// Currently used by tests (test_multidevice_tma.cpp). In a future PR +// this kernel will be integrated as a P2P and multicast transport +// alongside the existing SM-based and copy-engine transports in +// csrc/multidevice/cuda_p2p.cpp. // // TMA (cp.async.bulk) is a GMEM<->SMEM transfer engine — there is no // GMEM-to-GMEM variant. Shared memory staging is inherent to the diff --git a/tests/cpp/test_multidevice_tma.cpp b/tests/cpp/test_multidevice_tma.cpp index 4a12b311b5d..1aa5f0af2a5 100644 --- a/tests/cpp/test_multidevice_tma.cpp +++ b/tests/cpp/test_multidevice_tma.cpp @@ -12,7 +12,7 @@ // 2. VMM-mapped peer device memory (inter-device P2P) // 3. NVLS multicast unicast pointers // -// The kernel source lives in test_multidevice_tma_kernel.cu and is +// The kernel source lives in csrc/multidevice/tma_copy.cu and is // stringified at build time. It is compiled at runtime via NVRTC, // same pattern as csrc/multidevice/cuda_p2p.cpp. @@ -27,7 +27,7 @@ #include "exceptions.h" #include "multidevice/symmetric_tensor.h" #include "multidevice/utils.h" -#include "nvfuser_resources/test_multidevice_tma_kernel.h" +#include "nvfuser_resources/tma_copy.h" #include "tests/cpp/multidevice.h" namespace nvfuser { @@ -104,8 +104,8 @@ CUfunction getTmaCopy1dKernel() { return compileAndGetKernel( module, kernel, - nvfuser_resources::test_multidevice_tma_kernel_cu, - "test_multidevice_tma_kernel.cu", + nvfuser_resources::tma_copy_cu, + "tma_copy.cu", "tma_copy_1d"); } From ae0c760718f5ad7436e4055714a221fcd50bb68a Mon Sep 17 00:00:00 2001 From: snordmann Date: Wed, 25 Feb 2026 02:09:47 -0800 Subject: [PATCH 5/5] lint --- csrc/multidevice/tma_copy.cu | 42 +++++++++++------------ tests/cpp/test_multidevice_tma.cpp | 55 +++++++++++------------------- 2 files changed, 39 insertions(+), 58 deletions(-) diff --git a/csrc/multidevice/tma_copy.cu b/csrc/multidevice/tma_copy.cu index 382fcc1f2a7..23176283e4e 100644 --- a/csrc/multidevice/tma_copy.cu +++ b/csrc/multidevice/tma_copy.cu @@ -36,11 +36,10 @@ // [0, num_bytes) : staging buffer // [num_bytes, num_bytes+8) : mbarrier (uint64_t) -extern "C" __global__ void __launch_bounds__(32, 1) - tma_copy_1d( - void* __restrict__ dst, - const void* __restrict__ src, - int num_bytes) { +extern "C" __global__ void __launch_bounds__(32, 1) tma_copy_1d( + void* __restrict__ dst, + const void* __restrict__ src, + int num_bytes) { extern __shared__ __align__(128) unsigned char smem[]; unsigned long long* mbar = @@ -52,25 +51,26 @@ extern "C" __global__ void __launch_bounds__(32, 1) if (threadIdx.x == 0) { asm volatile( - "mbarrier.init.shared::cta.b64 [%0], %1;" - ::"r"(mbar_addr), "r"(1)); - asm volatile( - "fence.mbarrier_init.release.cluster;" :::); + "mbarrier.init.shared::cta.b64 [%0], %1;" ::"r"(mbar_addr), "r"(1)); + asm volatile("fence.mbarrier_init.release.cluster;" :::); } __syncwarp(); if (threadIdx.x == 0) { // Announce expected transaction bytes on the mbarrier asm volatile( - "mbarrier.arrive.expect_tx.shared::cta.b64 _, [%0], %1;" - ::"r"(mbar_addr), "r"(num_bytes)); + "mbarrier.arrive.expect_tx.shared::cta.b64 _, [%0], %1;" ::"r"( + mbar_addr), + "r"(num_bytes)); // TMA Load: GMEM -> SMEM (async, completed via mbarrier) asm volatile( "cp.async.bulk.shared::cluster.global" ".mbarrier::complete_tx::bytes" - " [%0], [%1], %2, [%3];\n" - ::"r"(smem_addr), "l"(src), "r"(num_bytes), "r"(mbar_addr) + " [%0], [%1], %2, [%3];\n" ::"r"(smem_addr), + "l"(src), + "r"(num_bytes), + "r"(mbar_addr) : "memory"); // Block until the mbarrier phase flips (TMA load completed) @@ -83,21 +83,19 @@ extern "C" __global__ void __launch_bounds__(32, 1) "@P1 bra TMA_COPY_LOAD_DONE;\n" "bra TMA_COPY_WAIT_LOAD;\n" "TMA_COPY_LOAD_DONE:\n" - "}" - ::"r"(mbar_addr), "r"(0)); + "}" ::"r"(mbar_addr), + "r"(0)); // TMA Store: SMEM -> GMEM asm volatile( "cp.async.bulk.global.shared::cta.bulk_group" - " [%0], [%1], %2;\n" - ::"l"(dst), "r"(smem_addr), "r"(num_bytes) + " [%0], [%1], %2;\n" ::"l"(dst), + "r"(smem_addr), + "r"(num_bytes) : "memory"); asm volatile("cp.async.bulk.commit_group;"); - asm volatile( - "cp.async.bulk.wait_group.read 0;" ::: "memory"); + asm volatile("cp.async.bulk.wait_group.read 0;" ::: "memory"); - asm volatile( - "mbarrier.inval.shared::cta.b64 [%0];" - ::"r"(mbar_addr)); + asm volatile("mbarrier.inval.shared::cta.b64 [%0];" ::"r"(mbar_addr)); } } diff --git a/tests/cpp/test_multidevice_tma.cpp b/tests/cpp/test_multidevice_tma.cpp index 1aa5f0af2a5..0a5363eae3f 100644 --- a/tests/cpp/test_multidevice_tma.cpp +++ b/tests/cpp/test_multidevice_tma.cpp @@ -49,29 +49,24 @@ CUfunction compileAndGetKernel( } nvrtcProgram prog; - NVFUSER_NVRTC_SAFE_CALL(nvrtcCreateProgram( - &prog, source, source_name, 0, nullptr, nullptr)); + NVFUSER_NVRTC_SAFE_CALL( + nvrtcCreateProgram(&prog, source, source_name, 0, nullptr, nullptr)); int device = 0; NVFUSER_CUDA_RT_SAFE_CALL(cudaGetDevice(&device)); cudaDeviceProp prop; - NVFUSER_CUDA_RT_SAFE_CALL( - cudaGetDeviceProperties(&prop, device)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaGetDeviceProperties(&prop, device)); std::string arch_arg = "--gpu-architecture=compute_" + std::to_string(prop.major) + std::to_string(prop.minor); - std::vector opts = { - arch_arg.c_str(), "--std=c++17"}; + std::vector opts = {arch_arg.c_str(), "--std=c++17"}; - nvrtcResult res = - nvrtcCompileProgram(prog, (int)opts.size(), opts.data()); + nvrtcResult res = nvrtcCompileProgram(prog, (int)opts.size(), opts.data()); if (res != NVRTC_SUCCESS) { size_t logSize; - NVFUSER_NVRTC_SAFE_CALL( - nvrtcGetProgramLogSize(prog, &logSize)); + NVFUSER_NVRTC_SAFE_CALL(nvrtcGetProgramLogSize(prog, &logSize)); std::vector log(logSize); - NVFUSER_NVRTC_SAFE_CALL( - nvrtcGetProgramLog(prog, log.data())); + NVFUSER_NVRTC_SAFE_CALL(nvrtcGetProgramLog(prog, log.data())); NVF_ERROR( false, "NVRTC compilation of '", @@ -86,10 +81,8 @@ CUfunction compileAndGetKernel( NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTX(prog, ptx.data())); NVFUSER_NVRTC_SAFE_CALL(nvrtcDestroyProgram(&prog)); - NVFUSER_CUDA_SAFE_CALL( - cuModuleLoadData(&module, ptx.data())); - NVFUSER_CUDA_SAFE_CALL( - cuModuleGetFunction(&function, module, kernel_name)); + NVFUSER_CUDA_SAFE_CALL(cuModuleLoadData(&module, ptx.data())); + NVFUSER_CUDA_SAFE_CALL(cuModuleGetFunction(&function, module, kernel_name)); return function; } @@ -121,13 +114,7 @@ void launchTmaCopy1D( int smem_size = num_bytes + static_cast(sizeof(uint64_t)); void* args[] = {&dst, &src, &num_bytes}; NVFUSER_CUDA_SAFE_CALL(cuLaunchKernel( - tma_kernel, - 1, 1, 1, - 32, 1, 1, - smem_size, - stream, - args, - nullptr)); + tma_kernel, 1, 1, 1, 32, 1, 1, smem_size, stream, args, nullptr)); } } // anonymous namespace @@ -157,9 +144,8 @@ TEST_F(TmaTest, TmaLocalCopy) { constexpr int kSizeBytes = kNumElems * sizeof(uint32_t); static_assert(kSizeBytes % 16 == 0); - auto options = at::TensorOptions() - .dtype(at::kInt) - .device(at::kCUDA, local_rank); + auto options = + at::TensorOptions().dtype(at::kInt).device(at::kCUDA, local_rank); at::Tensor src = at::arange(kNumElems, options); at::Tensor dst = at::zeros({kNumElems}, options); @@ -192,8 +178,8 @@ TEST_F(TmaTest, TmaInterDeviceCopy) { constexpr int kSizeBytes = kNumElems * sizeof(int32_t); static_assert(kSizeBytes % 16 == 0); - at::Tensor local = SymmetricTensor::allocate( - {kNumElems}, at::kInt, communicator_->device()); + at::Tensor local = + SymmetricTensor::allocate({kNumElems}, at::kInt, communicator_->device()); local.fill_(static_cast(rank * 10000)); SymmetricTensor sym(local); sym.setupRemoteHandles("tma_p2p"); @@ -244,8 +230,7 @@ TEST_F(TmaTest, TmaMulticastWrite) { CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, local_rank)); if (is_multicast_supported == 0) { - GTEST_SKIP() - << "Device does not support Multicast Objects; skipping."; + GTEST_SKIP() << "Device does not support Multicast Objects; skipping."; } constexpr int64_t kNumElems = 524288; // 2 MB / sizeof(int32_t) @@ -257,14 +242,13 @@ TEST_F(TmaTest, TmaMulticastWrite) { static_assert(kTmaBytes % 16 == 0); constexpr int kTmaElems = kTmaBytes / sizeof(int32_t); - at::Tensor local = SymmetricTensor::allocate( - {kNumElems}, at::kInt, communicator_->device()); + at::Tensor local = + SymmetricTensor::allocate({kNumElems}, at::kInt, communicator_->device()); local.zero_(); SymmetricTensor sym(local); sym.setupMulticast(root, "tma_mcast"); - auto opts = - at::TensorOptions().dtype(at::kInt).device(at::kCUDA, local_rank); + auto opts = at::TensorOptions().dtype(at::kInt).device(at::kCUDA, local_rank); // Root: TMA-write source data to MC pointer (NVLS broadcasts it) if (rank == root) { @@ -279,8 +263,7 @@ TEST_F(TmaTest, TmaMulticastWrite) { at::Tensor readback = sym.localTensor().slice(0, 0, kTmaElems).clone(); at::Tensor expected = at::arange(kTmaElems, opts); EXPECT_TRUE(readback.equal(expected)) - << "Rank " << rank - << " did not receive multicast data written by TMA"; + << "Rank " << rank << " did not receive multicast data written by TMA"; } #endif // CUDA_VERSION >= 13000