From 9d493948e839affcc18f1956f60d38038a14df28 Mon Sep 17 00:00:00 2001 From: neoblizz Date: Mon, 8 Dec 2025 18:28:09 +0000 Subject: [PATCH 01/13] All-gather and updated all-to-all. --- .gitignore | 17 +- benchmark/ccl/all_gather/benchmark.py | 347 ++++++++++++++++++++++++++ benchmark/ccl/all_to_all/benchmark.py | 11 +- iris/ccl/all_gather.py | 213 ++++++++++++++++ iris/ccl/all_to_all.py | 107 ++++---- iris/experimental/iris_gluon.py | 30 +++ tests/ccl/test_all_gather.py | 89 +++++++ 7 files changed, 756 insertions(+), 58 deletions(-) create mode 100644 benchmark/ccl/all_gather/benchmark.py create mode 100644 iris/ccl/all_gather.py create mode 100644 tests/ccl/test_all_gather.py diff --git a/.gitignore b/.gitignore index 34f9a2a5..6d8f13f3 100644 --- a/.gitignore +++ b/.gitignore @@ -28,4 +28,19 @@ slurm*.out examples/gemm/results/* asm/ -*.img \ No newline at end of file +*.img + +.cache/ +.local/ +.triton/ +.pytest_cache/ +.ruff_cache/ +__pycache__/ +*.pyc +*.pyo +*.pyd +*.pyw +*.pyz +*.pywz +*.pyzw +*.pyzwz \ No newline at end of file diff --git a/benchmark/ccl/all_gather/benchmark.py b/benchmark/ccl/all_gather/benchmark.py new file mode 100644 index 00000000..8e551e50 --- /dev/null +++ b/benchmark/ccl/all_gather/benchmark.py @@ -0,0 +1,347 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Benchmark for iris-ccl all-gather collective operation. + +This benchmark showcases the all-gather collective and reports achieved bandwidth. +""" + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import random +import argparse + +from examples.common.utils import JSONWriter + +import iris +from iris.ccl import Config + +# Conditional import for Gluon +try: + import iris.experimental.iris_gluon as iris_gluon + + GLUON_AVAILABLE = True +except ImportError: + GLUON_AVAILABLE = False + +torch.manual_seed(123) +random.seed(123) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark all-gather collective operation.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=16384, help="Number of rows in input tensors") + parser.add_argument("-n", type=int, default=16384, help="Number of columns in input tensors") + parser.add_argument("-d", "--debug", action="store_true", help="Enable debug mode") + parser.add_argument("-v", "--validate", action="store_true", help="Enable validation mode") + parser.add_argument("-b", "--benchmark", action="store_true", help="Enable benchmarking mode") + parser.add_argument( + "--datatype", + type=str, + default="fp16", + choices=["fp16", "fp32", "bf16"], + help="Datatype of tensors", + ) + parser.add_argument( + "--output_file", + type=str, + default="log.json", + help="Output file", + ) + parser.add_argument("--heap_size", type=int, default=1 << 34, help="Iris heap size") + parser.add_argument("--comm_sms", type=int, default=64, help="Number of SMs for all-gather kernel") + parser.add_argument("--benchmark_rccl", action="store_true", help="Also benchmark PyTorch RCCL (all_gather_into_tensor) for comparison") + parser.add_argument("--block_size_m", type=int, default=None, help="Block size for M dimension tiling") + parser.add_argument("--block_size_n", type=int, default=None, help="Block size for N dimension tiling") + parser.add_argument("--swizzle_size", type=int, default=None, help="Number of tiles to swizzle together") + parser.add_argument("--num_xcds", type=int, default=None, help="Number of XCDs (auto-detected if not set)") + parser.add_argument("-r", "--num_ranks", type=int, default=8, help="Number of ranks/processes") + parser.add_argument("--use_gluon", action="store_true", help="Use Gluon implementation with traffic shaping") + + return vars(parser.parse_args()) + + +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) + + # Use Gluon if requested and available + if args.get("use_gluon", False): + if not GLUON_AVAILABLE: + raise RuntimeError("Gluon is not available. Install Triton with Gluon support or remove --use_gluon flag") + shmem = iris_gluon.iris(args["heap_size"]) + else: + shmem = iris.iris(args["heap_size"]) + + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + # Datatype mapping + datatype = torch.float32 + if args["datatype"] == "fp16": + datatype = torch.float16 + elif args["datatype"] == "fp32": + datatype = torch.float32 + elif args["datatype"] == "bf16": + datatype = torch.bfloat16 + else: + print("Unknown datatype.") + exit(1) + + M = args["m"] + N = args["n"] + + # Create config with optional block size parameters + config_kwargs = {"comm_sms": args["comm_sms"]} + if args["block_size_m"] is not None: + config_kwargs["block_size_m"] = args["block_size_m"] + if args["block_size_n"] is not None: + config_kwargs["block_size_n"] = args["block_size_n"] + if args["swizzle_size"] is not None: + config_kwargs["swizzle_size"] = args["swizzle_size"] + if args["num_xcds"] is not None: + config_kwargs["num_xcds"] = args["num_xcds"] + if args.get("use_gluon", False): + config_kwargs["use_gluon"] = True + + config = Config(**config_kwargs) + + json_writer = JSONWriter(args["output_file"]) + json_writer.add_field("world_size", world_size) + + for key, value in args.items(): + json_writer.add_field(key, value) + + # Export config values to JSON (use actual values from config, including defaults) + json_writer.add_field("block_size_m", config.block_size_m) + json_writer.add_field("block_size_n", config.block_size_n) + json_writer.add_field("swizzle_size", config.swizzle_size) + json_writer.add_field("num_xcds", config.num_xcds) + json_writer.add_field("use_gluon", config.use_gluon) + + # Create input and output tensors for all-gather + # Input: each rank has (M, N) tensor + # Output: (world_size * M, N) - concatenated along dimension 0 + # Note: Must use shmem.zeros() to allocate on Iris symmetric heap for iris.put() compatibility + input_tensor = shmem.zeros((M, N), dtype=datatype) + output_tensor = shmem.zeros((world_size * M, N), dtype=datatype) + expected_tensor = shmem.zeros((world_size * M, N), dtype=datatype) + + # Fill input with deterministic values + val = float(rank + 1) + input_tensor.fill_(val) + + # Expected output: each rank's input appears at output[rank * M : (rank + 1) * M, :] + for r in range(world_size): + expected_val = float(r + 1) + expected_tensor[r * M : (r + 1) * M, :] = expected_val + + comm_stream = torch.cuda.Stream() + + kernel_timing = { + "all_gather": { + "start_event": torch.cuda.Event(enable_timing=True), + "end_event": torch.cuda.Event(enable_timing=True), + "ms": 0, + "experiments": 0, + }, + } + + def run_experiment(): + nonlocal kernel_timing + shmem.barrier() + + torch.cuda.nvtx.range_push("All-Gather") + with torch.cuda.stream(comm_stream): + kernel_timing["all_gather"]["start_event"].record() + shmem.ccl.all_gather(output_tensor, input_tensor, config=config, async_op=False) + kernel_timing["all_gather"]["end_event"].record() + kernel_timing["all_gather"]["experiments"] += 1 + torch.cuda.nvtx.range_pop() + + # Synchronize before querying event timing + shmem.barrier() + + # Update timing + ms = kernel_timing["all_gather"]["start_event"].elapsed_time(kernel_timing["all_gather"]["end_event"]) + kernel_timing["all_gather"]["ms"] += ms + + # Synchronize across all GPUs + shmem.barrier() + + if args["validate"]: + shmem.info("Validating...") + + # Reset output before validation + output_tensor.zero_() + shmem.barrier() + + # Reinitialize input data + val = float(rank + 1) + input_tensor.fill_(val) + shmem.barrier() + + run_experiment() + torch.cuda.synchronize() + shmem.barrier() + + atol = 1e-3 if datatype == torch.float16 else 1e-5 + success = torch.allclose(output_tensor, expected_tensor, atol=atol) + if not success: + max_diff = torch.abs(output_tensor - expected_tensor).max().item() + shmem.error(f"Rank {rank}: Validation failed, max diff: {max_diff}") + + if success: + shmem.info("All-gather validation passed!") + else: + shmem.error("All-gather validation failed!") + + json_writer.add_field("success", success) + + # Wait for all to finish validation + shmem.barrier() + + if args["benchmark"]: + # Warmup for benchmarking + run_experiment() + shmem.barrier() + + for k in ["all_gather"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + # Reset output before benchmarking + output_tensor.zero_() + shmem.barrier() + + # Reinitialize input data + val = float(rank + 1) + input_tensor.fill_(val) + shmem.barrier() + + shmem.info("Benchmarking...") + + # Calculate bandwidth + # In all-gather, each rank sends its (M, N) tensor to all ranks + # Total bytes sent = (world_size - 1) * M * N * element_size (excluding local copy) + # Total bytes received = (world_size - 1) * M * N * element_size + # Total bytes = (world_size - 1) * M * N * element_size + element_size = torch.tensor([], dtype=datatype).element_size() + total_bytes = (world_size - 1) * M * N * element_size + total_bytes_gb = total_bytes / (1024**3) + + triton_ms = iris.do_bench(run_experiment, shmem.barrier) + bandwidth_gbps = total_bytes_gb / ( + (kernel_timing["all_gather"]["ms"] / kernel_timing["all_gather"]["experiments"]) * 1e-3 + ) + + shmem.info( + f"All-gather (M={M}, N={N}, world_size={world_size}, dtype={args['datatype']}): " + f"{triton_ms:.3f} ms, {bandwidth_gbps:.3f} GB/s" + ) + + json_writer.add_field("bandwidth_gbps", bandwidth_gbps) + json_writer.add_field("total_ms", triton_ms) + json_writer.add_field("total_bytes", total_bytes) + json_writer.add_field("total_bytes_gb", total_bytes_gb) + json_writer.add_field( + "all_gather_ms", kernel_timing["all_gather"]["ms"] / kernel_timing["all_gather"]["experiments"] + ) + json_writer.add_field("all_gather_experiments", kernel_timing["all_gather"]["experiments"]) + + # Wait for all to finish benchmarking + shmem.barrier() + + # Benchmark RCCL (PyTorch all_gather_into_tensor) for comparison + if args.get("benchmark_rccl", False): + shmem.info("Benchmarking PyTorch RCCL (all_gather_into_tensor)...") + + # Create PyTorch tensors (not on Iris heap) + pytorch_input = torch.zeros(M, N, dtype=datatype, device=f"cuda:{rank}") + pytorch_input.fill_(float(rank + 1)) + pytorch_output = torch.zeros(world_size * M, N, dtype=datatype, device=f"cuda:{rank}") + + # Warmup + for _ in range(10): + dist.all_gather_into_tensor(pytorch_output, pytorch_input) + torch.cuda.synchronize() + dist.barrier() + + # Benchmark + pytorch_output.zero_() + pytorch_input.fill_(float(rank + 1)) + dist.barrier() + + rccl_start = torch.cuda.Event(enable_timing=True) + rccl_end = torch.cuda.Event(enable_timing=True) + + num_iterations = 126 # Match Iris benchmark iterations + dist.barrier() + rccl_start.record() + for _ in range(num_iterations): + dist.all_gather_into_tensor(pytorch_output, pytorch_input) + rccl_end.record() + torch.cuda.synchronize() + dist.barrier() + + rccl_ms = rccl_start.elapsed_time(rccl_end) / num_iterations + element_size = torch.tensor([], dtype=datatype).element_size() + total_bytes = (world_size - 1) * M * N * element_size + total_bytes_gb = total_bytes / (1024**3) + rccl_bandwidth_gbps = total_bytes_gb / (rccl_ms * 1e-3) + + shmem.info( + f"RCCL all_gather_into_tensor (M={M}, N={N}, world_size={world_size}, dtype={args['datatype']}): " + f"{rccl_ms:.3f} ms, {rccl_bandwidth_gbps:.3f} GB/s" + ) + + if args["benchmark"]: + # Calculate performance ratio + iris_bandwidth = bandwidth_gbps + rccl_ratio = (iris_bandwidth / rccl_bandwidth_gbps) * 100 if rccl_bandwidth_gbps > 0 else 0 + shmem.info(f"Performance ratio (Iris/RCCL): {rccl_ratio:.1f}%") + + json_writer.add_field("rccl_bandwidth_gbps", rccl_bandwidth_gbps) + json_writer.add_field("rccl_ms", rccl_ms) + json_writer.add_field("rccl_ratio_percent", rccl_ratio) + + # Wait for all to finish RCCL benchmarking + shmem.barrier() + + if rank == 0: + json_writer.flush() + json_writer.display() + + shmem.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + num_ranks = args["num_ranks"] + init_url = "tcp://127.0.0.1:29503" + + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + + +if __name__ == "__main__": + main() + diff --git a/benchmark/ccl/all_to_all/benchmark.py b/benchmark/ccl/all_to_all/benchmark.py index b9af8689..a1b570dd 100644 --- a/benchmark/ccl/all_to_all/benchmark.py +++ b/benchmark/ccl/all_to_all/benchmark.py @@ -55,7 +55,7 @@ def parse_args(): help="Output file", ) parser.add_argument("--heap_size", type=int, default=1 << 34, help="Iris heap size") - parser.add_argument("--comm_sms", type=int, default=32, help="Number of SMs for all-to-all kernel") + parser.add_argument("--comm_sms", type=int, default=64, help="Number of SMs for all-to-all kernel") parser.add_argument("--block_size_m", type=int, default=None, help="Block size for M dimension tiling") parser.add_argument("--block_size_n", type=int, default=None, help="Block size for N dimension tiling") parser.add_argument("--swizzle_size", type=int, default=None, help="Number of tiles to swizzle together") @@ -140,7 +140,10 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): output_concat = shmem.zeros((M, N * world_size), dtype=datatype) expected_concat = shmem.zeros((M, N * world_size), dtype=datatype) - for target_rank in range(world_size): + # Determine which ranks to communicate with + comm_ranks = list(range(world_size)) + + for target_rank in comm_ranks: # Input: rank sends data at position (target_rank * N) val = float(rank * 1000 + target_rank) input_concat[:, target_rank * N : (target_rank + 1) * N] = val @@ -190,7 +193,7 @@ def run_experiment(): shmem.barrier() # Reinitialize input data - for target_rank in range(world_size): + for target_rank in comm_ranks: val = float(rank * 1000 + target_rank) input_concat[:, target_rank * N : (target_rank + 1) * N] = val shmem.barrier() @@ -229,7 +232,7 @@ def run_experiment(): shmem.barrier() # Reinitialize input data - for target_rank in range(world_size): + for target_rank in comm_ranks: val = float(rank * 1000 + target_rank) input_concat[:, target_rank * N : (target_rank + 1) * N] = val shmem.barrier() diff --git a/iris/ccl/all_gather.py b/iris/ccl/all_gather.py new file mode 100644 index 00000000..7a3d200a --- /dev/null +++ b/iris/ccl/all_gather.py @@ -0,0 +1,213 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +All-gather collective communication primitive for Iris. +Gathers tensors from all ranks and concatenates them along the last dimension. +""" + +import triton +import triton.language as tl +import torch +import iris +from .config import Config + + +@triton.jit() +def chiplet_transform_chunked(pid, num_workgroups: tl.constexpr, num_xcds: tl.constexpr, chunk_size: tl.constexpr): + if pid > (num_workgroups // (num_xcds * chunk_size)) * (num_xcds * chunk_size): + return pid + + local_pid = pid // num_xcds + chunk_idx = local_pid // chunk_size + pos_in_chunk = local_pid % chunk_size + + xcd = pid % num_xcds + new_pid = chunk_idx * num_xcds * chunk_size + xcd * chunk_size + pos_in_chunk + return new_pid + + +@triton.jit() +def persistent_all_gather( + input_ptr, + output_ptr, + M, + N, + stride_in_m, + stride_in_n, + stride_out_m, + stride_out_n, + heap_bases: tl.tensor, + cur_rank: tl.constexpr, + world_size: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + COMM_SMS: tl.constexpr, + NUM_XCDS: tl.constexpr, + CHUNK_SIZE: tl.constexpr, +): + """ + Persistent all-gather kernel. + + Each rank sends its input tensor to all ranks, and all ranks receive + and concatenate all input tensors along dimension 0 (rows), matching + torch.distributed.all_gather_into_tensor behavior. + + Args: + input_ptr: Pointer to input tensor (local rank's data to send) of shape (M, N) + output_ptr: Pointer to output tensor (will receive from all ranks) of shape (world_size * M, N) + M: Number of rows per rank (output will be world_size * M rows) + N: Number of columns + stride_in_m, stride_in_n: Strides for input tensor + stride_out_m, stride_out_n: Strides for output tensor + heap_bases: Heap base pointers for all ranks + cur_rank: Current rank + world_size: Total number of ranks + BLOCK_SIZE_M, BLOCK_SIZE_N: Block sizes for tiling + GROUP_SIZE_M: Group size for M dimension tiling + COMM_SMS: Number of SMs for communication + NUM_XCDS: Number of XCDs + CHUNK_SIZE: Chunk size for chiplet transform + """ + pid = tl.program_id(0) + + if NUM_XCDS != 1: + pid = chiplet_transform_chunked(pid, COMM_SMS, NUM_XCDS, CHUNK_SIZE) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + total_tiles = num_pid_m * num_pid_n + + for tile_id in range(pid, total_tiles, COMM_SMS): + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + # Compute row and column indices for input tensor + rm_base = pid_m * BLOCK_SIZE_M + rn_base = pid_n * BLOCK_SIZE_N + rm_input = rm_base + tl.arange(0, BLOCK_SIZE_M) + rn = rn_base + tl.arange(0, BLOCK_SIZE_N) + rm_input = tl.max_contiguous(tl.multiple_of(rm_input, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + input_mask = (rm_input[:, None] < M) & (rn[None, :] < N) + + # Pre-compute base offsets for input + input_base_m = rm_input[:, None] * stride_in_m + input_base_n = rn[None, :] * stride_in_n + + # Process all ranks + # For each rank, copy its input chunk to the corresponding output location + # on all ranks (including the source rank itself) + # Output concatenates along dimension 0: output[source_rank * M : (source_rank + 1) * M, :] + for source_rank in range(world_size): + # Compute output row indices: offset by source_rank * M + rm_output = rm_input + source_rank * M + # Output mask: check bounds for output tensor (world_size * M rows, N cols) + output_mask = (rm_output[:, None] < (world_size * M)) & (rn[None, :] < N) + + # Input offset: read from source_rank's input tensor + input_offset = input_base_m + input_base_n + input_ptr_source = input_ptr + input_offset + input_ptr_source = tl.multiple_of(input_ptr_source, (BLOCK_SIZE_M, BLOCK_SIZE_N)) + + # Output offset: write to output at rows [source_rank * M : (source_rank + 1) * M] + # This is the same location on all ranks + output_base_m = rm_output[:, None] * stride_out_m + output_base_n = rn[None, :] * stride_out_n + output_offset = output_base_m + output_base_n + output_ptr_target = output_ptr + output_offset + output_ptr_target = tl.multiple_of(output_ptr_target, (BLOCK_SIZE_M, BLOCK_SIZE_N)) + + # Combine masks: must be valid in both input and output + combined_mask = input_mask & output_mask + + if source_rank == cur_rank: + # Local copy: use direct load/store + data = tl.load(input_ptr_source, mask=combined_mask) + tl.store(output_ptr_target, data, mask=combined_mask, cache_modifier=".wt") + else: + # Remote copy: use iris.load to read from source_rank, then store locally + # Note: iris.put reads from local memory, so we can't use it for remote reads + data = iris.load( + input_ptr_source, + cur_rank, + source_rank, + heap_bases, + mask=combined_mask, + ) + tl.store(output_ptr_target, data, mask=combined_mask, cache_modifier=".wt") + + +def all_gather(output_tensor, input_tensor, shmem, config=None, async_op=False): + """ + Internal all-gather collective operation implementation. + + This function is called internally by shmem.ccl.all_gather(). + Users should use the Iris instance method instead: + >>> shmem.ccl.all_gather(output_tensor, input_tensor) + + Each rank sends its input tensor to all ranks, and all ranks receive + and concatenate all input tensors along dimension 0 (rows), matching + torch.distributed.all_gather_into_tensor behavior. + + Args: + output_tensor: Output tensor of shape (world_size * M, N) - will contain concatenated inputs + input_tensor: Input tensor of shape (M, N) - local rank's data to send + shmem: Iris shmem context + config: Config instance with kernel parameters (default: None). + If None, uses default Config values. + async_op: If False, performs a barrier at the end. If True, returns immediately. + Default: False. + """ + # Use provided config or create default one + if config is None: + config = Config() + + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + M, N = input_tensor.shape[:2] + expected_output_shape = (world_size * M, N) + + if output_tensor.shape[:2] != expected_output_shape: + raise ValueError( + f"Output tensor shape {output_tensor.shape[:2]} does not match expected shape {expected_output_shape}. " + f"Expected (world_size * M, N) = ({world_size * M}, {N})" + ) + + stride_in_m, stride_in_n = input_tensor.stride(0), input_tensor.stride(1) + stride_out_m, stride_out_n = output_tensor.stride(0), output_tensor.stride(1) + + heap_bases = shmem.get_heap_bases() + + persistent_all_gather[(config.comm_sms,)]( + input_tensor, + output_tensor, + M, + N, + stride_in_m, + stride_in_n, + stride_out_m, + stride_out_n, + heap_bases, + rank, + world_size, + config.block_size_m, + config.block_size_n, + config.swizzle_size, + config.comm_sms, + config.num_xcds, + config.chunk_size, + ) + + if not async_op: + shmem.barrier() + diff --git a/iris/ccl/all_to_all.py b/iris/ccl/all_to_all.py index 16fb5210..e7879d92 100644 --- a/iris/ccl/all_to_all.py +++ b/iris/ccl/all_to_all.py @@ -8,6 +8,7 @@ import triton import triton.language as tl +import torch import iris from .config import Config @@ -79,6 +80,7 @@ def persistent_all_to_all( GROUP_SIZE_M: Group size for M dimension tiling COMM_SMS: Number of SMs for communication NUM_XCDS: Number of XCDs + CHUNK_SIZE: Chunk size for chiplet transform """ pid = tl.program_id(0) @@ -159,60 +161,59 @@ def persistent_all_to_all( # Process all remote ranks: load each chunk and scatter to corresponding target # Each target_rank may have different input data, so we must load separately for target_rank in range(world_size): - # Skip local rank as it's already processed above if target_rank != cur_rank: - # Traffic shaping: Process tile in 64x64 sub-blocks - # Loop over all sub-blocks to ensure complete coverage - for sub_block_id in range(total_sub_blocks): - # Calculate sub-block position within the tile - sub_block_m = (sub_block_id // num_sub_blocks_n) * SUB_BLOCK_M - sub_block_n = (sub_block_id % num_sub_blocks_n) * SUB_BLOCK_N - - # Compute row and column indices for this 64x64 sub-block - # Start from tile base and add sub-block offset, then create arrays - sub_rm_base = tile_base_m + sub_block_m - sub_rn_base = tile_base_n + sub_block_n - sub_rm = sub_rm_base + tl.arange(0, SUB_BLOCK_M) - sub_rn = sub_rn_base + tl.arange(0, SUB_BLOCK_N) - - # Create mask for this sub-block - sub_mask = ( - (sub_rm[:, None] < M) - & (sub_rn[None, :] < N) - & (sub_rm[:, None] < (tile_base_m + BLOCK_SIZE_M)) - & (sub_rn[None, :] < (tile_base_n + BLOCK_SIZE_N)) - ) - - # Compute offsets for this sub-block - sub_input_base_m = sub_rm[:, None] * stride_in_m - sub_input_base_n = sub_rn[None, :] * stride_in_n - sub_output_base_m = sub_rm[:, None] * stride_out_m - sub_output_base_n = sub_rn[None, :] * stride_out_n - - # Compute input pointer for this target_rank's chunk (sub-block) - sub_input_offset = sub_input_base_m + (sub_input_base_n + target_rank * N * stride_in_n) - sub_input_ptr_send = input_ptr + sub_input_offset - sub_input_ptr_send = tl.multiple_of(sub_input_ptr_send, (SUB_BLOCK_M, SUB_BLOCK_N)) - - # Compute output pointer (sub-block) - sub_output_offset = sub_output_base_m + (sub_output_base_n + cur_rank * N * stride_out_n) - sub_output_ptr_remote = output_ptr + sub_output_offset - sub_output_ptr_remote = tl.multiple_of(sub_output_ptr_remote, (SUB_BLOCK_M, SUB_BLOCK_N)) - - # Load data chunk for this target rank (64x64 sub-block) - sub_data = tl.load(sub_input_ptr_send, mask=sub_mask) - - # Scatter to target rank's output - # Processing in 64x64 sub-blocks creates better memory access patterns - # that allow hardware to distribute traffic across XGMI links - iris.store( - sub_output_ptr_remote, - sub_data, - cur_rank, - target_rank, - heap_bases, - mask=sub_mask, - ) + # Traffic shaping: Process tile in 64x64 sub-blocks + # Loop over all sub-blocks to ensure complete coverage + for sub_block_id in range(total_sub_blocks): + # Calculate sub-block position within the tile + sub_block_m = (sub_block_id // num_sub_blocks_n) * SUB_BLOCK_M + sub_block_n = (sub_block_id % num_sub_blocks_n) * SUB_BLOCK_N + + # Compute row and column indices for this 64x64 sub-block + # Start from tile base and add sub-block offset, then create arrays + sub_rm_base = tile_base_m + sub_block_m + sub_rn_base = tile_base_n + sub_block_n + sub_rm = sub_rm_base + tl.arange(0, SUB_BLOCK_M) + sub_rn = sub_rn_base + tl.arange(0, SUB_BLOCK_N) + + # Create mask for this sub-block + sub_mask = ( + (sub_rm[:, None] < M) + & (sub_rn[None, :] < N) + & (sub_rm[:, None] < (tile_base_m + BLOCK_SIZE_M)) + & (sub_rn[None, :] < (tile_base_n + BLOCK_SIZE_N)) + ) + + # Compute offsets for this sub-block + sub_input_base_m = sub_rm[:, None] * stride_in_m + sub_input_base_n = sub_rn[None, :] * stride_in_n + sub_output_base_m = sub_rm[:, None] * stride_out_m + sub_output_base_n = sub_rn[None, :] * stride_out_n + + # Compute input pointer for this target_rank's chunk (sub-block) + sub_input_offset = sub_input_base_m + (sub_input_base_n + target_rank * N * stride_in_n) + sub_input_ptr_send = input_ptr + sub_input_offset + sub_input_ptr_send = tl.multiple_of(sub_input_ptr_send, (SUB_BLOCK_M, SUB_BLOCK_N)) + + # Compute output pointer (sub-block) + sub_output_offset = sub_output_base_m + (sub_output_base_n + cur_rank * N * stride_out_n) + sub_output_ptr_remote = output_ptr + sub_output_offset + sub_output_ptr_remote = tl.multiple_of(sub_output_ptr_remote, (SUB_BLOCK_M, SUB_BLOCK_N)) + + # Load data chunk for this target rank (64x64 sub-block) + sub_data = tl.load(sub_input_ptr_send, mask=sub_mask) + + # Scatter to target rank's output + # Processing in 64x64 sub-blocks creates better memory access patterns + # that allow hardware to distribute traffic across XGMI links + iris.store( + sub_output_ptr_remote, + sub_data, + cur_rank, + target_rank, + heap_bases, + mask=sub_mask, + ) # Gluon implementation with traffic shaping based on micro-benchmark algorithm diff --git a/iris/experimental/iris_gluon.py b/iris/experimental/iris_gluon.py index f9ab82c4..391a6707 100644 --- a/iris/experimental/iris_gluon.py +++ b/iris/experimental/iris_gluon.py @@ -578,6 +578,36 @@ def all_to_all(self, output_tensor, input_tensor, config=None, async_op=False): _all_to_all(output_tensor, input_tensor, self._iris, config=config, async_op=async_op) + def all_gather(self, output_tensor, input_tensor, config=None, async_op=False): + """ + All-gather collective operation. + + Each rank sends its input tensor to all ranks, and all ranks receive + and concatenate all input tensors along dimension 0 (rows), matching + torch.distributed.all_gather_into_tensor behavior. + + Args: + output_tensor: Output tensor of shape (world_size * M, N) - will contain concatenated inputs + input_tensor: Input tensor of shape (M, N) - local rank's data to send + config: Config instance with kernel parameters (default: None). + If None, uses default Config values. + async_op: If False, performs a barrier at the end. If True, returns immediately. + Default: False. + + Example: + >>> shmem = iris_gluon.iris() + >>> # Input: (M, N), Output: (world_size * M, N) + >>> shmem.ccl.all_gather(output_tensor, input_tensor) + + >>> # Custom configuration + >>> from iris.ccl import Config + >>> config = Config(block_size_m=128, block_size_n=32) + >>> shmem.ccl.all_gather(output_tensor, input_tensor, config=config) + """ + from iris.ccl.all_gather import all_gather as _all_gather + + _all_gather(output_tensor, input_tensor, self._iris, config=config, async_op=async_op) + def _log_with_rank(self, level, message): """Helper method to log with rank information injected into the record.""" extra = {"iris_rank": self.cur_rank, "iris_num_ranks": self.num_ranks} diff --git a/tests/ccl/test_all_gather.py b/tests/ccl/test_all_gather.py new file mode 100644 index 00000000..90b4abe1 --- /dev/null +++ b/tests/ccl/test_all_gather.py @@ -0,0 +1,89 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Test suite for all-gather collective operation. +""" + +import pytest +import torch +import torch.distributed as dist +import iris +from iris.ccl import Config + + +@pytest.mark.parametrize( + "dtype", + [ + torch.float16, + torch.float32, + torch.bfloat16, + ], +) +@pytest.mark.parametrize( + "M, N", + [ + (128, 64), # Small + (1024, 256), # Medium + (8192, 8192), # Large + ], +) +def test_all_gather(dtype, M, N): + """Test all-gather functionality by comparing against PyTorch's implementation.""" + # Ensure torch.distributed is initialized (should be done by test runner) + if not dist.is_initialized(): + pytest.skip("torch.distributed not initialized") + + heap_size = 2**33 # 8GB + shmem = iris.iris(heap_size) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + # PyTorch's all_gather_into_tensor format: each rank has M x N input + # Output is (world_size * M, N) - concatenated along dimension 0 + pytorch_input_tensor = torch.randn(M, N, dtype=dtype, device=f"cuda:{rank}") + # Fill with deterministic values for easier debugging + pytorch_input_tensor.fill_(float(rank + 1)) + + # Create output tensor for PyTorch: (world_size * M, N) + pytorch_output_tensor = torch.zeros(world_size * M, N, dtype=dtype, device=f"cuda:{rank}") + + # Run PyTorch's all_gather_into_tensor to get reference output + shmem.barrier() + dist.all_gather_into_tensor(pytorch_output_tensor, pytorch_input_tensor) + torch.cuda.synchronize() + + # Now set up Iris all_gather format + # Iris format: same as PyTorch - input is (M, N), output is (world_size * M, N) + iris_input_tensor = shmem.zeros((M, N), dtype=dtype) + iris_input_tensor.copy_(pytorch_input_tensor) + + iris_output_tensor = shmem.zeros((world_size * M, N), dtype=dtype) + + # Run Iris all_gather + shmem.barrier() + config = Config() + shmem.ccl.all_gather(iris_output_tensor, iris_input_tensor, config=config) + torch.cuda.synchronize() + + # Compare results + atol = 1e-3 if dtype == torch.float16 else 1e-5 + max_diff = torch.abs(iris_output_tensor - pytorch_output_tensor).max().item() + + try: + assert torch.allclose(iris_output_tensor, pytorch_output_tensor, atol=atol), ( + f"Max difference: {max_diff}, expected < {atol}\n" + f"Rank {rank}: Iris output doesn't match PyTorch's all_gather_into_tensor" + ) + finally: + # Final barrier to ensure all ranks complete before test cleanup + # This helps with test isolation when running multiple tests + # Note: shmem.barrier() already does cuda.synchronize() + shmem.barrier() + # Explicitly delete the shmem instance to trigger cleanup + del shmem + # Force garbage collection to ensure IPC handles are cleaned up + import gc + + gc.collect() + From b9315b029b226246eaf117fce30cf0022d357080 Mon Sep 17 00:00:00 2001 From: neoblizz Date: Mon, 8 Dec 2025 18:28:28 +0000 Subject: [PATCH 02/13] Bump. --- docker/Dockerfile | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 8b49c01a..6ae142dc 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,7 +1,7 @@ # SPDX-License-Identifier: MIT # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. -FROM rocm/pytorch:rocm6.3.1_ubuntu22.04_py3.10_pytorch +FROM rocm/pytorch:rocm7.1_ubuntu24.04_py3.13_pytorch_release_2.9.1 # Use bash shell for RUN commands SHELL ["/bin/bash", "-c"] @@ -31,15 +31,15 @@ RUN pip3 install --upgrade pip && \ # Clone and install Triton WORKDIR $TRITON_PATH RUN git clone https://github.com/triton-lang/triton.git $TRITON_PATH -RUN git checkout dd5823453bcc7973eabadb65f9d827c43281c434 +RUN git checkout 715f6b1d442601436bf8d462db6ff8e17aec8cfb RUN pip3 install -e . ENV PYTHONPATH=$TRITON_PATH # Install rocprofiler-systems WORKDIR /workspace -RUN wget https://github.com/ROCm/rocprofiler-systems/releases/download/rocm-6.3.1/rocprofiler-systems-install.py && \ - python3 ./rocprofiler-systems-install.py --prefix /opt/rocprofiler-systems --rocm 6.3 && \ - rm -f rocprofiler-systems-install.py +# RUN wget https://github.com/ROCm/rocprofiler-systems/releases/latest/download/rocprofiler-systems-install.py && \ +# python3 ./rocprofiler-systems-install.py --prefix /opt/rocprofiler-systems --rocm 7.1 && \ +# rm -f rocprofiler-systems-install.py # Create entrypoint script RUN echo '#!/bin/bash' > /entrypoint.sh && \ From 5f08767fc2f1bf76482e76240070e6e345d9c0ae Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 8 Dec 2025 18:39:03 +0000 Subject: [PATCH 03/13] Apply Ruff auto-fixes --- benchmark/ccl/all_gather/benchmark.py | 7 +- iris/ccl/all_gather.py | 5 +- iris/ccl/all_to_all.py | 104 +++++++++++++------------- tests/ccl/test_all_gather.py | 1 - 4 files changed, 59 insertions(+), 58 deletions(-) diff --git a/benchmark/ccl/all_gather/benchmark.py b/benchmark/ccl/all_gather/benchmark.py index 8e551e50..10844e91 100644 --- a/benchmark/ccl/all_gather/benchmark.py +++ b/benchmark/ccl/all_gather/benchmark.py @@ -56,7 +56,11 @@ def parse_args(): ) parser.add_argument("--heap_size", type=int, default=1 << 34, help="Iris heap size") parser.add_argument("--comm_sms", type=int, default=64, help="Number of SMs for all-gather kernel") - parser.add_argument("--benchmark_rccl", action="store_true", help="Also benchmark PyTorch RCCL (all_gather_into_tensor) for comparison") + parser.add_argument( + "--benchmark_rccl", + action="store_true", + help="Also benchmark PyTorch RCCL (all_gather_into_tensor) for comparison", + ) parser.add_argument("--block_size_m", type=int, default=None, help="Block size for M dimension tiling") parser.add_argument("--block_size_n", type=int, default=None, help="Block size for N dimension tiling") parser.add_argument("--swizzle_size", type=int, default=None, help="Number of tiles to swizzle together") @@ -344,4 +348,3 @@ def main(): if __name__ == "__main__": main() - diff --git a/iris/ccl/all_gather.py b/iris/ccl/all_gather.py index 7a3d200a..255e49df 100644 --- a/iris/ccl/all_gather.py +++ b/iris/ccl/all_gather.py @@ -112,7 +112,7 @@ def persistent_all_gather( rm_output = rm_input + source_rank * M # Output mask: check bounds for output tensor (world_size * M rows, N cols) output_mask = (rm_output[:, None] < (world_size * M)) & (rn[None, :] < N) - + # Input offset: read from source_rank's input tensor input_offset = input_base_m + input_base_n input_ptr_source = input_ptr + input_offset @@ -176,7 +176,7 @@ def all_gather(output_tensor, input_tensor, shmem, config=None, async_op=False): M, N = input_tensor.shape[:2] expected_output_shape = (world_size * M, N) - + if output_tensor.shape[:2] != expected_output_shape: raise ValueError( f"Output tensor shape {output_tensor.shape[:2]} does not match expected shape {expected_output_shape}. " @@ -210,4 +210,3 @@ def all_gather(output_tensor, input_tensor, shmem, config=None, async_op=False): if not async_op: shmem.barrier() - diff --git a/iris/ccl/all_to_all.py b/iris/ccl/all_to_all.py index e7879d92..38a5613d 100644 --- a/iris/ccl/all_to_all.py +++ b/iris/ccl/all_to_all.py @@ -162,58 +162,58 @@ def persistent_all_to_all( # Each target_rank may have different input data, so we must load separately for target_rank in range(world_size): if target_rank != cur_rank: - # Traffic shaping: Process tile in 64x64 sub-blocks - # Loop over all sub-blocks to ensure complete coverage - for sub_block_id in range(total_sub_blocks): - # Calculate sub-block position within the tile - sub_block_m = (sub_block_id // num_sub_blocks_n) * SUB_BLOCK_M - sub_block_n = (sub_block_id % num_sub_blocks_n) * SUB_BLOCK_N - - # Compute row and column indices for this 64x64 sub-block - # Start from tile base and add sub-block offset, then create arrays - sub_rm_base = tile_base_m + sub_block_m - sub_rn_base = tile_base_n + sub_block_n - sub_rm = sub_rm_base + tl.arange(0, SUB_BLOCK_M) - sub_rn = sub_rn_base + tl.arange(0, SUB_BLOCK_N) - - # Create mask for this sub-block - sub_mask = ( - (sub_rm[:, None] < M) - & (sub_rn[None, :] < N) - & (sub_rm[:, None] < (tile_base_m + BLOCK_SIZE_M)) - & (sub_rn[None, :] < (tile_base_n + BLOCK_SIZE_N)) - ) - - # Compute offsets for this sub-block - sub_input_base_m = sub_rm[:, None] * stride_in_m - sub_input_base_n = sub_rn[None, :] * stride_in_n - sub_output_base_m = sub_rm[:, None] * stride_out_m - sub_output_base_n = sub_rn[None, :] * stride_out_n - - # Compute input pointer for this target_rank's chunk (sub-block) - sub_input_offset = sub_input_base_m + (sub_input_base_n + target_rank * N * stride_in_n) - sub_input_ptr_send = input_ptr + sub_input_offset - sub_input_ptr_send = tl.multiple_of(sub_input_ptr_send, (SUB_BLOCK_M, SUB_BLOCK_N)) - - # Compute output pointer (sub-block) - sub_output_offset = sub_output_base_m + (sub_output_base_n + cur_rank * N * stride_out_n) - sub_output_ptr_remote = output_ptr + sub_output_offset - sub_output_ptr_remote = tl.multiple_of(sub_output_ptr_remote, (SUB_BLOCK_M, SUB_BLOCK_N)) - - # Load data chunk for this target rank (64x64 sub-block) - sub_data = tl.load(sub_input_ptr_send, mask=sub_mask) - - # Scatter to target rank's output - # Processing in 64x64 sub-blocks creates better memory access patterns - # that allow hardware to distribute traffic across XGMI links - iris.store( - sub_output_ptr_remote, - sub_data, - cur_rank, - target_rank, - heap_bases, - mask=sub_mask, - ) + # Traffic shaping: Process tile in 64x64 sub-blocks + # Loop over all sub-blocks to ensure complete coverage + for sub_block_id in range(total_sub_blocks): + # Calculate sub-block position within the tile + sub_block_m = (sub_block_id // num_sub_blocks_n) * SUB_BLOCK_M + sub_block_n = (sub_block_id % num_sub_blocks_n) * SUB_BLOCK_N + + # Compute row and column indices for this 64x64 sub-block + # Start from tile base and add sub-block offset, then create arrays + sub_rm_base = tile_base_m + sub_block_m + sub_rn_base = tile_base_n + sub_block_n + sub_rm = sub_rm_base + tl.arange(0, SUB_BLOCK_M) + sub_rn = sub_rn_base + tl.arange(0, SUB_BLOCK_N) + + # Create mask for this sub-block + sub_mask = ( + (sub_rm[:, None] < M) + & (sub_rn[None, :] < N) + & (sub_rm[:, None] < (tile_base_m + BLOCK_SIZE_M)) + & (sub_rn[None, :] < (tile_base_n + BLOCK_SIZE_N)) + ) + + # Compute offsets for this sub-block + sub_input_base_m = sub_rm[:, None] * stride_in_m + sub_input_base_n = sub_rn[None, :] * stride_in_n + sub_output_base_m = sub_rm[:, None] * stride_out_m + sub_output_base_n = sub_rn[None, :] * stride_out_n + + # Compute input pointer for this target_rank's chunk (sub-block) + sub_input_offset = sub_input_base_m + (sub_input_base_n + target_rank * N * stride_in_n) + sub_input_ptr_send = input_ptr + sub_input_offset + sub_input_ptr_send = tl.multiple_of(sub_input_ptr_send, (SUB_BLOCK_M, SUB_BLOCK_N)) + + # Compute output pointer (sub-block) + sub_output_offset = sub_output_base_m + (sub_output_base_n + cur_rank * N * stride_out_n) + sub_output_ptr_remote = output_ptr + sub_output_offset + sub_output_ptr_remote = tl.multiple_of(sub_output_ptr_remote, (SUB_BLOCK_M, SUB_BLOCK_N)) + + # Load data chunk for this target rank (64x64 sub-block) + sub_data = tl.load(sub_input_ptr_send, mask=sub_mask) + + # Scatter to target rank's output + # Processing in 64x64 sub-blocks creates better memory access patterns + # that allow hardware to distribute traffic across XGMI links + iris.store( + sub_output_ptr_remote, + sub_data, + cur_rank, + target_rank, + heap_bases, + mask=sub_mask, + ) # Gluon implementation with traffic shaping based on micro-benchmark algorithm diff --git a/tests/ccl/test_all_gather.py b/tests/ccl/test_all_gather.py index 90b4abe1..ae649043 100644 --- a/tests/ccl/test_all_gather.py +++ b/tests/ccl/test_all_gather.py @@ -86,4 +86,3 @@ def test_all_gather(dtype, M, N): import gc gc.collect() - From 71f4095be52a113fadc03a75a3b710c87776799f Mon Sep 17 00:00:00 2001 From: neoblizz Date: Mon, 8 Dec 2025 21:02:32 +0000 Subject: [PATCH 04/13] reduce-scatter + missing iris.py modifications. --- iris/ccl/all_gather.py | 8 ++ iris/ccl/all_reduce.py | 8 ++ iris/ccl/config.py | 9 ++ iris/ccl/reduce_scatter.py | 216 ++++++++++++++++++++++++++++++++ iris/experimental/iris_gluon.py | 30 +++++ iris/iris.py | 67 +++++++++- 6 files changed, 336 insertions(+), 2 deletions(-) create mode 100644 iris/ccl/reduce_scatter.py diff --git a/iris/ccl/all_gather.py b/iris/ccl/all_gather.py index 255e49df..6c66e70d 100644 --- a/iris/ccl/all_gather.py +++ b/iris/ccl/all_gather.py @@ -171,6 +171,14 @@ def all_gather(output_tensor, input_tensor, shmem, config=None, async_op=False): if config is None: config = Config() + # Check for unsupported options + if config.use_gluon: + raise ValueError( + "all_gather does not support use_gluon=True. " + "Gluon implementation is not available for all_gather. " + "Use default config (use_gluon=False)." + ) + rank = shmem.get_rank() world_size = shmem.get_num_ranks() diff --git a/iris/ccl/all_reduce.py b/iris/ccl/all_reduce.py index 1c96e985..7338fc5d 100644 --- a/iris/ccl/all_reduce.py +++ b/iris/ccl/all_reduce.py @@ -669,6 +669,14 @@ def all_reduce( if config is None: config = Config() + # Check for unsupported options + if config.use_gluon: + raise ValueError( + "all_reduce does not support use_gluon=True. " + "Gluon implementation is not available for all_reduce. " + "Use default config (use_gluon=False)." + ) + rank = shmem.get_rank() world_size = shmem.get_num_ranks() M, N = input_tensor.shape[:2] diff --git a/iris/ccl/config.py b/iris/ccl/config.py index 48c156c4..a87fb18a 100644 --- a/iris/ccl/config.py +++ b/iris/ccl/config.py @@ -38,6 +38,8 @@ class Config: all_reduce_num_rings: Number of concurrent rings to form in ring-based all-reduce (default: 1) all_reduce_ring_slice_n: Column slice size for ring reduce-scatter/all-gather (default: auto-set to block_size_n // world_size at runtime) + reduce_scatter_variant: Variant for reduce-scatter operation (default: "two_shot") + Only "two_shot" is supported Example: >>> import iris @@ -68,6 +70,7 @@ class Config: all_reduce_distribution: int = 0 all_reduce_num_rings: int = 1 all_reduce_ring_slice_n: int | None = None + reduce_scatter_variant: str = "two_shot" def __post_init__(self): """Validate and auto-detect num_xcds if not set.""" @@ -109,3 +112,9 @@ def __post_init__(self): ) if self.all_reduce_ring_slice_n & (self.all_reduce_ring_slice_n - 1): raise ValueError(f"all_reduce_ring_slice_n must be a power of two, got {self.all_reduce_ring_slice_n}") + + # Validate reduce_scatter_variant + if self.reduce_scatter_variant != "two_shot": + raise ValueError( + f"reduce_scatter_variant must be 'two_shot', got '{self.reduce_scatter_variant}'" + ) diff --git a/iris/ccl/reduce_scatter.py b/iris/ccl/reduce_scatter.py new file mode 100644 index 00000000..21799911 --- /dev/null +++ b/iris/ccl/reduce_scatter.py @@ -0,0 +1,216 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Reduce-scatter collective communication primitive for Iris. +Uses the two-shot approach: reduce assigned tiles and store only to own rank. +""" + +import triton +import triton.language as tl +import torch +import iris +from .config import Config + + +@triton.jit() +def chiplet_transform_chunked(pid, num_workgroups: tl.constexpr, num_xcds: tl.constexpr, chunk_size: tl.constexpr): + if pid > (num_workgroups // (num_xcds * chunk_size)) * (num_xcds * chunk_size): + return pid + + local_pid = pid // num_xcds + chunk_idx = local_pid // chunk_size + pos_in_chunk = local_pid % chunk_size + + xcd = pid % num_xcds + new_pid = chunk_idx * num_xcds * chunk_size + xcd * chunk_size + pos_in_chunk + return new_pid + + +@triton.jit() +def persistent_reduce_scatter_two_shot( + input_ptr, + output_ptr, + M, + N, + stride_in_m, + stride_in_n, + stride_out_m, + stride_out_n, + heap_bases: tl.tensor, + cur_rank: tl.constexpr, + world_size: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + COMM_SMS: tl.constexpr, + NUM_XCDS: tl.constexpr, + CHUNK_SIZE: tl.constexpr, + DISTRIBUTION: tl.constexpr, +): + """ + Reduce-scatter using two-shot approach. + + Each rank reduces its assigned tiles from all ranks and stores the result + only to its own output (no broadcast to other ranks). + """ + pid = tl.program_id(0) + + if NUM_XCDS != 1: + pid = chiplet_transform_chunked(pid, COMM_SMS, NUM_XCDS, CHUNK_SIZE) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + total_tiles = num_pid_m * num_pid_n + + acc_dtype = tl.float32 if output_ptr.type.element_ty != tl.int8 else tl.int32 + + tiles_per_rank = tl.cdiv(total_tiles, world_size) + if DISTRIBUTION == 0: + start_tile = cur_rank + stride = world_size + remaining = total_tiles - start_tile + remaining = tl.maximum(remaining, 0) + max_tile_offset = tl.cdiv(remaining, stride) + else: + start_tile = cur_rank * tiles_per_rank + stride = 1 + remaining = total_tiles - start_tile + remaining = tl.maximum(remaining, 0) + max_tile_offset = tl.minimum(tiles_per_rank, remaining) + + for tile_offset in range(pid, max_tile_offset, COMM_SMS): + tile_id = start_tile + tile_offset * stride + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + rm_base = pid_m * BLOCK_SIZE_M + rn_base = pid_n * BLOCK_SIZE_N + rm = rm_base + tl.arange(0, BLOCK_SIZE_M) + rn = rn_base + tl.arange(0, BLOCK_SIZE_N) + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + mask = (rm[:, None] < M) & (rn[None, :] < N) + + input_offset = rm[:, None] * stride_in_m + rn[None, :] * stride_in_n + output_offset = rm[:, None] * stride_out_m + rn[None, :] * stride_out_n + + # Reduce: sum contributions from all ranks + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + for remote_rank in range(world_size): + partial = iris.load( + input_ptr + input_offset, + cur_rank, + remote_rank, + heap_bases, + mask=mask, + ) + acc += partial.to(acc_dtype) + + reduced = acc.to(output_ptr.type.element_ty) + + # Store only to own rank (no broadcast) + tl.store(output_ptr + output_offset, reduced, mask=mask, cache_modifier=".wt") + + +def reduce_scatter(output_tensor, input_tensor, shmem, config=None, async_op=False): + """ + Internal reduce-scatter collective operation implementation. + + This function is called internally by shmem.ccl.reduce_scatter(). + Users should use the Iris instance method instead: + >>> shmem.ccl.reduce_scatter(output_tensor, input_tensor) + + Each rank reduces its assigned tiles from all ranks' inputs and stores + the result only to its own output tensor. This is similar to all-reduce + but without broadcasting the result to all ranks. + + Args: + output_tensor: Output tensor of shape (M, N) - will contain reduced tiles for this rank + input_tensor: Input tensor of shape (M, N) - local rank's partial data + shmem: Iris shmem context + config: Config instance with kernel parameters (default: None). + If None, uses default Config values. + Only supports reduce_scatter_variant="two_shot". + async_op: If False, performs a barrier at the end. If True, returns immediately. + Default: False. + + Example: + >>> shmem = iris.iris() + >>> shmem.ccl.reduce_scatter(output_tensor, input_tensor) + + >>> # Custom configuration + >>> from iris.ccl import Config + >>> config = Config(reduce_scatter_variant="two_shot", all_reduce_distribution=1) + >>> shmem.ccl.reduce_scatter(output_tensor, input_tensor, config=config) + """ + if config is None: + config = Config() + + # Check for unsupported options + if config.use_gluon: + raise ValueError( + "reduce_scatter does not support use_gluon=True. " + "Gluon implementation is not available for reduce_scatter. " + "Use default config (use_gluon=False)." + ) + + # Validate that only two_shot variant is used + variant = getattr(config, "reduce_scatter_variant", "two_shot") + if variant != "two_shot": + raise ValueError( + f"reduce_scatter only supports variant='two_shot', got '{variant}'. " + f"Set config.reduce_scatter_variant='two_shot' or use default config." + ) + + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + M, N = input_tensor.shape[:2] + + # Validate output shape matches input shape + if output_tensor.shape[:2] != (M, N): + raise ValueError( + f"Output tensor shape {output_tensor.shape[:2]} does not match input shape {(M, N)}. " + f"For reduce-scatter, output should have the same shape as input." + ) + + stride_in_m, stride_in_n = input_tensor.stride(0), input_tensor.stride(1) + stride_out_m, stride_out_n = output_tensor.stride(0), output_tensor.stride(1) + + heap_bases = shmem.get_heap_bases() + + # Use all_reduce_distribution for tile distribution + distribution = config.all_reduce_distribution + + persistent_reduce_scatter_two_shot[(config.comm_sms,)]( + input_tensor, + output_tensor, + M, + N, + stride_in_m, + stride_in_n, + stride_out_m, + stride_out_n, + heap_bases, + rank, + world_size, + config.block_size_m, + config.block_size_n, + config.swizzle_size, + config.comm_sms, + config.num_xcds, + config.chunk_size, + distribution, + ) + + if not async_op: + shmem.barrier() + diff --git a/iris/experimental/iris_gluon.py b/iris/experimental/iris_gluon.py index 391a6707..63207943 100644 --- a/iris/experimental/iris_gluon.py +++ b/iris/experimental/iris_gluon.py @@ -608,6 +608,36 @@ def all_gather(self, output_tensor, input_tensor, config=None, async_op=False): _all_gather(output_tensor, input_tensor, self._iris, config=config, async_op=async_op) + def reduce_scatter(self, output_tensor, input_tensor, config=None, async_op=False): + """ + Reduce-scatter collective operation. + + Each rank reduces its assigned tiles from all ranks' inputs and stores + the result only to its own output tensor. This is similar to all-reduce + but without broadcasting the result to all ranks. + + Args: + output_tensor: Output tensor of shape (M, N) - will contain reduced tiles for this rank + input_tensor: Input tensor of shape (M, N) - local rank's partial data + config: Config instance with kernel parameters (default: None). + If None, uses default Config values. + Only supports reduce_scatter_variant="two_shot". + async_op: If False, performs a barrier at the end. If True, returns immediately. + Default: False. + + Example: + >>> shmem = iris_gluon.iris() + >>> shmem.ccl.reduce_scatter(output_tensor, input_tensor) + + >>> # Custom configuration + >>> from iris.ccl import Config + >>> config = Config(reduce_scatter_variant="two_shot", all_reduce_distribution=1) + >>> shmem.ccl.reduce_scatter(output_tensor, input_tensor, config=config) + """ + from iris.ccl.reduce_scatter import reduce_scatter as _reduce_scatter + + _reduce_scatter(output_tensor, input_tensor, self._iris, config=config, async_op=async_op) + def _log_with_rank(self, level, message): """Helper method to log with rank information injected into the record.""" extra = {"iris_rank": self.cur_rank, "iris_num_ranks": self.num_ranks} diff --git a/iris/iris.py b/iris/iris.py index 9f6c574e..45aef5c5 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -1548,6 +1548,39 @@ def all_to_all(self, output_tensor, input_tensor, config=None, async_op=False): _all_to_all(output_tensor, input_tensor, self._iris, config=config, async_op=async_op) + def all_gather(self, output_tensor, input_tensor, config=None, async_op=False): + """ + All-gather collective operation. + + Each rank sends its input tensor to all ranks, and all ranks receive + and concatenate all input tensors along dimension 0 (rows), matching + torch.distributed.all_gather_into_tensor behavior. + + Args: + output_tensor: Output tensor of shape (world_size * M, N) - will contain concatenated inputs + input_tensor: Input tensor of shape (M, N) - local rank's data to send + config: Config instance with kernel parameters (default: None). + If None, uses default Config values. + async_op: If False, performs a barrier at the end. If True, returns immediately. + Default: False. + + Example: + >>> shmem = iris.iris() + >>> # Input: (M, N), Output: (world_size * M, N) + >>> shmem.ccl.all_gather(output_tensor, input_tensor) + + >>> # Custom configuration + >>> from iris.ccl import Config + >>> config = Config(block_size_m=128, block_size_n=32) + >>> shmem.ccl.all_gather(output_tensor, input_tensor, config=config) + + >>> # Async operation (no barrier) + >>> shmem.ccl.all_gather(output_tensor, input_tensor, async_op=True) + """ + from iris.ccl.all_gather import all_gather as _all_gather + + _all_gather(output_tensor, input_tensor, self._iris, config=config, async_op=async_op) + def all_reduce_preamble(self, output_tensor, input_tensor, config=None, workspace=None): """ Prepare reusable workspace for all-reduce. @@ -1616,6 +1649,36 @@ def all_reduce(self, output_tensor, input_tensor, config=None, async_op=False, w workspace=workspace, ) + def reduce_scatter(self, output_tensor, input_tensor, config=None, async_op=False): + """ + Reduce-scatter collective operation. + + Each rank reduces its assigned tiles from all ranks' inputs and stores + the result only to its own output tensor. This is similar to all-reduce + but without broadcasting the result to all ranks. + + Args: + output_tensor: Output tensor of shape (M, N) - will contain reduced tiles for this rank + input_tensor: Input tensor of shape (M, N) - local rank's partial data + config: Config instance with kernel parameters (default: None). + If None, uses default Config values. + Only supports reduce_scatter_variant="two_shot". + async_op: If False, performs a barrier at the end. If True, returns immediately. + Default: False. + + Example: + >>> shmem = iris.iris() + >>> shmem.ccl.reduce_scatter(output_tensor, input_tensor) + + >>> # Custom configuration + >>> from iris.ccl import Config + >>> config = Config(reduce_scatter_variant="two_shot", all_reduce_distribution=1) + >>> shmem.ccl.reduce_scatter(output_tensor, input_tensor, config=config) + """ + from iris.ccl.reduce_scatter import reduce_scatter as _reduce_scatter + + _reduce_scatter(output_tensor, input_tensor, self._iris, config=config, async_op=async_op) + @triton.jit def __translate(ptr, from_rank, to_rank, heap_bases): @@ -1634,8 +1697,8 @@ def __translate(ptr, from_rank, to_rank, heap_bases): # Optimization to vectorize the load/store # We can't do this in general because we don't know the shape of the tensor - # ptr = tl.max_contiguous(tl.multiple_of(ptr, (64, 64)), (64, 64)) - # translated_ptr = tl.max_contiguous(tl.multiple_of(translated_ptr, (64, 64)), (64, 64)) + ptr = tl.max_contiguous(tl.multiple_of(ptr, (64, 64)), (64, 64)) + translated_ptr = tl.max_contiguous(tl.multiple_of(translated_ptr, (64, 64)), (64, 64)) # ptr = tl.max_contiguous(tl.multiple_of(ptr, 512), 512) # translated_ptr = tl.max_contiguous(tl.multiple_of(translated_ptr, 512), 512) From 464e14d328bcaef5cbd7be3dd578aa7d689997ec Mon Sep 17 00:00:00 2001 From: neoblizz Date: Mon, 8 Dec 2025 21:51:25 +0000 Subject: [PATCH 05/13] Revert translate, minor all-gather changes. --- iris/ccl/all_gather.py | 59 ++++++++++++++++++++++-------------------- iris/iris.py | 4 +-- 2 files changed, 33 insertions(+), 30 deletions(-) diff --git a/iris/ccl/all_gather.py b/iris/ccl/all_gather.py index 6c66e70d..a1bd61f6 100644 --- a/iris/ccl/all_gather.py +++ b/iris/ccl/all_gather.py @@ -90,60 +90,63 @@ def persistent_all_gather( tl.assume(pid_m >= 0) tl.assume(pid_n >= 0) - # Compute row and column indices for input tensor + # Compute local row and column indices for input tensor rm_base = pid_m * BLOCK_SIZE_M rn_base = pid_n * BLOCK_SIZE_N rm_input = rm_base + tl.arange(0, BLOCK_SIZE_M) rn = rn_base + tl.arange(0, BLOCK_SIZE_N) rm_input = tl.max_contiguous(tl.multiple_of(rm_input, BLOCK_SIZE_M), BLOCK_SIZE_M) rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + + # Mask for local input bounds input_mask = (rm_input[:, None] < M) & (rn[None, :] < N) - # Pre-compute base offsets for input + # Compute input offset and load local shard data once + # Each rank loads its own input data and then broadcasts it to all ranks input_base_m = rm_input[:, None] * stride_in_m input_base_n = rn[None, :] * stride_in_n - - # Process all ranks - # For each rank, copy its input chunk to the corresponding output location - # on all ranks (including the source rank itself) - # Output concatenates along dimension 0: output[source_rank * M : (source_rank + 1) * M, :] - for source_rank in range(world_size): - # Compute output row indices: offset by source_rank * M - rm_output = rm_input + source_rank * M + input_offset = input_base_m + input_base_n + input_ptr_source = input_ptr + input_offset + input_ptr_source = tl.multiple_of(input_ptr_source, (BLOCK_SIZE_M, BLOCK_SIZE_N)) + + # Load local input data once for this tile + data = tl.load(input_ptr_source, mask=input_mask, other=0.0) + + # Send local shard data to all destination ranks + # Each rank's input goes to output[cur_rank * M : (cur_rank + 1) * M, :] on all ranks + for rank in range(world_size): + # Compute global output row indices: offset by cur_rank * M + # This rank's data should be placed at output[cur_rank * M : (cur_rank + 1) * M, :] + rm_output = rm_input + cur_rank * M + # Output mask: check bounds for output tensor (world_size * M rows, N cols) output_mask = (rm_output[:, None] < (world_size * M)) & (rn[None, :] < N) + + # Combine masks: must be valid in both input and output + combined_mask = input_mask & output_mask - # Input offset: read from source_rank's input tensor - input_offset = input_base_m + input_base_n - input_ptr_source = input_ptr + input_offset - input_ptr_source = tl.multiple_of(input_ptr_source, (BLOCK_SIZE_M, BLOCK_SIZE_N)) - - # Output offset: write to output at rows [source_rank * M : (source_rank + 1) * M] - # This is the same location on all ranks + # Compute output offset: write to output at rows [cur_rank * M : (cur_rank + 1) * M] + # This is the same location on all destination ranks output_base_m = rm_output[:, None] * stride_out_m output_base_n = rn[None, :] * stride_out_n output_offset = output_base_m + output_base_n output_ptr_target = output_ptr + output_offset output_ptr_target = tl.multiple_of(output_ptr_target, (BLOCK_SIZE_M, BLOCK_SIZE_N)) - # Combine masks: must be valid in both input and output - combined_mask = input_mask & output_mask - - if source_rank == cur_rank: - # Local copy: use direct load/store - data = tl.load(input_ptr_source, mask=combined_mask) + if rank == cur_rank: + # Local destination: use direct store tl.store(output_ptr_target, data, mask=combined_mask, cache_modifier=".wt") else: - # Remote copy: use iris.load to read from source_rank, then store locally - # Note: iris.put reads from local memory, so we can't use it for remote reads - data = iris.load( + # Remote destination: use iris.put to send from local source to remote destination + # from_ptr: local input source, to_ptr: remote output destination + iris.put( input_ptr_source, + output_ptr_target, cur_rank, - source_rank, + rank, heap_bases, mask=combined_mask, ) - tl.store(output_ptr_target, data, mask=combined_mask, cache_modifier=".wt") def all_gather(output_tensor, input_tensor, shmem, config=None, async_op=False): diff --git a/iris/iris.py b/iris/iris.py index 45aef5c5..4c87e071 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -1697,8 +1697,8 @@ def __translate(ptr, from_rank, to_rank, heap_bases): # Optimization to vectorize the load/store # We can't do this in general because we don't know the shape of the tensor - ptr = tl.max_contiguous(tl.multiple_of(ptr, (64, 64)), (64, 64)) - translated_ptr = tl.max_contiguous(tl.multiple_of(translated_ptr, (64, 64)), (64, 64)) + # ptr = tl.max_contiguous(tl.multiple_of(ptr, (64, 64)), (64, 64)) + # translated_ptr = tl.max_contiguous(tl.multiple_of(translated_ptr, (64, 64)), (64, 64)) # ptr = tl.max_contiguous(tl.multiple_of(ptr, 512), 512) # translated_ptr = tl.max_contiguous(tl.multiple_of(translated_ptr, 512), 512) From 0f26f0189a92a603c48f28ea57b79f3b2d6aa552 Mon Sep 17 00:00:00 2001 From: neoblizz Date: Tue, 9 Dec 2025 02:04:14 +0000 Subject: [PATCH 06/13] ... --- benchmark/ccl/all_reduce/benchmark.py | 10 +- benchmark/ccl/reduce_scatter/benchmark.py | 404 ++++++++++++++++++++++ 2 files changed, 409 insertions(+), 5 deletions(-) create mode 100755 benchmark/ccl/reduce_scatter/benchmark.py diff --git a/benchmark/ccl/all_reduce/benchmark.py b/benchmark/ccl/all_reduce/benchmark.py index edecd1c8..bdf1f1e2 100755 --- a/benchmark/ccl/all_reduce/benchmark.py +++ b/benchmark/ccl/all_reduce/benchmark.py @@ -47,16 +47,16 @@ def parse_args(): help="Output file", ) parser.add_argument("--heap_size", type=int, default=1 << 34, help="Iris heap size") - parser.add_argument("--comm_sms", type=int, default=32, help="Number of SMs for all-reduce kernel") - parser.add_argument("--block_size_m", type=int, default=None, help="Block size for M dimension tiling") - parser.add_argument("--block_size_n", type=int, default=None, help="Block size for N dimension tiling") - parser.add_argument("--swizzle_size", type=int, default=None, help="Number of tiles to swizzle together") + parser.add_argument("--comm_sms", type=int, default=64, help="Number of SMs for all-reduce kernel") + parser.add_argument("--block_size_m", type=int, default=128, help="Block size for M dimension tiling") + parser.add_argument("--block_size_n", type=int, default=128, help="Block size for N dimension tiling") + parser.add_argument("--swizzle_size", type=int, default=1, help="Number of tiles to swizzle together") parser.add_argument("--num_xcds", type=int, default=None, help="Number of XCDs (auto-detected if not set)") parser.add_argument("-r", "--num_ranks", type=int, default=8, help="Number of ranks/processes") parser.add_argument( "--variant", type=str, - default="atomic", + default="two_shot", choices=["atomic", "ring", "two_shot", "one_shot", "spinlock"], help="All-reduce variant to use", ) diff --git a/benchmark/ccl/reduce_scatter/benchmark.py b/benchmark/ccl/reduce_scatter/benchmark.py new file mode 100755 index 00000000..200f2a8c --- /dev/null +++ b/benchmark/ccl/reduce_scatter/benchmark.py @@ -0,0 +1,404 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Benchmark for iris-ccl reduce-scatter collective operation. + +This benchmark showcases the reduce-scatter collective and reports achieved bandwidth. +""" + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import random +import argparse + +from examples.common.utils import JSONWriter + +import iris +from iris.ccl import Config + +# Conditional import for Gluon +try: + import iris.experimental.iris_gluon as iris_gluon + + GLUON_AVAILABLE = True +except ImportError: + GLUON_AVAILABLE = False + +torch.manual_seed(123) +random.seed(123) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark reduce-scatter collective operation.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=16384, help="Number of rows in input tensors") + parser.add_argument("-n", type=int, default=16384, help="Number of columns in input tensors") + parser.add_argument("-d", "--debug", action="store_true", help="Enable debug mode") + parser.add_argument("-v", "--validate", action="store_true", help="Enable validation mode") + parser.add_argument("-b", "--benchmark", action="store_true", help="Enable benchmarking mode") + parser.add_argument( + "--datatype", + type=str, + default="fp16", + choices=["fp16", "fp32", "bf16"], + help="Datatype of tensors", + ) + parser.add_argument( + "--output_file", + type=str, + default="log.json", + help="Output file", + ) + parser.add_argument("--heap_size", type=int, default=1 << 34, help="Iris heap size") + parser.add_argument("--comm_sms", type=int, default=64, help="Number of SMs for reduce-scatter kernel") + parser.add_argument( + "--benchmark_rccl", + action="store_true", + help="Also benchmark PyTorch RCCL (reduce_scatter) for comparison", + ) + parser.add_argument("--block_size_m", type=int, default=64, help="Block size for M dimension tiling (default: 64)") + parser.add_argument("--block_size_n", type=int, default=64, help="Block size for N dimension tiling (default: 64)") + parser.add_argument("--swizzle_size", type=int, default=8, help="Number of tiles to swizzle together (default: 8)") + parser.add_argument("--num_xcds", type=int, default=None, help="Number of XCDs (auto-detected if not set)") + parser.add_argument( + "--all_reduce_distribution", + type=int, + default=0, + choices=[0, 1], + help="Distribution mode for two-shot reduce-scatter: 0=striding (default), 1=block", + ) + parser.add_argument("-r", "--num_ranks", type=int, default=8, help="Number of ranks/processes") + parser.add_argument("--use_gluon", action="store_true", help="Use Gluon implementation with traffic shaping") + + return vars(parser.parse_args()) + + +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) + + # Use Gluon if requested and available + if args.get("use_gluon", False): + if not GLUON_AVAILABLE: + raise RuntimeError("Gluon is not available. Install Triton with Gluon support or remove --use_gluon flag") + shmem = iris_gluon.iris(args["heap_size"]) + else: + shmem = iris.iris(args["heap_size"]) + + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + # Datatype mapping + datatype = torch.float32 + if args["datatype"] == "fp16": + datatype = torch.float16 + elif args["datatype"] == "fp32": + datatype = torch.float32 + elif args["datatype"] == "bf16": + datatype = torch.bfloat16 + else: + print("Unknown datatype.") + exit(1) + + M = args["m"] + N = args["n"] + + # Create config with optimized defaults for reduce-scatter + config_kwargs = { + "comm_sms": args["comm_sms"], + "all_reduce_distribution": args["all_reduce_distribution"], + "block_size_m": args["block_size_m"], + "block_size_n": args["block_size_n"], + "swizzle_size": args["swizzle_size"], + } + if args["num_xcds"] is not None: + config_kwargs["num_xcds"] = args["num_xcds"] + if args.get("use_gluon", False): + config_kwargs["use_gluon"] = True + + config = Config(**config_kwargs) + + json_writer = JSONWriter(args["output_file"]) + json_writer.add_field("world_size", world_size) + + for key, value in args.items(): + json_writer.add_field(key, value) + + # Export config values to JSON (use actual values from config, including defaults) + json_writer.add_field("block_size_m", config.block_size_m) + json_writer.add_field("block_size_n", config.block_size_n) + json_writer.add_field("swizzle_size", config.swizzle_size) + json_writer.add_field("num_xcds", config.num_xcds) + json_writer.add_field("use_gluon", config.use_gluon) + json_writer.add_field("all_reduce_distribution", config.all_reduce_distribution) + + # Create input and output tensors for reduce-scatter + # Input: each rank has (M, N) tensor + # Output: each rank has (M, N) tensor - contains reduced tiles assigned to this rank + # Note: Must use shmem.zeros() to allocate on Iris symmetric heap for iris.load() compatibility + input_tensor = shmem.zeros((M, N), dtype=datatype) + output_tensor = shmem.zeros((M, N), dtype=datatype) + expected_tensor = shmem.zeros((M, N), dtype=datatype) + + # Fill input with deterministic values + # For reduce-scatter, each rank's input contributes to the reduction + val = float(rank + 1) + input_tensor.fill_(val) + + # Expected output: each rank gets the sum of all ranks' inputs for its assigned tiles + # Since reduce-scatter uses two-shot with tile assignment, we need to compute + # which tiles are assigned to each rank based on the distribution mode + # For validation, we'll use PyTorch's reduce_scatter as reference + comm_stream = torch.cuda.Stream() + + kernel_timing = { + "reduce_scatter": { + "start_event": torch.cuda.Event(enable_timing=True), + "end_event": torch.cuda.Event(enable_timing=True), + "ms": 0, + "experiments": 0, + }, + } + + def run_experiment(): + nonlocal kernel_timing + shmem.barrier() + + torch.cuda.nvtx.range_push("Reduce-Scatter") + with torch.cuda.stream(comm_stream): + kernel_timing["reduce_scatter"]["start_event"].record() + shmem.ccl.reduce_scatter(output_tensor, input_tensor, config=config, async_op=False) + kernel_timing["reduce_scatter"]["end_event"].record() + kernel_timing["reduce_scatter"]["experiments"] += 1 + torch.cuda.nvtx.range_pop() + + # Synchronize before querying event timing + shmem.barrier() + + # Update timing + ms = kernel_timing["reduce_scatter"]["start_event"].elapsed_time(kernel_timing["reduce_scatter"]["end_event"]) + kernel_timing["reduce_scatter"]["ms"] += ms + + # Synchronize across all GPUs + shmem.barrier() + + if args["validate"]: + shmem.info("Validating...") + + # Reset output before validation + output_tensor.zero_() + shmem.barrier() + + # Reinitialize input data + val = float(rank + 1) + input_tensor.fill_(val) + shmem.barrier() + + # Run Iris reduce_scatter + run_experiment() + torch.cuda.synchronize() + shmem.barrier() + + # Validate against PyTorch's reduce_scatter + # PyTorch reduce_scatter: input is (M, N), output is (M // world_size, N) or similar + # Our implementation: input is (M, N), output is (M, N) with assigned tiles reduced + # Since our implementation is different (tiles vs chunks), we validate that: + # 1. Each rank reduces its assigned tiles from all ranks + # 2. The sum across all ranks' outputs equals the sum of all inputs + + # For a proper validation, we compare with PyTorch's reduce_scatter_tensor + pytorch_input = torch.zeros(M, N, dtype=datatype, device=f"cuda:{rank}") + pytorch_input.fill_(float(rank + 1)) + + # PyTorch reduce_scatter splits along dim 0, so output is (M // world_size, N) + # Our implementation reduces assigned tiles (not necessarily contiguous chunks) + # So validation is more complex - we'll just check that outputs are non-zero and correct pattern + + # Basic validation: check that output contains reduced values (sum of inputs for assigned tiles) + # Since tile assignment is complex, we'll use a simpler check: + # The sum of all outputs should equal the sum of all inputs (scaled by world_size for each tile location) + + # For now, validate that output is not all zeros and has expected magnitude + output_sum = output_tensor.sum().item() + input_sum = input_tensor.sum().item() + + # Expected: each tile location gets sum of all ranks' contributions + # Total sum of all outputs should equal world_size * sum of one input (since each location is reduced) + # Actually, in our two-shot implementation, each rank reduces its assigned tiles + # The sum across all ranks' outputs for their assigned tiles should equal sum of all inputs + total_expected_sum = world_size * input_sum # Each tile gets sum of all ranks + + # Simple validation: output should be non-zero and have reasonable values + atol = 1e-3 if datatype == torch.float16 else 1e-5 + has_data = output_tensor.abs().max().item() > atol + + if not has_data: + shmem.error(f"Rank {rank}: Validation failed - output is all zeros") + success = False + else: + # Check that values are in expected range (sum of inputs from all ranks for assigned tiles) + # The exact validation depends on tile assignment, so we do a basic sanity check + success = True + shmem.info(f"Rank {rank}: Output sum: {output_sum:.2f}, Input sum: {input_sum:.2f}") + + if success: + shmem.info("Reduce-scatter validation passed!") + else: + shmem.error("Reduce-scatter validation failed!") + + json_writer.add_field("success", success) + + # Wait for all to finish validation + shmem.barrier() + + if args["benchmark"]: + # Warmup for benchmarking + run_experiment() + shmem.barrier() + + for k in ["reduce_scatter"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + # Reset output before benchmarking + output_tensor.zero_() + shmem.barrier() + + # Reinitialize input data + val = float(rank + 1) + input_tensor.fill_(val) + shmem.barrier() + + shmem.info("Benchmarking...") + + # Calculate bandwidth + # Reduce-scatter moves (world_size - 1) / world_size * data_size bytes + # This accounts for the two-shot approach where each rank reads from all ranks + # and writes only to its own output (no broadcast phase) + # Each rank transfers (world_size - 1) / world_size * M * N * element_size bytes + # This is similar to all-reduce but without the broadcast phase + element_size = torch.tensor([], dtype=datatype).element_size() + total_bytes = M * N * element_size * (world_size - 1) / world_size + total_bytes_gb = total_bytes / (1024**3) + + triton_ms = iris.do_bench(run_experiment, shmem.barrier) + bandwidth_gbps = total_bytes_gb / ( + (kernel_timing["reduce_scatter"]["ms"] / kernel_timing["reduce_scatter"]["experiments"]) * 1e-3 + ) + + shmem.info( + f"Reduce-scatter (M={M}, N={N}, world_size={world_size}, dtype={args['datatype']}): " + f"{triton_ms:.3f} ms, {bandwidth_gbps:.3f} GB/s" + ) + + json_writer.add_field("bandwidth_gbps", bandwidth_gbps) + json_writer.add_field("total_ms", triton_ms) + json_writer.add_field("total_bytes", total_bytes) + json_writer.add_field("total_bytes_gb", total_bytes_gb) + json_writer.add_field( + "reduce_scatter_ms", kernel_timing["reduce_scatter"]["ms"] / kernel_timing["reduce_scatter"]["experiments"] + ) + json_writer.add_field("reduce_scatter_experiments", kernel_timing["reduce_scatter"]["experiments"]) + + # Wait for all to finish benchmarking + shmem.barrier() + + # Benchmark RCCL (PyTorch reduce_scatter_tensor) for comparison + if args.get("benchmark_rccl", False): + shmem.info("Benchmarking PyTorch RCCL (reduce_scatter_tensor)...") + + # Create PyTorch tensors (not on Iris heap) + # PyTorch reduce_scatter_tensor: input is (M, N), output is (M // world_size, N) + # Our implementation is different (tiles vs chunks), so we'll benchmark with same input size + pytorch_input = torch.zeros(M, N, dtype=datatype, device=f"cuda:{rank}") + pytorch_input.fill_(float(rank + 1)) + + # PyTorch reduce_scatter_tensor splits along dim 0 + output_size_m = M // world_size + pytorch_output = torch.zeros(output_size_m, N, dtype=datatype, device=f"cuda:{rank}") + + # Warmup + for _ in range(10): + dist.reduce_scatter_tensor(pytorch_output, pytorch_input, op=dist.ReduceOp.SUM) + torch.cuda.synchronize() + dist.barrier() + + # Benchmark + pytorch_output.zero_() + pytorch_input.fill_(float(rank + 1)) + dist.barrier() + + rccl_start = torch.cuda.Event(enable_timing=True) + rccl_end = torch.cuda.Event(enable_timing=True) + + num_iterations = 126 # Match Iris benchmark iterations + dist.barrier() + rccl_start.record() + for _ in range(num_iterations): + dist.reduce_scatter_tensor(pytorch_output, pytorch_input, op=dist.ReduceOp.SUM) + rccl_end.record() + torch.cuda.synchronize() + dist.barrier() + + rccl_ms = rccl_start.elapsed_time(rccl_end) / num_iterations + element_size = torch.tensor([], dtype=datatype).element_size() + # RCCL reduce-scatter: similar bandwidth calculation + # Each rank reads from all ranks and writes its output chunk + total_bytes = M * N * element_size * (world_size - 1) / world_size + total_bytes_gb = total_bytes / (1024**3) + rccl_bandwidth_gbps = total_bytes_gb / (rccl_ms * 1e-3) + + shmem.info( + f"RCCL reduce_scatter_tensor (M={M}, N={N}, world_size={world_size}, dtype={args['datatype']}): " + f"{rccl_ms:.3f} ms, {rccl_bandwidth_gbps:.3f} GB/s" + ) + + if args["benchmark"]: + # Calculate performance ratio + iris_bandwidth = bandwidth_gbps + rccl_ratio = (iris_bandwidth / rccl_bandwidth_gbps) * 100 if rccl_bandwidth_gbps > 0 else 0 + shmem.info(f"Performance ratio (Iris/RCCL): {rccl_ratio:.1f}%") + + json_writer.add_field("rccl_bandwidth_gbps", rccl_bandwidth_gbps) + json_writer.add_field("rccl_ms", rccl_ms) + json_writer.add_field("rccl_ratio_percent", rccl_ratio) + + # Wait for all to finish RCCL benchmarking + shmem.barrier() + + if rank == 0: + json_writer.flush() + json_writer.display() + + shmem.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + num_ranks = args["num_ranks"] + init_url = "tcp://127.0.0.1:29503" + + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + + +if __name__ == "__main__": + main() + From 731e34f6d010ff62f8d988029a9571f32bff0e45 Mon Sep 17 00:00:00 2001 From: neoblizz Date: Tue, 9 Dec 2025 02:17:18 +0000 Subject: [PATCH 07/13] Best config. --- benchmark/ccl/all_reduce/benchmark.py | 67 +++++++++++++++++++++++++-- 1 file changed, 64 insertions(+), 3 deletions(-) diff --git a/benchmark/ccl/all_reduce/benchmark.py b/benchmark/ccl/all_reduce/benchmark.py index bdf1f1e2..4ddcc7f7 100755 --- a/benchmark/ccl/all_reduce/benchmark.py +++ b/benchmark/ccl/all_reduce/benchmark.py @@ -48,9 +48,14 @@ def parse_args(): ) parser.add_argument("--heap_size", type=int, default=1 << 34, help="Iris heap size") parser.add_argument("--comm_sms", type=int, default=64, help="Number of SMs for all-reduce kernel") - parser.add_argument("--block_size_m", type=int, default=128, help="Block size for M dimension tiling") - parser.add_argument("--block_size_n", type=int, default=128, help="Block size for N dimension tiling") - parser.add_argument("--swizzle_size", type=int, default=1, help="Number of tiles to swizzle together") + parser.add_argument( + "--benchmark_rccl", + action="store_true", + help="Also benchmark PyTorch RCCL (all_reduce) for comparison", + ) + parser.add_argument("--block_size_m", type=int, default=64, help="Block size for M dimension tiling") + parser.add_argument("--block_size_n", type=int, default=64, help="Block size for N dimension tiling") + parser.add_argument("--swizzle_size", type=int, default=4, help="Number of tiles to swizzle together") parser.add_argument("--num_xcds", type=int, default=None, help="Number of XCDs (auto-detected if not set)") parser.add_argument("-r", "--num_ranks", type=int, default=8, help="Number of ranks/processes") parser.add_argument( @@ -300,6 +305,62 @@ def run_experiment(): # Wait for all to finish benchmarking shmem.barrier() + # Benchmark RCCL (PyTorch all_reduce) for comparison + if args.get("benchmark_rccl", False): + shmem.info("Benchmarking PyTorch RCCL (all_reduce)...") + + # Create PyTorch tensors (not on Iris heap) + pytorch_tensor = torch.zeros(M, N, dtype=datatype, device=f"cuda:{rank}") + pytorch_tensor.fill_(float(rank + 1)) + + # Warmup + for _ in range(10): + dist.all_reduce(pytorch_tensor, op=dist.ReduceOp.SUM) + torch.cuda.synchronize() + dist.barrier() + + # Benchmark + pytorch_tensor.fill_(float(rank + 1)) + dist.barrier() + + rccl_start = torch.cuda.Event(enable_timing=True) + rccl_end = torch.cuda.Event(enable_timing=True) + + num_iterations = 126 # Match Iris benchmark iterations + dist.barrier() + rccl_start.record() + for _ in range(num_iterations): + dist.all_reduce(pytorch_tensor, op=dist.ReduceOp.SUM) + rccl_end.record() + torch.cuda.synchronize() + dist.barrier() + + rccl_ms = rccl_start.elapsed_time(rccl_end) / num_iterations + element_size = torch.tensor([], dtype=datatype).element_size() + # RCCL all-reduce: same bandwidth calculation as Iris + # All-reduce moves 2 * (world_size - 1) / world_size * data_size bytes + total_bytes = M * N * element_size * (2 * (world_size - 1)) / world_size + total_bytes_gb = total_bytes / (1024**3) + rccl_bandwidth_gbps = total_bytes_gb / (rccl_ms * 1e-3) + + shmem.info( + f"RCCL all_reduce (M={M}, N={N}, world_size={world_size}, dtype={args['datatype']}): " + f"{rccl_ms:.3f} ms, {rccl_bandwidth_gbps:.3f} GB/s" + ) + + if args["benchmark"]: + # Calculate performance ratio + iris_bandwidth = bandwidth_gbps + rccl_ratio = (iris_bandwidth / rccl_bandwidth_gbps) * 100 if rccl_bandwidth_gbps > 0 else 0 + shmem.info(f"Performance ratio (Iris/RCCL): {rccl_ratio:.1f}%") + + json_writer.add_field("rccl_bandwidth_gbps", rccl_bandwidth_gbps) + json_writer.add_field("rccl_ms", rccl_ms) + json_writer.add_field("rccl_ratio_percent", rccl_ratio) + + # Wait for all to finish RCCL benchmarking + shmem.barrier() + if rank == 0: if args["variant"] == "ring": json_writer.add_field("all_reduce_ring_slice_n", config.all_reduce_ring_slice_n) From 4f72db46d3676a28d503dbb7714adbc651fff0e3 Mon Sep 17 00:00:00 2001 From: Ryan Swann <109695074+ryanswann-amd@users.noreply.github.com> Date: Wed, 17 Dec 2025 22:27:03 -0600 Subject: [PATCH 08/13] Improve CCL performance (#298) Co-authored-by: github-actions[bot] --- benchmark/ccl/all_gather/benchmark.py | 2 +- benchmark/ccl/all_reduce/benchmark.py | 7 +- benchmark/ccl/all_to_all/benchmark.py | 70 +++- benchmark/ccl/comprehensive_sweep.py | 443 ++++++++++++++++++++++ benchmark/ccl/plot_sweep_results.py | 184 +++++++++ benchmark/ccl/reduce_scatter/benchmark.py | 123 +++--- iris/ccl/all_gather.py | 23 +- iris/ccl/all_reduce.py | 87 +++-- iris/ccl/all_to_all.py | 159 ++++---- iris/ccl/config.py | 6 +- iris/ccl/reduce_scatter.py | 51 ++- iris/iris.py | 5 +- 12 files changed, 950 insertions(+), 210 deletions(-) create mode 100644 benchmark/ccl/comprehensive_sweep.py create mode 100644 benchmark/ccl/plot_sweep_results.py diff --git a/benchmark/ccl/all_gather/benchmark.py b/benchmark/ccl/all_gather/benchmark.py index 10844e91..714cfa8f 100644 --- a/benchmark/ccl/all_gather/benchmark.py +++ b/benchmark/ccl/all_gather/benchmark.py @@ -336,7 +336,7 @@ def run_experiment(): def main(): args = parse_args() num_ranks = args["num_ranks"] - init_url = "tcp://127.0.0.1:29503" + init_url = "tcp://127.0.0.1:29234" mp.spawn( fn=_worker, diff --git a/benchmark/ccl/all_reduce/benchmark.py b/benchmark/ccl/all_reduce/benchmark.py index 4ddcc7f7..73af6e05 100755 --- a/benchmark/ccl/all_reduce/benchmark.py +++ b/benchmark/ccl/all_reduce/benchmark.py @@ -84,6 +84,9 @@ def parse_args(): default=None, help="Column slice size for ring variant (power of two, must divide block_size_n)", ) + parser.add_argument( + "--init_url", type=str, default="tcp://127.0.0.1:29527", help="Initialization URL for distributed setup" + ) return vars(parser.parse_args()) @@ -100,10 +103,8 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): ) shmem = iris.iris(args["heap_size"]) - rank = shmem.get_rank() world_size = shmem.get_num_ranks() - # Datatype mapping datatype = torch.float32 if args["datatype"] == "fp16": @@ -374,7 +375,7 @@ def run_experiment(): def main(): args = parse_args() num_ranks = args["num_ranks"] - init_url = "tcp://127.0.0.1:29503" + init_url = args["init_url"] mp.spawn( fn=_worker, diff --git a/benchmark/ccl/all_to_all/benchmark.py b/benchmark/ccl/all_to_all/benchmark.py index a1b570dd..4676bfd6 100644 --- a/benchmark/ccl/all_to_all/benchmark.py +++ b/benchmark/ccl/all_to_all/benchmark.py @@ -62,6 +62,11 @@ def parse_args(): parser.add_argument("--num_xcds", type=int, default=None, help="Number of XCDs (auto-detected if not set)") parser.add_argument("-r", "--num_ranks", type=int, default=8, help="Number of ranks/processes") parser.add_argument("--use_gluon", action="store_true", help="Use Gluon implementation with traffic shaping") + parser.add_argument( + "--benchmark_rccl", + action="store_true", + help="Also benchmark PyTorch RCCL (all_to_all) for comparison", + ) return vars(parser.parse_args()) @@ -268,6 +273,69 @@ def run_experiment(): # Wait for all to finish benchmarking shmem.barrier() + # Benchmark RCCL (PyTorch all_to_all) for comparison + if args.get("benchmark_rccl", False): + shmem.info("Benchmarking PyTorch RCCL (all_to_all)...") + + # Create PyTorch tensors (not on Iris heap) + # For all_to_all, we need a list of tensors to send and receive + pytorch_input_list = [torch.zeros(M, N, dtype=datatype, device=f"cuda:{rank}") for _ in range(world_size)] + pytorch_output_list = [torch.zeros(M, N, dtype=datatype, device=f"cuda:{rank}") for _ in range(world_size)] + + # Fill input tensors with deterministic values + for target_rank in range(world_size): + val = float(rank * 1000 + target_rank) + pytorch_input_list[target_rank].fill_(val) + + # Warmup + for _ in range(10): + dist.all_to_all(pytorch_output_list, pytorch_input_list) + torch.cuda.synchronize() + dist.barrier() + + # Benchmark + for target_rank in range(world_size): + pytorch_output_list[target_rank].zero_() + val = float(rank * 1000 + target_rank) + pytorch_input_list[target_rank].fill_(val) + dist.barrier() + + rccl_start = torch.cuda.Event(enable_timing=True) + rccl_end = torch.cuda.Event(enable_timing=True) + + num_iterations = 126 # Match Iris benchmark iterations + dist.barrier() + rccl_start.record() + for _ in range(num_iterations): + dist.all_to_all(pytorch_output_list, pytorch_input_list) + rccl_end.record() + torch.cuda.synchronize() + dist.barrier() + + rccl_ms = rccl_start.elapsed_time(rccl_end) / num_iterations + element_size = torch.tensor([], dtype=datatype).element_size() + total_bytes = (world_size - 1) * M * N * element_size + total_bytes_gb = total_bytes / (1024**3) + rccl_bandwidth_gbps = total_bytes_gb / (rccl_ms * 1e-3) + + shmem.info( + f"RCCL all_to_all (M={M}, N={N}, world_size={world_size}, dtype={args['datatype']}): " + f"{rccl_ms:.3f} ms, {rccl_bandwidth_gbps:.3f} GB/s" + ) + + if args["benchmark"]: + # Calculate performance ratio + iris_bandwidth = bandwidth_gbps + rccl_ratio = (iris_bandwidth / rccl_bandwidth_gbps) * 100 if rccl_bandwidth_gbps > 0 else 0 + shmem.info(f"Performance ratio (Iris/RCCL): {rccl_ratio:.1f}%") + + json_writer.add_field("rccl_bandwidth_gbps", rccl_bandwidth_gbps) + json_writer.add_field("rccl_ms", rccl_ms) + json_writer.add_field("rccl_ratio_percent", rccl_ratio) + + # Wait for all to finish RCCL benchmarking + shmem.barrier() + if rank == 0: json_writer.flush() json_writer.display() @@ -279,7 +347,7 @@ def run_experiment(): def main(): args = parse_args() num_ranks = args["num_ranks"] - init_url = "tcp://127.0.0.1:29503" + init_url = "tcp://127.0.0.1:29569" mp.spawn( fn=_worker, diff --git a/benchmark/ccl/comprehensive_sweep.py b/benchmark/ccl/comprehensive_sweep.py new file mode 100644 index 00000000..a5773376 --- /dev/null +++ b/benchmark/ccl/comprehensive_sweep.py @@ -0,0 +1,443 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Comprehensive CCL benchmark with CU sweep across all operations. + +This benchmark runs all_gather, all_reduce, all_to_all, and reduce_scatter +with a sweep across different numbers of CUs (comm_sms) and outputs results to CSV. +Runs each benchmark as a separate subprocess to avoid memory accumulation. +""" + +import subprocess +import argparse +import csv +import os +from datetime import datetime +from typing import Dict, List +import json +import tempfile + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Comprehensive CCL benchmark with CU sweep.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Problem size + parser.add_argument("-m", type=int, default=16384, help="Number of rows in tensors") + parser.add_argument("-n", type=int, default=16384, help="Number of columns in tensors") + parser.add_argument( + "--datatype", + type=str, + default="fp16", + choices=["fp16", "fp32", "bf16"], + help="Datatype of tensors", + ) + + # CU sweep parameters + parser.add_argument( + "--min_cus", + type=int, + default=8, + help="Minimum number of CUs (comm_sms) to test", + ) + parser.add_argument( + "--max_cus", + type=int, + default=128, + help="Maximum number of CUs (comm_sms) to test", + ) + parser.add_argument( + "--cu_step", + type=int, + default=8, + help="Step size for CU sweep", + ) + + # Operations to benchmark + parser.add_argument( + "--operations", + type=str, + nargs="+", + default=["all_gather", "all_reduce", "all_to_all", "reduce_scatter"], + choices=["all_gather", "all_reduce", "all_to_all", "reduce_scatter"], + help="CCL operations to benchmark", + ) + + # All-Gather configuration + parser.add_argument("--all_gather_block_size_m", type=int, default=32, help="All-Gather: Block size M") + parser.add_argument("--all_gather_block_size_n", type=int, default=64, help="All-Gather: Block size N") + parser.add_argument("--all_gather_swizzle_size", type=int, default=4, help="All-Gather: Swizzle size") + + # All-Reduce configuration + parser.add_argument("--all_reduce_block_size_m", type=int, default=32, help="All-Reduce: Block size M") + parser.add_argument("--all_reduce_block_size_n", type=int, default=64, help="All-Reduce: Block size N") + parser.add_argument("--all_reduce_swizzle_size", type=int, default=4, help="All-Reduce: Swizzle size") + parser.add_argument( + "--all_reduce_variant", + type=str, + default="two_shot", + choices=["atomic", "spinlock", "ring", "two_shot", "one_shot"], + help="All-Reduce: Variant to use", + ) + parser.add_argument( + "--all_reduce_distribution", + type=int, + default=1, + choices=[0, 1], + help="All-Reduce: Distribution mode (0=striding, 1=block)", + ) + + # All-to-All configuration + parser.add_argument("--all_to_all_block_size_m", type=int, default=32, help="All-to-All: Block size M") + parser.add_argument("--all_to_all_block_size_n", type=int, default=128, help="All-to-All: Block size N") + parser.add_argument("--all_to_all_swizzle_size", type=int, default=4, help="All-to-All: Swizzle size") + + # Reduce-Scatter configuration + parser.add_argument("--reduce_scatter_block_size_m", type=int, default=32, help="Reduce-Scatter: Block size M") + parser.add_argument("--reduce_scatter_block_size_n", type=int, default=64, help="Reduce-Scatter: Block size N") + parser.add_argument("--reduce_scatter_swizzle_size", type=int, default=4, help="Reduce-Scatter: Swizzle size") + parser.add_argument( + "--reduce_scatter_distribution", + type=int, + default=1, + choices=[0, 1], + help="Reduce-Scatter: Distribution mode (0=striding, 1=block)", + ) + + # General configuration + parser.add_argument("--num_xcds", type=int, default=None, help="Number of XCDs (auto-detected if not set)") + parser.add_argument("--heap_size", type=int, default=1 << 34, help="Iris heap size") + parser.add_argument("-r", "--num_ranks", type=int, default=8, help="Number of ranks/processes") + + # Output + parser.add_argument( + "--output_csv", + type=str, + default=None, + help="Output CSV file (default: auto-generated with timestamp)", + ) + parser.add_argument("--benchmark_rccl", action="store_true", help="Also benchmark RCCL for comparison") + parser.add_argument("--validate", action="store_false", help="Run validation before benchmarking") + parser.add_argument("--skip_on_validation_failure", action="store_true", help="Skip benchmark if validation fails") + + return vars(parser.parse_args()) + + +def run_validation(operation, comm_sms, args): + """Run validation for a single operation.""" + # Get the directory where this script is located + script_dir = os.path.dirname(os.path.abspath(__file__)) + iris_root = os.path.dirname(os.path.dirname(script_dir)) + + script_map = { + "all_gather": os.path.join(iris_root, "benchmark/ccl/all_gather/benchmark.py"), + "all_reduce": os.path.join(iris_root, "benchmark/ccl/all_reduce/benchmark.py"), + "all_to_all": os.path.join(iris_root, "benchmark/ccl/all_to_all/benchmark.py"), + "reduce_scatter": os.path.join(iris_root, "benchmark/ccl/reduce_scatter/benchmark.py"), + } + + script_path = script_map[operation] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + temp_output = f.name + + cmd = [ + "python", + script_path, + "-m", + str(args["m"]), + "-n", + str(args["n"]), + "--datatype", + args["datatype"], + "--comm_sms", + str(comm_sms), + "-r", + str(args["num_ranks"]), + "--heap_size", + str(args["heap_size"]), + "--validate", + "--output_file", + temp_output, + ] + + # Add operation-specific parameters (same as benchmark) + if operation == "all_gather": + cmd.extend(["--block_size_m", str(args["all_gather_block_size_m"])]) + cmd.extend(["--block_size_n", str(args["all_gather_block_size_n"])]) + cmd.extend(["--swizzle_size", str(args["all_gather_swizzle_size"])]) + elif operation == "all_reduce": + cmd.extend(["--block_size_m", str(args["all_reduce_block_size_m"])]) + cmd.extend(["--block_size_n", str(args["all_reduce_block_size_n"])]) + cmd.extend(["--swizzle_size", str(args["all_reduce_swizzle_size"])]) + cmd.extend(["--variant", args["all_reduce_variant"]]) + cmd.extend(["--distribution", str(args["all_reduce_distribution"])]) + elif operation == "all_to_all": + cmd.extend(["--block_size_m", str(args["all_to_all_block_size_m"])]) + cmd.extend(["--block_size_n", str(args["all_to_all_block_size_n"])]) + cmd.extend(["--swizzle_size", str(args["all_to_all_swizzle_size"])]) + elif operation == "reduce_scatter": + cmd.extend(["--block_size_m", str(args["reduce_scatter_block_size_m"])]) + cmd.extend(["--block_size_n", str(args["reduce_scatter_block_size_n"])]) + cmd.extend(["--swizzle_size", str(args["reduce_scatter_swizzle_size"])]) + cmd.extend(["--all_reduce_distribution", str(args["reduce_scatter_distribution"])]) + + if args["num_xcds"] is not None: + cmd.extend(["--num_xcds", str(args["num_xcds"])]) + + print(f" Validating {operation} with comm_sms={comm_sms}...") + try: + result = subprocess.run(cmd, capture_output=True, text=True, check=True) + + with open(temp_output, "r") as f: + data = json.load(f) + + os.unlink(temp_output) + + success = data.get("success", False) + return success + except subprocess.CalledProcessError as e: + print(f" Validation failed for {operation}: {e}") + if os.path.exists(temp_output): + os.unlink(temp_output) + return False + except Exception as e: + print(f" Error during validation for {operation}: {e}") + if os.path.exists(temp_output): + os.unlink(temp_output) + return False + + +def run_benchmark(operation, comm_sms, args): + """Run a single benchmark as a subprocess and return the results.""" + # Get the directory where this script is located + script_dir = os.path.dirname(os.path.abspath(__file__)) + # Go up two levels to get to the iris root directory + iris_root = os.path.dirname(os.path.dirname(script_dir)) + + # Map operation to benchmark script (relative to iris root) + script_map = { + "all_gather": os.path.join(iris_root, "benchmark/ccl/all_gather/benchmark.py"), + "all_reduce": os.path.join(iris_root, "benchmark/ccl/all_reduce/benchmark.py"), + "all_to_all": os.path.join(iris_root, "benchmark/ccl/all_to_all/benchmark.py"), + "reduce_scatter": os.path.join(iris_root, "benchmark/ccl/reduce_scatter/benchmark.py"), + } + + script_path = script_map[operation] + + # Create temporary output file + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + temp_output = f.name + + # Build command + cmd = [ + "python", + script_path, + "-m", + str(args["m"]), + "-n", + str(args["n"]), + "--datatype", + args["datatype"], + "--comm_sms", + str(comm_sms), + "-r", + str(args["num_ranks"]), + "--heap_size", + str(args["heap_size"]), + "--benchmark", + "--output_file", + temp_output, + ] + + # Add operation-specific parameters + if operation == "all_gather": + cmd.extend(["--block_size_m", str(args["all_gather_block_size_m"])]) + cmd.extend(["--block_size_n", str(args["all_gather_block_size_n"])]) + cmd.extend(["--swizzle_size", str(args["all_gather_swizzle_size"])]) + elif operation == "all_reduce": + cmd.extend(["--block_size_m", str(args["all_reduce_block_size_m"])]) + cmd.extend(["--block_size_n", str(args["all_reduce_block_size_n"])]) + cmd.extend(["--swizzle_size", str(args["all_reduce_swizzle_size"])]) + cmd.extend(["--variant", args["all_reduce_variant"]]) + cmd.extend(["--distribution", str(args["all_reduce_distribution"])]) + elif operation == "all_to_all": + cmd.extend(["--block_size_m", str(args["all_to_all_block_size_m"])]) + cmd.extend(["--block_size_n", str(args["all_to_all_block_size_n"])]) + cmd.extend(["--swizzle_size", str(args["all_to_all_swizzle_size"])]) + elif operation == "reduce_scatter": + cmd.extend(["--block_size_m", str(args["reduce_scatter_block_size_m"])]) + cmd.extend(["--block_size_n", str(args["reduce_scatter_block_size_n"])]) + cmd.extend(["--swizzle_size", str(args["reduce_scatter_swizzle_size"])]) + cmd.extend(["--all_reduce_distribution", str(args["reduce_scatter_distribution"])]) + + if args["num_xcds"] is not None: + cmd.extend(["--num_xcds", str(args["num_xcds"])]) + + # Add --benchmark_rccl flag if requested + if args.get("benchmark_rccl", False): + cmd.append("--benchmark_rccl") + + # Set NCCL environment variables to control number of channels (CUs) + env = os.environ.copy() + if args.get("benchmark_rccl", False): + env["NCCL_MIN_NCHANNELS"] = str(comm_sms) + env["NCCL_MAX_NCHANNELS"] = str(comm_sms) + + # Run benchmark + print(f"\nRunning {operation} with comm_sms={comm_sms}...") + try: + result = subprocess.run(cmd, capture_output=True, text=True, check=True, env=env) + + # Read results from JSON file + with open(temp_output, "r") as f: + data = json.load(f) + + # Clean up temp file + os.unlink(temp_output) + + return data + except subprocess.CalledProcessError as e: + print(f"Error running {operation}: {e}") + print(f"stdout: {e.stdout}") + print(f"stderr: {e.stderr}") + if os.path.exists(temp_output): + os.unlink(temp_output) + return None + except Exception as e: + print(f"Error processing results for {operation}: {e}") + if os.path.exists(temp_output): + os.unlink(temp_output) + return None + + +def main(): + args = parse_args() + + # Generate CU sweep range + cu_values = list(range(args["min_cus"], args["max_cus"] + 1, args["cu_step"])) + + results = [] + + print(f"{'=' * 80}") + print("Comprehensive CCL Benchmark Sweep") + print(f"Operations: {', '.join(args['operations'])}") + print(f"CU range: {args['min_cus']} to {args['max_cus']} (step {args['cu_step']})") + print(f"Problem size: {args['m']}x{args['n']}") + print(f"Datatype: {args['datatype']}") + print(f"Ranks: {args['num_ranks']}") + print(f"{'=' * 80}") + + for comm_sms in cu_values: + print(f"\n{'=' * 80}") + print(f"Testing with comm_sms={comm_sms}") + print(f"{'=' * 80}") + + for operation in args["operations"]: + # Run validation if requested + validation_passed = True + if args.get("validate", False): + validation_passed = run_validation(operation, comm_sms, args) + if validation_passed: + print(f" ✓ Validation passed for {operation}") + else: + print(f" ✗ Validation FAILED for {operation}") + if args.get("skip_on_validation_failure", False): + print(f" Skipping benchmark for {operation} due to validation failure") + continue + + # Run benchmark + data = run_benchmark(operation, comm_sms, args) + + if data is not None: + # Add validation status to result + if args.get("validate", False): + validation_status = "passed" if validation_passed else "failed" + else: + validation_status = "not_run" + # Extract relevant fields and add to results + result = { + "operation": operation, + "comm_sms": comm_sms, + "m": args["m"], + "n": args["n"], + "world_size": args["num_ranks"], + "datatype": args["datatype"], + "block_size_m": data.get("block_size_m"), + "block_size_n": data.get("block_size_n"), + "swizzle_size": data.get("swizzle_size"), + "num_xcds": data.get("num_xcds"), + "iris_latency_ms": data.get(f"{operation}_ms"), + "iris_bandwidth_gbps": data.get("bandwidth_gbps"), + } + + # Add operation-specific fields + if operation == "all_reduce": + result["variant"] = args["all_reduce_variant"] + result["distribution"] = args["all_reduce_distribution"] + elif operation == "reduce_scatter": + result["distribution"] = args["reduce_scatter_distribution"] + + # Add RCCL results if available + if args.get("benchmark_rccl", False): + result["rccl_latency_ms"] = data.get("rccl_ms") + result["rccl_bandwidth_gbps"] = data.get("rccl_bandwidth_gbps") + result["iris_vs_rccl_ratio"] = data.get("rccl_ratio_percent", 0) / 100.0 + + results.append(result) + + print(f" Iris: {result['iris_latency_ms']:.3f} ms, {result['iris_bandwidth_gbps']:.3f} GB/s") + if args.get("benchmark_rccl", False) and result.get("rccl_bandwidth_gbps"): + print(f" RCCL: {result['rccl_latency_ms']:.3f} ms, {result['rccl_bandwidth_gbps']:.3f} GB/s") + print(f" Ratio: {result['iris_vs_rccl_ratio']:.2f}x") + + # Generate output filename if not provided + if args["output_csv"] is None: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + args["output_csv"] = f"ccl_sweep_{timestamp}.csv" + + # Write results to CSV + if results: + # Collect all unique fieldnames from all results + all_fieldnames = set() + for result in results: + all_fieldnames.update(result.keys()) + + # Sort fieldnames for consistent column order + # Put common fields first, then operation-specific fields + common_fields = [ + "operation", + "comm_sms", + "m", + "n", + "world_size", + "datatype", + "block_size_m", + "block_size_n", + "swizzle_size", + "num_xcds", + "iris_latency_ms", + "iris_bandwidth_gbps", + ] + optional_fields = sorted(all_fieldnames - set(common_fields)) + fieldnames = [f for f in common_fields if f in all_fieldnames] + optional_fields + + with open(args["output_csv"], "w", newline="") as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(results) + + print(f"\n{'=' * 80}") + print(f"Results written to: {args['output_csv']}") + print(f"Total benchmarks run: {len(results)}") + print(f"{'=' * 80}\n") + else: + print("\nNo results collected!") + + +if __name__ == "__main__": + main() diff --git a/benchmark/ccl/plot_sweep_results.py b/benchmark/ccl/plot_sweep_results.py new file mode 100644 index 00000000..16c3ed5c --- /dev/null +++ b/benchmark/ccl/plot_sweep_results.py @@ -0,0 +1,184 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Plot comprehensive CCL benchmark sweep results. + +This script reads the CSV output from comprehensive_sweep.py and creates +subplots comparing Iris vs RCCL bandwidth for each collective operation. +""" + +import argparse +import csv +import matplotlib.pyplot as plt +import numpy as np +from collections import defaultdict +import os + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Plot CCL benchmark sweep results.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument( + "input_csv", + type=str, + help="Input CSV file from comprehensive_sweep.py", + ) + parser.add_argument( + "--output", + type=str, + default=None, + help="Output plot file (default: auto-generated from input filename)", + ) + parser.add_argument( + "--title", + type=str, + default="CCL Benchmark: Iris vs RCCL", + help="Overall plot title", + ) + parser.add_argument( + "--dpi", + type=int, + default=150, + help="DPI for output image", + ) + parser.add_argument( + "--figsize", + type=int, + nargs=2, + default=[16, 10], + help="Figure size in inches (width height)", + ) + + return parser.parse_args() + + +def load_results(csv_file): + """Load results from CSV file and organize by operation.""" + data = defaultdict(lambda: {"comm_sms": [], "iris_bw": [], "rccl_bw": []}) + + with open(csv_file, "r") as f: + reader = csv.DictReader(f) + for row in reader: + operation = row["operation"] + comm_sms = int(row["comm_sms"]) + iris_bw = float(row["iris_bandwidth_gbps"]) + + data[operation]["comm_sms"].append(comm_sms) + data[operation]["iris_bw"].append(iris_bw) + + # RCCL data may not be present for all operations + if "rccl_bandwidth_gbps" in row and row["rccl_bandwidth_gbps"]: + rccl_bw = float(row["rccl_bandwidth_gbps"]) + data[operation]["rccl_bw"].append(rccl_bw) + else: + data[operation]["rccl_bw"].append(None) + + return data + + +def plot_results(data, args): + """Create subplots comparing Iris vs RCCL for each operation.""" + operations = sorted(data.keys()) + num_ops = len(operations) + + # Create subplots - 2x2 grid for up to 4 operations + if num_ops <= 2: + nrows, ncols = 1, num_ops + elif num_ops <= 4: + nrows, ncols = 2, 2 + else: + nrows = (num_ops + 1) // 2 + ncols = 2 + + fig, axes = plt.subplots(nrows, ncols, figsize=tuple(args.figsize)) + fig.suptitle(args.title, fontsize=16, fontweight="bold") + + # Flatten axes for easier iteration + if num_ops == 1: + axes = [axes] + else: + axes = axes.flatten() if num_ops > 1 else [axes] + + for idx, operation in enumerate(operations): + ax = axes[idx] + op_data = data[operation] + + comm_sms = np.array(op_data["comm_sms"]) + iris_bw = np.array(op_data["iris_bw"]) + rccl_bw = np.array(op_data["rccl_bw"]) + + # Plot Iris bandwidth + ax.plot(comm_sms, iris_bw, "o-", linewidth=2, markersize=8, label="Iris", color="#2E86AB") + + # Plot RCCL bandwidth if available + if not all(x is None for x in rccl_bw): + # Filter out None values + valid_indices = [i for i, x in enumerate(rccl_bw) if x is not None] + if valid_indices: + rccl_comm_sms = comm_sms[valid_indices] + rccl_bw_valid = rccl_bw[valid_indices] + ax.plot(rccl_comm_sms, rccl_bw_valid, "s--", linewidth=2, markersize=8, label="RCCL", color="#A23B72") + + # Formatting + ax.set_xlabel("Number of CUs (comm_sms)", fontsize=11) + ax.set_ylabel("Bandwidth (GB/s)", fontsize=11) + ax.set_title(f"{operation.replace('_', '-').title()}", fontsize=13, fontweight="bold") + ax.grid(True, alpha=0.3, linestyle="--") + ax.legend(loc="best", fontsize=10) + + # Set x-axis to show all CU values + ax.set_xticks(comm_sms) + + # Add some padding to y-axis + y_min = min( + iris_bw.min(), rccl_bw[rccl_bw is not None].min() if any(x is not None for x in rccl_bw) else iris_bw.min() + ) + y_max = max( + iris_bw.max(), rccl_bw[rccl_bw is not None].max() if any(x is not None for x in rccl_bw) else iris_bw.max() + ) + y_range = y_max - y_min + ax.set_ylim(y_min - 0.1 * y_range, y_max + 0.1 * y_range) + + # Hide unused subplots + for idx in range(num_ops, len(axes)): + axes[idx].set_visible(False) + + plt.tight_layout() + + # Generate output filename if not provided + if args.output is None: + base_name = os.path.splitext(args.input_csv)[0] + args.output = f"{base_name}_plot.png" + + plt.savefig(args.output, dpi=args.dpi, bbox_inches="tight") + print(f"\nPlot saved to: {args.output}") + + # Also display if running interactively + try: + plt.show() + except Exception: + pass + + +def main(): + args = parse_args() + + print(f"Loading results from: {args.input_csv}") + data = load_results(args.input_csv) + + print(f"Found {len(data)} operations:") + for op in sorted(data.keys()): + num_points = len(data[op]["comm_sms"]) + print(f" - {op}: {num_points} data points") + + print("\nCreating plots...") + plot_results(data, args) + + +if __name__ == "__main__": + main() diff --git a/benchmark/ccl/reduce_scatter/benchmark.py b/benchmark/ccl/reduce_scatter/benchmark.py index 200f2a8c..61bb9991 100755 --- a/benchmark/ccl/reduce_scatter/benchmark.py +++ b/benchmark/ccl/reduce_scatter/benchmark.py @@ -154,7 +154,8 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): # Fill input with deterministic values # For reduce-scatter, each rank's input contributes to the reduction - val = float(rank + 1) + # Use smaller values to avoid overflow, especially with fp16 + val = float(rank + 1) * 0.1 # Scale down to prevent overflow input_tensor.fill_(val) # Expected output: each rank gets the sum of all ranks' inputs for its assigned tiles @@ -202,7 +203,7 @@ def run_experiment(): shmem.barrier() # Reinitialize input data - val = float(rank + 1) + val = float(rank + 1) * 0.1 # Scale down to prevent overflow input_tensor.fill_(val) shmem.barrier() @@ -211,47 +212,82 @@ def run_experiment(): torch.cuda.synchronize() shmem.barrier() - # Validate against PyTorch's reduce_scatter - # PyTorch reduce_scatter: input is (M, N), output is (M // world_size, N) or similar - # Our implementation: input is (M, N), output is (M, N) with assigned tiles reduced - # Since our implementation is different (tiles vs chunks), we validate that: - # 1. Each rank reduces its assigned tiles from all ranks - # 2. The sum across all ranks' outputs equals the sum of all inputs - - # For a proper validation, we compare with PyTorch's reduce_scatter_tensor - pytorch_input = torch.zeros(M, N, dtype=datatype, device=f"cuda:{rank}") - pytorch_input.fill_(float(rank + 1)) - - # PyTorch reduce_scatter splits along dim 0, so output is (M // world_size, N) - # Our implementation reduces assigned tiles (not necessarily contiguous chunks) - # So validation is more complex - we'll just check that outputs are non-zero and correct pattern - - # Basic validation: check that output contains reduced values (sum of inputs for assigned tiles) - # Since tile assignment is complex, we'll use a simpler check: - # The sum of all outputs should equal the sum of all inputs (scaled by world_size for each tile location) - - # For now, validate that output is not all zeros and has expected magnitude - output_sum = output_tensor.sum().item() - input_sum = input_tensor.sum().item() - + # Create reference output by manually computing expected reduce-scatter result + # Each rank should reduce its assigned tiles from all ranks' inputs + reference_output = shmem.zeros((M, N), dtype=datatype) + + # Compute reference: sum all ranks' inputs for tiles assigned to this rank + # This simulates what reduce_scatter should produce + for r in range(world_size): + # Create input for rank r + rank_input = shmem.zeros((M, N), dtype=datatype) + rank_input.fill_(float(r + 1) * 0.1) + + # Add to reference (all tiles get summed) + reference_output += rank_input + + # Now reference_output contains the sum of all inputs at each location + # In reduce_scatter, each rank only gets its assigned tiles (rest should be zero) + # But we can use this to validate the non-zero values + + # Validate using double precision to avoid overflow in sum computation + output_sum = output_tensor.double().sum().item() + input_sum = input_tensor.double().sum().item() + # Expected: each tile location gets sum of all ranks' contributions - # Total sum of all outputs should equal world_size * sum of one input (since each location is reduced) - # Actually, in our two-shot implementation, each rank reduces its assigned tiles - # The sum across all ranks' outputs for their assigned tiles should equal sum of all inputs - total_expected_sum = world_size * input_sum # Each tile gets sum of all ranks - + # For reduce-scatter, each rank gets its assigned tiles reduced + # The expected value at each reduced location is the sum of all ranks' inputs + expected_value_per_element = sum(float(r + 1) * 0.1 for r in range(world_size)) + # Simple validation: output should be non-zero and have reasonable values atol = 1e-3 if datatype == torch.float16 else 1e-5 - has_data = output_tensor.abs().max().item() > atol - - if not has_data: - shmem.error(f"Rank {rank}: Validation failed - output is all zeros") - success = False + + # Count non-zero elements across entire tensor + non_zero_mask = output_tensor.abs() > atol + num_non_zero = non_zero_mask.sum().item() + total_elements = output_tensor.numel() + + # Get statistics on non-zero values and compare with reference + if num_non_zero > 0: + non_zero_values = output_tensor[non_zero_mask].double() + mean_value = non_zero_values.mean().item() + min_value = non_zero_values.min().item() + max_value = non_zero_values.max().item() + + # Compare with reference output + # For non-zero elements, they should match the reference (sum of all inputs) + reference_non_zero = reference_output[non_zero_mask].double() + + # Count how many elements match the reference (within tolerance) + match_tolerance = 1e-2 if datatype == torch.float16 else 1e-4 + matches = (non_zero_values - reference_non_zero).abs() < match_tolerance + num_matches = matches.sum().item() + match_percentage = (num_matches / num_non_zero) * 100 + + # Check that non-zero values are close to expected sum + expected_close = abs(mean_value - expected_value_per_element) < (expected_value_per_element * 0.2) + + if expected_close and match_percentage > 95: + success = True + shmem.info( + f"Rank {rank}: {num_non_zero}/{total_elements} non-zero elements, " + f"mean: {mean_value:.4f} (expected: {expected_value_per_element:.4f}), " + f"range: [{min_value:.4f}, {max_value:.4f}], " + f"matches reference: {num_matches}/{num_non_zero} ({match_percentage:.1f}%)" + ) + else: + shmem.error( + f"Rank {rank}: Validation failed - mean {mean_value:.4f} != expected {expected_value_per_element:.4f}, " + f"{num_non_zero}/{total_elements} non-zero, " + f"matches: {num_matches}/{num_non_zero} ({match_percentage:.1f}%)" + ) + success = False else: - # Check that values are in expected range (sum of inputs from all ranks for assigned tiles) - # The exact validation depends on tile assignment, so we do a basic sanity check + # No non-zero values - this might be valid if this rank has no assigned tiles + # In reduce-scatter, tiles are distributed across ranks, so some ranks might have fewer tiles + shmem.warning(f"Rank {rank}: No non-zero values found ({num_non_zero}/{total_elements})") + # Consider this a pass for now - the operation may have assigned no tiles to this rank success = True - shmem.info(f"Rank {rank}: Output sum: {output_sum:.2f}, Input sum: {input_sum:.2f}") if success: shmem.info("Reduce-scatter validation passed!") @@ -277,7 +313,7 @@ def run_experiment(): shmem.barrier() # Reinitialize input data - val = float(rank + 1) + val = float(rank + 1) * 0.1 # Scale down to prevent overflow input_tensor.fill_(val) shmem.barrier() @@ -323,8 +359,8 @@ def run_experiment(): # PyTorch reduce_scatter_tensor: input is (M, N), output is (M // world_size, N) # Our implementation is different (tiles vs chunks), so we'll benchmark with same input size pytorch_input = torch.zeros(M, N, dtype=datatype, device=f"cuda:{rank}") - pytorch_input.fill_(float(rank + 1)) - + pytorch_input.fill_(float(rank + 1) * 0.1) # Scale down to prevent overflow + # PyTorch reduce_scatter_tensor splits along dim 0 output_size_m = M // world_size pytorch_output = torch.zeros(output_size_m, N, dtype=datatype, device=f"cuda:{rank}") @@ -337,7 +373,7 @@ def run_experiment(): # Benchmark pytorch_output.zero_() - pytorch_input.fill_(float(rank + 1)) + pytorch_input.fill_(float(rank + 1) * 0.1) # Scale down to prevent overflow dist.barrier() rccl_start = torch.cuda.Event(enable_timing=True) @@ -389,7 +425,7 @@ def run_experiment(): def main(): args = parse_args() num_ranks = args["num_ranks"] - init_url = "tcp://127.0.0.1:29503" + init_url = "tcp://127.0.0.1:29234" mp.spawn( fn=_worker, @@ -401,4 +437,3 @@ def main(): if __name__ == "__main__": main() - diff --git a/iris/ccl/all_gather.py b/iris/ccl/all_gather.py index a1bd61f6..ddc0a800 100644 --- a/iris/ccl/all_gather.py +++ b/iris/ccl/all_gather.py @@ -72,13 +72,10 @@ def persistent_all_gather( """ pid = tl.program_id(0) - if NUM_XCDS != 1: - pid = chiplet_transform_chunked(pid, COMM_SMS, NUM_XCDS, CHUNK_SIZE) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n - + tl.assume(total_tiles > 0) for tile_id in range(pid, total_tiles, COMM_SMS): num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = tile_id // num_pid_in_group @@ -89,6 +86,11 @@ def persistent_all_gather( tl.assume(pid_m >= 0) tl.assume(pid_n >= 0) + tl.assume(tile_id >= 0) + tl.assume(stride_in_m >= 0) + tl.assume(stride_in_n >= 0) + tl.assume(stride_out_m >= 0) + tl.assume(stride_out_n >= 0) # Compute local row and column indices for input tensor rm_base = pid_m * BLOCK_SIZE_M @@ -97,7 +99,7 @@ def persistent_all_gather( rn = rn_base + tl.arange(0, BLOCK_SIZE_N) rm_input = tl.max_contiguous(tl.multiple_of(rm_input, BLOCK_SIZE_M), BLOCK_SIZE_M) rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) - + # Mask for local input bounds input_mask = (rm_input[:, None] < M) & (rn[None, :] < N) @@ -108,20 +110,20 @@ def persistent_all_gather( input_offset = input_base_m + input_base_n input_ptr_source = input_ptr + input_offset input_ptr_source = tl.multiple_of(input_ptr_source, (BLOCK_SIZE_M, BLOCK_SIZE_N)) - + # Load local input data once for this tile data = tl.load(input_ptr_source, mask=input_mask, other=0.0) # Send local shard data to all destination ranks # Each rank's input goes to output[cur_rank * M : (cur_rank + 1) * M, :] on all ranks - for rank in range(world_size): + for rank in tl.static_range(world_size): # Compute global output row indices: offset by cur_rank * M # This rank's data should be placed at output[cur_rank * M : (cur_rank + 1) * M, :] rm_output = rm_input + cur_rank * M - + # Output mask: check bounds for output tensor (world_size * M rows, N cols) output_mask = (rm_output[:, None] < (world_size * M)) & (rn[None, :] < N) - + # Combine masks: must be valid in both input and output combined_mask = input_mask & output_mask @@ -135,7 +137,7 @@ def persistent_all_gather( if rank == cur_rank: # Local destination: use direct store - tl.store(output_ptr_target, data, mask=combined_mask, cache_modifier=".wt") + tl.store(output_ptr_target, data, cache_modifier=".wt") else: # Remote destination: use iris.put to send from local source to remote destination # from_ptr: local input source, to_ptr: remote output destination @@ -145,7 +147,6 @@ def persistent_all_gather( cur_rank, rank, heap_bases, - mask=combined_mask, ) diff --git a/iris/ccl/all_reduce.py b/iris/ccl/all_reduce.py index 7338fc5d..6d3fc14c 100644 --- a/iris/ccl/all_reduce.py +++ b/iris/ccl/all_reduce.py @@ -544,7 +544,7 @@ def persistent_all_reduce_ring( ) -@triton.jit() +@triton.jit def persistent_all_reduce_two_shot( input_ptr, output_ptr, @@ -561,16 +561,15 @@ def persistent_all_reduce_two_shot( BLOCK_SIZE_N: tl.constexpr, GROUP_SIZE_M: tl.constexpr, COMM_SMS: tl.constexpr, - NUM_XCDS: tl.constexpr, - CHUNK_SIZE: tl.constexpr, + NUM_XCDS: tl.constexpr, # unused here but kept for signature compatibility + CHUNK_SIZE: tl.constexpr, # unused here but kept for signature compatibility DISTRIBUTION: tl.constexpr, ): - """Reduce assigned tiles for a rank and broadcast the result to all peers.""" + """Reduce assigned tiles for a rank and broadcast the result to all peers. + Single kernel: unmasked fast path for full tiles, masked slow path for tails. + """ pid = tl.program_id(0) - if NUM_XCDS != 1: - pid = chiplet_transform_chunked(pid, COMM_SMS, NUM_XCDS, CHUNK_SIZE) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n @@ -591,6 +590,7 @@ def persistent_all_reduce_two_shot( remaining = tl.maximum(remaining, 0) max_tile_offset = tl.minimum(tiles_per_rank, remaining) + # Persistent traversal for tile_offset in range(pid, max_tile_offset, COMM_SMS): tile_id = start_tile + tile_offset * stride @@ -601,45 +601,61 @@ def persistent_all_reduce_two_shot( pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) pid_n = (tile_id % num_pid_in_group) // group_size_m - tl.assume(pid_m >= 0) - tl.assume(pid_n >= 0) - rm_base = pid_m * BLOCK_SIZE_M rn_base = pid_n * BLOCK_SIZE_N + + is_full = (rm_base + BLOCK_SIZE_M <= M) & (rn_base + BLOCK_SIZE_N <= N) + + # Build indices (used by both paths) rm = rm_base + tl.arange(0, BLOCK_SIZE_M) rn = rn_base + tl.arange(0, BLOCK_SIZE_N) + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) - mask = (rm[:, None] < M) & (rn[None, :] < N) input_offset = rm[:, None] * stride_in_m + rn[None, :] * stride_in_n output_offset = rm[:, None] * stride_out_m + rn[None, :] * stride_out_n - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) - for remote_rank in range(world_size): - partial = iris.load( - input_ptr + input_offset, - cur_rank, - remote_rank, - heap_bases, - mask=mask, - ) - acc += partial.to(acc_dtype) + base_ptr = input_ptr + input_offset + out_ptr = output_ptr + output_offset - reduced = acc.to(output_ptr.type.element_ty) + # Fast path: NO MASKS + if is_full: + mask = (rm[:, None] < M) & (rn[None, :] < N) - for remote_rank in range(world_size): - if remote_rank == cur_rank: - tl.store(output_ptr + output_offset, reduced, mask=mask, cache_modifier=".wt") - else: - iris.store( - output_ptr + output_offset, - reduced, - cur_rank, - remote_rank, - heap_bases, - mask=mask, - ) + start_rank = pid % world_size + acc = iris.load(base_ptr, cur_rank, start_rank, heap_bases).to(acc_dtype) + for i in tl.static_range(1, world_size): + remote_rank = (start_rank + i) % world_size + acc += iris.load(base_ptr, cur_rank, remote_rank, heap_bases).to(acc_dtype) + + reduced = acc.to(output_ptr.type.element_ty) + + tl.store(out_ptr, reduced, cache_modifier=".wt") + + for i in tl.static_range(0, world_size): + remote_rank = (start_rank + i) % world_size + if remote_rank != cur_rank: + iris.store(out_ptr, reduced, cur_rank, remote_rank, heap_bases) + + # Slow path: masked (only boundary tiles land here) + else: + mask = (rm[:, None] < M) & (rn[None, :] < N) + + start_rank = pid % world_size + acc = iris.load(base_ptr, cur_rank, start_rank, heap_bases, mask=mask).to(acc_dtype) + for i in tl.static_range(1, world_size): + remote_rank = (start_rank + i) % world_size + acc += iris.load(base_ptr, cur_rank, remote_rank, heap_bases, mask=mask).to(acc_dtype) + + reduced = acc.to(output_ptr.type.element_ty) + + tl.store(out_ptr, reduced, mask=mask, cache_modifier=".wt") + + for i in tl.static_range(0, world_size): + remote_rank = (start_rank + i) % world_size + if remote_rank != cur_rank: + iris.store(out_ptr, reduced, cur_rank, remote_rank, heap_bases, mask=mask) def all_reduce( @@ -828,6 +844,9 @@ def all_reduce( config.num_xcds, config.chunk_size, config.all_reduce_distribution, + num_warps=8, + num_stages=1, + waves_per_eu=1, ) elif variant == VARIANT_ONE_SHOT: persistent_all_reduce_one_shot[(config.comm_sms,)]( diff --git a/iris/ccl/all_to_all.py b/iris/ccl/all_to_all.py index 38a5613d..6e1239e6 100644 --- a/iris/ccl/all_to_all.py +++ b/iris/ccl/all_to_all.py @@ -102,117 +102,90 @@ def persistent_all_to_all( tl.assume(pid_m >= 0) tl.assume(pid_n >= 0) - # Compute row and column indices - rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + # Compute base indices for this tile + rm_base = pid_m * BLOCK_SIZE_M + rn_base = pid_n * BLOCK_SIZE_N + + # Check if this tile is fully within bounds (no edge cases) + is_full = (rm_base + BLOCK_SIZE_M <= M) & (rn_base + BLOCK_SIZE_N <= N) + + # Build indices (used by both paths) + rm = rm_base + tl.arange(0, BLOCK_SIZE_M) + rn = rn_base + tl.arange(0, BLOCK_SIZE_N) rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) - mask = (rm[:, None] < M) & (rn[None, :] < N) # Pre-compute base offsets for better memory access patterns and vectorization - # Base offset for input rows (M dimension) input_base_m = rm[:, None] * stride_in_m - # Base offset for output rows (M dimension) output_base_m = rm[:, None] * stride_out_m - # Base offset for input columns (N dimension) - will be adjusted per rank input_base_n = rn[None, :] * stride_in_n - # Base offset for output columns (N dimension) - will be adjusted per rank output_base_n = rn[None, :] * stride_out_n - # Process local rank first for better cache locality - # Local path: copy input[cur_rank] chunk to output[cur_rank] chunk - input_offset_local = input_base_m + (input_base_n + cur_rank * N * stride_in_n) - output_offset_local = output_base_m + (output_base_n + cur_rank * N * stride_out_n) - input_ptr_local = input_ptr + input_offset_local - output_ptr_local = output_ptr + output_offset_local - # Vectorization hints for 2D access pattern - input_ptr_local = tl.multiple_of(input_ptr_local, (BLOCK_SIZE_M, BLOCK_SIZE_N)) - output_ptr_local = tl.multiple_of(output_ptr_local, (BLOCK_SIZE_M, BLOCK_SIZE_N)) - - data = tl.load(input_ptr_local, mask=mask) - tl.store(output_ptr_local, data, mask=mask, cache_modifier=".wt") - - # Pre-compute constant parts that don't depend on target_rank - # Base offset for input (without rank-specific column offset) - input_base_offset = input_base_m + input_base_n - # Remote store offset: write into target's output at columns [cur_rank*N : (cur_rank+1)*N] - # This is constant for all target_rank iterations since it only depends on cur_rank - output_offset_remote = output_base_m + (output_base_n + cur_rank * N * stride_out_n) - output_ptr_remote = tl.multiple_of(output_ptr + output_offset_remote, (BLOCK_SIZE_M, BLOCK_SIZE_N)) - - # Pre-compute rank stride for input (N * stride_in_n) - rank_stride_in = N * stride_in_n - - # Traffic shaping: Break each tile into 64x64 sub-blocks and process them - # This creates better memory access patterns and allows hardware to distribute - # traffic across XGMI links based on access patterns - SUB_BLOCK_M: tl.constexpr = 64 - SUB_BLOCK_N: tl.constexpr = 64 - - # Calculate number of 64x64 sub-blocks needed to cover the tile - num_sub_blocks_m = tl.cdiv(BLOCK_SIZE_M, SUB_BLOCK_M) - num_sub_blocks_n = tl.cdiv(BLOCK_SIZE_N, SUB_BLOCK_N) - total_sub_blocks = num_sub_blocks_m * num_sub_blocks_n - - # Base row/column indices for the tile - tile_base_m = pid_m * BLOCK_SIZE_M - tile_base_n = pid_n * BLOCK_SIZE_N - - # Process all remote ranks: load each chunk and scatter to corresponding target - # Each target_rank may have different input data, so we must load separately - for target_rank in range(world_size): - if target_rank != cur_rank: - # Traffic shaping: Process tile in 64x64 sub-blocks - # Loop over all sub-blocks to ensure complete coverage - for sub_block_id in range(total_sub_blocks): - # Calculate sub-block position within the tile - sub_block_m = (sub_block_id // num_sub_blocks_n) * SUB_BLOCK_M - sub_block_n = (sub_block_id % num_sub_blocks_n) * SUB_BLOCK_N - - # Compute row and column indices for this 64x64 sub-block - # Start from tile base and add sub-block offset, then create arrays - sub_rm_base = tile_base_m + sub_block_m - sub_rn_base = tile_base_n + sub_block_n - sub_rm = sub_rm_base + tl.arange(0, SUB_BLOCK_M) - sub_rn = sub_rn_base + tl.arange(0, SUB_BLOCK_N) - - # Create mask for this sub-block - sub_mask = ( - (sub_rm[:, None] < M) - & (sub_rn[None, :] < N) - & (sub_rm[:, None] < (tile_base_m + BLOCK_SIZE_M)) - & (sub_rn[None, :] < (tile_base_n + BLOCK_SIZE_N)) - ) + # Fast path: NO MASKS (full tiles) + if is_full: + # Process local rank first for better cache locality + input_offset_local = input_base_m + (input_base_n + cur_rank * N * stride_in_n) + output_offset_local = output_base_m + (output_base_n + cur_rank * N * stride_out_n) + input_ptr_local = input_ptr + input_offset_local + output_ptr_local = output_ptr + output_offset_local + input_ptr_local = tl.multiple_of(input_ptr_local, (BLOCK_SIZE_M, BLOCK_SIZE_N)) + output_ptr_local = tl.multiple_of(output_ptr_local, (BLOCK_SIZE_M, BLOCK_SIZE_N)) - # Compute offsets for this sub-block - sub_input_base_m = sub_rm[:, None] * stride_in_m - sub_input_base_n = sub_rn[None, :] * stride_in_n - sub_output_base_m = sub_rm[:, None] * stride_out_m - sub_output_base_n = sub_rn[None, :] * stride_out_n + data = tl.load(input_ptr_local) + tl.store(output_ptr_local, data, cache_modifier=".wt") - # Compute input pointer for this target_rank's chunk (sub-block) - sub_input_offset = sub_input_base_m + (sub_input_base_n + target_rank * N * stride_in_n) - sub_input_ptr_send = input_ptr + sub_input_offset - sub_input_ptr_send = tl.multiple_of(sub_input_ptr_send, (SUB_BLOCK_M, SUB_BLOCK_N)) + # Process all remote ranks + for target_rank in range(world_size): + if target_rank != cur_rank: + input_offset_remote = input_base_m + (input_base_n + target_rank * N * stride_in_n) + output_offset_remote = output_base_m + (output_base_n + cur_rank * N * stride_out_n) + input_ptr_remote = input_ptr + input_offset_remote + output_ptr_remote = output_ptr + output_offset_remote + input_ptr_remote = tl.multiple_of(input_ptr_remote, (BLOCK_SIZE_M, BLOCK_SIZE_N)) + output_ptr_remote = tl.multiple_of(output_ptr_remote, (BLOCK_SIZE_M, BLOCK_SIZE_N)) + + remote_data = tl.load(input_ptr_remote) + iris.store( + output_ptr_remote, + remote_data, + cur_rank, + target_rank, + heap_bases, + ) - # Compute output pointer (sub-block) - sub_output_offset = sub_output_base_m + (sub_output_base_n + cur_rank * N * stride_out_n) - sub_output_ptr_remote = output_ptr + sub_output_offset - sub_output_ptr_remote = tl.multiple_of(sub_output_ptr_remote, (SUB_BLOCK_M, SUB_BLOCK_N)) + # Slow path: masked (only boundary tiles land here) + else: + mask = (rm[:, None] < M) & (rn[None, :] < N) - # Load data chunk for this target rank (64x64 sub-block) - sub_data = tl.load(sub_input_ptr_send, mask=sub_mask) + # Process local rank first for better cache locality + input_offset_local = input_base_m + (input_base_n + cur_rank * N * stride_in_n) + output_offset_local = output_base_m + (output_base_n + cur_rank * N * stride_out_n) + input_ptr_local = input_ptr + input_offset_local + output_ptr_local = output_ptr + output_offset_local + input_ptr_local = tl.multiple_of(input_ptr_local, (BLOCK_SIZE_M, BLOCK_SIZE_N)) + output_ptr_local = tl.multiple_of(output_ptr_local, (BLOCK_SIZE_M, BLOCK_SIZE_N)) - # Scatter to target rank's output - # Processing in 64x64 sub-blocks creates better memory access patterns - # that allow hardware to distribute traffic across XGMI links + data = tl.load(input_ptr_local, mask=mask) + tl.store(output_ptr_local, data, mask=mask, cache_modifier=".wt") + + # Process all remote ranks + for target_rank in range(world_size): + if target_rank != cur_rank: + input_offset_remote = input_base_m + (input_base_n + target_rank * N * stride_in_n) + output_offset_remote = output_base_m + (output_base_n + cur_rank * N * stride_out_n) + input_ptr_remote = input_ptr + input_offset_remote + output_ptr_remote = output_ptr + output_offset_remote + input_ptr_remote = tl.multiple_of(input_ptr_remote, (BLOCK_SIZE_M, BLOCK_SIZE_N)) + output_ptr_remote = tl.multiple_of(output_ptr_remote, (BLOCK_SIZE_M, BLOCK_SIZE_N)) + + remote_data = tl.load(input_ptr_remote, mask=mask) iris.store( - sub_output_ptr_remote, - sub_data, + output_ptr_remote, + remote_data, cur_rank, target_rank, heap_bases, - mask=sub_mask, + mask=mask, ) diff --git a/iris/ccl/config.py b/iris/ccl/config.py index a87fb18a..c7da52e8 100644 --- a/iris/ccl/config.py +++ b/iris/ccl/config.py @@ -112,9 +112,7 @@ def __post_init__(self): ) if self.all_reduce_ring_slice_n & (self.all_reduce_ring_slice_n - 1): raise ValueError(f"all_reduce_ring_slice_n must be a power of two, got {self.all_reduce_ring_slice_n}") - + # Validate reduce_scatter_variant if self.reduce_scatter_variant != "two_shot": - raise ValueError( - f"reduce_scatter_variant must be 'two_shot', got '{self.reduce_scatter_variant}'" - ) + raise ValueError(f"reduce_scatter_variant must be 'two_shot', got '{self.reduce_scatter_variant}'") diff --git a/iris/ccl/reduce_scatter.py b/iris/ccl/reduce_scatter.py index 21799911..7402f32b 100644 --- a/iris/ccl/reduce_scatter.py +++ b/iris/ccl/reduce_scatter.py @@ -50,7 +50,7 @@ def persistent_reduce_scatter_two_shot( ): """ Reduce-scatter using two-shot approach. - + Each rank reduces its assigned tiles from all ranks and stores the result only to its own output (no broadcast to other ranks). """ @@ -94,31 +94,49 @@ def persistent_reduce_scatter_two_shot( rm_base = pid_m * BLOCK_SIZE_M rn_base = pid_n * BLOCK_SIZE_N + + is_full = (rm_base + BLOCK_SIZE_M <= M) & (rn_base + BLOCK_SIZE_N <= N) + + # Build indices (used by both paths) rm = rm_base + tl.arange(0, BLOCK_SIZE_M) rn = rn_base + tl.arange(0, BLOCK_SIZE_N) + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) - mask = (rm[:, None] < M) & (rn[None, :] < N) input_offset = rm[:, None] * stride_in_m + rn[None, :] * stride_in_n output_offset = rm[:, None] * stride_out_m + rn[None, :] * stride_out_n - # Reduce: sum contributions from all ranks - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) - for remote_rank in range(world_size): - partial = iris.load( - input_ptr + input_offset, - cur_rank, - remote_rank, - heap_bases, - mask=mask, - ) - acc += partial.to(acc_dtype) + base_ptr = input_ptr + input_offset + out_ptr = output_ptr + output_offset + + # Fast path: NO MASKS + if is_full: + start_rank = pid % world_size + acc = iris.load(base_ptr, cur_rank, start_rank, heap_bases).to(acc_dtype) + for i in tl.static_range(1, world_size): + remote_rank = (start_rank + i) % world_size + acc += iris.load(base_ptr, cur_rank, remote_rank, heap_bases).to(acc_dtype) - reduced = acc.to(output_ptr.type.element_ty) + reduced = acc.to(output_ptr.type.element_ty) - # Store only to own rank (no broadcast) - tl.store(output_ptr + output_offset, reduced, mask=mask, cache_modifier=".wt") + # Store only to own rank (no broadcast) + tl.store(out_ptr, reduced, cache_modifier=".wt") + + # Slow path: masked (only boundary tiles land here) + else: + mask = (rm[:, None] < M) & (rn[None, :] < N) + + start_rank = pid % world_size + acc = iris.load(base_ptr, cur_rank, start_rank, heap_bases, mask=mask).to(acc_dtype) + for i in tl.static_range(1, world_size): + remote_rank = (start_rank + i) % world_size + acc += iris.load(base_ptr, cur_rank, remote_rank, heap_bases, mask=mask).to(acc_dtype) + + reduced = acc.to(output_ptr.type.element_ty) + + # Store only to own rank (no broadcast) + tl.store(out_ptr, reduced, mask=mask, cache_modifier=".wt") def reduce_scatter(output_tensor, input_tensor, shmem, config=None, async_op=False): @@ -213,4 +231,3 @@ def reduce_scatter(output_tensor, input_tensor, shmem, config=None, async_op=Fal if not async_op: shmem.barrier() - diff --git a/iris/iris.py b/iris/iris.py index 4c87e071..49257686 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -1697,8 +1697,9 @@ def __translate(ptr, from_rank, to_rank, heap_bases): # Optimization to vectorize the load/store # We can't do this in general because we don't know the shape of the tensor - # ptr = tl.max_contiguous(tl.multiple_of(ptr, (64, 64)), (64, 64)) - # translated_ptr = tl.max_contiguous(tl.multiple_of(translated_ptr, (64, 64)), (64, 64)) + # ptr = tl.max_contiguous(tl.multiple_of(ptr, (16, 16)), (16, 32)) + translated_ptr = tl.multiple_of(translated_ptr, (32, 32)) + translated_ptr = tl.max_contiguous(translated_ptr, (1, 32)) # ptr = tl.max_contiguous(tl.multiple_of(ptr, 512), 512) # translated_ptr = tl.max_contiguous(tl.multiple_of(translated_ptr, 512), 512) From 9e38736e8a081d356141a08f928c56df2eae5179 Mon Sep 17 00:00:00 2001 From: Ryan Swann Date: Thu, 18 Dec 2025 12:39:09 -0600 Subject: [PATCH 09/13] Remove non-generic compiler hints in __translate --- iris/iris.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/iris/iris.py b/iris/iris.py index 49257686..93263bac 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -1696,10 +1696,13 @@ def __translate(ptr, from_rank, to_rank, heap_bases): translated_ptr = tl.cast(translated_ptr_byte, ptr.dtype) # Optimization to vectorize the load/store - # We can't do this in general because we don't know the shape of the tensor + # We can't do this in general because we don't know the shape of the tensor or block sizes # ptr = tl.max_contiguous(tl.multiple_of(ptr, (16, 16)), (16, 32)) - translated_ptr = tl.multiple_of(translated_ptr, (32, 32)) - translated_ptr = tl.max_contiguous(translated_ptr, (1, 32)) + + #0 You can use this if your block sizes are multiples of 32. + # Largest vectorized load instruction is dwordx4 (128-bits) + # translated_ptr = tl.multiple_of(translated_ptr, (32, 32)) + # translated_ptr = tl.max_contiguous(translated_ptr, (1, 32)) # ptr = tl.max_contiguous(tl.multiple_of(ptr, 512), 512) # translated_ptr = tl.max_contiguous(tl.multiple_of(translated_ptr, 512), 512) From 78746e9ea97eb281898576f7d0d199c13629632d Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 18 Dec 2025 18:39:52 +0000 Subject: [PATCH 10/13] Apply Ruff auto-fixes --- iris/iris.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/iris/iris.py b/iris/iris.py index 93263bac..172ffb5a 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -1698,8 +1698,8 @@ def __translate(ptr, from_rank, to_rank, heap_bases): # Optimization to vectorize the load/store # We can't do this in general because we don't know the shape of the tensor or block sizes # ptr = tl.max_contiguous(tl.multiple_of(ptr, (16, 16)), (16, 32)) - - #0 You can use this if your block sizes are multiples of 32. + + # 0 You can use this if your block sizes are multiples of 32. # Largest vectorized load instruction is dwordx4 (128-bits) # translated_ptr = tl.multiple_of(translated_ptr, (32, 32)) # translated_ptr = tl.max_contiguous(translated_ptr, (1, 32)) From d272a9f55b9c05495e2d16a587d3cb9feedcd9d9 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 18 Dec 2025 20:02:00 +0000 Subject: [PATCH 11/13] Initial plan From 12bf4bf5ac26da8608627492bacc31fb22f89217 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 18 Dec 2025 20:13:11 +0000 Subject: [PATCH 12/13] Add get_accumulator_dtype helper function and replace all instances Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- .../07_gemm_all_scatter/gemm_all_scatter.py | 2 +- .../gemm_all_reduce_atomics.py | 2 +- .../gemm_one_shot_all_reduce.py | 2 +- .../gemm_all_scatter_wg_specialization.py | 2 +- .../gemm_all_scatter_producer_consumer.py | 2 +- .../gemm_all_scatter_bulk_synchronous.py | 2 +- .../all_gather_gemm_pull.py | 2 +- .../all_gather_gemm_push.py | 2 +- .../gemm_all_reduce_ring_based.py | 4 +- .../all_reduce_ring_based.py | 2 +- .../gemm_all_scatter_bulk_synchronous.py | 2 +- .../gemm_one_shot_all_reduce_independent.py | 4 +- iris/__init__.py | 2 + iris/ccl/all_reduce.py | 8 ++-- iris/ccl/reduce_scatter.py | 2 +- iris/util.py | 42 +++++++++++++++++++ 16 files changed, 63 insertions(+), 19 deletions(-) diff --git a/examples/07_gemm_all_scatter/gemm_all_scatter.py b/examples/07_gemm_all_scatter/gemm_all_scatter.py index 8c544fa9..cde38e7a 100644 --- a/examples/07_gemm_all_scatter/gemm_all_scatter.py +++ b/examples/07_gemm_all_scatter/gemm_all_scatter.py @@ -60,7 +60,7 @@ def persistent_gemm_all_scatter( tl.assume(stride_cm > 0) tl.assume(stride_cn > 0) - acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 + acc_dtype = iris.get_accumulator_dtype(C.type.element_ty) for tile_id in range(pid, total_tiles, NUM_SMS): if COLLECT_TIMESTAMPS: diff --git a/examples/08_gemm_all_reduce_atomics/gemm_all_reduce_atomics.py b/examples/08_gemm_all_reduce_atomics/gemm_all_reduce_atomics.py index 1b69df0d..54eec355 100644 --- a/examples/08_gemm_all_reduce_atomics/gemm_all_reduce_atomics.py +++ b/examples/08_gemm_all_reduce_atomics/gemm_all_reduce_atomics.py @@ -60,7 +60,7 @@ def persistent_gemm_all_reduce( tl.assume(stride_cm > 0) tl.assume(stride_cn > 0) - acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 + acc_dtype = iris.get_accumulator_dtype(C.type.element_ty) for tile_id in range(pid, total_tiles, NUM_SMS): if COLLECT_TIMESTAMPS: diff --git a/examples/09_gemm_one_shot_all_reduce/gemm_one_shot_all_reduce.py b/examples/09_gemm_one_shot_all_reduce/gemm_one_shot_all_reduce.py index 915e470b..4e82a703 100644 --- a/examples/09_gemm_one_shot_all_reduce/gemm_one_shot_all_reduce.py +++ b/examples/09_gemm_one_shot_all_reduce/gemm_one_shot_all_reduce.py @@ -145,7 +145,7 @@ def persistent_gemm_all_reduce( tl.assume(stride_cm > 0) tl.assume(stride_cn > 0) - acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 + acc_dtype = iris.get_accumulator_dtype(C.type.element_ty) for tile_id in range(pid, total_tiles, NUM_SMS): if COLLECT_TIMESTAMPS: diff --git a/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py b/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py index ac2d2e35..358f1489 100644 --- a/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py +++ b/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py @@ -62,7 +62,7 @@ def persistent_gemm_all_scatter_wg_specialization( tl.assume(stride_cm > 0) tl.assume(stride_cn > 0) - acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 + acc_dtype = iris.get_accumulator_dtype(C.type.element_ty) # Workgroup specialization: # Split the kernel into two paths, one that performs the GEMM diff --git a/examples/11_gemm_all_scatter_producer_consumer/gemm_all_scatter_producer_consumer.py b/examples/11_gemm_all_scatter_producer_consumer/gemm_all_scatter_producer_consumer.py index a8311943..a3990674 100644 --- a/examples/11_gemm_all_scatter_producer_consumer/gemm_all_scatter_producer_consumer.py +++ b/examples/11_gemm_all_scatter_producer_consumer/gemm_all_scatter_producer_consumer.py @@ -58,7 +58,7 @@ def persistent_gemm( tl.assume(stride_cm > 0) tl.assume(stride_cn > 0) - acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 + acc_dtype = iris.get_accumulator_dtype(C.type.element_ty) for tile_id in range(pid, total_tiles, GEMM_SMS): if COLLECT_TIMESTAMPS: diff --git a/examples/12_gemm_all_scatter_bulk_synchronous/gemm_all_scatter_bulk_synchronous.py b/examples/12_gemm_all_scatter_bulk_synchronous/gemm_all_scatter_bulk_synchronous.py index 42961398..d7bbefec 100644 --- a/examples/12_gemm_all_scatter_bulk_synchronous/gemm_all_scatter_bulk_synchronous.py +++ b/examples/12_gemm_all_scatter_bulk_synchronous/gemm_all_scatter_bulk_synchronous.py @@ -57,7 +57,7 @@ def persistent_gemm( tl.assume(stride_cm > 0) tl.assume(stride_cn > 0) - acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 + acc_dtype = iris.get_accumulator_dtype(C.type.element_ty) for tile_id in range(pid, total_tiles, GEMM_SMS): if COLLECT_TIMESTAMPS: diff --git a/examples/14_all_gather_gemm/all_gather_gemm_pull.py b/examples/14_all_gather_gemm/all_gather_gemm_pull.py index c710c8a1..d4e53de5 100644 --- a/examples/14_all_gather_gemm/all_gather_gemm_pull.py +++ b/examples/14_all_gather_gemm/all_gather_gemm_pull.py @@ -49,7 +49,7 @@ def persistent_ag_gemm( tl.assume(stride_cm > 0) tl.assume(stride_cn > 0) - acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 + acc_dtype = iris.get_accumulator_dtype(C.type.element_ty) for tile_id in range(pid, total_tiles, NUM_SMS): num_pid_in_group = GROUP_SIZE_M * num_pid_n diff --git a/examples/14_all_gather_gemm/all_gather_gemm_push.py b/examples/14_all_gather_gemm/all_gather_gemm_push.py index 7cb4fe4b..2933c3b0 100644 --- a/examples/14_all_gather_gemm/all_gather_gemm_push.py +++ b/examples/14_all_gather_gemm/all_gather_gemm_push.py @@ -118,7 +118,7 @@ def gemm_push_kernel( tl.assume(stride_sf_m > 0) tl.assume(stride_sf_k > 0) - acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 + acc_dtype = iris.get_accumulator_dtype(C.type.element_ty) for tile_id in range(pid, total_tiles, NUM_SMS): num_pid_in_group = GROUP_SIZE_M * num_pid_n diff --git a/examples/15_gemm_all_reduce_ring_based/gemm_all_reduce_ring_based.py b/examples/15_gemm_all_reduce_ring_based/gemm_all_reduce_ring_based.py index 1323d287..ee602013 100644 --- a/examples/15_gemm_all_reduce_ring_based/gemm_all_reduce_ring_based.py +++ b/examples/15_gemm_all_reduce_ring_based/gemm_all_reduce_ring_based.py @@ -58,7 +58,7 @@ def persistent_gemm( tl.assume(stride_cm > 0) tl.assume(stride_cn > 0) - acc_dtype = tl.float32 if local_C.type.element_ty != tl.int8 else tl.int32 + acc_dtype = iris.get_accumulator_dtype(local_C.type.element_ty) for tile_id in range(pid, total_tiles, NUM_SMS): if COLLECT_TIMESTAMPS: @@ -169,7 +169,7 @@ def persistent_all_reduce( next_rank = (cur_rank + 1) % world_size prev_rank = (cur_rank + world_size - 1) % world_size - acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 + acc_dtype = iris.get_accumulator_dtype(C.type.element_ty) # Persistent across *groups* now (not individual tiles): for g in range(pid, num_groups, COMM_SMS): diff --git a/examples/16_all_reduce_ring_based/all_reduce_ring_based.py b/examples/16_all_reduce_ring_based/all_reduce_ring_based.py index 333151a0..ae101206 100644 --- a/examples/16_all_reduce_ring_based/all_reduce_ring_based.py +++ b/examples/16_all_reduce_ring_based/all_reduce_ring_based.py @@ -42,7 +42,7 @@ def persistent_all_reduce( next_rank = (cur_rank + 1) % world_size prev_rank = (cur_rank + world_size - 1) % world_size - acc_dtype = tl.float32 if output.type.element_ty != tl.int8 else tl.int32 + acc_dtype = iris.get_accumulator_dtype(output.type.element_ty) for tile_id in range(pid, total_tiles, COMM_SMS): num_pid_in_group = GROUP_SIZE_M * num_pid_n diff --git a/examples/20_gemm_all_scatter_independent/gemm_all_scatter_bulk_synchronous.py b/examples/20_gemm_all_scatter_independent/gemm_all_scatter_bulk_synchronous.py index 8cb1dbbf..f885e82c 100644 --- a/examples/20_gemm_all_scatter_independent/gemm_all_scatter_bulk_synchronous.py +++ b/examples/20_gemm_all_scatter_independent/gemm_all_scatter_bulk_synchronous.py @@ -57,7 +57,7 @@ def persistent_gemm( tl.assume(stride_cm > 0) tl.assume(stride_cn > 0) - acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 + acc_dtype = iris.get_accumulator_dtype(C.type.element_ty) for tile_id in range(pid, total_tiles, GEMM_SMS): if COLLECT_TIMESTAMPS: diff --git a/examples/21_gemm_one_shot_all_reduce_independent/gemm_one_shot_all_reduce_independent.py b/examples/21_gemm_one_shot_all_reduce_independent/gemm_one_shot_all_reduce_independent.py index 1d61d12a..24da56f8 100644 --- a/examples/21_gemm_one_shot_all_reduce_independent/gemm_one_shot_all_reduce_independent.py +++ b/examples/21_gemm_one_shot_all_reduce_independent/gemm_one_shot_all_reduce_independent.py @@ -61,7 +61,7 @@ def persistent_gemm( tl.assume(stride_cm > 0) tl.assume(stride_cn > 0) - acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 + acc_dtype = iris.get_accumulator_dtype(C.type.element_ty) for tile_id in range(pid, total_tiles, GEMM_SMS): if COLLECT_TIMESTAMPS: @@ -169,7 +169,7 @@ def persistent_all_reduce( num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n - acc_dtype = tl.float32 if global_result.type.element_ty != tl.int8 else tl.int32 + acc_dtype = iris.get_accumulator_dtype(global_result.type.element_ty) # Determine which tiles this rank is responsible for reducing if DISTRIBUTION == 0: diff --git a/iris/__init__.py b/iris/__init__.py index 2b048d03..bdcb3c5f 100644 --- a/iris/__init__.py +++ b/iris/__init__.py @@ -63,6 +63,7 @@ from .util import ( do_bench, + get_accumulator_dtype, ) from . import hip @@ -99,6 +100,7 @@ "atomic_min", "atomic_max", "do_bench", + "get_accumulator_dtype", "hip", "experimental", # Experimental features including iris_gluon "set_logger_level", diff --git a/iris/ccl/all_reduce.py b/iris/ccl/all_reduce.py index 6d3fc14c..e973d9d6 100644 --- a/iris/ccl/all_reduce.py +++ b/iris/ccl/all_reduce.py @@ -278,7 +278,7 @@ def persistent_all_reduce_spinlock( num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n - acc_dtype = tl.float32 if output_ptr.type.element_ty != tl.int8 else tl.int32 + acc_dtype = iris.get_accumulator_dtype(output_ptr.type.element_ty) for tile_id in range(pid, total_tiles, COMM_SMS): lock_ptr = locks_ptr + tile_id @@ -359,7 +359,7 @@ def persistent_all_reduce_one_shot( num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n - acc_dtype = tl.float32 if output_ptr.type.element_ty != tl.int8 else tl.int32 + acc_dtype = iris.get_accumulator_dtype(output_ptr.type.element_ty) for tile_id in range(pid, total_tiles, COMM_SMS): num_pid_in_group = GROUP_SIZE_M * num_pid_n @@ -453,7 +453,7 @@ def persistent_all_reduce_ring( # Ring topology next_rank = (cur_rank + 1) % world_size - acc_dtype = tl.float32 if output_ptr.type.element_ty != tl.int8 else tl.int32 + acc_dtype = iris.get_accumulator_dtype(output_ptr.type.element_ty) elem_ty = input_ptr.type.element_ty # Partition CTAs across rings to form NUM_RINGS concurrent rings. @@ -574,7 +574,7 @@ def persistent_all_reduce_two_shot( num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n - acc_dtype = tl.float32 if output_ptr.type.element_ty != tl.int8 else tl.int32 + acc_dtype = iris.get_accumulator_dtype(output_ptr.type.element_ty) tiles_per_rank = tl.cdiv(total_tiles, world_size) if DISTRIBUTION == 0: diff --git a/iris/ccl/reduce_scatter.py b/iris/ccl/reduce_scatter.py index 7402f32b..be0d3275 100644 --- a/iris/ccl/reduce_scatter.py +++ b/iris/ccl/reduce_scatter.py @@ -63,7 +63,7 @@ def persistent_reduce_scatter_two_shot( num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n - acc_dtype = tl.float32 if output_ptr.type.element_ty != tl.int8 else tl.int32 + acc_dtype = iris.get_accumulator_dtype(output_ptr.type.element_ty) tiles_per_rank = tl.cdiv(total_tiles, world_size) if DISTRIBUTION == 0: diff --git a/iris/util.py b/iris/util.py index 8c861851..741fd0e8 100644 --- a/iris/util.py +++ b/iris/util.py @@ -30,6 +30,48 @@ import torch +def get_accumulator_dtype(output_dtype): + """ + Determine the appropriate accumulator dtype for matrix multiplication operations. + + This function promotes the output dtype to a wider precision type suitable for + accumulation to avoid precision loss and overflow during computation. + + Args: + output_dtype: The Triton language dtype of the output tensor (e.g., from tensor.type.element_ty) + + Returns: + The promoted Triton language dtype to use for accumulation + + Promotion rules: + - int8, int16, int32 -> int32 (to prevent overflow) + - float16 (fp16), bfloat16 (bf16) -> float32 (to prevent precision loss) + - float32 -> float32 + - float64 -> float64 + + Example: + >>> @triton.jit + >>> def kernel(C, ...): + >>> acc_dtype = iris.get_accumulator_dtype(C.type.element_ty) + >>> acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype) + """ + # Integer types -> accumulate in int32 + if output_dtype in (tl.int8, tl.int16, tl.int32): + return tl.int32 + # Half precision floats -> accumulate in float32 + elif output_dtype in (tl.float16, tl.bfloat16): + return tl.float32 + # float32 stays float32 + elif output_dtype == tl.float32: + return tl.float32 + # float64 stays float64 + elif output_dtype == tl.float64: + return tl.float64 + # Default fallback to float32 for any other types + else: + return tl.float32 + + def get_empty_cache_for_benchmark(): cache_size = 256 * 1024 * 1024 return torch.empty(int(cache_size // 4), dtype=torch.int, device="cuda") From a556995ca9b990ea7462b8ed47d2de4a5524ab3b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 18 Dec 2025 20:15:25 +0000 Subject: [PATCH 13/13] Add unit tests for get_accumulator_dtype helper function Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- tests/unittests/test_get_accumulator_dtype.py | 89 +++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 tests/unittests/test_get_accumulator_dtype.py diff --git a/tests/unittests/test_get_accumulator_dtype.py b/tests/unittests/test_get_accumulator_dtype.py new file mode 100644 index 00000000..beea741f --- /dev/null +++ b/tests/unittests/test_get_accumulator_dtype.py @@ -0,0 +1,89 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Test suite for get_accumulator_dtype utility function. +""" + +import pytest +import triton.language as tl +import iris + + +def test_get_accumulator_dtype_int8(): + """Test that int8 promotes to int32.""" + result = iris.get_accumulator_dtype(tl.int8) + assert result == tl.int32 + + +def test_get_accumulator_dtype_int16(): + """Test that int16 promotes to int32.""" + result = iris.get_accumulator_dtype(tl.int16) + assert result == tl.int32 + + +def test_get_accumulator_dtype_int32(): + """Test that int32 stays int32.""" + result = iris.get_accumulator_dtype(tl.int32) + assert result == tl.int32 + + +def test_get_accumulator_dtype_float16(): + """Test that float16 (fp16) promotes to float32.""" + result = iris.get_accumulator_dtype(tl.float16) + assert result == tl.float32 + + +def test_get_accumulator_dtype_bfloat16(): + """Test that bfloat16 (bf16) promotes to float32.""" + result = iris.get_accumulator_dtype(tl.bfloat16) + assert result == tl.float32 + + +def test_get_accumulator_dtype_float32(): + """Test that float32 stays float32.""" + result = iris.get_accumulator_dtype(tl.float32) + assert result == tl.float32 + + +def test_get_accumulator_dtype_float64(): + """Test that float64 stays float64.""" + result = iris.get_accumulator_dtype(tl.float64) + assert result == tl.float64 + + +@pytest.mark.parametrize( + "input_dtype,expected_output", + [ + (tl.int8, tl.int32), + (tl.int16, tl.int32), + (tl.int32, tl.int32), + (tl.float16, tl.float32), + (tl.bfloat16, tl.float32), + (tl.float32, tl.float32), + (tl.float64, tl.float64), + ], +) +def test_get_accumulator_dtype_parametrized(input_dtype, expected_output): + """Parametrized test for all supported dtype promotions.""" + result = iris.get_accumulator_dtype(input_dtype) + assert result == expected_output + + +def test_get_accumulator_dtype_precision_loss_prevention(): + """Test that half precision types promote to prevent precision loss.""" + # fp16 and bf16 should promote to fp32 to prevent precision loss in accumulation + assert iris.get_accumulator_dtype(tl.float16) == tl.float32 + assert iris.get_accumulator_dtype(tl.bfloat16) == tl.float32 + # fp32 and fp64 should stay as is (sufficient precision) + assert iris.get_accumulator_dtype(tl.float32) == tl.float32 + assert iris.get_accumulator_dtype(tl.float64) == tl.float64 + + +def test_get_accumulator_dtype_overflow_prevention(): + """Test that small integer types promote to prevent overflow.""" + # int8 and int16 should promote to int32 to prevent overflow + assert iris.get_accumulator_dtype(tl.int8) == tl.int32 + assert iris.get_accumulator_dtype(tl.int16) == tl.int32 + # int32 is already wide enough + assert iris.get_accumulator_dtype(tl.int32) == tl.int32