From 591255b5b49243bf8a3a618c0b916a1486adc71f Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Tue, 6 Jun 2023 14:32:07 +0800 Subject: [PATCH 01/11] add bert align test, fix dist loss bug --- .../shardformer/layer/dist_crossentropy.py | 4 +- colossalai/shardformer/test/align_bert.py | 99 +++++++++++++++++++ colossalai/shardformer/test/module_test.py | 6 +- 3 files changed, 105 insertions(+), 4 deletions(-) create mode 100644 colossalai/shardformer/test/align_bert.py diff --git a/colossalai/shardformer/layer/dist_crossentropy.py b/colossalai/shardformer/layer/dist_crossentropy.py index 1869594670ce..721fa55fdf5c 100644 --- a/colossalai/shardformer/layer/dist_crossentropy.py +++ b/colossalai/shardformer/layer/dist_crossentropy.py @@ -75,8 +75,8 @@ def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor): # calculate the loss # loss = log(sum(exp(x[i]))) - x[class] - loss = torch.log(sum_exp_logits) - pred_logits - loss = torch.sum(loss).div_(loss.numel()) + loss = torch.where(target == -100, 0.0, torch.log(sum_exp_logits) - pred_logits) + loss = torch.sum(loss).div_(torch.sum(loss != 0.0)) # caculate the softmax exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) diff --git a/colossalai/shardformer/test/align_bert.py b/colossalai/shardformer/test/align_bert.py new file mode 100644 index 000000000000..cf26f779fe2e --- /dev/null +++ b/colossalai/shardformer/test/align_bert.py @@ -0,0 +1,99 @@ +import os +import random + +import torch +import torch.nn as nn +from datasets import load_dataset +from torch.utils.data import DataLoader +from tqdm.auto import tqdm +from transformers import ( + AutoTokenizer, + BertConfig, + BertForMaskedLM, + DataCollatorForLanguageModeling, + GPT2LMHeadModel, + get_scheduler, +) + +import colossalai +from colossalai.shardformer.shard import ShardConfig, shard_model +from colossalai.utils import get_current_device, print_rank_0 + +os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' + + +def get_args(): + parser = colossalai.get_default_parser() + parser.add_argument("--mode", type=str, default='inference') + parser.add_argument("--save_model", action='store_true') + parser.add_argument("--model", type=str, default='bert-base-uncased') + return parser.parse_args() + + +def forward_verify(): + # launch dist + colossalai.launch_from_torch(config=get_args().config) + input = "Hello, my dog is cute" + tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') + tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') + # forward + # orgin model + org_model = BertForMaskedLM.from_pretrained('bert-base-uncased').to('cuda') + org_model.eval() + org_out = org_model(**tokenized_input) + print_rank_0(org_out[0]) + # shard model + shardconfig = ShardConfig( + rank=int(os.environ['RANK']), + world_size=int(os.environ['WORLD_SIZE']), + ) + sharded_model = shard_model(BertForMaskedLM.from_pretrained('bert-base-uncased'), shardconfig).to('cuda') + sharded_model.eval() + shard_out = sharded_model(**tokenized_input) + print_rank_0(shard_out[0]) + + assert torch.allclose(org_out[0], shard_out[0], atol=1e-5), "shard model output is not equal to orgin model output" + print("[OK] shard model output is equal to orgin model output") + + +def backward_verify(): + # launch dist + colossalai.launch_from_torch(config=get_args().config) + input = "Hello, my dog is cute" + tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') + tokenized_input = tokenizer(input, padding='max_length', return_tensors='pt').to('cuda') + labels = tokenized_input['input_ids'].clone() + labels[labels == tokenizer.pad_token_id] = -100 + tokenized_input['labels'] = labels + # disable dropout + config = BertConfig.from_pretrained('bert-base-uncased') + config.hidden_dropout_prob = 0 + config.attention_probs_dropout_prob = 0 + # backward + # orgin model + org_model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config).to('cuda') + org_model.train() + org_out = org_model(**tokenized_input) + org_loss = org_out.loss + print_rank_0(org_loss) + org_loss.backward() + + # shard model + shardconfig = ShardConfig( + rank=int(os.environ['RANK']), + world_size=int(os.environ['WORLD_SIZE']), + ) + sharded_model = shard_model(BertForMaskedLM.from_pretrained('bert-base-uncased', config=config), + shardconfig).to('cuda') + sharded_model.train() + shard_out = sharded_model(**tokenized_input) + shard_loss = shard_out.loss + print_rank_0(shard_loss) + shard_loss.backward() + + assert torch.allclose(org_out[0], shard_out[0], atol=1e-5), "shard model output is not equal to orgin model output" + print("[OK] shard model output is equal to orgin model output") + + +if __name__ == '__main__': + backward_verify() diff --git a/colossalai/shardformer/test/module_test.py b/colossalai/shardformer/test/module_test.py index 83dc7ec6cf4a..6f3b43496f01 100644 --- a/colossalai/shardformer/test/module_test.py +++ b/colossalai/shardformer/test/module_test.py @@ -17,8 +17,10 @@ def get_args(): def test_dist_crossentropy(): pred = torch.randn(2, 4, 8, requires_grad=True) - labels = torch.randint(8, (1, 4)).repeat(2, 1) - + print(pred) + labels = torch.randint(8, (2, 4)) + labels[0, -1] = -100 + print(labels) pred_ = pred.view(-1, 8) labels_ = labels.view(-1) loss = F.cross_entropy(pred_, labels_) From e3bae196fc3d5df5cff242b7041d8600985b0dbe Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Wed, 7 Jun 2023 11:08:50 +0800 Subject: [PATCH 02/11] forward and backward align --- colossalai/shardformer/policies/bert.py | 5 +-- colossalai/shardformer/shard/sharder.py | 2 ++ colossalai/shardformer/test/align_bert.py | 43 ++++++++++++++++------- 3 files changed, 35 insertions(+), 15 deletions(-) diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 89b32f065c27..5d489f41986c 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -141,7 +141,7 @@ def unembedding() -> List: weight="decoder.weight", bias="decoder.bias", replace_layer=col_nn.Linear1D_Col, - # gather_output=True, + gather_output=True, ) ] @@ -155,7 +155,8 @@ class BertForMaskedLMPolicy(BertPolicy): @staticmethod def inject_policy() -> Tuple[nn.Module, nn.Module]: - return (BertForMaskedLM, BertForMaskedLM_) + # return (BertForMaskedLM, BertForMaskedLM_) + return None class BertForSequenceClassificationPolicy(BertPolicy): diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 1ada75e06b67..941199ef5ac3 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -65,6 +65,8 @@ def inject_model( BertForMaskedLM.forward -> BertForMaskedLM_.forward """ inject_policy = self.policy.inject_policy() + if inject_policy is None: + return if inject_policy is None: return diff --git a/colossalai/shardformer/test/align_bert.py b/colossalai/shardformer/test/align_bert.py index cf26f779fe2e..ecdd720cea07 100644 --- a/colossalai/shardformer/test/align_bert.py +++ b/colossalai/shardformer/test/align_bert.py @@ -32,7 +32,7 @@ def get_args(): def forward_verify(): # launch dist - colossalai.launch_from_torch(config=get_args().config) + # colossalai.launch_from_torch(config=get_args().config) input = "Hello, my dog is cute" tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') @@ -41,7 +41,7 @@ def forward_verify(): org_model = BertForMaskedLM.from_pretrained('bert-base-uncased').to('cuda') org_model.eval() org_out = org_model(**tokenized_input) - print_rank_0(org_out[0]) + # print_rank_0(org_out[0]) # shard model shardconfig = ShardConfig( rank=int(os.environ['RANK']), @@ -50,18 +50,20 @@ def forward_verify(): sharded_model = shard_model(BertForMaskedLM.from_pretrained('bert-base-uncased'), shardconfig).to('cuda') sharded_model.eval() shard_out = sharded_model(**tokenized_input) - print_rank_0(shard_out[0]) + # print_rank_0(shard_out[0]) - assert torch.allclose(org_out[0], shard_out[0], atol=1e-5), "shard model output is not equal to orgin model output" - print("[OK] shard model output is equal to orgin model output") + assert torch.allclose( + org_out[0], shard_out[0], + atol=1e-5), f"shard model output is not equal to orgin model output\n{org_out[0]}\n{shard_out[0]}" + print_rank_0("[OK] shard model output is equal to orgin model output") def backward_verify(): # launch dist - colossalai.launch_from_torch(config=get_args().config) + # colossalai.launch_from_torch(config=get_args().config) input = "Hello, my dog is cute" tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') - tokenized_input = tokenizer(input, padding='max_length', return_tensors='pt').to('cuda') + tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') labels = tokenized_input['input_ids'].clone() labels[labels == tokenizer.pad_token_id] = -100 tokenized_input['labels'] = labels @@ -75,9 +77,10 @@ def backward_verify(): org_model.train() org_out = org_model(**tokenized_input) org_loss = org_out.loss - print_rank_0(org_loss) + # print_rank_0(org_loss) org_loss.backward() - + org_grad = org_model.bert.encoder.layer[0].attention.self.query.weight.grad + # print_rank_0(f"grad: {org_grad}") # shard model shardconfig = ShardConfig( rank=int(os.environ['RANK']), @@ -88,12 +91,26 @@ def backward_verify(): sharded_model.train() shard_out = sharded_model(**tokenized_input) shard_loss = shard_out.loss - print_rank_0(shard_loss) + # print_rank_0(shard_loss) shard_loss.backward() - - assert torch.allclose(org_out[0], shard_out[0], atol=1e-5), "shard model output is not equal to orgin model output" - print("[OK] shard model output is equal to orgin model output") + shard_grad = sharded_model.bert.encoder.layer[0].attention.self.query.weight.grad + # print(f"grad: {shard_grad}") + # all gather + gather_grad = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(gather_grad, shard_grad) + all_shard_grad = torch.cat(gather_grad, dim=0) + # print(all_shard_grad) + assert torch.allclose(org_loss, shard_loss, + atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + print_rank_0("[OK] shard model loss is equal to orgin model loss") + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" + print_rank_0("[OK] shard model grad is equal to orgin model grad") if __name__ == '__main__': + colossalai.launch_from_torch(config=get_args().config) + print_rank_0("\n-------------------forward--------------------") + forward_verify() + print_rank_0("\n-------------------backward--------------------") backward_verify() From 4e000cad00750b09de6b8f0d43489827b297cf98 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Wed, 7 Jun 2023 15:41:10 +0800 Subject: [PATCH 03/11] add ignore index --- colossalai/shardformer/layer/dist_crossentropy.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/colossalai/shardformer/layer/dist_crossentropy.py b/colossalai/shardformer/layer/dist_crossentropy.py index 721fa55fdf5c..05c04bb545c1 100644 --- a/colossalai/shardformer/layer/dist_crossentropy.py +++ b/colossalai/shardformer/layer/dist_crossentropy.py @@ -14,7 +14,7 @@ class DistCrossEntropy(Function): """ @staticmethod - def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor): + def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index: int): r""" Calculate the cross entropy loss before gather, the origin loss function is as follows: loss = -log(exp(x[class])/sum(exp(x[i])) @@ -75,7 +75,7 @@ def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor): # calculate the loss # loss = log(sum(exp(x[i]))) - x[class] - loss = torch.where(target == -100, 0.0, torch.log(sum_exp_logits) - pred_logits) + loss = torch.where(target == ignore_index, 0.0, torch.log(sum_exp_logits) - pred_logits) loss = torch.sum(loss).div_(torch.sum(loss != 0.0)) # caculate the softmax @@ -101,5 +101,5 @@ def backward(ctx, grad_output): return grad_logits, None, None -def applyDistCrossEntropy(vocab_logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: - return DistCrossEntropy.apply(vocab_logits, labels) +def applyDistCrossEntropy(vocab_logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100) -> torch.Tensor: + return DistCrossEntropy.apply(vocab_logits, labels, ignore_index) From cfe5bc71efafd0c6bd0bee8efd361835ae21a528 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Wed, 7 Jun 2023 16:10:31 +0800 Subject: [PATCH 04/11] add shardformer CI --- tests/test_shardformer/bert/bert_test.py | 102 +++++++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 tests/test_shardformer/bert/bert_test.py diff --git a/tests/test_shardformer/bert/bert_test.py b/tests/test_shardformer/bert/bert_test.py new file mode 100644 index 000000000000..071d7fc6be16 --- /dev/null +++ b/tests/test_shardformer/bert/bert_test.py @@ -0,0 +1,102 @@ +import os +import random + +import pytest +import torch +from transformers import AutoTokenizer, BertConfig, BertForMaskedLM + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.shard import ShardConfig, shard_model +from colossalai.testing import rerun_if_address_is_in_use, spawn + +os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' +CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),) +tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + + +def build_model(rank, world_size): + config = BertConfig.from_pretrained('bert-base-uncased') + config.hidden_dropout_prob = 0 + config.attention_probs_dropout_prob = 0 + + org_model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config).to('cuda') + + shardconfig = ShardConfig( + rank=rank, + world_size=world_size, + ) + sharded_model = shard_model(BertForMaskedLM.from_pretrained('bert-base-uncased', config=config), + shardconfig).to('cuda') + + return org_model, sharded_model + + +def check_forward(org_model, sharded_model): + input = 'Hello, my dog is cute' + tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') + + #orgin model + org_model.eval() + org_out = org_model(**tokenized_input) + + #shard model + sharded_model.eval() + shard_out = sharded_model(**tokenized_input) + + assert torch.allclose( + org_out[0], shard_out[0], + atol=1e-5), f"shard model output is not equal to orgin model output\n{org_out[0]}\n{shard_out[0]}" + + +def check_backward(org_model, sharded_model): + # prepare input + input = 'Hello, my dog is cute' + tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') + labels = tokenized_input['input_ids'].clone() + labels[labels == tokenizer.pad_token_id] = -100 + tokenized_input['labels'] = labels + + #orgin model + org_model.train() + org_out = org_model(**tokenized_input) + org_loss = org_out.loss + org_loss.backward() + org_grad = org_model.bert.encoder.layer[0].attention.self.query.weight.grad + + #shard model + sharded_model.train() + shard_out = sharded_model(**tokenized_input) + shard_loss = shard_out.loss + shard_loss.backward() + shard_grad = sharded_model.bert.encoder.layer[0].attention.self.query.weight.grad + + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + + assert torch.allclose(org_loss, shard_loss, + atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" + + +def check_bert(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + org_model, sharded_model = build_model(rank, world_size) + check_forward(org_model, sharded_model) + check_backward(org_model, sharded_model) + + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bert(): + spawn(check_bert, 2) + + +if __name__ == "__main__": + test_bert() From e32490b551f4f92d3584ccd314cd8e10fa786267 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Wed, 7 Jun 2023 16:28:01 +0800 Subject: [PATCH 05/11] add gather_output optional for user in shardconfig --- colossalai/shardformer/shard/shard_config.py | 18 ++++++++---------- colossalai/shardformer/shard/sharder.py | 2 +- tests/test_shardformer/bert/bert_test.py | 1 + 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 4cf9162b9548..e8d6f3408c76 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -5,16 +5,14 @@ @dataclass class ShardConfig: - """ - The config for sharding the huggingface model for test + r""" + The config for sharding the huggingface model + + Args: + rank (int): The rank of local process + world_size (int): The world size of the distributed process + gather_output (bool): Whether to gather the output of the model of the last layer """ rank: int - fp16: bool = True - num_gpus: int = 2 world_size: int = 2 - backend = "nccl" - verbose: str = 'simple' - seed: int = None - require_grad: bool = False - master_addr: str = "127.0.0.1" - master_port: int = 29500 + gather_output: bool = True diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 941199ef5ac3..159bebccd02d 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -150,7 +150,7 @@ def shard_one_layer( n_cast = policy_layer.n_cast reversed = policy_layer.reversed if policy_layer.__class__.__name__ == "Col_Layer": - gather_output = policy_layer.gather_output + gather_output = policy_layer.gather_output and self.shard_config.gather_output if weight_attr is not None: if hasattr_(org_layer, weight_attr): diff --git a/tests/test_shardformer/bert/bert_test.py b/tests/test_shardformer/bert/bert_test.py index 071d7fc6be16..55b78d040505 100644 --- a/tests/test_shardformer/bert/bert_test.py +++ b/tests/test_shardformer/bert/bert_test.py @@ -25,6 +25,7 @@ def build_model(rank, world_size): shardconfig = ShardConfig( rank=rank, world_size=world_size, + gather_output=True, ) sharded_model = shard_model(BertForMaskedLM.from_pretrained('bert-base-uncased', config=config), shardconfig).to('cuda') From 64764b1d9d9136fc3b6a196b7ff795f9230099b7 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Wed, 7 Jun 2023 16:51:46 +0800 Subject: [PATCH 06/11] update readme with optional gather_ouput --- colossalai/shardformer/README.md | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index 93a4f1e578e4..699b4b7da3ef 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -20,7 +20,7 @@ The sample API usage is given below: ``` python -from colossalai.shardformer import shard_model +from colossalai.shardformer.shard import ShardConfig, shard_model from transformers import BertForMaskedLM # create huggingface model as normal @@ -28,7 +28,12 @@ model = BertForMaskedLM.from_pretrained("bert-base-uncased") # make the huggingface model paralleled to ShardModel # auto policy: -sharded_model = shard_model(model) +shardconfig = ShardConfig( + rank=rank, + world_size=world_size, + gather_output=True, +) +sharded_model = shard_model(model, config=shardconfig) # custom policy: from xxx import @@ -235,7 +240,7 @@ CustomPolicy(Policy): This class is used to specify the replacement policy for a particular layer. If `replace_layer` is None, only parameter partitioning will be performed without replacing the layer class. CLASS `Col_Layer(Layer)`: - - gather_output (bool): Whether to gather the output of the layer + - gather_output (bool): Whether the output of this layer can be gathered, like the last layer can be gathered, but most of the time, the intermediate layers of the model do not need to be gathered. This class inherited from `Layer`, representing the layer will be sliced along column. From 5a1cdefb079527b6e79cfce14c2a342b6f71520f Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Wed, 7 Jun 2023 17:28:22 +0800 Subject: [PATCH 07/11] add dist crossentropy loss test, remove unused files --- colossalai/shardformer/test/align_bert.py | 116 ---------------- colossalai/shardformer/test/test.py | 124 ------------------ .../bert_test.py => test_bert/test_bert.py} | 0 .../test_module/test_distcrossentropy.py | 42 ++++++ 4 files changed, 42 insertions(+), 240 deletions(-) delete mode 100644 colossalai/shardformer/test/align_bert.py delete mode 100644 colossalai/shardformer/test/test.py rename tests/test_shardformer/{bert/bert_test.py => test_bert/test_bert.py} (100%) create mode 100644 tests/test_shardformer/test_module/test_distcrossentropy.py diff --git a/colossalai/shardformer/test/align_bert.py b/colossalai/shardformer/test/align_bert.py deleted file mode 100644 index ecdd720cea07..000000000000 --- a/colossalai/shardformer/test/align_bert.py +++ /dev/null @@ -1,116 +0,0 @@ -import os -import random - -import torch -import torch.nn as nn -from datasets import load_dataset -from torch.utils.data import DataLoader -from tqdm.auto import tqdm -from transformers import ( - AutoTokenizer, - BertConfig, - BertForMaskedLM, - DataCollatorForLanguageModeling, - GPT2LMHeadModel, - get_scheduler, -) - -import colossalai -from colossalai.shardformer.shard import ShardConfig, shard_model -from colossalai.utils import get_current_device, print_rank_0 - -os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' - - -def get_args(): - parser = colossalai.get_default_parser() - parser.add_argument("--mode", type=str, default='inference') - parser.add_argument("--save_model", action='store_true') - parser.add_argument("--model", type=str, default='bert-base-uncased') - return parser.parse_args() - - -def forward_verify(): - # launch dist - # colossalai.launch_from_torch(config=get_args().config) - input = "Hello, my dog is cute" - tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') - tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') - # forward - # orgin model - org_model = BertForMaskedLM.from_pretrained('bert-base-uncased').to('cuda') - org_model.eval() - org_out = org_model(**tokenized_input) - # print_rank_0(org_out[0]) - # shard model - shardconfig = ShardConfig( - rank=int(os.environ['RANK']), - world_size=int(os.environ['WORLD_SIZE']), - ) - sharded_model = shard_model(BertForMaskedLM.from_pretrained('bert-base-uncased'), shardconfig).to('cuda') - sharded_model.eval() - shard_out = sharded_model(**tokenized_input) - # print_rank_0(shard_out[0]) - - assert torch.allclose( - org_out[0], shard_out[0], - atol=1e-5), f"shard model output is not equal to orgin model output\n{org_out[0]}\n{shard_out[0]}" - print_rank_0("[OK] shard model output is equal to orgin model output") - - -def backward_verify(): - # launch dist - # colossalai.launch_from_torch(config=get_args().config) - input = "Hello, my dog is cute" - tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') - tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') - labels = tokenized_input['input_ids'].clone() - labels[labels == tokenizer.pad_token_id] = -100 - tokenized_input['labels'] = labels - # disable dropout - config = BertConfig.from_pretrained('bert-base-uncased') - config.hidden_dropout_prob = 0 - config.attention_probs_dropout_prob = 0 - # backward - # orgin model - org_model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config).to('cuda') - org_model.train() - org_out = org_model(**tokenized_input) - org_loss = org_out.loss - # print_rank_0(org_loss) - org_loss.backward() - org_grad = org_model.bert.encoder.layer[0].attention.self.query.weight.grad - # print_rank_0(f"grad: {org_grad}") - # shard model - shardconfig = ShardConfig( - rank=int(os.environ['RANK']), - world_size=int(os.environ['WORLD_SIZE']), - ) - sharded_model = shard_model(BertForMaskedLM.from_pretrained('bert-base-uncased', config=config), - shardconfig).to('cuda') - sharded_model.train() - shard_out = sharded_model(**tokenized_input) - shard_loss = shard_out.loss - # print_rank_0(shard_loss) - shard_loss.backward() - shard_grad = sharded_model.bert.encoder.layer[0].attention.self.query.weight.grad - # print(f"grad: {shard_grad}") - # all gather - gather_grad = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - shard_grad = torch.distributed.all_gather(gather_grad, shard_grad) - all_shard_grad = torch.cat(gather_grad, dim=0) - # print(all_shard_grad) - assert torch.allclose(org_loss, shard_loss, - atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" - print_rank_0("[OK] shard model loss is equal to orgin model loss") - assert torch.allclose(org_grad, all_shard_grad, - atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" - print_rank_0("[OK] shard model grad is equal to orgin model grad") - - -if __name__ == '__main__': - colossalai.launch_from_torch(config=get_args().config) - print_rank_0("\n-------------------forward--------------------") - forward_verify() - print_rank_0("\n-------------------backward--------------------") - backward_verify() diff --git a/colossalai/shardformer/test/test.py b/colossalai/shardformer/test/test.py deleted file mode 100644 index e2d5a94c782a..000000000000 --- a/colossalai/shardformer/test/test.py +++ /dev/null @@ -1,124 +0,0 @@ -import os -import random - -import torch -import torch.nn as nn -from datasets import load_dataset -from torch.utils.data import DataLoader -from tqdm.auto import tqdm -from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, GPT2LMHeadModel, get_scheduler - -import colossalai -from colossalai.shardformer.shard import ShardConfig, shard_model -from colossalai.utils import get_current_device, print_rank_0 - -os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' - - -def get_args(): - parser = colossalai.get_default_parser() - parser.add_argument("--mode", type=str, default='inference') - parser.add_argument("--save_model", action='store_true') - parser.add_argument("--model", type=str, default='bert-base-uncased') - return parser.parse_args() - - -def load_data(args): - tokenizer = AutoTokenizer.from_pretrained(args.model) - if tokenizer.pad_token is None: - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - # tokenizer.pad_token_id = 0 - datasets = load_dataset('wikitext', 'wikitext-2-raw-v1') - # datasets=load_dataset("yelp_review_full") - tokenized_datasets = datasets.map( - lambda examples: tokenizer(examples["text"], truncation=True, padding="max_length"), batched=True) - tokenized_datasets = tokenized_datasets.remove_columns(["text"]) - # tokenized_datasets=tokenized_datasets.rename_column("label","labels") - tokenized_datasets.set_format("torch") - - train_dataset = tokenized_datasets["train"] - test_dataset = tokenized_datasets["test"] - - datacollector = DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=0.15, return_tensors="pt") - train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=datacollector) - eval_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True, collate_fn=datacollector) - return train_dataloader, eval_dataloader - - -def inference(model: nn.Module, args): - print(model) - # print(model.wte.weight.shape) - tokenizer = AutoTokenizer.from_pretrained(args.model) - if tokenizer.pad_token is None: - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - tokenizer.pad_token_id = 0 - token = "Hello, my dog is cute" - inputs = tokenizer(token, return_tensors="pt") - inputs.to("cuda") - model.eval() - model.to("cuda") - outputs = model(**inputs) - print(outputs[0]) - - -def train(model: nn.Module, args, num_epoch: int = 3): - train_dataloader, eval_dataloader = load_data(args) - optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5) - num_training = num_epoch * len(train_dataloader) - progress_bar = tqdm(range(num_training)) - lr_scheduler = get_scheduler(name="linear", - optimizer=optimizer, - num_warmup_steps=0, - num_training_steps=num_training) - best_test_loss = float("inf") - model.to("cuda") - model.train() - for epoch in range(num_epoch): - progress_bar.set_description(f"Rank {get_current_device()} epoch {epoch}") - for batch in train_dataloader: - optimizer.zero_grad() - batch = {k: v.to('cuda') for k, v in batch.items()} - outputs = model(**batch) - loss = outputs.loss - loss.backward() - optimizer.step() - lr_scheduler.step() - progress_bar.update(1) - train_loss = loss - - loss = 0.0 - for batch in eval_dataloader: - batch = {k: v.to('cuda') for k, v in batch.items()} - outputs = model(**batch) - # loss = outputs.loss - assert not torch.isnan(outputs.loss), f"{batch}" - loss += outputs.loss.item() - # loss = criterion(outputs.logits, batch["input_ids"]) - test_loss = loss / len(eval_dataloader) - print_rank_0(f"Train Loss: {train_loss:.4f} Test Loss:{test_loss:.4f}") - if args.save_model and test_loss < best_test_loss: - best_test_loss = test_loss - torch.save(model.state_dict(), "./checkpoints/best_model.pth") - - -if __name__ == "__main__": - args = get_args() - colossalai.launch_from_torch(config=args.config) - if args.model == 'bert-base-uncased': - model = BertForMaskedLM.from_pretrained("bert-base-uncased") - elif args.model == 'gpt2': - model = GPT2LMHeadModel.from_pretrained("gpt2") - else: - raise AttributeError("model not supported") - shard_config = ShardConfig( - rank=int(str(get_current_device()).split(':')[-1]), - world_size=int(os.environ['WORLD_SIZE']), - ) - sharded_model = shard_model(model, shard_config) - - if args.mode == "train": - train(sharded_model, args) - elif args.mode == "inference": - inference(sharded_model, args) - else: - raise NotImplementedError diff --git a/tests/test_shardformer/bert/bert_test.py b/tests/test_shardformer/test_bert/test_bert.py similarity index 100% rename from tests/test_shardformer/bert/bert_test.py rename to tests/test_shardformer/test_bert/test_bert.py diff --git a/tests/test_shardformer/test_module/test_distcrossentropy.py b/tests/test_shardformer/test_module/test_distcrossentropy.py new file mode 100644 index 000000000000..9a19ec57821d --- /dev/null +++ b/tests/test_shardformer/test_module/test_distcrossentropy.py @@ -0,0 +1,42 @@ +import pytest +import torch +import torch.nn.functional as F + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.dist_crossentropy import applyDistCrossEntropy +from colossalai.testing import rerun_if_address_is_in_use, spawn + +CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),) + + +def check_dist_crossentropy(rank, world_size, port, ignore_index): + disable_existing_loggers() + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host='localhost', backend='nccl') + + # prepare data + pred = torch.randn(2, 4, 8, requires_grad=True) + labels = torch.randint(8, (2, 4)) + # set some label to -100 to test the ignore index + labels[0, -1] = ignore_index + + org_pred = pred.view(-1, 8) + org_labels = labels.view(-1) + org_loss = F.cross_entropy(org_pred, org_labels) + + dist_pred = pred.chunk(world_size, -1)[rank] + dist_loss = applyDistCrossEntropy(dist_pred.to('cuda'), labels.to('cuda'), ignore_index=ignore_index) + + assert torch.allclose(org_loss, dist_loss, + atol=1e-5), f"dist cross entropy loss is not equal to orgin loss\n{org_loss}\n{dist_loss}" + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_dist_crossentropy(): + ignore_index = -100 + spawn(check_dist_crossentropy, 2, ignore_index=ignore_index) + + +if __name__ == '__main__': + test_dist_crossentropy() From 739fbb09ad45cb8331348cac2e5952f44c863921 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Thu, 8 Jun 2023 09:31:59 +0800 Subject: [PATCH 08/11] remove unused file --- colossalai/shardformer/test/module_test.py | 52 ---------------------- 1 file changed, 52 deletions(-) delete mode 100644 colossalai/shardformer/test/module_test.py diff --git a/colossalai/shardformer/test/module_test.py b/colossalai/shardformer/test/module_test.py deleted file mode 100644 index 6f3b43496f01..000000000000 --- a/colossalai/shardformer/test/module_test.py +++ /dev/null @@ -1,52 +0,0 @@ -import os - -import torch -import torch.nn as nn -import torch.nn.functional as F - -import colossalai -from colossalai.shardformer.layer.dist_crossentropy import applyDistCrossEntropy -from colossalai.shardformer.layer.dropout import Dropout1D - - -def get_args(): - parser = colossalai.get_default_parser() - parser.add_argument("--module", type=str, default='distloss') - return parser.parse_args() - - -def test_dist_crossentropy(): - pred = torch.randn(2, 4, 8, requires_grad=True) - print(pred) - labels = torch.randint(8, (2, 4)) - labels[0, -1] = -100 - print(labels) - pred_ = pred.view(-1, 8) - labels_ = labels.view(-1) - loss = F.cross_entropy(pred_, labels_) - loss.backward() - print(f"normal loss:{loss}") - - pred = pred.chunk(int(os.environ['WORLD_SIZE']), -1)[int(os.environ['RANK'])] - loss = applyDistCrossEntropy(pred.to('cuda'), labels.to('cuda')) - loss.backward() - print(f"dist loss:{loss}") - - -def test_dropout(): - input = torch.randn(5, 4).to("cuda") - m = Dropout1D(p=0.2).to("cuda") - for i in range(2): - print(f"Output: {m(input)}") - print(torch.randn(1)) - - -if __name__ == '__main__': - args = get_args() - colossalai.launch_from_torch(config={}) - if args.module == 'distloss': - test_dist_crossentropy() - elif args.module == 'dropout': - test_dropout() - else: - print("not implemented yet") From 93cc9e2ad20a0548cec100550c93494c8062cc16 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Thu, 8 Jun 2023 09:54:03 +0800 Subject: [PATCH 09/11] remove unused file --- colossalai/shardformer/test/config.py | 1 - 1 file changed, 1 deletion(-) delete mode 100644 colossalai/shardformer/test/config.py diff --git a/colossalai/shardformer/test/config.py b/colossalai/shardformer/test/config.py deleted file mode 100644 index 2b80d8b3ca12..000000000000 --- a/colossalai/shardformer/test/config.py +++ /dev/null @@ -1 +0,0 @@ -parallel = dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')) From ccec046cd0154142a8ab679b5134d0e8890e567d Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Thu, 8 Jun 2023 14:18:10 +0800 Subject: [PATCH 10/11] rename the file --- .../{test_bert/test_bert.py => test_model/test_shard_bert.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/test_shardformer/{test_bert/test_bert.py => test_model/test_shard_bert.py} (100%) diff --git a/tests/test_shardformer/test_bert/test_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py similarity index 100% rename from tests/test_shardformer/test_bert/test_bert.py rename to tests/test_shardformer/test_model/test_shard_bert.py From 931c8f0a377e28422bde4624f3971a75f21b1d8a Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Fri, 9 Jun 2023 11:53:20 +0800 Subject: [PATCH 11/11] polish code --- colossalai/shardformer/README.md | 4 ++-- colossalai/shardformer/__init__.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index 699b4b7da3ef..222626db3e9d 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -20,7 +20,7 @@ The sample API usage is given below: ``` python -from colossalai.shardformer.shard import ShardConfig, shard_model +from colossalai.shardformer import ShardConfig, shard_model from transformers import BertForMaskedLM # create huggingface model as normal @@ -77,7 +77,7 @@ More details can be found in shardformer/policies/basepolicy.py ``` python from colossalai.shardformer.policies.basepolicy import Policy, Layer, Col_Layer, Row_Layer, Argument -CustomPolicy(Policy): +class CustomPolicy(Policy): @staticmethod def argument_policy(model_config, shard_config: int) -> Dict[nn.Module, Argument]: r""" diff --git a/colossalai/shardformer/__init__.py b/colossalai/shardformer/__init__.py index e69de29bb2d1..50c92738077a 100644 --- a/colossalai/shardformer/__init__.py +++ b/colossalai/shardformer/__init__.py @@ -0,0 +1 @@ +from .shard import ShardConfig, shard_model