From 19b83404ff4692c47f2aee8d8fb546dc4ce97b8f Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Fri, 7 Jul 2023 18:50:36 +0800 Subject: [PATCH 1/8] add pipeline forward --- colossalai/shardformer/policies/bert.py | 27 +++--- colossalai/shardformer/shard/sharder.py | 5 +- tests/kit/model_zoo/torchrec/__init__.py | 2 +- tests/test_shardformer/test_model/_utils.py | 19 ++++ .../test_model/test_shard_bert_pipeline.py | 86 +++++++++++++++++++ 5 files changed, 127 insertions(+), 12 deletions(-) create mode 100644 tests/test_shardformer/test_model/test_shard_bert_pipeline.py diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index e18cb6ece674..cd2e784fa457 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -134,7 +134,6 @@ def module_policy(self): ], policy=policy, target_key=BertLayer) - # handle embedding layer self.append_or_create_submodule_replacement( description=[SubModuleReplacementDescription( @@ -143,6 +142,7 @@ def module_policy(self): )], policy=policy, target_key=BertEmbeddings) + return policy def add_lm_head_policy(self, base_policy): @@ -176,6 +176,13 @@ class BertModelPolicy(BertPolicy): def __init__(self) -> None: super().__init__() + def module_policy(self): + module_policy = super().module_policy() + from transformers.models.bert.modeling_bert import BertModel + module_policy[BertModel] = ModulePolicyDescription( + method_replacement={'forward': partial(bert_model_forward, stage_manager=self.pipeline_stage_manager)}) + return module_policy + def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" module = self.model @@ -428,7 +435,8 @@ def bert_model_forward( use_cache = use_cache if use_cache is not None else self.config.use_cache else: use_cache = False - + if stage_manager.stage == 1: + print('hidden_states', hidden_states) if stage_manager.is_first_stage(): if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") @@ -440,6 +448,13 @@ def bert_model_forward( raise ValueError("You have to specify either input_ids or inputs_embeds") batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) else: input_shape = hidden_states.size()[:-1] batch_size, seq_length = input_shape @@ -462,14 +477,6 @@ def bert_model_forward( if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) - if token_type_ids is None: - if hasattr(self.embeddings, "token_type_ids"): - buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] - buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) - token_type_ids = buffered_token_type_ids_expanded - else: - token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index ca2f46a187d1..a488fb5c7ca5 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -1,3 +1,4 @@ +from types import MethodType from typing import Any, Callable, Dict, List, Union import torch.nn as nn @@ -134,7 +135,9 @@ def _replace_param( def _replace_method(self, module: nn.Module, method_replacement: Dict[str, Callable]): for method_name, new_method in method_replacement.items(): # bind the new method to the module - setattr(module, method_name, new_method.__get__(module, module.__class__)) + bound_method = MethodType(new_method, module) + original_method = getattr(module, method_name) + setattr(module, method_name, bound_method) def _replace_sub_module( self, diff --git a/tests/kit/model_zoo/torchrec/__init__.py b/tests/kit/model_zoo/torchrec/__init__.py index 43952e6998cf..4a19f2449602 100644 --- a/tests/kit/model_zoo/torchrec/__init__.py +++ b/tests/kit/model_zoo/torchrec/__init__.py @@ -1 +1 @@ -from .torchrec import * +#from .torchrec import * diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index e03014f3f234..a95460d920bf 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -1,5 +1,6 @@ import copy +from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer @@ -16,6 +17,24 @@ def build_model(model_fn, enable_fused_normalization=True, enable_tensor_paralle return org_model, sharded_model.cuda() +def build_pipeline_model(model_fn, + stage_manager=None, + enable_fused_normalization=False, + enable_tensor_parallelism=False): + # create new model + org_model = model_fn().cuda() + + # shard model + shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, + enable_tensor_parallelism=enable_tensor_parallelism, + pipeline_stage_manager=stage_manager) + + model_copy = copy.deepcopy(org_model) + shard_former = ShardFormer(shard_config=shard_config) + sharded_model, shared_params = shard_former.optimize(model_copy) + return org_model, sharded_model.cuda() + + def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): # prepare input data = data_gen_fn() diff --git a/tests/test_shardformer/test_model/test_shard_bert_pipeline.py b/tests/test_shardformer/test_model/test_shard_bert_pipeline.py new file mode 100644 index 000000000000..4b605dae27a1 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_bert_pipeline.py @@ -0,0 +1,86 @@ +import pytest +import torch + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward + + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # check forward + pass + + +@parameterize('enable_fused_normalization', [False]) +@parameterize('enable_tensor_parallelism', [False]) +def run_bert_test(enable_fused_normalization, enable_tensor_parallelism): + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + + sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') + x = torch.randint(0, 1000, (2, 3)) + hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32) + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + if name == 'transformers_bert': + org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, + enable_tensor_parallelism) + + if stage_manager.stage == 2: + attention_mask = torch.ones_like(x) + output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) + print(output['hidden_states'].shape) + assert output['hidden_states'].shape == (2, 3, 768) + print('----------------------') + elif stage_manager.stage == 1: + print('----------------------') + attention_mask = torch.ones((2, 3)) + output = sharded_model(hidden_states=hidden_states, + attention_mask=attention_mask, + stage_manager=stage_manager) + print(output[0].shape) + assert output[0].shape == (2, 3, 768) + print('----------------------') + + torch.cuda.empty_cache() + + +def check_bert(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_bert_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_bert(): + spawn(check_bert, 4) + + +if __name__ == "__main__": + test_bert() From 41b76e530951fcc1823873a0022ce50c3f941434 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Fri, 7 Jul 2023 22:35:33 +0800 Subject: [PATCH 2/8] complete pipeline forward check --- colossalai/shardformer/policies/bert.py | 3 +-- .../test_model/test_shard_bert_pipeline.py | 20 +++++++++---------- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index cd2e784fa457..48d96858b200 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -435,8 +435,7 @@ def bert_model_forward( use_cache = use_cache if use_cache is not None else self.config.use_cache else: use_cache = False - if stage_manager.stage == 1: - print('hidden_states', hidden_states) + if stage_manager.is_first_stage(): if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") diff --git a/tests/test_shardformer/test_model/test_shard_bert_pipeline.py b/tests/test_shardformer/test_model/test_shard_bert_pipeline.py index 4b605dae27a1..d26a86d92345 100644 --- a/tests/test_shardformer/test_model/test_shard_bert_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_bert_pipeline.py @@ -43,28 +43,26 @@ def run_bert_test(enable_fused_normalization, enable_tensor_parallelism): stage_manager = PipelineStageManager(pg_mesh, PP_DIM) sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') - x = torch.randint(0, 1000, (2, 3)) - hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32) + x = torch.randint(0, 1000, (2, 3)).cuda() + hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda() for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): if name == 'transformers_bert': org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, enable_tensor_parallelism) - if stage_manager.stage == 2: - attention_mask = torch.ones_like(x) + if stage_manager.stage == 0: + attention_mask = torch.ones_like(x).cuda() output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) print(output['hidden_states'].shape) - assert output['hidden_states'].shape == (2, 3, 768) - print('----------------------') - elif stage_manager.stage == 1: - print('----------------------') - attention_mask = torch.ones((2, 3)) + assert output['hidden_states'].shape == (2, 3, 128) + print('end of the first stage') + else: + attention_mask = torch.ones((2, 3)).cuda() output = sharded_model(hidden_states=hidden_states, attention_mask=attention_mask, stage_manager=stage_manager) print(output[0].shape) - assert output[0].shape == (2, 3, 768) - print('----------------------') + assert output[0].shape == (2, 3, 128) torch.cuda.empty_cache() From 5b10144bd53341b7950865245ad914b0714906f3 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 10 Jul 2023 10:44:23 +0800 Subject: [PATCH 3/8] fix bert forward without pipeline --- colossalai/shardformer/policies/bert.py | 136 +++++++++++++++++++++++- 1 file changed, 134 insertions(+), 2 deletions(-) diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 48d96858b200..cec200033556 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -13,6 +13,8 @@ CausalLMOutputWithCrossAttentions, ) from transformers.models.bert.modeling_bert import ( + BertForMaskedLM, + BertForNextSentencePrediction, BertForPreTraining, BertForPreTrainingOutput, BertLMHeadModel, @@ -179,8 +181,10 @@ def __init__(self) -> None: def module_policy(self): module_policy = super().module_policy() from transformers.models.bert.modeling_bert import BertModel - module_policy[BertModel] = ModulePolicyDescription( - method_replacement={'forward': partial(bert_model_forward, stage_manager=self.pipeline_stage_manager)}) + if self.pipeline_stage_manager: + # set None as default + module_policy[BertModel] = ModulePolicyDescription( + method_replacement={'forward': partial(bert_model_forward, stage_manager=self.pipeline_stage_manager)}) return module_policy def get_held_layers(self) -> List[Module]: @@ -780,3 +784,131 @@ def bert_lmhead_forward(self: BertLMHeadModel, hidden_states = outputs.get('hidden_states') # intermediate stage always return dict return {'hidden_states': hidden_states} + + +def bert_for_masked_lm_forward( + self: BertForMaskedLM, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.Tensor] = None, + stage_manager: Optional[PipelineStageManager] = None, +): + #-> Union[Tuple[torch.Tensor], MaskedLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + pass + + +def bert_for_next_sentence_prediction_forward( + self: BertForNextSentencePrediction, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.Tensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + **kwargs, +): + #-> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair + (see `input_ids` docstring). Indices should be in `[0, 1]`: + + - 0 indicates sequence B is a continuation of sequence A, + - 1 indicates sequence B is a random sequence. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, BertForNextSentencePrediction + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + >>> model = BertForNextSentencePrediction.from_pretrained("bert-base-uncased") + + >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." + >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt") + + >>> outputs = model(**encoding, labels=torch.LongTensor([1])) + >>> logits = outputs.logits + >>> assert logits[0, 0] < logits[0, 1] # next sentence was random + ``` + """ + + if "next_sentence_label" in kwargs: + warnings.warn( + "The `next_sentence_label` argument is deprecated and will be removed in a future version, use" + " `labels` instead.", + FutureWarning, + ) + labels = kwargs.pop("next_sentence_label") + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = bert_model_forward( + self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + if stage_manager.is_last_stage(): + pooled_output = outputs[1] + seq_relationship_scores = self.cls(pooled_output) + + next_sentence_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1)) + + if not return_dict: + output = (seq_relationship_scores,) + outputs[2:] + return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output + + return NextSentencePredictorOutput( + loss=next_sentence_loss, + logits=seq_relationship_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get('hidden_states') + # intermediate stage always return dict + return {'hidden_states': hidden_states} From 2eef95c58391f2199afe835730d2eef71c75c4df Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 10 Jul 2023 10:51:49 +0800 Subject: [PATCH 4/8] fix comments --- tests/kit/model_zoo/torchrec/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kit/model_zoo/torchrec/__init__.py b/tests/kit/model_zoo/torchrec/__init__.py index 4a19f2449602..43952e6998cf 100644 --- a/tests/kit/model_zoo/torchrec/__init__.py +++ b/tests/kit/model_zoo/torchrec/__init__.py @@ -1 +1 @@ -#from .torchrec import * +from .torchrec import * From 205f5f0fa60b83e58b41a26c91552e1a40f610bd Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 10 Jul 2023 10:57:41 +0800 Subject: [PATCH 5/8] discard useless line --- colossalai/shardformer/shard/sharder.py | 1 - 1 file changed, 1 deletion(-) diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index a488fb5c7ca5..3e1ebe9687bb 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -136,7 +136,6 @@ def _replace_method(self, module: nn.Module, method_replacement: Dict[str, Calla for method_name, new_method in method_replacement.items(): # bind the new method to the module bound_method = MethodType(new_method, module) - original_method = getattr(module, method_name) setattr(module, method_name, bound_method) def _replace_sub_module( From 76d75c748627432095de7a31d3ac07efb316e745 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 10 Jul 2023 11:19:41 +0800 Subject: [PATCH 6/8] add todo --- tests/test_shardformer/test_model/test_shard_bert_pipeline.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_shardformer/test_model/test_shard_bert_pipeline.py b/tests/test_shardformer/test_model/test_shard_bert_pipeline.py index d26a86d92345..f0e23f756e99 100644 --- a/tests/test_shardformer/test_model/test_shard_bert_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_bert_pipeline.py @@ -24,6 +24,7 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_fused_normalization', [False]) @parameterize('enable_tensor_parallelism', [False]) +#TODO: merge this into test_shard_bert def run_bert_test(enable_fused_normalization, enable_tensor_parallelism): DP_DIM, PP_DIM = 0, 1 DP_SIZE, PP_SIZE = 2, 2 From f0d361129635a2e8b684bb953b2d013fc6d4cb00 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 10 Jul 2023 11:50:47 +0800 Subject: [PATCH 7/8] clean prints --- .../test_shardformer/test_model/test_shard_bert_pipeline.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/test_shardformer/test_model/test_shard_bert_pipeline.py b/tests/test_shardformer/test_model/test_shard_bert_pipeline.py index f7eb09775e9d..9cca5ec8bc51 100644 --- a/tests/test_shardformer/test_model/test_shard_bert_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_bert_pipeline.py @@ -55,15 +55,14 @@ def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_laz if stage_manager.stage == 0: attention_mask = torch.ones_like(x).cuda() output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) - print(output['hidden_states'].shape) + # print(output['hidden_states'].shape) assert output['hidden_states'].shape == (2, 3, 128) - print('end of the first stage') else: attention_mask = torch.ones((2, 3)).cuda() output = sharded_model(hidden_states=hidden_states, attention_mask=attention_mask, stage_manager=stage_manager) - print(output[0].shape) + # print(output[0].shape) assert output[0].shape == (2, 3, 128) torch.cuda.empty_cache() From e768aeb3e7c45f9046558e59470b93ee17804919 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 10 Jul 2023 13:42:59 +0800 Subject: [PATCH 8/8] fix distribute layers --- colossalai/shardformer/policies/base_policy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index 65aee13861ee..aac86eb20a56 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -191,7 +191,7 @@ def distribute_layers(num_layers: int, num_stages: int) -> List[int]: # deal with the rest layers if remainder > 0: - start_position = num_layers // 2 - remainder // 2 + start_position = num_stages // 2 - remainder // 2 for i in range(start_position, start_position + remainder): layers_per_stage[i] += 1 return layers_per_stage