From 75d1d1cca44f1d0855c49fb15a0ccb16683ba323 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 7 Aug 2024 09:51:18 +0000 Subject: [PATCH 1/9] fix --- colossalai/quantization/fp8.py | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index bc8c3ced4cdd..805824a896c7 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -376,28 +376,6 @@ def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2"): output_list[i].copy_(cast_from_fp8(tensor, scale, input_type)) -def all_to_all_single_fp8(output_tensor, input_tensor, group=None, fp8_format="e5m2"): - - world_size = dist.get_world_size(group) - - per_slice_len = input_tensor.size(0) // world_size - input_type = input_tensor.dtype - ret, scale = cast_to_fp8(input_tensor, fp8_format=fp8_format) - fp8_type = ret.dtype - input_tensor = ret.view(torch.uint8) - tensor = torch.empty_like(input_tensor) - scale_list = [torch.empty_like(scale) for _ in range(world_size)] - dist.all_to_all_single(tensor, input_tensor, group=group) - dist.all_gather(scale_list, scale, group=group) - cast_tensor_list = [] - - for i in range(world_size): - output_part = tensor[per_slice_len * i : per_slice_len * (i + 1)].view(fp8_type) - output_part = cast_from_fp8(output_part, scale_list[i], input_type) - cast_tensor_list.append(output_part) - output_tensor.copy_(torch.concatenate(cast_tensor_list, dim=0)) - - def gather_fp8(output_list, input_, group=None, fp8_format="e5m2"): world_size = dist.get_world_size(group) From f081275993dde620a4fcc9823a1f43926e808d8e Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 8 Aug 2024 08:02:30 +0000 Subject: [PATCH 2/9] fix --- .github/workflows/example_check_on_pr.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/example_check_on_pr.yml b/.github/workflows/example_check_on_pr.yml index 56fa006b1633..1ccdd59afefd 100644 --- a/.github/workflows/example_check_on_pr.yml +++ b/.github/workflows/example_check_on_pr.yml @@ -107,7 +107,7 @@ jobs: - name: Install Colossal-AI run: | - BUILD_EXT=1 pip install -v . + BUILD_EXT=1 pip install -v -e . - name: Store Colossal-AI Cache run: | From 4f127135b2c33cde6056d251d0c32d845276c2ac Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 8 Aug 2024 08:04:55 +0000 Subject: [PATCH 3/9] fix --- .github/workflows/example_check_on_pr.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/example_check_on_pr.yml b/.github/workflows/example_check_on_pr.yml index 1ccdd59afefd..7a906738cb96 100644 --- a/.github/workflows/example_check_on_pr.yml +++ b/.github/workflows/example_check_on_pr.yml @@ -9,6 +9,7 @@ on: paths: - "examples/**" - "!examples/**.md" + - ".github/workflows/example_check_on_pr.yml" jobs: # This is for changed example files detect and output a matrix containing all the corresponding directory name. From 52d4b0bbce79d4c819fd801057439ffff3adc8f8 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Tue, 13 Aug 2024 08:16:38 +0000 Subject: [PATCH 4/9] support async all2all --- colossalai/quantization/fp8.py | 74 ++++++++++++++++-------- tests/test_fp8/test_all_to_all_single.py | 22 ++++--- 2 files changed, 64 insertions(+), 32 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 52bb8cc9bc33..158cd7475a30 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -7,6 +7,17 @@ from torch.distributed import ReduceOp +class Handle: + def __init__(self, handles, remain_ops) -> None: + self.handles = handles + self.remain_ops = remain_ops + + def wait(self): + for handle in self.handles: + handle.wait() + self.remain_ops() + + def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> (torch.Tensor, torch.Tensor): r""" casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling. @@ -60,7 +71,9 @@ def cast_from_fp8( return ret.to(ret_type) -def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None) -> None: +def all_reduce_fp8( + tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None, async_op: bool = False +) -> None: r""" This is an in-place operation for compressed all_reduce using fp8. It works like dist.all_reduce but during communication the data is cast to fp8 format. @@ -93,9 +106,9 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, gro inp = ret.view(torch.uint8) input_chunks = list(torch.chunk(inp, world_size, dim=0)) output_chunks = list(torch.chunk(torch.empty_like(inp), world_size, dim=0)) - dist.all_to_all(output_chunks, input_chunks, group=group) + chunk_handle = dist.all_to_all(output_chunks, input_chunks, group=group, async_op=async_op) scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)] - dist.all_gather(scale_list, scale, group=group) + scale_hanle = dist.all_gather(scale_list, scale, group=group, async_op=async_op) summed_out = torch.zeros_like(output_chunks[0]).to(input_type) for scale, out in zip(scale_list, output_chunks): out = out.view(fp8_type) @@ -117,7 +130,7 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, gro def all_to_all_single_fp8( output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False -) -> None: +) -> Handle: r""" This is an in-place operation for compressed all_reduce using fp8. It works like dist.all_to_all_single but during communication the data is cast to fp8 format. @@ -155,20 +168,27 @@ def all_to_all_single_fp8( else: output_chunks = [torch.empty_like(input_chunks[0]) for _ in range(world_size)] - dist.all_to_all(output_chunks, input_chunks, group=group) + chunk_handle = dist.all_to_all(output_chunks, input_chunks, group=group, async_op=async_op) scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)] - dist.all_gather(scale_list, scale, group=group) - cast_output_chunk = [ - cast_from_fp8(out.view(fp8_type), scale, input_type) for scale, out in zip(scale_list, output_chunks) - ] + scale_hanle = dist.all_gather(scale_list, scale, group=group, async_op=async_op) - tensor_out = torch.cat(cast_output_chunk, dim=0) - outputs_shape = list(input_shape) - if output_split_sizes is not None: - outputs_shape[0] = sum(output_split_sizes) + def cast_op(): + cast_output_chunk = [ + cast_from_fp8(out.view(fp8_type), scale, input_type) for scale, out in zip(scale_list, output_chunks) + ] + + tensor_out = torch.cat(cast_output_chunk, dim=0) + outputs_shape = list(input_shape) + if output_split_sizes is not None: + outputs_shape[0] = sum(output_split_sizes) + else: + outputs_shape = input_shape + output.data = tensor_out.view(outputs_shape).to(input_type) + + if async_op: + return Handle([chunk_handle, scale_hanle], cast_op) else: - outputs_shape = input_shape - output.data = tensor_out.view(outputs_shape).to(input_type) + cast_op() def cast_to_fp8_pipeline(inp: Any) -> None: @@ -236,7 +256,7 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None: del inp["fp8_scale"] -def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2") -> None: +def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2", async_op: bool = False) -> None: r""" This is an in-place operation for compressed reduce_scatter using fp8. It works like dist.reduce_scatter but during communication the data is cast to fp8 format. @@ -263,14 +283,20 @@ def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2 cast_input_list.append(ret) output_chunks.append(torch.empty_like(ret)) output_scale_list.append(torch.empty_like(scale)) - dist.all_to_all(output_chunks, cast_input_list, group=group) - dist.all_to_all(output_scale_list, scale_list, group=group) - - summed_out = torch.zeros_like(output_chunks[0]).to(input_type) - for scale, out in zip(output_scale_list, output_chunks): - out = out.view(fp8_type) - summed_out += cast_from_fp8(out, scale, input_type) - output.data = summed_out + chunk_handle = dist.all_to_all(output_chunks, cast_input_list, group=group, async_op=async_op) + scale_handle = dist.all_to_all(output_scale_list, scale_list, group=group, async_op=async_op) + + def cast_op(): + summed_out = torch.zeros_like(output_chunks[0]).to(input_type) + for scale, out in zip(output_scale_list, output_chunks): + out = out.view(fp8_type) + summed_out += cast_from_fp8(out, scale, input_type) + output.data = summed_out + + if async_op: + return Handle([chunk_handle, scale_handle], cast_op) + else: + cast_op() def split_chunk_by_channel( diff --git a/tests/test_fp8/test_all_to_all_single.py b/tests/test_fp8/test_all_to_all_single.py index 88becd3f07fc..8e1031e93f09 100644 --- a/tests/test_fp8/test_all_to_all_single.py +++ b/tests/test_fp8/test_all_to_all_single.py @@ -10,19 +10,23 @@ @parameterize("shape", [(4,), (1, 8, 16), (4, 8, 16)]) -@parameterize("dtype", [torch.bfloat16]) -def check_all2all(shape, dtype): +@parameterize("dtype", [torch.bfloat16, torch.float16]) +@parameterize("async_op", [True, False]) +def check_all2all(shape, dtype, async_op): x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) output = torch.empty_like(x) output_fp8 = torch.empty_like(x) - dist.all_to_all_single(output, x, group=_get_default_group(), async_op=False) - all_to_all_single_fp8(output_fp8, x, group=_get_default_group(), async_op=False) + dist.all_to_all_single(output, x, group=_get_default_group(), async_op=async_op) + handle = all_to_all_single_fp8(output_fp8, x, group=_get_default_group(), async_op=async_op) + if async_op: + handle.wait() assert_close(output, output_fp8, rtol=0.1, atol=0.1) @parameterize("shape", [(8, 8, 16)]) @parameterize("dtype", [torch.bfloat16, torch.float16]) -def check_all2all_uneven(shape, dtype): +@parameterize("async_op", [True, False]) +def check_all2all_uneven(shape, dtype, async_op): x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) input_split_sizes = [3, 3, 1, 1] if dist.get_rank() in [0, 1]: @@ -39,16 +43,18 @@ def check_all2all_uneven(shape, dtype): output_split_sizes=output_split_sizes, input_split_sizes=input_split_sizes, group=_get_default_group(), - async_op=False, + async_op=async_op, ) - all_to_all_single_fp8( + handle = all_to_all_single_fp8( output_fp8, x, output_split_sizes=output_split_sizes, input_split_sizes=input_split_sizes, group=_get_default_group(), - async_op=False, + async_op=async_op, ) + if async_op: + handle.wait() assert_close(output, output_fp8, rtol=0.1, atol=0.1) From 351969cabb7af0d5a2fbfb2ae1467a9666aae337 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Tue, 13 Aug 2024 08:33:24 +0000 Subject: [PATCH 5/9] support async op for all gather --- colossalai/quantization/fp8.py | 22 ++++++++++++++-------- tests/test_fp8/test_all_to_all_single.py | 18 ++++++++++-------- tests/test_fp8/test_fp8_gather.py | 10 +++++++--- 3 files changed, 31 insertions(+), 19 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 158cd7475a30..faece70d1c68 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -130,7 +130,7 @@ def all_reduce_fp8( def all_to_all_single_fp8( output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False -) -> Handle: +) -> Optional[Handle]: r""" This is an in-place operation for compressed all_reduce using fp8. It works like dist.all_to_all_single but during communication the data is cast to fp8 format. @@ -402,7 +402,7 @@ def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2"): output_list[i].copy_(cast_from_fp8(tensor, scale, input_type)) -def gather_fp8(output_list, input_, group=None, fp8_format="e5m2"): +def gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]: world_size = dist.get_world_size(group) @@ -412,13 +412,19 @@ def gather_fp8(output_list, input_, group=None, fp8_format="e5m2"): input_ = ret.view(torch.uint8) tensor_list = [torch.empty_like(input_) for _ in range(world_size)] scale_list = [torch.ones(1, dtype=scale.dtype, device=input_.device) for _ in range(world_size)] - dist.all_gather(tensor_list, input_, group=group) - dist.all_gather(scale_list, scale, group=group) + chunk_handle = dist.all_gather(tensor_list, input_, group=group, async_op=async_op) + scale_hanle = dist.all_gather(scale_list, scale, group=group, async_op=async_op) - for i in range(world_size): - output = tensor_list[i].view(fp8_type) - scale = scale_list[i] - output_list[i].copy_(cast_from_fp8(output, scale, input_type)) + def cast_op(): + for i in range(world_size): + output = tensor_list[i].view(fp8_type) + scale = scale_list[i] + output_list[i].copy_(cast_from_fp8(output, scale, input_type)) + + if async_op: + return Handle([chunk_handle, scale_hanle], cast_op) + else: + cast_op() class _LinearFp8(torch.autograd.Function): diff --git a/tests/test_fp8/test_all_to_all_single.py b/tests/test_fp8/test_all_to_all_single.py index 8e1031e93f09..05f6513c1abf 100644 --- a/tests/test_fp8/test_all_to_all_single.py +++ b/tests/test_fp8/test_all_to_all_single.py @@ -11,21 +11,22 @@ @parameterize("shape", [(4,), (1, 8, 16), (4, 8, 16)]) @parameterize("dtype", [torch.bfloat16, torch.float16]) -@parameterize("async_op", [True, False]) +@parameterize("async_op", [True]) def check_all2all(shape, dtype, async_op): x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) output = torch.empty_like(x) output_fp8 = torch.empty_like(x) - dist.all_to_all_single(output, x, group=_get_default_group(), async_op=async_op) - handle = all_to_all_single_fp8(output_fp8, x, group=_get_default_group(), async_op=async_op) + origin_hanle = dist.all_to_all_single(output, x, group=_get_default_group(), async_op=async_op) + fp8_handle = all_to_all_single_fp8(output_fp8, x, group=_get_default_group(), async_op=async_op) if async_op: - handle.wait() + origin_hanle.wait() + fp8_handle.wait() assert_close(output, output_fp8, rtol=0.1, atol=0.1) @parameterize("shape", [(8, 8, 16)]) @parameterize("dtype", [torch.bfloat16, torch.float16]) -@parameterize("async_op", [True, False]) +@parameterize("async_op", [True]) def check_all2all_uneven(shape, dtype, async_op): x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) input_split_sizes = [3, 3, 1, 1] @@ -37,7 +38,7 @@ def check_all2all_uneven(shape, dtype, async_op): output_shape[0] = sum(output_split_sizes) output = torch.empty(output_shape, device=x.device, dtype=x.dtype) output_fp8 = torch.empty(output_shape, device=x.device, dtype=x.dtype) - dist.all_to_all_single( + origin_hanle = dist.all_to_all_single( output, x, output_split_sizes=output_split_sizes, @@ -45,7 +46,7 @@ def check_all2all_uneven(shape, dtype, async_op): group=_get_default_group(), async_op=async_op, ) - handle = all_to_all_single_fp8( + fp8_handle = all_to_all_single_fp8( output_fp8, x, output_split_sizes=output_split_sizes, @@ -54,7 +55,8 @@ def check_all2all_uneven(shape, dtype, async_op): async_op=async_op, ) if async_op: - handle.wait() + origin_hanle.wait() + fp8_handle.wait() assert_close(output, output_fp8, rtol=0.1, atol=0.1) diff --git a/tests/test_fp8/test_fp8_gather.py b/tests/test_fp8/test_fp8_gather.py index 79d1d4ea49e6..40c2ccb9a17b 100644 --- a/tests/test_fp8/test_fp8_gather.py +++ b/tests/test_fp8/test_fp8_gather.py @@ -24,13 +24,17 @@ ) @parameterize("dtype", [torch.bfloat16, torch.float16]) @parameterize("fp8_format", ["e4m3", "e5m2"]) -def check_4gpu(shape, dtype, fp8_format): +@parameterize("async_op", [True, False]) +def check_4gpu(shape, dtype, fp8_format, async_op): world_size = dist.get_world_size() x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) output_list = [torch.empty_like(x) for _ in range(world_size)] output_list_fp8 = [torch.empty_like(x) for _ in range(world_size)] - gather_fp8(output_list_fp8, x, group=_get_default_group(), fp8_format=fp8_format) - dist.all_gather(output_list, x, group=_get_default_group()) + fp8_handle = gather_fp8(output_list_fp8, x, group=_get_default_group(), fp8_format=fp8_format, async_op=async_op) + origin_hanle = dist.all_gather(output_list, x, group=_get_default_group(), async_op=async_op) + if async_op: + fp8_handle.wait() + origin_hanle.wait() assert_close(output_list, output_list_fp8, rtol=0.1, atol=0.1) From adc62ff98f02228cab3fd8dcebf1c71e95182065 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Tue, 13 Aug 2024 09:56:26 +0000 Subject: [PATCH 6/9] fix --- colossalai/quantization/fp8.py | 35 +++++++++++++++++------- tests/test_fp8/test_all_to_all_single.py | 4 +-- tests/test_fp8/test_fp8_allreduce.py | 17 ++++++++---- 3 files changed, 39 insertions(+), 17 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index faece70d1c68..8b49a1b66339 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -8,14 +8,15 @@ class Handle: - def __init__(self, handles, remain_ops) -> None: + def __init__(self, handles=[], remain_ops=None) -> None: self.handles = handles self.remain_ops = remain_ops def wait(self): for handle in self.handles: handle.wait() - self.remain_ops() + if self.remain_ops: + self.remain_ops() def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> (torch.Tensor, torch.Tensor): @@ -106,10 +107,15 @@ def all_reduce_fp8( inp = ret.view(torch.uint8) input_chunks = list(torch.chunk(inp, world_size, dim=0)) output_chunks = list(torch.chunk(torch.empty_like(inp), world_size, dim=0)) - chunk_handle = dist.all_to_all(output_chunks, input_chunks, group=group, async_op=async_op) + chunck_handle = dist.all_to_all(output_chunks, input_chunks, group=group, async_op=async_op) scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)] - scale_hanle = dist.all_gather(scale_list, scale, group=group, async_op=async_op) + scale_handle = dist.all_gather(scale_list, scale, group=group, async_op=async_op) summed_out = torch.zeros_like(output_chunks[0]).to(input_type) + + if async_op: + chunck_handle.wait() + scale_handle.wait() + for scale, out in zip(scale_list, output_chunks): out = out.view(fp8_type) summed_out += cast_from_fp8(out, scale, input_type) @@ -118,14 +124,23 @@ def all_reduce_fp8( summed_out.div_(world_size) summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format) - dist.all_gather(scale_list, scale, group=group) + gather_scale_handle = dist.all_gather(scale_list, scale, group=group, async_op=async_op) tensor_list = [torch.empty_like(summed_out_fp8.view(torch.uint8)) for _ in range(world_size)] - dist.all_gather(tensor_list, summed_out_fp8.view(torch.uint8), group=group) - for i in range(world_size): - tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i] - out = torch.cat(tensor_list, dim=0) - tensor.copy_(out[:input_size].view(input_shape).to(input_type)) + gather_tensor_handle = dist.all_gather( + tensor_list, summed_out_fp8.view(torch.uint8), group=group, async_op=async_op + ) + + def cat_op(): + for i in range(world_size): + tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i] + out = torch.cat(tensor_list, dim=0) + tensor.copy_(out[:input_size].view(input_shape).to(input_type)) + + if async_op: + return Handle([gather_scale_handle, gather_tensor_handle], cat_op) + else: + cat_op() def all_to_all_single_fp8( diff --git a/tests/test_fp8/test_all_to_all_single.py b/tests/test_fp8/test_all_to_all_single.py index 05f6513c1abf..722cbce9ac02 100644 --- a/tests/test_fp8/test_all_to_all_single.py +++ b/tests/test_fp8/test_all_to_all_single.py @@ -11,7 +11,7 @@ @parameterize("shape", [(4,), (1, 8, 16), (4, 8, 16)]) @parameterize("dtype", [torch.bfloat16, torch.float16]) -@parameterize("async_op", [True]) +@parameterize("async_op", [True, False]) def check_all2all(shape, dtype, async_op): x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) output = torch.empty_like(x) @@ -26,7 +26,7 @@ def check_all2all(shape, dtype, async_op): @parameterize("shape", [(8, 8, 16)]) @parameterize("dtype", [torch.bfloat16, torch.float16]) -@parameterize("async_op", [True]) +@parameterize("async_op", [True, False]) def check_all2all_uneven(shape, dtype, async_op): x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) input_split_sizes = [3, 3, 1, 1] diff --git a/tests/test_fp8/test_fp8_allreduce.py b/tests/test_fp8/test_fp8_allreduce.py index c23959b5d0da..ccc43ed2979f 100644 --- a/tests/test_fp8/test_fp8_allreduce.py +++ b/tests/test_fp8/test_fp8_allreduce.py @@ -22,15 +22,22 @@ ) @parameterize("dtype", [torch.float16, torch.bfloat16]) @parameterize("fp8_format", ["e4m3", "e5m2"]) -def check_4gpu(shape, dtype, fp8_format): +@parameterize("async_op", [True, False]) +def check_4gpu(shape, dtype, fp8_format, async_op): x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) x_fp8 = x.clone() - dist.all_reduce(x) - all_reduce_fp8(x_fp8, fp8_format=fp8_format) + origin_handle = dist.all_reduce(x, async_op=async_op) + fp8_handle = all_reduce_fp8(x_fp8, fp8_format=fp8_format, async_op=async_op) + if async_op: + origin_handle.wait() + fp8_handle.wait() assert_close(x, x_fp8, rtol=0.1, atol=0.1) - dist.all_reduce(x, op=dist.ReduceOp.AVG) - all_reduce_fp8(x_fp8, op=dist.ReduceOp.AVG, fp8_format=fp8_format) + origin_handle = dist.all_reduce(x, op=dist.ReduceOp.AVG, async_op=async_op) + fp8_handle = all_reduce_fp8(x_fp8, op=dist.ReduceOp.AVG, fp8_format=fp8_format, async_op=async_op) + if async_op: + origin_handle.wait() + fp8_handle.wait() assert_close(x, x_fp8, rtol=0.1, atol=0.1) From bffa302f77e3783877f121a2b1a22f928c1f69f5 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Tue, 13 Aug 2024 10:00:03 +0000 Subject: [PATCH 7/9] fix --- colossalai/quantization/fp8.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 8b49a1b66339..0f5fe4c76135 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -74,7 +74,7 @@ def cast_from_fp8( def all_reduce_fp8( tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None, async_op: bool = False -) -> None: +) -> Optional[Handle]: r""" This is an in-place operation for compressed all_reduce using fp8. It works like dist.all_reduce but during communication the data is cast to fp8 format. @@ -271,7 +271,9 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None: del inp["fp8_scale"] -def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2", async_op: bool = False) -> None: +def reduce_scatter_fp8( + output: torch.Tensor, input_list, group, fp8_format="e5m2", async_op: bool = False +) -> Optional[Handle]: r""" This is an in-place operation for compressed reduce_scatter using fp8. It works like dist.reduce_scatter but during communication the data is cast to fp8 format. From 34bb6cae357f12ac9eb4b227ad540e8e52d99ded Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 13 Aug 2024 10:05:18 +0000 Subject: [PATCH 8/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/quantization/fp8.py | 1 + 1 file changed, 1 insertion(+) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index b72a06aee0b2..038bf77b3a17 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -9,6 +9,7 @@ SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.3.0") + class Handle: def __init__(self, handles=[], remain_ops=None) -> None: self.handles = handles From 92a3a5f4ef95ad0809064f731dad37b52e9f46a7 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 14 Aug 2024 03:44:41 +0000 Subject: [PATCH 9/9] fix --- colossalai/quantization/fp8.py | 35 +++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 0f5fe4c76135..80c9e853b492 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -107,15 +107,11 @@ def all_reduce_fp8( inp = ret.view(torch.uint8) input_chunks = list(torch.chunk(inp, world_size, dim=0)) output_chunks = list(torch.chunk(torch.empty_like(inp), world_size, dim=0)) - chunck_handle = dist.all_to_all(output_chunks, input_chunks, group=group, async_op=async_op) + dist.all_to_all(output_chunks, input_chunks, group=group) scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)] - scale_handle = dist.all_gather(scale_list, scale, group=group, async_op=async_op) + dist.all_gather(scale_list, scale, group=group) summed_out = torch.zeros_like(output_chunks[0]).to(input_type) - if async_op: - chunck_handle.wait() - scale_handle.wait() - for scale, out in zip(scale_list, output_chunks): out = out.view(fp8_type) summed_out += cast_from_fp8(out, scale, input_type) @@ -336,7 +332,8 @@ def all_gather_into_tensor_flat_fp8( output_shape: torch.Size, group: dist.ProcessGroup, fp8_format: str = "e4m3", -): + async_op: bool = False, +) -> Optional[Handle]: """all gather into tensor in fp8 format Args: @@ -383,15 +380,17 @@ def all_gather_into_tensor_flat_fp8( scale = fp8_max / per_tensor_max fp8_input = (scale * input_tensor.float()).to(fp8_type) scale_inv = 1.0 / scale + buffer = torch.empty_like(output_tensor, dtype=fp8_type) dist.all_gather_into_tensor(buffer.view(torch.uint8), fp8_input.view(torch.uint8), group=group) + numel = np.prod(output_shape) valid_buffer = buffer[:numel].reshape(output_shape) valid_buffer = cast_from_fp8(valid_buffer, scale_inv, input_type, per_channel_scale=(len(output_shape) == 2)) output_tensor[:numel].copy_(valid_buffer.view(-1)) -def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2"): +def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async_op=False): world_size = dist.get_world_size(group) @@ -409,14 +408,20 @@ def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2"): output_scale_list = [torch.empty_like(x) for x in scale_list] output_tensor_list = [torch.empty_like(x) for x in tensor_list] - dist.all_to_all(output_tensor_list, tensor_list, group=group) - dist.all_to_all(output_scale_list, scale_list, group=group) + tensor_hanle = dist.all_to_all(output_tensor_list, tensor_list, group=group, async_op=async_op) + scale_handle = dist.all_to_all(output_scale_list, scale_list, group=group, async_op=async_op) - for i in range(world_size): - scale = output_scale_list[i] - tensor = output_tensor_list[i] - tensor = tensor.view(fp8_type) - output_list[i].copy_(cast_from_fp8(tensor, scale, input_type)) + def cast_op(): + for i in range(world_size): + scale = output_scale_list[i] + tensor = output_tensor_list[i] + tensor = tensor.view(fp8_type) + output_list[i].copy_(cast_from_fp8(tensor, scale, input_type)) + + if async_op: + return Handle([tensor_hanle, scale_handle], cast_op) + else: + cast_op() def gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]: