From 57d4fabdf036ca9b6c59f1fe53d41b3d15d70b64 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Fri, 30 Jun 2023 19:19:24 +0800 Subject: [PATCH 01/27] add pipeline policy and bert forward to be done --- colossalai/pipeline/policy/__init__.py | 20 ++ colossalai/pipeline/policy/base.py | 108 +++++++ colossalai/pipeline/policy/bert.py | 295 +++++++++++++++++++ colossalai/pipeline/policy/llama.py | 258 ++++++++++++++++ tests/test_pipeline/test_policy/test_bert.py | 57 ++++ tests/test_pipeline/test_stage_manager.py | 2 +- 6 files changed, 739 insertions(+), 1 deletion(-) create mode 100644 colossalai/pipeline/policy/__init__.py create mode 100644 colossalai/pipeline/policy/base.py create mode 100644 colossalai/pipeline/policy/bert.py create mode 100644 colossalai/pipeline/policy/llama.py create mode 100644 tests/test_pipeline/test_policy/test_bert.py diff --git a/colossalai/pipeline/policy/__init__.py b/colossalai/pipeline/policy/__init__.py new file mode 100644 index 000000000000..cd372a28b79c --- /dev/null +++ b/colossalai/pipeline/policy/__init__.py @@ -0,0 +1,20 @@ +from typing import Any, Dict, List, Optional, Tuple, Type + +from torch import Tensor +from torch.nn import Module, Parameter + +from colossalai.pipeline.stage_manager import PipelineStageManager + +from .base import Policy +from .llama import LlamaForCausalLM, LlamaForCausalLMPolicy + +POLICY_MAP: Dict[Type[Module], Type[Policy]] = { + LlamaForCausalLM: LlamaForCausalLMPolicy, +} + + +def pipeline_parallelize(model: Module, stage_manager: PipelineStageManager) -> Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]: + if type(model) not in POLICY_MAP: + raise NotImplementedError(f"Policy for {type(model)} not implemented") + policy = POLICY_MAP[type(model)](stage_manager) + return policy.parallelize_model(model) diff --git a/colossalai/pipeline/policy/base.py b/colossalai/pipeline/policy/base.py new file mode 100644 index 000000000000..ad595a04b1b0 --- /dev/null +++ b/colossalai/pipeline/policy/base.py @@ -0,0 +1,108 @@ +from typing import Any, Dict, List, Optional, Tuple + +from colossalai.lazy import LazyTensor +from torch import Tensor +from torch.nn import Module, Parameter + +from colossalai.pipeline.stage_manager import PipelineStageManager + + +class Policy: + def __init__(self, stage_manager: PipelineStageManager) -> None: + self.stage_manager = stage_manager + + def setup_model(self, module: Module) -> Tuple[Dict[str, Parameter], Dict[str, Tensor]]: + """Setup model for pipeline parallel + + Args: + module (Module): Module to be setup + + Returns: + Tuple[Dict[str, Parameter], Dict[str, Tensor]]: Hold parameters and buffers + """ + hold_params = set() + hold_buffers = set() + + def init_layer(layer: Module): + for p in layer.parameters(): + if isinstance(p, LazyTensor): + p.materialize() + p.data = p.cuda() + hold_params.add(p) + for b in layer.buffers(): + if isinstance(b, LazyTensor): + b.materialize() + b.data = b.cuda() + hold_buffers.add(b) + + hold_layers = self.get_hold_layers(module) + + for layer in hold_layers: + init_layer(layer) + + hold_params_dict = {} + hold_buffers_dict = {} + + # release other tensors + for n, p in module.named_parameters(): + if p in hold_params: + hold_params_dict[n] = p + else: + if isinstance(p, LazyTensor): + p.materialize() + p.data = p.cuda() + p.storage().resize_(0) + for n, b in module.named_buffers(): + if b in hold_buffers: + hold_buffers_dict[n] = b + else: + if isinstance(b, LazyTensor): + b.materialize() + b.data = b.cuda() + # FIXME(ver217): use meta tensor may be better + b.storage().resize_(0) + return hold_params_dict, hold_buffers_dict + + def replace_forward(self, module: Module) -> None: + """Replace module forward in place. This method should be implemented by subclass. The output of internal layers must be a dict + + Args: + module (Module): _description_ + """ + raise NotImplementedError + + def get_hold_layers(self, module: Module) -> List[Module]: + """Get layers that should be hold in current stage. This method should be implemented by subclass. + + Args: + module (Module): Module to be setup + + Returns: + List[Module]: List of layers that should be hold in current stage + """ + raise NotImplementedError + + def get_shared_params(self, module: Module) -> List[Dict[int, Tensor]]: + """Get parameters that should be shared across stages. This method should be implemented by subclass. + + Args: + module (Module): Module to be setup + + Returns: + List[Module]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}] + """ + raise NotImplementedError + + def parallelize_model(self, module: Module) -> Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]: + """Parallelize model for pipeline parallel + + Args: + module (Module): Module to be setup + + Returns: + Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]: Hold parameters, buffers and shared parameters + """ + hold_params, hold_buffers = self.setup_model(module) + self.replace_forward(module) + shared_params = self.get_shared_params(module) + return hold_params, hold_buffers, shared_params diff --git a/colossalai/pipeline/policy/bert.py b/colossalai/pipeline/policy/bert.py new file mode 100644 index 000000000000..00aabf3984ef --- /dev/null +++ b/colossalai/pipeline/policy/bert.py @@ -0,0 +1,295 @@ +from functools import partial +from types import MethodType +from typing import Dict, List, Optional, Tuple, Union + +import torch +from torch import Tensor +from torch.nn import CrossEntropyLoss, Module +from transformers.modeling_outputs import (BaseModelOutputWithPast, + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions) +from transformers.models.bert.modeling_bert import BertModel +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager + +from .base import Policy + +logger = logging.get_logger(__name__) + +def bert_model_forward(self:BertModel, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + #labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, #this is from the previous stage + ) : + #TODO: add explaination of the output here. + + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + # preprocess: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if use_cache: + logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + use_cache = False + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + # assure that the input is embedding_output and is the hidden_states of previous stages. + + hidden_states = input_ids if input_ids is not None else None + if stage_manager.is_first_stage(): + hidden_states= self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + + + encoder_outputs = None + #inherit from bert_layer + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + if self.encoder.gradient_checkpointing and self.encoder.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + next_decoder_cache = () if use_cache else None + + #calculate the num_layers + num_layers_per_stage = len(self.encoder.layer) // stage_manager.num_stages + start_layer = stage_manager.stage * num_layers_per_stage + end_layer = (stage_manager.stage + 1) * num_layers_per_stage + + for idx, encoder_layer in enumerate(self.encoder.layer[start_layer:end_layer], start=start_layer): + if stage_manager.is_first_stage() and idx == 0: + attention_mask = extended_attention_mask + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[idx] if head_mask is not None else None + past_key_value = past_key_values[idx] if past_key_values is not None else None + + ### + print('where is the model now',start_layer,idx,end_layer) + print('what stage is now',stage_manager.stage) + + if self.encoder.gradient_checkpointing and self.encoder.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + if stage_manager.stage == 1: + if hidden_states is not None : + print('shape of hidden_states',hidden_states.shape) + if attention_mask is not None : + print('shape of attention_mask',attention_mask.shape) + ## TODO: check for this layer_head_mask + if layer_head_mask is not None : + print('shape of layer_head_mask',layer_head_mask.shape) + if encoder_hidden_states is not None : + print('shape of encoder_hidden_states',encoder_hidden_states.shape) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + #end of a stage loop + sequence_output = layer_outputs[0] if layer_outputs is not None else None + + if stage_manager.is_last_stage(): + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + + #output of non-first and non-last stages: + if not return_dict: + return tuple(v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] if v is not None) + + #return dict is not supported at this moment + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +# class BertModelPolicy(Policy): +# def get_hold_layers(self, module: BertModel) -> List[Module]: +# # get pipeline layers for curerent stage +# hold_layers = [] +# if self.stage_manager.is_first_stage(): +# hold_layers.append(module.embeddings) +# #Fix: num_layers_per_stage should be calculated based on the number of layers in the model +# num_layers_per_stage = len(module.encoder.layer) // self.stage_manager.num_stages + +# hold_layers.extend(module.encoder.layer[self.stage_manager.stage* +# num_layers_per_stage : (self.stage_manager.stage+1)* num_layers_per_stage]) +# if self.stage_manager.is_last_stage(): +# hold_layers.append(module.pooler) + +# return hold_layers + +# def get_shared_params(self, module: BertModel) -> List[Dict[int, Tensor]]: +# if id(module.embeddings.parameters) == id(module.pooler.parameters) +# return [dict(module.embeddings.named_parameters())] +# return [] +# def replace_forward(self, module: Module) -> None: +# return super().replace_forward(module) + +''' +def bert_pretraining_model_forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + next_sentence_label: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.LongTensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + + ) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]: + pass +''' \ No newline at end of file diff --git a/colossalai/pipeline/policy/llama.py b/colossalai/pipeline/policy/llama.py new file mode 100644 index 000000000000..d83683ccb264 --- /dev/null +++ b/colossalai/pipeline/policy/llama.py @@ -0,0 +1,258 @@ +from functools import partial +from types import MethodType +from typing import Dict, List, Optional, Tuple, Union + +import torch +from torch import Tensor +from torch.nn import CrossEntropyLoss, Module +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutput, + CausalLMOutputWithPast) +from transformers.models.llama.modeling_llama import (LlamaForCausalLM, + LlamaModel) +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager + +from .base import Policy + +logger = logging.get_logger(__name__) + + +def llama_model_forward(self: LlamaModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, # this is set by partial + hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage + ) -> Union[CausalLMOutput, Tuple]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if output_attentions: + logger.warning_once('`output_attentions=True` is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('`output_hidden_states=True` is not supported for pipeline models at the moment.') + output_hidden_states = False + if use_cache: + logger.warning_once('`use_cache=True` is not supported for pipeline models at the moment.') + use_cache = False + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + if stage_manager.is_first_stage(): + inputs_embeds = self.embed_tokens(input_ids) + else: + inputs_embeds = hidden_states + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + # this function only uses inputs_embeds' device, dtype, and shape, it's safe to use hidden_state + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + num_layers_per_stage = len(self.layers) // stage_manager.num_stages + start_layer = stage_manager.stage * num_layers_per_stage + end_layer = (stage_manager.stage + 1) * num_layers_per_stage + + for idx, decoder_layer in enumerate(self.layers[start_layer:end_layer], start=start_layer): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if stage_manager.is_last_stage(): + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + # TODO(ver217): return_dict is not supported for pipeline models at the moment. + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +def llama_for_causal_lm_forward(self: LlamaForCausalLM, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, # this is set by partial + hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + ) + + hidden_states = outputs[0] + if not stage_manager.is_last_stage(): + return dict(hidden_states=hidden_states) + + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + ) + + +class LlamaForCausalLMPolicy(Policy): + def get_hold_layers(self, module: LlamaForCausalLM) -> List[Module]: + hold_layers = [] + + if self.stage_manager.is_first_stage(): + hold_layers.append(module.model.embed_tokens) + num_layers_per_stage = len(module.model.layers) // self.stage_manager.num_stages + hold_layers.extend(module.model.layers[self.stage_manager.stage * + num_layers_per_stage: (self.stage_manager.stage + 1) * num_layers_per_stage]) + if self.stage_manager.is_last_stage(): + hold_layers.append(module.model.norm) + hold_layers.append(module.lm_head) + + return hold_layers + + def get_shared_params(self, module: LlamaForCausalLM) -> List[Dict[int, Tensor]]: + if id(module.model.embed_tokens.weight) == id(module.lm_head.weight): + # tie weights + return [{0: module.model.embed_tokens.weight, self.stage_manager.num_stages - 1: module.lm_head.weight}] + return [] + + def replace_forward(self, module: LlamaForCausalLM) -> None: + module.model.forward = MethodType(partial(llama_model_forward, stage_manager=self.stage_manager), module.model) + module.forward = MethodType(partial(llama_for_causal_lm_forward, stage_manager=self.stage_manager), module) diff --git a/tests/test_pipeline/test_policy/test_bert.py b/tests/test_pipeline/test_policy/test_bert.py new file mode 100644 index 000000000000..0e27802da13e --- /dev/null +++ b/tests/test_pipeline/test_policy/test_bert.py @@ -0,0 +1,57 @@ +import torch +import pytest +import torch.distributed as dist +from colossalai.cluster import ProcessGroupMesh +import colossalai +from colossalai.testing import rerun_if_address_is_in_use, spawn + +from colossalai.pipeline.policy.bert import bert_model_forward +from colossalai.pipeline.stage_manager import PipelineStageManager +from transformers.models.bert.modeling_bert import BertModel + +def check_bert_model_forward(): + model = BertModel.from_pretrained('bert-base-uncased') + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + #print(pg_mesh) + + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + rank = dist.get_rank() + #print(rank) + + x = torch.randint(0, 1000, (2, 3)) + attention_mask = torch.ones_like(x) + + output = bert_model_forward(self=model, input_ids=x, attention_mask=attention_mask, + stage_manager=stage_manager) + print(output) + assert output[0].shape == (2, 3, 768) + # assert output[1].shape == (2, 768) + + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') + check_bert_model_forward() + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bert_model_forward(): + spawn(run_dist, 4) + + +if __name__ == "__main__": + test_bert_model_forward() diff --git a/tests/test_pipeline/test_stage_manager.py b/tests/test_pipeline/test_stage_manager.py index be4591d58f74..67a2e90532e2 100644 --- a/tests/test_pipeline/test_stage_manager.py +++ b/tests/test_pipeline/test_stage_manager.py @@ -21,7 +21,7 @@ def check_stage_manager(): 1: [0, 1], 2: [2, 3], 3: [2, 3], - } + } pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) stage_manager = PipelineStageManager(pg_mesh, PP_DIM) rank = dist.get_rank() From 8300f451863649375db9fdb062e9cfe0990e5a6f Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 3 Jul 2023 14:53:32 +0800 Subject: [PATCH 02/27] add bertmodel pipeline forward and make tests --- colossalai/pipeline/policy/bert.py | 97 ++++--- colossalai/pipeline/policy/llama.py | 258 ------------------- tests/test_pipeline/test_policy/test_bert.py | 23 +- 3 files changed, 61 insertions(+), 317 deletions(-) delete mode 100644 colossalai/pipeline/policy/llama.py diff --git a/colossalai/pipeline/policy/bert.py b/colossalai/pipeline/policy/bert.py index 00aabf3984ef..1b9cdaecf9eb 100644 --- a/colossalai/pipeline/policy/bert.py +++ b/colossalai/pipeline/policy/bert.py @@ -57,6 +57,7 @@ def bert_model_forward(self:BertModel, If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). """ + # debugging # preprocess: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -69,15 +70,26 @@ def bert_model_forward(self:BertModel, else: use_cache = False - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] + if stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + attention_mask = extended_attention_mask else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - + input_shape = hidden_states.size()[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device + if output_attentions: logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') output_attentions = False @@ -88,8 +100,7 @@ def bert_model_forward(self:BertModel, logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') use_cache = False - batch_size, seq_length = input_shape - device = input_ids.device if input_ids is not None else inputs_embeds.device + # past_key_values_length past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 @@ -105,10 +116,24 @@ def bert_model_forward(self:BertModel, else: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + hidden_states = hidden_states if hidden_states is not None else None + if stage_manager.is_first_stage(): + hidden_states= self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if self.config.is_decoder and encoder_hidden_states is not None: @@ -120,27 +145,7 @@ def bert_model_forward(self:BertModel, else: encoder_extended_attention_mask = None - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - - # assure that the input is embedding_output and is the hidden_states of previous stages. - hidden_states = input_ids if input_ids is not None else None - if stage_manager.is_first_stage(): - hidden_states= self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, - ) - - - encoder_outputs = None #inherit from bert_layer all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -159,22 +164,19 @@ def bert_model_forward(self:BertModel, start_layer = stage_manager.stage * num_layers_per_stage end_layer = (stage_manager.stage + 1) * num_layers_per_stage + #layer_outputs + layer_outputs = hidden_states if hidden_states is not None else None for idx, encoder_layer in enumerate(self.encoder.layer[start_layer:end_layer], start=start_layer): if stage_manager.is_first_stage() and idx == 0: - attention_mask = extended_attention_mask + encoder_attention_mask=encoder_extended_attention_mask if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[idx] if head_mask is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None - - ### - print('where is the model now',start_layer,idx,end_layer) - print('what stage is now',stage_manager.stage) - - if self.encoder.gradient_checkpointing and self.encoder.training: - + + if self.encoder.gradient_checkpointing and self.encoder.training: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs, past_key_value, output_attentions) @@ -190,16 +192,6 @@ def custom_forward(*inputs): encoder_attention_mask, ) else: - if stage_manager.stage == 1: - if hidden_states is not None : - print('shape of hidden_states',hidden_states.shape) - if attention_mask is not None : - print('shape of attention_mask',attention_mask.shape) - ## TODO: check for this layer_head_mask - if layer_head_mask is not None : - print('shape of layer_head_mask',layer_head_mask.shape) - if encoder_hidden_states is not None : - print('shape of encoder_hidden_states',encoder_hidden_states.shape) layer_outputs = encoder_layer( hidden_states, attention_mask, @@ -226,9 +218,8 @@ def custom_forward(*inputs): if stage_manager.is_last_stage(): pooled_output = self.pooler(sequence_output) if self.pooler is not None else None if not return_dict: - return (sequence_output, pooled_output) + encoder_outputs[1:] + return (sequence_output, pooled_output) + layer_outputs[1:] - #output of non-first and non-last stages: if not return_dict: return tuple(v diff --git a/colossalai/pipeline/policy/llama.py b/colossalai/pipeline/policy/llama.py deleted file mode 100644 index d83683ccb264..000000000000 --- a/colossalai/pipeline/policy/llama.py +++ /dev/null @@ -1,258 +0,0 @@ -from functools import partial -from types import MethodType -from typing import Dict, List, Optional, Tuple, Union - -import torch -from torch import Tensor -from torch.nn import CrossEntropyLoss, Module -from transformers.modeling_outputs import (BaseModelOutputWithPast, - CausalLMOutput, - CausalLMOutputWithPast) -from transformers.models.llama.modeling_llama import (LlamaForCausalLM, - LlamaModel) -from transformers.utils import logging - -from colossalai.pipeline.stage_manager import PipelineStageManager - -from .base import Policy - -logger = logging.get_logger(__name__) - - -def llama_model_forward(self: LlamaModel, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, # this is set by partial - hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage - ) -> Union[CausalLMOutput, Tuple]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - if output_attentions: - logger.warning_once('`output_attentions=True` is not supported for pipeline models at the moment.') - output_attentions = False - if output_hidden_states: - logger.warning_once('`output_hidden_states=True` is not supported for pipeline models at the moment.') - output_hidden_states = False - if use_cache: - logger.warning_once('`use_cache=True` is not supported for pipeline models at the moment.') - use_cache = False - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - if stage_manager.is_first_stage(): - inputs_embeds = self.embed_tokens(input_ids) - else: - inputs_embeds = hidden_states - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device - ) - # this function only uses inputs_embeds' device, dtype, and shape, it's safe to use hidden_state - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) - - hidden_states = inputs_embeds - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - num_layers_per_stage = len(self.layers) // stage_manager.num_stages - start_layer = stage_manager.stage * num_layers_per_stage - end_layer = (stage_manager.stage + 1) * num_layers_per_stage - - for idx, decoder_layer in enumerate(self.layers[start_layer:end_layer], start=start_layer): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - position_ids, - None, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if stage_manager.is_last_stage(): - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - # TODO(ver217): return_dict is not supported for pipeline models at the moment. - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -def llama_for_causal_lm_forward(self: LlamaForCausalLM, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, # this is set by partial - hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage - ) -> Union[Tuple, CausalLMOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - hidden_states=hidden_states, - ) - - hidden_states = outputs[0] - if not stage_manager.is_last_stage(): - return dict(hidden_states=hidden_states) - - logits = self.lm_head(hidden_states) - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - ) - - -class LlamaForCausalLMPolicy(Policy): - def get_hold_layers(self, module: LlamaForCausalLM) -> List[Module]: - hold_layers = [] - - if self.stage_manager.is_first_stage(): - hold_layers.append(module.model.embed_tokens) - num_layers_per_stage = len(module.model.layers) // self.stage_manager.num_stages - hold_layers.extend(module.model.layers[self.stage_manager.stage * - num_layers_per_stage: (self.stage_manager.stage + 1) * num_layers_per_stage]) - if self.stage_manager.is_last_stage(): - hold_layers.append(module.model.norm) - hold_layers.append(module.lm_head) - - return hold_layers - - def get_shared_params(self, module: LlamaForCausalLM) -> List[Dict[int, Tensor]]: - if id(module.model.embed_tokens.weight) == id(module.lm_head.weight): - # tie weights - return [{0: module.model.embed_tokens.weight, self.stage_manager.num_stages - 1: module.lm_head.weight}] - return [] - - def replace_forward(self, module: LlamaForCausalLM) -> None: - module.model.forward = MethodType(partial(llama_model_forward, stage_manager=self.stage_manager), module.model) - module.forward = MethodType(partial(llama_for_causal_lm_forward, stage_manager=self.stage_manager), module) diff --git a/tests/test_pipeline/test_policy/test_bert.py b/tests/test_pipeline/test_policy/test_bert.py index 0e27802da13e..4f9af46c485e 100644 --- a/tests/test_pipeline/test_policy/test_bert.py +++ b/tests/test_pipeline/test_policy/test_bert.py @@ -30,15 +30,26 @@ def check_bert_model_forward(): stage_manager = PipelineStageManager(pg_mesh, PP_DIM) rank = dist.get_rank() - #print(rank) + # print(rank) x = torch.randint(0, 1000, (2, 3)) - attention_mask = torch.ones_like(x) + hidden_states = torch.randint(0,1000,(2,3,768)).to(torch.float32) + if stage_manager.stage == 0: + attention_mask = torch.ones_like(x) + output = bert_model_forward(self=model, input_ids=x, attention_mask=attention_mask, + stage_manager=stage_manager) + print(output[0].shape) + assert output[0].shape == (2, 3, 768) + print('start the training') + else: + attention_mask = torch.ones((2,12,3,3)) + output = bert_model_forward(self=model, hidden_states=hidden_states, attention_mask=attention_mask, + stage_manager=stage_manager) + print(output[0].shape) + assert output[0].shape == (2, 3, 768) + print('end the training') + print(output) - output = bert_model_forward(self=model, input_ids=x, attention_mask=attention_mask, - stage_manager=stage_manager) - print(output) - assert output[0].shape == (2, 3, 768) # assert output[1].shape == (2, 768) From 246b6d3b7a3eba548f88eb4be604b85ff516bf60 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 3 Jul 2023 16:27:18 +0800 Subject: [PATCH 03/27] add Bert_Policy and test for policy --- colossalai/pipeline/policy/__init__.py | 5 +- colossalai/pipeline/policy/bert.py | 78 +++++++++++++++----- tests/test_pipeline/test_policy/test_bert.py | 42 ++++++++++- 3 files changed, 100 insertions(+), 25 deletions(-) diff --git a/colossalai/pipeline/policy/__init__.py b/colossalai/pipeline/policy/__init__.py index cd372a28b79c..cb4b99803119 100644 --- a/colossalai/pipeline/policy/__init__.py +++ b/colossalai/pipeline/policy/__init__.py @@ -6,10 +6,9 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from .base import Policy -from .llama import LlamaForCausalLM, LlamaForCausalLMPolicy - +from .bert import BertModel,BertModelPolicy POLICY_MAP: Dict[Type[Module], Type[Policy]] = { - LlamaForCausalLM: LlamaForCausalLMPolicy, + BertModel: BertModelPolicy, } diff --git a/colossalai/pipeline/policy/bert.py b/colossalai/pipeline/policy/bert.py index 1b9cdaecf9eb..d9ee53748126 100644 --- a/colossalai/pipeline/policy/bert.py +++ b/colossalai/pipeline/policy/bert.py @@ -240,29 +240,69 @@ def custom_forward(*inputs): cross_attentions=all_cross_attentions, ) +# The layer partition policy for bertmodel +class BertModelPolicy(Policy): + def __init__(self, stage_manager: PipelineStageManager, num_layers: int,num_stages: int): + self.stage_manager = stage_manager + self.layers_per_stage = self.distribute_layers(num_layers,num_stages) -# class BertModelPolicy(Policy): -# def get_hold_layers(self, module: BertModel) -> List[Module]: -# # get pipeline layers for curerent stage -# hold_layers = [] -# if self.stage_manager.is_first_stage(): -# hold_layers.append(module.embeddings) -# #Fix: num_layers_per_stage should be calculated based on the number of layers in the model -# num_layers_per_stage = len(module.encoder.layer) // self.stage_manager.num_stages + def get_hold_layers(self, module: BertModel) -> List[Module]: + # get pipeline layers for current stage + hold_layers = [] + if self.stage_manager.is_first_stage(): + hold_layers.append(module.embeddings) + num_layers_per_stage_accumulated = self.convert_into_accumulated() + hold_layers.extend(module.encoder.layer[num_layers_per_stage_accumulated \ + [self.stage_manager.stage-1] if self.stage_manager.stage > 0 else 0: + num_layers_per_stage_accumulated[self.stage_manager.stage]]) -# hold_layers.extend(module.encoder.layer[self.stage_manager.stage* -# num_layers_per_stage : (self.stage_manager.stage+1)* num_layers_per_stage]) -# if self.stage_manager.is_last_stage(): -# hold_layers.append(module.pooler) + if self.stage_manager.is_last_stage(): + hold_layers.append(module.pooler) -# return hold_layers + return hold_layers -# def get_shared_params(self, module: BertModel) -> List[Dict[int, Tensor]]: -# if id(module.embeddings.parameters) == id(module.pooler.parameters) -# return [dict(module.embeddings.named_parameters())] -# return [] -# def replace_forward(self, module: Module) -> None: -# return super().replace_forward(module) + def get_shared_params(self, module: BertModel) -> List[Dict[int, Tensor]]: + '''no shared params in bertmodel''' + pass + def replace_forward(self, module: Module) -> None: + module.model.forward = MethodType(partial(bert_model_forward,stage_manager=self.stage_manager), module.model) + + # divide layers into stages + def distribute_layers(self, num, stage_num) -> List[int]: + quotient = num // stage_num + remainder = num % stage_num + + # calculate the num_layers per stage + layers_per_stage = [quotient] * stage_num + + # deal with the rest layers + if remainder > 0: + middle_stages = (stage_num-1) // 2 + right_extra = remainder // 2 + left_extra = remainder - right_extra + + #divide the rest part + left=0 + right=0 + while left_extra > 0: + layers_per_stage[middle_stages - left] += 1 + left_extra -= 1 + left+= 1 + while right_extra > 0 : + layers_per_stage[middle_stages + right + 1] += 1 + right_extra -= 1 + right+=1 + return layers_per_stage + def convert_into_accumulated(self) -> List[int]: + '''convert a array into accumulated array''' + acc = 0 + layers_per_stage_accumulated=[] + for num in self.layers_per_stage: + acc += num + layers_per_stage_accumulated.append(acc) + return layers_per_stage_accumulated + + ''' def bert_pretraining_model_forward( diff --git a/tests/test_pipeline/test_policy/test_bert.py b/tests/test_pipeline/test_policy/test_bert.py index 4f9af46c485e..4545bc795d40 100644 --- a/tests/test_pipeline/test_policy/test_bert.py +++ b/tests/test_pipeline/test_policy/test_bert.py @@ -5,7 +5,7 @@ import colossalai from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.pipeline.policy.bert import bert_model_forward +from colossalai.pipeline.policy.bert import bert_model_forward,BertModelPolicy from colossalai.pipeline.stage_manager import PipelineStageManager from transformers.models.bert.modeling_bert import BertModel @@ -52,17 +52,53 @@ def check_bert_model_forward(): # assert output[1].shape == (2, 768) +def check_bert_model_policy(): + model = BertModel.from_pretrained('bert-base-uncased') + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + #print(pg_mesh) + + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + rank = dist.get_rank() + + model_policy = BertModelPolicy(stage_manager,len(model.encoder.layer),2) + assert model_policy.layers_per_stage == [6,6] + layers=model_policy.get_hold_layers(model) + for layer in layers: + print(layer) -def run_dist(rank, world_size, port): +def run_dist_model(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') check_bert_model_forward() +def run_dist_policy(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') + check_bert_model_policy() + @pytest.mark.dist @rerun_if_address_is_in_use() def test_bert_model_forward(): - spawn(run_dist, 4) + spawn(run_dist_model, 4) +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bert_model_policy(): + spawn(run_dist_policy, 4) if __name__ == "__main__": test_bert_model_forward() + test_bert_model_policy() \ No newline at end of file From db0a1f150df079c6027f430b32293f12ca3018ac Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 3 Jul 2023 17:05:17 +0800 Subject: [PATCH 04/27] update formatting --- colossalai/pipeline/policy/__init__.py | 7 +- colossalai/pipeline/policy/bert.py | 405 ++++++++++--------- tests/test_pipeline/test_policy/test_bert.py | 48 ++- 3 files changed, 237 insertions(+), 223 deletions(-) diff --git a/colossalai/pipeline/policy/__init__.py b/colossalai/pipeline/policy/__init__.py index cb4b99803119..fd9e6e04588e 100644 --- a/colossalai/pipeline/policy/__init__.py +++ b/colossalai/pipeline/policy/__init__.py @@ -6,13 +6,16 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from .base import Policy -from .bert import BertModel,BertModelPolicy +from .bert import BertModel, BertModelPolicy + POLICY_MAP: Dict[Type[Module], Type[Policy]] = { BertModel: BertModelPolicy, } -def pipeline_parallelize(model: Module, stage_manager: PipelineStageManager) -> Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]: +def pipeline_parallelize( + model: Module, + stage_manager: PipelineStageManager) -> Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]: if type(model) not in POLICY_MAP: raise NotImplementedError(f"Policy for {type(model)} not implemented") policy = POLICY_MAP[type(model)](stage_manager) diff --git a/colossalai/pipeline/policy/bert.py b/colossalai/pipeline/policy/bert.py index d9ee53748126..9fab35241767 100644 --- a/colossalai/pipeline/policy/bert.py +++ b/colossalai/pipeline/policy/bert.py @@ -5,10 +5,12 @@ import torch from torch import Tensor from torch.nn import CrossEntropyLoss, Module -from transformers.modeling_outputs import (BaseModelOutputWithPast, - BaseModelOutputWithPastAndCrossAttentions, - BaseModelOutputWithPoolingAndCrossAttentions) -from transformers.models.bert.modeling_bert import BertModel +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, +) +from transformers.models.bert.modeling_bert import BertModel from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager @@ -17,7 +19,9 @@ logger = logging.get_logger(__name__) -def bert_model_forward(self:BertModel, + +def bert_model_forward( + self: BertModel, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, @@ -27,17 +31,16 @@ def bert_model_forward(self:BertModel, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, - #labels: Optional[torch.LongTensor] = None, + #labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, #this is from the previous stage - ) : - #TODO: add explaination of the output here. - - r""" + hidden_states: Optional[torch.FloatTensor] = None, #this is from the previous stage +): + #TODO: add explaination of the output here. + r""" encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is configured as a decoder. @@ -57,197 +60,195 @@ def bert_model_forward(self:BertModel, If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). """ - # debugging - # preprocess: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if self.config.is_decoder: - use_cache = use_cache if use_cache is not None else self.config.use_cache + # debugging + # preprocess: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] else: - use_cache = False - - if stage_manager.is_first_stage(): - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - batch_size, seq_length = input_shape - device = input_ids.device if input_ids is not None else inputs_embeds.device - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) - attention_mask = extended_attention_mask + raise ValueError("You have to specify either input_ids or inputs_embeds") + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + attention_mask = extended_attention_mask + else: + input_shape = hidden_states.size()[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device + + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if use_cache: + logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + use_cache = False + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded else: - input_shape = hidden_states.size()[:-1] - batch_size, seq_length = input_shape - device = hidden_states.device - - if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') - output_attentions = False - if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') - output_hidden_states = False - if use_cache: - logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') - use_cache = False - - - - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - - if attention_mask is None: - attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) - - if token_type_ids is None: - if hasattr(self.embeddings, "token_type_ids"): - buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] - buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) - token_type_ids = buffered_token_type_ids_expanded - else: - token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) - - - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - - hidden_states = hidden_states if hidden_states is not None else None - if stage_manager.is_first_stage(): - hidden_states= self.embeddings( + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + hidden_states = hidden_states if hidden_states is not None else None + if stage_manager.is_first_stage(): + hidden_states = self.embeddings( input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds, past_key_values_length=past_key_values_length, ) - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.is_decoder and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) - else: - encoder_extended_attention_mask = None + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + #inherit from bert_layer + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + if self.encoder.gradient_checkpointing and self.encoder.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + next_decoder_cache = () if use_cache else None + #calculate the num_layers + num_layers_per_stage = len(self.encoder.layer) // stage_manager.num_stages + start_layer = stage_manager.stage * num_layers_per_stage + end_layer = (stage_manager.stage + 1) * num_layers_per_stage - #inherit from bert_layer - all_hidden_states = () if output_hidden_states else None - all_self_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + #layer_outputs + layer_outputs = hidden_states if hidden_states is not None else None + for idx, encoder_layer in enumerate(self.encoder.layer[start_layer:end_layer], start=start_layer): + if stage_manager.is_first_stage() and idx == 0: + encoder_attention_mask = encoder_extended_attention_mask - if self.encoder.gradient_checkpointing and self.encoder.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - next_decoder_cache = () if use_cache else None - - #calculate the num_layers - num_layers_per_stage = len(self.encoder.layer) // stage_manager.num_stages - start_layer = stage_manager.stage * num_layers_per_stage - end_layer = (stage_manager.stage + 1) * num_layers_per_stage - - #layer_outputs - layer_outputs = hidden_states if hidden_states is not None else None - for idx, encoder_layer in enumerate(self.encoder.layer[start_layer:end_layer], start=start_layer): - if stage_manager.is_first_stage() and idx == 0: - encoder_attention_mask=encoder_extended_attention_mask - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_head_mask = head_mask[idx] if head_mask is not None else None - past_key_value = past_key_values[idx] if past_key_values is not None else None - - if self.encoder.gradient_checkpointing and self.encoder.training: - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1],) - if output_attentions: - all_self_attentions = all_self_attentions + (layer_outputs[1],) - if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (layer_outputs[2],) - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[idx] if head_mask is not None else None + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.encoder.gradient_checkpointing and self.encoder.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions) - #end of a stage loop - sequence_output = layer_outputs[0] if layer_outputs is not None else None + return custom_forward - if stage_manager.is_last_stage(): - pooled_output = self.pooler(sequence_output) if self.pooler is not None else None - if not return_dict: - return (sequence_output, pooled_output) + layer_outputs[1:] - - #output of non-first and non-last stages: + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + #end of a stage loop + sequence_output = layer_outputs[0] if layer_outputs is not None else None + + if stage_manager.is_last_stage(): + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None if not return_dict: - return tuple(v - for v in [ - hidden_states, - next_decoder_cache, - all_hidden_states, - all_self_attentions, - all_cross_attentions, - ] if v is not None) - - #return dict is not supported at this moment - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - cross_attentions=all_cross_attentions, - ) + return (sequence_output, pooled_output) + layer_outputs[1:] + + #output of non-first and non-last stages: + if not return_dict: + return tuple(v for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] if v is not None) + + #return dict is not supported at this moment + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + # The layer partition policy for bertmodel class BertModelPolicy(Policy): - def __init__(self, stage_manager: PipelineStageManager, num_layers: int,num_stages: int): + + def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int): self.stage_manager = stage_manager - self.layers_per_stage = self.distribute_layers(num_layers,num_stages) + self.layers_per_stage = self.distribute_layers(num_layers, num_stages) def get_hold_layers(self, module: BertModel) -> List[Module]: - # get pipeline layers for current stage + """ + get pipeline layers for current stage + """ hold_layers = [] if self.stage_manager.is_first_stage(): hold_layers.append(module.embeddings) @@ -255,53 +256,55 @@ def get_hold_layers(self, module: BertModel) -> List[Module]: hold_layers.extend(module.encoder.layer[num_layers_per_stage_accumulated \ [self.stage_manager.stage-1] if self.stage_manager.stage > 0 else 0: num_layers_per_stage_accumulated[self.stage_manager.stage]]) - + if self.stage_manager.is_last_stage(): hold_layers.append(module.pooler) return hold_layers - + def get_shared_params(self, module: BertModel) -> List[Dict[int, Tensor]]: '''no shared params in bertmodel''' pass + def replace_forward(self, module: Module) -> None: - module.model.forward = MethodType(partial(bert_model_forward,stage_manager=self.stage_manager), module.model) + module.model.forward = MethodType(partial(bert_model_forward, stage_manager=self.stage_manager), module.model) - # divide layers into stages def distribute_layers(self, num, stage_num) -> List[int]: - quotient = num // stage_num - remainder = num % stage_num + """ + divide layers into stages + """ + quotient = num // stage_num + remainder = num % stage_num # calculate the num_layers per stage layers_per_stage = [quotient] * stage_num # deal with the rest layers if remainder > 0: - middle_stages = (stage_num-1) // 2 - right_extra = remainder // 2 - left_extra = remainder - right_extra - + middle_stages = (stage_num - 1) // 2 + right_extra = remainder // 2 + left_extra = remainder - right_extra + #divide the rest part - left=0 - right=0 + left = 0 + right = 0 while left_extra > 0: layers_per_stage[middle_stages - left] += 1 left_extra -= 1 - left+= 1 - while right_extra > 0 : - layers_per_stage[middle_stages + right + 1] += 1 + left += 1 + while right_extra > 0: + layers_per_stage[middle_stages + right + 1] += 1 right_extra -= 1 - right+=1 + right += 1 return layers_per_stage + def convert_into_accumulated(self) -> List[int]: - '''convert a array into accumulated array''' acc = 0 - layers_per_stage_accumulated=[] + layers_per_stage_accumulated = [] for num in self.layers_per_stage: acc += num layers_per_stage_accumulated.append(acc) return layers_per_stage_accumulated - ''' @@ -323,4 +326,4 @@ def bert_pretraining_model_forward( ) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]: pass -''' \ No newline at end of file +''' diff --git a/tests/test_pipeline/test_policy/test_bert.py b/tests/test_pipeline/test_policy/test_bert.py index 4545bc795d40..c92f7f6c34c0 100644 --- a/tests/test_pipeline/test_policy/test_bert.py +++ b/tests/test_pipeline/test_policy/test_bert.py @@ -1,13 +1,14 @@ -import torch import pytest +import torch import torch.distributed as dist -from colossalai.cluster import ProcessGroupMesh +from transformers.models.bert.modeling_bert import BertModel + import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.pipeline.policy.bert import BertModelPolicy, bert_model_forward +from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.pipeline.policy.bert import bert_model_forward,BertModelPolicy -from colossalai.pipeline.stage_manager import PipelineStageManager -from transformers.models.bert.modeling_bert import BertModel def check_bert_model_forward(): model = BertModel.from_pretrained('bert-base-uncased') @@ -24,34 +25,36 @@ def check_bert_model_forward(): 1: [0, 1], 2: [2, 3], 3: [2, 3], - } + } pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) #print(pg_mesh) stage_manager = PipelineStageManager(pg_mesh, PP_DIM) rank = dist.get_rank() # print(rank) - - x = torch.randint(0, 1000, (2, 3)) - hidden_states = torch.randint(0,1000,(2,3,768)).to(torch.float32) + + x = torch.randint(0, 1000, (2, 3)) + hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32) if stage_manager.stage == 0: attention_mask = torch.ones_like(x) - output = bert_model_forward(self=model, input_ids=x, attention_mask=attention_mask, - stage_manager=stage_manager) + output = bert_model_forward(self=model, input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) print(output[0].shape) assert output[0].shape == (2, 3, 768) print('start the training') else: - attention_mask = torch.ones((2,12,3,3)) - output = bert_model_forward(self=model, hidden_states=hidden_states, attention_mask=attention_mask, + attention_mask = torch.ones((2, 12, 3, 3)) + output = bert_model_forward(self=model, + hidden_states=hidden_states, + attention_mask=attention_mask, stage_manager=stage_manager) print(output[0].shape) assert output[0].shape == (2, 3, 768) print('end the training') print(output) - + # assert output[1].shape == (2, 768) + def check_bert_model_policy(): model = BertModel.from_pretrained('bert-base-uncased') DP_DIM, PP_DIM = 0, 1 @@ -67,16 +70,16 @@ def check_bert_model_policy(): 1: [0, 1], 2: [2, 3], 3: [2, 3], - } + } pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) #print(pg_mesh) stage_manager = PipelineStageManager(pg_mesh, PP_DIM) rank = dist.get_rank() - model_policy = BertModelPolicy(stage_manager,len(model.encoder.layer),2) - assert model_policy.layers_per_stage == [6,6] - layers=model_policy.get_hold_layers(model) + model_policy = BertModelPolicy(stage_manager, len(model.encoder.layer), 2) + assert model_policy.layers_per_stage == [6, 6] + layers = model_policy.get_hold_layers(model) for layer in layers: print(layer) @@ -85,20 +88,25 @@ def run_dist_model(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') check_bert_model_forward() + def run_dist_policy(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') - check_bert_model_policy() + check_bert_model_policy() + @pytest.mark.dist @rerun_if_address_is_in_use() def test_bert_model_forward(): spawn(run_dist_model, 4) + @pytest.mark.dist @rerun_if_address_is_in_use() def test_bert_model_policy(): spawn(run_dist_policy, 4) + if __name__ == "__main__": + """test the bert model forward and bert model policy""" test_bert_model_forward() - test_bert_model_policy() \ No newline at end of file + test_bert_model_policy() From 9f57067d72a39f95ebbea01e6a2208bcd532caa2 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 3 Jul 2023 17:12:31 +0800 Subject: [PATCH 05/27] update formatting --- colossalai/pipeline/policy/bert.py | 35 ++++++++++++++---------------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/colossalai/pipeline/policy/bert.py b/colossalai/pipeline/policy/bert.py index 9fab35241767..c862e9297044 100644 --- a/colossalai/pipeline/policy/bert.py +++ b/colossalai/pipeline/policy/bert.py @@ -10,7 +10,7 @@ BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, ) -from transformers.models.bert.modeling_bert import BertModel +from transformers.models.bert.modeling_bert import BertForPreTrainingOutput, BertModel from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager @@ -307,23 +307,20 @@ def convert_into_accumulated(self) -> List[int]: return layers_per_stage_accumulated -''' def bert_pretraining_model_forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - next_sentence_label: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - hidden_states: Optional[torch.LongTensor] = None, - stage_manager: Optional[PipelineStageManager] = None, - - ) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]: + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + next_sentence_label: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.LongTensor] = None, + stage_manager: Optional[PipelineStageManager] = None, +) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]: pass -''' From dac6427377aadd62632834fdecb51a534284c28f Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 3 Jul 2023 18:19:13 +0800 Subject: [PATCH 06/27] update the code --- colossalai/pipeline/policy/bert.py | 37 ++++++++---------------------- 1 file changed, 9 insertions(+), 28 deletions(-) diff --git a/colossalai/pipeline/policy/bert.py b/colossalai/pipeline/policy/bert.py index c862e9297044..15be48b47b4e 100644 --- a/colossalai/pipeline/policy/bert.py +++ b/colossalai/pipeline/policy/bert.py @@ -2,6 +2,7 @@ from types import MethodType from typing import Dict, List, Optional, Tuple, Union +import numpy as np import torch from torch import Tensor from torch.nn import CrossEntropyLoss, Module @@ -252,7 +253,7 @@ def get_hold_layers(self, module: BertModel) -> List[Module]: hold_layers = [] if self.stage_manager.is_first_stage(): hold_layers.append(module.embeddings) - num_layers_per_stage_accumulated = self.convert_into_accumulated() + num_layers_per_stage_accumulated = np.cumsum(self.layers_per_stage) hold_layers.extend(module.encoder.layer[num_layers_per_stage_accumulated \ [self.stage_manager.stage-1] if self.stage_manager.stage > 0 else 0: num_layers_per_stage_accumulated[self.stage_manager.stage]]) @@ -269,43 +270,23 @@ def get_shared_params(self, module: BertModel) -> List[Dict[int, Tensor]]: def replace_forward(self, module: Module) -> None: module.model.forward = MethodType(partial(bert_model_forward, stage_manager=self.stage_manager), module.model) - def distribute_layers(self, num, stage_num) -> List[int]: + def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]: """ divide layers into stages """ - quotient = num // stage_num - remainder = num % stage_num + quotient = num_layers // num_stages + remainder = num_layers % num_stages # calculate the num_layers per stage - layers_per_stage = [quotient] * stage_num + layers_per_stage = [quotient] * num_stages # deal with the rest layers if remainder > 0: - middle_stages = (stage_num - 1) // 2 - right_extra = remainder // 2 - left_extra = remainder - right_extra - - #divide the rest part - left = 0 - right = 0 - while left_extra > 0: - layers_per_stage[middle_stages - left] += 1 - left_extra -= 1 - left += 1 - while right_extra > 0: - layers_per_stage[middle_stages + right + 1] += 1 - right_extra -= 1 - right += 1 + start_position = num_layers // 2 - remainder // 2 + for i in range(start_position, start_position + remainder): + layers_per_stage[i] += 1 return layers_per_stage - def convert_into_accumulated(self) -> List[int]: - acc = 0 - layers_per_stage_accumulated = [] - for num in self.layers_per_stage: - acc += num - layers_per_stage_accumulated.append(acc) - return layers_per_stage_accumulated - def bert_pretraining_model_forward( self, From 8b30a0223088f9f4b1efdd577aee7d8f6e604aba Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 4 Jul 2023 10:22:05 +0800 Subject: [PATCH 07/27] fix bugs --- colossalai/pipeline/policy/bert.py | 91 ++++++++++++++++- colossalai/pipeline/policy/bloom.py | 153 ++++++++++++++++++++++++++++ 2 files changed, 240 insertions(+), 4 deletions(-) create mode 100644 colossalai/pipeline/policy/bloom.py diff --git a/colossalai/pipeline/policy/bert.py b/colossalai/pipeline/policy/bert.py index 15be48b47b4e..6f912d2c6b80 100644 --- a/colossalai/pipeline/policy/bert.py +++ b/colossalai/pipeline/policy/bert.py @@ -11,7 +11,7 @@ BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, ) -from transformers.models.bert.modeling_bert import BertForPreTrainingOutput, BertModel +from transformers.models.bert.modeling_bert import BertForPreTraining, BertForPreTrainingOutput, BertModel from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager @@ -288,8 +288,8 @@ def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]: return layers_per_stage -def bert_pretraining_model_forward( - self, +def bert_for_pretraining_forward( + self: BertForPreTraining, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, @@ -304,4 +304,87 @@ def bert_pretraining_model_forward( hidden_states: Optional[torch.LongTensor] = None, stage_manager: Optional[PipelineStageManager] = None, ) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]: - pass + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output, pooled_output = outputs[:2] + prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + + total_loss = None + if labels is not None and next_sentence_label is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) + total_loss = masked_lm_loss + next_sentence_loss + + if not return_dict: + output = (prediction_scores, seq_relationship_score) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return BertForPreTrainingOutput( + loss=total_loss, + prediction_logits=prediction_scores, + seq_relationship_logits=seq_relationship_score, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class BertForPreTrainingPolicy(Policy): + + def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int): + self.stage_manager = stage_manager + self.layers_per_stage = self.distribute_layers(num_layers, num_stages) + + def get_hold_layers(self, module: BertForPreTraining) -> List[Module]: + """ + get pipeline layers for current stage + """ + hold_layers = [] + if self.stage_manager.is_first_stage(): + hold_layers.append(module.bert.embeddings) + num_layers_per_stage_accumulated = np.cumsum(self.layers_per_stage) + hold_layers.extend(module.bert.encoder.layer[num_layers_per_stage_accumulated \ + [self.stage_manager.stage-1] if self.stage_manager.stage > 0 else 0: + num_layers_per_stage_accumulated[self.stage_manager.stage]]) + if self.stage_manager.is_last_stage(): + hold_layers.append(module.cls) + + return hold_layers + + def get_shared_params(self, module: BertForPreTraining) -> List[Dict[int, Tensor]]: + '''no shared params in bertmodel''' + pass + + def replace_forward(self, module: Module) -> None: + module.model.forward = MethodType(partial(bert_for_pretraining_forward, stage_manager=self.stage_manager), + module.model) + + def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]: + """ + divide layers into stages + """ + quotient = num_layers // num_stages + remainder = num_layers % num_stages + + # calculate the num_layers per stage + layers_per_stage = [quotient] * num_stages + + # deal with the rest layers + if remainder > 0: + start_position = num_layers // 2 - remainder // 2 + for i in range(start_position, start_position + remainder): + layers_per_stage[i] += 1 + return layers_per_stage diff --git a/colossalai/pipeline/policy/bloom.py b/colossalai/pipeline/policy/bloom.py new file mode 100644 index 000000000000..8dffcd8f9af5 --- /dev/null +++ b/colossalai/pipeline/policy/bloom.py @@ -0,0 +1,153 @@ +from functools import partial +from types import MethodType +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from torch import Tensor +from torch.nn import CrossEntropyLoss, Module +from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions +from transformers.models.bloom.modeling_bloom import BloomModel +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager + +from .base import Policy + + +def bloom_model_forward( + self: BloomModel, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **deprecated_arguments, +) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: + if deprecated_arguments.pop("position_ids", False) is not False: + # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` + warnings.warn( + "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" + " passing `position_ids`.", + FutureWarning, + ) + if len(deprecated_arguments) > 0: + raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if past_key_values is None: + past_key_values = tuple([None] * len(self.h)) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape batch_size x num_heads x N x N + # head_mask has shape n_layer x batch x num_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + hidden_states = self.word_embeddings_layernorm(inputs_embeds) + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + # Compute alibi tensor: check build_alibi_tensor documentation + seq_length_with_past = seq_length + past_key_values_length = 0 + if past_key_values[0] is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) + else: + attention_mask = attention_mask.to(hidden_states.device) + + alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) + + causal_mask = self._prepare_attn_mask( + attention_mask, + input_shape=(batch_size, seq_length), + past_key_values_length=past_key_values_length, + ) + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + alibi, + causal_mask, + layer_past, + head_mask[i], + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=causal_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + alibi=alibi, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + # Add last hidden state + hidden_states = self.ln_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) From 585eb9d9470d3f56a6d6c84c02bca406187cffb0 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 4 Jul 2023 10:42:03 +0800 Subject: [PATCH 08/27] fix name confilt --- .../test_policy/{test_bert.py => test_bert_model.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/test_pipeline/test_policy/{test_bert.py => test_bert_model.py} (100%) diff --git a/tests/test_pipeline/test_policy/test_bert.py b/tests/test_pipeline/test_policy/test_bert_model.py similarity index 100% rename from tests/test_pipeline/test_policy/test_bert.py rename to tests/test_pipeline/test_policy/test_bert_model.py From 27fb80409570f3aa3c7ae4e96544e2b3c0e53c43 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 4 Jul 2023 16:53:20 +0800 Subject: [PATCH 09/27] add bloom model and policy ,revise the base class of policy --- colossalai/pipeline/policy/base.py | 23 +++- colossalai/pipeline/policy/bert.py | 86 ++++++------- colossalai/pipeline/policy/bloom.py | 110 ++++++++++++---- .../test_policy/test_bert_model.py | 4 +- .../test_policy/test_bloom_model.py | 119 ++++++++++++++++++ 5 files changed, 268 insertions(+), 74 deletions(-) create mode 100644 tests/test_pipeline/test_policy/test_bloom_model.py diff --git a/colossalai/pipeline/policy/base.py b/colossalai/pipeline/policy/base.py index ad595a04b1b0..9bfce15a83ab 100644 --- a/colossalai/pipeline/policy/base.py +++ b/colossalai/pipeline/policy/base.py @@ -1,13 +1,14 @@ from typing import Any, Dict, List, Optional, Tuple -from colossalai.lazy import LazyTensor from torch import Tensor from torch.nn import Module, Parameter +from colossalai.lazy import LazyTensor from colossalai.pipeline.stage_manager import PipelineStageManager class Policy: + def __init__(self, stage_manager: PipelineStageManager) -> None: self.stage_manager = stage_manager @@ -93,7 +94,8 @@ def get_shared_params(self, module: Module) -> List[Dict[int, Tensor]]: """ raise NotImplementedError - def parallelize_model(self, module: Module) -> Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]: + def parallelize_model(self, + module: Module) -> Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]: """Parallelize model for pipeline parallel Args: @@ -106,3 +108,20 @@ def parallelize_model(self, module: Module) -> Tuple[Dict[str, Parameter], Dict[ self.replace_forward(module) shared_params = self.get_shared_params(module) return hold_params, hold_buffers, shared_params + + def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]: + """ + divide layers into stages + """ + quotient = num_layers // num_stages + remainder = num_layers % num_stages + + # calculate the num_layers per stage + layers_per_stage = [quotient] * num_stages + + # deal with the rest layers + if remainder > 0: + start_position = num_layers // 2 - remainder // 2 + for i in range(start_position, start_position + remainder): + layers_per_stage[i] += 1 + return layers_per_stage diff --git a/colossalai/pipeline/policy/bert.py b/colossalai/pipeline/policy/bert.py index 6f912d2c6b80..002814e9014e 100644 --- a/colossalai/pipeline/policy/bert.py +++ b/colossalai/pipeline/policy/bert.py @@ -22,25 +22,26 @@ def bert_model_forward( - self: BertModel, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - #labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, #this is from the previous stage + self: BertModel, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + # labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + # this is from the previous stage + hidden_states: Optional[torch.FloatTensor] = None, ): - #TODO: add explaination of the output here. + # TODO: add explaination of the output here. r""" encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if @@ -93,6 +94,7 @@ def bert_model_forward( batch_size, seq_length = input_shape device = hidden_states.device + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') output_attentions = False @@ -144,7 +146,7 @@ def bert_model_forward( else: encoder_extended_attention_mask = None - #inherit from bert_layer + # inherit from bert_layer all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None @@ -156,12 +158,12 @@ def bert_model_forward( use_cache = False next_decoder_cache = () if use_cache else None - #calculate the num_layers + # calculate the num_layers num_layers_per_stage = len(self.encoder.layer) // stage_manager.num_stages start_layer = stage_manager.stage * num_layers_per_stage end_layer = (stage_manager.stage + 1) * num_layers_per_stage - #layer_outputs + # layer_outputs layer_outputs = hidden_states if hidden_states is not None else None for idx, encoder_layer in enumerate(self.encoder.layer[start_layer:end_layer], start=start_layer): if stage_manager.is_first_stage() and idx == 0: @@ -206,12 +208,13 @@ def custom_forward(*inputs): if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + all_cross_attentions = all_cross_attentions + \ + (layer_outputs[2],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - #end of a stage loop + # end of a stage loop sequence_output = layer_outputs[0] if layer_outputs is not None else None if stage_manager.is_last_stage(): @@ -219,7 +222,7 @@ def custom_forward(*inputs): if not return_dict: return (sequence_output, pooled_output) + layer_outputs[1:] - #output of non-first and non-last stages: + # output of non-first and non-last stages: if not return_dict: return tuple(v for v in [ hidden_states, @@ -229,7 +232,7 @@ def custom_forward(*inputs): all_cross_attentions, ] if v is not None) - #return dict is not supported at this moment + # return dict is not supported at this moment return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=next_decoder_cache, @@ -243,8 +246,9 @@ def custom_forward(*inputs): class BertModelPolicy(Policy): def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int): + super().__init__(stage_manager=stage_manager) self.stage_manager = stage_manager - self.layers_per_stage = self.distribute_layers(num_layers, num_stages) + self.layers_per_stage = super().distribute_layers(num_layers, num_stages) def get_hold_layers(self, module: BertModel) -> List[Module]: """ @@ -254,9 +258,9 @@ def get_hold_layers(self, module: BertModel) -> List[Module]: if self.stage_manager.is_first_stage(): hold_layers.append(module.embeddings) num_layers_per_stage_accumulated = np.cumsum(self.layers_per_stage) - hold_layers.extend(module.encoder.layer[num_layers_per_stage_accumulated \ - [self.stage_manager.stage-1] if self.stage_manager.stage > 0 else 0: - num_layers_per_stage_accumulated[self.stage_manager.stage]]) + hold_layers.extend( + module.encoder.layer[num_layers_per_stage_accumulated[self.stage_manager.stage - 1] if self.stage_manager. + stage > 0 else 0:num_layers_per_stage_accumulated[self.stage_manager.stage]]) if self.stage_manager.is_last_stage(): hold_layers.append(module.pooler) @@ -270,23 +274,6 @@ def get_shared_params(self, module: BertModel) -> List[Dict[int, Tensor]]: def replace_forward(self, module: Module) -> None: module.model.forward = MethodType(partial(bert_model_forward, stage_manager=self.stage_manager), module.model) - def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]: - """ - divide layers into stages - """ - quotient = num_layers // num_stages - remainder = num_layers % num_stages - - # calculate the num_layers per stage - layers_per_stage = [quotient] * num_stages - - # deal with the rest layers - if remainder > 0: - start_position = num_layers // 2 - remainder // 2 - for i in range(start_position, start_position + remainder): - layers_per_stage[i] += 1 - return layers_per_stage - def bert_for_pretraining_forward( self: BertForPreTraining, @@ -356,9 +343,10 @@ def get_hold_layers(self, module: BertForPreTraining) -> List[Module]: if self.stage_manager.is_first_stage(): hold_layers.append(module.bert.embeddings) num_layers_per_stage_accumulated = np.cumsum(self.layers_per_stage) - hold_layers.extend(module.bert.encoder.layer[num_layers_per_stage_accumulated \ - [self.stage_manager.stage-1] if self.stage_manager.stage > 0 else 0: - num_layers_per_stage_accumulated[self.stage_manager.stage]]) + hold_layers.extend( + module.bert.encoder.layer[num_layers_per_stage_accumulated[self.stage_manager.stage - + 1] if self.stage_manager. + stage > 0 else 0:num_layers_per_stage_accumulated[self.stage_manager.stage]]) if self.stage_manager.is_last_stage(): hold_layers.append(module.cls) diff --git a/colossalai/pipeline/policy/bloom.py b/colossalai/pipeline/policy/bloom.py index 8dffcd8f9af5..25b5039760bf 100644 --- a/colossalai/pipeline/policy/bloom.py +++ b/colossalai/pipeline/policy/bloom.py @@ -1,3 +1,4 @@ +import warnings from functools import partial from types import MethodType from typing import Dict, List, Optional, Tuple, Union @@ -14,6 +15,8 @@ from .base import Policy +logger = logging.get_logger(__name__) + def bloom_model_forward( self: BloomModel, @@ -26,6 +29,8 @@ def bloom_model_forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, **deprecated_arguments, ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: if deprecated_arguments.pop("position_ids", False) is not False: @@ -44,29 +49,45 @@ def bloom_model_forward( use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if past_key_values is None: - past_key_values = tuple([None] * len(self.h)) - + # add warnings here + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if use_cache: + logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + use_cache = False # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape batch_size x num_heads x N x N + # head_mask has shape n_layer x batch x num_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.n_layer) - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) + # case: First stage of training + if stage_manager.is_first_stage(): + # check input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) - hidden_states = self.word_embeddings_layernorm(inputs_embeds) + hidden_states = self.word_embeddings_layernorm(inputs_embeds) + # initialize in the first stage and then pass to the next stage + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + # extra recording tensor should be generated in the first stage presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None @@ -77,11 +98,13 @@ def bloom_model_forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") use_cache = False - # Compute alibi tensor: check build_alibi_tensor documentation + if past_key_values is None: + past_key_values = tuple([None] * len(self.h)) + # Compute alibi tensor: check build_alibi_tensor documentation,build for every stage seq_length_with_past = seq_length past_key_values_length = 0 if past_key_values[0] is not None: - past_key_values_length = past_key_values[0][0].shape[2] + past_key_values_length = past_key_values[0][0].shape[2] # source_len seq_length_with_past = seq_length_with_past + past_key_values_length if attention_mask is None: attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) @@ -90,13 +113,19 @@ def bloom_model_forward( alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) + # causal_mask is constructed every stage and its input is passed through different stages causal_mask = self._prepare_attn_mask( attention_mask, input_shape=(batch_size, seq_length), past_key_values_length=past_key_values_length, ) - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # calculate the num_layers + num_layers_per_stage = len(self.h) // stage_manager.num_stages + start_layer = stage_manager.stage * num_layers_per_stage + end_layer = (stage_manager.stage + 1) * num_layers_per_stage + + for i, (block, layer_past) in enumerate(zip(self.h[start_layer:end_layer], past_key_values[start_layer:end_layer])): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -130,24 +159,63 @@ def custom_forward(*inputs): ) hidden_states = outputs[0] + if use_cache is True: presents = presents + (outputs[1],) if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + all_self_attentions = all_self_attentions + \ + (outputs[2 if use_cache else 1],) - # Add last hidden state - hidden_states = self.ln_f(hidden_states) + if stage_manager.is_last_stage(): + # Add last hidden state + hidden_states = self.ln_f(hidden_states) + # TODO: deal with all_hidden_states, all_self_attentions, presents if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + # attention_mask is not returned ; presents = past_key_values + return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions, ) + + +class BloomModelPolicy(Policy): + + def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int): + super().__init__(stage_manager=stage_manager) + self.stage_manager = stage_manager + self.layers_per_stage = super().distribute_layers(num_layers, num_stages) + + def get_hold_layers(self, module: BloomModel) -> List[Module]: + """ + get pipeline layers for current stage + """ + hold_layers = [] + if self.stage_manager.is_first_stage(): + hold_layers.append(module.word_embeddings) + hold_layers.append(module.word_embeddings_layernorm) + num_layers_per_stage_accumulated = np.cumsum(self.layers_per_stage) + hold_layers.extend(module.h[num_layers_per_stage_accumulated[self.stage_manager.stage - + 1] if self.stage_manager. + stage > 0 else 0:num_layers_per_stage_accumulated[self.stage_manager.stage]]) + + if self.stage_manager.is_last_stage(): + hold_layers.append(module.ln_f) + + return hold_layers + + def get_shared_params(self, module: BloomModel) -> List[Dict[int, Tensor]]: + '''no shared params in bloommodel''' + pass + + def replace_forward(self, module: Module) -> None: + module.forward = MethodType(partial(bloom_model_forward, stage_manager=self.stage_manager), module.model) diff --git a/tests/test_pipeline/test_policy/test_bert_model.py b/tests/test_pipeline/test_policy/test_bert_model.py index c92f7f6c34c0..b757f6813153 100644 --- a/tests/test_pipeline/test_policy/test_bert_model.py +++ b/tests/test_pipeline/test_policy/test_bert_model.py @@ -27,7 +27,7 @@ def check_bert_model_forward(): 3: [2, 3], } pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) - #print(pg_mesh) + # print(pg_mesh) stage_manager = PipelineStageManager(pg_mesh, PP_DIM) rank = dist.get_rank() @@ -72,7 +72,7 @@ def check_bert_model_policy(): 3: [2, 3], } pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) - #print(pg_mesh) + # print(pg_mesh) stage_manager = PipelineStageManager(pg_mesh, PP_DIM) rank = dist.get_rank() diff --git a/tests/test_pipeline/test_policy/test_bloom_model.py b/tests/test_pipeline/test_policy/test_bloom_model.py new file mode 100644 index 000000000000..5ba92d734590 --- /dev/null +++ b/tests/test_pipeline/test_policy/test_bloom_model.py @@ -0,0 +1,119 @@ +import pytest +import torch +import torch.distributed as dist +from transformers.models.bloom import BloomConfig, BloomModel + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.pipeline.policy.bloom import BloomModelPolicy, bloom_model_forward +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def check_bloom_model_forward(): + # create a BloomModel + configuration = BloomConfig() + model = BloomModel(configuration) + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + # print(pg_mesh) + + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + rank = dist.get_rank() + # print(rank) + + x = torch.randint(0, 1000, (2, 3)) + hidden_states = torch.randint(0, 1000, (2, 3, 64)).to(torch.float32) + if stage_manager.is_first_stage(): + attention_mask = torch.ones_like(x) + output = bloom_model_forward(self=model, + input_ids=x, + attention_mask=attention_mask, + stage_manager=stage_manager) + print(output[0].shape) + assert output[0].shape == (2, 3, 64) + print('start the training') + else: + attention_mask = torch.ones((2, 3)) + output = bloom_model_forward(self=model, + hidden_states=hidden_states, + attention_mask=attention_mask, + stage_manager=stage_manager) + print(output[0].shape) + assert output[0].shape == (2, 3, 64) + print('end the training') + print(output) + + # assert output[1].shape == (2, 768) + + +def check_bloom_model_policy(): + # create a BloomModel + configuration = BloomConfig() + model = BloomModel(configuration) + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + # print(pg_mesh) + + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + rank = dist.get_rank() + + model_policy = BloomModelPolicy(stage_manager=stage_manager, num_layers=len(model.h), num_stages=2) + assert model_policy.layers_per_stage == [1, 1] + layers = model_policy.get_hold_layers(model) + for layer in layers: + print(layer) + + +def run_dist_model(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') + check_bloom_model_forward() + + +def run_dist_policy(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') + check_bloom_model_policy() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bloom_model_forward(): + spawn(run_dist_model, 4) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bloom_model_policy(): + spawn(run_dist_policy, 4) + + +if __name__ == "__main__": + """test the bloom model forward and bloom model policy""" + test_bloom_model_forward() + test_bloom_model_policy() From 3ea0ba4627e218224f79dff6ab5aa47683d616e3 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 4 Jul 2023 18:06:18 +0800 Subject: [PATCH 10/27] revise --- colossalai/pipeline/policy/base.py | 3 ++- colossalai/pipeline/policy/bert.py | 11 ++++++----- colossalai/pipeline/policy/bloom.py | 11 ++++++----- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/colossalai/pipeline/policy/base.py b/colossalai/pipeline/policy/base.py index 9bfce15a83ab..8da70dd43362 100644 --- a/colossalai/pipeline/policy/base.py +++ b/colossalai/pipeline/policy/base.py @@ -109,7 +109,8 @@ def parallelize_model(self, shared_params = self.get_shared_params(module) return hold_params, hold_buffers, shared_params - def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]: + @staticmethod + def distribute_layers(num_layers: int, num_stages: int) -> List[int]: """ divide layers into stages """ diff --git a/colossalai/pipeline/policy/bert.py b/colossalai/pipeline/policy/bert.py index 002814e9014e..0ec30d41129c 100644 --- a/colossalai/pipeline/policy/bert.py +++ b/colossalai/pipeline/policy/bert.py @@ -248,7 +248,7 @@ class BertModelPolicy(Policy): def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int): super().__init__(stage_manager=stage_manager) self.stage_manager = stage_manager - self.layers_per_stage = super().distribute_layers(num_layers, num_stages) + self.layers_per_stage = self.distribute_layers(num_layers, num_stages) def get_hold_layers(self, module: BertModel) -> List[Module]: """ @@ -257,11 +257,12 @@ def get_hold_layers(self, module: BertModel) -> List[Module]: hold_layers = [] if self.stage_manager.is_first_stage(): hold_layers.append(module.embeddings) - num_layers_per_stage_accumulated = np.cumsum(self.layers_per_stage) - hold_layers.extend( - module.encoder.layer[num_layers_per_stage_accumulated[self.stage_manager.stage - 1] if self.stage_manager. - stage > 0 else 0:num_layers_per_stage_accumulated[self.stage_manager.stage]]) + num_layers_per_stage_accumulated = np.insert(np.cumsum(self.layers_per_stage), 0, 0) + + start_idx = num_layers_per_stage_accumulated[self.stage_manager.stage] + end_idx = num_layers_per_stage_accumulated[self.stage_manager.stage + 1] + hold_layers.extend(module.encoder.layer[start_idx:end_idx]) if self.stage_manager.is_last_stage(): hold_layers.append(module.pooler) diff --git a/colossalai/pipeline/policy/bloom.py b/colossalai/pipeline/policy/bloom.py index 25b5039760bf..56337b26f333 100644 --- a/colossalai/pipeline/policy/bloom.py +++ b/colossalai/pipeline/policy/bloom.py @@ -193,7 +193,7 @@ class BloomModelPolicy(Policy): def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int): super().__init__(stage_manager=stage_manager) self.stage_manager = stage_manager - self.layers_per_stage = super().distribute_layers(num_layers, num_stages) + self.layers_per_stage = self.distribute_layers(num_layers, num_stages) def get_hold_layers(self, module: BloomModel) -> List[Module]: """ @@ -203,10 +203,11 @@ def get_hold_layers(self, module: BloomModel) -> List[Module]: if self.stage_manager.is_first_stage(): hold_layers.append(module.word_embeddings) hold_layers.append(module.word_embeddings_layernorm) - num_layers_per_stage_accumulated = np.cumsum(self.layers_per_stage) - hold_layers.extend(module.h[num_layers_per_stage_accumulated[self.stage_manager.stage - - 1] if self.stage_manager. - stage > 0 else 0:num_layers_per_stage_accumulated[self.stage_manager.stage]]) + num_layers_per_stage_accumulated = np.insert(np.cumsum(self.layers_per_stage), 0, 0) + + start_idx = num_layers_per_stage_accumulated[self.stage_manager.stage] + end_idx = num_layers_per_stage_accumulated[self.stage_manager.stage + 1] + hold_layers.extend(module.h[start_idx:end_idx]) if self.stage_manager.is_last_stage(): hold_layers.append(module.ln_f) From edb02b268df05529669eff4dc8b9f97c0ca99be3 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 4 Jul 2023 18:36:13 +0800 Subject: [PATCH 11/27] revision --- colossalai/pipeline/policy/base.py | 15 ++++++++++++++- colossalai/pipeline/policy/bert.py | 13 ++----------- colossalai/pipeline/policy/bloom.py | 9 +++------ 3 files changed, 19 insertions(+), 18 deletions(-) diff --git a/colossalai/pipeline/policy/base.py b/colossalai/pipeline/policy/base.py index c390e436b5a1..9736f1004fe4 100644 --- a/colossalai/pipeline/policy/base.py +++ b/colossalai/pipeline/policy/base.py @@ -1,14 +1,15 @@ from typing import Any, Dict, List, Optional, Tuple +import numpy as np from torch import Tensor from torch.nn import Module, Parameter from colossalai.lazy import LazyTensor - from colossalai.pipeline.stage_manager import PipelineStageManager class Policy: + def __init__(self, stage_manager: PipelineStageManager) -> None: self.stage_manager = stage_manager @@ -126,3 +127,15 @@ def distribute_layers(num_layers: int, num_stages: int) -> List[int]: for i in range(start_position, start_position + remainder): layers_per_stage[i] += 1 return layers_per_stage + + @staticmethod + def get_stage_index(layers_per_stage: List[int], stage: int) -> List[int]: + """ + get the start index and end index of layers for each stage. + """ + num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) + + start_idx = num_layers_per_stage_accumulated[stage] + end_idx = num_layers_per_stage_accumulated[stage + 1] + + return [start_idx, end_idx] diff --git a/colossalai/pipeline/policy/bert.py b/colossalai/pipeline/policy/bert.py index bc4d6e549762..a1efe238573c 100644 --- a/colossalai/pipeline/policy/bert.py +++ b/colossalai/pipeline/policy/bert.py @@ -22,7 +22,6 @@ def bert_model_forward( - self: BertModel, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, @@ -95,7 +94,6 @@ def bert_model_forward( batch_size, seq_length = input_shape device = hidden_states.device - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') @@ -213,7 +211,6 @@ def custom_forward(*inputs): all_cross_attentions = all_cross_attentions + \ (layer_outputs[2],) - if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -225,7 +222,6 @@ def custom_forward(*inputs): if not return_dict: return (sequence_output, pooled_output) + layer_outputs[1:] - # output of non-first and non-last stages: if not return_dict: return tuple(v for v in [ @@ -236,7 +232,6 @@ def custom_forward(*inputs): all_cross_attentions, ] if v is not None) - # return dict is not supported at this moment return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, @@ -262,11 +257,7 @@ def get_hold_layers(self, module: BertModel) -> List[Module]: hold_layers = [] if self.stage_manager.is_first_stage(): hold_layers.append(module.embeddings) - num_layers_per_stage_accumulated = np.insert(np.cumsum(self.layers_per_stage), 0, 0) - - start_idx = num_layers_per_stage_accumulated[self.stage_manager.stage] - end_idx = num_layers_per_stage_accumulated[self.stage_manager.stage + 1] - + start_idx, end_idx = self.get_stage_index(self.layers_per_stage, self.stage_manager.stage) hold_layers.extend(module.encoder.layer[start_idx:end_idx]) if self.stage_manager.is_last_stage(): hold_layers.append(module.pooler) @@ -280,6 +271,7 @@ def get_shared_params(self, module: BertModel) -> List[Dict[int, Tensor]]: def replace_forward(self, module: Module) -> None: module.model.forward = MethodType(partial(bert_model_forward, stage_manager=self.stage_manager), module.model) + def bert_for_pretraining_forward( self: BertForPreTraining, input_ids: Optional[torch.Tensor] = None, @@ -352,7 +344,6 @@ def get_hold_layers(self, module: BertForPreTraining) -> List[Module]: module.bert.encoder.layer[num_layers_per_stage_accumulated[self.stage_manager.stage - 1] if self.stage_manager. stage > 0 else 0:num_layers_per_stage_accumulated[self.stage_manager.stage]]) - if self.stage_manager.is_last_stage(): hold_layers.append(module.cls) diff --git a/colossalai/pipeline/policy/bloom.py b/colossalai/pipeline/policy/bloom.py index ebd086df67a8..71d2913fc3aa 100644 --- a/colossalai/pipeline/policy/bloom.py +++ b/colossalai/pipeline/policy/bloom.py @@ -15,9 +15,9 @@ from .base import Policy - logger = logging.get_logger(__name__) + def bloom_model_forward( self: BloomModel, input_ids: Optional[torch.LongTensor] = None, @@ -187,7 +187,7 @@ def custom_forward(*inputs): attentions=all_self_attentions, ) - + class BloomModelPolicy(Policy): def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int): @@ -203,10 +203,8 @@ def get_hold_layers(self, module: BloomModel) -> List[Module]: if self.stage_manager.is_first_stage(): hold_layers.append(module.word_embeddings) hold_layers.append(module.word_embeddings_layernorm) - num_layers_per_stage_accumulated = np.insert(np.cumsum(self.layers_per_stage), 0, 0) - start_idx = num_layers_per_stage_accumulated[self.stage_manager.stage] - end_idx = num_layers_per_stage_accumulated[self.stage_manager.stage + 1] + start_idx, end_idx = self.get_stage_index(self.layers_per_stage, self.stage_manager.stage) hold_layers.extend(module.h[start_idx:end_idx]) if self.stage_manager.is_last_stage(): @@ -220,4 +218,3 @@ def get_shared_params(self, module: BloomModel) -> List[Dict[int, Tensor]]: def replace_forward(self, module: Module) -> None: module.forward = MethodType(partial(bloom_model_forward, stage_manager=self.stage_manager), module.model) - From 369df2cf4fef581e08f1da501f7ae070bb8f57d8 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 4 Jul 2023 19:20:10 +0800 Subject: [PATCH 12/27] add bert_for_pretraining --- colossalai/pipeline/policy/bert.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/colossalai/pipeline/policy/bert.py b/colossalai/pipeline/policy/bert.py index a1efe238573c..8cd0fadd167f 100644 --- a/colossalai/pipeline/policy/bert.py +++ b/colossalai/pipeline/policy/bert.py @@ -290,8 +290,8 @@ def bert_for_pretraining_forward( ) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.bert( + outputs = bert_model_forward( + self.bert, input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -304,7 +304,8 @@ def bert_for_pretraining_forward( ) sequence_output, pooled_output = outputs[:2] - prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + if stage_manager.is_last_stage(): + prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) total_loss = None if labels is not None and next_sentence_label is not None: @@ -339,12 +340,12 @@ def get_hold_layers(self, module: BertForPreTraining) -> List[Module]: hold_layers = [] if self.stage_manager.is_first_stage(): hold_layers.append(module.bert.embeddings) - num_layers_per_stage_accumulated = np.cumsum(self.layers_per_stage) - hold_layers.extend( - module.bert.encoder.layer[num_layers_per_stage_accumulated[self.stage_manager.stage - - 1] if self.stage_manager. - stage > 0 else 0:num_layers_per_stage_accumulated[self.stage_manager.stage]]) + + start_idx, end_idx = self.get_stage_index(self.layers_per_stage, self.stage_manager.stage) + hold_layers.extend(module.bert.encoder.layer[start_idx:end_idx]) + if self.stage_manager.is_last_stage(): + hold_layers.append(module.bert.pooler) hold_layers.append(module.cls) return hold_layers From 0319c8bc4f803578cedbfa54726ee9fec9eae650 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Wed, 5 Jul 2023 12:23:57 +0800 Subject: [PATCH 13/27] add bert_for_pretraining forward and policy --- colossalai/pipeline/policy/bert.py | 112 +++++++++-------- .../test_bert_for_pretraining_model.py | 118 ++++++++++++++++++ 2 files changed, 178 insertions(+), 52 deletions(-) create mode 100644 tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py diff --git a/colossalai/pipeline/policy/bert.py b/colossalai/pipeline/policy/bert.py index 8cd0fadd167f..d8b665ec6c24 100644 --- a/colossalai/pipeline/policy/bert.py +++ b/colossalai/pipeline/policy/bert.py @@ -285,51 +285,76 @@ def bert_for_pretraining_forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - hidden_states: Optional[torch.LongTensor] = None, + hidden_states: Optional[torch.FloatTensor] = None, stage_manager: Optional[PipelineStageManager] = None, ) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = bert_model_forward( - self.bert, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output, pooled_output = outputs[:2] + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + + outputs = bert_model_forward(self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states if hidden_states is not None else None) + past_key_values = None + all_hidden_states = None + all_self_attentions = None + all_cross_attentions = None + hidden_states = outputs[0] if stage_manager.is_last_stage(): + sequence_output, pooled_output = outputs[:2] prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + # the last stage for pretraining model + total_loss = None + if labels is not None and next_sentence_label is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) + total_loss = masked_lm_loss + next_sentence_loss - total_loss = None - if labels is not None and next_sentence_label is not None: - loss_fct = CrossEntropyLoss() - masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) - next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) - total_loss = masked_lm_loss + next_sentence_loss + if not return_dict: + output = (prediction_scores, seq_relationship_score) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return BertForPreTrainingOutput( + loss=total_loss, + prediction_logits=prediction_scores, + seq_relationship_logits=seq_relationship_score, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) - if not return_dict: - output = (prediction_scores, seq_relationship_score) + outputs[2:] - return ((total_loss,) + output) if total_loss is not None else output - - return BertForPreTrainingOutput( - loss=total_loss, - prediction_logits=prediction_scores, - seq_relationship_logits=seq_relationship_score, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) + else: + if not return_dict: + return tuple(v for v in [ + hidden_states, + past_key_values, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] if v is not None) class BertForPreTrainingPolicy(Policy): def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int): + super().__init__(stage_manager=stage_manager) self.stage_manager = stage_manager self.layers_per_stage = self.distribute_layers(num_layers, num_stages) @@ -355,22 +380,5 @@ def get_shared_params(self, module: BertForPreTraining) -> List[Dict[int, Tensor pass def replace_forward(self, module: Module) -> None: - module.model.forward = MethodType(partial(bert_for_pretraining_forward, stage_manager=self.stage_manager), - module.model) - - def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]: - """ - divide layers into stages - """ - quotient = num_layers // num_stages - remainder = num_layers % num_stages - - # calculate the num_layers per stage - layers_per_stage = [quotient] * num_stages - - # deal with the rest layers - if remainder > 0: - start_position = num_layers // 2 - remainder // 2 - for i in range(start_position, start_position + remainder): - layers_per_stage[i] += 1 - return layers_per_stage + module.forward = MethodType(partial(bert_for_pretraining_forward, stage_manager=self.stage_manager), + module.forward) diff --git a/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py b/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py new file mode 100644 index 000000000000..4d764704ccba --- /dev/null +++ b/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py @@ -0,0 +1,118 @@ +import pytest +import torch +import torch.distributed as dist +from transformers.models.bert import BertConfig +from transformers.models.bert.modeling_bert import BertForPreTraining + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.pipeline.policy.bert import BertForPreTrainingPolicy, bert_for_pretraining_forward +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def check_bert_for_pretraining_forward(): + configuration = BertConfig() + model = BertForPreTraining(configuration) + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + # print(pg_mesh) + + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + rank = dist.get_rank() + # print(rank) + + x = torch.randint(0, 1000, (2, 3)) + hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32) + if stage_manager.stage == 2: + attention_mask = torch.ones_like(x) + output = bert_for_pretraining_forward(self=model, + input_ids=x, + attention_mask=attention_mask, + stage_manager=stage_manager) + print(output[0].shape) + assert output[0].shape == (2, 3, 768) + print('start the training') + elif stage_manager.stage == 1: + attention_mask = torch.ones((2, 12, 3, 3)) + output = bert_for_pretraining_forward(self=model, + hidden_states=hidden_states, + attention_mask=attention_mask, + stage_manager=stage_manager) + print(output[0].shape) + assert output[0].shape == (2, 3, 30522) + print('end the training') + print(output) + + # assert output[1].shape == (2, 768) + + +def check_bert_for_pretraining_policy(): + configuration = BertConfig() + model = BertForPreTraining(configuration) + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + # print(pg_mesh) + + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + rank = dist.get_rank() + + model_policy = BertForPreTrainingPolicy(stage_manager, len(model.bert.encoder.layer), 2) + assert model_policy.layers_per_stage == [6, 6] + layers = model_policy.get_hold_layers(model) + for layer in layers: + print(layer) + + +def run_dist_model(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') + check_bert_for_pretraining_forward() + + +def run_dist_policy(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') + check_bert_for_pretraining_policy() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bert_for_pretraining_forward(): + spawn(run_dist_model, 4) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bert_for_pretraining_policy(): + spawn(run_dist_policy, 4) + + +if __name__ == "__main__": + """test the bert for pretraining model forward and bert for pretraining model policy""" + test_bert_for_pretraining_forward() + test_bert_for_pretraining_policy() From 29ef3807accadd42023f1bc8ba2880756fe8e858 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Thu, 6 Jul 2023 10:35:38 +0800 Subject: [PATCH 14/27] fix typos --- colossalai/pipeline/policy/bert.py | 280 ++++++++++++++---- .../test_bert_for_pretraining_model.py | 8 +- .../test_policy/test_bert_lmhead_model.py | 118 ++++++++ .../test_policy/test_bert_model.py | 4 +- 4 files changed, 340 insertions(+), 70 deletions(-) create mode 100644 tests/test_pipeline/test_policy/test_bert_lmhead_model.py diff --git a/colossalai/pipeline/policy/bert.py b/colossalai/pipeline/policy/bert.py index d8b665ec6c24..85cb0b0af585 100644 --- a/colossalai/pipeline/policy/bert.py +++ b/colossalai/pipeline/policy/bert.py @@ -10,9 +10,15 @@ BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, ) -from transformers.models.bert.modeling_bert import BertForPreTraining, BertForPreTrainingOutput, BertModel -from transformers.utils import logging +from transformers.models.bert.modeling_bert import ( + BertForPreTraining, + BertForPreTrainingOutput, + BertLMHeadModel, + BertModel, +) +from transformers.utils import ModelOutput, logging from colossalai.pipeline.stage_manager import PipelineStageManager @@ -21,25 +27,38 @@ logger = logging.get_logger(__name__) +class BertModelIntermediateOutput(ModelOutput): + """ + Class for the intermediate output of bert model and bert-based model + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the previous stage. + NOTE: This is different from the base model. + """ + + hidden_states: torch.FloatTensor = None + attention_mask: Optional[torch.Tensor] = None + + def bert_model_forward( - self: BertModel, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + self: BertModel, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, # labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - # this is from the previous stage - hidden_states: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage ): # TODO: add explaination of the output here. r""" @@ -85,10 +104,6 @@ def bert_model_forward( raise ValueError("You have to specify either input_ids or inputs_embeds") batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) - attention_mask = extended_attention_mask else: input_shape = hidden_states.size()[:-1] batch_size, seq_length = input_shape @@ -119,14 +134,29 @@ def bert_model_forward( else: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + attention_mask = extended_attention_mask + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - hidden_states = hidden_states if hidden_states is not None else None + if stage_manager.is_first_stage(): hidden_states = self.embeddings( input_ids=input_ids, @@ -135,18 +165,8 @@ def bert_model_forward( inputs_embeds=inputs_embeds, past_key_values_length=past_key_values_length, ) - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.is_decoder and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) - else: - encoder_extended_attention_mask = None - # inherit from bert_layer + # inherit from bert_layer,this should be changed when we add the feature to record hidden_states all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None @@ -221,34 +241,34 @@ def custom_forward(*inputs): pooled_output = self.pooler(sequence_output) if self.pooler is not None else None if not return_dict: return (sequence_output, pooled_output) + layer_outputs[1:] + # return dict is not supported at this moment + else: + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) - # output of non-first and non-last stages: + # output of non-first and non-last stages: must be a dict if not return_dict: - return tuple(v for v in [ - hidden_states, - next_decoder_cache, - all_hidden_states, - all_self_attentions, - all_cross_attentions, - ] if v is not None) - - # return dict is not supported at this moment - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - cross_attentions=all_cross_attentions, - ) + logger.warning_once('The output of intermediate stage should always be a dict') + + return BertModelIntermediateOutput(hidden_states=hidden_states,) # The layer partition policy for bertmodel class BertModelPolicy(Policy): - def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int): + def __init__( + self, + stage_manager: PipelineStageManager, + num_layers: int, + ): super().__init__(stage_manager=stage_manager) self.stage_manager = stage_manager - self.layers_per_stage = self.distribute_layers(num_layers, num_stages) + self.layers_per_stage = self.distribute_layers(num_layers, stage_manager.num_stages) def get_hold_layers(self, module: BertModel) -> List[Module]: """ @@ -287,7 +307,7 @@ def bert_for_pretraining_forward( return_dict: Optional[bool] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_manager: Optional[PipelineStageManager] = None, -) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]: +): return_dict = return_dict if return_dict is not None else self.config.use_return_dict # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: @@ -317,6 +337,7 @@ def bert_for_pretraining_forward( all_self_attentions = None all_cross_attentions = None hidden_states = outputs[0] + if stage_manager.is_last_stage(): sequence_output, pooled_output = outputs[:2] prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) @@ -342,21 +363,16 @@ def bert_for_pretraining_forward( else: if not return_dict: - return tuple(v for v in [ - hidden_states, - past_key_values, - all_hidden_states, - all_self_attentions, - all_cross_attentions, - ] if v is not None) + logger.warning_once('The output of intermediate stage should always be a dict') + return BertModelIntermediateOutput(hidden_states=hidden_states,) class BertForPreTrainingPolicy(Policy): - def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int): + def __init__(self, stage_manager: PipelineStageManager, num_layers: int): super().__init__(stage_manager=stage_manager) self.stage_manager = stage_manager - self.layers_per_stage = self.distribute_layers(num_layers, num_stages) + self.layers_per_stage = self.distribute_layers(num_layers, stage_manager.num_stages) def get_hold_layers(self, module: BertForPreTraining) -> List[Module]: """ @@ -382,3 +398,139 @@ def get_shared_params(self, module: BertForPreTraining) -> List[Dict[int, Tensor def replace_forward(self, module: Module) -> None: module.forward = MethodType(partial(bert_for_pretraining_forward, stage_manager=self.stage_manager), module.forward) + + +def bert_lmhead_forward(self: BertLMHeadModel, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.Tensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_manager: Optional[PipelineStageManager] = None): + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]` + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + use_cache = False + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + + outputs = bert_model_forward(self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states if hidden_states is not None else None) + past_key_values = None + all_hidden_states = None + all_self_attentions = None + all_cross_attentions = None + hidden_states = outputs[0] + + if stage_manager.is_last_stage(): + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + else: + if not return_dict: + return BertModelIntermediateOutput(hidden_states=hidden_states) + + +class BertLMHeadModelPolicy(Policy): + + def __init__(self, stage_manager: PipelineStageManager, num_layers: int): + super().__init__(stage_manager=stage_manager) + self.stage_manager = stage_manager + self.layers_per_stage = self.distribute_layers(num_layers, stage_manager.num_stages) + + def get_hold_layers(self, module: BertLMHeadModel) -> List[Module]: + """ + get pipeline layers for current stage + """ + hold_layers = [] + if self.stage_manager.is_first_stage(): + hold_layers.append(module.bert.embeddings) + start_idx, end_idx = self.get_stage_index(self.layers_per_stage, self.stage_manager.stage) + hold_layers.extend(module.bert.encoder.layer[start_idx:end_idx]) + if self.stage_manager.is_last_stage(): + hold_layers.append(module.bert.pooler) + hold_layers.append(module.cls) + + return hold_layers + + def get_shared_params(self, module: BertLMHeadModel) -> List[Dict[int, Tensor]]: + '''no shared params in bertmodel''' + pass + + def replace_forward(self, module: Module) -> None: + module.forward = MethodType(partial(bert_lmhead_forward, stage_manager=self.stage_manager), module) diff --git a/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py b/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py index 4d764704ccba..b170b52163c3 100644 --- a/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py +++ b/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py @@ -37,7 +37,7 @@ def check_bert_for_pretraining_forward(): x = torch.randint(0, 1000, (2, 3)) hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32) - if stage_manager.stage == 2: + if stage_manager.stage == 0: attention_mask = torch.ones_like(x) output = bert_for_pretraining_forward(self=model, input_ids=x, @@ -46,8 +46,8 @@ def check_bert_for_pretraining_forward(): print(output[0].shape) assert output[0].shape == (2, 3, 768) print('start the training') - elif stage_manager.stage == 1: - attention_mask = torch.ones((2, 12, 3, 3)) + else: + attention_mask = torch.ones((2, 3)) output = bert_for_pretraining_forward(self=model, hidden_states=hidden_states, attention_mask=attention_mask, @@ -83,7 +83,7 @@ def check_bert_for_pretraining_policy(): stage_manager = PipelineStageManager(pg_mesh, PP_DIM) rank = dist.get_rank() - model_policy = BertForPreTrainingPolicy(stage_manager, len(model.bert.encoder.layer), 2) + model_policy = BertForPreTrainingPolicy(stage_manager, len(model.bert.encoder.layer)) assert model_policy.layers_per_stage == [6, 6] layers = model_policy.get_hold_layers(model) for layer in layers: diff --git a/tests/test_pipeline/test_policy/test_bert_lmhead_model.py b/tests/test_pipeline/test_policy/test_bert_lmhead_model.py new file mode 100644 index 000000000000..04a6aff80ff1 --- /dev/null +++ b/tests/test_pipeline/test_policy/test_bert_lmhead_model.py @@ -0,0 +1,118 @@ +import pytest +import torch +import torch.distributed as dist +from transformers.models.bert import BertConfig +from transformers.models.bert.modeling_bert import BertLMHeadModel + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.pipeline.policy.bert import BertLMHeadModelPolicy, bert_lmhead_forward +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def check_bert_lmhead_forward(): + configuration = BertConfig() + model = BertLMHeadModel(configuration) + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + # print(pg_mesh) + + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + rank = dist.get_rank() + # print(rank) + + x = torch.randint(0, 1000, (2, 3)) + hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32) + if stage_manager.stage == 0: + attention_mask = torch.ones_like(x) + output = bert_lmhead_forward(self=model, + input_ids=x, + attention_mask=attention_mask, + stage_manager=stage_manager) + print(output[0].shape) + assert output[0].shape == (2, 3, 768) + print('start the training') + else: + attention_mask = torch.ones((2, 3)) + output = bert_lmhead_forward(self=model, + hidden_states=hidden_states, + attention_mask=attention_mask, + stage_manager=stage_manager) + print(output[0].shape) + assert output[0].shape == (2, 3, 30522) + print('end the training') + print(output) + + # assert output[1].shape == (2, 768) + + +def check_bert_lmhead_policy(): + configuration = BertConfig() + model = BertLMHeadModel(configuration) + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + # print(pg_mesh) + + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + rank = dist.get_rank() + + model_policy = BertLMHeadModelPolicy(stage_manager, len(model.bert.encoder.layer)) + assert model_policy.layers_per_stage == [6, 6] + layers = model_policy.get_hold_layers(model) + for layer in layers: + print(layer) + + +def run_dist_model(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') + check_bert_lmhead_forward() + + +def run_dist_policy(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') + check_bert_lmhead_policy() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bert_lmhead_forward(): + spawn(run_dist_model, 4) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bert_lmhead_policy(): + spawn(run_dist_policy, 4) + + +if __name__ == "__main__": + """test the bert for pretraining model forward and bert for pretraining model policy""" + test_bert_lmhead_forward() + test_bert_lmhead_policy() diff --git a/tests/test_pipeline/test_policy/test_bert_model.py b/tests/test_pipeline/test_policy/test_bert_model.py index cf5dc95feb8c..5903434d97b8 100644 --- a/tests/test_pipeline/test_policy/test_bert_model.py +++ b/tests/test_pipeline/test_policy/test_bert_model.py @@ -43,7 +43,7 @@ def check_bert_model_forward(): assert output[0].shape == (2, 3, 768) print('start the training') else: - attention_mask = torch.ones((2, 12, 3, 3)) + attention_mask = torch.ones((2, 3)) output = bert_model_forward(self=model, hidden_states=hidden_states, attention_mask=attention_mask, @@ -78,7 +78,7 @@ def check_bert_model_policy(): stage_manager = PipelineStageManager(pg_mesh, PP_DIM) rank = dist.get_rank() - model_policy = BertModelPolicy(stage_manager, len(model.encoder.layer), 2) + model_policy = BertModelPolicy(stage_manager, len(model.encoder.layer)) assert model_policy.layers_per_stage == [6, 6] layers = model_policy.get_hold_layers(model) for layer in layers: From 5cd2478db4f6159e8a12ad66a630c7d24d2b0395 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Thu, 6 Jul 2023 12:10:15 +0800 Subject: [PATCH 15/27] cancel warning --- colossalai/pipeline/policy/bert.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/colossalai/pipeline/policy/bert.py b/colossalai/pipeline/policy/bert.py index 85cb0b0af585..85bd35962386 100644 --- a/colossalai/pipeline/policy/bert.py +++ b/colossalai/pipeline/policy/bert.py @@ -252,10 +252,9 @@ def custom_forward(*inputs): ) # output of non-first and non-last stages: must be a dict - if not return_dict: - logger.warning_once('The output of intermediate stage should always be a dict') - - return BertModelIntermediateOutput(hidden_states=hidden_states,) + else: + # intermediate stage always return dict + return BertModelIntermediateOutput(hidden_states=hidden_states,) # The layer partition policy for bertmodel @@ -362,8 +361,7 @@ def bert_for_pretraining_forward( ) else: - if not return_dict: - logger.warning_once('The output of intermediate stage should always be a dict') + # intermediate stage always return dict return BertModelIntermediateOutput(hidden_states=hidden_states,) @@ -502,8 +500,8 @@ def bert_lmhead_forward(self: BertLMHeadModel, cross_attentions=outputs.cross_attentions, ) else: - if not return_dict: - return BertModelIntermediateOutput(hidden_states=hidden_states) + # intermediate stage always return dict + return BertModelIntermediateOutput(hidden_states=hidden_states) class BertLMHeadModelPolicy(Policy): From ef528e6d7d61f416ff47dab0aa082462f3e73b64 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Thu, 6 Jul 2023 14:17:49 +0800 Subject: [PATCH 16/27] change the imediate output to default dict --- colossalai/pipeline/policy/bert.py | 33 +++++++------------ .../test_bert_for_pretraining_model.py | 4 +-- .../test_policy/test_bert_lmhead_model.py | 4 +-- .../test_policy/test_bert_model.py | 4 +-- 4 files changed, 17 insertions(+), 28 deletions(-) diff --git a/colossalai/pipeline/policy/bert.py b/colossalai/pipeline/policy/bert.py index 85bd35962386..ec6ab91b9365 100644 --- a/colossalai/pipeline/policy/bert.py +++ b/colossalai/pipeline/policy/bert.py @@ -27,20 +27,6 @@ logger = logging.get_logger(__name__) -class BertModelIntermediateOutput(ModelOutput): - """ - Class for the intermediate output of bert model and bert-based model - - Args: - hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the previous stage. - NOTE: This is different from the base model. - """ - - hidden_states: torch.FloatTensor = None - attention_mask: Optional[torch.Tensor] = None - - def bert_model_forward( self: BertModel, input_ids: Optional[torch.Tensor] = None, @@ -254,7 +240,9 @@ def custom_forward(*inputs): # output of non-first and non-last stages: must be a dict else: # intermediate stage always return dict - return BertModelIntermediateOutput(hidden_states=hidden_states,) + return { + 'hidden_states': hidden_states, + } # The layer partition policy for bertmodel @@ -288,7 +276,7 @@ def get_shared_params(self, module: BertModel) -> List[Dict[int, Tensor]]: pass def replace_forward(self, module: Module) -> None: - module.model.forward = MethodType(partial(bert_model_forward, stage_manager=self.stage_manager), module.model) + module.forward = MethodType(partial(bert_model_forward, stage_manager=self.stage_manager), module) def bert_for_pretraining_forward( @@ -335,8 +323,6 @@ def bert_for_pretraining_forward( all_hidden_states = None all_self_attentions = None all_cross_attentions = None - hidden_states = outputs[0] - if stage_manager.is_last_stage(): sequence_output, pooled_output = outputs[:2] prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) @@ -359,10 +345,13 @@ def bert_for_pretraining_forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) - else: + hidden_states = outputs.get('hidden_states') + # intermediate stage always return dict - return BertModelIntermediateOutput(hidden_states=hidden_states,) + return { + 'hidden_states': hidden_states, + } class BertForPreTrainingPolicy(Policy): @@ -473,7 +462,6 @@ def bert_lmhead_forward(self: BertLMHeadModel, all_hidden_states = None all_self_attentions = None all_cross_attentions = None - hidden_states = outputs[0] if stage_manager.is_last_stage(): sequence_output = outputs[0] @@ -500,8 +488,9 @@ def bert_lmhead_forward(self: BertLMHeadModel, cross_attentions=outputs.cross_attentions, ) else: + hidden_states = outputs.get('hidden_states') # intermediate stage always return dict - return BertModelIntermediateOutput(hidden_states=hidden_states) + return {'hidden_states': hidden_states} class BertLMHeadModelPolicy(Policy): diff --git a/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py b/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py index b170b52163c3..afbea49c1829 100644 --- a/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py +++ b/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py @@ -43,8 +43,8 @@ def check_bert_for_pretraining_forward(): input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) - print(output[0].shape) - assert output[0].shape == (2, 3, 768) + print(output['hidden_states'].shape) + assert output['hidden_states'].shape == (2, 3, 768) print('start the training') else: attention_mask = torch.ones((2, 3)) diff --git a/tests/test_pipeline/test_policy/test_bert_lmhead_model.py b/tests/test_pipeline/test_policy/test_bert_lmhead_model.py index 04a6aff80ff1..d41eddc74dff 100644 --- a/tests/test_pipeline/test_policy/test_bert_lmhead_model.py +++ b/tests/test_pipeline/test_policy/test_bert_lmhead_model.py @@ -43,8 +43,8 @@ def check_bert_lmhead_forward(): input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) - print(output[0].shape) - assert output[0].shape == (2, 3, 768) + print(output['hidden_states'].shape) + assert output['hidden_states'].shape == (2, 3, 768) print('start the training') else: attention_mask = torch.ones((2, 3)) diff --git a/tests/test_pipeline/test_policy/test_bert_model.py b/tests/test_pipeline/test_policy/test_bert_model.py index 5903434d97b8..92485072a5e4 100644 --- a/tests/test_pipeline/test_policy/test_bert_model.py +++ b/tests/test_pipeline/test_policy/test_bert_model.py @@ -39,8 +39,8 @@ def check_bert_model_forward(): if stage_manager.stage == 0: attention_mask = torch.ones_like(x) output = bert_model_forward(self=model, input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) - print(output[0].shape) - assert output[0].shape == (2, 3, 768) + print(output['hidden_states'].shape) + assert output['hidden_states'].shape == (2, 3, 768) print('start the training') else: attention_mask = torch.ones((2, 3)) From e3e6c3bd6ae42646658a91b3571c408e1916b78e Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Thu, 6 Jul 2023 14:23:02 +0800 Subject: [PATCH 17/27] change the default output of get_shared_params --- colossalai/pipeline/policy/bert.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/colossalai/pipeline/policy/bert.py b/colossalai/pipeline/policy/bert.py index ec6ab91b9365..abce504e9d61 100644 --- a/colossalai/pipeline/policy/bert.py +++ b/colossalai/pipeline/policy/bert.py @@ -273,7 +273,7 @@ def get_hold_layers(self, module: BertModel) -> List[Module]: def get_shared_params(self, module: BertModel) -> List[Dict[int, Tensor]]: '''no shared params in bertmodel''' - pass + return [] def replace_forward(self, module: Module) -> None: module.forward = MethodType(partial(bert_model_forward, stage_manager=self.stage_manager), module) @@ -380,7 +380,7 @@ def get_hold_layers(self, module: BertForPreTraining) -> List[Module]: def get_shared_params(self, module: BertForPreTraining) -> List[Dict[int, Tensor]]: '''no shared params in bertmodel''' - pass + return [] def replace_forward(self, module: Module) -> None: module.forward = MethodType(partial(bert_for_pretraining_forward, stage_manager=self.stage_manager), @@ -517,7 +517,7 @@ def get_hold_layers(self, module: BertLMHeadModel) -> List[Module]: def get_shared_params(self, module: BertLMHeadModel) -> List[Dict[int, Tensor]]: '''no shared params in bertmodel''' - pass + return [] def replace_forward(self, module: Module) -> None: module.forward = MethodType(partial(bert_lmhead_forward, stage_manager=self.stage_manager), module) From 6793bc34befd70c46cacccab1890afd353c3c81b Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 1 Aug 2023 17:53:16 +0800 Subject: [PATCH 18/27] add chatglm --- colossalai/shardformer/modeling/chatglm.py | 214 ++++++++++++++++++ .../modeling/chatglm2_6b/modeling_chatglm.py | 0 colossalai/shardformer/policies/chatglm.py | 0 tests/kit/model_zoo/transformers/chatglm.py | 46 ++++ .../test_model/test_shard_chatglm.py | 109 +++++++++ .../test_model/test_shard_chatglm_pipeline.py | 87 +++++++ 6 files changed, 456 insertions(+) create mode 100644 colossalai/shardformer/modeling/chatglm.py create mode 100644 colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py create mode 100644 colossalai/shardformer/policies/chatglm.py create mode 100644 tests/kit/model_zoo/transformers/chatglm.py create mode 100644 tests/test_shardformer/test_model/test_shard_chatglm.py create mode 100644 tests/test_shardformer/test_model/test_shard_chatglm_pipeline.py diff --git a/colossalai/shardformer/modeling/chatglm.py b/colossalai/shardformer/modeling/chatglm.py new file mode 100644 index 000000000000..5557dd379a75 --- /dev/null +++ b/colossalai/shardformer/modeling/chatglm.py @@ -0,0 +1,214 @@ +""" PyTorch ChatGLM model. """ +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig +from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( + ChatGLMForConditionalGeneration, + ChatGLMModel, + GLMBlock, +) + + +class ChatGLMPipelineForwards: + ''' + This class serves as a micro library for ChatGLM model forwards under pipeline parallelism. + ''' + + def chatglm_model_forward( + self: ChatGLMModel, + input_ids, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + full_attention_mask: Optional[torch.BoolTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ): + + logger = logging.get_logger(__name__) + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if past_key_values: + logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') + past_key_values = None + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if use_cache: + logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + use_cache = False + + if stage_manager.is_first_stage(): + batch_size, seq_length = input_ids.shape + + if inputs_embeds is None: + inputs_embeds = self.embedding(input_ids) + hidden_states = inputs_embeds + else: + seq_length, batch_size = hidden_states.shape[:2] + + if self.pre_seq_len is not None: + if past_key_values is None: + past_key_values = self.get_prompt(batch_size=batch_size, + device=input_ids.device, + dtype=inputs_embeds.dtype) + if attention_mask is not None: + attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)), attention_mask], + dim=-1) + + if full_attention_mask is None: + if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): + full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) + # Rotary positional embeddings + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + if position_ids is not None: + rotary_pos_emb = rotary_pos_emb[position_ids] + else: + layer_ret = layer(hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=kv_caches[idx], + use_cache=use_cache) + hidden_states, kv_cache = layer_ret + if use_cache: + presents = presents + (kv_cache,) + + if not past_key_values: + past_key_values = [None for _ in range(self.num_layers)] + + presents = () if use_cache else None + + if self.encoder.gradient_checkpointing and self.encoder.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + all_self_attentions = None + all_hidden_states = () if output_hidden_states else None + start_idx, end_idx = stage_index[0], stage_index[1] + for idx in range(start_idx, end_idx): + layer = self.encoder._get_layer(idx) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.encoder.gradient_checkpointing and self.encoder.training: + layer_ret = torch.utils.checkpoint.checkpoint(layer, hidden_states, attention_mask, rotary_pos_emb, + past_key_values[idx], use_cache) + else: + layer_ret = layer(hidden_states, + full_attention_mask, + rotary_pos_emb, + kv_cache=past_key_values[idx], + use_cache=use_cache) + + hidden_states, kv_cache = layer_ret + if use_cache: + presents = presents + (kv_cache,) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if stage_manager.is_last_stage(): + # final layer_norm + if self.encoder.post_layer_norm: + hidden_states = self.encoder.final_layernorm(hidden_states) + if not return_dict: + return tuple( + v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + else: + return {'hidden_states': hidden_states} + + def chatglm_for_conditional_generation_forward( + self: ChatGLMForConditionalGeneration, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + return_last_logit: Optional[bool] = False, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ): + logger = logging.get_logger(__name__) + + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) + + transformer_outputs = ChatGLMPipelineForwards.chatglm_model_forward( + self.transformer, + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + ) + if stage_manager.is_last_stage(): + hidden_states = transformer_outputs[0] + if return_last_logit: + hidden_states = hidden_states[-1:] + lm_logits = self.transformer.output_layer(hidden_states) + lm_logits = lm_logits.transpose(0, 1).contiguous() + + loss = None + if labels is not None: + lm_logits = lm_logits.to(torch.float32) + + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + lm_logits = lm_logits.to(hidden_states.dtype) + loss = loss.to(hidden_states.dtype) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + else: + return transformer_outputs diff --git a/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/shardformer/policies/chatglm.py b/colossalai/shardformer/policies/chatglm.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/kit/model_zoo/transformers/chatglm.py b/tests/kit/model_zoo/transformers/chatglm.py new file mode 100644 index 000000000000..96306fa54ff5 --- /dev/null +++ b/tests/kit/model_zoo/transformers/chatglm.py @@ -0,0 +1,46 @@ +import torch +import transformers + +from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig +from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel + +from ..registry import ModelAttribute, model_zoo + +# ================================ +# Register single-sentence ChatGLM +# ================================ + + +def data_gen(): + input_ids = torch.tensor([[5941, 15, 2670, 3543, 632, 2075]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]]) + return dict(input_ids=input_ids, attention_mask=attention_mask) + + +# define output transform function +output_transform_fn = lambda x: x + +# define loss function +loss_fn_for_chatglm_model = lambda x: x.last_hidden_state +loss_fn = lambda x: x.loss +config = ChatGLMConfig(num_layers=1, + padded_vocab_size=65024, + hidden_size=64, + num_attention_heads=8, + rmsnorm=False, + original_rope=True, + use_cache=True) + +model_zoo.register(name='transformers_chatglm', + model_fn=lambda: ChatGLMModel(config, empty_init=False), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_chatglm_model, + model_attribute=ModelAttribute(has_control_flow=True)) + +model_zoo.register(name="transformers_chatglm_for_conditional_generation", + model_fn=lambda: ChatGLMForConditionalGeneration(config, empty_init=False), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/test_shardformer/test_model/test_shard_chatglm.py b/tests/test_shardformer/test_model/test_shard_chatglm.py new file mode 100644 index 000000000000..7f6ce65b4255 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_chatglm.py @@ -0,0 +1,109 @@ +import copy +import os + +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer.policies.chatglm import ChatGLMForConditionalGenerationPolicy, ChatGLMModelPolicy +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, run_forward + + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # check forward + org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, + output_transform_fn, loss_fn) + assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values']) + # do backward + org_loss.backward() + shard_loss.backward() + + assert torch.allclose(org_loss, shard_loss, + atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + + # unwrap model + if org_model.__class__.__name__ == 'ChatGLMModel': + chatglm_model = org_model + shard_chatglm_model = sharded_model + else: + chatglm_model = org_model.transformer + shard_chatglm_model = sharded_model.transformer + + # check attention grad + org_grad = chatglm_model.encoder.layers[0].self_attention.query_key_value.weight.grad + shard_grad = shard_chatglm_model.encoder.layers[0].self_attention.query_key_value.weight.grad + shard_weight = shard_chatglm_model.encoder.layers[0].self_attention.query_key_value.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" + + # check embedding weights + org_grad = chatglm_model.embedding.word_embeddings.weight.grad + shard_grad = shard_chatglm_model.embedding.word_embeddings.weight.grad + shard_weight = shard_chatglm_model.embedding.word_embeddings.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros_like(shard_grad) for _ in range(2)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad + + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + + +@parameterize('enable_fused_normalization', [True, False]) +@parameterize('enable_tensor_parallelism', [True, False]) +def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism): + sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + # create new model + org_model = model_fn().cuda() + + # shard model + shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, + enable_tensor_parallelism=enable_tensor_parallelism) + model_copy = copy.deepcopy(org_model) + shard_former = ShardFormer(shard_config=shard_config) + if name == "transformers_chatglm": + sharded_model = shard_former.optimize(model_copy, ChatGLMModelPolicy()).cuda() + else: + sharded_model = shard_former.optimize(model_copy, ChatGLMForConditionalGenerationPolicy()).cuda() + + check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + torch.cuda.empty_cache() + + +def check_chatglm(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_chatglm_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_chatglm(): + spawn(check_chatglm, 2) + + +if __name__ == "__main__": + test_chatglm() diff --git a/tests/test_shardformer/test_model/test_shard_chatglm_pipeline.py b/tests/test_shardformer/test_model/test_shard_chatglm_pipeline.py new file mode 100644 index 000000000000..f4ccf3296b8f --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_chatglm_pipeline.py @@ -0,0 +1,87 @@ +import copy +import os + +import pytest +import torch + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.policies.chatglm import ChatGLMForConditionalGenerationPolicy, ChatGLMModelPolicy +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward + + +@parameterize('enable_fused_normalization', [False]) +@parameterize('enable_tensor_parallelism', [False]) +@parameterize('use_lazy_init', [False]) +def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + # create new model + inputs = data_gen_fn() + inputs = {k: v.cuda() for k, v in inputs.items()} + input_ids = inputs['input_ids'] + hidden_size = 64 + batch_size, seq_len = input_ids.shape + hidden_state_shape = (seq_len, batch_size, hidden_size) + if name == "transformers_chatglm": + _, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, + enable_tensor_parallelism, use_lazy_init, ChatGLMModelPolicy()) + if stage_manager.is_last_stage(): + hidden_states = torch.randn(*hidden_state_shape).cuda() + inputs['input_ids'] = None + inputs['hidden_states'] = hidden_states + outputs = sharded_model(**inputs) + + if stage_manager.is_last_stage(): + assert outputs[0].shape == hidden_state_shape + else: + assert outputs['hidden_states'].shape == hidden_state_shape + + if name == "transformers_chatglm_for_conditional_generation": + _, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, + enable_tensor_parallelism, use_lazy_init, + ChatGLMForConditionalGenerationPolicy()) + if stage_manager.is_last_stage(): + hidden_states = torch.randn(*hidden_state_shape).cuda() + inputs['input_ids'] = None + inputs['hidden_states'] = hidden_states + outputs = sharded_model(**inputs) + + if stage_manager.is_last_stage(): + assert outputs[0].shape == (batch_size, seq_len, 65024) + else: + assert outputs['hidden_states'].shape == hidden_state_shape + + torch.cuda.empty_cache() + + +def check_chatglm(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_chatglm_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_chatglm(): + spawn(check_chatglm, 4) + + +if __name__ == "__main__": + test_chatglm() From d9289dc5164572409845824c91a8ce8939e1cbdf Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 1 Aug 2023 18:17:19 +0800 Subject: [PATCH 19/27] add --- .../chatglm2_6b/configuration_chatglm.py | 58 +++++++++++++++++++ .../modeling/chatglm2_6b/modeling_chatglm.py | 0 colossalai/shardformer/policies/chatglm.py | 0 tests/kit/model_zoo/transformers/__init__.py | 1 + 4 files changed, 59 insertions(+) create mode 100644 colossalai/shardformer/modeling/chatglm2_6b/configuration_chatglm.py delete mode 100644 colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py delete mode 100644 colossalai/shardformer/policies/chatglm.py diff --git a/colossalai/shardformer/modeling/chatglm2_6b/configuration_chatglm.py b/colossalai/shardformer/modeling/chatglm2_6b/configuration_chatglm.py new file mode 100644 index 000000000000..3e78732be2da --- /dev/null +++ b/colossalai/shardformer/modeling/chatglm2_6b/configuration_chatglm.py @@ -0,0 +1,58 @@ +from transformers import PretrainedConfig + + +class ChatGLMConfig(PretrainedConfig): + model_type = "chatglm" + + def __init__(self, + num_layers=28, + padded_vocab_size=65024, + hidden_size=4096, + ffn_hidden_size=13696, + kv_channels=128, + num_attention_heads=32, + seq_length=2048, + hidden_dropout=0.0, + attention_dropout=0.0, + layernorm_epsilon=1e-5, + rmsnorm=True, + apply_residual_connection_post_layernorm=False, + post_layer_norm=True, + add_bias_linear=False, + add_qkv_bias=False, + bias_dropout_fusion=True, + multi_query_attention=False, + multi_query_group_num=1, + apply_query_key_layer_scaling=True, + attention_softmax_in_fp32=True, + fp32_residual_connection=False, + quantization_bit=0, + pre_seq_len=None, + prefix_projection=False, + **kwargs): + self.num_layers = num_layers + self.vocab_size = padded_vocab_size + self.padded_vocab_size = padded_vocab_size + self.hidden_size = hidden_size + self.ffn_hidden_size = ffn_hidden_size + self.kv_channels = kv_channels + self.num_attention_heads = num_attention_heads + self.seq_length = seq_length + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.layernorm_epsilon = layernorm_epsilon + self.rmsnorm = rmsnorm + self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm + self.post_layer_norm = post_layer_norm + self.add_bias_linear = add_bias_linear + self.add_qkv_bias = add_qkv_bias + self.bias_dropout_fusion = bias_dropout_fusion + self.multi_query_attention = multi_query_attention + self.multi_query_group_num = multi_query_group_num + self.apply_query_key_layer_scaling = apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 + self.fp32_residual_connection = fp32_residual_connection + self.quantization_bit = quantization_bit + self.pre_seq_len = pre_seq_len + self.prefix_projection = prefix_projection + super().__init__(**kwargs) diff --git a/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/colossalai/shardformer/policies/chatglm.py b/colossalai/shardformer/policies/chatglm.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index a298767d12e7..fa4bbe1b998f 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -1,6 +1,7 @@ from .albert import * from .bert import * from .bloom import * +from .chatglm import * from .gpt import * from .llama import * from .opt import * From 68dd1018bacf0c7b17d93e517b73c5c24a57d497 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 1 Aug 2023 18:21:05 +0800 Subject: [PATCH 20/27] chatglm --- .../modeling/chatglm2_6b/modeling_chatglm.py | 1372 +++++++++++++++++ 1 file changed, 1372 insertions(+) create mode 100644 colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py diff --git a/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py new file mode 100644 index 000000000000..bae6d425878d --- /dev/null +++ b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py @@ -0,0 +1,1372 @@ +""" +The ChatGLM2-6B License + +1. Definitions + +“Licensor” means the ChatGLM2-6B Model Team that distributes its Software. + +“Software” means the ChatGLM2-6B model parameters made available under this license. + +2. License Grant + +Subject to the terms and conditions of this License, the Licensor hereby grants to you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license to use the Software solely for your non-commercial research purposes. + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +3. Restriction + +You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any commercial, military, or illegal purposes. + +You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings. + +4. Disclaimer + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +5. Limitation of Liability + +EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. + +6. Dispute Resolution + +This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing. + +Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at glm-130b@googlegroups.com. +""" +""" PyTorch ChatGLM model. """ + +import copy +import math +import re +import sys +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss, LayerNorm +from torch.nn.utils import skip_init +from transformers.generation.logits_process import LogitsProcessor +from transformers.generation.utils import GenerationConfig, LogitsProcessorList, ModelOutput, StoppingCriteriaList +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + +from .configuration_chatglm import ChatGLMConfig + +# flags required to enable jit fusion kernels + +if sys.platform != "darwin": + torch._C._jit_set_profiling_mode(False) + torch._C._jit_set_profiling_executor(False) + torch._C._jit_override_can_fuse_on_cpu(True) + torch._C._jit_override_can_fuse_on_gpu(True) + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM2-6B" +_CONFIG_FOR_DOC = "ChatGLM6BConfig" + +CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "THUDM/chatglm2-6b", + # See all ChatGLM models at https://huggingface.co/models?filter=chatglm +] + + +def default_init(cls, *args, **kwargs): + return cls(*args, **kwargs) + + +class InvalidScoreLogitsProcessor(LogitsProcessor): + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if torch.isnan(scores).any() or torch.isinf(scores).any(): + scores.zero_() + scores[..., 5] = 5e4 + return scores + + +class PrefixEncoder(torch.nn.Module): + """ + The torch.nn model to encode the prefix + Input shape: (batch-size, prefix-length) + Output shape: (batch-size, prefix-length, 2*layers*hidden) + """ + + def __init__(self, config: ChatGLMConfig): + super().__init__() + self.prefix_projection = config.prefix_projection + if self.prefix_projection: + # Use a two-layer MLP to encode the prefix + kv_size = (config.num_layers * config.kv_channels * config.multi_query_group_num * 2) + self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size) + self.trans = torch.nn.Sequential( + torch.nn.Linear(kv_size, config.hidden_size), + torch.nn.Tanh(), + torch.nn.Linear(config.hidden_size, kv_size), + ) + else: + self.embedding = torch.nn.Embedding( + config.pre_seq_len, + config.num_layers * config.kv_channels * config.multi_query_group_num * 2, + ) + + def forward(self, prefix: torch.Tensor): + if self.prefix_projection: + prefix_tokens = self.embedding(prefix) + past_key_values = self.trans(prefix_tokens) + else: + past_key_values = self.embedding(prefix) + return past_key_values + + +def split_tensor_along_last_dim( + tensor: torch.Tensor, + num_partitions: int, + contiguous_split_chunks: bool = False, +) -> List[torch.Tensor]: + """Split a tensor along its last dimension. + + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + + Returns: + A list of Tensors + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = tensor.size()[last_dim] // num_partitions + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + +class RotaryEmbedding(nn.Module): + + def __init__(self, dim, original_impl=False, device=None, dtype=None): + super().__init__() + inv_freq = 1.0 / (10000**(torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim)) + self.register_buffer("inv_freq", inv_freq) + self.dim = dim + self.original_impl = original_impl + + def forward_impl( + self, + seq_len: int, + n_elem: int, + dtype: torch.dtype, + device: torch.device, + base: int = 10000, + ): + """Enhanced Transformer with Rotary Position Embedding. + + Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ + transformers/rope/__init__.py. MIT License: + https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. + """ + # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + theta = 1.0 / (base**(torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem)) + + # Create position indexes `[0, 1, ..., seq_len - 1]` + seq_idx = torch.arange(seq_len, dtype=dtype, device=device) + + # Calculate the product of position index and $\theta_i$ + idx_theta = torch.outer(seq_idx, theta).float() + + cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) + + # this is to mimic the behaviour of complex32, else we will get different results + if dtype in (torch.float16, torch.bfloat16, torch.int8): + cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half() + return cache + + def forward(self, max_seq_len, offset=0): + return self.forward_impl( + max_seq_len, + self.dim, + dtype=self.inv_freq.dtype, + device=self.inv_freq.device, + ) + + +@torch.jit.script +def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: + # x: [sq, b, np, hn] + sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3) + rot_dim = rope_cache.shape[-2] * 2 + x, x_pass = x[..., :rot_dim], x[..., rot_dim:] + # truncate to support variable sizes + rope_cache = rope_cache[:sq] + xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) + rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], + xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], + ], + -1, + ) + x_out2 = x_out2.flatten(3) + return torch.cat((x_out2, x_pass), dim=-1) + + +class RMSNorm(torch.nn.Module): + + def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): + super().__init__() + self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype)) + self.eps = eps + + def forward(self, hidden_states: torch.Tensor): + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + + return (self.weight * hidden_states).to(input_dtype) + + +class CoreAttention(torch.nn.Module): + + def __init__(self, config: ChatGLMConfig, layer_number): + super(CoreAttention, self).__init__() + + self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 + if self.apply_query_key_layer_scaling: + self.attention_softmax_in_fp32 = True + self.layer_number = max(1, layer_number) + + projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + self.hidden_size_per_partition = projection_size + self.hidden_size_per_attention_head = (projection_size // config.num_attention_heads) + self.num_attention_heads_per_partition = config.num_attention_heads + + coeff = None + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + if self.apply_query_key_layer_scaling: + coeff = self.layer_number + self.norm_factor *= coeff + self.coeff = coeff + + self.attention_dropout = torch.nn.Dropout(config.attention_dropout) + + def forward(self, query_layer, key_layer, value_layer, attention_mask): + pytorch_major_version = int(torch.__version__.split(".")[0]) + if pytorch_major_version >= 2: + query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] + if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, + key_layer, + value_layer, + is_causal=True) + else: + if attention_mask is not None: + attention_mask = ~attention_mask + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, + attention_mask) + context_layer = context_layer.permute(2, 0, 1, 3) + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.reshape(*new_context_layer_shape) + else: + # Raw attention scores + + # [b, np, sq, sk] + output_size = ( + query_layer.size(1), + query_layer.size(2), + query_layer.size(0), + key_layer.size(0), + ) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) + + # preallocting input tensor: [b * np, sq, sk] + matmul_input_buffer = torch.empty( + output_size[0] * output_size[1], + output_size[2], + output_size[3], + dtype=query_layer.dtype, + device=query_layer.device, + ) + + # Raw attention scores. [b * np, sq, sk] + matmul_result = torch.baddbmm( + matmul_input_buffer, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=(1.0 / self.norm_factor), + ) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + # =========================== + # Attention probs and dropout + # =========================== + + # attention scores and attention mask [b, np, sq, sk] + if self.attention_softmax_in_fp32: + attention_scores = attention_scores.float() + if self.coeff is not None: + attention_scores = attention_scores * self.coeff + if (attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]): + attention_mask = torch.ones( + output_size[0], + 1, + output_size[2], + output_size[3], + device=attention_scores.device, + dtype=torch.bool, + ) + attention_mask.tril_() + attention_mask = ~attention_mask + if attention_mask is not None: + attention_scores = attention_scores.masked_fill(attention_mask, float("-inf")) + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = attention_probs.type_as(value_layer) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.attention_dropout(attention_probs) + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = ( + value_layer.size(1), + value_layer.size(2), + query_layer.size(0), + value_layer.size(3), + ) + # change view [sk, b * np, hn] + value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer + + +class SelfAttention(torch.nn.Module): + """Parallel self-attention layer abstract class. + + Self-attention layer takes input with size [s, b, h] + and returns output of the same size. + """ + + def __init__(self, config: ChatGLMConfig, layer_number, device=None): + super(SelfAttention, self).__init__() + self.layer_number = max(1, layer_number) + + self.projection_size = config.kv_channels * config.num_attention_heads + # Per attention head and per partition values. + self.hidden_size_per_attention_head = (self.projection_size // config.num_attention_heads) + self.num_attention_heads_per_partition = config.num_attention_heads + + self.multi_query_attention = config.multi_query_attention + self.qkv_hidden_size = 3 * self.projection_size + if self.multi_query_attention: + self.num_multi_query_groups_per_partition = config.multi_query_group_num + self.qkv_hidden_size = (self.projection_size + + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num) + self.query_key_value = nn.Linear( + config.hidden_size, + self.qkv_hidden_size, + bias=config.add_bias_linear or config.add_qkv_bias, + device=device, + **_config_to_kwargs(config), + ) + + self.core_attention = CoreAttention(config, self.layer_number) + + # Output. + self.dense = nn.Linear( + self.projection_size, + config.hidden_size, + bias=config.add_bias_linear, + device=device, + **_config_to_kwargs(config), + ) + + def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None): + if self.multi_query_attention: + num_attention_heads = self.num_multi_query_groups_per_partition + else: + num_attention_heads = self.num_attention_heads_per_partition + return torch.empty( + inference_max_sequence_len, + batch_size, + num_attention_heads, + self.hidden_size_per_attention_head, + dtype=dtype, + device=device, + ) + + def forward( + self, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=None, + use_cache=True, + ): + # hidden_states: [sq, b, h] + + # ================================================= + # Pre-allocate memory for key-values for inference. + # ================================================= + # ===================== + # Query, Key, and Value + # ===================== + + # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] + mixed_x_layer = self.query_key_value(hidden_states) + + if self.multi_query_attention: + (query_layer, key_layer, value_layer) = mixed_x_layer.split( + [ + self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + ], + dim=-1, + ) + query_layer = query_layer.view(query_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + )) + key_layer = key_layer.view(key_layer.size()[:-1] + ( + self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head, + )) + value_layer = value_layer.view(value_layer.size()[:-1] + ( + self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head, + )) + else: + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) + + # apply relative positional encoding (rotary embedding) + if rotary_pos_emb is not None: + query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) + key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) + + # adjust key and value for inference + if kv_cache is not None: + cache_k, cache_v = kv_cache + key_layer = torch.cat((cache_k, key_layer), dim=0) + value_layer = torch.cat((cache_v, value_layer), dim=0) + if use_cache: + kv_cache = (key_layer, value_layer) + else: + kv_cache = None + + if self.multi_query_attention: + key_layer = key_layer.unsqueeze(-2) + key_layer = key_layer.expand( + -1, + -1, + -1, + self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, + -1, + ) + key_layer = key_layer.contiguous().view(key_layer.size()[:2] + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + )) + value_layer = value_layer.unsqueeze(-2) + value_layer = value_layer.expand( + -1, + -1, + -1, + self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, + -1, + ) + value_layer = value_layer.contiguous().view(value_layer.size()[:2] + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + )) + + # ================================== + # core attention computation + # ================================== + + context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) + + # ================= + # Output. [sq, b, h] + # ================= + + output = self.dense(context_layer) + + return output, kv_cache + + +def _config_to_kwargs(args): + common_kwargs = { + "dtype": args.torch_dtype, + } + return common_kwargs + + +class MLP(torch.nn.Module): + """MLP. + + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. + """ + + def __init__(self, config: ChatGLMConfig, device=None): + super(MLP, self).__init__() + + self.add_bias = config.add_bias_linear + + # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf + self.dense_h_to_4h = nn.Linear( + config.hidden_size, + config.ffn_hidden_size * 2, + bias=self.add_bias, + device=device, + **_config_to_kwargs(config), + ) + + def swiglu(x): + x = torch.chunk(x, 2, dim=-1) + return F.silu(x[0]) * x[1] + + self.activation_func = swiglu + + # Project back to h. + self.dense_4h_to_h = nn.Linear( + config.ffn_hidden_size, + config.hidden_size, + bias=self.add_bias, + device=device, + **_config_to_kwargs(config), + ) + + def forward(self, hidden_states): + # [s, b, 4hp] + intermediate_parallel = self.dense_h_to_4h(hidden_states) + intermediate_parallel = self.activation_func(intermediate_parallel) + # [s, b, h] + output = self.dense_4h_to_h(intermediate_parallel) + return output + + +class GLMBlock(torch.nn.Module): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + """ + + def __init__(self, config: ChatGLMConfig, layer_number, device=None): + super(GLMBlock, self).__init__() + self.layer_number = layer_number + + self.apply_residual_connection_post_layernorm = (config.apply_residual_connection_post_layernorm) + + self.fp32_residual_connection = config.fp32_residual_connection + + LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm + # Layernorm on the input data. + self.input_layernorm = LayerNormFunc( + config.hidden_size, + eps=config.layernorm_epsilon, + device=device, + dtype=config.torch_dtype, + ) + + # Self attention. + self.self_attention = SelfAttention(config, layer_number, device=device) + self.hidden_dropout = config.hidden_dropout + + # Layernorm on the attention output + self.post_attention_layernorm = LayerNormFunc( + config.hidden_size, + eps=config.layernorm_epsilon, + device=device, + dtype=config.torch_dtype, + ) + + # MLP + self.mlp = MLP(config, device=device) + + def forward( + self, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=None, + use_cache=True, + ): + # hidden_states: [s, b, h] + + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output, kv_cache = self.self_attention( + layernorm_output, + attention_mask, + rotary_pos_emb, + kv_cache=kv_cache, + use_cache=use_cache, + ) + + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training) + layernorm_input = residual + layernorm_input + + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + + # MLP. + mlp_output = self.mlp(layernorm_output) + + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training) + output = residual + output + + return output, kv_cache + + +class GLMTransformer(torch.nn.Module): + """Transformer class.""" + + def __init__(self, config: ChatGLMConfig, device=None): + super(GLMTransformer, self).__init__() + + self.fp32_residual_connection = config.fp32_residual_connection + self.post_layer_norm = config.post_layer_norm + + # Number of layers. + self.num_layers = config.num_layers + + # Transformer layers. + def build_layer(layer_number): + return GLMBlock(config, layer_number, device=device) + + self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)]) + + if self.post_layer_norm: + LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm + # Final layer norm before output. + self.final_layernorm = LayerNormFunc( + config.hidden_size, + eps=config.layernorm_epsilon, + device=device, + dtype=config.torch_dtype, + ) + + self.gradient_checkpointing = False + + def _get_layer(self, layer_number): + return self.layers[layer_number] + + def forward( + self, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_caches=None, + use_cache: Optional[bool] = True, + output_hidden_states: Optional[bool] = False, + ): + if not kv_caches: + kv_caches = [None for _ in range(self.num_layers)] + presents = () if use_cache else None + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + all_self_attentions = None + all_hidden_states = () if output_hidden_states else None + for index in range(self.num_layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer = self._get_layer(index) + if self.gradient_checkpointing and self.training: + layer_ret = torch.utils.checkpoint.checkpoint( + layer, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_caches[index], + use_cache, + ) + else: + layer_ret = layer( + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=kv_caches[index], + use_cache=use_cache, + ) + hidden_states, kv_cache = layer_ret + if use_cache: + presents = presents + (kv_cache,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # Final layer norm. + if self.post_layer_norm: + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states, presents, all_hidden_states, all_self_attentions + + +class ChatGLMPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and + a simple interface for downloading and loading pretrained models. + """ + + is_parallelizable = False + supports_gradient_checkpointing = True + config_class = ChatGLMConfig + base_model_prefix = "transformer" + _no_split_modules = ["GLMBlock"] + + def _init_weights(self, module: nn.Module): + """Initialize the weights.""" + return + + def get_masks(self, input_ids, past_key_values, padding_mask=None): + batch_size, seq_length = input_ids.shape + full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device) + full_attention_mask.tril_() + past_length = 0 + if past_key_values: + past_length = past_key_values[0][0].shape[0] + if past_length: + full_attention_mask = torch.cat( + ( + torch.ones(batch_size, seq_length, past_length, device=input_ids.device), + full_attention_mask, + ), + dim=-1, + ) + if padding_mask is not None: + full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) + if not past_length and padding_mask is not None: + full_attention_mask -= padding_mask.unsqueeze(-1) - 1 + full_attention_mask = (full_attention_mask < 0.5).bool() + full_attention_mask.unsqueeze_(1) + return full_attention_mask + + def get_position_ids(self, input_ids, device): + batch_size, seq_length = input_ids.shape + position_ids = (torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)) + return position_ids + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, GLMTransformer): + module.gradient_checkpointing = value + + +class Embedding(torch.nn.Module): + """Language model embeddings.""" + + def __init__(self, config: ChatGLMConfig, device=None): + super(Embedding, self).__init__() + + self.hidden_size = config.hidden_size + # Word embeddings (parallel). + self.word_embeddings = nn.Embedding( + config.padded_vocab_size, + self.hidden_size, + dtype=config.torch_dtype, + device=device, + ) + self.fp32_residual_connection = config.fp32_residual_connection + + def forward(self, input_ids): + # Embeddings. + words_embeddings = self.word_embeddings(input_ids) + embeddings = words_embeddings + # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. + embeddings = embeddings.transpose(0, 1).contiguous() + # If the input flag for fp32 residual connection is set, convert for float. + if self.fp32_residual_connection: + embeddings = embeddings.float() + return embeddings + + +class ChatGLMModel(ChatGLMPreTrainedModel): + + def __init__(self, config: ChatGLMConfig, device=None, empty_init=True): + super().__init__(config) + if empty_init: + init_method = skip_init + else: + init_method = default_init + init_kwargs = {} + if device is not None: + init_kwargs["device"] = device + self.embedding = init_method(Embedding, config, **init_kwargs) + self.num_layers = config.num_layers + self.multi_query_group_num = config.multi_query_group_num + self.kv_channels = config.kv_channels + + # Rotary positional embeddings + self.seq_length = config.seq_length + rotary_dim = (config.hidden_size // + config.num_attention_heads if config.kv_channels is None else config.kv_channels) + + self.rotary_pos_emb = RotaryEmbedding( + rotary_dim // 2, + original_impl=config.original_rope, + device=device, + dtype=config.torch_dtype, + ) + self.encoder = init_method(GLMTransformer, config, **init_kwargs) + self.output_layer = init_method( + nn.Linear, + config.hidden_size, + config.padded_vocab_size, + bias=False, + dtype=config.torch_dtype, + **init_kwargs, + ) + self.pre_seq_len = config.pre_seq_len + self.prefix_projection = config.prefix_projection + if self.pre_seq_len is not None: + for param in self.parameters(): + param.requires_grad = False + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = PrefixEncoder(config) + self.dropout = torch.nn.Dropout(0.1) + + def get_input_embeddings(self): + return self.embedding.word_embeddings + + def get_prompt(self, batch_size, device, dtype=torch.half): + prefix_tokens = (self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)) + past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) + past_key_values = past_key_values.view( + batch_size, + self.pre_seq_len, + self.num_layers * 2, + self.multi_query_group_num, + self.kv_channels, + ) + # seq_len, b, nh, hidden_size + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) + return past_key_values + + def forward( + self, + input_ids, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + full_attention_mask: Optional[torch.BoolTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) + + batch_size, seq_length = input_ids.shape + + if inputs_embeds is None: + inputs_embeds = self.embedding(input_ids) + + if self.pre_seq_len is not None: + if past_key_values is None: + past_key_values = self.get_prompt( + batch_size=batch_size, + device=input_ids.device, + dtype=inputs_embeds.dtype, + ) + if attention_mask is not None: + attention_mask = torch.cat( + [ + attention_mask.new_ones((batch_size, self.pre_seq_len)), + attention_mask, + ], + dim=-1, + ) + + if full_attention_mask is None: + if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): + full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) + + # Rotary positional embeddings + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + if position_ids is not None: + rotary_pos_emb = rotary_pos_emb[position_ids] + else: + rotary_pos_emb = rotary_pos_emb[None, :seq_length] + rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() + + # Run encoder. + hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( + inputs_embeds, + full_attention_mask, + rotary_pos_emb=rotary_pos_emb, + kv_caches=past_key_values, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + ) + + if not return_dict: + return tuple(v for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + ] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + def quantize(self, weight_bit_width: int): + from .quantization import quantize + + quantize(self.encoder, weight_bit_width) + return self + + +class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): + + def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): + super().__init__(config) + + self.max_sequence_length = config.max_length + self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device) + self.config = config + self.quantized = False + + if self.config.quantization_bit: + self.quantize(self.config.quantization_bit, empty_init=True) + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + standardize_cache_format: bool = False, + ) -> Dict[str, Any]: + # update past_key_values + model_kwargs["past_key_values"] = self._extract_past_from_model_output( + outputs, standardize_cache_format=standardize_cache_format) + + # update attention mask + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], + dim=-1, + ) + + # update position ids + if "position_ids" in model_kwargs: + position_ids = model_kwargs["position_ids"] + new_position_id = position_ids[..., -1:].clone() + new_position_id += 1 + model_kwargs["position_ids"] = torch.cat([position_ids, new_position_id], dim=-1) + + model_kwargs["is_first_forward"] = False + return model_kwargs + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + is_first_forward: bool = True, + **kwargs, + ) -> dict: + # only last token for input_ids if past is not None + if position_ids is None: + position_ids = self.get_position_ids(input_ids, device=input_ids.device) + if not is_first_forward: + position_ids = position_ids[..., -1:] + input_ids = input_ids[:, -1:] + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "position_ids": position_ids, + "attention_mask": attention_mask, + "return_last_logit": True, + } + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + return_last_logit: Optional[bool] = False, + ): + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) + + transformer_outputs = self.transformer( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + if return_last_logit: + hidden_states = hidden_states[-1:] + lm_logits = self.transformer.output_layer(hidden_states) + lm_logits = lm_logits.transpose(0, 1).contiguous() + + loss = None + if labels is not None: + lm_logits = lm_logits.to(torch.float32) + + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + lm_logits = lm_logits.to(hidden_states.dtype) + loss = loss.to(hidden_states.dtype) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache(past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], + beam_idx: torch.LongTensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + + Output shares the same memory storage as `past`. + """ + return tuple(( + layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), + layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), + ) for layer_past in past) + + def process_response(self, response): + response = response.strip() + response = response.replace("[[训练时间]]", "2023年") + return response + + def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None): + prompt = tokenizer.build_prompt(query, history=history) + inputs = tokenizer([prompt], return_tensors="pt") + inputs = inputs.to(self.device) + return inputs + + def build_stream_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None): + if history: + prompt = "\n\n[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query) + input_ids = tokenizer.encode(prompt, add_special_tokens=False) + input_ids = input_ids[1:] + inputs = tokenizer.batch_encode_plus([(input_ids, None)], return_tensors="pt", add_special_tokens=False) + else: + prompt = "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query) + inputs = tokenizer([prompt], return_tensors="pt") + inputs = inputs.to(self.device) + return inputs + + @torch.no_grad() + def chat( + self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = None, + max_length: int = 8192, + num_beams=1, + do_sample=True, + top_p=0.8, + temperature=0.8, + logits_processor=None, + **kwargs, + ): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + gen_kwargs = { + "max_length": max_length, + "num_beams": num_beams, + "do_sample": do_sample, + "top_p": top_p, + "temperature": temperature, + "logits_processor": logits_processor, + **kwargs, + } + inputs = self.build_inputs(tokenizer, query, history=history) + outputs = self.generate(**inputs, **gen_kwargs) + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] + response = tokenizer.decode(outputs) + response = self.process_response(response) + history = history + [(query, response)] + return response, history + + @torch.no_grad() + def stream_chat( + self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = None, + past_key_values=None, + max_length: int = 8192, + do_sample=True, + top_p=0.8, + temperature=0.8, + logits_processor=None, + return_past_key_values=False, + **kwargs, + ): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + gen_kwargs = { + "max_length": max_length, + "do_sample": do_sample, + "top_p": top_p, + "temperature": temperature, + "logits_processor": logits_processor, + **kwargs, + } + if past_key_values is None and not return_past_key_values: + inputs = self.build_inputs(tokenizer, query, history=history) + else: + inputs = self.build_stream_inputs(tokenizer, query, history=history) + if past_key_values is not None: + past_length = past_key_values[0][0].shape[0] + if self.transformer.pre_seq_len is not None: + past_length -= self.transformer.pre_seq_len + inputs.position_ids += past_length + attention_mask = inputs.attention_mask + attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1) + inputs["attention_mask"] = attention_mask + for outputs in self.stream_generate( + **inputs, + past_key_values=past_key_values, + return_past_key_values=return_past_key_values, + **gen_kwargs, + ): + if return_past_key_values: + outputs, past_key_values = outputs + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] + response = tokenizer.decode(outputs) + if response and response[-1] != "�": + response = self.process_response(response) + new_history = history + [(query, response)] + if return_past_key_values: + yield response, new_history, past_key_values + else: + yield response, new_history + + @torch.no_grad() + def stream_generate( + self, + input_ids, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + return_past_key_values=False, + **kwargs, + ): + batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] + + if generation_config is None: + generation_config = self.generation_config + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) + bos_token_id, eos_token_id = ( + generation_config.bos_token_id, + generation_config.eos_token_id, + ) + + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + + has_default_max_length = (kwargs.get("max_length") is None and generation_config.max_length is not None) + if has_default_max_length and generation_config.max_new_tokens is None: + warnings.warn( + f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " + "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" + " recommend using `max_new_tokens` to control the maximum length of the generation.", + UserWarning, + ) + elif generation_config.max_new_tokens is not None: + generation_config.max_length = (generation_config.max_new_tokens + input_ids_seq_length) + if not has_default_max_length: + logger.warn( + f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" + f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", + UserWarning, + ) + + if input_ids_seq_length >= generation_config.max_length: + input_ids_string = ("decoder_input_ids" if self.config.is_encoder_decoder else "input_ids") + logger.warning(f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`.") + + # 2. Set generation parameters if not already defined + logits_processor = (logits_processor if logits_processor is not None else LogitsProcessorList()) + stopping_criteria = (stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()) + + logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=input_ids, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + ) + + stopping_criteria = self._get_stopping_criteria(generation_config=generation_config, + stopping_criteria=stopping_criteria) + logits_warper = self._get_logits_warper(generation_config) + + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + scores = None + while True: + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + if generation_config.do_sample: + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(probs, dim=-1) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation(outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder) + unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) + if return_past_key_values: + yield input_ids, outputs.past_key_values + else: + yield input_ids + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + break + + def quantize(self, bits: int, empty_init=False, device=None, **kwargs): + if bits == 0: + return + + from .quantization import quantize + + if self.quantized: + logger.info("Already quantized.") + return self + + self.quantized = True + + self.config.quantization_bit = bits + + self.transformer.encoder = quantize( + self.transformer.encoder, + bits, + empty_init=empty_init, + device=device, + **kwargs, + ) + return self From 593d347082a812922c03daf746cb6a3f13d37355 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 1 Aug 2023 18:44:45 +0800 Subject: [PATCH 21/27] chatglm --- colossalai/shardformer/modeling/chatglm.py | 35 +--- colossalai/shardformer/policies/chatglm.py | 212 ++++++++++++++++++++ tests/kit/model_zoo/torchrec/__init__.py | 2 +- tests/test_shardformer/test_model/_utils.py | 7 +- 4 files changed, 223 insertions(+), 33 deletions(-) create mode 100644 colossalai/shardformer/policies/chatglm.py diff --git a/colossalai/shardformer/modeling/chatglm.py b/colossalai/shardformer/modeling/chatglm.py index 5557dd379a75..0bb8bdc58218 100644 --- a/colossalai/shardformer/modeling/chatglm.py +++ b/colossalai/shardformer/modeling/chatglm.py @@ -4,7 +4,7 @@ import torch import torch.nn.functional as F import torch.utils.checkpoint -from torch.nn import CrossEntropyLoss +from torch.nn import CrossEntropyLoss, LayerNorm from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.utils import logging @@ -22,6 +22,7 @@ class ChatGLMPipelineForwards: This class serves as a micro library for ChatGLM model forwards under pipeline parallelism. ''' + @staticmethod def chatglm_model_forward( self: ChatGLMModel, input_ids, @@ -37,13 +38,11 @@ def chatglm_model_forward( hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, ): - logger = logging.get_logger(__name__) output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. if past_key_values: logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') @@ -54,16 +53,13 @@ def chatglm_model_forward( if use_cache: logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') use_cache = False - if stage_manager.is_first_stage(): batch_size, seq_length = input_ids.shape - if inputs_embeds is None: inputs_embeds = self.embedding(input_ids) hidden_states = inputs_embeds else: seq_length, batch_size = hidden_states.shape[:2] - if self.pre_seq_len is not None: if past_key_values is None: past_key_values = self.get_prompt(batch_size=batch_size, @@ -72,7 +68,6 @@ def chatglm_model_forward( if attention_mask is not None: attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)), attention_mask], dim=-1) - if full_attention_mask is None: if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) @@ -81,26 +76,16 @@ def chatglm_model_forward( if position_ids is not None: rotary_pos_emb = rotary_pos_emb[position_ids] else: - layer_ret = layer(hidden_states, - attention_mask, - rotary_pos_emb, - kv_cache=kv_caches[idx], - use_cache=use_cache) - hidden_states, kv_cache = layer_ret - if use_cache: - presents = presents + (kv_cache,) - + rotary_pos_emb = rotary_pos_emb[None, :seq_length] + rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() if not past_key_values: past_key_values = [None for _ in range(self.num_layers)] - presents = () if use_cache else None - if self.encoder.gradient_checkpointing and self.encoder.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") use_cache = False - all_self_attentions = None all_hidden_states = () if output_hidden_states else None start_idx, end_idx = stage_index[0], stage_index[1] @@ -108,7 +93,6 @@ def chatglm_model_forward( layer = self.encoder._get_layer(idx) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.encoder.gradient_checkpointing and self.encoder.training: layer_ret = torch.utils.checkpoint.checkpoint(layer, hidden_states, attention_mask, rotary_pos_emb, past_key_values[idx], use_cache) @@ -118,13 +102,11 @@ def chatglm_model_forward( rotary_pos_emb, kv_cache=past_key_values[idx], use_cache=use_cache) - hidden_states, kv_cache = layer_ret if use_cache: presents = presents + (kv_cache,) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if stage_manager.is_last_stage(): # final layer_norm if self.encoder.post_layer_norm: @@ -132,7 +114,6 @@ def chatglm_model_forward( if not return_dict: return tuple( v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) - return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=presents, @@ -142,6 +123,7 @@ def chatglm_model_forward( else: return {'hidden_states': hidden_states} + @staticmethod def chatglm_for_conditional_generation_forward( self: ChatGLMForConditionalGeneration, input_ids: Optional[torch.Tensor] = None, @@ -160,10 +142,8 @@ def chatglm_for_conditional_generation_forward( stage_index: Optional[List[int]] = None, ): logger = logging.get_logger(__name__) - use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) - transformer_outputs = ChatGLMPipelineForwards.chatglm_model_forward( self.transformer, input_ids=input_ids, @@ -184,25 +164,20 @@ def chatglm_for_conditional_generation_forward( hidden_states = hidden_states[-1:] lm_logits = self.transformer.output_layer(hidden_states) lm_logits = lm_logits.transpose(0, 1).contiguous() - loss = None if labels is not None: lm_logits = lm_logits.to(torch.float32) - # Shift so that tokens < n predict n shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss(ignore_index=-100) loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - lm_logits = lm_logits.to(hidden_states.dtype) loss = loss.to(hidden_states.dtype) - if not return_dict: output = (lm_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=lm_logits, diff --git a/colossalai/shardformer/policies/chatglm.py b/colossalai/shardformer/policies/chatglm.py new file mode 100644 index 000000000000..79733f893a3a --- /dev/null +++ b/colossalai/shardformer/policies/chatglm.py @@ -0,0 +1,212 @@ +from functools import partial +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch.nn as nn +from torch import Tensor +from transformers.modeling_outputs import BaseModelOutputWithPast + +import colossalai.shardformer.layer as col_nn +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.modeling.chatglm import ChatGLMPipelineForwards +from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig +from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( + ChatGLMForConditionalGeneration, + ChatGLMModel, + GLMBlock, +) + +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = ['ChatGLMPolicy', 'ChatGLMModelPolicy', 'ChatGLMForConditionalGenerationPolicy'] + + +class ChatGLMPolicy(Policy): + + def config_sanity_check(self): + pass + + def preprocess(self): + # Resize embedding + if self.shard_config.enable_tensor_parallelism: + vocab_size = self.model.config.padded_vocab_size + world_size = self.shard_config.tensor_parallel_size + + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + + return self.model + + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMModel, GLMBlock + + policy = {} + + if self.shard_config.enable_tensor_parallelism: + + policy[ChatGLMModel] = ModulePolicyDescription(attribute_replacement={}, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="embedding.word_embeddings", + target_module=col_nn.VocabParallelEmbedding1D, + ) + ]) + + policy[GLMBlock] = ModulePolicyDescription(attribute_replacement={ + "self_attention.num_attention_heads_per_partition": + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "self_attention.projection_size": + (self.model.config.kv_channels * self.model.config.num_attention_heads) // + self.shard_config.tensor_parallel_size, + "self_attention.qkv_hidden_size": + (self.model.config.kv_channels * self.model.config.num_attention_heads * 3) // + self.shard_config.tensor_parallel_size, + "self_attention.core_attention.num_attention_heads_per_partition": + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "self_attention.core_attention.hidden_size_per_partition": + self.model.config.kv_channels * self.model.config.num_attention_heads // + self.shard_config.tensor_parallel_size, + }, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attention.query_key_value", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attention.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="self_attention.core_attention.attention_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ]) + # optimization configuration + if self.shard_config.enable_fused_normalization: + if not self.model.config.rmsnorm: + + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedLayerNorm), + SubModuleReplacementDescription(suffix="post_attention_layernorm", + target_module=col_nn.FusedLayerNorm) + ], + policy=policy, + target_key=GLMBlock) + + if self.model.config.post_layer_norm: + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription(suffix="encoder.final_layernorm", + target_module=col_nn.FusedLayerNorm) + ], + policy=policy, + target_key=ChatGLMModel) + + else: + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedRMSNorm), + SubModuleReplacementDescription(suffix="post_attention_layernorm", + target_module=col_nn.FusedRMSNorm) + ], + policy=policy, + target_key=GLMBlock) + + if self.model.config.post_layer_norm: + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription(suffix="encoder.final_layernorm", + target_module=col_nn.FusedRMSNorm) + ], + policy=policy, + target_key=ChatGLMModel) + + return policy + + def postprocess(self): + return self.model + + def get_held_layers(self) -> List[nn.Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + + if self.model.__class__.__name__ == 'ChatGLMModel': + module = self.model + else: + module = self.model.transformer + stage_manager = self.pipeline_stage_manager + + held_layers = [] + layers_per_stage = self.distribute_layers(module.num_layers, stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.embedding) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.encoder.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + if module.encoder.post_layer_norm: + held_layers.append(module.encoder.final_layernorm) + + # rotary_pos_emb is needed for all stages + held_layers.append(module.rotary_pos_emb) + + return held_layers + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if not self.pipeline_stage_manager: + raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.") + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == 'ChatGLMModel': + module = self.model + else: + module = self.model.transformer + + layers_per_stage = Policy.distribute_layers(module.num_layers, stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) + + +class ChatGLMModelPolicy(ChatGLMPolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.gpt2.modeling_gpt2 import GPT2Model + + policy = super().module_policy() + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=ChatGLMModel, + new_forward=ChatGLMPipelineForwards.chatglm_model_forward, + policy=policy) + return policy + + def get_held_layers(self) -> List[nn.Module]: + return super().get_held_layers() + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in ChatGLMModel.""" + return [] + + +class ChatGLMForConditionalGenerationPolicy(ChatGLMModelPolicy): + + def module_policy(self): + policy = super().module_policy() + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=ChatGLMForConditionalGeneration, + new_forward=ChatGLMPipelineForwards.chatglm_for_conditional_generation_forward, + policy=policy) + return policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.transformer.output_layer) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in ChatGLMForConditionalGenerationModel.""" + return [] diff --git a/tests/kit/model_zoo/torchrec/__init__.py b/tests/kit/model_zoo/torchrec/__init__.py index 43952e6998cf..4a19f2449602 100644 --- a/tests/kit/model_zoo/torchrec/__init__.py +++ b/tests/kit/model_zoo/torchrec/__init__.py @@ -1 +1 @@ -from .torchrec import * +#from .torchrec import * diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 2320c725d444..2430102115c1 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -1,11 +1,13 @@ import copy from contextlib import nullcontext +from typing import Optional import torch from torch.nn import Module from colossalai.lazy import LazyInitContext from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer.policies.auto_policy import Policy def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True, use_lazy_init: bool = False): @@ -28,7 +30,8 @@ def build_pipeline_model(model_fn, stage_manager=None, enable_fused_normalization=False, enable_tensor_parallelism=False, - use_lazy_init: bool = False): + use_lazy_init: bool = False, + policy: Optional[Policy] = None): ctx = LazyInitContext() if use_lazy_init else nullcontext() with ctx: # create new model @@ -43,7 +46,7 @@ def build_pipeline_model(model_fn, pipeline_stage_manager=stage_manager) shard_former = ShardFormer(shard_config=shard_config) - sharded_model, shared_params = shard_former.optimize(model_copy) + sharded_model, shared_params = shard_former.optimize(model_copy, policy=policy) return org_model.cuda(), sharded_model.cuda() From c14dc16d1999d30084dd60b4c2d34a0f99ac8b64 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 1 Aug 2023 19:08:40 +0800 Subject: [PATCH 22/27] finish chatglm --- tests/kit/model_zoo/torchrec/__init__.py | 2 +- tests/kit/model_zoo/transformers/chatglm.py | 4 ++-- tests/test_shardformer/test_model/test_shard_chatglm.py | 5 +++-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/kit/model_zoo/torchrec/__init__.py b/tests/kit/model_zoo/torchrec/__init__.py index 4a19f2449602..43952e6998cf 100644 --- a/tests/kit/model_zoo/torchrec/__init__.py +++ b/tests/kit/model_zoo/torchrec/__init__.py @@ -1 +1 @@ -#from .torchrec import * +from .torchrec import * diff --git a/tests/kit/model_zoo/transformers/chatglm.py b/tests/kit/model_zoo/transformers/chatglm.py index 96306fa54ff5..1b81fef65518 100644 --- a/tests/kit/model_zoo/transformers/chatglm.py +++ b/tests/kit/model_zoo/transformers/chatglm.py @@ -21,8 +21,8 @@ def data_gen(): output_transform_fn = lambda x: x # define loss function -loss_fn_for_chatglm_model = lambda x: x.last_hidden_state -loss_fn = lambda x: x.loss +loss_fn_for_chatglm_model = lambda x: x.last_hidden_state.sum() +loss_fn = lambda x: x.logits.sum() config = ChatGLMConfig(num_layers=1, padded_vocab_size=65024, hidden_size=64, diff --git a/tests/test_shardformer/test_model/test_shard_chatglm.py b/tests/test_shardformer/test_model/test_shard_chatglm.py index 7f6ce65b4255..4f8046b1a692 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm.py @@ -84,9 +84,10 @@ def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism): model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) if name == "transformers_chatglm": - sharded_model = shard_former.optimize(model_copy, ChatGLMModelPolicy()).cuda() + sharded_model, _ = shard_former.optimize(model_copy, ChatGLMModelPolicy()) else: - sharded_model = shard_former.optimize(model_copy, ChatGLMForConditionalGenerationPolicy()).cuda() + sharded_model, _ = shard_former.optimize(model_copy, ChatGLMForConditionalGenerationPolicy()) + sharded_model.cuda() check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() From d451a176681b118a431c7b8263733e1e101a8b6d Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 1 Aug 2023 19:12:21 +0800 Subject: [PATCH 23/27] deletes --- .../test_bert_for_pretraining_model.py | 118 ------------------ .../test_policy/test_bert_lmhead_model.py | 118 ------------------ .../test_policy/test_bert_model.py | 113 ----------------- .../test_policy/test_t5_pipeline_utils.py | 39 ------ 4 files changed, 388 deletions(-) delete mode 100644 tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py delete mode 100644 tests/test_pipeline/test_policy/test_bert_lmhead_model.py delete mode 100644 tests/test_pipeline/test_policy/test_bert_model.py delete mode 100644 tests/test_pipeline/test_policy/test_t5_pipeline_utils.py diff --git a/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py b/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py deleted file mode 100644 index afbea49c1829..000000000000 --- a/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py +++ /dev/null @@ -1,118 +0,0 @@ -import pytest -import torch -import torch.distributed as dist -from transformers.models.bert import BertConfig -from transformers.models.bert.modeling_bert import BertForPreTraining - -import colossalai -from colossalai.cluster import ProcessGroupMesh -from colossalai.pipeline.policy.bert import BertForPreTrainingPolicy, bert_for_pretraining_forward -from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.testing import rerun_if_address_is_in_use, spawn - - -def check_bert_for_pretraining_forward(): - configuration = BertConfig() - model = BertForPreTraining(configuration) - DP_DIM, PP_DIM = 0, 1 - DP_SIZE, PP_SIZE = 2, 2 - RANK_TO_COORDINATE = { - 0: (0, 0), - 1: (0, 1), - 2: (1, 0), - 3: (1, 1), - } - PP_RANKS_IN_GROUP = { - 0: [0, 1], - 1: [0, 1], - 2: [2, 3], - 3: [2, 3], - } - pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) - # print(pg_mesh) - - stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - rank = dist.get_rank() - # print(rank) - - x = torch.randint(0, 1000, (2, 3)) - hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32) - if stage_manager.stage == 0: - attention_mask = torch.ones_like(x) - output = bert_for_pretraining_forward(self=model, - input_ids=x, - attention_mask=attention_mask, - stage_manager=stage_manager) - print(output['hidden_states'].shape) - assert output['hidden_states'].shape == (2, 3, 768) - print('start the training') - else: - attention_mask = torch.ones((2, 3)) - output = bert_for_pretraining_forward(self=model, - hidden_states=hidden_states, - attention_mask=attention_mask, - stage_manager=stage_manager) - print(output[0].shape) - assert output[0].shape == (2, 3, 30522) - print('end the training') - print(output) - - # assert output[1].shape == (2, 768) - - -def check_bert_for_pretraining_policy(): - configuration = BertConfig() - model = BertForPreTraining(configuration) - DP_DIM, PP_DIM = 0, 1 - DP_SIZE, PP_SIZE = 2, 2 - RANK_TO_COORDINATE = { - 0: (0, 0), - 1: (0, 1), - 2: (1, 0), - 3: (1, 1), - } - PP_RANKS_IN_GROUP = { - 0: [0, 1], - 1: [0, 1], - 2: [2, 3], - 3: [2, 3], - } - pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) - # print(pg_mesh) - - stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - rank = dist.get_rank() - - model_policy = BertForPreTrainingPolicy(stage_manager, len(model.bert.encoder.layer)) - assert model_policy.layers_per_stage == [6, 6] - layers = model_policy.get_hold_layers(model) - for layer in layers: - print(layer) - - -def run_dist_model(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') - check_bert_for_pretraining_forward() - - -def run_dist_policy(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') - check_bert_for_pretraining_policy() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_bert_for_pretraining_forward(): - spawn(run_dist_model, 4) - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_bert_for_pretraining_policy(): - spawn(run_dist_policy, 4) - - -if __name__ == "__main__": - """test the bert for pretraining model forward and bert for pretraining model policy""" - test_bert_for_pretraining_forward() - test_bert_for_pretraining_policy() diff --git a/tests/test_pipeline/test_policy/test_bert_lmhead_model.py b/tests/test_pipeline/test_policy/test_bert_lmhead_model.py deleted file mode 100644 index d41eddc74dff..000000000000 --- a/tests/test_pipeline/test_policy/test_bert_lmhead_model.py +++ /dev/null @@ -1,118 +0,0 @@ -import pytest -import torch -import torch.distributed as dist -from transformers.models.bert import BertConfig -from transformers.models.bert.modeling_bert import BertLMHeadModel - -import colossalai -from colossalai.cluster import ProcessGroupMesh -from colossalai.pipeline.policy.bert import BertLMHeadModelPolicy, bert_lmhead_forward -from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.testing import rerun_if_address_is_in_use, spawn - - -def check_bert_lmhead_forward(): - configuration = BertConfig() - model = BertLMHeadModel(configuration) - DP_DIM, PP_DIM = 0, 1 - DP_SIZE, PP_SIZE = 2, 2 - RANK_TO_COORDINATE = { - 0: (0, 0), - 1: (0, 1), - 2: (1, 0), - 3: (1, 1), - } - PP_RANKS_IN_GROUP = { - 0: [0, 1], - 1: [0, 1], - 2: [2, 3], - 3: [2, 3], - } - pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) - # print(pg_mesh) - - stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - rank = dist.get_rank() - # print(rank) - - x = torch.randint(0, 1000, (2, 3)) - hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32) - if stage_manager.stage == 0: - attention_mask = torch.ones_like(x) - output = bert_lmhead_forward(self=model, - input_ids=x, - attention_mask=attention_mask, - stage_manager=stage_manager) - print(output['hidden_states'].shape) - assert output['hidden_states'].shape == (2, 3, 768) - print('start the training') - else: - attention_mask = torch.ones((2, 3)) - output = bert_lmhead_forward(self=model, - hidden_states=hidden_states, - attention_mask=attention_mask, - stage_manager=stage_manager) - print(output[0].shape) - assert output[0].shape == (2, 3, 30522) - print('end the training') - print(output) - - # assert output[1].shape == (2, 768) - - -def check_bert_lmhead_policy(): - configuration = BertConfig() - model = BertLMHeadModel(configuration) - DP_DIM, PP_DIM = 0, 1 - DP_SIZE, PP_SIZE = 2, 2 - RANK_TO_COORDINATE = { - 0: (0, 0), - 1: (0, 1), - 2: (1, 0), - 3: (1, 1), - } - PP_RANKS_IN_GROUP = { - 0: [0, 1], - 1: [0, 1], - 2: [2, 3], - 3: [2, 3], - } - pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) - # print(pg_mesh) - - stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - rank = dist.get_rank() - - model_policy = BertLMHeadModelPolicy(stage_manager, len(model.bert.encoder.layer)) - assert model_policy.layers_per_stage == [6, 6] - layers = model_policy.get_hold_layers(model) - for layer in layers: - print(layer) - - -def run_dist_model(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') - check_bert_lmhead_forward() - - -def run_dist_policy(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') - check_bert_lmhead_policy() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_bert_lmhead_forward(): - spawn(run_dist_model, 4) - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_bert_lmhead_policy(): - spawn(run_dist_policy, 4) - - -if __name__ == "__main__": - """test the bert for pretraining model forward and bert for pretraining model policy""" - test_bert_lmhead_forward() - test_bert_lmhead_policy() diff --git a/tests/test_pipeline/test_policy/test_bert_model.py b/tests/test_pipeline/test_policy/test_bert_model.py deleted file mode 100644 index 92485072a5e4..000000000000 --- a/tests/test_pipeline/test_policy/test_bert_model.py +++ /dev/null @@ -1,113 +0,0 @@ -import pytest -import torch -import torch.distributed as dist -from transformers.models.bert.modeling_bert import BertModel - -import colossalai -from colossalai.cluster import ProcessGroupMesh -from colossalai.pipeline.policy.bert import BertModelPolicy, bert_model_forward -from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.testing import rerun_if_address_is_in_use, spawn - - -def check_bert_model_forward(): - model = BertModel.from_pretrained('bert-base-uncased') - DP_DIM, PP_DIM = 0, 1 - DP_SIZE, PP_SIZE = 2, 2 - RANK_TO_COORDINATE = { - 0: (0, 0), - 1: (0, 1), - 2: (1, 0), - 3: (1, 1), - } - PP_RANKS_IN_GROUP = { - 0: [0, 1], - 1: [0, 1], - 2: [2, 3], - 3: [2, 3], - } - pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) - - # print(pg_mesh) - - stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - rank = dist.get_rank() - # print(rank) - - x = torch.randint(0, 1000, (2, 3)) - hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32) - if stage_manager.stage == 0: - attention_mask = torch.ones_like(x) - output = bert_model_forward(self=model, input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) - print(output['hidden_states'].shape) - assert output['hidden_states'].shape == (2, 3, 768) - print('start the training') - else: - attention_mask = torch.ones((2, 3)) - output = bert_model_forward(self=model, - hidden_states=hidden_states, - attention_mask=attention_mask, - stage_manager=stage_manager) - print(output[0].shape) - assert output[0].shape == (2, 3, 768) - print('end the training') - print(output) - - # assert output[1].shape == (2, 768) - - -def check_bert_model_policy(): - model = BertModel.from_pretrained('bert-base-uncased') - DP_DIM, PP_DIM = 0, 1 - DP_SIZE, PP_SIZE = 2, 2 - RANK_TO_COORDINATE = { - 0: (0, 0), - 1: (0, 1), - 2: (1, 0), - 3: (1, 1), - } - PP_RANKS_IN_GROUP = { - 0: [0, 1], - 1: [0, 1], - 2: [2, 3], - 3: [2, 3], - } - pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) - # print(pg_mesh) - - stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - rank = dist.get_rank() - - model_policy = BertModelPolicy(stage_manager, len(model.encoder.layer)) - assert model_policy.layers_per_stage == [6, 6] - layers = model_policy.get_hold_layers(model) - for layer in layers: - print(layer) - - -def run_dist_model(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') - check_bert_model_forward() - - -def run_dist_policy(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') - check_bert_model_policy() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_bert_model_forward(): - spawn(run_dist_model, 4) - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_bert_model_policy(): - spawn(run_dist_policy, 4) - - -if __name__ == "__main__": - """test the bert model forward and bert model policy""" - test_bert_model_forward() - test_bert_model_policy() diff --git a/tests/test_pipeline/test_policy/test_t5_pipeline_utils.py b/tests/test_pipeline/test_policy/test_t5_pipeline_utils.py deleted file mode 100644 index 0cbb852b97a0..000000000000 --- a/tests/test_pipeline/test_policy/test_t5_pipeline_utils.py +++ /dev/null @@ -1,39 +0,0 @@ -from colossalai.shardformer.policies.t5 import T5BasePolicy - - -def test_t5_pipeline_distribution(): - num_test_cases = 8 - test_dict = { - 'num_encoder_layers': [2, 1, 3, 2, 3, 2, 10, 5], - 'num_decoder_layers': [2, 8, 0, 2, 1, 5, 6, 22], - 'num_stages': [2, 2, 2, 4, 4, 4, 8, 8], - 'decoder_starting_stage': [1, 1, 2, 2, 3, 1, 5, 2] - } - - for i in range(num_test_cases): - _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(test_dict['num_encoder_layers'][i], - test_dict['num_decoder_layers'][i], - test_dict['num_stages'][i]) - assert test_dict['decoder_starting_stage'][i] == decoder_starting_stage - - -def test_t5_pipeline_layers(): - num_test_cases = 4 - test_dict = { - 'num_encoder_layers': [2, 3, 2, 4], - 'num_decoder_layers': [2, 0, 2, 8], - 'num_stages': [2, 2, 4, 4], - 'layers_per_stage': [[[0, 2], [0, 2]], [[0, 1], [1, 3]], [[0, 1], [1, 2], [0, 1], [1, 2]], - [[0, 4], [0, 3], [3, 6], [6, 8]]] - } - - for i in range(num_test_cases): - layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( - test_dict['num_encoder_layers'][i], test_dict['num_decoder_layers'][i], test_dict['num_stages'][i]) - - for stage in range(test_dict['num_stages'][i]): - start_idx, end_idx = test_dict['layers_per_stage'][i][stage] - predicted_start, predicted_end = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage, - decoder_starting_stage) - assert start_idx == predicted_start - assert end_idx == predicted_end From c6bbb90a7e4d84d7410dd1f1a684c9b8049e46db Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Thu, 3 Aug 2023 15:20:42 +0800 Subject: [PATCH 24/27] fix rmsnorm --- .../shardformer/modeling/chatglm2_6b/modeling_chatglm.py | 3 +-- tests/kit/model_zoo/transformers/chatglm.py | 5 +++-- .../test_model/test_shard_chatglm_pipeline.py | 3 +-- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py index bae6d425878d..792b54e39ba5 100644 --- a/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py +++ b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py @@ -223,14 +223,13 @@ class RMSNorm(torch.nn.Module): def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): super().__init__() - self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype)) + self.weight = torch.nn.Parameter(torch.ones(normalized_shape, device=device, dtype=dtype)) self.eps = eps def forward(self, hidden_states: torch.Tensor): input_dtype = hidden_states.dtype variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.eps) - return (self.weight * hidden_states).to(input_dtype) diff --git a/tests/kit/model_zoo/transformers/chatglm.py b/tests/kit/model_zoo/transformers/chatglm.py index 1b81fef65518..a6b5b3ee9afe 100644 --- a/tests/kit/model_zoo/transformers/chatglm.py +++ b/tests/kit/model_zoo/transformers/chatglm.py @@ -27,9 +27,10 @@ def data_gen(): padded_vocab_size=65024, hidden_size=64, num_attention_heads=8, - rmsnorm=False, + rmsnorm=True, original_rope=True, - use_cache=True) + use_cache=True, + torch_dtype=torch.float32) model_zoo.register(name='transformers_chatglm', model_fn=lambda: ChatGLMModel(config, empty_init=False), diff --git a/tests/test_shardformer/test_model/test_shard_chatglm_pipeline.py b/tests/test_shardformer/test_model/test_shard_chatglm_pipeline.py index f4ccf3296b8f..8db252bf528b 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm_pipeline.py @@ -46,9 +46,9 @@ def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism, use_ inputs['input_ids'] = None inputs['hidden_states'] = hidden_states outputs = sharded_model(**inputs) - if stage_manager.is_last_stage(): assert outputs[0].shape == hidden_state_shape + else: assert outputs['hidden_states'].shape == hidden_state_shape @@ -61,7 +61,6 @@ def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism, use_ inputs['input_ids'] = None inputs['hidden_states'] = hidden_states outputs = sharded_model(**inputs) - if stage_manager.is_last_stage(): assert outputs[0].shape == (batch_size, seq_len, 65024) else: From cf0b554c3923086c305efa1c5597a5e6a33e380b Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Thu, 3 Aug 2023 17:26:54 +0800 Subject: [PATCH 25/27] chatglm --- .../test_shardformer/test_model/test_shard_chatglm_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_shardformer/test_model/test_shard_chatglm_pipeline.py b/tests/test_shardformer/test_model/test_shard_chatglm_pipeline.py index 8db252bf528b..ee474ac7be3b 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm_pipeline.py @@ -31,7 +31,7 @@ def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism, use_ stage_manager = PipelineStageManager(pg_mesh, PP_DIM) sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - # create new model + # create new model for test inputs = data_gen_fn() inputs = {k: v.cuda() for k, v in inputs.items()} input_ids = inputs['input_ids'] From 0beae704cffe1205eddfb21ad94754bab30f77bd Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Fri, 4 Aug 2023 11:19:12 +0800 Subject: [PATCH 26/27] fix chatglm shard --- colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py | 2 ++ tests/kit/model_zoo/torchrec/__init__.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py index 792b54e39ba5..a21ee0231422 100644 --- a/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py +++ b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py @@ -223,6 +223,8 @@ class RMSNorm(torch.nn.Module): def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): super().__init__() + self.elementwise_affine = True + self.normalized_shape = normalized_shape self.weight = torch.nn.Parameter(torch.ones(normalized_shape, device=device, dtype=dtype)) self.eps = eps diff --git a/tests/kit/model_zoo/torchrec/__init__.py b/tests/kit/model_zoo/torchrec/__init__.py index 43952e6998cf..4a19f2449602 100644 --- a/tests/kit/model_zoo/torchrec/__init__.py +++ b/tests/kit/model_zoo/torchrec/__init__.py @@ -1 +1 @@ -from .torchrec import * +#from .torchrec import * From b868b6a2a0b7751586aada220fdf2a1563104161 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Fri, 4 Aug 2023 11:55:43 +0800 Subject: [PATCH 27/27] init --- tests/kit/model_zoo/torchrec/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kit/model_zoo/torchrec/__init__.py b/tests/kit/model_zoo/torchrec/__init__.py index 4a19f2449602..43952e6998cf 100644 --- a/tests/kit/model_zoo/torchrec/__init__.py +++ b/tests/kit/model_zoo/torchrec/__init__.py @@ -1 +1 @@ -#from .torchrec import * +from .torchrec import *