From 5d3b97c8147dcd686b9fac53abbc8eb6d0562eda Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Fri, 14 Jul 2023 10:13:56 +0800 Subject: [PATCH 1/9] add forward for GPTLMHeadModel --- colossalai/shardformer/policies/gpt2.py | 138 +++++++++++++++++- .../test_model/test_shard_gpt2.py | 2 + .../test_model/test_shard_gpt2_pipeline.py | 45 +++--- 3 files changed, 155 insertions(+), 30 deletions(-) diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index ffba27a50e72..6b3c98c9844c 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -5,7 +5,7 @@ import torch from torch import Tensor -from torch.nn import Module +from torch.nn import CrossEntropyLoss, Module import colossalai.shardformer.layer as col_nn from colossalai.pipeline.stage_manager import PipelineStageManager @@ -149,6 +149,7 @@ def module_policy(self): def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None module = self.model stage_manager = self.pipeline_stage_manager held_layers = [] @@ -163,8 +164,7 @@ def get_held_layers(self) -> List[Module]: return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: - # TODO: check whether there is shared param in gpt2model - """No shared params in gpt2 model.""" + """No shared params in GPT2Model.""" return [] @@ -188,10 +188,52 @@ def module_policy(self): ]) } module_policy.update(addon_module) + + if self.pipeline_stage_manager: + # set None as default + stage_manager = self.pipeline_stage_manager + layers_per_stage = Policy.distribute_layers(len(self.model.transformer.h), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = { + 'forward': + partial(GPT2PipelineForwards.gpt2_lmhead_model_forward, + stage_manager=stage_manager, + stage_index=stage_index) + } + self.append_or_create_method_replacement(description=method_replacement, + policy=module_policy, + target_key=GPT2LMHeadModel) + return module_policy + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + module = self.model + sub_module = self.model.transformer + stage_manager = self.pipeline_stage_manager + held_layers = [] + layers_per_stage = self.distribute_layers(len(sub_module.h), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(sub_module.wte) + held_layers.append(sub_module.wpe) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(sub_module.h[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(sub_module.ln_f) + held_layers.append(module.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + '''The weights of wte and lm_head are shared.''' + module = self.model + stage_manager = self.pipeline_stage_manager + first_stage, last_stage = 0, stage_manager.num_stages - 1 + return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}] + def postprocess(self): - if self.shard_config.enable_tensor_parallelism: + if self.shard_config.enable_tensor_parallelism \ + and self.pipeline_stage_manager is None: binding_map = {"transformer.wte.weight": "lm_head.weight"} for k, v in binding_map.items(): param = getattr_(self.model, k) @@ -199,7 +241,7 @@ def postprocess(self): return self.model -# GPT22DoubleHeadsModel +# GPT2DoubleHeadsModel class GPT2DoubleHeadsModelPolicy(GPT2Policy): def __init__(self) -> None: @@ -299,8 +341,7 @@ def gpt2_model_forward( if token_type_ids is not None: token_type_ids = token_type_ids.view(-1, seq_length) else: - if hidden_states is None: - raise ValueError("hidden_states shouln't be None for stages other than the first stage.") + assert hidden_states is not None input_shape = hidden_states.size()[:-1] batch_size, seq_length = input_shape[0], input_shape[1] device = hidden_states.device @@ -462,3 +503,86 @@ def custom_forward(*inputs): else: # always return dict for intermediate stage return {'hidden_states': hidden_states} + + @staticmethod + def gpt2_lmhead_model_forward( + self: 'GPT2LMHeadModel', + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None) -> Union[Tuple, 'CausalLMOutputWithCrossAttentions']: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + + This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward. + Please refer to original code of transformers for more details. + """ + + from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index) + + # If not at the last stage, return hidden_states as in GPT2Model + if not stage_manager.is_last_stage(): + return {'hidden_states': outputs['hidden_states']} + + torch.cuda.set_device(hidden_states.device) + hidden_states = outputs[0] + lm_logits = self.lm_head(hidden_states) + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 552c6e2f4d53..cc41b158db15 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -67,6 +67,8 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}" torch.cuda.empty_cache() + torch.cuda.empty_cache() + @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) diff --git a/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py b/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py index 5f92f638f863..18dc5cb51b51 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py @@ -21,8 +21,8 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo pass +@parameterize('enable_tensor_parallelism', [False, True]) @parameterize('enable_fused_normalization', [False]) -@parameterize('enable_tensor_parallelism', [False]) @parameterize('use_lazy_init', [False]) #TODO: merge this into test_shard_gpt2 def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): @@ -33,29 +33,28 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - if name != "transformers_gpt": - continue - - inputs = data_gen_fn() - inputs = {k: v.cuda() for k, v in inputs.items()} - - org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, - enable_tensor_parallelism, use_lazy_init) - org_model.train() - org_output = org_model(**inputs) - hidden_state_shape = org_output['last_hidden_state'].shape - - if stage_manager.is_first_stage(): - output = sharded_model(**inputs) - assert output['hidden_states'].shape == hidden_state_shape - else: - attention_mask = inputs['attention_mask'] - hidden_states = torch.zeros(*hidden_state_shape).cuda() - output = sharded_model(hidden_states=hidden_states, attention_mask=attention_mask) - if stage_manager.is_last_stage(): - assert output['last_hidden_state'].shape == hidden_state_shape - else: + if name == 'transformers_gpt': + inputs = data_gen_fn() + inputs = {k: v.cuda() for k, v in inputs.items()} + input_ids, attention_mask = inputs['input_ids'], inputs['attention_mask'] + batch_size, seq_len = input_ids.shape + hidden_size = 768 + hidden_state_shape = (batch_size, seq_len, hidden_size) + + _, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, + enable_tensor_parallelism, use_lazy_init) + + sharded_model.train() + if stage_manager.is_first_stage(): + output = sharded_model(**inputs) assert output['hidden_states'].shape == hidden_state_shape + else: + hidden_states = torch.zeros(*hidden_state_shape).cuda() + output = sharded_model(hidden_states=hidden_states, attention_mask=attention_mask) + if stage_manager.is_last_stage(): + assert output['last_hidden_state'].shape == hidden_state_shape + else: + assert output['hidden_states'].shape == hidden_state_shape torch.cuda.empty_cache() From c65697b5bcce9ca4d30284731a34c60fefd863cf Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Fri, 14 Jul 2023 11:40:23 +0800 Subject: [PATCH 2/9] add test for gpt_lm --- .../test_model/test_shard_gpt2_pipeline.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py b/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py index 18dc5cb51b51..9ba2726508ed 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py @@ -21,7 +21,7 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo pass -@parameterize('enable_tensor_parallelism', [False, True]) +@parameterize('enable_tensor_parallelism', [False]) @parameterize('enable_fused_normalization', [False]) @parameterize('use_lazy_init', [False]) #TODO: merge this into test_shard_gpt2 @@ -33,7 +33,7 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - if name == 'transformers_gpt': + if name in ['transformers_gpt', 'transformers_gpt_lm']: inputs = data_gen_fn() inputs = {k: v.cuda() for k, v in inputs.items()} input_ids, attention_mask = inputs['input_ids'], inputs['attention_mask'] @@ -48,13 +48,10 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz if stage_manager.is_first_stage(): output = sharded_model(**inputs) assert output['hidden_states'].shape == hidden_state_shape - else: + elif not stage_manager.is_last_stage(): hidden_states = torch.zeros(*hidden_state_shape).cuda() output = sharded_model(hidden_states=hidden_states, attention_mask=attention_mask) - if stage_manager.is_last_stage(): - assert output['last_hidden_state'].shape == hidden_state_shape - else: - assert output['hidden_states'].shape == hidden_state_shape + assert output['hidden_states'].shape == hidden_state_shape torch.cuda.empty_cache() From 455d21455576fadd807e8bbc8481e043c4c75aba Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Fri, 14 Jul 2023 15:39:52 +0800 Subject: [PATCH 3/9] arranging get_held_layers method --- colossalai/shardformer/policies/gpt2.py | 54 +++++++++---------- .../test_model/test_shard_gpt2_pipeline.py | 2 - 2 files changed, 25 insertions(+), 31 deletions(-) diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 6b3c98c9844c..e88caebc993d 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -120,6 +120,27 @@ def module_policy(self): def postprocess(self): return self.model + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + + if self.model.__class__.__name__ == 'GPT2Model': + module = self.model + else: + module = self.model.transformer + stage_manager = self.pipeline_stage_manager + + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.wte) + held_layers.append(module.wpe) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.h[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.ln_f) + return held_layers + # GPT2Model class GPT2ModelPolicy(GPT2Policy): @@ -148,20 +169,7 @@ def module_policy(self): return policy def get_held_layers(self) -> List[Module]: - """Get pipeline layers for current stage.""" - assert self.pipeline_stage_manager is not None - module = self.model - stage_manager = self.pipeline_stage_manager - held_layers = [] - layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) - if stage_manager.is_first_stage(): - held_layers.append(module.wte) - held_layers.append(module.wpe) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) - held_layers.extend(module.h[start_idx:end_idx]) - if stage_manager.is_last_stage(): - held_layers.append(module.ln_f) - return held_layers + return super().get_held_layers() def get_shared_params(self) -> List[Dict[int, Tensor]]: """No shared params in GPT2Model.""" @@ -207,21 +215,9 @@ def module_policy(self): return module_policy def get_held_layers(self) -> List[Module]: - """Get pipeline layers for current stage.""" - assert self.pipeline_stage_manager is not None - module = self.model - sub_module = self.model.transformer - stage_manager = self.pipeline_stage_manager - held_layers = [] - layers_per_stage = self.distribute_layers(len(sub_module.h), stage_manager.num_stages) - if stage_manager.is_first_stage(): - held_layers.append(sub_module.wte) - held_layers.append(sub_module.wpe) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) - held_layers.extend(sub_module.h[start_idx:end_idx]) - if stage_manager.is_last_stage(): - held_layers.append(sub_module.ln_f) - held_layers.append(module.lm_head) + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: diff --git a/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py b/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py index 9ba2726508ed..f4c099452b73 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py @@ -40,10 +40,8 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz batch_size, seq_len = input_ids.shape hidden_size = 768 hidden_state_shape = (batch_size, seq_len, hidden_size) - _, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, enable_tensor_parallelism, use_lazy_init) - sharded_model.train() if stage_manager.is_first_stage(): output = sharded_model(**inputs) From 8817726ff7fc80109b4201ce829df67300f8d73e Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Fri, 14 Jul 2023 16:08:51 +0800 Subject: [PATCH 4/9] arrange forward replacement --- colossalai/shardformer/policies/gpt2.py | 54 +++++++++++-------------- 1 file changed, 24 insertions(+), 30 deletions(-) diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index e88caebc993d..9806d1da8457 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -1,7 +1,7 @@ import logging from functools import partial from types import MethodType -from typing import Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import torch from torch import Tensor @@ -141,6 +141,23 @@ def get_held_layers(self) -> List[Module]: held_layers.append(module.ln_f) return held_layers + def set_pipeline_forward(self, model_cls: Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if self.pipeline_stage_manager: + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == 'GPT2Model': + module = self.model + else: + module = self.model.transformer + + layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=model_cls) + # GPT2Model class GPT2ModelPolicy(GPT2Policy): @@ -152,20 +169,9 @@ def module_policy(self): from transformers.models.gpt2.modeling_gpt2 import GPT2Model policy = super().module_policy() - if self.pipeline_stage_manager: - # set None as default - stage_manager = self.pipeline_stage_manager - layers_per_stage = Policy.distribute_layers(len(self.model.h), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = { - 'forward': - partial(GPT2PipelineForwards.gpt2_model_forward, - stage_manager=stage_manager, - stage_index=stage_index) - } - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=GPT2Model) + self.set_pipeline_forward(model_cls=GPT2Model, + new_forward=GPT2PipelineForwards.gpt2_model_forward, + policy=policy) return policy def get_held_layers(self) -> List[Module]: @@ -197,21 +203,9 @@ def module_policy(self): } module_policy.update(addon_module) - if self.pipeline_stage_manager: - # set None as default - stage_manager = self.pipeline_stage_manager - layers_per_stage = Policy.distribute_layers(len(self.model.transformer.h), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = { - 'forward': - partial(GPT2PipelineForwards.gpt2_lmhead_model_forward, - stage_manager=stage_manager, - stage_index=stage_index) - } - self.append_or_create_method_replacement(description=method_replacement, - policy=module_policy, - target_key=GPT2LMHeadModel) - + self.set_pipeline_forward(model_cls=GPT2LMHeadModel, + new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward, + policy=module_policy) return module_policy def get_held_layers(self) -> List[Module]: From 83e5e852fe0d4542e1a025f72035bf2997684391 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Fri, 14 Jul 2023 17:08:05 +0800 Subject: [PATCH 5/9] add forward for GPT2ForTokenClassification --- colossalai/shardformer/policies/gpt2.py | 106 +++++++++++++++++- .../test_model/test_shard_gpt2_pipeline.py | 21 +++- 2 files changed, 119 insertions(+), 8 deletions(-) diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 9806d1da8457..439e46514ec6 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -135,6 +135,7 @@ def get_held_layers(self) -> List[Module]: if stage_manager.is_first_stage(): held_layers.append(module.wte) held_layers.append(module.wpe) + held_layers.append(module.drop) start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) held_layers.extend(module.h[start_idx:end_idx]) if stage_manager.is_last_stage(): @@ -268,6 +269,37 @@ class GPT2ForTokenClassificationPolicy(GPT2Policy): def __init__(self) -> None: super().__init__() + def module_policy(self): + from transformers.models.gpt2.modeling_gpt2 import GPT2ForTokenClassification + + module_policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + addon_module = { + GPT2ForTokenClassification: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}) + ]) + } + module_policy.update(addon_module) + + self.set_pipeline_forward(model_cls=GPT2ForTokenClassification, + new_forward=GPT2PipelineForwards.gpt2_for_token_classification_forward, + policy=module_policy) + return module_policy + + def get_held_layers(self) -> List[Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.dropout) + held_layers.append(self.model.classifier) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in GPT2ForTokenClassification.""" + return [] + # GPT2ForSequenceClassification class GPT2ForSequenceClassificationPolicy(GPT2Policy): @@ -550,7 +582,6 @@ def gpt2_lmhead_model_forward( if not stage_manager.is_last_stage(): return {'hidden_states': outputs['hidden_states']} - torch.cuda.set_device(hidden_states.device) hidden_states = outputs[0] lm_logits = self.lm_head(hidden_states) loss = None @@ -563,7 +594,6 @@ def gpt2_lmhead_model_forward( # Flatten the tokens loss_fct = CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - if not return_dict: output = (lm_logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output @@ -576,3 +606,75 @@ def gpt2_lmhead_model_forward( attentions=outputs.attentions, cross_attentions=outputs.cross_attentions, ) + + @staticmethod + def gpt2_for_token_classification_forward( + self: 'GPT2ForTokenClassification', + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None) -> Union[Tuple, 'TokenClassifierOutput']: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForTokenClassification.forward. + # Please refer to original code of transformers for more details. + """ + + from transformers.modeling_outputs import TokenClassifierOutput + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index) + # If not at the last stage, return hidden_states as in GPT2Model + if not stage_manager.is_last_stage(): + return {'hidden_states': outputs['hidden_states']} + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py b/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py index f4c099452b73..4fe946e5e6d4 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py @@ -33,23 +33,32 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - if name in ['transformers_gpt', 'transformers_gpt_lm']: + if name in ['transformers_gpt', 'transformers_gpt_lm', 'transformers_gpt_for_token_classification']: inputs = data_gen_fn() inputs = {k: v.cuda() for k, v in inputs.items()} input_ids, attention_mask = inputs['input_ids'], inputs['attention_mask'] batch_size, seq_len = input_ids.shape hidden_size = 768 hidden_state_shape = (batch_size, seq_len, hidden_size) + + if not stage_manager.is_first_stage(): + # change inputs if not the first stage + hidden_states = torch.zeros(*hidden_state_shape).cuda() + inputs['input_ids'] = None + inputs['hidden_states'] = hidden_states + _, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, enable_tensor_parallelism, use_lazy_init) sharded_model.train() + output = sharded_model(**inputs) if stage_manager.is_first_stage(): - output = sharded_model(**inputs) - assert output['hidden_states'].shape == hidden_state_shape - elif not stage_manager.is_last_stage(): - hidden_states = torch.zeros(*hidden_state_shape).cuda() - output = sharded_model(hidden_states=hidden_states, attention_mask=attention_mask) assert output['hidden_states'].shape == hidden_state_shape + else: + if stage_manager.is_last_stage(): + if name != 'transformers_gpt': + assert output.loss is not None + else: + assert output['hidden_states'].shape == hidden_state_shape torch.cuda.empty_cache() From 5ebb069725f818731811e39df8ee9bc9ca4fe8c6 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Fri, 14 Jul 2023 17:39:02 +0800 Subject: [PATCH 6/9] add forward for GPT2ForSequenceClassification --- colossalai/shardformer/policies/gpt2.py | 153 +++++++++++++++++- .../test_model/test_shard_gpt2_pipeline.py | 5 +- 2 files changed, 150 insertions(+), 8 deletions(-) diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 439e46514ec6..42f589904cff 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -4,8 +4,8 @@ from typing import Callable, Dict, List, Optional, Tuple, Union import torch -from torch import Tensor -from torch.nn import CrossEntropyLoss, Module +from torch import Tensor, nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss import colossalai.shardformer.layer as col_nn from colossalai.pipeline.stage_manager import PipelineStageManager @@ -120,7 +120,7 @@ def module_policy(self): def postprocess(self): return self.model - def get_held_layers(self) -> List[Module]: + def get_held_layers(self) -> List[nn.Module]: """Get pipeline layers for current stage.""" assert self.pipeline_stage_manager is not None @@ -142,7 +142,7 @@ def get_held_layers(self) -> List[Module]: held_layers.append(module.ln_f) return held_layers - def set_pipeline_forward(self, model_cls: Module, new_forward: Callable, policy: Dict) -> None: + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: """If under pipeline parallel setting, replacing the original forward method of huggingface to customized forward method, and add this changing to policy.""" if self.pipeline_stage_manager: @@ -175,7 +175,7 @@ def module_policy(self): policy=policy) return policy - def get_held_layers(self) -> List[Module]: + def get_held_layers(self) -> List[nn.Module]: return super().get_held_layers() def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -209,7 +209,7 @@ def module_policy(self): policy=module_policy) return module_policy - def get_held_layers(self) -> List[Module]: + def get_held_layers(self) -> List[nn.Module]: held_layers = super().get_held_layers() if self.pipeline_stage_manager.is_last_stage(): held_layers.append(self.model.lm_head) @@ -289,7 +289,7 @@ def module_policy(self): policy=module_policy) return module_policy - def get_held_layers(self) -> List[Module]: + def get_held_layers(self) -> List[nn.Module]: held_layers = super().get_held_layers() if self.pipeline_stage_manager.is_last_stage(): held_layers.append(self.model.dropout) @@ -307,6 +307,36 @@ class GPT2ForSequenceClassificationPolicy(GPT2Policy): def __init__(self) -> None: super().__init__() + def module_policy(self): + from transformers.models.gpt2.modeling_gpt2 import GPT2ForSequenceClassification + + module_policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + addon_module = { + GPT2ForSequenceClassification: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="score", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}) + ]) + } + module_policy.update(addon_module) + + self.set_pipeline_forward(model_cls=GPT2ForSequenceClassification, + new_forward=GPT2PipelineForwards.gpt2_for_sequence_classification_forward, + policy=module_policy) + return module_policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.score) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in GPT2ForTokenClassification.""" + return [] + class GPT2PipelineForwards: ''' @@ -654,6 +684,7 @@ def gpt2_for_token_classification_forward( stage_manager=stage_manager, hidden_states=hidden_states, stage_index=stage_index) + # If not at the last stage, return hidden_states as in GPT2Model if not stage_manager.is_last_stage(): return {'hidden_states': outputs['hidden_states']} @@ -678,3 +709,111 @@ def gpt2_for_token_classification_forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + @staticmethod + def gpt2_for_sequence_classification_forward( + self: 'GPT2ForSequenceClassification', + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None) -> Union[Tuple, 'SequenceClassifierOutputWithPast']: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification.forward. + # Please refer to original code of transformers for more details. + """ + from transformers.modeling_outputs import SequenceClassifierOutputWithPast + + if input_ids is not None: + batch_size, _ = input_ids.shape[:2] + else: + batch_size, _ = hidden_states.shape[:2] + assert (self.config.pad_token_id is not None + or batch_size == 1), "Cannot handle batch sizes > 1 if no padding token is defined." + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index) + + # If not at the last stage, return hidden_states as in GPT2Model + if not stage_manager.is_last_stage(): + return {'hidden_states': outputs['hidden_states']} + + hidden_states = outputs[0] + logits = self.score(hidden_states) + + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + else: + sequence_lengths = -1 + logging.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`") + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py b/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py index 4fe946e5e6d4..c281e38faf71 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py @@ -33,7 +33,10 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - if name in ['transformers_gpt', 'transformers_gpt_lm', 'transformers_gpt_for_token_classification']: + if name in [ + 'transformers_gpt', 'transformers_gpt_lm', 'transformers_gpt_for_token_classification', + 'transformers_gpt_for_sequence_classification' + ]: inputs = data_gen_fn() inputs = {k: v.cuda() for k, v in inputs.items()} input_ids, attention_mask = inputs['input_ids'], inputs['attention_mask'] From 6ce317da745fc4c2958c253b820b92ebca4b9367 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Fri, 14 Jul 2023 17:47:40 +0800 Subject: [PATCH 7/9] fix test_shard_gpt2.py --- tests/test_shardformer/test_model/test_shard_gpt2.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index cc41b158db15..552c6e2f4d53 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -67,8 +67,6 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}" torch.cuda.empty_cache() - torch.cuda.empty_cache() - @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) From 4a86da146c9e71b932b7414e3be5ec6a9c206087 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Mon, 17 Jul 2023 11:44:38 +0800 Subject: [PATCH 8/9] add GPT2DoubleHeadsmodel & fix bugs --- colossalai/shardformer/policies/gpt2.py | 142 ++++++++++++++++-- .../test_model/test_shard_gpt2_pipeline.py | 52 +++---- 2 files changed, 150 insertions(+), 44 deletions(-) diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 42f589904cff..c1d72dd8aefa 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -48,6 +48,10 @@ def module_policy(self): suffix="wte", target_module=col_nn.VocabParallelEmbedding1D, ), + SubModuleReplacementDescription( + suffix="drop", + target_module=col_nn.DropoutForParallelInput, + ), ]) policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={ "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, @@ -217,8 +221,10 @@ def get_held_layers(self) -> List[nn.Module]: def get_shared_params(self) -> List[Dict[int, Tensor]]: '''The weights of wte and lm_head are shared.''' - module = self.model stage_manager = self.pipeline_stage_manager + if stage_manager is None: + return [] + module = self.model first_stage, last_stage = 0, stage_manager.num_stages - 1 return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}] @@ -252,10 +258,37 @@ def module_policy(self): ]) } module_policy.update(addon_module) + + self.set_pipeline_forward(model_cls=GPT2DoubleHeadsModel, + new_forward=GPT2PipelineForwards.gpt2_double_heads_model_forward, + policy=module_policy) + return module_policy + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + multiple_choice_head = self.model.multiple_choice_head + held_layers.append(self.model.lm_head) + held_layers.append(multiple_choice_head.summary) + held_layers.append(multiple_choice_head.activation) + held_layers.append(multiple_choice_head.first_dropout) + held_layers.append(multiple_choice_head.last_dropout) + + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + '''The weights of wte and lm_head are shared.''' + stage_manager = self.pipeline_stage_manager + if stage_manager is None: + return [] + module = self.model + first_stage, last_stage = 0, stage_manager.num_stages - 1 + return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}] + def postprocess(self): - if self.shard_config.enable_tensor_parallelism: + if self.shard_config.enable_tensor_parallelism \ + and self.pipeline_stage_manager is None: binding_map = {"transformer.wte.weight": "lm_head.weight"} for k, v in binding_map.items(): param = getattr_(self.model, k) @@ -278,8 +311,7 @@ def module_policy(self): addon_module = { GPT2ForTokenClassification: ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}) + SubModuleReplacementDescription(suffix="dropout", target_module=col_nn.DropoutForParallelInput) ]) } module_policy.update(addon_module) @@ -311,17 +343,6 @@ def module_policy(self): from transformers.models.gpt2.modeling_gpt2 import GPT2ForSequenceClassification module_policy = super().module_policy() - - if self.shard_config.enable_tensor_parallelism: - addon_module = { - GPT2ForSequenceClassification: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="score", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}) - ]) - } - module_policy.update(addon_module) - self.set_pipeline_forward(model_cls=GPT2ForSequenceClassification, new_forward=GPT2PipelineForwards.gpt2_for_sequence_classification_forward, policy=module_policy) @@ -637,6 +658,97 @@ def gpt2_lmhead_model_forward( cross_attentions=outputs.cross_attentions, ) + @staticmethod + def gpt2_double_heads_model_forward( + self: 'GPT2DoubleHeadsModel', + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + mc_token_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + mc_labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None) -> Union[Tuple, 'GPT2DoubleHeadsModelOutput']: + r""" + mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): + Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - + 1]`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to + `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]` + mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` + where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above) + + This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel.forward. + Please refer to original code of transformers for more details. + ```""" + from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModelOutput + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index) + + # If not at the last stage, return hidden_states as in GPT2Model + if not stage_manager.is_last_stage(): + return {'hidden_states': outputs['hidden_states']} + + hidden_states = outputs[0] + lm_logits = self.lm_head(hidden_states) + mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) + + mc_loss = None + if mc_labels is not None: + loss_fct = CrossEntropyLoss() + mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)) + lm_loss = None + if labels is not None: + labels = labels.to(lm_logits.device) + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits, mc_logits) + outputs[1:] + if mc_loss is not None: + output = (mc_loss,) + output + return ((lm_loss,) + output) if lm_loss is not None else output + + return GPT2DoubleHeadsModelOutput( + loss=lm_loss, + mc_loss=mc_loss, + logits=lm_logits, + mc_logits=mc_logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + @staticmethod def gpt2_for_token_classification_forward( self: 'GPT2ForTokenClassification', diff --git a/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py b/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py index c281e38faf71..91aea433e114 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py @@ -33,35 +33,29 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - if name in [ - 'transformers_gpt', 'transformers_gpt_lm', 'transformers_gpt_for_token_classification', - 'transformers_gpt_for_sequence_classification' - ]: - inputs = data_gen_fn() - inputs = {k: v.cuda() for k, v in inputs.items()} - input_ids, attention_mask = inputs['input_ids'], inputs['attention_mask'] - batch_size, seq_len = input_ids.shape - hidden_size = 768 - hidden_state_shape = (batch_size, seq_len, hidden_size) - - if not stage_manager.is_first_stage(): - # change inputs if not the first stage - hidden_states = torch.zeros(*hidden_state_shape).cuda() - inputs['input_ids'] = None - inputs['hidden_states'] = hidden_states - - _, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, - enable_tensor_parallelism, use_lazy_init) - sharded_model.train() - output = sharded_model(**inputs) - if stage_manager.is_first_stage(): - assert output['hidden_states'].shape == hidden_state_shape - else: - if stage_manager.is_last_stage(): - if name != 'transformers_gpt': - assert output.loss is not None - else: - assert output['hidden_states'].shape == hidden_state_shape + + inputs = data_gen_fn() + inputs = {k: v.cuda() for k, v in inputs.items()} + input_ids, attention_mask = inputs['input_ids'], inputs['attention_mask'] + batch_size, seq_len = input_ids.shape + hidden_size = 768 + hidden_state_shape = (batch_size, seq_len, hidden_size) + + if not stage_manager.is_first_stage(): + # change inputs if not the first stage + hidden_states = torch.zeros(*hidden_state_shape).cuda() + inputs['input_ids'] = None + inputs['hidden_states'] = hidden_states + + _, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, + enable_tensor_parallelism, use_lazy_init) + sharded_model.train() + output = sharded_model(**inputs) + if stage_manager.is_last_stage(): + if name != 'transformers_gpt': + assert output.loss is not None + else: + assert output['hidden_states'].shape == hidden_state_shape torch.cuda.empty_cache() From e153172ad1e7fc4c10bf9234be26d542a2d30619 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Mon, 17 Jul 2023 15:08:48 +0800 Subject: [PATCH 9/9] add id checking in get_shared_params --- colossalai/shardformer/policies/gpt2.py | 18 ++++++++++-------- .../test_model/test_shard_gpt2_pipeline.py | 14 ++++---------- 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index c1d72dd8aefa..5d6f47636587 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -221,12 +221,13 @@ def get_held_layers(self) -> List[nn.Module]: def get_shared_params(self) -> List[Dict[int, Tensor]]: '''The weights of wte and lm_head are shared.''' + module = self.model stage_manager = self.pipeline_stage_manager - if stage_manager is None: + if stage_manager and id(module.transformer.wte.weight) == id(module.lm_head.weight): + first_stage, last_stage = 0, stage_manager.num_stages - 1 + return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}] + else: return [] - module = self.model - first_stage, last_stage = 0, stage_manager.num_stages - 1 - return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}] def postprocess(self): if self.shard_config.enable_tensor_parallelism \ @@ -279,12 +280,13 @@ def get_held_layers(self) -> List[nn.Module]: def get_shared_params(self) -> List[Dict[int, Tensor]]: '''The weights of wte and lm_head are shared.''' + module = self.model stage_manager = self.pipeline_stage_manager - if stage_manager is None: + if stage_manager and id(module.transformer.wte.weight) == id(module.lm_head.weight): + first_stage, last_stage = 0, stage_manager.num_stages - 1 + return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}] + else: return [] - module = self.model - first_stage, last_stage = 0, stage_manager.num_stages - 1 - return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}] def postprocess(self): if self.shard_config.enable_tensor_parallelism \ diff --git a/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py b/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py index 91aea433e114..dd439a394827 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py @@ -5,15 +5,9 @@ from colossalai.cluster import ProcessGroupMesh from colossalai.logging import disable_existing_loggers from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.testing import ( - assert_hf_output_close, - clear_cache_before_run, - parameterize, - rerun_if_address_is_in_use, - spawn, -) +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward +from tests.test_shardformer.test_model._utils import build_pipeline_model def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): @@ -32,11 +26,11 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz stage_manager = PipelineStageManager(pg_mesh, PP_DIM) sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): inputs = data_gen_fn() inputs = {k: v.cuda() for k, v in inputs.items()} - input_ids, attention_mask = inputs['input_ids'], inputs['attention_mask'] + input_ids, _ = inputs['input_ids'], inputs['attention_mask'] batch_size, seq_len = input_ids.shape hidden_size = 768 hidden_state_shape = (batch_size, seq_len, hidden_size)