From 01e7f59348aa58dcc5344589a41329602d4eb330 Mon Sep 17 00:00:00 2001 From: GuangyaoZhang Date: Wed, 11 Sep 2024 17:32:47 +0800 Subject: [PATCH 1/8] all_gather only internode, fix pytest --- colossalai/quantization/fp8.py | 15 ++++++++++++--- colossalai/shardformer/layer/_operation.py | 4 ++-- tests/test_fp8/test_fp8_all_to_all.py | 4 ++-- ...test_fp8_gather.py => test_fp8_allgather.py} | 17 +++++------------ 4 files changed, 21 insertions(+), 19 deletions(-) rename tests/test_fp8/{test_fp8_gather.py => test_fp8_allgather.py} (81%) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 388bbde052d2..f287a9489f15 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -718,8 +718,8 @@ def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async return _all_to_all_fp8(output_list, input_list, group=group, fp8_format=fp8_format, async_op=async_op) -def gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]: - +@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False) +def _all_gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]: world_size = dist.get_world_size(group) input_type = input_.dtype @@ -743,8 +743,17 @@ def cast_op(): cast_op() -@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False) def all_gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]: + if process_group_is_intranode(group): + return dist.all_gather(output_list, input_, group=group, async_op=async_op) + else: + return _all_gather_fp8(output_list, input_, group=group, fp8_format=fp8_format, async_op=async_op) + + +@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False) +def all_gather_fp8_lagacy( + output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False +) -> Optional[Handle]: world_size = dist.get_world_size(group) shape = input_.shape input_type = input_.dtype diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index f970d8ccc85d..aec82356747a 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -17,10 +17,10 @@ _grad_accum_fusion_available = False from colossalai.quantization.fp8 import ( + all_gather_fp8, all_reduce_fp8, all_to_all_fp8, all_to_all_single_fp8, - gather_fp8, reduce_scatter_fp8, ) @@ -961,7 +961,7 @@ def _gather(input_, dim=-1, process_group=None, fp8_communication=False, fp8_for input_ = input_.contiguous() tensor_list = [torch.empty_like(input_) for _ in range(world_size)] if fp8_communication: - gather_fp8(tensor_list, input_, fp8_format=fp8_format, group=process_group) + all_gather_fp8(tensor_list, input_, fp8_format=fp8_format, group=process_group) else: dist.all_gather(tensor_list, input_, group=process_group) diff --git a/tests/test_fp8/test_fp8_all_to_all.py b/tests/test_fp8/test_fp8_all_to_all.py index 884aab744ba0..98bbbad8550d 100644 --- a/tests/test_fp8/test_fp8_all_to_all.py +++ b/tests/test_fp8/test_fp8_all_to_all.py @@ -5,7 +5,7 @@ from colossalai import launch from colossalai.accelerator import get_accelerator -from colossalai.quantization.fp8 import all_to_all_fp8 +from colossalai.quantization.fp8 import _all_to_all_fp8 from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn @@ -20,7 +20,7 @@ def check_4gpu(shape, scatter_dim, dtype, fp8_format): input_tensor_list = [x.contiguous() for x in input_tensor_list] output_tensor_list_fp8 = [torch.empty_like(x) for x in input_tensor_list] output_tensor_list = [torch.empty_like(x) for x in input_tensor_list] - all_to_all_fp8(output_tensor_list_fp8, input_tensor_list, group=_get_default_group(), fp8_format=fp8_format) + _all_to_all_fp8(output_tensor_list_fp8, input_tensor_list, group=_get_default_group(), fp8_format=fp8_format) dist.all_to_all(output_tensor_list, input_tensor_list, group=_get_default_group()) assert_close(output_tensor_list_fp8, output_tensor_list, rtol=0.1, atol=0.1) diff --git a/tests/test_fp8/test_fp8_gather.py b/tests/test_fp8/test_fp8_allgather.py similarity index 81% rename from tests/test_fp8/test_fp8_gather.py rename to tests/test_fp8/test_fp8_allgather.py index 40c2ccb9a17b..91e66e83c67b 100644 --- a/tests/test_fp8/test_fp8_gather.py +++ b/tests/test_fp8/test_fp8_allgather.py @@ -5,22 +5,13 @@ from colossalai import launch from colossalai.accelerator import get_accelerator -from colossalai.quantization.fp8 import gather_fp8 +from colossalai.quantization.fp8 import _all_gather_fp8 from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn @parameterize( "shape", - [ - (3, 7), - (2, 1), - (1, 2), - (2, 2), - (4, 2), - (5,), - (4,), - (2,), - ], + [(3, 7, 16)], ) @parameterize("dtype", [torch.bfloat16, torch.float16]) @parameterize("fp8_format", ["e4m3", "e5m2"]) @@ -30,7 +21,9 @@ def check_4gpu(shape, dtype, fp8_format, async_op): x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) output_list = [torch.empty_like(x) for _ in range(world_size)] output_list_fp8 = [torch.empty_like(x) for _ in range(world_size)] - fp8_handle = gather_fp8(output_list_fp8, x, group=_get_default_group(), fp8_format=fp8_format, async_op=async_op) + fp8_handle = _all_gather_fp8( + output_list_fp8, x, group=_get_default_group(), fp8_format=fp8_format, async_op=async_op + ) origin_hanle = dist.all_gather(output_list, x, group=_get_default_group(), async_op=async_op) if async_op: fp8_handle.wait() From a6b0ecd287f230ca72673a8c809f04e54a90a474 Mon Sep 17 00:00:00 2001 From: GuangyaoZhang Date: Wed, 11 Sep 2024 11:15:47 +0000 Subject: [PATCH 2/8] fix cuda arch <89 compile pytest error --- colossalai/quantization/fp8.py | 16 ++++++++++------ tests/test_fp8/test_fp8_all_to_all.py | 4 ++-- tests/test_fp8/test_fp8_allgather.py | 4 ++-- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index f287a9489f15..57d473c1e304 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -11,6 +11,8 @@ SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.4.0") SCALE_BYTES = 4 +cuda_arch = int("".join(str(i) for i in torch.cuda.get_device_capability())) + class Handle: def __init__(self, handles=[], remain_ops=None) -> None: @@ -185,7 +187,7 @@ def all_reduce_fp8( return dist.all_reduce(tensor, op=op, group=group, async_op=async_op) -@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False) +@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False, disable=cuda_arch < 89) def _all_to_all_single_fp8( output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False ) -> Optional[Handle]: @@ -678,7 +680,7 @@ def cast_op(): cast_op() -@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False) +@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False, disable=cuda_arch < 89) def _all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async_op=False): world_size = dist.get_world_size(group) input_type = input_list[0].dtype @@ -718,7 +720,7 @@ def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async return _all_to_all_fp8(output_list, input_list, group=group, fp8_format=fp8_format, async_op=async_op) -@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False) +@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False, disable=cuda_arch < 89) def _all_gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]: world_size = dist.get_world_size(group) @@ -750,7 +752,7 @@ def all_gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: return _all_gather_fp8(output_list, input_, group=group, fp8_format=fp8_format, async_op=async_op) -@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False) +@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False, disable=cuda_arch < 89) def all_gather_fp8_lagacy( output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False ) -> Optional[Handle]: @@ -778,7 +780,7 @@ def all_gather_fp8_lagacy( # out.copy_(output[i].view(shape)) -@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False) +@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False, disable=cuda_arch < 89) def all_gather_fp8_ring(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]: world_size = dist.get_world_size(group) rank = dist.get_rank(group) @@ -900,7 +902,9 @@ def backward(ctx: Any, out_grad) -> Any: return x_grad.reshape(ctx.x_shape), w_grad, bias_grad -@torch.compile(mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=False) +@torch.compile( + mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=False, disable=cuda_arch < 89 +) def _linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: return _LinearFp8.apply(input, weight, bias) diff --git a/tests/test_fp8/test_fp8_all_to_all.py b/tests/test_fp8/test_fp8_all_to_all.py index 98bbbad8550d..884aab744ba0 100644 --- a/tests/test_fp8/test_fp8_all_to_all.py +++ b/tests/test_fp8/test_fp8_all_to_all.py @@ -5,7 +5,7 @@ from colossalai import launch from colossalai.accelerator import get_accelerator -from colossalai.quantization.fp8 import _all_to_all_fp8 +from colossalai.quantization.fp8 import all_to_all_fp8 from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn @@ -20,7 +20,7 @@ def check_4gpu(shape, scatter_dim, dtype, fp8_format): input_tensor_list = [x.contiguous() for x in input_tensor_list] output_tensor_list_fp8 = [torch.empty_like(x) for x in input_tensor_list] output_tensor_list = [torch.empty_like(x) for x in input_tensor_list] - _all_to_all_fp8(output_tensor_list_fp8, input_tensor_list, group=_get_default_group(), fp8_format=fp8_format) + all_to_all_fp8(output_tensor_list_fp8, input_tensor_list, group=_get_default_group(), fp8_format=fp8_format) dist.all_to_all(output_tensor_list, input_tensor_list, group=_get_default_group()) assert_close(output_tensor_list_fp8, output_tensor_list, rtol=0.1, atol=0.1) diff --git a/tests/test_fp8/test_fp8_allgather.py b/tests/test_fp8/test_fp8_allgather.py index 91e66e83c67b..50957045c318 100644 --- a/tests/test_fp8/test_fp8_allgather.py +++ b/tests/test_fp8/test_fp8_allgather.py @@ -5,7 +5,7 @@ from colossalai import launch from colossalai.accelerator import get_accelerator -from colossalai.quantization.fp8 import _all_gather_fp8 +from colossalai.quantization.fp8 import all_gather_fp8 from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn @@ -21,7 +21,7 @@ def check_4gpu(shape, dtype, fp8_format, async_op): x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) output_list = [torch.empty_like(x) for _ in range(world_size)] output_list_fp8 = [torch.empty_like(x) for _ in range(world_size)] - fp8_handle = _all_gather_fp8( + fp8_handle = all_gather_fp8( output_list_fp8, x, group=_get_default_group(), fp8_format=fp8_format, async_op=async_op ) origin_hanle = dist.all_gather(output_list, x, group=_get_default_group(), async_op=async_op) From dd809571b7b5568c2f75986ac0ecc8e194c92f89 Mon Sep 17 00:00:00 2001 From: GuangyaoZhang Date: Wed, 11 Sep 2024 11:24:43 +0000 Subject: [PATCH 3/8] fix pytest failure --- colossalai/quantization/fp8.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 57d473c1e304..7e097360f5f0 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -902,9 +902,7 @@ def backward(ctx: Any, out_grad) -> Any: return x_grad.reshape(ctx.x_shape), w_grad, bias_grad -@torch.compile( - mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=False, disable=cuda_arch < 89 -) +@torch.compile(mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=False) def _linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: return _LinearFp8.apply(input, weight, bias) From 092e0798028a3046a89d509654ba6690885f0f8c Mon Sep 17 00:00:00 2001 From: GuangyaoZhang Date: Fri, 13 Sep 2024 05:53:12 +0000 Subject: [PATCH 4/8] disable all_gather_into_tensor_flat_fp8 --- colossalai/quantization/fp8.py | 72 ------------------- colossalai/zero/gemini/chunk/chunk.py | 9 +-- .../low_level/bookkeeping/tensor_bucket.py | 10 +-- colossalai/zero/low_level/low_level_optim.py | 7 +- tests/test_fp8/test_fp8_allgather_flat.py | 43 ----------- 5 files changed, 11 insertions(+), 130 deletions(-) delete mode 100644 tests/test_fp8/test_fp8_allgather_flat.py diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 7e097360f5f0..3d5096c67c6c 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -608,78 +608,6 @@ def split_chunk_by_channel( return chunk.split(sizes) -def all_gather_into_tensor_flat_fp8( - output_tensor: torch.Tensor, - input_tensor: torch.Tensor, - output_shape: torch.Size, - group: dist.ProcessGroup, - fp8_format: str = "e4m3", - async_op: bool = False, -) -> Optional[Handle]: - """all gather into tensor in fp8 format - - Args: - output_tensor (torch.Tensor): output tensor, which is flattened - input_tensor (torch.Tensor): input tensor, which is flattened - group (dist.ProcessGroup): process group - fp8_format (str, optional): fp8 format, e4m3 or e5m2. Defaults to "e4m3". - """ - assert input_tensor.dim() == 1 and output_tensor.dim() == 1, "input/output tensor should be flattened" - world_size = dist.get_world_size(group) - assert ( - output_tensor.numel() == input_tensor.numel() * world_size - ), "output tensor size should be world_size times of input tensor size" - - input_type = output_tensor.dtype - - fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 - fp8_max = torch.finfo(fp8_type).max - - if len(output_shape) == 2: - per_channel_max = torch.zeros(output_shape[0], device=output_tensor.device, dtype=torch.float) - num_channels, channel_size = output_shape - rank = dist.get_rank(group) - channel_start_idx = (input_tensor.numel() * rank) // channel_size - per_channel_splits = split_chunk_by_channel(input_tensor, channel_size, num_channels, rank, world_size) - for i, per_channel_split in enumerate(per_channel_splits): - idx = i + channel_start_idx - if idx < num_channels: - per_channel_max[idx] = per_channel_split.abs().max().float() - dist.all_reduce(per_channel_max, op=dist.ReduceOp.MAX, group=group) - per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0) - scale = fp8_max / per_channel_max - fp8_input = input_tensor.float() - fp8_per_channel_splits = split_chunk_by_channel(fp8_input, channel_size, num_channels, rank, world_size) - for i, per_channel_split in enumerate(fp8_per_channel_splits): - idx = i + channel_start_idx - if idx < num_channels: - per_channel_split.mul_(scale[idx]) - fp8_input = fp8_input.to(fp8_type) - else: - per_tensor_max = input_tensor.abs().max().float() - dist.all_reduce(per_tensor_max, op=dist.ReduceOp.MAX, group=group) - per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0) - scale = fp8_max / per_tensor_max - fp8_input = (scale * input_tensor.float()).to(fp8_type) - scale_inv = 1.0 / scale - - buffer = torch.empty_like(output_tensor, dtype=fp8_type) - tensor_handle = dist.all_gather_into_tensor( - buffer.view(torch.uint8), fp8_input.view(torch.uint8), group=group, async_op=async_op - ) - - def cast_op(): - numel = output_shape.numel() - valid_buffer = buffer[:numel].reshape(output_shape) - valid_buffer = cast_from_fp8(valid_buffer, scale_inv, input_type, per_channel_scale=(len(output_shape) == 2)) - output_tensor[:numel].copy_(valid_buffer.view(-1)) - - if async_op: - return Handle([tensor_handle], cast_op) - else: - cast_op() - - @torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False, disable=cuda_arch < 89) def _all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async_op=False): world_size = dist.get_world_size(group) diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index e2b7a8f56432..95ff653e605d 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -7,6 +7,7 @@ from torch.distributed import ProcessGroup from colossalai.accelerator import get_accelerator +from colossalai.quantization.fp8 import all_gather_fp8 class TensorState(Enum): @@ -523,12 +524,8 @@ def __gather(self, async_op: bool = False) -> Optional[dist.Work]: alloc_storage(self.cuda_global_chunk) assert self.cuda_global_chunk.is_contiguous() if self.fp8_communication: - assert async_op == False, "fp8 all-gather does not support async_op!" - from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8 - - work = all_gather_into_tensor_flat_fp8( - self.cuda_global_chunk, self.cuda_shard, self.cuda_global_chunk.shape, self.torch_pg - ) + cuda_global_chunk_list = torch.chunk(self.cuda_global_chunk, chunks=self.pg_size) + work = all_gather_fp8(cuda_global_chunk_list, self.cuda_shard, self.torch_pg, async_op=async_op) else: work = dist.all_gather_into_tensor( self.cuda_global_chunk, self.cuda_shard, self.torch_pg, async_op=async_op diff --git a/colossalai/zero/low_level/bookkeeping/tensor_bucket.py b/colossalai/zero/low_level/bookkeeping/tensor_bucket.py index d5fd2fe51662..41f9ec2cac77 100644 --- a/colossalai/zero/low_level/bookkeeping/tensor_bucket.py +++ b/colossalai/zero/low_level/bookkeeping/tensor_bucket.py @@ -4,7 +4,7 @@ import torch.distributed as dist from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors -from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8 +from colossalai.quantization.fp8 import all_gather_fp8 class TensorBucket: @@ -65,12 +65,12 @@ def unflatten_and_copy(self, flat_tensor): def all_gather(self, group=None, fp8_communication: bool = False): flat = self.flatten() - buffer = torch.empty(flat.numel() * dist.get_world_size(group), device=flat.device, dtype=flat.dtype) + buffers = [torch.empty_like(flat) for _ in range(dist.get_world_size(group))] if fp8_communication: - all_gather_into_tensor_flat_fp8(buffer, flat, output_shape=buffer.shape, group=group) + all_gather_fp8(buffers, flat, group=group) else: - dist.all_gather_into_tensor(buffer, flat, group=group) - unflat_buffers = [self.unflatten(buffer) for buffer in buffer.chunk(dist.get_world_size(group))] + dist.all_gather(buffers, flat, group=group) + unflat_buffers = [self.unflatten(buffer) for buffer in buffers] # transpose the list of list unflat_buffers = list(map(list, zip(*unflat_buffers))) for unflat_shards, tensor in zip(unflat_buffers, self._bucket): diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index ed51c2bacafc..a70ef4aa8f3c 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -20,7 +20,7 @@ ) from colossalai.interface import OptimizerWrapper from colossalai.logging import get_dist_logger -from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8, all_reduce_fp8, reduce_scatter_fp8 +from colossalai.quantization.fp8 import all_gather_fp8, all_reduce_fp8, reduce_scatter_fp8 from colossalai.tensor.moe_tensor.api import is_moe_tensor from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor @@ -580,9 +580,8 @@ def step(self, closure=None): else: if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size: if self._fp8_communication: - all_gather_into_tensor_flat_fp8( - padded_working_param, param_to_gather, pg, fp8_format="e4m3" - ) + padded_working_param_list = torch.chunk(padded_working_param, dist.get_world_size(pg)) + all_gather_fp8(padded_working_param_list, param_to_gather, pg, fp8_format="e4m3") else: dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg) continue diff --git a/tests/test_fp8/test_fp8_allgather_flat.py b/tests/test_fp8/test_fp8_allgather_flat.py deleted file mode 100644 index 2d43e5bd5902..000000000000 --- a/tests/test_fp8/test_fp8_allgather_flat.py +++ /dev/null @@ -1,43 +0,0 @@ -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch.distributed.distributed_c10d import _get_default_group -from torch.testing import assert_close - -from colossalai import launch -from colossalai.accelerator import get_accelerator -from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8 -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn - - -@parameterize("shape", [(3, 7), (2, 1), (1, 2), (2, 2), (4, 2), (5,), (4,), (2,)]) -@parameterize("dtype", [torch.bfloat16, torch.float16]) -@parameterize("async_op", [True, False]) -def check_4gpu(shape, dtype, async_op): - world_size = dist.get_world_size() - rank = dist.get_rank() - x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) - flat_padded_x = x.view(-1) - if flat_padded_x.size(0) % world_size != 0: - pad_size = world_size - flat_padded_x.size(0) % world_size - flat_padded_x = F.pad(flat_padded_x, (0, pad_size)) - output = torch.empty_like(flat_padded_x) - chunk = flat_padded_x.chunk(world_size)[rank].clone() - handle = all_gather_into_tensor_flat_fp8(output, chunk, x.shape, group=_get_default_group(), async_op=async_op) - if async_op: - handle.wait() - assert_close(output[: x.numel()], x.view(-1), rtol=0.1, atol=0.1) - - -def run_dist(rank, world_size, port): - launch(rank=rank, world_size=world_size, port=port, host="localhost") - check_4gpu() - - -@rerun_if_address_is_in_use() -def test_all_gather_flat(): - spawn(run_dist, 4) - - -if __name__ == "__main__": - test_all_gather_flat() From 27b8dbb5246f522d6800df4babe893e4d1fd02ec Mon Sep 17 00:00:00 2001 From: GuangyaoZhang Date: Fri, 13 Sep 2024 06:00:29 +0000 Subject: [PATCH 5/8] fix fp8 format --- colossalai/zero/gemini/chunk/chunk.py | 4 +++- colossalai/zero/low_level/bookkeeping/tensor_bucket.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index 95ff653e605d..597dba1e8ba2 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -525,7 +525,9 @@ def __gather(self, async_op: bool = False) -> Optional[dist.Work]: assert self.cuda_global_chunk.is_contiguous() if self.fp8_communication: cuda_global_chunk_list = torch.chunk(self.cuda_global_chunk, chunks=self.pg_size) - work = all_gather_fp8(cuda_global_chunk_list, self.cuda_shard, self.torch_pg, async_op=async_op) + work = all_gather_fp8( + cuda_global_chunk_list, self.cuda_shard, self.torch_pg, fp8_format="e4m3", async_op=async_op + ) else: work = dist.all_gather_into_tensor( self.cuda_global_chunk, self.cuda_shard, self.torch_pg, async_op=async_op diff --git a/colossalai/zero/low_level/bookkeeping/tensor_bucket.py b/colossalai/zero/low_level/bookkeeping/tensor_bucket.py index 41f9ec2cac77..11a3c6c04a5d 100644 --- a/colossalai/zero/low_level/bookkeeping/tensor_bucket.py +++ b/colossalai/zero/low_level/bookkeeping/tensor_bucket.py @@ -67,7 +67,7 @@ def all_gather(self, group=None, fp8_communication: bool = False): flat = self.flatten() buffers = [torch.empty_like(flat) for _ in range(dist.get_world_size(group))] if fp8_communication: - all_gather_fp8(buffers, flat, group=group) + all_gather_fp8(buffers, flat, group=group, fp8_format="e4m3") else: dist.all_gather(buffers, flat, group=group) unflat_buffers = [self.unflatten(buffer) for buffer in buffers] From 802df11d53de114e397f231f726697c003c636c8 Mon Sep 17 00:00:00 2001 From: GuangyaoZhang Date: Fri, 13 Sep 2024 06:19:23 +0000 Subject: [PATCH 6/8] fix pytest --- tests/test_fp8/test_fp8_all_to_all.py | 4 ++-- tests/test_fp8/test_fp8_allgather.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_fp8/test_fp8_all_to_all.py b/tests/test_fp8/test_fp8_all_to_all.py index 884aab744ba0..98bbbad8550d 100644 --- a/tests/test_fp8/test_fp8_all_to_all.py +++ b/tests/test_fp8/test_fp8_all_to_all.py @@ -5,7 +5,7 @@ from colossalai import launch from colossalai.accelerator import get_accelerator -from colossalai.quantization.fp8 import all_to_all_fp8 +from colossalai.quantization.fp8 import _all_to_all_fp8 from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn @@ -20,7 +20,7 @@ def check_4gpu(shape, scatter_dim, dtype, fp8_format): input_tensor_list = [x.contiguous() for x in input_tensor_list] output_tensor_list_fp8 = [torch.empty_like(x) for x in input_tensor_list] output_tensor_list = [torch.empty_like(x) for x in input_tensor_list] - all_to_all_fp8(output_tensor_list_fp8, input_tensor_list, group=_get_default_group(), fp8_format=fp8_format) + _all_to_all_fp8(output_tensor_list_fp8, input_tensor_list, group=_get_default_group(), fp8_format=fp8_format) dist.all_to_all(output_tensor_list, input_tensor_list, group=_get_default_group()) assert_close(output_tensor_list_fp8, output_tensor_list, rtol=0.1, atol=0.1) diff --git a/tests/test_fp8/test_fp8_allgather.py b/tests/test_fp8/test_fp8_allgather.py index 50957045c318..91e66e83c67b 100644 --- a/tests/test_fp8/test_fp8_allgather.py +++ b/tests/test_fp8/test_fp8_allgather.py @@ -5,7 +5,7 @@ from colossalai import launch from colossalai.accelerator import get_accelerator -from colossalai.quantization.fp8 import all_gather_fp8 +from colossalai.quantization.fp8 import _all_gather_fp8 from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn @@ -21,7 +21,7 @@ def check_4gpu(shape, dtype, fp8_format, async_op): x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) output_list = [torch.empty_like(x) for _ in range(world_size)] output_list_fp8 = [torch.empty_like(x) for _ in range(world_size)] - fp8_handle = all_gather_fp8( + fp8_handle = _all_gather_fp8( output_list_fp8, x, group=_get_default_group(), fp8_format=fp8_format, async_op=async_op ) origin_hanle = dist.all_gather(output_list, x, group=_get_default_group(), async_op=async_op) From c9833f16ac479ba063b1a58a5592fcc1d6dddef1 Mon Sep 17 00:00:00 2001 From: GuangyaoZhang Date: Fri, 13 Sep 2024 08:05:15 +0000 Subject: [PATCH 7/8] fix conversations --- colossalai/quantization/fp8.py | 6 ++++-- colossalai/zero/gemini/chunk/chunk.py | 7 +++++-- colossalai/zero/low_level/bookkeeping/tensor_bucket.py | 8 ++++---- colossalai/zero/low_level/low_level_optim.py | 8 ++++++-- 4 files changed, 19 insertions(+), 10 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 3d5096c67c6c..8243a29ac825 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -10,8 +10,10 @@ SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.4.0") SCALE_BYTES = 4 - -cuda_arch = int("".join(str(i) for i in torch.cuda.get_device_capability())) +try: + cuda_arch = int("".join(str(i) for i in torch.cuda.get_device_capability())) +except: + cuda_arch = 0 class Handle: diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index 597dba1e8ba2..49e6e6bbb985 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -524,9 +524,12 @@ def __gather(self, async_op: bool = False) -> Optional[dist.Work]: alloc_storage(self.cuda_global_chunk) assert self.cuda_global_chunk.is_contiguous() if self.fp8_communication: - cuda_global_chunk_list = torch.chunk(self.cuda_global_chunk, chunks=self.pg_size) work = all_gather_fp8( - cuda_global_chunk_list, self.cuda_shard, self.torch_pg, fp8_format="e4m3", async_op=async_op + self.cuda_global_chunk.chunk(self.pg_size), + self.cuda_shard, + self.torch_pg, + fp8_format="e4m3", + async_op=async_op, ) else: work = dist.all_gather_into_tensor( diff --git a/colossalai/zero/low_level/bookkeeping/tensor_bucket.py b/colossalai/zero/low_level/bookkeeping/tensor_bucket.py index 11a3c6c04a5d..c3111b2eb46e 100644 --- a/colossalai/zero/low_level/bookkeeping/tensor_bucket.py +++ b/colossalai/zero/low_level/bookkeeping/tensor_bucket.py @@ -65,12 +65,12 @@ def unflatten_and_copy(self, flat_tensor): def all_gather(self, group=None, fp8_communication: bool = False): flat = self.flatten() - buffers = [torch.empty_like(flat) for _ in range(dist.get_world_size(group))] + buffer = torch.empty(flat.numel() * dist.get_world_size(group), device=flat.device, dtype=flat.dtype) if fp8_communication: - all_gather_fp8(buffers, flat, group=group, fp8_format="e4m3") + all_gather_fp8(buffer.chunk(dist.get_world_size(group)), flat, group=group, fp8_format="e4m3") else: - dist.all_gather(buffers, flat, group=group) - unflat_buffers = [self.unflatten(buffer) for buffer in buffers] + dist.all_gather_into_tensor(buffer, flat, group=group) + unflat_buffers = [self.unflatten(buffer) for buffer in buffer.chunk(dist.get_world_size(group))] # transpose the list of list unflat_buffers = list(map(list, zip(*unflat_buffers))) for unflat_shards, tensor in zip(unflat_buffers, self._bucket): diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index a70ef4aa8f3c..0d4dd2acd0ed 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -580,8 +580,12 @@ def step(self, closure=None): else: if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size: if self._fp8_communication: - padded_working_param_list = torch.chunk(padded_working_param, dist.get_world_size(pg)) - all_gather_fp8(padded_working_param_list, param_to_gather, pg, fp8_format="e4m3") + all_gather_fp8( + padded_working_param.chunk(dist.get_world_size(pg)), + param_to_gather, + pg, + fp8_format="e4m3", + ) else: dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg) continue From b1c1fcabe01c923dd777a2f34c6f521a782306e4 Mon Sep 17 00:00:00 2001 From: GuangyaoZhang Date: Fri, 13 Sep 2024 09:01:44 +0000 Subject: [PATCH 8/8] fix chunk tuple to list --- colossalai/zero/gemini/chunk/chunk.py | 2 +- colossalai/zero/low_level/bookkeeping/tensor_bucket.py | 2 +- colossalai/zero/low_level/low_level_optim.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index 49e6e6bbb985..351ff14e0131 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -525,7 +525,7 @@ def __gather(self, async_op: bool = False) -> Optional[dist.Work]: assert self.cuda_global_chunk.is_contiguous() if self.fp8_communication: work = all_gather_fp8( - self.cuda_global_chunk.chunk(self.pg_size), + list(self.cuda_global_chunk.chunk(self.pg_size)), self.cuda_shard, self.torch_pg, fp8_format="e4m3", diff --git a/colossalai/zero/low_level/bookkeeping/tensor_bucket.py b/colossalai/zero/low_level/bookkeeping/tensor_bucket.py index c3111b2eb46e..3c95aa6babcd 100644 --- a/colossalai/zero/low_level/bookkeeping/tensor_bucket.py +++ b/colossalai/zero/low_level/bookkeeping/tensor_bucket.py @@ -67,7 +67,7 @@ def all_gather(self, group=None, fp8_communication: bool = False): flat = self.flatten() buffer = torch.empty(flat.numel() * dist.get_world_size(group), device=flat.device, dtype=flat.dtype) if fp8_communication: - all_gather_fp8(buffer.chunk(dist.get_world_size(group)), flat, group=group, fp8_format="e4m3") + all_gather_fp8(list(buffer.chunk(dist.get_world_size(group))), flat, group=group, fp8_format="e4m3") else: dist.all_gather_into_tensor(buffer, flat, group=group) unflat_buffers = [self.unflatten(buffer) for buffer in buffer.chunk(dist.get_world_size(group))] diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 0d4dd2acd0ed..c019ff19b84c 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -581,7 +581,7 @@ def step(self, closure=None): if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size: if self._fp8_communication: all_gather_fp8( - padded_working_param.chunk(dist.get_world_size(pg)), + list(padded_working_param.chunk(dist.get_world_size(pg))), param_to_gather, pg, fp8_format="e4m3",