From 7416d55b5b0450d368bc8c462b419b8fff87a738 Mon Sep 17 00:00:00 2001 From: BurkeHulk Date: Tue, 3 Sep 2024 13:50:56 +0800 Subject: [PATCH 1/4] enhance all_to_all_fp8 with internode comm control --- colossalai/quantization/fp8.py | 41 ++++++++++++++++++++++++++++++---- 1 file changed, 37 insertions(+), 4 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index c022fab158c8..93e3e1e559e7 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -1,4 +1,5 @@ from typing import Any, Optional, Tuple +import os import numpy as np import torch @@ -22,6 +23,16 @@ def wait(self): self.remain_ops() +def process_group_is_intranode(pg): + if pg is None: + from torch.distributed.distributed_c10d import _get_default_group + pg = _get_default_group() + local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) + group_ranks = list(dist.distributed_c10d._pg_group_ranks[pg].values()) + group_ranks_node_ids = [rank // local_world_size for rank in group_ranks] + return min(group_ranks_node_ids) == max(group_ranks_node_ids) + + def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> Tuple[torch.Tensor, torch.Tensor]: r""" casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling. @@ -147,7 +158,8 @@ def cat_op(): cat_op() -def all_to_all_single_fp8( +@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False) +def _all_to_all_single_fp8( output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False ) -> Optional[Handle]: r""" @@ -209,6 +221,21 @@ def cast_op(): else: cast_op() +def all_to_all_single_fp8( + output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False +) -> Optional[Handle]: + r""" + This is wrapper for _all_to_all_single_fp8. + """ + if process_group_is_intranode(group): + return dist.all_to_all_single(output, input, + output_split_sizes=output_split_sizes, input_split_sizes=input_split_sizes, + group=group, async_op=async_op) + else: + return _all_to_all_single_fp8(output, input, fp8_format=fp8_format, + output_split_sizes=output_split_sizes, input_split_sizes=input_split_sizes, + group=group, async_op=async_op) + def cast_to_fp8_pipeline(inp: Any) -> None: """ @@ -605,10 +632,9 @@ def cast_op(): cast_op() -def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async_op=False): - +@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False) +def _all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async_op=False): world_size = dist.get_world_size(group) - input_type = input_list[0].dtype fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 scale_list = [] @@ -639,6 +665,13 @@ def cast_op(): cast_op() +def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async_op=False): + if process_group_is_intranode(group): + return dist.all_to_all(output_list, input_list, group=group, async_op=async_op) + else: + return _all_to_all_fp8(output_list, input_list, group=group, fp8_format=fp8_format, async_op=async_op) + + def gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]: world_size = dist.get_world_size(group) From fd80530e763b3778526876831d5c908069079cfa Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Sep 2024 05:54:29 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/quantization/fp8.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 93e3e1e559e7..360e390a50a2 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -1,5 +1,5 @@ -from typing import Any, Optional, Tuple import os +from typing import Any, Optional, Tuple import numpy as np import torch @@ -26,6 +26,7 @@ def wait(self): def process_group_is_intranode(pg): if pg is None: from torch.distributed.distributed_c10d import _get_default_group + pg = _get_default_group() local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) group_ranks = list(dist.distributed_c10d._pg_group_ranks[pg].values()) @@ -221,6 +222,7 @@ def cast_op(): else: cast_op() + def all_to_all_single_fp8( output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False ) -> Optional[Handle]: @@ -228,13 +230,24 @@ def all_to_all_single_fp8( This is wrapper for _all_to_all_single_fp8. """ if process_group_is_intranode(group): - return dist.all_to_all_single(output, input, - output_split_sizes=output_split_sizes, input_split_sizes=input_split_sizes, - group=group, async_op=async_op) + return dist.all_to_all_single( + output, + input, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=async_op, + ) else: - return _all_to_all_single_fp8(output, input, fp8_format=fp8_format, - output_split_sizes=output_split_sizes, input_split_sizes=input_split_sizes, - group=group, async_op=async_op) + return _all_to_all_single_fp8( + output, + input, + fp8_format=fp8_format, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=async_op, + ) def cast_to_fp8_pipeline(inp: Any) -> None: From 124fb25b145560027712448cd261e3ff5d1a7261 Mon Sep 17 00:00:00 2001 From: BurkeHulk Date: Tue, 3 Sep 2024 18:42:27 +0800 Subject: [PATCH 3/4] disable some fp8 ops due to performance issue --- colossalai/quantization/fp8.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 93e3e1e559e7..8c02a666cff6 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -27,8 +27,15 @@ def process_group_is_intranode(pg): if pg is None: from torch.distributed.distributed_c10d import _get_default_group pg = _get_default_group() - local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) - group_ranks = list(dist.distributed_c10d._pg_group_ranks[pg].values()) + + local_world_size = None + for var in ["LOCAL_WORLD_SIZE", "OMPI_COMM_WORLD_LOCAL_SIZE", "SLURM_TASKS_PER_NODE"]: + if var in os.environ: + local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) + if local_world_size is None: + local_world_size = torch.cuda.device_count() + + group_ranks = dist.get_process_group_ranks(pg) group_ranks_node_ids = [rank // local_world_size for rank in group_ranks] return min(group_ranks_node_ids) == max(group_ranks_node_ids) @@ -91,7 +98,7 @@ def cast_from_fp8( return ret.to(ret_type) -def all_reduce_fp8( +def _all_reduce_fp8( tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None, async_op: bool = False ) -> Optional[Handle]: r""" @@ -158,6 +165,12 @@ def cat_op(): cat_op() +def all_reduce_fp8( + tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None, async_op: bool = False +) -> Optional[Handle]: + # fall back to default op due to performance issue + return dist.all_reduce(tensor, op=op, group=group, async_op=async_op) + @torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False) def _all_to_all_single_fp8( output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False @@ -308,7 +321,7 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None: del inp["dtype"] -def reduce_scatter_fp8( +def _reduce_scatter_fp8( output: torch.Tensor, input_list, group, fp8_format="e5m2", async_op: bool = False ) -> Optional[Handle]: r""" @@ -353,6 +366,13 @@ def cast_op(): cast_op() +def reduce_scatter_fp8( + output: torch.Tensor, input_list, group, fp8_format="e5m2", async_op: bool = False +) -> Optional[Handle]: + # fall back to default op due to performance issue + return dist.reduce_scatter(output, input_list, group=group, async_op=async_op) + + def fp8_compress_ddp_grad_comm_hook_async( process_group: dist.ProcessGroup, bucket: dist.GradBucket, From 987c603c18638649253d5d9094a9ce3fc7475b53 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Sep 2024 02:16:21 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/quantization/fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 24b2aa17f4f8..388bbde052d2 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -24,7 +24,6 @@ def wait(self): self.remain_ops() - def process_group_is_intranode(pg): if pg is None: from torch.distributed.distributed_c10d import _get_default_group @@ -185,6 +184,7 @@ def all_reduce_fp8( # fall back to default op due to performance issue return dist.all_reduce(tensor, op=op, group=group, async_op=async_op) + @torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False) def _all_to_all_single_fp8( output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False