From 9043fba1502f28018fdbede88e0940700cbf543c Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Tue, 30 Jul 2024 09:06:55 +0000 Subject: [PATCH 1/7] support all2all fp8 --- colossalai/quantization/fp8.py | 36 ++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index fe5bd5744e69..48bcb4d621ad 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -102,6 +102,42 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e5m2", group=None) -> None: tensor_out = torch.cat(tensor_list, dim=0) tensor.data = tensor_out.view(input_shape).to(input_type) +def all_to_all_single_fp8(output, input, output_tensor_list, input_tensor_list, fp8_format="e5m2", group=None, async_op=False) -> None: + r""" + This is an in-place operation for compressed all_reduce using fp8. + It works like dist.all_reduce but during communication the data is cast to fp8 format. + Args: + tensor: torch.Tensor in fp32, fp16, bf16 datatype. + fp8_format: e4m3 or e5m2 + Returns: + None + """ + + world_size = dist.get_world_size(group=group) + input_type = input.dtype + input_shape = input.shape + input_device = input.device + input = input.flatten() + + fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 + + ret, scale = cast_to_fp8(input, fp8_format=fp8_format) + + inp = ret.view(torch.uint8) + input_chunks = torch.split(inp, input_tensor_list) + + output_chunks = [torch.empty((output_tensor_list[i]*np.prod(input_shape[1:]),), device=input_device, dtype=input_type) for i in range(world_size)] + + dist.all_to_all(output_chunks, input_chunks, group=group) + scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)] + dist.all_gather(scale_list, scale, group=group) + for scale, out in zip(scale_list, output_chunks): + out = out.view(fp8_type) + out = cast_from_fp8(out, scale, input_type) + + tensor_out = torch.cat(output_chunks, dim=0) + output.data = tensor_out.to(input_type) + def cast_to_fp8_pipeline(inp: Any) -> None: """ From 41a9aca399c3ea712aa66b2a571f3eb4a74a4662 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Tue, 30 Jul 2024 09:10:37 +0000 Subject: [PATCH 2/7] fix --- colossalai/quantization/fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 48bcb4d621ad..9674c1e4798d 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -105,7 +105,7 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e5m2", group=None) -> None: def all_to_all_single_fp8(output, input, output_tensor_list, input_tensor_list, fp8_format="e5m2", group=None, async_op=False) -> None: r""" This is an in-place operation for compressed all_reduce using fp8. - It works like dist.all_reduce but during communication the data is cast to fp8 format. + It works like dist.all_to_all_single but during communication the data is cast to fp8 format. Args: tensor: torch.Tensor in fp32, fp16, bf16 datatype. fp8_format: e4m3 or e5m2 From 214bc6c7eeb7ddb22ea0da1d2b48a183fac6a3a7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 30 Jul 2024 09:13:23 +0000 Subject: [PATCH 3/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/quantization/fp8.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 9674c1e4798d..fca789ce811c 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -102,7 +102,10 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e5m2", group=None) -> None: tensor_out = torch.cat(tensor_list, dim=0) tensor.data = tensor_out.view(input_shape).to(input_type) -def all_to_all_single_fp8(output, input, output_tensor_list, input_tensor_list, fp8_format="e5m2", group=None, async_op=False) -> None: + +def all_to_all_single_fp8( + output, input, output_tensor_list, input_tensor_list, fp8_format="e5m2", group=None, async_op=False +) -> None: r""" This is an in-place operation for compressed all_reduce using fp8. It works like dist.all_to_all_single but during communication the data is cast to fp8 format. @@ -126,7 +129,10 @@ def all_to_all_single_fp8(output, input, output_tensor_list, input_tensor_list, inp = ret.view(torch.uint8) input_chunks = torch.split(inp, input_tensor_list) - output_chunks = [torch.empty((output_tensor_list[i]*np.prod(input_shape[1:]),), device=input_device, dtype=input_type) for i in range(world_size)] + output_chunks = [ + torch.empty((output_tensor_list[i] * np.prod(input_shape[1:]),), device=input_device, dtype=input_type) + for i in range(world_size) + ] dist.all_to_all(output_chunks, input_chunks, group=group) scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)] From 7fc7e159b898d638f2d3963d9a76cf431761fbcd Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Tue, 30 Jul 2024 09:18:33 +0000 Subject: [PATCH 4/7] fix --- colossalai/quantization/fp8.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 9674c1e4798d..ae87079838fe 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -136,7 +136,9 @@ def all_to_all_single_fp8(output, input, output_tensor_list, input_tensor_list, out = cast_from_fp8(out, scale, input_type) tensor_out = torch.cat(output_chunks, dim=0) - output.data = tensor_out.to(input_type) + outputs_shape = list(input_shape) + outputs_shape[0] = sum(output_tensor_list) + output.data = tensor_out.view(outputs_shape).to(input_type) def cast_to_fp8_pipeline(inp: Any) -> None: From d50dd3b3a4da4a6d4c189c0827160e9465a9deb3 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 31 Jul 2024 07:32:00 +0000 Subject: [PATCH 5/7] fix --- colossalai/quantization/fp8.py | 37 +++++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 622b785f75cf..1f3cda2728d9 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -104,7 +104,7 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e5m2", group=None) -> None: def all_to_all_single_fp8( - output, input, output_tensor_list, input_tensor_list, fp8_format="e5m2", group=None, async_op=False + output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False ) -> None: r""" This is an in-place operation for compressed all_reduce using fp8. @@ -115,7 +115,6 @@ def all_to_all_single_fp8( Returns: None """ - world_size = dist.get_world_size(group=group) input_type = input.dtype input_shape = input.shape @@ -127,23 +126,37 @@ def all_to_all_single_fp8( ret, scale = cast_to_fp8(input, fp8_format=fp8_format) inp = ret.view(torch.uint8) - input_chunks = torch.split(inp, input_tensor_list) + if input_split_sizes is not None: + input_split_sizes = [input_split_sizes[i] * np.prod(input_shape[1:]) for i in range(world_size)] + input_chunks = list(torch.split(inp, input_split_sizes)) + else: + input_chunks = list(torch.chunk(inp, world_size, dim=0)) - output_chunks = [ - torch.empty((output_tensor_list[i] * np.prod(input_shape[1:]),), device=input_device, dtype=input_type) - for i in range(world_size) - ] + if output_split_sizes is not None: + output_chunks = [ + torch.empty((output_split_sizes[i] * np.prod(input_shape[1:]),), device=input_device, dtype=inp.dtype) + for i in range(world_size) + ] + else: + if dist.get_rank() == world_size - 1: + output_chunks = [torch.empty_like(input_chunks[-1]) for _ in range(world_size)] + else: + output_chunks = [torch.empty_like(input_chunks[0]) for _ in range(world_size)] dist.all_to_all(output_chunks, input_chunks, group=group) scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)] dist.all_gather(scale_list, scale, group=group) - for scale, out in zip(scale_list, output_chunks): - out = out.view(fp8_type) - out = cast_from_fp8(out, scale, input_type) + cast_output_chunk = [ + cast_from_fp8(out.view(fp8_type), scale, input_type) + for scale, out in zip(scale_list, output_chunks) + ] - tensor_out = torch.cat(output_chunks, dim=0) + tensor_out = torch.cat(cast_output_chunk, dim=0) outputs_shape = list(input_shape) - outputs_shape[0] = sum(output_tensor_list) + if output_split_sizes is not None: + outputs_shape[0] = sum(output_split_sizes) + else: + outputs_shape = input_shape output.data = tensor_out.view(outputs_shape).to(input_type) From ffba42328029187cbaa8dcdd2e588b16e5212205 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 31 Jul 2024 07:33:52 +0000 Subject: [PATCH 6/7] fix --- tests/test_fp8/test_all_to_all_single.py | 52 ++++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 tests/test_fp8/test_all_to_all_single.py diff --git a/tests/test_fp8/test_all_to_all_single.py b/tests/test_fp8/test_all_to_all_single.py new file mode 100644 index 000000000000..4c9ea495828a --- /dev/null +++ b/tests/test_fp8/test_all_to_all_single.py @@ -0,0 +1,52 @@ +import torch +import torch.distributed as dist +from torch.distributed.distributed_c10d import _get_default_group +from torch.testing import assert_close + +from colossalai import launch +from colossalai.accelerator import get_accelerator +from colossalai.quantization.fp8 import all_to_all_single_fp8 +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + + +@parameterize("shape", [(4,), (1, 8, 16), (4, 8, 16)]) +@parameterize("dtype", [torch.bfloat16]) +def check_all2all(shape, dtype): + x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) + output = torch.empty_like(x) + output_fp8 = torch.empty_like(x) + dist.all_to_all_single(output, x, group=_get_default_group(), async_op=False) + all_to_all_single_fp8(output_fp8, x, group=_get_default_group(), async_op=False) + assert_close(output, output_fp8, rtol=0.1, atol=0.1) + +@parameterize("shape", [(8, 8, 16)]) +@parameterize("dtype", [torch.bfloat16, torch.float16]) +def check_all2all_uneven(shape, dtype): + x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) + input_split_sizes = [3, 3, 1, 1] + if dist.get_rank() in [0, 1]: + output_split_sizes = [3, 3, 3, 3] + else: + output_split_sizes = [1, 1, 1, 1] + output_shape = list(shape) + output_shape[0] = sum(output_split_sizes) + output = torch.empty(output_shape, device=x.device, dtype=x.dtype) + output_fp8 = torch.empty(output_shape, device=x.device, dtype=x.dtype) + dist.all_to_all_single(output, x, output_split_sizes=output_split_sizes, input_split_sizes=input_split_sizes, group=_get_default_group(), async_op=False) + all_to_all_single_fp8(output_fp8, x, output_split_sizes=output_split_sizes, input_split_sizes=input_split_sizes, group=_get_default_group(), async_op=False) + assert_close(output, output_fp8, rtol=0.1, atol=0.1) + + +def run_dist(rank, world_size, port): + launch(rank=rank, world_size=world_size, port=port, host="localhost") + check_all2all() + check_all2all_uneven() + + +@rerun_if_address_is_in_use() +def test_all_to_all_single(): + spawn(run_dist, 4) + + +if __name__ == "__main__": + test_all_to_all_single() \ No newline at end of file From d67c59eab6f4f6896b5bdd06c1411f2564feaffb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 31 Jul 2024 07:36:04 +0000 Subject: [PATCH 7/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/quantization/fp8.py | 3 +-- tests/test_fp8/test_all_to_all_single.py | 21 ++++++++++++++++++--- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 1f3cda2728d9..840af1f3f7f7 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -147,8 +147,7 @@ def all_to_all_single_fp8( scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)] dist.all_gather(scale_list, scale, group=group) cast_output_chunk = [ - cast_from_fp8(out.view(fp8_type), scale, input_type) - for scale, out in zip(scale_list, output_chunks) + cast_from_fp8(out.view(fp8_type), scale, input_type) for scale, out in zip(scale_list, output_chunks) ] tensor_out = torch.cat(cast_output_chunk, dim=0) diff --git a/tests/test_fp8/test_all_to_all_single.py b/tests/test_fp8/test_all_to_all_single.py index 4c9ea495828a..88becd3f07fc 100644 --- a/tests/test_fp8/test_all_to_all_single.py +++ b/tests/test_fp8/test_all_to_all_single.py @@ -19,6 +19,7 @@ def check_all2all(shape, dtype): all_to_all_single_fp8(output_fp8, x, group=_get_default_group(), async_op=False) assert_close(output, output_fp8, rtol=0.1, atol=0.1) + @parameterize("shape", [(8, 8, 16)]) @parameterize("dtype", [torch.bfloat16, torch.float16]) def check_all2all_uneven(shape, dtype): @@ -32,8 +33,22 @@ def check_all2all_uneven(shape, dtype): output_shape[0] = sum(output_split_sizes) output = torch.empty(output_shape, device=x.device, dtype=x.dtype) output_fp8 = torch.empty(output_shape, device=x.device, dtype=x.dtype) - dist.all_to_all_single(output, x, output_split_sizes=output_split_sizes, input_split_sizes=input_split_sizes, group=_get_default_group(), async_op=False) - all_to_all_single_fp8(output_fp8, x, output_split_sizes=output_split_sizes, input_split_sizes=input_split_sizes, group=_get_default_group(), async_op=False) + dist.all_to_all_single( + output, + x, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=_get_default_group(), + async_op=False, + ) + all_to_all_single_fp8( + output_fp8, + x, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=_get_default_group(), + async_op=False, + ) assert_close(output, output_fp8, rtol=0.1, atol=0.1) @@ -49,4 +64,4 @@ def test_all_to_all_single(): if __name__ == "__main__": - test_all_to_all_single() \ No newline at end of file + test_all_to_all_single()