From 8dee68a0a22568dbeed6d4563372b25e1e825fb0 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 10 Jul 2023 13:41:14 +0800 Subject: [PATCH 01/23] bloom policy --- colossalai/shardformer/policies/bloom.py | 220 ++++++++++++++++++++++- 1 file changed, 219 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 662ff5b4977a..7452b611bed0 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -1,11 +1,26 @@ +import warnings +from functools import partial +from types import MethodType +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch import torch.nn as nn +from torch import Tensor +from torch.nn import CrossEntropyLoss, Module +from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions +from transformers.models.bloom.modeling_bloom import BloomModel +from transformers.utils import logging import colossalai.shardformer.layer as col_nn +from colossalai.pipeline.stage_manager import PipelineStageManager from .._utils import getattr_, setattr_ from ..modeling.bloom import build_bloom_alibi_tensor_fn from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +logger = logging.get_logger(__name__) + class BloomPolicy(Policy): @@ -110,7 +125,40 @@ def postprocess(self): class BloomModelPolicy(BloomPolicy): - pass + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + module_policy = super().module_policy() + from transformers.models.bloom.modeling_bloom import BloomModel + if self.pipeline_stage_manager: + module_policy[BloomModel] = ModulePolicyDescription( + method_replacement={"forward": partial(bloom_model_forward, stage_manager=self.pipeline_stage_manager)}) + + def get_held_layers(self) -> List[Module]: + """ + get pipeline layers for current stage + """ + module = self.model + stage_manager = self.pipeline_stage_manager + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + if self.stage_manager.is_first_stage(): + held_layers.append(module.word_embeddings) + held_layers.append(module.word_embeddings_layernorm) + + start_idx, end_idx = self.get_stage_index(layers_per_stage, self.stage_manager.stage) + held_layers.extend(module.h[start_idx:end_idx]) + + if self.stage_manager.is_last_stage(): + held_layers.append(module.ln_f) + + return held_layers + + def get_shared_params(self, module: BloomModel) -> List[Dict[int, Tensor]]: + '''no shared params in bloommodel''' + pass class BloomForCausalLMPolicy(BloomPolicy): @@ -181,3 +229,173 @@ def module_policy(self): class BloomForQuestionAnsweringPolicy(BloomPolicy): # No head sharding as the output features is only 2 pass + + +def bloom_model_forward( + self: BloomModel, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + **deprecated_arguments, +) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: + if deprecated_arguments.pop("position_ids", False) is not False: + # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` + warnings.warn( + "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" + " passing `position_ids`.", + FutureWarning, + ) + if len(deprecated_arguments) > 0: + raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # add warnings here + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if use_cache: + logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + use_cache = False + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape batch_size x num_heads x N x N + + # head_mask has shape n_layer x batch x num_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + # case: First stage of training + if stage_manager.is_first_stage(): + # check input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + hidden_states = self.word_embeddings_layernorm(inputs_embeds) + # initialize in the first stage and then pass to the next stage + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + + # extra recording tensor should be generated in the first stage + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + if past_key_values is None: + past_key_values = tuple([None] * len(self.h)) + # Compute alibi tensor: check build_alibi_tensor documentation,build for every stage + seq_length_with_past = seq_length + past_key_values_length = 0 + if past_key_values[0] is not None: + past_key_values_length = past_key_values[0][0].shape[2] # source_len + + seq_length_with_past = seq_length_with_past + past_key_values_length + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) + else: + attention_mask = attention_mask.to(hidden_states.device) + + alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) + + # causal_mask is constructed every stage and its input is passed through different stages + causal_mask = self._prepare_attn_mask( + attention_mask, + input_shape=(batch_size, seq_length), + past_key_values_length=past_key_values_length, + ) + + # calculate the num_layers + num_layers_per_stage = len(self.h) // stage_manager.num_stages + start_layer = stage_manager.stage * num_layers_per_stage + end_layer = (stage_manager.stage + 1) * num_layers_per_stage + + for i, (block, layer_past) in enumerate(zip(self.h[start_layer:end_layer], past_key_values[start_layer:end_layer])): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + alibi, + causal_mask, + layer_past, + head_mask[i], + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=causal_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + alibi=alibi, + ) + + hidden_states = outputs[0] + + if use_cache is True: + presents = presents + (outputs[1],) + if output_attentions: + all_self_attentions = all_self_attentions + \ + (outputs[2 if use_cache else 1],) + + if stage_manager.is_last_stage(): + # Add last hidden state + hidden_states = self.ln_f(hidden_states) + + # TODO: deal with all_hidden_states, all_self_attentions, presents + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + # attention_mask is not returned ; presents = past_key_values + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) From 9cb30051339477ffd9fd73d1687e0fa9a5c1e33e Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 10 Jul 2023 16:14:26 +0800 Subject: [PATCH 02/23] llama pipeline forward and tests --- .../shardformer/policies/auto_policy.py | 2 +- colossalai/shardformer/policies/bert.py | 2 +- colossalai/shardformer/policies/llama.py | 200 +++++++++++++++++- tests/test_shardformer/test_model/_utils.py | 1 + .../test_model/test_shard_llama_pipeline.py | 92 ++++++++ 5 files changed, 294 insertions(+), 3 deletions(-) create mode 100644 tests/test_shardformer/test_model/test_shard_llama_pipeline.py diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 640b61b579bd..0ad9a3e95a0e 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -45,7 +45,7 @@ class PolicyLocation: # LLaMA "transformers.models.llama.modeling_llama.LlamaModel": - PolicyLocation(file_name="llama", class_name="LlamaPolicy"), + PolicyLocation(file_name="llama", class_name="LlamaModelPolicy"), "transformers.models.llama.modeling_llama.LlamaForCausalLM": PolicyLocation(file_name="llama", class_name="LlamaForCausalLMPolicy"), "transformers.models.llama.modeling_llama.LlamaForSequenceClassification": diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index eacd0b449ad4..2b2c003ffb04 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -193,7 +193,7 @@ def get_held_layers(self) -> List[Module]: module = self.model stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(len(self.model.encoder.layer), stage_manager.num_stages) + layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) if stage_manager.is_first_stage(): held_layers.append(module.embeddings) start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index b10e07560d22..c6fe5c395aa7 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -1,11 +1,28 @@ -from typing import Dict, Union +import math +from functools import partial +from types import MethodType +from typing import Dict, List, Optional, Tuple, Union +import torch import torch.nn as nn +from torch import Tensor +from torch.nn import CrossEntropyLoss, Module +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, +) +from transformers.models.llama.modeling_llama import LlamaModel +from transformers.utils import ModelOutput, logging +from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +logger = logging.get_logger(__name__) + __all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy'] @@ -106,6 +123,39 @@ def postprocess(self): return self.model +class LlamaModelPolicy(LlamaPolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + module_policy = super().module_policy() + from transformers.models.llama.modeling_llama import LlamaModel + if self.pipeline_stage_manager: + # set None as default + module_policy[LlamaModel] = ModulePolicyDescription( + method_replacement={'forward': partial(llama_model_forward, stage_manager=self.pipeline_stage_manager)}) + return module_policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + module = self.model + stage_manager = self.pipeline_stage_manager + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.embed_tokens) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.norm) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in bert model""" + return [] + + class LlamaForCausalLMPolicy(LlamaPolicy): def module_policy(self): @@ -144,3 +194,151 @@ def module_policy(self): } policy.update(new_item) return policy + + +def llama_model_forward( + self: LlamaModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, +): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + device = input_ids.device if input_ids is not None else inputs_embeds.device + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device + + seq_length_with_past = seq_length + past_key_values_length = 0 + + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if use_cache: + logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + use_cache = False + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + position_ids = torch.arange(past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + # embed positions, for the first stage, hidden_states is the input embeddings, + # for the other stages, hidden_states is the output of the previous stage + # TODO: we should recive the attn mask of 1st stage and send it to the other stages + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device) + attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), hidden_states, + past_key_values_length) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + start_idx, end_idx = stage_index[0], stage_index[1] + for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx]): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if stage_manager.is_last_stage(): + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if stage_manager.is_last_stage(): + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + # always return dict for imediate stage + return {'hidden_states': hidden_states, 'attention_mask': attention_mask} diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index de8cb65d21d0..f26c6622da7e 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -39,6 +39,7 @@ def build_pipeline_model(model_fn, shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, enable_tensor_parallelism=enable_tensor_parallelism, pipeline_stage_manager=stage_manager) + shard_former = ShardFormer(shard_config=shard_config) sharded_model, shared_params = shard_former.optimize(model_copy) return org_model.cuda(), sharded_model.cuda() diff --git a/tests/test_shardformer/test_model/test_shard_llama_pipeline.py b/tests/test_shardformer/test_model/test_shard_llama_pipeline.py new file mode 100644 index 000000000000..90a908e04f2a --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_llama_pipeline.py @@ -0,0 +1,92 @@ +import pytest +import torch + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.policies.base_policy import Policy +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward + + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # check forward + pass + + +@parameterize('enable_fused_normalization', [False]) +@parameterize('enable_tensor_parallelism', [False]) +@parameterize('use_lazy_init', [False]) +#TODO: merge this into test_shard_bert +def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + + sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') + x = torch.randint(0, 1000, (2, 3)).cuda() + hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda() + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + if name == 'transformers_llama': + org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, + enable_tensor_parallelism, use_lazy_init) + + layers_per_stage = Policy.distribute_layers(len(org_model.layers), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + + if stage_manager.stage == 0: + attention_mask = torch.ones_like(x).cuda() + output = sharded_model(input_ids=x, + attention_mask=attention_mask, + stage_manager=stage_manager, + stage_index=stage_index) + assert output['hidden_states'].shape == (2, 3, 128) + else: + attention_mask = torch.ones((2, 3)).cuda() + output = sharded_model(hidden_states=hidden_states, + attention_mask=attention_mask, + stage_manager=stage_manager, + stage_index=stage_index) + # print(output[0].shape) + assert output[0].shape == (2, 3, 128) + + torch.cuda.empty_cache() + + +def check_bert(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_bert_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_bert(): + spawn(check_bert, 4) + + +if __name__ == "__main__": + test_bert() From a2144142d6c31eb3a9e0a73411f6596ea2a72d8d Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 10 Jul 2023 16:44:57 +0800 Subject: [PATCH 03/23] fix the output and attention_mask --- colossalai/shardformer/policies/llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index c6fe5c395aa7..85bb8c37cc07 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -268,7 +268,7 @@ def llama_model_forward( # for the other stages, hidden_states is the output of the previous stage # TODO: we should recive the attn mask of 1st stage and send it to the other stages if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device) + attention_mask = torch.ones((batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device) attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length) @@ -341,4 +341,4 @@ def custom_forward(*inputs): attentions=all_self_attns, ) # always return dict for imediate stage - return {'hidden_states': hidden_states, 'attention_mask': attention_mask} + return {'hidden_states': hidden_states} From b625f2fc16302c8e60741ef66264abd45c866bc4 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 10 Jul 2023 16:51:02 +0800 Subject: [PATCH 04/23] fix name --- .../test_model/test_shard_llama_pipeline.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_shardformer/test_model/test_shard_llama_pipeline.py b/tests/test_shardformer/test_model/test_shard_llama_pipeline.py index 90a908e04f2a..9b2a3f0ba0c1 100644 --- a/tests/test_shardformer/test_model/test_shard_llama_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_llama_pipeline.py @@ -26,8 +26,8 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_fused_normalization', [False]) @parameterize('enable_tensor_parallelism', [False]) @parameterize('use_lazy_init', [False]) -#TODO: merge this into test_shard_bert -def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): +#TODO: merge this into test_shard_llama +def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): DP_DIM, PP_DIM = 0, 1 DP_SIZE, PP_SIZE = 2, 2 RANK_TO_COORDINATE = { @@ -75,18 +75,18 @@ def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_laz torch.cuda.empty_cache() -def check_bert(rank, world_size, port): +def check_llama(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_bert_test() + run_llama_test() @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() -def test_bert(): - spawn(check_bert, 4) +def test_llama(): + spawn(check_llama, 4) if __name__ == "__main__": - test_bert() + test_llama() From 67e6a717c5b99bb84baf7de97c440db36b82e36a Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 10 Jul 2023 17:31:06 +0800 Subject: [PATCH 05/23] bind argument to policy --- .../test_model/test_shard_llama_pipeline.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/tests/test_shardformer/test_model/test_shard_llama_pipeline.py b/tests/test_shardformer/test_model/test_shard_llama_pipeline.py index 9b2a3f0ba0c1..8f27eb914e3b 100644 --- a/tests/test_shardformer/test_model/test_shard_llama_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_llama_pipeline.py @@ -52,23 +52,15 @@ def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_la if name == 'transformers_llama': org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, enable_tensor_parallelism, use_lazy_init) - - layers_per_stage = Policy.distribute_layers(len(org_model.layers), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - if stage_manager.stage == 0: attention_mask = torch.ones_like(x).cuda() - output = sharded_model(input_ids=x, - attention_mask=attention_mask, - stage_manager=stage_manager, - stage_index=stage_index) + output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) assert output['hidden_states'].shape == (2, 3, 128) else: attention_mask = torch.ones((2, 3)).cuda() output = sharded_model(hidden_states=hidden_states, attention_mask=attention_mask, - stage_manager=stage_manager, - stage_index=stage_index) + stage_manager=stage_manager) # print(output[0].shape) assert output[0].shape == (2, 3, 128) From dfe65dc2b14800d3db3349104ecdc9a6bf8f6e27 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 10 Jul 2023 17:32:18 +0800 Subject: [PATCH 06/23] Revert "bloom policy" This reverts commit 8dee68a0a22568dbeed6d4563372b25e1e825fb0. This policy should be revert and copied to feature/bloom --- colossalai/shardformer/policies/bloom.py | 220 +---------------------- 1 file changed, 1 insertion(+), 219 deletions(-) diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 7452b611bed0..662ff5b4977a 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -1,26 +1,11 @@ -import warnings -from functools import partial -from types import MethodType -from typing import Dict, List, Optional, Tuple, Union - -import numpy as np -import torch import torch.nn as nn -from torch import Tensor -from torch.nn import CrossEntropyLoss, Module -from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions -from transformers.models.bloom.modeling_bloom import BloomModel -from transformers.utils import logging import colossalai.shardformer.layer as col_nn -from colossalai.pipeline.stage_manager import PipelineStageManager from .._utils import getattr_, setattr_ from ..modeling.bloom import build_bloom_alibi_tensor_fn from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -logger = logging.get_logger(__name__) - class BloomPolicy(Policy): @@ -125,40 +110,7 @@ def postprocess(self): class BloomModelPolicy(BloomPolicy): - - def __init__(self) -> None: - super().__init__() - - def module_policy(self): - module_policy = super().module_policy() - from transformers.models.bloom.modeling_bloom import BloomModel - if self.pipeline_stage_manager: - module_policy[BloomModel] = ModulePolicyDescription( - method_replacement={"forward": partial(bloom_model_forward, stage_manager=self.pipeline_stage_manager)}) - - def get_held_layers(self) -> List[Module]: - """ - get pipeline layers for current stage - """ - module = self.model - stage_manager = self.pipeline_stage_manager - held_layers = [] - layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) - if self.stage_manager.is_first_stage(): - held_layers.append(module.word_embeddings) - held_layers.append(module.word_embeddings_layernorm) - - start_idx, end_idx = self.get_stage_index(layers_per_stage, self.stage_manager.stage) - held_layers.extend(module.h[start_idx:end_idx]) - - if self.stage_manager.is_last_stage(): - held_layers.append(module.ln_f) - - return held_layers - - def get_shared_params(self, module: BloomModel) -> List[Dict[int, Tensor]]: - '''no shared params in bloommodel''' - pass + pass class BloomForCausalLMPolicy(BloomPolicy): @@ -229,173 +181,3 @@ def module_policy(self): class BloomForQuestionAnsweringPolicy(BloomPolicy): # No head sharding as the output features is only 2 pass - - -def bloom_model_forward( - self: BloomModel, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - **deprecated_arguments, -) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: - if deprecated_arguments.pop("position_ids", False) is not False: - # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` - warnings.warn( - "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" - " passing `position_ids`.", - FutureWarning, - ) - if len(deprecated_arguments) > 0: - raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # add warnings here - if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') - output_attentions = False - if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') - output_hidden_states = False - if use_cache: - logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') - use_cache = False - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape batch_size x num_heads x N x N - - # head_mask has shape n_layer x batch x num_heads x N x N - head_mask = self.get_head_mask(head_mask, self.config.n_layer) - - # case: First stage of training - if stage_manager.is_first_stage(): - # check input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) - - hidden_states = self.word_embeddings_layernorm(inputs_embeds) - # initialize in the first stage and then pass to the next stage - else: - input_shape = hidden_states.shape[:-1] - batch_size, seq_length = input_shape - - # extra recording tensor should be generated in the first stage - - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") - use_cache = False - - if past_key_values is None: - past_key_values = tuple([None] * len(self.h)) - # Compute alibi tensor: check build_alibi_tensor documentation,build for every stage - seq_length_with_past = seq_length - past_key_values_length = 0 - if past_key_values[0] is not None: - past_key_values_length = past_key_values[0][0].shape[2] # source_len - - seq_length_with_past = seq_length_with_past + past_key_values_length - if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) - else: - attention_mask = attention_mask.to(hidden_states.device) - - alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) - - # causal_mask is constructed every stage and its input is passed through different stages - causal_mask = self._prepare_attn_mask( - attention_mask, - input_shape=(batch_size, seq_length), - past_key_values_length=past_key_values_length, - ) - - # calculate the num_layers - num_layers_per_stage = len(self.h) // stage_manager.num_stages - start_layer = stage_manager.stage * num_layers_per_stage - end_layer = (stage_manager.stage + 1) * num_layers_per_stage - - for i, (block, layer_past) in enumerate(zip(self.h[start_layer:end_layer], past_key_values[start_layer:end_layer])): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - alibi, - causal_mask, - layer_past, - head_mask[i], - ) - else: - outputs = block( - hidden_states, - layer_past=layer_past, - attention_mask=causal_mask, - head_mask=head_mask[i], - use_cache=use_cache, - output_attentions=output_attentions, - alibi=alibi, - ) - - hidden_states = outputs[0] - - if use_cache is True: - presents = presents + (outputs[1],) - if output_attentions: - all_self_attentions = all_self_attentions + \ - (outputs[2 if use_cache else 1],) - - if stage_manager.is_last_stage(): - # Add last hidden state - hidden_states = self.ln_f(hidden_states) - - # TODO: deal with all_hidden_states, all_self_attentions, presents - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) - - # attention_mask is not returned ; presents = past_key_values - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) From 9c88910dd56e37f4d8b6ba9d0a2b4dfa9a4af73e Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 10 Jul 2023 17:34:16 +0800 Subject: [PATCH 07/23] revert the bloom changes --- colossalai/shardformer/policies/llama.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 85bb8c37cc07..9494b686f9aa 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -133,8 +133,12 @@ def module_policy(self): from transformers.models.llama.modeling_llama import LlamaModel if self.pipeline_stage_manager: # set None as default - module_policy[LlamaModel] = ModulePolicyDescription( - method_replacement={'forward': partial(llama_model_forward, stage_manager=self.pipeline_stage_manager)}) + stage_manager = self.pipeline_stage_manager + layers_per_stage = Policy.distribute_layers(len(self.model.layers), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + module_policy[LlamaModel] = ModulePolicyDescription(method_replacement={ + 'forward': partial(llama_model_forward, stage_manager=stage_manager, stage_index=stage_index) + }) return module_policy def get_held_layers(self) -> List[Module]: @@ -266,7 +270,6 @@ def llama_model_forward( # embed positions, for the first stage, hidden_states is the input embeddings, # for the other stages, hidden_states is the output of the previous stage - # TODO: we should recive the attn mask of 1st stage and send it to the other stages if attention_mask is None: attention_mask = torch.ones((batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device) attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), hidden_states, From 48c41dcffe278fde7e15d64119d209520ef330fb Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 10 Jul 2023 18:35:15 +0800 Subject: [PATCH 08/23] cancel unneeded inputs --- colossalai/shardformer/policies/llama.py | 92 ++++++++++++++++++- .../test_model/test_shard_llama_pipeline.py | 9 +- 2 files changed, 96 insertions(+), 5 deletions(-) diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 9494b686f9aa..b3a41f8ab00b 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -12,8 +12,9 @@ BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, CausalLMOutputWithCrossAttentions, + CausalLMOutputWithPast, ) -from transformers.models.llama.modeling_llama import LlamaModel +from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaModel from transformers.utils import ModelOutput, logging from colossalai.pipeline.stage_manager import PipelineStageManager @@ -345,3 +346,92 @@ def custom_forward(*inputs): ) # always return dict for imediate stage return {'hidden_states': hidden_states} + + +def llama_for_causal_lm_forward( + self: LlamaForCausalLM, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, +): + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you consciours? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if stage_manager.is_last_stage(): + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/tests/test_shardformer/test_model/test_shard_llama_pipeline.py b/tests/test_shardformer/test_model/test_shard_llama_pipeline.py index 8f27eb914e3b..81c183d3230e 100644 --- a/tests/test_shardformer/test_model/test_shard_llama_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_llama_pipeline.py @@ -54,13 +54,14 @@ def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_la enable_tensor_parallelism, use_lazy_init) if stage_manager.stage == 0: attention_mask = torch.ones_like(x).cuda() - output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) + output = sharded_model(input_ids=x, attention_mask=attention_mask) assert output['hidden_states'].shape == (2, 3, 128) else: attention_mask = torch.ones((2, 3)).cuda() - output = sharded_model(hidden_states=hidden_states, - attention_mask=attention_mask, - stage_manager=stage_manager) + output = sharded_model( + hidden_states=hidden_states, + attention_mask=attention_mask, + ) # print(output[0].shape) assert output[0].shape == (2, 3, 128) From 2a9097bdfde0db68f1a5bd8e7a34e7a024e47574 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 11 Jul 2023 10:37:32 +0800 Subject: [PATCH 09/23] gpt --- colossalai/shardformer/policies/llama.py | 145 ++++++++++++++++++++++- tests/kit/model_zoo/transformers/gpt.py | 2 +- 2 files changed, 141 insertions(+), 6 deletions(-) diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index b3a41f8ab00b..b2b6470188a4 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -6,15 +6,16 @@ import torch import torch.nn as nn from torch import Tensor -from torch.nn import CrossEntropyLoss, Module +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, Module, MSELoss from transformers.modeling_outputs import ( BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, CausalLMOutputWithCrossAttentions, CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, ) -from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaModel +from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel from transformers.utils import ModelOutput, logging from colossalai.pipeline.stage_manager import PipelineStageManager @@ -394,8 +395,20 @@ def llama_for_causal_lm_forward( if output_hidden_states is not None else self.config.output_hidden_states) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs = llama_model_forward( + self.model, input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -405,12 +418,18 @@ def llama_for_causal_lm_forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, ) + past_key_values = None + all_hidden_states = None + all_self_attentions = None + all_cross_attentions = None - hidden_states = outputs[0] if stage_manager.is_last_stage(): + hidden_states = outputs[0] logits = self.lm_head(hidden_states) - loss = None if labels is not None: # Shift so that tokens < n predict n @@ -435,3 +454,119 @@ def llama_for_causal_lm_forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + else: + hidden_states = outputs.get('hidden_states') + return {'hidden_states': hidden_states} + + +def llama_for_sequence_classification_forward( + self: LlamaForSequenceClassification, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, +): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + + transformer_outputs = llama_model_forward( + self.model, + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + ) + + if input_ids is not None: + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + batch_size = inputs_embeds.shape[0] + else: + batch_size = hidden_states.shape[0] + + if stage_manager.is_last_stage(): + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + else: + hidden_states = transformer_outputs.get('hidden_states') + return {'hidden_states': hidden_states} diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index b9e0310780af..ac70138e3f8f 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -52,7 +52,7 @@ def data_gen_for_sequence_classification(): loss_fn = lambda x: x.loss config = transformers.GPT2Config(n_layer=2, - n_head=4, + n_head=2, vocab_size=50258, attn_pdrop=0, embd_pdrop=0, From 279a34e89fe30c53e1c6afc1bc3488af5c78a182 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 11 Jul 2023 11:17:16 +0800 Subject: [PATCH 10/23] finish llama --- colossalai/shardformer/policies/llama.py | 77 +++++++++++++++++++++++- 1 file changed, 76 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index b2b6470188a4..c692935957ff 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -158,7 +158,7 @@ def get_held_layers(self) -> List[Module]: return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: - """No shared params in bert model""" + """No shared params in llama model""" return [] @@ -179,8 +179,47 @@ def module_policy(self): ]) } policy.update(new_item) + # to be confirmed + if self.pipeline_stage_manager: + # set None as default + stage_manager = self.pipeline_stage_manager + layers_per_stage = Policy.distribute_layers(len(self.model.model.layers), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + new_item = { + LlamaForCausalLM: + ModulePolicyDescription( + method_replacement={ + 'forward': + partial( + llama_for_causal_lm_forward, stage_manager=stage_manager, stage_index=stage_index) + }) + } + policy.update(new_item) return policy + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + module = self.model + stage_manager = self.pipeline_stage_manager + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.model.layers), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.model.embed_tokens) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.model.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.model.norm) + held_layers.append(module.lmhead) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in llama model""" + llama_model = self.model.model + if id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight): + # tie weights + return [{0: llama_model.embed_tokens.weight, self.stage_manager.num_stages - 1: self.model.lm_head.weight}] + return [] + class LlamaForSequenceClassificationPolicy(LlamaPolicy): @@ -199,8 +238,44 @@ def module_policy(self): ]) } policy.update(new_item) + # to be confirmed + if self.pipeline_stage_manager: + # set None as default + stage_manager = self.pipeline_stage_manager + layers_per_stage = Policy.distribute_layers(len(self.model.model.layers), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + new_item = { + LlamaForSequenceClassification: + ModulePolicyDescription( + method_replacement={ + 'forward': + partial(llama_for_sequence_classification_forward, + stage_manager=stage_manager, + stage_index=stage_index) + }) + } + policy.update(new_item) return policy + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + module = self.model + stage_manager = self.pipeline_stage_manager + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.model.layers), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.model.embed_tokens) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.model.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.model.norm) + held_layers.append(module.score) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in llama for sequence classification model""" + return [] + def llama_model_forward( self: LlamaModel, From 4ac3aac06c9ce34eed0e37765d3827116824e30d Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 11 Jul 2023 14:16:07 +0800 Subject: [PATCH 11/23] causal lm and sequence classification --- .../shardformer/policies/base_policy.py | 18 +++++++ colossalai/shardformer/policies/llama.py | 48 +++++++++---------- .../test_model/test_shard_llama_pipeline.py | 28 +++++------ 3 files changed, 54 insertions(+), 40 deletions(-) diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index aac86eb20a56..68fde0115de6 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -162,6 +162,24 @@ def append_or_create_submodule_replacement( return policy + def append_or_create_method_replacement( + self, description: Dict[str, Callable], policy: Dict[Union[str, nn.Module], ModulePolicyDescription], + target_key: Union[str, nn.Module]) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + r""" + Append or create a new method replacement description to the policy for the given key. + + Args: + description (Union[SubModuleReplacementDescription, List[SubModuleReplacementDescription]]): the submodule replacement description to be appended + policy (Dict[Union[str, nn.Module], ModulePolicyDescription]): the policy to be updated + target_key (Union[str, nn.Module]): the key of the policy to be updated + """ + if target_key in policy: + policy[target_key].method_replacement.update(description) + else: + policy[target_key] = ModulePolicyDescription(method_replacement=description) + + return policy + def get_held_layers(self) -> List[Module]: """Get layers that should be held in current stage. This method should be implemented by subclass. diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index c692935957ff..2210aac950f1 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -131,17 +131,20 @@ def __init__(self) -> None: super().__init__() def module_policy(self): - module_policy = super().module_policy() + policy = super().module_policy() from transformers.models.llama.modeling_llama import LlamaModel if self.pipeline_stage_manager: # set None as default stage_manager = self.pipeline_stage_manager layers_per_stage = Policy.distribute_layers(len(self.model.layers), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - module_policy[LlamaModel] = ModulePolicyDescription(method_replacement={ + method_replacement = { 'forward': partial(llama_model_forward, stage_manager=stage_manager, stage_index=stage_index) - }) - return module_policy + } + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=LlamaModel) + return policy def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" @@ -179,22 +182,19 @@ def module_policy(self): ]) } policy.update(new_item) - # to be confirmed + if self.pipeline_stage_manager: # set None as default stage_manager = self.pipeline_stage_manager layers_per_stage = Policy.distribute_layers(len(self.model.model.layers), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - new_item = { - LlamaForCausalLM: - ModulePolicyDescription( - method_replacement={ - 'forward': - partial( - llama_for_causal_lm_forward, stage_manager=stage_manager, stage_index=stage_index) - }) + method_replacement = { + 'forward': partial(llama_for_causal_lm_forward, stage_manager=stage_manager, stage_index=stage_index) } - policy.update(new_item) + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=LlamaForCausalLM) + return policy def get_held_layers(self) -> List[Module]: @@ -209,7 +209,7 @@ def get_held_layers(self) -> List[Module]: held_layers.extend(module.model.layers[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.model.norm) - held_layers.append(module.lmhead) + held_layers.append(module.lm_head) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -244,17 +244,15 @@ def module_policy(self): stage_manager = self.pipeline_stage_manager layers_per_stage = Policy.distribute_layers(len(self.model.model.layers), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - new_item = { - LlamaForSequenceClassification: - ModulePolicyDescription( - method_replacement={ - 'forward': - partial(llama_for_sequence_classification_forward, - stage_manager=stage_manager, - stage_index=stage_index) - }) + method_replacement = { + 'forward': + partial(llama_for_sequence_classification_forward, + stage_manager=stage_manager, + stage_index=stage_index) } - policy.update(new_item) + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=LlamaForSequenceClassification) return policy def get_held_layers(self) -> List[Module]: diff --git a/tests/test_shardformer/test_model/test_shard_llama_pipeline.py b/tests/test_shardformer/test_model/test_shard_llama_pipeline.py index 81c183d3230e..8fd9ed099478 100644 --- a/tests/test_shardformer/test_model/test_shard_llama_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_llama_pipeline.py @@ -49,21 +49,19 @@ def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_la x = torch.randint(0, 1000, (2, 3)).cuda() hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda() for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - if name == 'transformers_llama': - org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, - enable_tensor_parallelism, use_lazy_init) - if stage_manager.stage == 0: - attention_mask = torch.ones_like(x).cuda() - output = sharded_model(input_ids=x, attention_mask=attention_mask) - assert output['hidden_states'].shape == (2, 3, 128) - else: - attention_mask = torch.ones((2, 3)).cuda() - output = sharded_model( - hidden_states=hidden_states, - attention_mask=attention_mask, - ) - # print(output[0].shape) - assert output[0].shape == (2, 3, 128) + org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, + enable_tensor_parallelism, use_lazy_init) + if stage_manager.stage == 0: + attention_mask = torch.ones_like(x).cuda() + output = sharded_model(input_ids=x, attention_mask=attention_mask) + assert output['hidden_states'].shape == (2, 3, 128) + else: + attention_mask = torch.ones((2, 3)).cuda() + output = sharded_model( + hidden_states=hidden_states, + attention_mask=attention_mask, + ) + assert output[0] is not None torch.cuda.empty_cache() From 807974f25b2ea80f55dae9ddd97f43bd1cb018dd Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 11 Jul 2023 14:31:55 +0800 Subject: [PATCH 12/23] revision --- colossalai/shardformer/policies/llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 2210aac950f1..a3ea807269bb 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -194,7 +194,6 @@ def module_policy(self): self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaForCausalLM) - return policy def get_held_layers(self) -> List[Module]: From 3d406c702b38540cf56b7195bf66db2a2702f0d7 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Wed, 12 Jul 2023 12:03:11 +0800 Subject: [PATCH 13/23] add pure pipeline test --- colossalai/pipeline/p2p.py | 5 +- colossalai/pipeline/schedule/one_f_one_b.py | 1 - .../test_model/test_pure_pipeline.py | 163 ++++++++++++++++++ 3 files changed, 167 insertions(+), 2 deletions(-) create mode 100644 tests/test_shardformer/test_model/test_pure_pipeline.py diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index 203b7439d7ef..2fd135d5475d 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -64,7 +64,10 @@ def _broadcast_object_list(object_list: List[Any], my_rank = dist.get_rank() # Serialize object_list elements to tensors on src rank. if my_rank == src: - tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list]) + if torch.__version__ >= "1.13.0": + tensor_list, size_list = zip(*[c10d._object_to_tensor(obj, device=device) for obj in object_list]) + else: + tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list]) object_sizes_tensor = torch.cat(size_list) else: object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long) diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index a8933bfbb4da..6ed3055d689b 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -205,7 +205,6 @@ def forward_backward_step(self, # the backward pass. input_obj = input_objs.pop(0) output_obj = output_objs.pop(0) - input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) if last_iteration: diff --git a/tests/test_shardformer/test_model/test_pure_pipeline.py b/tests/test_shardformer/test_model/test_pure_pipeline.py new file mode 100644 index 000000000000..a82c39406dfa --- /dev/null +++ b/tests/test_shardformer/test_model/test_pure_pipeline.py @@ -0,0 +1,163 @@ +import random +from contextlib import nullcontext +from typing import Any, Callable, Iterator, List, Optional, Tuple + +import numpy as np +import pytest +import torch +import torch.distributed as dist +from torch import Tensor +from torch.nn import Module +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward + +DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 + + +class PipelineOptimizer(OptimizerWrapper): + + def __init__(self, optim: Optimizer, model: Module): + super().__init__(optim) + params = set(model.parameters()) + new_param_groups = [] + for group in optim.param_groups: + params = [p for p in group['params'] if p in params] + new_param_groups.append({**group, 'params': params}) + optim.__setstate__({'param_groups': new_param_groups}) + # TODO: support amp + + +class PipelinedModel(ModelWrapper): + + def __init__(self, module: Module, shard_config: ShardConfig, stage_manager: PipelineStageManager) -> None: + self.stage_manager = stage_manager + shardformer = ShardFormer(shard_config) + module, self.shared_params = shardformer.optimize(module) + self.shared_param_process_groups = [] + super().__init__(module) + + +def prepare_dataloader(dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0): + sampler = DistributedSampler( + dataset, + #rank=self.pg_mesh.coordinate(DP_AXIS), + shuffle=shuffle) + + # Deterministic dataloader + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + ) + + +def execute_pipeline( + data_iter: Iterator, + model: PipelinedModel, + criterion: Callable[[Any, Any], torch.Tensor], + optimizer: PipelineOptimizer, + return_loss: bool = True, + return_outputs: bool = False, + schedule: OneForwardOneBackwardSchedule = None, +) -> dict: + # return loss or outputs if needed + outputs = schedule.forward_backward_step(model, optimizer, data_iter, criterion, return_loss, return_outputs) + return outputs + + +class data_iter(): + + def __getitem__(self, x): + return torch.randint(0, 100, (4, 128)).cuda() + + +def loss(x, y): + return (x[0].float().mean() - y[0].float().mean()) + + +@parameterize('enable_fused_normalization', [False]) +@parameterize('enable_tensor_parallelism', [False]) +@parameterize('use_lazy_init', [False]) +def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): + PP_DIM = 0 + PP_SIZE = 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + from datasets import load_dataset + + #dataset = load_dataset("open_subtitles", lang1="fi", lang2="hi") + pg_mesh = ProcessGroupMesh(PP_SIZE) + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + num_microbatches = 2 + org_model = model_fn().cuda() + optimizer = torch.optim.AdamW(org_model.parameters(), lr=1e-3) + #dataloader=prepare_dataloader(dataset=dataset['train'],batch_size=4) + schedule = OneForwardOneBackwardSchedule(num_microbatches, stage_manager) + shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, + enable_tensor_parallelism=enable_tensor_parallelism, + pipeline_stage_manager=stage_manager) + pipelined_model = PipelinedModel(org_model, shard_config, stage_manager) + pp_optimizer = PipelineOptimizer(optimizer, pipelined_model) + data_it = iter(data_iter()) + results = execute_pipeline(data_it, pipelined_model, loss, pp_optimizer, schedule=schedule) + assert results['loss'] is not None + assert results['outputs'] is None + torch.cuda.empty_cache() + + +def check_llama(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_llama_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_llama(): + spawn(check_llama, 2) + + +if __name__ == "__main__": + test_llama() From 4c45a3fb3d4493b73f23a3826c9c0b524d294c4f Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Wed, 12 Jul 2023 19:12:47 +0800 Subject: [PATCH 14/23] finish some bert models --- colossalai/shardformer/policies/bert.py | 147 +++++++++++++++++++++++- 1 file changed, 145 insertions(+), 2 deletions(-) diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 2b2c003ffb04..67a0642f9355 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -11,6 +11,7 @@ BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, CausalLMOutputWithCrossAttentions, + SequenceClassifierOutput, ) from transformers.models.bert.modeling_bert import ( BertForMaskedLM, @@ -807,14 +808,63 @@ def bert_for_masked_lm_forward( hidden_states: Optional[torch.Tensor] = None, stage_manager: Optional[PipelineStageManager] = None, ): - #-> Union[Tuple[torch.Tensor], MaskedLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` """ - pass + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + + outputs = bert_model_forward( + self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + ) + + if stage_manager.is_last_stage(): + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get('hidden_states') + return {'hidden_states': hidden_states} def bert_for_next_sentence_prediction_forward( @@ -916,3 +966,96 @@ def bert_for_next_sentence_prediction_forward( hidden_states = outputs.get('hidden_states') # intermediate stage always return dict return {'hidden_states': hidden_states} + + +def bert_for_next_sentence_prediction_forward( + self: BertForNextSentencePrediction, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.Tensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + **kwargs, +): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = bert_model_forward( + self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + ) + + if stage_manager.is_last_stage(): + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get('hidden_states') + return {'hidden_states': hidden_states} From a693291189741a3949e47002bdec3530e1ee74f5 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Thu, 13 Jul 2023 15:38:10 +0800 Subject: [PATCH 15/23] finish all bert models --- colossalai/shardformer/policies/bert.py | 757 +++++++++++++++++++---- colossalai/shardformer/policies/llama.py | 6 +- 2 files changed, 652 insertions(+), 111 deletions(-) diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 67a0642f9355..5d30ff84631d 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -11,13 +11,20 @@ BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, CausalLMOutputWithCrossAttentions, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, SequenceClassifierOutput, + TokenClassifierOutput, ) from transformers.models.bert.modeling_bert import ( BertForMaskedLM, + BertForMultipleChoice, BertForNextSentencePrediction, BertForPreTraining, BertForPreTrainingOutput, + BertForQuestionAnswering, + BertForSequenceClassification, + BertForTokenClassification, BertLMHeadModel, BertModel, ) @@ -32,9 +39,9 @@ logger = logging.get_logger(__name__) __all__ = [ - 'BertPolicy', 'BertModelPolicy', 'BertForPreTrainingPolicy', 'BertLMHeadModelPolicy', 'BertForMaskedLMPolicy', + 'BertPolicy', 'BertModelPolicy', 'BertForPreTrainingPolicy', 'BertLMdHeadModelPolicy', 'BertForMaskedLMPolicy', 'BertForNextSentencePredictionPolicy', 'BertForSequenceClassificationPolicy', 'BertForTokenClassificationPolicy', - 'BertForMultipleChoicePolicy' + 'BertForMultipleChoicePolicy', 'BertForQuestionAnsweringPolicy' ] @@ -181,13 +188,20 @@ def __init__(self) -> None: super().__init__() def module_policy(self): - module_policy = super().module_policy() + policy = super().module_policy() from transformers.models.bert.modeling_bert import BertModel if self.pipeline_stage_manager: # set None as default - module_policy[BertModel] = ModulePolicyDescription( - method_replacement={'forward': partial(bert_model_forward, stage_manager=self.pipeline_stage_manager)}) - return module_policy + stage_manager = self.pipeline_stage_manager + layers_per_stage = Policy.distribute_layers(len(self.model.encoder.layer), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = { + 'forward': partial(bert_model_forward, stage_manager=stage_manager, stage_index=stage_index) + } + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=BertModel) + return policy def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" @@ -215,15 +229,27 @@ def __init__(self) -> None: super().__init__() def module_policy(self): - module_policy = super().module_policy() - module_policy = self.add_lm_head_policy(module_policy) - return module_policy + policy = super().module_policy() + policy = self.add_lm_head_policy(policy) + from transformers.models.bert.modeling_bert import BertForPreTraining + if self.pipeline_stage_manager: + # set None as default + stage_manager = self.pipeline_stage_manager + layers_per_stage = Policy.distribute_layers(len(self.model.bert.encoder.layer), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = { + 'forward': partial(bert_for_pretraining_forward, stage_manager=stage_manager, stage_index=stage_index) + } + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=BertForPreTraining) + return policy def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage""" module = self.model stage_manager = self.pipeline_stage_manager - layers_per_stage = self.distribute_layers(len(self.model.bert.encoder.layer), stage_manager.num_stages) + layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages) held_layers = [] if stage_manager.is_first_stage(): held_layers.append(module.bert.embeddings) @@ -238,7 +264,13 @@ def get_held_layers(self) -> List[Module]: return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: - '''No shared params in bertmodel''' + model = self.model + if id(model.bert.embeddings.word_embeddings.weight) == id(model.cls.predictions.decoder.weight): + #tie weights + return [{ + 0: model.bert.embeddings.word_embeddings.weight, + self.pipeline_stage_manager.num_stages - 1: model.cls.predictions.decoder.weight + }] return [] def postprocess(self): @@ -257,9 +289,20 @@ def __init__(self) -> None: super().__init__() def module_policy(self): - module_policy = super().module_policy() - module_policy = self.add_lm_head_policy(module_policy) - return module_policy + policy = super().module_policy() + policy = self.add_lm_head_policy(policy) + from transformers.models.bert.modeling_bert import BertLMHeadModel + if self.pipeline_stage_manager: + stage_manager = self.pipeline_stage_manager + layers_per_stage = Policy.distribute_layers(len(self.model.bert.encoder.layer), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = { + 'forward': partial(bert_lm_head_model_forward, stage_manager=stage_manager, stage_index=stage_index) + } + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=BertLMHeadModel) + return policy def get_held_layers(self) -> List[Module]: """ @@ -268,7 +311,7 @@ def get_held_layers(self) -> List[Module]: module = self.model held_layers = [] stage_manager = self.pipeline_stage_manager - layers_per_stage = self.distribute_layers(len(self.model.bert.encoder.layer), stage_manager.num_stages) + layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages) if stage_manager.is_first_stage(): held_layers.append(module.bert.embeddings) start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) @@ -279,7 +322,13 @@ def get_held_layers(self) -> List[Module]: return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: - '''No shared params in bertmodel''' + bert_model = self.model.bert + if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight): + #tie weights + return [{ + 0: bert_model.embeddings.word_embeddings.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight + }] return [] def postprocess(self): @@ -298,9 +347,47 @@ def __init__(self) -> None: super().__init__() def module_policy(self): - module_policy = super().module_policy() - module_policy = self.add_lm_head_policy(module_policy) - return module_policy + policy = super().module_policy() + policy = self.add_lm_head_policy(policy) + from transformers.models.bert.modeling_bert import BertForMaskedLM + if self.pipeline_stage_manager: + stage_manager = self.pipeline_stage_manager + layers_per_stage = Policy.distribute_layers(len(self.model.bert.encoder.layer), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = { + 'forward': partial(bert_for_masked_lm_forward, stage_manager=stage_manager, stage_index=stage_index) + } + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=BertForMaskedLM) + return policy + + def get_held_layers(self) -> List[Module]: + """ + get pipeline layers for current stage + """ + module = self.model + held_layers = [] + stage_manager = self.pipeline_stage_manager + layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.bert.embeddings) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.bert.encoder.layer[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.bert.pooler) + held_layers.append(module.cls) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + bert_model = self.model.bert + if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight): + #tie weights + return [{ + 0: bert_model.embeddings.word_embeddings.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight + }] + return [] def postprocess(self): if self.shard_config.enable_tensor_parallelism: @@ -320,7 +407,7 @@ def __init__(self) -> None: def module_policy(self): from transformers.models.bert.modeling_bert import BertForSequenceClassification - module_policy = super().module_policy() + policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: addon_module = { @@ -332,8 +419,45 @@ def module_policy(self): ) ]) } - module_policy.update(addon_module) - return module_policy + policy.update(addon_module) + + if self.pipeline_stage_manager: + stage_manager = self.pipeline_stage_manager + layers_per_stage = Policy.distribute_layers(len(self.model.bert.encoder.layer), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = { + 'forward': + partial(bert_for_sequence_classification_forward, + stage_manager=stage_manager, + stage_index=stage_index) + } + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=BertForSequenceClassification) + + return policy + + def get_held_layers(self) -> List[Module]: + """ + get pipeline layers for current stage + """ + module = self.model + held_layers = [] + stage_manager = self.pipeline_stage_manager + layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.bert.embeddings) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.bert.encoder.layer[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.bert.pooler) + held_layers.append(module.dropout) + held_layers.append(module.classifier) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + # no shared params for sequence classification model + return [] # BertForTokenClassification @@ -345,7 +469,7 @@ def __init__(self) -> None: def module_policy(self): from transformers.models.bert.modeling_bert import BertForTokenClassification - module_policy = super().module_policy() + policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: addon_module = { @@ -357,8 +481,43 @@ def module_policy(self): ) ]) } - module_policy.update(addon_module) - return module_policy + policy.update(addon_module) + + if self.pipeline_stage_manager: + stage_manager = self.pipeline_stage_manager + layers_per_stage = Policy.distribute_layers(len(self.model.bert.encoder.layer), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = { + 'forward': + partial(bert_for_token_classification_forward, stage_manager=stage_manager, stage_index=stage_index) + } + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=BertForTokenClassification) + + return policy + + def get_held_layers(self) -> List[Module]: + """ + get pipeline layers for current stage + """ + module = self.model + held_layers = [] + stage_manager = self.pipeline_stage_manager + layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.bert.embeddings) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.bert.encoder.layer[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.bert.pooler) + held_layers.append(module.dropout) + held_layers.append(module.classifier) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + # no shared params for sequence classification model + return [] # BertForNextSentencePrediction @@ -367,6 +526,47 @@ class BertForNextSentencePredictionPolicy(BertPolicy): def __init__(self) -> None: super().__init__() + def module_policy(self): + policy = super().module_policy() + from transformers.models.bert.modeling_bert import BertForNextSentencePrediction + + if self.pipeline_stage_manager: + stage_manager = self.pipeline_stage_manager + layers_per_stage = Policy.distribute_layers(len(self.model.bert.encoder.layer), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = { + 'forward': + partial(bert_for_next_sentence_prediction_forward, + stage_manager=stage_manager, + stage_index=stage_index) + } + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=BertForNextSentencePrediction) + + return policy + + def get_held_layers(self) -> List[Module]: + """ + get pipeline layers for current stage + """ + module = self.model + held_layers = [] + stage_manager = self.pipeline_stage_manager + layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.bert.embeddings) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.bert.encoder.layer[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.bert.pooler) + held_layers.append(module.cls) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + # no shared params for sequence classification model + return [] + # BertForMultipleChoice class BertForMultipleChoicePolicy(BertPolicy): @@ -377,7 +577,7 @@ def __init__(self) -> None: def module_policy(self): from transformers.models.bert.modeling_bert import BertForMultipleChoice - module_policy = super().module_policy() + policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: addon_module = { @@ -389,28 +589,107 @@ def module_policy(self): ) ]) } - module_policy.update(addon_module) - return module_policy + policy.update(addon_module) + + if self.pipeline_stage_manager: + stage_manager = self.pipeline_stage_manager + layers_per_stage = Policy.distribute_layers(len(self.model.bert.encoder.layer), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = { + 'forward': + partial(bert_for_multipile_choice_forward, stage_manager=stage_manager, stage_index=stage_index) + } + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=BertForMultipleChoice) + + return policy + + def get_held_layers(self) -> List[Module]: + """ + get pipeline layers for current stage + """ + module = self.model + held_layers = [] + stage_manager = self.pipeline_stage_manager + layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.bert.embeddings) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.bert.encoder.layer[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.bert.pooler) + held_layers.append(module.dropout) + held_layers.append(module.classifier) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + # no shared params for sequence classification model + return [] + + +class BertForQuestionAnsweringPolicy(BertPolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.bert.modeling_bert import BertForQuestionAnswering + policy = super().module_policy() + if self.pipeline_stage_manager: + stage_manager = self.pipeline_stage_manager + layers_per_stage = Policy.distribute_layers(len(self.model.bert.encoder.layer), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = { + 'forward': + partial(bert_for_question_answering_forward, stage_manager=stage_manager, stage_index=stage_index) + } + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=BertForQuestionAnswering) + + return policy + + def get_held_layers(self) -> List[Module]: + """ + get pipeline layers for current stage + """ + module = self.model + held_layers = [] + stage_manager = self.pipeline_stage_manager + layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.bert.embeddings) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.bert.encoder.layer[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.bert.pooler) + held_layers.append(module.qa_outputs) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + # no shared params for sequence classification model + return [] def bert_model_forward( - self: BertModel, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - # labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage + self: BertModel, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage + stage_index: Optional[List[int]] = None, ): # TODO: add explaination of the output here. r""" @@ -529,14 +808,10 @@ def bert_model_forward( use_cache = False next_decoder_cache = () if use_cache else None - # calculate the num_layers - num_layers_per_stage = len(self.encoder.layer) // stage_manager.num_stages - start_layer = stage_manager.stage * num_layers_per_stage - end_layer = (stage_manager.stage + 1) * num_layers_per_stage - + start_idx, end_idx = stage_index[0], stage_index[1] # layer_outputs layer_outputs = hidden_states if hidden_states is not None else None - for idx, encoder_layer in enumerate(self.encoder.layer[start_layer:end_layer], start=start_layer): + for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx): if stage_manager.is_first_stage() and idx == 0: encoder_attention_mask = encoder_extended_attention_mask @@ -625,6 +900,7 @@ def bert_for_pretraining_forward( return_dict: Optional[bool] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. @@ -638,18 +914,21 @@ def bert_for_pretraining_forward( logger.warning_once('return_dict is not supported for pipeline models at the moment') return_dict = False - outputs = bert_model_forward(self.bert, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states if hidden_states is not None else None) + outputs = bert_model_forward( + self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states if hidden_states is not None else None, + stage_index=stage_index, + ) past_key_values = None all_hidden_states = None all_self_attentions = None @@ -685,23 +964,23 @@ def bert_for_pretraining_forward( } -def bert_lmhead_forward(self: BertLMHeadModel, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.Tensor]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_manager: Optional[PipelineStageManager] = None): +def bert_lm_head_model_forward(self: BertLMHeadModel, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.Tensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_manager: Optional[PipelineStageManager] = None): r""" encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if @@ -807,6 +1086,7 @@ def bert_for_masked_lm_forward( return_dict: Optional[bool] = None, hidden_states: Optional[torch.Tensor] = None, stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, ): r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -841,6 +1121,7 @@ def bert_for_masked_lm_forward( return_dict=return_dict, hidden_states=hidden_states, stage_manager=stage_manager, + stage_index=stage_index, ) if stage_manager.is_last_stage(): @@ -881,6 +1162,7 @@ def bert_for_next_sentence_prediction_forward( return_dict: Optional[bool] = None, hidden_states: Optional[torch.Tensor] = None, stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, **kwargs, ): #-> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]: @@ -931,18 +1213,19 @@ def bert_for_next_sentence_prediction_forward( return_dict = False return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = bert_model_forward( - self.bert, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) + outputs = bert_model_forward(self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index) if stage_manager.is_last_stage(): pooled_output = outputs[1] seq_relationship_scores = self.cls(pooled_output) @@ -968,8 +1251,8 @@ def bert_for_next_sentence_prediction_forward( return {'hidden_states': hidden_states} -def bert_for_next_sentence_prediction_forward( - self: BertForNextSentencePrediction, +def bert_for_sequence_classification_forward( + self: BertForSequenceClassification, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, @@ -982,7 +1265,7 @@ def bert_for_next_sentence_prediction_forward( return_dict: Optional[bool] = None, hidden_states: Optional[torch.Tensor] = None, stage_manager: Optional[PipelineStageManager] = None, - **kwargs, + stage_index: Optional[List[int]] = None, ): r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -1003,20 +1286,19 @@ def bert_for_next_sentence_prediction_forward( return_dict = False return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = bert_model_forward( - self.bert, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - hidden_states=hidden_states, - stage_manager=stage_manager, - ) + outputs = bert_model_forward(self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index) if stage_manager.is_last_stage(): pooled_output = outputs[1] @@ -1059,3 +1341,260 @@ def bert_for_next_sentence_prediction_forward( else: hidden_states = outputs.get('hidden_states') return {'hidden_states': hidden_states} + + +def bert_for_token_classification_forward( + self: BertForTokenClassification, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.Tensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, +): + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = bert_model_forward( + self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index, + ) + + if stage_manager.is_last_stage(): + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get('hidden_states') + return {'hidden_states': hidden_states} + + +def bert_for_multipile_choice_forward( + self: BertForMultipleChoice, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.Tensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, +): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = (inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None else None) + + outputs = bert_model_forward( + self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index, + ) + if stage_manager.is_last_stage(): + # the num_choices is only used for the last stage + num_choices = hidden_states[1] + pooled_output = outputs[1] + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get('hidden_states') + return {'hidden_states': hidden_states} + + +def bert_for_question_answering_forward( + self: BertForQuestionAnswering, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + start_positions: Optional[torch.Tensor] = None, + end_positions: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.Tensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, +): + # NOTE: the arg start_position and end_position are used only for the last stage + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = bert_model_forward( + self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index, + ) + if stage_manager.is_last_stage(): + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get('hidden_states') + return {'hidden_states': hidden_states} diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index a3ea807269bb..b3757452c314 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -212,11 +212,13 @@ def get_held_layers(self) -> List[Module]: return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: - """No shared params in llama model""" llama_model = self.model.model if id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight): # tie weights - return [{0: llama_model.embed_tokens.weight, self.stage_manager.num_stages - 1: self.model.lm_head.weight}] + return [{ + 0: llama_model.embed_tokens.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight + }] return [] From 1714cfc9fd64b199944d91cab812cd5641c6b426 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Fri, 14 Jul 2023 17:43:15 +0800 Subject: [PATCH 16/23] finish bert tests --- .../shardformer/policies/auto_policy.py | 2 + colossalai/shardformer/policies/bert.py | 54 +++++++++++-------- tests/kit/model_zoo/torchrec/__init__.py | 2 +- tests/kit/model_zoo/transformers/bert.py | 6 +++ .../test_model/test_shard_bert_pipeline.py | 28 +++++++--- 5 files changed, 61 insertions(+), 31 deletions(-) diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 0ad9a3e95a0e..ccdb33b2efe5 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -42,6 +42,8 @@ class PolicyLocation: PolicyLocation(file_name="bert", class_name="BertForNextSentencePredictionPolicy"), "transformers.models.bert.modeling_bert.BertForMultipleChoice": PolicyLocation(file_name="bert", class_name="BertForMultipleChoicePolicy"), + "transformers.models.bert.modeling_bert.BertForQuestionAnswering": + PolicyLocation(file_name="bert", class_name="BertForQuestionAnsweringPolicy"), # LLaMA "transformers.models.llama.modeling_llama.LlamaModel": diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 5d30ff84631d..d7c38607238b 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -11,6 +11,7 @@ BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, CausalLMOutputWithCrossAttentions, + MultipleChoiceModelOutput, NextSentencePredictorOutput, QuestionAnsweringModelOutput, SequenceClassifierOutput, @@ -869,8 +870,9 @@ def custom_forward(*inputs): return (sequence_output, pooled_output) + layer_outputs[1:] # return dict is not supported at this moment else: - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, past_key_values=next_decoder_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, @@ -964,23 +966,26 @@ def bert_for_pretraining_forward( } -def bert_lm_head_model_forward(self: BertLMHeadModel, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.Tensor]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_manager: Optional[PipelineStageManager] = None): +def bert_lm_head_model_forward( + self: BertLMHeadModel, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.Tensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, +): r""" encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if @@ -1034,7 +1039,8 @@ def bert_lm_head_model_forward(self: BertLMHeadModel, output_hidden_states=output_hidden_states, return_dict=return_dict, stage_manager=stage_manager, - hidden_states=hidden_states if hidden_states is not None else None) + hidden_states=hidden_states if hidden_states is not None else None, + stage_index=stage_index) past_key_values = None all_hidden_states = None all_self_attentions = None @@ -1226,6 +1232,7 @@ def bert_for_next_sentence_prediction_forward( hidden_states=hidden_states, stage_manager=stage_manager, stage_index=stage_index) + if stage_manager.is_last_stage(): pooled_output = outputs[1] seq_relationship_scores = self.cls(pooled_output) @@ -1453,6 +1460,11 @@ def bert_for_multipile_choice_forward( return_dict = False return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # in our pipeline design,input ids are copied for every stage and shouldn't be none + # the input_ids for multiple choice model is [batch_size, num_choices, sequence_length] + if stage_manager.is_last_stage(): + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None @@ -1476,8 +1488,6 @@ def bert_for_multipile_choice_forward( stage_index=stage_index, ) if stage_manager.is_last_stage(): - # the num_choices is only used for the last stage - num_choices = hidden_states[1] pooled_output = outputs[1] pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) diff --git a/tests/kit/model_zoo/torchrec/__init__.py b/tests/kit/model_zoo/torchrec/__init__.py index 43952e6998cf..4a19f2449602 100644 --- a/tests/kit/model_zoo/torchrec/__init__.py +++ b/tests/kit/model_zoo/torchrec/__init__.py @@ -1 +1 @@ -from .torchrec import * +#from .torchrec import * diff --git a/tests/kit/model_zoo/transformers/bert.py b/tests/kit/model_zoo/transformers/bert.py index d2d3de7b7bee..fa4875308b55 100644 --- a/tests/kit/model_zoo/transformers/bert.py +++ b/tests/kit/model_zoo/transformers/bert.py @@ -150,3 +150,9 @@ def data_gen_for_mcq(): output_transform_fn=output_transform_fn, loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_bert_for_question_answering', + model_fn=lambda: transformers.BertForQuestionAnswering(config), + data_gen_fn=data_gen_for_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/test_shardformer/test_model/test_shard_bert_pipeline.py b/tests/test_shardformer/test_model/test_shard_bert_pipeline.py index 9cca5ec8bc51..a4785b1f54f0 100644 --- a/tests/test_shardformer/test_model/test_shard_bert_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_bert_pipeline.py @@ -45,25 +45,37 @@ def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_laz stage_manager = PipelineStageManager(pg_mesh, PP_DIM) sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') - x = torch.randint(0, 1000, (2, 3)).cuda() - hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda() for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - if name == 'transformers_bert': - org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, - enable_tensor_parallelism, use_lazy_init) + org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, + enable_tensor_parallelism, use_lazy_init) + if name == 'transformers_bert_for_mcq': + x = torch.randint(0, 1000, (2, 3, 3)).cuda() + attention_mask = torch.ones_like(x).cuda() + if stage_manager.stage == 0: + output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) + assert output['hidden_states'].shape == (6, 3, 128) + else: + hidden_states = torch.randint(0, 1000, (6, 3, 128)).to(torch.float32).cuda() + output = sharded_model(input_ids=x, + hidden_states=hidden_states, + attention_mask=attention_mask, + stage_manager=stage_manager) + print('model_name. ', name, 'output. ', output) + else: + x = torch.randint(0, 1000, (2, 3)).cuda() + # one batch, 2 single sentences, each sentence has 3 tokens + hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda() if stage_manager.stage == 0: attention_mask = torch.ones_like(x).cuda() output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) - # print(output['hidden_states'].shape) assert output['hidden_states'].shape == (2, 3, 128) else: attention_mask = torch.ones((2, 3)).cuda() output = sharded_model(hidden_states=hidden_states, attention_mask=attention_mask, stage_manager=stage_manager) - # print(output[0].shape) - assert output[0].shape == (2, 3, 128) + print('model_name. ', name, 'output. ', output[0].shape) torch.cuda.empty_cache() From 5ecf2f7852462bb8b365bdb34a813e270cbb93ef Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Fri, 14 Jul 2023 18:07:19 +0800 Subject: [PATCH 17/23] fix bugs --- tests/test_shardformer/test_model/test_shard_bert_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_shardformer/test_model/test_shard_bert_pipeline.py b/tests/test_shardformer/test_model/test_shard_bert_pipeline.py index a4785b1f54f0..4feaf982aa37 100644 --- a/tests/test_shardformer/test_model/test_shard_bert_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_bert_pipeline.py @@ -61,7 +61,7 @@ def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_laz hidden_states=hidden_states, attention_mask=attention_mask, stage_manager=stage_manager) - print('model_name. ', name, 'output. ', output) + assert output[0].shape == (2, 3) else: x = torch.randint(0, 1000, (2, 3)).cuda() # one batch, 2 single sentences, each sentence has 3 tokens @@ -75,7 +75,7 @@ def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_laz output = sharded_model(hidden_states=hidden_states, attention_mask=attention_mask, stage_manager=stage_manager) - print('model_name. ', name, 'output. ', output[0].shape) + assert output[0].shape[0] == 2 torch.cuda.empty_cache() From 90c7e18cf3d986766166a1f57b7096041ac09a8a Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Fri, 14 Jul 2023 18:17:23 +0800 Subject: [PATCH 18/23] fix bugs --- .../{test_bert_lmhead_model.py => test_bert_lm_head_model.py} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename tests/test_pipeline/test_policy/{test_bert_lmhead_model.py => test_bert_lm_head_model.py} (99%) diff --git a/tests/test_pipeline/test_policy/test_bert_lmhead_model.py b/tests/test_pipeline/test_policy/test_bert_lm_head_model.py similarity index 99% rename from tests/test_pipeline/test_policy/test_bert_lmhead_model.py rename to tests/test_pipeline/test_policy/test_bert_lm_head_model.py index b14dadf29e3c..2b042481e548 100644 --- a/tests/test_pipeline/test_policy/test_bert_lmhead_model.py +++ b/tests/test_pipeline/test_policy/test_bert_lm_head_model.py @@ -7,7 +7,7 @@ import colossalai from colossalai.cluster import ProcessGroupMesh from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.policies.bert import BertLMHeadModelPolicy, bert_lmhead_forward +from colossalai.shardformer.policies.bert import BertLMHeadModelPolicy, bert_lm_head_model_forward from colossalai.shardformer.shard import ShardConfig from colossalai.testing import rerun_if_address_is_in_use, spawn From fa7f27009a46ba38a9e15f2f563c49c88cecc8e1 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 17 Jul 2023 10:59:38 +0800 Subject: [PATCH 19/23] fix test pipeline --- .../test_bert_for_pretraining_model.py | 17 +++++++--- .../test_policy/test_bert_lm_head_model.py | 31 +++++++++++-------- .../test_policy/test_bert_model.py | 15 ++++++--- 3 files changed, 40 insertions(+), 23 deletions(-) diff --git a/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py b/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py index 97d7d2fa538a..571bc12fd3c7 100644 --- a/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py +++ b/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py @@ -7,6 +7,7 @@ import colossalai from colossalai.cluster import ProcessGroupMesh from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.policies.base_policy import Policy from colossalai.shardformer.policies.bert import BertForPreTrainingPolicy, bert_for_pretraining_forward from colossalai.shardformer.shard import ShardConfig from colossalai.testing import rerun_if_address_is_in_use, spawn @@ -35,15 +36,20 @@ def check_bert_for_pretraining_forward(): stage_manager = PipelineStageManager(pg_mesh, PP_DIM) rank = dist.get_rank() # print(rank) + layers_per_stage = Policy.distribute_layers(len(model.bert.encoder.layer), 2) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) x = torch.randint(0, 1000, (2, 3)) hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32) if stage_manager.stage == 0: attention_mask = torch.ones_like(x) - output = bert_for_pretraining_forward(self=model, - input_ids=x, - attention_mask=attention_mask, - stage_manager=stage_manager) + output = bert_for_pretraining_forward( + self=model, + input_ids=x, + attention_mask=attention_mask, + stage_manager=stage_manager, + stage_index=stage_index, + ) print(output['hidden_states'].shape) assert output['hidden_states'].shape == (2, 3, 768) @@ -52,7 +58,8 @@ def check_bert_for_pretraining_forward(): output = bert_for_pretraining_forward(self=model, hidden_states=hidden_states, attention_mask=attention_mask, - stage_manager=stage_manager) + stage_manager=stage_manager, + stage_index=stage_index) print(output[0].shape) assert output[0].shape == (2, 3, 30522) # assert output[1].shape == (2, 768) diff --git a/tests/test_pipeline/test_policy/test_bert_lm_head_model.py b/tests/test_pipeline/test_policy/test_bert_lm_head_model.py index 2b042481e548..cd47f7a33c4b 100644 --- a/tests/test_pipeline/test_policy/test_bert_lm_head_model.py +++ b/tests/test_pipeline/test_policy/test_bert_lm_head_model.py @@ -7,12 +7,13 @@ import colossalai from colossalai.cluster import ProcessGroupMesh from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.policies.base_policy import Policy from colossalai.shardformer.policies.bert import BertLMHeadModelPolicy, bert_lm_head_model_forward from colossalai.shardformer.shard import ShardConfig from colossalai.testing import rerun_if_address_is_in_use, spawn -def check_bert_lmhead_forward(): +def check_bert_lm_head_model_forward(): configuration = BertConfig() model = BertLMHeadModel(configuration) DP_DIM, PP_DIM = 0, 1 @@ -35,24 +36,28 @@ def check_bert_lmhead_forward(): stage_manager = PipelineStageManager(pg_mesh, PP_DIM) rank = dist.get_rank() # print(rank) - + layers_per_stage = Policy.distribute_layers(len(model.bert.encoder.layer), 2) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) x = torch.randint(0, 1000, (2, 3)) hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32) if stage_manager.stage == 0: attention_mask = torch.ones_like(x) - output = bert_lmhead_forward(self=model, - input_ids=x, - attention_mask=attention_mask, - stage_manager=stage_manager) + + output = bert_lm_head_model_forward(self=model, + input_ids=x, + attention_mask=attention_mask, + stage_manager=stage_manager, + stage_index=stage_index) print(output['hidden_states'].shape) assert output['hidden_states'].shape == (2, 3, 768) else: attention_mask = torch.ones((2, 3)) - output = bert_lmhead_forward(self=model, - hidden_states=hidden_states, - attention_mask=attention_mask, - stage_manager=stage_manager) + output = bert_lm_head_model_forward(self=model, + hidden_states=hidden_states, + attention_mask=attention_mask, + stage_manager=stage_manager, + stage_index=stage_index) print(output[0].shape) assert output[0].shape == (2, 3, 30522) @@ -93,7 +98,7 @@ def check_bert_lmhead_policy(): def run_dist_model(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') - check_bert_lmhead_forward() + check_bert_lm_head_model_forward() def run_dist_policy(rank, world_size, port): @@ -103,7 +108,7 @@ def run_dist_policy(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() -def test_bert_lmhead_forward(): +def test_bert_lm_head_model_forward(): spawn(run_dist_model, 4) @@ -115,5 +120,5 @@ def test_bert_lmhead_policy(): if __name__ == "__main__": """test the bert for pretraining model forward and bert for pretraining model policy""" - test_bert_lmhead_forward() + test_bert_lm_head_model_forward() test_bert_lmhead_policy() diff --git a/tests/test_pipeline/test_policy/test_bert_model.py b/tests/test_pipeline/test_policy/test_bert_model.py index f5a443309cb2..92bbcdc54d2b 100644 --- a/tests/test_pipeline/test_policy/test_bert_model.py +++ b/tests/test_pipeline/test_policy/test_bert_model.py @@ -6,6 +6,7 @@ import colossalai from colossalai.cluster import ProcessGroupMesh from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.policies.base_policy import Policy from colossalai.shardformer.policies.bert import BertModelPolicy, bert_model_forward from colossalai.shardformer.shard import ShardConfig from colossalai.testing import rerun_if_address_is_in_use, spawn @@ -34,20 +35,25 @@ def check_bert_model_forward(): stage_manager = PipelineStageManager(pg_mesh, PP_DIM) rank = dist.get_rank() # print(rank) - + layers_per_stage = Policy.distribute_layers(len(model.encoder.layer), 2) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) x = torch.randint(0, 1000, (2, 3)) hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32) if stage_manager.stage == 0: attention_mask = torch.ones_like(x) - output = bert_model_forward(self=model, input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) - print(output['hidden_states'].shape) + output = bert_model_forward(self=model, + input_ids=x, + attention_mask=attention_mask, + stage_manager=stage_manager, + stage_index=stage_index) assert output['hidden_states'].shape == (2, 3, 768) else: attention_mask = torch.ones((2, 3)) output = bert_model_forward(self=model, hidden_states=hidden_states, attention_mask=attention_mask, - stage_manager=stage_manager) + stage_manager=stage_manager, + stage_index=stage_index) print(output[0].shape) assert output[0].shape == (2, 3, 768) @@ -112,4 +118,3 @@ def test_bert_model_policy(): """test the bert model forward and bert model policy""" #test_bert_model_forward() test_bert_model_policy() - # this test need config to run From 7a168dac832e7ec44567d0469531c5853c860f34 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 17 Jul 2023 11:37:15 +0800 Subject: [PATCH 20/23] fix data gen for qa --- tests/kit/model_zoo/transformers/bert.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/kit/model_zoo/transformers/bert.py b/tests/kit/model_zoo/transformers/bert.py index fa4875308b55..8bc284502087 100644 --- a/tests/kit/model_zoo/transformers/bert.py +++ b/tests/kit/model_zoo/transformers/bert.py @@ -87,6 +87,13 @@ def data_gen_for_mcq(): return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels) +def data_gen_for_qa(): + # generating data for question answering + # no need for labels + data = data_gen() + return data + + # define output transform function output_transform_fn = lambda x: x @@ -152,7 +159,7 @@ def data_gen_for_mcq(): model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_bert_for_question_answering', model_fn=lambda: transformers.BertForQuestionAnswering(config), - data_gen_fn=data_gen_for_lm, + data_gen_fn=data_gen_for_qa, output_transform_fn=output_transform_fn, loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) From 4a2b7a3363bb48223da93c90fc2eb97a190619ee Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 17 Jul 2023 15:12:55 +0800 Subject: [PATCH 21/23] update the set pipeline forward --- colossalai/shardformer/policies/bert.py | 145 +++++------------- .../test_bert_for_pretraining_model.py | 2 - .../test_policy/test_bert_model.py | 1 + 3 files changed, 41 insertions(+), 107 deletions(-) diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index d7c38607238b..682a97c467f9 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -1,13 +1,12 @@ from functools import partial from types import MethodType -from typing import Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn from torch import Tensor from torch.nn import CrossEntropyLoss, Module from transformers.modeling_outputs import ( - BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -181,6 +180,25 @@ def add_lm_head_policy(self, base_policy): def postprocess(self): return self.model + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if self.pipeline_stage_manager: + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "BertModel": + module = self.model + else: + module = self.model.bert + + layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=model_cls) + + return + # BertModel class BertModelPolicy(BertPolicy): @@ -191,17 +209,7 @@ def __init__(self) -> None: def module_policy(self): policy = super().module_policy() from transformers.models.bert.modeling_bert import BertModel - if self.pipeline_stage_manager: - # set None as default - stage_manager = self.pipeline_stage_manager - layers_per_stage = Policy.distribute_layers(len(self.model.encoder.layer), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = { - 'forward': partial(bert_model_forward, stage_manager=stage_manager, stage_index=stage_index) - } - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=BertModel) + self.set_pipeline_forward(model_cls=BertModel, new_forward=bert_model_forward, policy=policy) return policy def get_held_layers(self) -> List[Module]: @@ -233,17 +241,7 @@ def module_policy(self): policy = super().module_policy() policy = self.add_lm_head_policy(policy) from transformers.models.bert.modeling_bert import BertForPreTraining - if self.pipeline_stage_manager: - # set None as default - stage_manager = self.pipeline_stage_manager - layers_per_stage = Policy.distribute_layers(len(self.model.bert.encoder.layer), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = { - 'forward': partial(bert_for_pretraining_forward, stage_manager=stage_manager, stage_index=stage_index) - } - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=BertForPreTraining) + self.set_pipeline_forward(model_cls=BertForPreTraining, new_forward=bert_for_pretraining_forward, policy=policy) return policy def get_held_layers(self) -> List[Module]: @@ -293,16 +291,7 @@ def module_policy(self): policy = super().module_policy() policy = self.add_lm_head_policy(policy) from transformers.models.bert.modeling_bert import BertLMHeadModel - if self.pipeline_stage_manager: - stage_manager = self.pipeline_stage_manager - layers_per_stage = Policy.distribute_layers(len(self.model.bert.encoder.layer), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = { - 'forward': partial(bert_lm_head_model_forward, stage_manager=stage_manager, stage_index=stage_index) - } - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=BertLMHeadModel) + self.set_pipeline_forward(model_cls=BertLMHeadModel, new_forward=bert_lm_head_model_forward, policy=policy) return policy def get_held_layers(self) -> List[Module]: @@ -351,16 +340,7 @@ def module_policy(self): policy = super().module_policy() policy = self.add_lm_head_policy(policy) from transformers.models.bert.modeling_bert import BertForMaskedLM - if self.pipeline_stage_manager: - stage_manager = self.pipeline_stage_manager - layers_per_stage = Policy.distribute_layers(len(self.model.bert.encoder.layer), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = { - 'forward': partial(bert_for_masked_lm_forward, stage_manager=stage_manager, stage_index=stage_index) - } - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=BertForMaskedLM) + self.set_pipeline_forward(model_cls=BertForMaskedLM, new_forward=bert_for_masked_lm_forward, policy=policy) return policy def get_held_layers(self) -> List[Module]: @@ -422,19 +402,9 @@ def module_policy(self): } policy.update(addon_module) - if self.pipeline_stage_manager: - stage_manager = self.pipeline_stage_manager - layers_per_stage = Policy.distribute_layers(len(self.model.bert.encoder.layer), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = { - 'forward': - partial(bert_for_sequence_classification_forward, - stage_manager=stage_manager, - stage_index=stage_index) - } - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=BertForSequenceClassification) + self.set_pipeline_forward(model_cls=BertForSequenceClassification, + new_forward=bert_for_sequence_classification_forward, + policy=policy) return policy @@ -484,17 +454,9 @@ def module_policy(self): } policy.update(addon_module) - if self.pipeline_stage_manager: - stage_manager = self.pipeline_stage_manager - layers_per_stage = Policy.distribute_layers(len(self.model.bert.encoder.layer), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = { - 'forward': - partial(bert_for_token_classification_forward, stage_manager=stage_manager, stage_index=stage_index) - } - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=BertForTokenClassification) + self.set_pipeline_forward(model_cls=BertForTokenClassification, + new_forward=bert_for_token_classification_forward, + policy=policy) return policy @@ -530,20 +492,9 @@ def __init__(self) -> None: def module_policy(self): policy = super().module_policy() from transformers.models.bert.modeling_bert import BertForNextSentencePrediction - - if self.pipeline_stage_manager: - stage_manager = self.pipeline_stage_manager - layers_per_stage = Policy.distribute_layers(len(self.model.bert.encoder.layer), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = { - 'forward': - partial(bert_for_next_sentence_prediction_forward, - stage_manager=stage_manager, - stage_index=stage_index) - } - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=BertForNextSentencePrediction) + self.set_pipeline_forward(model_cls=BertForNextSentencePrediction, + new_forward=bert_for_next_sentence_prediction_forward, + policy=policy) return policy @@ -592,17 +543,9 @@ def module_policy(self): } policy.update(addon_module) - if self.pipeline_stage_manager: - stage_manager = self.pipeline_stage_manager - layers_per_stage = Policy.distribute_layers(len(self.model.bert.encoder.layer), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = { - 'forward': - partial(bert_for_multipile_choice_forward, stage_manager=stage_manager, stage_index=stage_index) - } - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=BertForMultipleChoice) + self.set_pipeline_forward(model_cls=BertForMultipleChoice, + new_forward=bert_for_multiple_choice_forward, + policy=policy) return policy @@ -637,17 +580,9 @@ def __init__(self) -> None: def module_policy(self): from transformers.models.bert.modeling_bert import BertForQuestionAnswering policy = super().module_policy() - if self.pipeline_stage_manager: - stage_manager = self.pipeline_stage_manager - layers_per_stage = Policy.distribute_layers(len(self.model.bert.encoder.layer), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = { - 'forward': - partial(bert_for_question_answering_forward, stage_manager=stage_manager, stage_index=stage_index) - } - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=BertForQuestionAnswering) + self.set_pipeline_forward(model_cls=BertForQuestionAnswering, + new_forward=bert_for_question_answering_forward, + policy=policy) return policy @@ -1425,7 +1360,7 @@ def bert_for_token_classification_forward( return {'hidden_states': hidden_states} -def bert_for_multipile_choice_forward( +def bert_for_multiple_choice_forward( self: BertForMultipleChoice, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, diff --git a/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py b/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py index 571bc12fd3c7..6a8d7b636375 100644 --- a/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py +++ b/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py @@ -50,7 +50,6 @@ def check_bert_for_pretraining_forward(): stage_manager=stage_manager, stage_index=stage_index, ) - print(output['hidden_states'].shape) assert output['hidden_states'].shape == (2, 3, 768) else: @@ -60,7 +59,6 @@ def check_bert_for_pretraining_forward(): attention_mask=attention_mask, stage_manager=stage_manager, stage_index=stage_index) - print(output[0].shape) assert output[0].shape == (2, 3, 30522) # assert output[1].shape == (2, 768) diff --git a/tests/test_pipeline/test_policy/test_bert_model.py b/tests/test_pipeline/test_policy/test_bert_model.py index 92bbcdc54d2b..f116bc761aa7 100644 --- a/tests/test_pipeline/test_policy/test_bert_model.py +++ b/tests/test_pipeline/test_policy/test_bert_model.py @@ -13,6 +13,7 @@ def check_bert_model_forward(): + # this test may crash for internet reasons model = BertModel.from_pretrained('bert-base-uncased') DP_DIM, PP_DIM = 0, 1 DP_SIZE, PP_SIZE = 2, 2 From 364e30db417b9f86bc2b50e9effa1eba5ddd27ca Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 17 Jul 2023 15:19:11 +0800 Subject: [PATCH 22/23] shared params --- colossalai/shardformer/policies/bert.py | 45 +++++++++++++------------ 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 682a97c467f9..1af26f50484c 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -264,16 +264,17 @@ def get_held_layers(self) -> List[Module]: def get_shared_params(self) -> List[Dict[int, Tensor]]: model = self.model - if id(model.bert.embeddings.word_embeddings.weight) == id(model.cls.predictions.decoder.weight): - #tie weights - return [{ - 0: model.bert.embeddings.word_embeddings.weight, - self.pipeline_stage_manager.num_stages - 1: model.cls.predictions.decoder.weight - }] + if self.pipeline_stage_manager: + if id(model.bert.embeddings.word_embeddings.weight) == id(model.cls.predictions.decoder.weight): + #tie weights + return [{ + 0: model.bert.embeddings.word_embeddings.weight, + self.pipeline_stage_manager.num_stages - 1: model.cls.predictions.decoder.weight + }] return [] def postprocess(self): - if self.shard_config.enable_tensor_parallelism: + if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} for k, v in binding_map.items(): param = getattr_(self.model, k) @@ -313,16 +314,17 @@ def get_held_layers(self) -> List[Module]: def get_shared_params(self) -> List[Dict[int, Tensor]]: bert_model = self.model.bert - if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight): - #tie weights - return [{ - 0: bert_model.embeddings.word_embeddings.weight, - self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight - }] + if self.pipeline_stage_manager: + if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight): + #tie weights + return [{ + 0: bert_model.embeddings.word_embeddings.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight + }] return [] def postprocess(self): - if self.shard_config.enable_tensor_parallelism: + if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} for k, v in binding_map.items(): param = getattr_(self.model, k) @@ -362,16 +364,17 @@ def get_held_layers(self) -> List[Module]: def get_shared_params(self) -> List[Dict[int, Tensor]]: bert_model = self.model.bert - if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight): - #tie weights - return [{ - 0: bert_model.embeddings.word_embeddings.weight, - self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight - }] + if self.pipeline_stage_manager: + if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight): + #tie weights + return [{ + 0: bert_model.embeddings.word_embeddings.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight + }] return [] def postprocess(self): - if self.shard_config.enable_tensor_parallelism: + if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} for k, v in binding_map.items(): param = getattr_(self.model, k) From df4ff4798620fecf35310ab8a880a996082fd15e Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 17 Jul 2023 16:10:53 +0800 Subject: [PATCH 23/23] fix bugs --- tests/kit/model_zoo/transformers/bert.py | 6 +++++- tests/test_shardformer/test_model/_utils.py | 1 - tests/test_shardformer/test_model/test_pure_pipeline.py | 3 ++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/kit/model_zoo/transformers/bert.py b/tests/kit/model_zoo/transformers/bert.py index 8bc284502087..1993af51ad63 100644 --- a/tests/kit/model_zoo/transformers/bert.py +++ b/tests/kit/model_zoo/transformers/bert.py @@ -89,8 +89,12 @@ def data_gen_for_mcq(): def data_gen_for_qa(): # generating data for question answering - # no need for labels + # no need for labels and use start and end position instead data = data_gen() + start_positions = torch.tensor([0], dtype=torch.int64) + data['start_positions'] = start_positions + end_positions = torch.tensor([1], dtype=torch.int64) + data['end_positions'] = end_positions return data diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index f26c6622da7e..825d6df6bb5e 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -49,7 +49,6 @@ def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, # prepare input data = data_gen_fn() data = {k: v.cuda() for k, v in data.items()} - # switch to train mode original_model.train() sharded_model.train() diff --git a/tests/test_shardformer/test_model/test_pure_pipeline.py b/tests/test_shardformer/test_model/test_pure_pipeline.py index a82c39406dfa..24cda193a5e6 100644 --- a/tests/test_shardformer/test_model/test_pure_pipeline.py +++ b/tests/test_shardformer/test_model/test_pure_pipeline.py @@ -141,7 +141,8 @@ def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_la pp_optimizer = PipelineOptimizer(optimizer, pipelined_model) data_it = iter(data_iter()) results = execute_pipeline(data_it, pipelined_model, loss, pp_optimizer, schedule=schedule) - assert results['loss'] is not None + if stage_manager.is_last_stage(): + assert results['loss'] is not None assert results['outputs'] is None torch.cuda.empty_cache()