From f74225fa2f0da8890837a4382403944ed796f9a2 Mon Sep 17 00:00:00 2001 From: ver217 Date: Tue, 20 Aug 2024 16:22:55 +0800 Subject: [PATCH 1/4] [fp8] optimize all-gather --- colossalai/quantization/fp8.py | 103 +++++++++++++++++++++++++++++++-- 1 file changed, 98 insertions(+), 5 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index c022fab158c8..e949af5ac0a2 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -8,6 +8,7 @@ from torch.distributed import ReduceOp SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.4.0") +SCALE_BYTES = 4 class Handle: @@ -22,7 +23,9 @@ def wait(self): self.remain_ops() -def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> Tuple[torch.Tensor, torch.Tensor]: +def cast_to_fp8( + inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False, out=None +) -> Tuple[torch.Tensor, torch.Tensor]: r""" casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling. Args: @@ -55,12 +58,15 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) - scale = fp8_max / per_tensor_max scale_inv = 1.0 / scale - ret = (scale * inp.float()).to(fp8_type) + if out is not None: + ret = torch.mul(scale, inp.float(), out=out) + else: + ret = (scale * inp.float()).to(fp8_type) return ret, torch.unsqueeze(scale_inv, dim=0) def cast_from_fp8( - inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype, per_channel_scale=False + inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype, per_channel_scale=False, out=None ) -> torch.Tensor: r""" Args: @@ -74,9 +80,15 @@ def cast_from_fp8( raise TypeError("Only float8_e4m3fn and float8_e5m2 are allowed.") if per_channel_scale: - ret = scale_inv[:, None] * inp.float() + if out is not None: + return torch.mul(scale_inv[:, None], inp.float(), out=out) + else: + ret = scale_inv[:, None] * inp.float() else: - ret = scale_inv * inp.float() + if out is not None: + return torch.mul(scale_inv, inp.float(), out=out) + else: + ret = scale_inv * inp.float() return ret.to(ret_type) @@ -664,6 +676,87 @@ def cast_op(): cast_op() +def all_gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]: + world_size = dist.get_world_size(group) + shape = input_.shape + input_type = input_.dtype + fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 + + combined_buffer = torch.empty(world_size * (SCALE_BYTES + input_.numel()), dtype=torch.uint8, device=input_.device) + combined_buffers = list(combined_buffer.chunk(world_size, dim=0)) + cur_buffer = combined_buffers[dist.get_rank(group)] + ret = cur_buffer[SCALE_BYTES:].view(fp8_type) + ret, scale = cast_to_fp8(input_.view(-1), fp8_format=fp8_format, out=ret) + cur_buffer[:SCALE_BYTES].copy_(scale.unsqueeze(0).view(torch.uint8)) + dist.all_gather(combined_buffers, cur_buffer, group=group, async_op=async_op) + for out, buf in zip(output_list, combined_buffers): + scale = buf[:SCALE_BYTES].view(scale.dtype) + output = buf[SCALE_BYTES:].view(fp8_type) + cast_from_fp8(output.view(shape), scale, input_type, out=out) + # output = combined_buffer.view(world_size, -1)[:, SCALE_BYTES:].view(fp8_type) + # scales = combined_buffer.view(world_size, -1)[:, :SCALE_BYTES].view(torch.float) + # output = output.float() * scales + # for i, out in enumerate(output_list): + # out.copy_(output[i].view(shape)) + + +def all_gather_fp8_ring(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]: + world_size = dist.get_world_size(group) + rank = dist.get_rank(group) + + send_rank = (rank + 1) % world_size + recv_rank = (rank - 1) % world_size + + shape = input_.shape + input_type = input_.dtype + fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 + + combined_buffer = torch.empty(world_size * (SCALE_BYTES + input_.numel()), dtype=torch.uint8, device=input_.device) + combined_buffers = list(combined_buffer.chunk(world_size, dim=0)) + cur_buffer = combined_buffers[dist.get_rank(group)] + ret = cur_buffer[SCALE_BYTES:].view(fp8_type) + ret, scale = cast_to_fp8(input_.view(-1), fp8_format=fp8_format, out=ret) + cur_buffer[:SCALE_BYTES].copy_(scale.unsqueeze(0).view(torch.uint8)) + + def send_recv(idx): + send_idx = (rank - idx) % world_size + recv_idx = (rank - idx - 1) % world_size + ops = dist.batch_isend_irecv( + [ + dist.P2POp(dist.isend, combined_buffers[send_idx], send_rank, group=group), + dist.P2POp(dist.irecv, combined_buffers[recv_idx], recv_rank, group=group), + ] + ) + return ops + + def cast(idx): + cast_idx = (rank - idx - 1) % world_size + scale = combined_buffers[cast_idx][:SCALE_BYTES].view(torch.float) + output = combined_buffers[cast_idx][SCALE_BYTES:].view(fp8_type) + cast_from_fp8(output.view(shape), scale, input_type, out=output_list[cast_idx]) + + # warmup + ops = send_recv(0) + if output_list[rank] is not input_ and output_list[rank].data_ptr() != input_.data_ptr(): + output_list[rank].copy_(input_) + for op in ops: + op.wait() + ops = [] + + # 1p-1c + for i in range(1, world_size - 1): + new_ops = send_recv(i) + for op in ops: + op.wait() + cast(i - 1) + ops = new_ops + + # cooldown + for op in ops: + op.wait() + cast(world_size - 1) + + class _LinearFp8(torch.autograd.Function): @staticmethod def forward( From e73d43e02de80377c089bff9e3a90b067a23a3cf Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 21 Aug 2024 17:07:09 +0800 Subject: [PATCH 2/4] [fp8] fix all gather fp8 ring --- colossalai/quantization/fp8.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index e949af5ac0a2..bb0f27205ac3 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -676,6 +676,7 @@ def cast_op(): cast_op() +# @torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False) def all_gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]: world_size = dist.get_world_size(group) shape = input_.shape @@ -687,10 +688,11 @@ def all_gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: cur_buffer = combined_buffers[dist.get_rank(group)] ret = cur_buffer[SCALE_BYTES:].view(fp8_type) ret, scale = cast_to_fp8(input_.view(-1), fp8_format=fp8_format, out=ret) - cur_buffer[:SCALE_BYTES].copy_(scale.unsqueeze(0).view(torch.uint8)) + # cur_buffer[:SCALE_BYTES].view(torch.float)[0] = scale + cur_buffer[:SCALE_BYTES] = scale.unsqueeze(0).view(torch.uint8) dist.all_gather(combined_buffers, cur_buffer, group=group, async_op=async_op) for out, buf in zip(output_list, combined_buffers): - scale = buf[:SCALE_BYTES].view(scale.dtype) + scale = buf[:SCALE_BYTES].clone().view(scale.dtype) output = buf[SCALE_BYTES:].view(fp8_type) cast_from_fp8(output.view(shape), scale, input_type, out=out) # output = combined_buffer.view(world_size, -1)[:, SCALE_BYTES:].view(fp8_type) @@ -700,6 +702,7 @@ def all_gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: # out.copy_(output[i].view(shape)) +# @torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False) def all_gather_fp8_ring(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]: world_size = dist.get_world_size(group) rank = dist.get_rank(group) @@ -716,7 +719,7 @@ def all_gather_fp8_ring(output_list, input_, group=None, fp8_format="e5m2", asyn cur_buffer = combined_buffers[dist.get_rank(group)] ret = cur_buffer[SCALE_BYTES:].view(fp8_type) ret, scale = cast_to_fp8(input_.view(-1), fp8_format=fp8_format, out=ret) - cur_buffer[:SCALE_BYTES].copy_(scale.unsqueeze(0).view(torch.uint8)) + cur_buffer[:SCALE_BYTES] = scale.unsqueeze(0).view(torch.uint8) def send_recv(idx): send_idx = (rank - idx) % world_size @@ -731,14 +734,13 @@ def send_recv(idx): def cast(idx): cast_idx = (rank - idx - 1) % world_size - scale = combined_buffers[cast_idx][:SCALE_BYTES].view(torch.float) + scale = combined_buffers[cast_idx][:SCALE_BYTES].clone().view(torch.float) output = combined_buffers[cast_idx][SCALE_BYTES:].view(fp8_type) cast_from_fp8(output.view(shape), scale, input_type, out=output_list[cast_idx]) # warmup ops = send_recv(0) - if output_list[rank] is not input_ and output_list[rank].data_ptr() != input_.data_ptr(): - output_list[rank].copy_(input_) + output_list[rank].copy_(input_) for op in ops: op.wait() ops = [] @@ -754,7 +756,7 @@ def cast(idx): # cooldown for op in ops: op.wait() - cast(world_size - 1) + cast(world_size - 2) class _LinearFp8(torch.autograd.Function): From 95ae810e20582194cb9040058394beacedecd827 Mon Sep 17 00:00:00 2001 From: ver217 Date: Tue, 27 Aug 2024 18:02:21 +0800 Subject: [PATCH 3/4] [fp8] enable compile --- colossalai/quantization/fp8.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index bb0f27205ac3..a9ec4577a416 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -676,7 +676,7 @@ def cast_op(): cast_op() -# @torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False) +@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False) def all_gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]: world_size = dist.get_world_size(group) shape = input_.shape @@ -688,8 +688,8 @@ def all_gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: cur_buffer = combined_buffers[dist.get_rank(group)] ret = cur_buffer[SCALE_BYTES:].view(fp8_type) ret, scale = cast_to_fp8(input_.view(-1), fp8_format=fp8_format, out=ret) - # cur_buffer[:SCALE_BYTES].view(torch.float)[0] = scale - cur_buffer[:SCALE_BYTES] = scale.unsqueeze(0).view(torch.uint8) + cur_buffer[:SCALE_BYTES].view(torch.float)[0] = scale + # cur_buffer[:SCALE_BYTES] = scale.unsqueeze(0).view(torch.uint8) dist.all_gather(combined_buffers, cur_buffer, group=group, async_op=async_op) for out, buf in zip(output_list, combined_buffers): scale = buf[:SCALE_BYTES].clone().view(scale.dtype) @@ -719,7 +719,8 @@ def all_gather_fp8_ring(output_list, input_, group=None, fp8_format="e5m2", asyn cur_buffer = combined_buffers[dist.get_rank(group)] ret = cur_buffer[SCALE_BYTES:].view(fp8_type) ret, scale = cast_to_fp8(input_.view(-1), fp8_format=fp8_format, out=ret) - cur_buffer[:SCALE_BYTES] = scale.unsqueeze(0).view(torch.uint8) + # cur_buffer[:SCALE_BYTES] = scale.unsqueeze(0).view(torch.uint8) + cur_buffer[:SCALE_BYTES].view(torch.float)[0] = scale def send_recv(idx): send_idx = (rank - idx) % world_size From c97777b5f465310d53734f37cd28a645085634bf Mon Sep 17 00:00:00 2001 From: ver217 Date: Mon, 2 Sep 2024 17:44:03 +0800 Subject: [PATCH 4/4] [fp8] fix all gather fp8 ring --- colossalai/quantization/fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index a9ec4577a416..6a0bd14d1071 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -702,7 +702,7 @@ def all_gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: # out.copy_(output[i].view(shape)) -# @torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False) +@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False) def all_gather_fp8_ring(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]: world_size = dist.get_world_size(group) rank = dist.get_rank(group)