From 3df5bcad960580819c16427db020d811c160b18a Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Thu, 8 Jun 2023 14:09:43 +0800 Subject: [PATCH 1/5] fix bug in slicer, add slicer unit test --- colossalai/shardformer/policies/bert.py | 2 +- colossalai/shardformer/shard/slicer.py | 16 ++-- .../test_module/test_slicer.py | 78 +++++++++++++++++++ 3 files changed, 87 insertions(+), 9 deletions(-) create mode 100644 tests/test_shardformer/test_module/test_slicer.py diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 89b32f065c27..f37fc4ac28f1 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -141,7 +141,7 @@ def unembedding() -> List: weight="decoder.weight", bias="decoder.bias", replace_layer=col_nn.Linear1D_Col, - # gather_output=True, + gather_output=True, ) ] diff --git a/colossalai/shardformer/shard/slicer.py b/colossalai/shardformer/shard/slicer.py index 6d35bd193fed..09e3219f87a2 100644 --- a/colossalai/shardformer/shard/slicer.py +++ b/colossalai/shardformer/shard/slicer.py @@ -3,7 +3,7 @@ from ..policies.basepolicy import Col_Layer, Layer, Row_Layer from .shard_config import ShardConfig -dim_mapping = {Col_Layer: 1, Row_Layer: 0} +dim_mapping = {Col_Layer: 0, Row_Layer: 1} class Slicer(): @@ -40,7 +40,7 @@ def slice_weight_bias( # print(weight.shape, dim) if policy_layer_cls == Col_Layer: weight = self.slice_tensor(weight, dim, False, n_cast) - bias = self.slice_tensor(bias, 0, True) + bias = self.slice_tensor(bias, 0, True, n_cast) elif policy_layer_cls == Row_Layer: weight = self.slice_tensor(weight, dim, False, n_cast) else: @@ -129,13 +129,13 @@ def slice_col( """ if n_cast is None: - return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous() + return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous() else: - tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0) + tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=1) chunk_list = [ tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size) ] - return torch.cat(chunk_list, dim=0).contiguous() + return torch.cat(chunk_list, dim=1).contiguous() def slice_row( self, @@ -152,10 +152,10 @@ def slice_row( :class:`torch.Tensor`: The sliced tensor """ if n_cast is None: - return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous() + return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous() else: - tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=1) + tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0) chunk_list = [ tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size) ] - return torch.cat(chunk_list, dim=1).contiguous() + return torch.cat(chunk_list, dim=0).contiguous() diff --git a/tests/test_shardformer/test_module/test_slicer.py b/tests/test_shardformer/test_module/test_slicer.py new file mode 100644 index 000000000000..c72a0357573b --- /dev/null +++ b/tests/test_shardformer/test_module/test_slicer.py @@ -0,0 +1,78 @@ +import pytest +import torch +import torch.nn.functional as F + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.policies.basepolicy import Col_Layer, Layer, Row_Layer +from colossalai.shardformer.shard.shard_config import ShardConfig +from colossalai.shardformer.shard.slicer import Slicer +from colossalai.testing import rerun_if_address_is_in_use, spawn + +CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),) + + +def check_slicer(rank, world_size, port, in_feature, out_feature): + disable_existing_loggers() + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host='localhost', backend='nccl') + # initialize slicer + shardconfig = ShardConfig(rank=rank, world_size=world_size) + slicer = Slicer(shardconfig) + # initialize test data + weight = torch.randn(in_feature, out_feature) + bias = torch.randn(out_feature) + policy_layer_cls_list = [Layer, Col_Layer, Row_Layer] + n_cast_list = [None, 2, 3, 4] + # weight and bias + for n_cast in n_cast_list: + sliced_weight, sliced_bias = slicer.slice_weight_bias(weight, bias, policy_layer_cls=Layer, n_cast=n_cast) + expected_sliced_weight = weight + expected_sliced_bias = bias + assert torch.equal( + sliced_weight, expected_sliced_weight + ), f"In Layer case, weight: sliced_weight is not equal to expected_sliced_weight\norg:{weight}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}" + assert torch.equal( + sliced_bias, expected_sliced_bias + ), f"In Layer case, bias: sliced_bias is not equal to expected_sliced_bias\norg:{bias}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}" + + sliced_weight, sliced_bias = slicer.slice_weight_bias(weight, bias, policy_layer_cls=Col_Layer, n_cast=n_cast) + if (n_cast is None): + expected_sliced_weight = weight.chunk(world_size, dim=0)[rank] + expected_sliced_bias = bias.chunk(world_size)[rank] + else: + chunks = weight.chunk(world_size * n_cast, dim=0) + expected_sliced_weight = torch.cat([chunks[i] for i in range(rank, n_cast * world_size, world_size)], dim=0) + chunks = bias.chunk(world_size * n_cast, dim=0) + expected_sliced_bias = torch.cat([chunks[i] for i in range(rank, n_cast * world_size, world_size)]) + assert torch.equal( + sliced_weight, expected_sliced_weight + ), f"In Col_Layer {n_cast} cast case, weight: sliced_weight is not equal to expected_sliced_weight\norg:{weight}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}" + assert torch.equal( + sliced_bias, expected_sliced_bias + ), f"In Col_Layer {n_cast} cast case, bias: sliced_bias is not equal to expected_sliced_bias\norg:{bias}\nsliced:{sliced_bias}\nexpected:{expected_sliced_bias}" + + sliced_weight, sliced_bias = slicer.slice_weight_bias(weight, bias, policy_layer_cls=Row_Layer, n_cast=n_cast) + if (n_cast is None): + expected_sliced_weight = weight.chunk(world_size, dim=1)[rank] + expected_sliced_bias = bias + else: + chunks = weight.chunk(world_size * n_cast, dim=1) + expected_sliced_weight = torch.cat([chunks[i] for i in range(rank, n_cast * world_size, world_size)], dim=1) + expected_sliced_bias = bias + assert torch.equal( + sliced_weight, expected_sliced_weight + ), f"In Row_Layer {n_cast} cast case, weight: sliced_weight is not equal to expected_sliced_weight\norg:{weight}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}" + assert torch.equal( + sliced_bias, expected_sliced_bias + ), f"In Row_Layer {n_cast} cast case, bias: sliced_bias is not equal to expected_sliced_bias\norg:{bias}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}" + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_slicer(): + args = dict(in_feature=24, out_feature=48) + spawn(check_slicer, nprocs=2, in_feature=args['in_feature'], out_feature=args['out_feature']) + + +if __name__ == '__main__': + test_slicer() From b4fac4e56f321a77b23f8283789e498b980b927d Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Thu, 8 Jun 2023 14:40:23 +0800 Subject: [PATCH 2/5] add dropout test --- .../test_module/test_dropout.py | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 tests/test_shardformer/test_module/test_dropout.py diff --git a/tests/test_shardformer/test_module/test_dropout.py b/tests/test_shardformer/test_module/test_dropout.py new file mode 100644 index 000000000000..b03195eaf319 --- /dev/null +++ b/tests/test_shardformer/test_module/test_dropout.py @@ -0,0 +1,41 @@ +import pytest +import torch +import torch.nn.functional as F + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use, spawn + +CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),) + + +def check_dropout(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host='localhost', backend='nccl') + from colossalai.shardformer.layer.dropout import Dropout1D + + # prepare data + input = torch.randn(5, 4).to('cuda') + dropout = Dropout1D(p=0.4).to('cuda') + output_list = [] + for i in range(2): + output = dropout(input) + output_list.append(output) + dist_output_list = [torch.zeros(*output.shape).to('cuda') for _ in range(world_size)] + torch.distributed.all_gather(dist_output_list, output) + print(dist_output_list) + for j in range(world_size): + for k in range(world_size): + if j != k: + mask = torch.eq(dist_output_list[i], 0.0) == torch.eq(dist_output_list[j], 0.0) + assert torch.all(mask) == False, f"{mask}" + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_dropout(): + spawn(check_dropout, 2) + + +if __name__ == '__main__': + test_dropout() From 27e3f7d95bf5203c8967714b4e9c35a44d010330 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Thu, 8 Jun 2023 16:01:17 +0800 Subject: [PATCH 3/5] use pid as dropout seed --- colossalai/shardformer/layer/dropout.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/shardformer/layer/dropout.py b/colossalai/shardformer/layer/dropout.py index acc114029ac1..9af8ede93d77 100644 --- a/colossalai/shardformer/layer/dropout.py +++ b/colossalai/shardformer/layer/dropout.py @@ -14,7 +14,7 @@ class SeedManager: def __init__(self): original_state = torch.cuda.get_rng_state() - seed = int(f"{int(time.time())}{os.environ['RANK']}") + seed = os.getpid() torch.cuda.manual_seed(int(seed)) self.dropout_state = torch.cuda.get_rng_state() torch.cuda.set_rng_state(original_state) From cf2a65a430647313e11c09350aea81cd2a0cca21 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Thu, 8 Jun 2023 16:01:58 +0800 Subject: [PATCH 4/5] updata dropout test with local pattern --- .../test_module/test_dropout.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/tests/test_shardformer/test_module/test_dropout.py b/tests/test_shardformer/test_module/test_dropout.py index b03195eaf319..4a13eb61c1fc 100644 --- a/tests/test_shardformer/test_module/test_dropout.py +++ b/tests/test_shardformer/test_module/test_dropout.py @@ -4,6 +4,7 @@ import colossalai from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.dropout import Dropout1D from colossalai.testing import rerun_if_address_is_in_use, spawn CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),) @@ -12,23 +13,32 @@ def check_dropout(rank, world_size, port): disable_existing_loggers() colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host='localhost', backend='nccl') - from colossalai.shardformer.layer.dropout import Dropout1D # prepare data input = torch.randn(5, 4).to('cuda') dropout = Dropout1D(p=0.4).to('cuda') output_list = [] + # compare the dropout pattern in each device for i in range(2): output = dropout(input) output_list.append(output) dist_output_list = [torch.zeros(*output.shape).to('cuda') for _ in range(world_size)] torch.distributed.all_gather(dist_output_list, output) - print(dist_output_list) for j in range(world_size): for k in range(world_size): if j != k: - mask = torch.eq(dist_output_list[i], 0.0) == torch.eq(dist_output_list[j], 0.0) - assert torch.all(mask) == False, f"{mask}" + mask = torch.eq(dist_output_list[j], 0.0) == torch.eq(dist_output_list[k], 0.0) + assert torch.all( + mask + ) == False, f"The dropout pattern in each device is not unique\n{dist_output_list[j]}\n{dist_output_list[k]}" + # compare the dropout pattern in loacl device + for i in range(len(output_list)): + for j in range(len(output_list)): + if i != j: + mask = torch.eq(output_list[i], 0.0) == torch.eq(output_list[j], 0.0) + assert torch.all( + mask + ) == False, f"The dropout pattern in one device is not unique\n{output_list[i]}\n{output_list[j]}" @pytest.mark.dist From d260db7cd0f4657f1adccf0066713ecf1cd1dadb Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Fri, 9 Jun 2023 11:55:46 +0800 Subject: [PATCH 5/5] ad todo --- colossalai/shardformer/layer/dropout.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/shardformer/layer/dropout.py b/colossalai/shardformer/layer/dropout.py index 9af8ede93d77..0f653a9be780 100644 --- a/colossalai/shardformer/layer/dropout.py +++ b/colossalai/shardformer/layer/dropout.py @@ -1,5 +1,4 @@ import os -import time from contextlib import contextmanager import torch @@ -14,6 +13,7 @@ class SeedManager: def __init__(self): original_state = torch.cuda.get_rng_state() + # TODO: unify this seed manager with the colossalai.context.random seed = os.getpid() torch.cuda.manual_seed(int(seed)) self.dropout_state = torch.cuda.get_rng_state()