From 7e81ba28118acd2059cddd0248030c240b9db8b2 Mon Sep 17 00:00:00 2001 From: Pengtai Xu Date: Tue, 5 Sep 2023 18:09:35 +0800 Subject: [PATCH 1/9] [shardformer] GPT-J policy dev 5 Sep --- colossalai/shardformer/policies/gptj.py | 122 ++++++++++++++++++++++++ 1 file changed, 122 insertions(+) create mode 100644 colossalai/shardformer/policies/gptj.py diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py new file mode 100644 index 000000000000..06d720f29327 --- /dev/null +++ b/colossalai/shardformer/policies/gptj.py @@ -0,0 +1,122 @@ +import colossalai.shardformer.layer as col_nn + +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = [] + + +class GPTJPolicy(Policy): + def config_sanity_check(self): + pass + + def preprocess(self): + # reshape the embedding layer + r""" + Reshape the Embedding layer to make the embedding dimension divisible by world_size + """ + if self.shard_config.enable_tensor_parallelism: + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + return self.model + + def module_policy(self): + from transformers.models.gptj.modeling_gptj import GPTJBlock, GPTJModel + + policy = {} + use_sequence_parallel = self.shard_config.enable_sequence_parallelism + overlap = self.shard_config.enable_sequence_overlap + if self.shard_config.enable_tensor_parallelism: + policy[GPTJModel] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="wte", + target_module=col_nn.VocabParallelEmbedding1D, + ), + SubModuleReplacementDescription( + suffix="drop", + target_module=col_nn.DropoutForParallelInput, + ), + ] + ) + + policy[GPTJBlock] = ModulePolicyDescription( + attribute_replacement={ + "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attn.k_attn", + target_module=col_nn.GPT2FusedLinearConv1D_Col, + kwargs={"n_fused": 3, "seq_parallel": use_sequence_parallel, "overlap": overlap}, + ), + SubModuleReplacementDescription( + suffix="attn.out_proj", + target_module=col_nn.GPT2FusedLinearConv1D_Row, + kwargs={ + "seq_parallel": use_sequence_parallel, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.fc_in", + target_module=col_nn.GPT2FusedLinearConv1D_Col, + kwargs={"n_fused": 1, "seq_parallel": use_sequence_parallel, "overlap": overlap}, + ), + SubModuleReplacementDescription( + suffix="mlp.fc_out", + target_module=col_nn.GPT2FusedLinearConv1D_Row, + kwargs={ + "seq_parallel": use_sequence_parallel, + }, + ), + SubModuleReplacementDescription( + suffix="attn.attn_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attn.resid_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ], + ) + + +""" +# GPTJModel +class GPTJModelPolicy(GPTJPolicy): + +# GPTJForCausalLM +class GPTJForCausalLMPolicy(GPTJPolicy): + +# GPTJForSequenceClassification +class GPTJForSequenceClassificationPolicy(GPTJPolicy): + +# GPTJForQuestionAnswering +class GPTJForQuestionAnsweringPolicy(GPTJPolicy): + +# TFGPTJForQuestionAnswering +class TFGPTJPolicy(GPTJPolicy): + +# TFGPTJForCausalLM +class TFGPTJCausalLMPolicy(GPTJPolicy): + +# TFGPTJForSequenceClassification +class TFGPTJForSequenceClassificationPolicy(GPTJPolicy): + +# TFGPTJForQuestionAnswering +class TFGPTJForQuestionAnsweringPolicy(GPTJPolicy): + +# FlaxGPTJModel +class FlaxGPTJPolicy(GPTJPolicy): + +# FlaxGPTJForCausalLMModel +class FlaxGPTJForCausalLMPolicy(GPTJPolicy): +""" From 0204ef3b726a588a114f8d0fe9fbfbbfd6e3d0a5 Mon Sep 17 00:00:00 2001 From: Pengtai Xu Date: Wed, 6 Sep 2023 17:50:46 +0800 Subject: [PATCH 2/9] [shardformer] implement shard policy for base gpt-j model --- colossalai/shardformer/modeling/gptj.py | 509 ++++++++++++++++++++++++ 1 file changed, 509 insertions(+) create mode 100644 colossalai/shardformer/modeling/gptj.py diff --git a/colossalai/shardformer/modeling/gptj.py b/colossalai/shardformer/modeling/gptj.py new file mode 100644 index 000000000000..e344c6d549fa --- /dev/null +++ b/colossalai/shardformer/modeling/gptj.py @@ -0,0 +1,509 @@ +class GPTJPipelineForwards: + ''' + This class serves as a micro library for forward function substitution of GPTJ models + under pipeline setting. + ''' + + @staticmethod + def gptj_model_forward( + self: GPTJModel, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: + + # This function is modified on the basis of transformers.models.gptj.modeling_gptj.GPTJModel.forward. + # Please refer to original code of transformers for more details. + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + logger = logging.get_logger(__name__) + + # Preprocess passed in arguments + # TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if past_key_values: + logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') + past_key_values = None + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if use_cache: + logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + use_cache = False + + if stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + input_shape = input_ids.size() + input_ids = input_ids.view(-1, seq_length) + + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, seq_length) + else: + if hidden_states is None: + raise ValueError("hidden_states shouldn't be None for stages other than the first stage.") + input_shape = hidden_states.size()[:-1] + batch_size, seq_length = input_shape[0], input_shape[1] + device = hidden_states.device + + # Attention mask. + if attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x num_attention_heads x N x N + # head_mask has shape n_layer x batch x num_attention_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if stage_manager.is_first_stage(): + if position_ids is not None: + position_ids = position_ids.view(-1, seq_length) + else: + position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + + hidden_states = inputs_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + # split the input tensor along sequence dimension + # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + if shard_config.enable_sequence_parallelism: + hidden_states = split_forward_gather_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + + # Going through held blocks. + start_idx, end_idx = stage_index[0], stage_index[1] + for i in range(start_idx, end_idx): + block = self.h[i] + torch.cuda.set_device(hidden_states.device) + + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + position_ids, + head_mask[i], + ) + else: + outputs = block( + hidden_states=hidden_states, + layer_past=None, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + # When sequence parallelism done, gather the output tensor in forward and split it in backward + if shard_config.enable_sequence_parallelism: + hidden_states = gather_forward_split_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + + if stage_manager.is_last_stage(): + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if stage_manager.is_last_stage(): + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + else: + # always return dict for intermediate stage + return {'hidden_states': hidden_states} + + + + +def get_gptj_flash_attention_forward(): + + from transformers.models.gptj.modeling_gptj import GPTJAttention + + from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + + def split_heads(self, tensor, num_attention_heads, attn_head_size, rotary): + """ + Splits hidden dim into attn_head_size and num_attention_heads + """ + new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size) + tensor = tensor.view(new_shape) + if rotary or len(tensor.shape) in [4,5]: + return tensor + else: + raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") + + def forward( + self: GPTJAttention, + hidden_states: torch.FloatTensor, + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Union[ + Tuple[torch.Tensor, Tuple[torch.Tensor]], + Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]], + ]: + query = self.q_proj(hidden_states) + key = self.k_proj(hidden_states) + value = self.v_proj(hidden_states) + + query = self.split_heads(query, self.num_attention_heads, self.head_dim, True) + key = self.split_heads(key, self.num_attention_heads, self.head_dim, True) + value = self.split_heads(value, self.num_attention_heads, self.head_dim, False) + + if is_torch_fx_proxy(position_ids) or torch.jit.is_tracing(): + # The logic to conditionally copy to GPU could not be traced, so we do this + # every time in the torch.fx case + embed_positions = get_embed_positions(self.embed_positions, position_ids) + else: + embed_positions = self._get_embed_positions(position_ids) + + repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1]) + sincos = torch.gather(embed_positions, 1, repeated_position_ids) + sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1) + + if self.rotary_dim is not None: + k_rot = key[:, :, :, : self.rotary_dim] + k_pass = key[:, :, :, self.rotary_dim :] + + q_rot = query[:, :, :, : self.rotary_dim] + q_pass = query[:, :, :, self.rotary_dim :] + + k_rot = apply_rotary_pos_emb(k_rot, sin, cos) + q_rot = apply_rotary_pos_emb(q_rot, sin, cos) + + key = torch.cat([k_rot, k_pass], dim=-1) + query = torch.cat([q_rot, q_pass], dim=-1) + else: + key = apply_rotary_pos_emb(key, sin, cos) + query = apply_rotary_pos_emb(query, sin, cos) + + key = key.permute(0, 2, 1, 3) + query = query.permute(0, 2, 1, 3) + + if layer_past is not None: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=1) + value = torch.cat((past_value, value), dim=1) + + if use_cache is True: + present = (key, value) + else: + present = None + + # use AttnMaskType and ColoAttention + if not self.is_cross_attention: + attn_mask_type = AttnMaskType.causal + flash_attention_mask = None + if attention_mask != None: + if attn_mask_type == AttnMaskType.causal: + attn_mask_type == AttnMaskType.paddedcausal + else: + attn_mask_type = AttnMaskType.padding + flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() + + # use coloattention + attention = ColoAttention(embed_dim=self.embed_dim, + num_heads=self.num_heads, + dropout=self.attn_dropout.p, + scale=scale) + + attn_output = attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type) + + attn_output = self.out_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + outputs = (attn_output, present, None) + + return outputs # a, present, (attentions) + +def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig): + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]).long() + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + + if position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # Attention mask. + if attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x num_attention_heads x N x N + # head_mask has shape n_layer x batch x num_attention_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + + hidden_states = inputs_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + # split the input tensor along sequence dimension + # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + hidden_states = split_forward_gather_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + position_ids, + head_mask[i], + ) + else: + outputs = block( + hidden_states=hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + # When sequence parallelism done, gather the output tensor in forward and split it in backward + hidden_states = gather_forward_split_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + return forward + \ No newline at end of file From 472faa587d6669062b52e09e55746ae3d6e164fa Mon Sep 17 00:00:00 2001 From: Pengtai Xu Date: Wed, 6 Sep 2023 17:51:34 +0800 Subject: [PATCH 3/9] [shardformer] implement shard policy for base gpt-j model 06 Sep --- colossalai/shardformer/policies/gptj.py | 122 ++++++++++++++++++++---- 1 file changed, 105 insertions(+), 17 deletions(-) diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py index 06d720f29327..771b2eeabfae 100644 --- a/colossalai/shardformer/policies/gptj.py +++ b/colossalai/shardformer/policies/gptj.py @@ -2,7 +2,15 @@ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -__all__ = [] +__all__ = [ + "GPTJPolicy", + "GPTJModelPolicy", + "GPTJForCausalLMPolicy", + "GPTJForSequenceClassificationPolicy", + "GPTJForQuestionAnsweringPolicy", + "FlaxGPTJPolicy", + "FlaxGPTJForCausalLMPolicy", +] class GPTJPolicy(Policy): @@ -23,7 +31,7 @@ def preprocess(self): return self.model def module_policy(self): - from transformers.models.gptj.modeling_gptj import GPTJBlock, GPTJModel + from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJModel policy = {} use_sequence_parallel = self.shard_config.enable_sequence_parallelism @@ -45,12 +53,23 @@ def module_policy(self): policy[GPTJBlock] = ModulePolicyDescription( attribute_replacement={ "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "attn.rotary_dim": self.model.config.rotary_dim // self.shard_config.tensor_parallel_size, + "attn.num_attention_heads": self.model.config.num_attention_heads + // self.shard_config.tensor_parallel_size, }, sub_module_replacement=[ SubModuleReplacementDescription( - suffix="attn.k_attn", + suffix="attn.k_proj", + target_module=col_nn.GPT2FusedLinearConv1D_Col, + kwargs={"n_fused": 3, "seq_parallel": use_sequence_parallel, "overlap": overlap}, + ), + SubModuleReplacementDescription( + suffix="attn.q_proj", + target_module=col_nn.GPT2FusedLinearConv1D_Col, + kwargs={"n_fused": 3, "seq_parallel": use_sequence_parallel, "overlap": overlap}, + ), + SubModuleReplacementDescription( + suffix="attn.v_proj", target_module=col_nn.GPT2FusedLinearConv1D_Col, kwargs={"n_fused": 3, "seq_parallel": use_sequence_parallel, "overlap": overlap}, ), @@ -88,6 +107,87 @@ def module_policy(self): ], ) + # optimization configuration + if self.shard_config.enable_fused_normalization: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="ln_f", + target_module=col_nn.FusedLayerNorm, + ), + policy=policy, + target_key=GPTJModel, + ) + + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="ln_1", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=GPTJBlock, + ) + + if self.shard_config.enable_flash_attention: + self.append_or_create_method_replacement( + description={ + "forward": get_gptj_flash_attention_forward(), + }, + policy=policy, + target_key=GPTJAttention, + ) + + if self.shard_config.enable_sequence_parallelism: + policy[GPTJModel].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)} + + return policy + + def postprocess(self): + return self.model + + def get_held_layers(self) -> List[nn.Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + + if self.model.__class__.__name__ == "GPT2Model": + module = self.model + else: + module = self.model.transformer + stage_manager = self.pipeline_stage_manager + + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.wte) + held_layers.append(module.wpe) + held_layers.append(module.drop) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.h[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.ln_f) + return held_layers + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if not self.pipeline_stage_manager: + raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.") + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "GPT2Model": + module = self.model + else: + module = self.model.transformer + + layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = { + "forward": partial( + new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + ) + } + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) + """ # GPTJModel @@ -102,18 +202,6 @@ class GPTJForSequenceClassificationPolicy(GPTJPolicy): # GPTJForQuestionAnswering class GPTJForQuestionAnsweringPolicy(GPTJPolicy): -# TFGPTJForQuestionAnswering -class TFGPTJPolicy(GPTJPolicy): - -# TFGPTJForCausalLM -class TFGPTJCausalLMPolicy(GPTJPolicy): - -# TFGPTJForSequenceClassification -class TFGPTJForSequenceClassificationPolicy(GPTJPolicy): - -# TFGPTJForQuestionAnswering -class TFGPTJForQuestionAnsweringPolicy(GPTJPolicy): - # FlaxGPTJModel class FlaxGPTJPolicy(GPTJPolicy): From a879ccbabf9635ab65255ba7ffec2712dfe6de76 Mon Sep 17 00:00:00 2001 From: Pengtai Xu Date: Thu, 7 Sep 2023 18:05:37 +0800 Subject: [PATCH 4/9] [shardformer] implement policy for all GPT-J models and test --- colossalai/shardformer/modeling/gptj.py | 470 +++++++++++++++--- colossalai/shardformer/policies/gptj.py | 127 ++++- tests/kit/model_zoo/transformers/gptj.py | 110 ++++ .../test_model/test_shard_gptj.py | 235 +++++++++ 4 files changed, 854 insertions(+), 88 deletions(-) create mode 100644 tests/kit/model_zoo/transformers/gptj.py create mode 100644 tests/test_shardformer/test_model/test_shard_gptj.py diff --git a/colossalai/shardformer/modeling/gptj.py b/colossalai/shardformer/modeling/gptj.py index e344c6d549fa..0ba854e2ec51 100644 --- a/colossalai/shardformer/modeling/gptj.py +++ b/colossalai/shardformer/modeling/gptj.py @@ -1,8 +1,33 @@ +from typing import Dict, List, Optional, Tuple, Union + +import torch +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, +) +from transformers.models.gptj.modeling_gptj import ( + GPTJForCausalLM, + GPTJForQuestionAnswering, + GPTJForSequenceClassification, + GPTJModel, + apply_rotary_pos_emb, + get_embed_positions, +) +from transformers.utils import is_torch_fx_proxy, logging + +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward +from colossalai.shardformer.shard import ShardConfig + + class GPTJPipelineForwards: - ''' + """ This class serves as a micro library for forward function substitution of GPTJ models under pipeline setting. - ''' + """ @staticmethod def gptj_model_forward( @@ -19,13 +44,14 @@ def gptj_model_forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - shard_config: ShardConfig = None) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: - + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Dict, Tuple, BaseModelOutputWithPast]: # This function is modified on the basis of transformers.models.gptj.modeling_gptj.GPTJModel.forward. # Please refer to original code of transformers for more details. - + # GPTJ has no cross attention in comparison to GPT2 + return_dict = return_dict if return_dict is not None else self.config.use_return_dict logger = logging.get_logger(__name__) @@ -33,42 +59,42 @@ def gptj_model_forward( # Preprocess passed in arguments # TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future. if past_key_values: - logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') + logger.warning_once("Non-empty past_key_values is not supported for pipeline models at the moment.") past_key_values = None if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False if use_cache: - logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False if stage_manager.is_first_stage(): - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - input_shape = input_ids.size() - input_ids = input_ids.view(-1, seq_length) - - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - batch_size = inputs_embeds.shape[0] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - device = input_ids.device if input_ids is not None else inputs_embeds.device - - if token_type_ids is not None: - token_type_ids = token_type_ids.view(-1, seq_length) - else: - if hidden_states is None: - raise ValueError("hidden_states shouldn't be None for stages other than the first stage.") - input_shape = hidden_states.size()[:-1] - batch_size, seq_length = input_shape[0], input_shape[1] - device = hidden_states.device + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + input_shape = input_ids.size() + input_ids = input_ids.view(-1, seq_length) + + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, seq_length) + else: + if hidden_states is None: + raise ValueError("hidden_states shouldn't be None for stages other than the first stage.") + input_shape = hidden_states.size()[:-1] + batch_size, seq_length = input_shape[0], input_shape[1] + device = hidden_states.device # Attention mask. if attention_mask is not None: @@ -102,7 +128,7 @@ def gptj_model_forward( else: position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) - + if inputs_embeds is None: inputs_embeds = self.wte(input_ids) @@ -115,7 +141,7 @@ def gptj_model_forward( hidden_states = self.drop(hidden_states) output_shape = input_shape + (hidden_states.size(-1),) - + if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( @@ -126,20 +152,20 @@ def gptj_model_forward( presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None - + # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] if shard_config.enable_sequence_parallelism: - hidden_states = split_forward_gather_backward(hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group) + hidden_states = split_forward_gather_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) # Going through held blocks. start_idx, end_idx = stage_index[0], stage_index[1] for i in range(start_idx, end_idx): block = self.h[i] torch.cuda.set_device(hidden_states.device) - + # Ensure that attention_mask is always on the same device as hidden_states if attention_mask is not None: attention_mask = attention_mask.to(hidden_states.device) @@ -185,9 +211,9 @@ def custom_forward(*inputs): # When sequence parallelism done, gather the output tensor in forward and split it in backward if shard_config.enable_sequence_parallelism: - hidden_states = gather_forward_split_backward(hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group) + hidden_states = gather_forward_split_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) if stage_manager.is_last_stage(): hidden_states = self.ln_f(hidden_states) @@ -198,39 +224,328 @@ def custom_forward(*inputs): all_hidden_states = all_hidden_states + (hidden_states,) if stage_manager.is_last_stage(): - if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) + if not return_dict: + return tuple( + v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None + ) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) else: # always return dict for intermediate stage - return {'hidden_states': hidden_states} + return {"hidden_states": hidden_states} + + @staticmethod + def gptj_causallm_model_forward( + self: GPTJForCausalLM, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Dict, Tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + + # This function is modified on the basis of transformers.models.gptj.modeling_gptj.GPTJForCausalLM.forward. + # Please refer to original code of transformers for more details. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + transformer_outputs = GPTJPipelineForwards.gptj_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + # If not at the last stage, return hidden_states as in GPTJModel + if not stage_manager.is_last_stage(): + return {"hidden_states": transformer_outputs["hidden_states"]} + hidden_states = transformer_outputs[0] + lm_logits = self.lm_head(hidden_states) -def get_gptj_flash_attention_forward(): + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + loss = loss.to(hidden_states.dtype) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def gptj_for_sequence_classification_forward( + self: GPTJForSequenceClassification, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Dict, Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + # This function is modified on the basis of transformers.models.gptj.modeling_gptj.GPTJForSequenceClassification.forward. + # Please refer to original code of transformers for more details. + """ + logger = logging.get_logger(__name__) + + if input_ids is not None: + batch_size, _ = input_ids.shape[:2] + else: + batch_size, _ = hidden_states.shape[:2] + assert ( + self.config.pad_token_id is not None or batch_size == 1 + ), "Cannot handle batch sizes > 1 if no padding token is defined." + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = GPTJPipelineForwards.gptj_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + + # If not at the last stage, return hidden_states as in GPTJModel + if not stage_manager.is_last_stage(): + return {"hidden_states": transformer_outputs["hidden_states"]} + + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + else: + sequence_lengths = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(pooled_logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def gptj_for_question_answering_forward( + self: GPTJForQuestionAnswering, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Dict, Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + + # This function is modified on the basis of transformers.models.gptj.modeling_gptj.GPTJForQuestionAnswering.forward. + # Please refer to original code of transformers for more details. + + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = GPTJPipelineForwards.gptj_model_forward( + self.transformer, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + + # If not at the last stage, return hidden_states as in GPTJModel + if not stage_manager.is_last_stage(): + return {"hidden_states": outputs["hidden_states"]} + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + +def get_gptj_flash_attention_forward(): from transformers.models.gptj.modeling_gptj import GPTJAttention from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention - def split_heads(self, tensor, num_attention_heads, attn_head_size, rotary): + def split_heads(tensor, num_attention_heads, attn_head_size, rotary): """ Splits hidden dim into attn_head_size and num_attention_heads """ new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size) tensor = tensor.view(new_shape) - if rotary or len(tensor.shape) in [4,5]: + if rotary or len(tensor.shape) in [4, 5]: return tensor else: raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") - + def forward( self: GPTJAttention, hidden_states: torch.FloatTensor, @@ -248,9 +563,9 @@ def forward( key = self.k_proj(hidden_states) value = self.v_proj(hidden_states) - query = self.split_heads(query, self.num_attention_heads, self.head_dim, True) - key = self.split_heads(key, self.num_attention_heads, self.head_dim, True) - value = self.split_heads(value, self.num_attention_heads, self.head_dim, False) + query = split_heads(query, self.num_attention_heads, self.head_dim, True) + key = split_heads(key, self.num_attention_heads, self.head_dim, True) + value = split_heads(value, self.num_attention_heads, self.head_dim, False) if is_torch_fx_proxy(position_ids) or torch.jit.is_tracing(): # The logic to conditionally copy to GPU could not be traced, so we do this @@ -305,21 +620,21 @@ def forward( flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() # use coloattention - attention = ColoAttention(embed_dim=self.embed_dim, - num_heads=self.num_heads, - dropout=self.attn_dropout.p, - scale=scale) + scale = value.size(-1) ** -0.5 + attention = ColoAttention( + embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.attn_dropout.p, scale=scale + ) attn_output = attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type) - + attn_output = self.out_proj(attn_output) attn_output = self.resid_dropout(attn_output) outputs = (attn_output, present, None) return outputs # a, present, (attentions) - -def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig): + +def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig): def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -420,13 +735,13 @@ def forward( presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None - + # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] - hidden_states = split_forward_gather_backward(hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group) - + hidden_states = split_forward_gather_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): # Model parallel if self.model_parallel: @@ -482,11 +797,11 @@ def custom_forward(*inputs): for k, v in self.device_map.items(): if i == v[-1] and "cuda:" + str(k) != self.last_device: hidden_states = hidden_states.to("cuda:" + str(k + 1)) - + # When sequence parallelism done, gather the output tensor in forward and split it in backward - hidden_states = gather_forward_split_backward(hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group) + hidden_states = gather_forward_split_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) hidden_states = self.ln_f(hidden_states) @@ -506,4 +821,3 @@ def custom_forward(*inputs): ) return forward - \ No newline at end of file diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py index 771b2eeabfae..4b371ca4f1ae 100644 --- a/colossalai/shardformer/policies/gptj.py +++ b/colossalai/shardformer/policies/gptj.py @@ -1,5 +1,6 @@ import colossalai.shardformer.layer as col_nn +from ..modeling.gptj import GPTJPipelineForwards, get_gptj_flash_attention_forward, gptj_sequence_parallel_forward_fn from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ @@ -139,7 +140,7 @@ def module_policy(self): ) if self.shard_config.enable_sequence_parallelism: - policy[GPTJModel].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)} + policy[GPTJModel].method_replacement = {"forward": gptj_sequence_parallel_forward_fn(self.shard_config)} return policy @@ -150,7 +151,7 @@ def get_held_layers(self) -> List[nn.Module]: """Get pipeline layers for current stage.""" assert self.pipeline_stage_manager is not None - if self.model.__class__.__name__ == "GPT2Model": + if self.model.__class__.__name__ == "GPTJModel": module = self.model else: module = self.model.transformer @@ -160,7 +161,7 @@ def get_held_layers(self) -> List[nn.Module]: layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) if stage_manager.is_first_stage(): held_layers.append(module.wte) - held_layers.append(module.wpe) + # held_layers.append(module.wpe) held_layers.append(module.drop) start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) held_layers.extend(module.h[start_idx:end_idx]) @@ -174,7 +175,7 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli if not self.pipeline_stage_manager: raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.") stage_manager = self.pipeline_stage_manager - if self.model.__class__.__name__ == "GPT2Model": + if self.model.__class__.__name__ == "GPTJModel": module = self.model else: module = self.model.transformer @@ -189,22 +190,128 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) -""" # GPTJModel class GPTJModelPolicy(GPTJPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.gptj.modeling_gptj import GPTJModel + + policy = super().module_policy() + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward( + model_cls=GPTJModel, new_forward=GPTJPipelineForwards.gptj_model_forward, policy=policy + ) + return policy + + def get_held_layers(self) -> List[nn.Module]: + return super().get_held_layers() + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in GPT2Model.""" + return [] + # GPTJForCausalLM class GPTJForCausalLMPolicy(GPTJPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.gptj.modeling_gptj import GPTJForCausalLM + + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + addon_module = { + GPTJForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True} + ) + ] + ) + } + policy.update(addon_module) + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward( + model_cls=GPTJForCausalLM, new_forward=GPTJPipelineForwards.gptj_causallm_model_forward, policy=policy + ) + return policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """The weights of wte and lm_head are shared.""" + module = self.model + stage_manager = self.pipeline_stage_manager + if stage_manager is not None: + if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight): + first_stage, last_stage = 0, stage_manager.num_stages - 1 + return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}] + return [] + # GPTJForSequenceClassification class GPTJForSequenceClassificationPolicy(GPTJPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.gptj.modeling_gptj import GPTJForCausalLM + + policy = super().module_policy() + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward( + model_cls=GPTJForCausalLM, + new_forward=GPTJPipelineForwards.gptj_for_sequence_classification_forward, + policy=policy, + ) + return policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.score) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in GPTJForSequenceClassification.""" + return [] + # GPTJForQuestionAnswering class GPTJForQuestionAnsweringPolicy(GPTJPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.gptj.modeling_gptj import GPTJForQuestionAnswering + + policy = super().module_policy() -# FlaxGPTJModel -class FlaxGPTJPolicy(GPTJPolicy): + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward( + model_cls=GPTJForQuestionAnswering, + new_forward=GPTJPipelineForwards.gptj_for_question_answering_forward, + policy=policy, + ) + return policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.score) + return held_layers -# FlaxGPTJForCausalLMModel -class FlaxGPTJForCausalLMPolicy(GPTJPolicy): -""" + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in GPT2ForQuestionAnswering.""" + return [] diff --git a/tests/kit/model_zoo/transformers/gptj.py b/tests/kit/model_zoo/transformers/gptj.py new file mode 100644 index 000000000000..f91cfeda507b --- /dev/null +++ b/tests/kit/model_zoo/transformers/gptj.py @@ -0,0 +1,110 @@ +import copy + +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +# =============================== +# Register single-sentence GPT +# =============================== + + +def data_gen(): + # Generated from following code snippet + # + # from transformers import AutoTokenizer + # input = 'Hello, my dog is cute' + # tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") + # tokenized_input = tokenizer(input, return_tensors='pt') + # input_ids = tokenized_input['input_ids'] + # attention_mask = tokenized_input['attention_mask'] + input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]], dtype=torch.int64) + return dict(input_ids=input_ids, attention_mask=attention_mask) + + +def data_gen_for_lm(): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen() + data["labels"] = data["input_ids"].clone() + return data + + +def data_gen_for_question_answering(): + # question answering data gen + # `labels` is the type not the token id for token classification, 0 or 1 + data = data_gen() + start_positions = torch.tensor([0], dtype=torch.int64) + data["start_positions"] = start_positions + end_positions = torch.tensor([1], dtype=torch.int64) + data["end_positions"] = end_positions + return data + + +def data_gen_for_sequence_classification(): + # sequence classification data gen + data = data_gen() + data["labels"] = torch.tensor([1], dtype=torch.int64) + return data + + +# define output transform function +output_transform_fn = lambda x: x + +# define loss function +loss_fn_for_gptj_model = lambda x: torch.nn.functional.mse_loss( + x.last_hidden_state, torch.ones_like(x.last_hidden_state) +) +loss_fn = lambda x: x.loss + +config = transformers.GPTJConfig( + n_layer=2, + n_head=4, + vocab_size=50258, + attn_pdrop=0, + embd_pdrop=0, + resid_pdrop=0, + summary_first_dropout=0, + hidden_dropout=0, + problem_type="single_label_classification", + pad_token_id=50256, +) + +config_for_token_classification = copy.deepcopy(config) +config_for_token_classification.num_labels = 2 + +# register the following models +model_zoo.register( + name="transformers_gptj", + model_fn=lambda: transformers.GPTJModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_gptj_model, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_gptj_lm", + model_fn=lambda: transformers.GPTJForCausalLM(config), + data_gen_fn=data_gen_for_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_gptj_for_question_answering", + model_fn=lambda: transformers.GPTJForQuestionAnswering(config), + data_gen_fn=data_gen_for_question_answering, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_gptj_for_sequence_classification", + model_fn=lambda: transformers.GPTJForSequenceClassification(config_for_token_classification), + data_gen_fn=data_gen_for_sequence_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/test_shardformer/test_model/test_shard_gptj.py b/tests/test_shardformer/test_model/test_shard_gptj.py new file mode 100644 index 000000000000..34c2de66392c --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_gptj.py @@ -0,0 +1,235 @@ +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_all_grad_tensors, + check_loss, + check_output_hidden_state, + check_weight, + get_grad_tensors_for_check, + run_forward_backward_with_hybrid_plugin, + unwrap_model, +) + + +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) + + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) + + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + + # unwrap model + gptj = unwrap_model(org_model, "GPTJModel", "transformer") + sharded_gptj = unwrap_model(sharded_model, "GPTJModel", "transformer") + + col_layer_for_check = ["h[0].mlp.fc_in"] + row_layer_for_check = ["wte", "h[0].mlp.fc_out"] + + # Save gradient tensors for comparison between the original model and the sharded model. + grads_to_check = {} + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + if test_config["precision"] == "fp32": + atol, rtol = 1e-4, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + col_layer_grads = get_grad_tensors_for_check( + gptj, sharded_gptj, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) + row_layer_grads = get_grad_tensors_for_check( + gptj, sharded_gptj, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + ) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + + if org_model.__class__.__name__ == "GPTJModel": + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights + if stage_manager is None or stage_manager.is_first_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 5e-3, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + check_weight(gptj, sharded_gptj, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) + + # check grads + check_all_grad_tensors(grads_to_check) + + Randomizer.reset_index() + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp32", + }, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp32", + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": True, + "use_lazy_init": True, + "enable_sequence_parallelism": True, + "precision": "fp32", + }, + { + "tp_size": 4, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": True, + "enable_sequence_parallelism": True, + "precision": "fp32", + }, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + ], +) +@clear_cache_before_run() +def run_gptj_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_gptj") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "zero_stage": 1, + "initial_scale": 1, + }, + ], +) +@clear_cache_before_run() +def run_gptj_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_gptj") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + torch.cuda.empty_cache() + + +def check_gptj(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_gptj_test() + + +def check_gptj_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_gptj_3d_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_gptj(): + spawn(check_gptj, 4) + + +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_gptj_3d(): + spawn(check_gptj_3d, 8) + + +if __name__ == "__main__": + test_gptj() + test_gptj_3d() From d0e6939a02599a5ef9fd4d93a815cb274a052c64 Mon Sep 17 00:00:00 2001 From: Pengtai Xu Date: Fri, 8 Sep 2023 18:25:08 +0800 Subject: [PATCH 5/9] [shardformer] test GPT-J sharding policy --- colossalai/shardformer/modeling/gptj.py | 16 ++++++++++------ .../shardformer/policies/auto_policy.py | 11 +++++++++++ .../test_model/test_shard_gptj.py | 19 ++++++++++--------- 3 files changed, 31 insertions(+), 15 deletions(-) diff --git a/colossalai/shardformer/modeling/gptj.py b/colossalai/shardformer/modeling/gptj.py index 0ba854e2ec51..79e5616fd7b2 100644 --- a/colossalai/shardformer/modeling/gptj.py +++ b/colossalai/shardformer/modeling/gptj.py @@ -594,8 +594,10 @@ def forward( key = apply_rotary_pos_emb(key, sin, cos) query = apply_rotary_pos_emb(query, sin, cos) - key = key.permute(0, 2, 1, 3) - query = query.permute(0, 2, 1, 3) + # key = key.permute(0, 2, 1, 3) + # query = query.permute(0, 2, 1, 3) + key = key.to(dtype=value.dtype) # fp16 compatability + query = query.to(dtype=value.dtype) if layer_past is not None: past_key = layer_past[0] @@ -609,9 +611,8 @@ def forward( present = None # use AttnMaskType and ColoAttention - if not self.is_cross_attention: - attn_mask_type = AttnMaskType.causal - flash_attention_mask = None + attn_mask_type = AttnMaskType.causal + flash_attention_mask = None if attention_mask != None: if attn_mask_type == AttnMaskType.causal: attn_mask_type == AttnMaskType.paddedcausal @@ -621,8 +622,9 @@ def forward( # use coloattention scale = value.size(-1) ** -0.5 + attention = ColoAttention( - embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.attn_dropout.p, scale=scale + embed_dim=self.embed_dim, num_heads=self.num_attention_heads, dropout=self.attn_dropout.p, scale=scale ) attn_output = attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type) @@ -633,6 +635,8 @@ def forward( return outputs # a, present, (attentions) + return forward + def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig): def forward( diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index f3587de15f86..c5a4586d323e 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -86,6 +86,17 @@ class PolicyLocation: "transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification": PolicyLocation( file_name="gpt2", class_name="GPT2ForSequenceClassificationPolicy" ), + # GPTJ + "transformers.models.gptj.modeling_gptj.GPTJModel": PolicyLocation(file_name="gptj", class_name="GPTJModelPolicy"), + "transformers.models.gptj.modeling_gptj.GPTJForCausalLM": PolicyLocation( + file_name="gptj", class_name="GPTJForCausalLMPolicy" + ), + "transformers.models.gptj.modeling_gptj.GPTJForQuestionAnswering": PolicyLocation( + file_name="gptj", class_name="GPTJForQuestionAnsweringPolicy" + ), + "transformers.models.gptj.modeling_gptj.GPTJForSequenceClassification": PolicyLocation( + file_name="gptj", class_name="GPTJForSequenceClassificationPolicy" + ), # ViT "transformers.models.vit.modeling_vit.ViTModel": PolicyLocation(file_name="vit", class_name="ViTModelPolicy"), "transformers.models.vit.modeling_vit.ViTForImageClassification": PolicyLocation( diff --git a/tests/test_shardformer/test_model/test_shard_gptj.py b/tests/test_shardformer/test_model/test_shard_gptj.py index 34c2de66392c..2933a5d5f12c 100644 --- a/tests/test_shardformer/test_model/test_shard_gptj.py +++ b/tests/test_shardformer/test_model/test_shard_gptj.py @@ -89,13 +89,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "test_config", [ { - "tp_size": 2, - "pp_size": 2, - "num_microbatches": 4, - "enable_all_optimization": True, - "use_lazy_init": True, - "precision": "fp16", - "initial_scale": 1, + "tp_size": 1, + "pp_size": 1, + #'num_microbatches': 4, + "enable_all_optimization": False, + #'use_lazy_init': False, + "precision": "fp32", + #'initial_scale': 1, }, { "tp_size": 1, @@ -161,7 +161,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @clear_cache_before_run() def run_gptj_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_gptj") - + print("===test config===") + print(test_config) for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -220,7 +221,7 @@ def check_gptj_3d(rank, world_size, port): @rerun_if_address_is_in_use() @clear_cache_before_run() def test_gptj(): - spawn(check_gptj, 4) + spawn(check_gptj, 2) @pytest.mark.largedist From 1f237cccf420383bf1a796b36314f64ddd4f7b8a Mon Sep 17 00:00:00 2001 From: Pengtai Xu Date: Sat, 9 Sep 2023 17:48:48 +0800 Subject: [PATCH 6/9] [shardformer] finished testing pp gpt-j sharding --- colossalai/shardformer/modeling/gptj.py | 41 ++++---- colossalai/shardformer/policies/gptj.py | 118 ++++++++++++------------ 2 files changed, 78 insertions(+), 81 deletions(-) diff --git a/colossalai/shardformer/modeling/gptj.py b/colossalai/shardformer/modeling/gptj.py index 79e5616fd7b2..ad51bf2c709b 100644 --- a/colossalai/shardformer/modeling/gptj.py +++ b/colossalai/shardformer/modeling/gptj.py @@ -89,12 +89,12 @@ def gptj_model_forward( if token_type_ids is not None: token_type_ids = token_type_ids.view(-1, seq_length) - else: - if hidden_states is None: - raise ValueError("hidden_states shouldn't be None for stages other than the first stage.") - input_shape = hidden_states.size()[:-1] - batch_size, seq_length = input_shape[0], input_shape[1] - device = hidden_states.device + else: + if hidden_states is None: + raise ValueError("hidden_states shouldn't be None for stages other than the first stage.") + input_shape = hidden_states.size()[:-1] + batch_size, seq_length = input_shape[0], input_shape[1] + device = hidden_states.device # Attention mask. if attention_mask is not None: @@ -122,23 +122,20 @@ def gptj_model_forward( # head_mask has shape n_layer x batch x num_attention_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.n_layer) + # position id to be asssigned not just for the first stage for attn input + if position_ids is not None: + position_ids = position_ids.view(-1, seq_length) + else: + position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) if stage_manager.is_first_stage(): - if position_ids is not None: - position_ids = position_ids.view(-1, seq_length) - else: - position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) - - if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) - - hidden_states = inputs_embeds - - if token_type_ids is not None: - token_type_embeds = self.wte(token_type_ids) - hidden_states = hidden_states + token_type_embeds - - hidden_states = self.drop(hidden_states) + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + hidden_states = inputs_embeds + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + hidden_states = self.drop(hidden_states) output_shape = input_shape + (hidden_states.size(-1),) diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py index 4b371ca4f1ae..b0a8ff5e5d8b 100644 --- a/colossalai/shardformer/policies/gptj.py +++ b/colossalai/shardformer/policies/gptj.py @@ -51,62 +51,62 @@ def module_policy(self): ] ) - policy[GPTJBlock] = ModulePolicyDescription( - attribute_replacement={ - "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "attn.rotary_dim": self.model.config.rotary_dim // self.shard_config.tensor_parallel_size, - "attn.num_attention_heads": self.model.config.num_attention_heads - // self.shard_config.tensor_parallel_size, - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="attn.k_proj", - target_module=col_nn.GPT2FusedLinearConv1D_Col, - kwargs={"n_fused": 3, "seq_parallel": use_sequence_parallel, "overlap": overlap}, - ), - SubModuleReplacementDescription( - suffix="attn.q_proj", - target_module=col_nn.GPT2FusedLinearConv1D_Col, - kwargs={"n_fused": 3, "seq_parallel": use_sequence_parallel, "overlap": overlap}, - ), - SubModuleReplacementDescription( - suffix="attn.v_proj", - target_module=col_nn.GPT2FusedLinearConv1D_Col, - kwargs={"n_fused": 3, "seq_parallel": use_sequence_parallel, "overlap": overlap}, - ), - SubModuleReplacementDescription( - suffix="attn.out_proj", - target_module=col_nn.GPT2FusedLinearConv1D_Row, - kwargs={ - "seq_parallel": use_sequence_parallel, - }, - ), - SubModuleReplacementDescription( - suffix="mlp.fc_in", - target_module=col_nn.GPT2FusedLinearConv1D_Col, - kwargs={"n_fused": 1, "seq_parallel": use_sequence_parallel, "overlap": overlap}, - ), - SubModuleReplacementDescription( - suffix="mlp.fc_out", - target_module=col_nn.GPT2FusedLinearConv1D_Row, - kwargs={ - "seq_parallel": use_sequence_parallel, - }, - ), - SubModuleReplacementDescription( - suffix="attn.attn_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="attn.resid_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="mlp.dropout", - target_module=col_nn.DropoutForParallelInput, - ), - ], - ) + policy[GPTJBlock] = ModulePolicyDescription( + attribute_replacement={ + "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "attn.rotary_dim": self.model.config.rotary_dim // self.shard_config.tensor_parallel_size, + "attn.num_attention_heads": self.model.config.num_attention_heads + // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attn.k_proj", + target_module=col_nn.GPT2FusedLinearConv1D_Col, + kwargs={"n_fused": 3, "seq_parallel": use_sequence_parallel, "overlap": overlap}, + ), + SubModuleReplacementDescription( + suffix="attn.q_proj", + target_module=col_nn.GPT2FusedLinearConv1D_Col, + kwargs={"n_fused": 3, "seq_parallel": use_sequence_parallel, "overlap": overlap}, + ), + SubModuleReplacementDescription( + suffix="attn.v_proj", + target_module=col_nn.GPT2FusedLinearConv1D_Col, + kwargs={"n_fused": 3, "seq_parallel": use_sequence_parallel, "overlap": overlap}, + ), + SubModuleReplacementDescription( + suffix="attn.out_proj", + target_module=col_nn.GPT2FusedLinearConv1D_Row, + kwargs={ + "seq_parallel": use_sequence_parallel, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.fc_in", + target_module=col_nn.GPT2FusedLinearConv1D_Col, + kwargs={"n_fused": 1, "seq_parallel": use_sequence_parallel, "overlap": overlap}, + ), + SubModuleReplacementDescription( + suffix="mlp.fc_out", + target_module=col_nn.GPT2FusedLinearConv1D_Row, + kwargs={ + "seq_parallel": use_sequence_parallel, + }, + ), + SubModuleReplacementDescription( + suffix="attn.attn_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attn.resid_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ], + ) # optimization configuration if self.shard_config.enable_fused_normalization: @@ -265,13 +265,13 @@ def __init__(self) -> None: super().__init__() def module_policy(self): - from transformers.models.gptj.modeling_gptj import GPTJForCausalLM + from transformers.models.gptj.modeling_gptj import GPTJForSequenceClassification policy = super().module_policy() if self.pipeline_stage_manager is not None: self.set_pipeline_forward( - model_cls=GPTJForCausalLM, + model_cls=GPTJForSequenceClassification, new_forward=GPTJPipelineForwards.gptj_for_sequence_classification_forward, policy=policy, ) @@ -309,7 +309,7 @@ def module_policy(self): def get_held_layers(self) -> List[nn.Module]: held_layers = super().get_held_layers() if self.pipeline_stage_manager.is_last_stage(): - held_layers.append(self.model.score) + held_layers.append(self.model.qa_outputs) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: From c013a0d1c9496177aefa877c16ebf8d157cc4970 Mon Sep 17 00:00:00 2001 From: Pengtai Xu Date: Mon, 11 Sep 2023 11:09:35 +0800 Subject: [PATCH 7/9] [shardformer] clean up for pr --- tests/kit/model_zoo/transformers/__init__.py | 1 + tests/test_shardformer/test_model/test_shard_gptj.py | 10 ++++++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index 2a492361b13b..aa5044cc1b94 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -4,6 +4,7 @@ from .bloom import * from .chatglm2 import * from .gpt import * +from .gptj import * from .llama import * from .opt import * from .sam import * diff --git a/tests/test_shardformer/test_model/test_shard_gptj.py b/tests/test_shardformer/test_model/test_shard_gptj.py index 2933a5d5f12c..ce72212d5803 100644 --- a/tests/test_shardformer/test_model/test_shard_gptj.py +++ b/tests/test_shardformer/test_model/test_shard_gptj.py @@ -90,8 +90,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, [ { "tp_size": 1, - "pp_size": 1, - #'num_microbatches': 4, + "pp_size": 2, + "num_microbatches": 1, "enable_all_optimization": False, #'use_lazy_init': False, "precision": "fp32", @@ -160,9 +160,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) @clear_cache_before_run() def run_gptj_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry("transformers_gptj") - print("===test config===") print(test_config) + + sub_model_zoo = model_zoo.get_sub_registry("transformers_gptj") + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -232,5 +233,6 @@ def test_gptj_3d(): if __name__ == "__main__": + print("===hello===") test_gptj() test_gptj_3d() From ec9cba2dacdeac26b4e49d38f341c628690213ef Mon Sep 17 00:00:00 2001 From: Pengtai Xu Date: Mon, 11 Sep 2023 11:09:35 +0800 Subject: [PATCH 8/9] [shardformer] support all GPT-J shard former techniques except lazy init --- colossalai/shardformer/policies/gptj.py | 34 +++++++------- tests/kit/model_zoo/registry.py | 8 +++- tests/kit/model_zoo/transformers/gpt.py | 2 +- tests/kit/model_zoo/transformers/gptj.py | 7 ++- .../test_model/test_shard_gptj.py | 45 +++++++------------ 5 files changed, 44 insertions(+), 52 deletions(-) diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py index b0a8ff5e5d8b..343df4e09777 100644 --- a/colossalai/shardformer/policies/gptj.py +++ b/colossalai/shardformer/policies/gptj.py @@ -1,3 +1,8 @@ +from functools import partial +from typing import Callable, Dict, List + +from torch import Tensor, nn + import colossalai.shardformer.layer as col_nn from ..modeling.gptj import GPTJPipelineForwards, get_gptj_flash_attention_forward, gptj_sequence_parallel_forward_fn @@ -54,44 +59,39 @@ def module_policy(self): policy[GPTJBlock] = ModulePolicyDescription( attribute_replacement={ "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "attn.rotary_dim": self.model.config.rotary_dim // self.shard_config.tensor_parallel_size, "attn.num_attention_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, }, sub_module_replacement=[ SubModuleReplacementDescription( suffix="attn.k_proj", - target_module=col_nn.GPT2FusedLinearConv1D_Col, - kwargs={"n_fused": 3, "seq_parallel": use_sequence_parallel, "overlap": overlap}, + target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, ), SubModuleReplacementDescription( suffix="attn.q_proj", - target_module=col_nn.GPT2FusedLinearConv1D_Col, - kwargs={"n_fused": 3, "seq_parallel": use_sequence_parallel, "overlap": overlap}, + target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, ), SubModuleReplacementDescription( suffix="attn.v_proj", - target_module=col_nn.GPT2FusedLinearConv1D_Col, - kwargs={"n_fused": 3, "seq_parallel": use_sequence_parallel, "overlap": overlap}, + target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, ), SubModuleReplacementDescription( suffix="attn.out_proj", - target_module=col_nn.GPT2FusedLinearConv1D_Row, - kwargs={ - "seq_parallel": use_sequence_parallel, - }, + target_module=col_nn.Linear1D_Row, + kwargs={"seq_parallel": use_sequence_parallel}, ), SubModuleReplacementDescription( suffix="mlp.fc_in", - target_module=col_nn.GPT2FusedLinearConv1D_Col, - kwargs={"n_fused": 1, "seq_parallel": use_sequence_parallel, "overlap": overlap}, + target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel}, ), SubModuleReplacementDescription( suffix="mlp.fc_out", - target_module=col_nn.GPT2FusedLinearConv1D_Row, - kwargs={ - "seq_parallel": use_sequence_parallel, - }, + target_module=col_nn.Linear1D_Row, + kwargs={"seq_parallel": use_sequence_parallel}, ), SubModuleReplacementDescription( suffix="attn.attn_dropout", diff --git a/tests/kit/model_zoo/registry.py b/tests/kit/model_zoo/registry.py index b90972291870..bb522778bb5d 100644 --- a/tests/kit/model_zoo/registry.py +++ b/tests/kit/model_zoo/registry.py @@ -71,8 +71,12 @@ def get_sub_registry(self, keyword: str): new_dict = dict() for k, v in self.items(): - if keyword in k: - new_dict[k] = v + if keyword == "transformers_gpt": + if keyword in k and not "gptj" in k: # ensure GPT2 does not retrieve GPTJ models + new_dict[k] = v + else: + if keyword in k: + new_dict[k] = v assert len(new_dict) > 0, f"No model found with keyword {keyword}" return new_dict diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index 2af6176fbe4a..4ddadae85f3b 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -14,7 +14,7 @@ def data_gen(): # Generated from following code snippet # # from transformers import GPT2Tokenizer - # input = 'Hello, my dog is cute' + # input = 'Hello, my dog is cute is cute' (last two words repeated to satisfy length requirement) # tokenized_input = tokenizer(input, return_tensors='pt') # input_ids = tokenized_input['input_ids'] # attention_mask = tokenized_input['attention_mask'] diff --git a/tests/kit/model_zoo/transformers/gptj.py b/tests/kit/model_zoo/transformers/gptj.py index f91cfeda507b..9eefbb43dad8 100644 --- a/tests/kit/model_zoo/transformers/gptj.py +++ b/tests/kit/model_zoo/transformers/gptj.py @@ -14,13 +14,13 @@ def data_gen(): # Generated from following code snippet # # from transformers import AutoTokenizer - # input = 'Hello, my dog is cute' + # input = 'Hello, my dog is cute is cute' (last two words repeated to satisfy length requirement) # tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") # tokenized_input = tokenizer(input, return_tensors='pt') # input_ids = tokenized_input['input_ids'] # attention_mask = tokenized_input['attention_mask'] - input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779]], dtype=torch.int64) - attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]], dtype=torch.int64) + input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) return dict(input_ids=input_ids, attention_mask=attention_mask) @@ -66,7 +66,6 @@ def data_gen_for_sequence_classification(): attn_pdrop=0, embd_pdrop=0, resid_pdrop=0, - summary_first_dropout=0, hidden_dropout=0, problem_type="single_label_classification", pad_token_id=50256, diff --git a/tests/test_shardformer/test_model/test_shard_gptj.py b/tests/test_shardformer/test_model/test_shard_gptj.py index ce72212d5803..a946aacfd7ed 100644 --- a/tests/test_shardformer/test_model/test_shard_gptj.py +++ b/tests/test_shardformer/test_model/test_shard_gptj.py @@ -35,8 +35,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, gptj = unwrap_model(org_model, "GPTJModel", "transformer") sharded_gptj = unwrap_model(sharded_model, "GPTJModel", "transformer") - col_layer_for_check = ["h[0].mlp.fc_in"] - row_layer_for_check = ["wte", "h[0].mlp.fc_out"] + col_layer_for_check = ["h[0].attn.k_proj"] + row_layer_for_check = ["h[0].mlp.fc_out"] # use dim=0 for wte get_grad_tensors_for_check # Save gradient tensors for comparison between the original model and the sharded model. grads_to_check = {} @@ -46,10 +46,11 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, else: atol, rtol = 5e-3, 5e-3 col_layer_grads = get_grad_tensors_for_check( - gptj, sharded_gptj, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + gptj, sharded_gptj, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False ) + row_layer_grads = get_grad_tensors_for_check( - gptj, sharded_gptj, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + gptj, sharded_gptj, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False ) grads_to_check.update(col_layer_grads) grads_to_check.update(row_layer_grads) @@ -76,7 +77,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 5e-3, 1e-3 else: atol, rtol = 5e-3, 5e-3 - check_weight(gptj, sharded_gptj, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) + check_weight(gptj, sharded_gptj, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False) # check grads check_all_grad_tensors(grads_to_check) @@ -89,20 +90,20 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "test_config", [ { - "tp_size": 1, + "tp_size": 2, "pp_size": 2, - "num_microbatches": 1, - "enable_all_optimization": False, - #'use_lazy_init': False, - "precision": "fp32", - #'initial_scale': 1, + "num_microbatches": 4, + "enable_all_optimization": True, + #'use_lazy_init': True, GPTJ currently do not support lazy init; model training has issue even without sharding + "precision": "fp16", + "initial_scale": 1, }, { "tp_size": 1, "pp_size": 2, "num_microbatches": 4, "enable_all_optimization": True, - "use_lazy_init": True, + #'use_lazy_init': True, "precision": "fp16", "initial_scale": 1, }, @@ -125,23 +126,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "pp_size": 2, "num_microbatches": 4, "enable_all_optimization": True, - "use_lazy_init": True, - "enable_sequence_parallelism": True, - "precision": "fp32", - }, - { - "tp_size": 4, - "pp_size": 1, - "enable_all_optimization": True, - "use_lazy_init": True, - "enable_sequence_parallelism": True, + #'use_lazy_init': True, "precision": "fp32", }, { "tp_size": 2, "pp_size": 1, "enable_all_optimization": True, - "use_lazy_init": True, + #'use_lazy_init': True, "zero_stage": 2, "precision": "fp16", "initial_scale": 1, @@ -151,7 +143,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "pp_size": 2, "num_microbatches": 2, "enable_all_optimization": True, - "use_lazy_init": True, + #'use_lazy_init': True, "zero_stage": 1, "precision": "fp16", "initial_scale": 1, @@ -160,8 +152,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) @clear_cache_before_run() def run_gptj_test(test_config): - print(test_config) - sub_model_zoo = model_zoo.get_sub_registry("transformers_gptj") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): @@ -222,7 +212,7 @@ def check_gptj_3d(rank, world_size, port): @rerun_if_address_is_in_use() @clear_cache_before_run() def test_gptj(): - spawn(check_gptj, 2) + spawn(check_gptj, 4) @pytest.mark.largedist @@ -233,6 +223,5 @@ def test_gptj_3d(): if __name__ == "__main__": - print("===hello===") test_gptj() test_gptj_3d() From 341effc25aced9c3c5da66174499dc0c4e59d626 Mon Sep 17 00:00:00 2001 From: ppt0011 Date: Thu, 12 Oct 2023 14:49:07 +0800 Subject: [PATCH 9/9] [shardformer] sync gptj config with hf due to flash attn head dim requirement --- tests/kit/model_zoo/transformers/gptj.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kit/model_zoo/transformers/gptj.py b/tests/kit/model_zoo/transformers/gptj.py index 9eefbb43dad8..263978512a02 100644 --- a/tests/kit/model_zoo/transformers/gptj.py +++ b/tests/kit/model_zoo/transformers/gptj.py @@ -61,7 +61,7 @@ def data_gen_for_sequence_classification(): config = transformers.GPTJConfig( n_layer=2, - n_head=4, + n_head=16, vocab_size=50258, attn_pdrop=0, embd_pdrop=0,