From 75d1d1cca44f1d0855c49fb15a0ccb16683ba323 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 7 Aug 2024 09:51:18 +0000 Subject: [PATCH 1/4] fix --- colossalai/quantization/fp8.py | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index bc8c3ced4cdd..805824a896c7 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -376,28 +376,6 @@ def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2"): output_list[i].copy_(cast_from_fp8(tensor, scale, input_type)) -def all_to_all_single_fp8(output_tensor, input_tensor, group=None, fp8_format="e5m2"): - - world_size = dist.get_world_size(group) - - per_slice_len = input_tensor.size(0) // world_size - input_type = input_tensor.dtype - ret, scale = cast_to_fp8(input_tensor, fp8_format=fp8_format) - fp8_type = ret.dtype - input_tensor = ret.view(torch.uint8) - tensor = torch.empty_like(input_tensor) - scale_list = [torch.empty_like(scale) for _ in range(world_size)] - dist.all_to_all_single(tensor, input_tensor, group=group) - dist.all_gather(scale_list, scale, group=group) - cast_tensor_list = [] - - for i in range(world_size): - output_part = tensor[per_slice_len * i : per_slice_len * (i + 1)].view(fp8_type) - output_part = cast_from_fp8(output_part, scale_list[i], input_type) - cast_tensor_list.append(output_part) - output_tensor.copy_(torch.concatenate(cast_tensor_list, dim=0)) - - def gather_fp8(output_list, input_, group=None, fp8_format="e5m2"): world_size = dist.get_world_size(group) From f081275993dde620a4fcc9823a1f43926e808d8e Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 8 Aug 2024 08:02:30 +0000 Subject: [PATCH 2/4] fix --- .github/workflows/example_check_on_pr.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/example_check_on_pr.yml b/.github/workflows/example_check_on_pr.yml index 56fa006b1633..1ccdd59afefd 100644 --- a/.github/workflows/example_check_on_pr.yml +++ b/.github/workflows/example_check_on_pr.yml @@ -107,7 +107,7 @@ jobs: - name: Install Colossal-AI run: | - BUILD_EXT=1 pip install -v . + BUILD_EXT=1 pip install -v -e . - name: Store Colossal-AI Cache run: | From 4f127135b2c33cde6056d251d0c32d845276c2ac Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 8 Aug 2024 08:04:55 +0000 Subject: [PATCH 3/4] fix --- .github/workflows/example_check_on_pr.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/example_check_on_pr.yml b/.github/workflows/example_check_on_pr.yml index 1ccdd59afefd..7a906738cb96 100644 --- a/.github/workflows/example_check_on_pr.yml +++ b/.github/workflows/example_check_on_pr.yml @@ -9,6 +9,7 @@ on: paths: - "examples/**" - "!examples/**.md" + - ".github/workflows/example_check_on_pr.yml" jobs: # This is for changed example files detect and output a matrix containing all the corresponding directory name. From 27ab889057b9588cb0e5ae526ad8ac0fe7ec15d2 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 14 Aug 2024 06:47:51 +0000 Subject: [PATCH 4/4] fix --- tests/test_fp8/test_fp8_reduce_scatter.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/test_fp8/test_fp8_reduce_scatter.py b/tests/test_fp8/test_fp8_reduce_scatter.py index c18446e39ea0..e0b558a257ed 100644 --- a/tests/test_fp8/test_fp8_reduce_scatter.py +++ b/tests/test_fp8/test_fp8_reduce_scatter.py @@ -13,14 +13,20 @@ @parameterize("scatter_dim", [0, 1, 2]) @parameterize("dtype", [torch.bfloat16, torch.float16]) @parameterize("fp8_format", ["e4m3", "e5m2"]) -def check_4gpu(shape, scatter_dim, dtype, fp8_format): +@parameterize("async_op", [True, False]) +def check_4gpu(shape, scatter_dim, dtype, fp8_format, async_op): x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) input_list = list(torch.chunk(x, dim=scatter_dim, chunks=4)) input_list = [t.contiguous() for t in input_list] output_origin = torch.empty_like(input_list[0]) output_fp8 = torch.empty_like(input_list[0]) - reduce_scatter(output_origin, input_list, group=_get_default_group()) - reduce_scatter_fp8(output_fp8, input_list, group=_get_default_group(), fp8_format=fp8_format) + origin_handle = reduce_scatter(output_origin, input_list, group=_get_default_group(), async_op=async_op) + fp8_handle = reduce_scatter_fp8( + output_fp8, input_list, group=_get_default_group(), fp8_format=fp8_format, async_op=async_op + ) + if async_op: + origin_handle.wait() + fp8_handle.wait() assert_close(output_origin, output_fp8, rtol=0.1, atol=0.1)